Add joins support
This commit is contained in:
		
							parent
							
								
									a992c1ea38
								
							
						
					
					
						commit
						50aa9be4f1
					
				| @ -1,6 +1,7 @@ | ||||
| package callbacks | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm" | ||||
| @ -9,8 +10,76 @@ import ( | ||||
| 
 | ||||
| func Query(db *gorm.DB) { | ||||
| 	if db.Statement.SQL.String() == "" { | ||||
| 		db.Statement.AddClauseIfNotExists(clause.Select{}) | ||||
| 		clauseSelect := clause.Select{} | ||||
| 
 | ||||
| 		if len(db.Statement.Selects) > 0 { | ||||
| 			for _, name := range db.Statement.Selects { | ||||
| 				if f := db.Statement.Schema.LookUpField(name); f != nil { | ||||
| 					clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ | ||||
| 						Name: f.DBName, | ||||
| 					}) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if len(db.Statement.Joins) != 0 { | ||||
| 			joins := []clause.Join{} | ||||
| 
 | ||||
| 			if len(db.Statement.Selects) == 0 { | ||||
| 				for _, dbName := range db.Statement.Schema.DBNames { | ||||
| 					clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ | ||||
| 						Name: dbName, | ||||
| 					}) | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			for name, conds := range db.Statement.Joins { | ||||
| 				if relation, ok := db.Statement.Schema.Relationships.Relations[name]; ok { | ||||
| 					for _, s := range relation.FieldSchema.DBNames { | ||||
| 						clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ | ||||
| 							Table: relation.FieldSchema.Table, | ||||
| 							Name:  s, | ||||
| 						}) | ||||
| 					} | ||||
| 
 | ||||
| 					var exprs []clause.Expression | ||||
| 					for _, ref := range relation.References { | ||||
| 						if ref.OwnPrimaryKey { | ||||
| 							exprs = append(exprs, clause.Expr{ | ||||
| 								SQL: fmt.Sprintf("%s.%s = %s.%s", db.Statement.Schema.Table, ref.PrimaryKey.DBName, relation.FieldSchema.Table, ref.ForeignKey.DBName), | ||||
| 							}) | ||||
| 						} else { | ||||
| 							if ref.PrimaryValue == "" { | ||||
| 								exprs = append(exprs, clause.Expr{ | ||||
| 									SQL: fmt.Sprintf("%s.%s = %s.%s", db.Statement.Schema.Table, ref.ForeignKey.DBName, relation.FieldSchema.Table, ref.PrimaryKey.DBName), | ||||
| 								}) | ||||
| 							} else { | ||||
| 								exprs = append(exprs, clause.Expr{ | ||||
| 									SQL:  fmt.Sprintf("%s.%s = ?", relation.FieldSchema.Table, ref.PrimaryKey.DBName), | ||||
| 									Vars: []interface{}{ref.PrimaryValue}, | ||||
| 								}) | ||||
| 							} | ||||
| 						} | ||||
| 					} | ||||
| 
 | ||||
| 					joins = append(joins, clause.Join{ | ||||
| 						Type:  clause.LeftJoin, | ||||
| 						Table: clause.Table{Name: relation.FieldSchema.Table}, | ||||
| 						ON:    clause.Where{Exprs: exprs}, | ||||
| 					}) | ||||
| 				} else { | ||||
| 					joins = append(joins, clause.Join{ | ||||
| 						Expression: clause.Expr{SQL: name, Vars: conds}, | ||||
| 					}) | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			db.Statement.AddClause(clause.From{Joins: joins}) | ||||
| 		} else { | ||||
| 			db.Statement.AddClauseIfNotExists(clause.From{}) | ||||
| 		} | ||||
| 
 | ||||
| 		db.Statement.AddClauseIfNotExists(clauseSelect) | ||||
| 		db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") | ||||
| 	} | ||||
| 
 | ||||
|  | ||||
| @ -134,6 +134,10 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { | ||||
| //     db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
 | ||||
| func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { | ||||
| 	tx = db.getInstance() | ||||
| 	if tx.Statement.Joins == nil { | ||||
| 		tx.Statement.Joins = map[string][]interface{}{} | ||||
| 	} | ||||
| 	tx.Statement.Joins[query] = args | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| @ -211,8 +215,12 @@ func (db *DB) Scopes(funcs ...func(*DB) *DB) *DB { | ||||
| 
 | ||||
| // Preload preload associations with given conditions
 | ||||
| //    db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
 | ||||
| func (db *DB) Preload(column string, conditions ...interface{}) (tx *DB) { | ||||
| func (db *DB) Preload(query string, args ...interface{}) (tx *DB) { | ||||
| 	tx = db.getInstance() | ||||
| 	if tx.Statement.Preloads == nil { | ||||
| 		tx.Statement.Preloads = map[string][]interface{}{} | ||||
| 	} | ||||
| 	tx.Statement.Preloads[query] = args | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
|  | ||||
| @ -15,9 +15,13 @@ type Join struct { | ||||
| 	Table      Table | ||||
| 	ON         Where | ||||
| 	Using      []string | ||||
| 	Expression Expression | ||||
| } | ||||
| 
 | ||||
| func (join Join) Build(builder Builder) { | ||||
| 	if join.Expression != nil { | ||||
| 		join.Expression.Build(builder) | ||||
| 	} else { | ||||
| 		if join.Type != "" { | ||||
| 			builder.WriteString(string(join.Type)) | ||||
| 			builder.WriteByte(' ') | ||||
| @ -39,4 +43,5 @@ func (join Join) Build(builder Builder) { | ||||
| 			} | ||||
| 			builder.WriteByte(')') | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
							
								
								
									
										10
									
								
								statement.go
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								statement.go
									
									
									
									
									
								
							| @ -24,6 +24,8 @@ type Statement struct { | ||||
| 	Clauses              map[string]clause.Clause | ||||
| 	Selects              []string // selected columns
 | ||||
| 	Omits                []string // omit columns
 | ||||
| 	Joins                map[string][]interface{} | ||||
| 	Preloads             map[string][]interface{} | ||||
| 	Settings             sync.Map | ||||
| 	ConnPool             ConnPool | ||||
| 	Schema               *schema.Schema | ||||
| @ -265,6 +267,14 @@ func (stmt *Statement) reinit() { | ||||
| 		delete(stmt.Clauses, k) | ||||
| 	} | ||||
| 
 | ||||
| 	for k := range stmt.Joins { | ||||
| 		delete(stmt.Joins, k) | ||||
| 	} | ||||
| 
 | ||||
| 	for k := range stmt.Preloads { | ||||
| 		delete(stmt.Preloads, k) | ||||
| 	} | ||||
| 
 | ||||
| 	stmt.Settings.Range(func(k, _ interface{}) bool { | ||||
| 		stmt.Settings.Delete(k) | ||||
| 		return true | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu