From 895c1178a0d1d837cd986c45eac62f6b10a6add4 Mon Sep 17 00:00:00 2001 From: Adrien Carreira Date: Thu, 8 Jul 2021 10:04:40 +0200 Subject: [PATCH] Proposal, Add Specific on for Joins queries --- callbacks/query.go | 47 ++++++++++++++++++++++++++-------------------- chainable_api.go | 6 ++++++ statement.go | 1 + 3 files changed, 34 insertions(+), 20 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index 3299d015..e5f1250c 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -125,33 +125,40 @@ func BuildQuerySQL(db *gorm.DB) { }) } - exprs := make([]clause.Expression, len(relation.References)) - for idx, ref := range relation.References { - if ref.OwnPrimaryKey { - exprs[idx] = clause.Eq{ - Column: clause.Column{Table: clause.CurrentTable, Name: ref.PrimaryKey.DBName}, - Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, - } - } else { - if ref.PrimaryValue == "" { + if join.On != nil { + exprs := make([]clause.Expression, len(relation.References)) + for idx, ref := range relation.References { + if ref.OwnPrimaryKey { exprs[idx] = clause.Eq{ - Column: clause.Column{Table: clause.CurrentTable, Name: ref.ForeignKey.DBName}, - Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, + Column: clause.Column{Table: clause.CurrentTable, Name: ref.PrimaryKey.DBName}, + Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, } } else { - exprs[idx] = clause.Eq{ - Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, - Value: ref.PrimaryValue, + if ref.PrimaryValue == "" { + exprs[idx] = clause.Eq{ + Column: clause.Column{Table: clause.CurrentTable, Name: ref.ForeignKey.DBName}, + Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, + } + } else { + exprs[idx] = clause.Eq{ + Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + } } } } + joins = append(joins, clause.Join{ + Type: clause.LeftJoin, + Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, + ON: clause.Where{Exprs: exprs}, + }) + } else { + joins = append(joins, clause.Join{ + Type: clause.LeftJoin, + Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, + ON: clause.Where{Exprs: []clause.Expression{join.On}}, + }) } - - joins = append(joins, clause.Join{ - Type: clause.LeftJoin, - Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, - ON: clause.Where{Exprs: exprs}, - }) } else { joins = append(joins, clause.Join{ Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, diff --git a/chainable_api.go b/chainable_api.go index 88279044..32943a83 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -177,6 +177,12 @@ func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { return } +func (db *DB) JoinsOn(query string, on clause.Expression, args ...interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, On: on}) + return +} + // Group specify the group method on the find func (db *DB) Group(name string) (tx *DB) { tx = db.getInstance() diff --git a/statement.go b/statement.go index 93b78c12..89824bc1 100644 --- a/statement.go +++ b/statement.go @@ -50,6 +50,7 @@ type Statement struct { type join struct { Name string Conds []interface{} + On clause.Expression } // StatementModifier statement modifier interface