From 6ecce561e6c8baf41e1fcea6d55238f3f4a46f4c Mon Sep 17 00:00:00 2001 From: Matt Schiros Date: Thu, 17 Jan 2019 11:57:22 -0800 Subject: [PATCH] added arbitrary join conditions as a tag option for many2many relationships --- callback_query.go | 2 +- join_table_handler.go | 6 +++++- model_struct.go | 8 ++++++++ 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/callback_query.go b/callback_query.go index 593e5d30..7facc42b 100644 --- a/callback_query.go +++ b/callback_query.go @@ -18,7 +18,7 @@ func queryCallback(scope *Scope) { if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { return } - + //we are only preloading relations, dont touch base model if _, skip := scope.InstanceGet("gorm:only_preload"); skip { return diff --git a/join_table_handler.go b/join_table_handler.go index a036d46d..d97f6d52 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -42,6 +42,7 @@ type JoinTableHandler struct { TableName string `sql:"-"` Source JoinTableSource `sql:"-"` Destination JoinTableSource `sql:"-"` + ArbitraryJoinConditions []string `sql:"-"` } // SourceForeignKeys return source foreign keys @@ -66,7 +67,7 @@ func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, s AssociationDBName: dbName, }) } - + s.Destination = JoinTableSource{ModelType: destination} s.Destination.ForeignKeys = []JoinTableForeignKey{} for idx, dbName := range relationship.AssociationForeignFieldNames { @@ -75,6 +76,8 @@ func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, s AssociationDBName: dbName, }) } + + s.ArbitraryJoinConditions = relationship.ArbitraryJoinConditions } // Table return join table's table name @@ -174,6 +177,7 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so for _, foreignKey := range s.Destination.ForeignKeys { joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTableName, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName))) } + joinConditions = append(joinConditions, s.ArbitraryJoinConditions...) var foreignDBNames []string var foreignFieldNames []string diff --git a/model_struct.go b/model_struct.go index 9e93db63..9d49b299 100644 --- a/model_struct.go +++ b/model_struct.go @@ -133,6 +133,7 @@ type Relationship struct { AssociationForeignFieldNames []string AssociationForeignDBNames []string JoinTableHandler JoinTableHandlerInterface + ArbitraryJoinConditions []string } func getForeignField(column string, fields []*StructField) *StructField { @@ -339,6 +340,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } } + // Check for arbitrary join conditions supplied in the tag + + if arbitraryConditions,_ := field.TagSettingsGet("ARBITRARY_JOIN_CONDITIONS"); arbitraryConditions != "" { + relationship.ArbitraryJoinConditions = strings.Split(arbitraryConditions, ",") + } + + joinTableHandler := JoinTableHandler{} joinTableHandler.Setup(relationship, ToTableName(many2many), reflectType, elemType) relationship.JoinTableHandler = &joinTableHandler