added arbitrary join conditions as a tag option for many2many relationships

This commit is contained in:
Matt Schiros 2019-01-17 11:57:22 -08:00
parent 9f1a7f5351
commit 6ecce561e6
3 changed files with 14 additions and 2 deletions

View File

@ -18,7 +18,7 @@ func queryCallback(scope *Scope) {
if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
return
}
//we are only preloading relations, dont touch base model
if _, skip := scope.InstanceGet("gorm:only_preload"); skip {
return

View File

@ -42,6 +42,7 @@ type JoinTableHandler struct {
TableName string `sql:"-"`
Source JoinTableSource `sql:"-"`
Destination JoinTableSource `sql:"-"`
ArbitraryJoinConditions []string `sql:"-"`
}
// SourceForeignKeys return source foreign keys
@ -66,7 +67,7 @@ func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, s
AssociationDBName: dbName,
})
}
s.Destination = JoinTableSource{ModelType: destination}
s.Destination.ForeignKeys = []JoinTableForeignKey{}
for idx, dbName := range relationship.AssociationForeignFieldNames {
@ -75,6 +76,8 @@ func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, s
AssociationDBName: dbName,
})
}
s.ArbitraryJoinConditions = relationship.ArbitraryJoinConditions
}
// Table return join table's table name
@ -174,6 +177,7 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so
for _, foreignKey := range s.Destination.ForeignKeys {
joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTableName, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName)))
}
joinConditions = append(joinConditions, s.ArbitraryJoinConditions...)
var foreignDBNames []string
var foreignFieldNames []string

View File

@ -133,6 +133,7 @@ type Relationship struct {
AssociationForeignFieldNames []string
AssociationForeignDBNames []string
JoinTableHandler JoinTableHandlerInterface
ArbitraryJoinConditions []string
}
func getForeignField(column string, fields []*StructField) *StructField {
@ -339,6 +340,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
}
}
// Check for arbitrary join conditions supplied in the tag
if arbitraryConditions,_ := field.TagSettingsGet("ARBITRARY_JOIN_CONDITIONS"); arbitraryConditions != "" {
relationship.ArbitraryJoinConditions = strings.Split(arbitraryConditions, ",")
}
joinTableHandler := JoinTableHandler{}
joinTableHandler.Setup(relationship, ToTableName(many2many), reflectType, elemType)
relationship.JoinTableHandler = &joinTableHandler