Refactor JoinTableHandler
This commit is contained in:
		
							parent
							
								
									c13e2f18f8
								
							
						
					
					
						commit
						6ba0c1661f
					
				
							
								
								
									
										103
									
								
								join_table.go
									
									
									
									
									
								
							
							
						
						
									
										103
									
								
								join_table.go
									
									
									
									
									
								
							@ -2,6 +2,7 @@ package gorm
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
 | 
						"reflect"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -13,70 +14,114 @@ type JoinTableHandlerInterface interface {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type JoinTableSource struct {
 | 
					type JoinTableSource struct {
 | 
				
			||||||
	ForeignKey       string
 | 
						ModelType   reflect.Type
 | 
				
			||||||
	ForeignKeyPrefix string
 | 
						ForeignKeys []struct {
 | 
				
			||||||
	ModelStruct
 | 
							DBName            string
 | 
				
			||||||
 | 
							AssociationDBName string
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type JoinTableHandler struct {
 | 
					type JoinTableHandler struct {
 | 
				
			||||||
	TableName string
 | 
						TableName   string          `sql:"-"`
 | 
				
			||||||
	Source1   JoinTableSource
 | 
						Source      JoinTableSource `sql:"-"`
 | 
				
			||||||
	Source2   JoinTableSource
 | 
						Destination JoinTableSource `sql:"-"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (jt JoinTableHandler) Table(*DB) string {
 | 
					func (s JoinTableHandler) Table(*DB) string {
 | 
				
			||||||
	return jt.TableName
 | 
						return s.TableName
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (jt JoinTableHandler) GetValueMap(db *DB, sources ...interface{}) map[string]interface{} {
 | 
					func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[string]interface{} {
 | 
				
			||||||
	values := map[string]interface{}{}
 | 
						values := map[string]interface{}{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for _, source := range sources {
 | 
						for _, source := range sources {
 | 
				
			||||||
		scope := db.NewScope(source)
 | 
							scope := db.NewScope(source)
 | 
				
			||||||
		for _, primaryField := range scope.GetModelStruct().PrimaryFields {
 | 
							modelType := scope.GetModelStruct().ModelType
 | 
				
			||||||
			if field, ok := scope.Fields()[primaryField.DBName]; ok {
 | 
					
 | 
				
			||||||
				values[primaryField.DBName] = field.Field.Interface()
 | 
							if s.Source.ModelType == modelType {
 | 
				
			||||||
 | 
								for _, foreignKey := range s.Source.ForeignKeys {
 | 
				
			||||||
 | 
									values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface()
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							} else if s.Destination.ModelType == modelType {
 | 
				
			||||||
 | 
								for _, foreignKey := range s.Destination.ForeignKeys {
 | 
				
			||||||
 | 
									values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface()
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return values
 | 
						return values
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (jt JoinTableHandler) Add(db *DB, source1 interface{}, source2 interface{}) error {
 | 
					func (s JoinTableHandler) Add(db *DB, source1 interface{}, source2 interface{}) error {
 | 
				
			||||||
	scope := db.NewScope("")
 | 
						scope := db.NewScope("")
 | 
				
			||||||
	valueMap := jt.GetValueMap(db, source1, source2)
 | 
						searchMap := s.GetSearchMap(db, source1, source2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var setColumns, setBinVars, queryConditions []string
 | 
						var assignColumns, binVars, conditions []string
 | 
				
			||||||
	var values []interface{}
 | 
						var values []interface{}
 | 
				
			||||||
	for key, value := range valueMap {
 | 
						for key, value := range searchMap {
 | 
				
			||||||
		setColumns = append(setColumns, key)
 | 
							assignColumns = append(assignColumns, key)
 | 
				
			||||||
		setBinVars = append(setBinVars, `?`)
 | 
							binVars = append(binVars, `?`)
 | 
				
			||||||
		queryConditions = append(queryConditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
 | 
							conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
 | 
				
			||||||
		values = append(values, value)
 | 
							values = append(values, value)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for _, value := range valueMap {
 | 
						for _, value := range searchMap {
 | 
				
			||||||
		values = append(values, value)
 | 
							values = append(values, value)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	quotedTable := jt.Table(db)
 | 
						quotedTable := s.Table(db)
 | 
				
			||||||
	sql := fmt.Sprintf(
 | 
						sql := fmt.Sprintf(
 | 
				
			||||||
		"INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v);",
 | 
							"INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v);",
 | 
				
			||||||
		quotedTable,
 | 
							quotedTable,
 | 
				
			||||||
		strings.Join(setColumns, ","),
 | 
							strings.Join(assignColumns, ","),
 | 
				
			||||||
		strings.Join(setBinVars, ","),
 | 
							strings.Join(binVars, ","),
 | 
				
			||||||
		scope.Dialect().SelectFromDummyTable(),
 | 
							scope.Dialect().SelectFromDummyTable(),
 | 
				
			||||||
		quotedTable,
 | 
							quotedTable,
 | 
				
			||||||
		strings.Join(queryConditions, " AND "),
 | 
							strings.Join(conditions, " AND "),
 | 
				
			||||||
	)
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return db.Exec(sql, values...).Error
 | 
						return db.Exec(sql, values...).Error
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (jt JoinTableHandler) Delete(db *DB, sources ...interface{}) error {
 | 
					func (s JoinTableHandler) Delete(db *DB, sources ...interface{}) error {
 | 
				
			||||||
	// return db.Table(jt.Table(db)).Delete("").Error
 | 
						var conditions []string
 | 
				
			||||||
	return nil
 | 
						var values []interface{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for key, value := range s.GetSearchMap(db, sources...) {
 | 
				
			||||||
 | 
							conditions = append(conditions, fmt.Sprintf("%v = ?", key))
 | 
				
			||||||
 | 
							values = append(values, value)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return db.Table(s.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (jt JoinTableHandler) JoinWith(db *DB, sources interface{}) *DB {
 | 
					func (s JoinTableHandler) JoinWith(db *DB, source interface{}) *DB {
 | 
				
			||||||
	return db
 | 
						quotedTable := s.Table(db)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						scope := db.NewScope(source)
 | 
				
			||||||
 | 
						modelType := scope.GetModelStruct().ModelType
 | 
				
			||||||
 | 
						var joinConditions []string
 | 
				
			||||||
 | 
						var queryConditions []string
 | 
				
			||||||
 | 
						var values []interface{}
 | 
				
			||||||
 | 
						if s.Source.ModelType == modelType {
 | 
				
			||||||
 | 
							for _, foreignKey := range s.Destination.ForeignKeys {
 | 
				
			||||||
 | 
								joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), scope.QuotedTableName(), scope.Quote(foreignKey.AssociationDBName)))
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							for _, foreignKey := range s.Source.ForeignKeys {
 | 
				
			||||||
 | 
								queryConditions = append(queryConditions, fmt.Sprintf("%v.%v = ?", quotedTable, scope.Quote(foreignKey.DBName)))
 | 
				
			||||||
 | 
								values = append(values, scope.Fields()[foreignKey.AssociationDBName].Field.Interface())
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						} else if s.Destination.ModelType == modelType {
 | 
				
			||||||
 | 
							for _, foreignKey := range s.Source.ForeignKeys {
 | 
				
			||||||
 | 
								joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), scope.QuotedTableName(), scope.Quote(foreignKey.AssociationDBName)))
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							for _, foreignKey := range s.Destination.ForeignKeys {
 | 
				
			||||||
 | 
								queryConditions = append(queryConditions, fmt.Sprintf("%v.%v = ?", quotedTable, scope.Quote(foreignKey.DBName)))
 | 
				
			||||||
 | 
								values = append(values, scope.Fields()[foreignKey.AssociationDBName].Field.Interface())
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", strings.Join(joinConditions, " AND "))).
 | 
				
			||||||
 | 
							Where(strings.Join(queryConditions, " AND "), values...)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -15,16 +15,13 @@ type Person struct {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type PersonAddress struct {
 | 
					type PersonAddress struct {
 | 
				
			||||||
 | 
						gorm.JoinTableHandler
 | 
				
			||||||
	PersonID  int
 | 
						PersonID  int
 | 
				
			||||||
	AddressID int
 | 
						AddressID int
 | 
				
			||||||
	DeletedAt time.Time
 | 
						DeletedAt time.Time
 | 
				
			||||||
	CreatedAt time.Time
 | 
						CreatedAt time.Time
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (*PersonAddress) Table(db *gorm.DB, relationship *gorm.Relationship) string {
 | 
					 | 
				
			||||||
	return relationship.JoinTable
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (*PersonAddress) Add(db *gorm.DB, relationship *gorm.Relationship, foreignValue interface{}, associationValue interface{}) error {
 | 
					func (*PersonAddress) Add(db *gorm.DB, relationship *gorm.Relationship, foreignValue interface{}, associationValue interface{}) error {
 | 
				
			||||||
	return db.Where(map[string]interface{}{
 | 
						return db.Where(map[string]interface{}{
 | 
				
			||||||
		relationship.ForeignDBName:            foreignValue,
 | 
							relationship.ForeignDBName:            foreignValue,
 | 
				
			||||||
@ -41,14 +38,14 @@ func (*PersonAddress) Delete(db *gorm.DB, relationship *gorm.Relationship) error
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (pa *PersonAddress) Scope(db *gorm.DB, relationship *gorm.Relationship) *gorm.DB {
 | 
					func (pa *PersonAddress) Scope(db *gorm.DB, relationship *gorm.Relationship) *gorm.DB {
 | 
				
			||||||
	table := pa.Table(db, relationship)
 | 
						table := pa.Table(db)
 | 
				
			||||||
	return db.Table(table).Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table))
 | 
						return db.Table(table).Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestJoinTable(t *testing.T) {
 | 
					func TestJoinTable(t *testing.T) {
 | 
				
			||||||
	DB.Exec("drop table person_addresses;")
 | 
						DB.Exec("drop table person_addresses;")
 | 
				
			||||||
	DB.AutoMigrate(&Person{})
 | 
						DB.AutoMigrate(&Person{})
 | 
				
			||||||
	DB.SetJoinTableHandler(&PersonAddress{}, "person_addresses")
 | 
						// DB.SetJoinTableHandler(&PersonAddress{}, "person_addresses")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	address1 := &Address{Address1: "address 1"}
 | 
						address1 := &Address{Address1: "address 1"}
 | 
				
			||||||
	address2 := &Address{Address1: "address 2"}
 | 
						address2 := &Address{Address1: "address 2"}
 | 
				
			||||||
 | 
				
			|||||||
@ -437,7 +437,8 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
func (scope *Scope) createJoinTable(field *StructField) {
 | 
					func (scope *Scope) createJoinTable(field *StructField) {
 | 
				
			||||||
	if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
 | 
						if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
 | 
				
			||||||
		joinTable := relationship.JoinTableHandler.Table(scope.db)
 | 
							joinTableHandler := relationship.JoinTableHandler
 | 
				
			||||||
 | 
							joinTable := joinTableHandler.Table(scope.db)
 | 
				
			||||||
		if !scope.Dialect().HasTable(scope, joinTable) {
 | 
							if !scope.Dialect().HasTable(scope, joinTable) {
 | 
				
			||||||
			primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryField().Field, 255, false)
 | 
								primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryField().Field, 255, false)
 | 
				
			||||||
			scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)",
 | 
								scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)",
 | 
				
			||||||
@ -447,7 +448,7 @@ func (scope *Scope) createJoinTable(field *StructField) {
 | 
				
			|||||||
					scope.Quote(relationship.AssociationForeignDBName) + " " + primaryKeySqlType}, ",")),
 | 
										scope.Quote(relationship.AssociationForeignDBName) + " " + primaryKeySqlType}, ",")),
 | 
				
			||||||
			).Error)
 | 
								).Error)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		scope.NewDB().Table(joinTable).AutoMigrate()
 | 
							scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user