From 50aa9be4f10d8a1562fef223efcf9fee6a02d256 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 15 Apr 2020 09:14:24 +0800 Subject: [PATCH] Add joins support --- callbacks/query.go | 73 ++++++++++++++++++++++++++++++++++++++++++++-- chainable_api.go | 10 ++++++- clause/joins.go | 51 +++++++++++++++++--------------- statement.go | 10 +++++++ 4 files changed, 118 insertions(+), 26 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index 00820bfd..ae22f4d0 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -1,6 +1,7 @@ package callbacks import ( + "fmt" "reflect" "github.com/jinzhu/gorm" @@ -9,8 +10,76 @@ import ( func Query(db *gorm.DB) { if db.Statement.SQL.String() == "" { - db.Statement.AddClauseIfNotExists(clause.Select{}) - db.Statement.AddClauseIfNotExists(clause.From{}) + clauseSelect := clause.Select{} + + if len(db.Statement.Selects) > 0 { + for _, name := range db.Statement.Selects { + if f := db.Statement.Schema.LookUpField(name); f != nil { + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ + Name: f.DBName, + }) + } + } + } + + if len(db.Statement.Joins) != 0 { + joins := []clause.Join{} + + if len(db.Statement.Selects) == 0 { + for _, dbName := range db.Statement.Schema.DBNames { + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ + Name: dbName, + }) + } + } + + for name, conds := range db.Statement.Joins { + if relation, ok := db.Statement.Schema.Relationships.Relations[name]; ok { + for _, s := range relation.FieldSchema.DBNames { + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ + Table: relation.FieldSchema.Table, + Name: s, + }) + } + + var exprs []clause.Expression + for _, ref := range relation.References { + if ref.OwnPrimaryKey { + exprs = append(exprs, clause.Expr{ + SQL: fmt.Sprintf("%s.%s = %s.%s", db.Statement.Schema.Table, ref.PrimaryKey.DBName, relation.FieldSchema.Table, ref.ForeignKey.DBName), + }) + } else { + if ref.PrimaryValue == "" { + exprs = append(exprs, clause.Expr{ + SQL: fmt.Sprintf("%s.%s = %s.%s", db.Statement.Schema.Table, ref.ForeignKey.DBName, relation.FieldSchema.Table, ref.PrimaryKey.DBName), + }) + } else { + exprs = append(exprs, clause.Expr{ + SQL: fmt.Sprintf("%s.%s = ?", relation.FieldSchema.Table, ref.PrimaryKey.DBName), + Vars: []interface{}{ref.PrimaryValue}, + }) + } + } + } + + joins = append(joins, clause.Join{ + Type: clause.LeftJoin, + Table: clause.Table{Name: relation.FieldSchema.Table}, + ON: clause.Where{Exprs: exprs}, + }) + } else { + joins = append(joins, clause.Join{ + Expression: clause.Expr{SQL: name, Vars: conds}, + }) + } + } + + db.Statement.AddClause(clause.From{Joins: joins}) + } else { + db.Statement.AddClauseIfNotExists(clause.From{}) + } + + db.Statement.AddClauseIfNotExists(clauseSelect) db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") } diff --git a/chainable_api.go b/chainable_api.go index 7a6e8b7c..6b91c9ad 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -134,6 +134,10 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { // db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { tx = db.getInstance() + if tx.Statement.Joins == nil { + tx.Statement.Joins = map[string][]interface{}{} + } + tx.Statement.Joins[query] = args return } @@ -211,8 +215,12 @@ func (db *DB) Scopes(funcs ...func(*DB) *DB) *DB { // Preload preload associations with given conditions // db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) -func (db *DB) Preload(column string, conditions ...interface{}) (tx *DB) { +func (db *DB) Preload(query string, args ...interface{}) (tx *DB) { tx = db.getInstance() + if tx.Statement.Preloads == nil { + tx.Statement.Preloads = map[string][]interface{}{} + } + tx.Statement.Preloads[query] = args return } diff --git a/clause/joins.go b/clause/joins.go index a78bde39..8d9055cd 100644 --- a/clause/joins.go +++ b/clause/joins.go @@ -11,32 +11,37 @@ const ( // Join join clause for from type Join struct { - Type JoinType - Table Table - ON Where - Using []string + Type JoinType + Table Table + ON Where + Using []string + Expression Expression } func (join Join) Build(builder Builder) { - if join.Type != "" { - builder.WriteString(string(join.Type)) - builder.WriteByte(' ') - } - - builder.WriteString("JOIN ") - builder.WriteQuoted(join.Table) - - if len(join.ON.Exprs) > 0 { - builder.WriteString(" ON ") - join.ON.Build(builder) - } else if len(join.Using) > 0 { - builder.WriteString(" USING (") - for idx, c := range join.Using { - if idx > 0 { - builder.WriteByte(',') - } - builder.WriteQuoted(c) + if join.Expression != nil { + join.Expression.Build(builder) + } else { + if join.Type != "" { + builder.WriteString(string(join.Type)) + builder.WriteByte(' ') + } + + builder.WriteString("JOIN ") + builder.WriteQuoted(join.Table) + + if len(join.ON.Exprs) > 0 { + builder.WriteString(" ON ") + join.ON.Build(builder) + } else if len(join.Using) > 0 { + builder.WriteString(" USING (") + for idx, c := range join.Using { + if idx > 0 { + builder.WriteByte(',') + } + builder.WriteQuoted(c) + } + builder.WriteByte(')') } - builder.WriteByte(')') } } diff --git a/statement.go b/statement.go index e45bd8bb..3f2ceca3 100644 --- a/statement.go +++ b/statement.go @@ -24,6 +24,8 @@ type Statement struct { Clauses map[string]clause.Clause Selects []string // selected columns Omits []string // omit columns + Joins map[string][]interface{} + Preloads map[string][]interface{} Settings sync.Map ConnPool ConnPool Schema *schema.Schema @@ -265,6 +267,14 @@ func (stmt *Statement) reinit() { delete(stmt.Clauses, k) } + for k := range stmt.Joins { + delete(stmt.Joins, k) + } + + for k := range stmt.Preloads { + delete(stmt.Preloads, k) + } + stmt.Settings.Range(func(k, _ interface{}) bool { stmt.Settings.Delete(k) return true