Codebase list golang-github-jinzhu-gorm / fresh-snapshots/main join_table_handler.go
fresh-snapshots/main

Tree @fresh-snapshots/main (Download .tar.gz)

join_table_handler.go @fresh-snapshots/mainraw · history · blame

package gorm

import (
	"errors"
	"fmt"
	"reflect"
	"strings"
)

// JoinTableHandlerInterface is an interface for how to handle many2many relations
type JoinTableHandlerInterface interface {
	// initialize join table handler
	Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type)
	// Table return join table's table name
	Table(db *DB) string
	// Add create relationship in join table for source and destination
	Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error
	// Delete delete relationship in join table for sources
	Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error
	// JoinWith query with `Join` conditions
	JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB
	// SourceForeignKeys return source foreign keys
	SourceForeignKeys() []JoinTableForeignKey
	// DestinationForeignKeys return destination foreign keys
	DestinationForeignKeys() []JoinTableForeignKey
}

// JoinTableForeignKey join table foreign key struct
type JoinTableForeignKey struct {
	DBName            string
	AssociationDBName string
}

// JoinTableSource is a struct that contains model type and foreign keys
type JoinTableSource struct {
	ModelType   reflect.Type
	ForeignKeys []JoinTableForeignKey
}

// JoinTableHandler default join table handler
type JoinTableHandler struct {
	TableName   string          `sql:"-"`
	Source      JoinTableSource `sql:"-"`
	Destination JoinTableSource `sql:"-"`
}

// SourceForeignKeys return source foreign keys
func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey {
	return s.Source.ForeignKeys
}

// DestinationForeignKeys return destination foreign keys
func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey {
	return s.Destination.ForeignKeys
}

// Setup initialize a default join table handler
func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) {
	s.TableName = tableName

	s.Source = JoinTableSource{ModelType: source}
	s.Source.ForeignKeys = []JoinTableForeignKey{}
	for idx, dbName := range relationship.ForeignFieldNames {
		s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{
			DBName:            relationship.ForeignDBNames[idx],
			AssociationDBName: dbName,
		})
	}

	s.Destination = JoinTableSource{ModelType: destination}
	s.Destination.ForeignKeys = []JoinTableForeignKey{}
	for idx, dbName := range relationship.AssociationForeignFieldNames {
		s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{
			DBName:            relationship.AssociationForeignDBNames[idx],
			AssociationDBName: dbName,
		})
	}
}

// Table return join table's table name
func (s JoinTableHandler) Table(db *DB) string {
	return DefaultTableNameHandler(db, s.TableName)
}

func (s JoinTableHandler) updateConditionMap(conditionMap map[string]interface{}, db *DB, joinTableSources []JoinTableSource, sources ...interface{}) {
	for _, source := range sources {
		scope := db.NewScope(source)
		modelType := scope.GetModelStruct().ModelType

		for _, joinTableSource := range joinTableSources {
			if joinTableSource.ModelType == modelType {
				for _, foreignKey := range joinTableSource.ForeignKeys {
					if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
						conditionMap[foreignKey.DBName] = field.Field.Interface()
					}
				}
				break
			}
		}
	}
}

// Add create relationship in join table for source and destination
func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error {
	var (
		scope        = db.NewScope("")
		conditionMap = map[string]interface{}{}
	)

	// Update condition map for source
	s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source}, source)

	// Update condition map for destination
	s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Destination}, destination)

	var assignColumns, binVars, conditions []string
	var values []interface{}
	for key, value := range conditionMap {
		assignColumns = append(assignColumns, scope.Quote(key))
		binVars = append(binVars, `?`)
		conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
		values = append(values, value)
	}

	for _, value := range values {
		values = append(values, value)
	}

	quotedTable := scope.Quote(handler.Table(db))
	sql := fmt.Sprintf(
		"INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v)",
		quotedTable,
		strings.Join(assignColumns, ","),
		strings.Join(binVars, ","),
		scope.Dialect().SelectFromDummyTable(),
		quotedTable,
		strings.Join(conditions, " AND "),
	)

	return db.Exec(sql, values...).Error
}

// Delete delete relationship in join table for sources
func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error {
	var (
		scope        = db.NewScope(nil)
		conditions   []string
		values       []interface{}
		conditionMap = map[string]interface{}{}
	)

	s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source, s.Destination}, sources...)

	for key, value := range conditionMap {
		conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
		values = append(values, value)
	}

	return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error
}

// JoinWith query with `Join` conditions
func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB {
	var (
		scope           = db.NewScope(source)
		tableName       = handler.Table(db)
		quotedTableName = scope.Quote(tableName)
		joinConditions  []string
		values          []interface{}
	)

	if s.Source.ModelType == scope.GetModelStruct().ModelType {
		destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName()
		for _, foreignKey := range s.Destination.ForeignKeys {
			joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTableName, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName)))
		}

		var foreignDBNames []string
		var foreignFieldNames []string

		for _, foreignKey := range s.Source.ForeignKeys {
			foreignDBNames = append(foreignDBNames, foreignKey.DBName)
			if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
				foreignFieldNames = append(foreignFieldNames, field.Name)
			}
		}

		foreignFieldValues := scope.getColumnAsArray(foreignFieldNames, scope.Value)

		var condString string
		if len(foreignFieldValues) > 0 {
			var quotedForeignDBNames []string
			for _, dbName := range foreignDBNames {
				quotedForeignDBNames = append(quotedForeignDBNames, tableName+"."+dbName)
			}

			condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues))

			keys := scope.getColumnAsArray(foreignFieldNames, scope.Value)
			values = append(values, toQueryValues(keys))
		} else {
			condString = fmt.Sprintf("1 <> 1")
		}

		return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTableName, strings.Join(joinConditions, " AND "))).
			Where(condString, toQueryValues(foreignFieldValues)...)
	}

	db.Error = errors.New("wrong source type for join table handler")
	return db
}