Add joins support
This commit is contained in:
		
							parent
							
								
									a992c1ea38
								
							
						
					
					
						commit
						50aa9be4f1
					
				| @ -1,6 +1,7 @@ | |||||||
| package callbacks | package callbacks | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"fmt" | ||||||
| 	"reflect" | 	"reflect" | ||||||
| 
 | 
 | ||||||
| 	"github.com/jinzhu/gorm" | 	"github.com/jinzhu/gorm" | ||||||
| @ -9,8 +10,76 @@ import ( | |||||||
| 
 | 
 | ||||||
| func Query(db *gorm.DB) { | func Query(db *gorm.DB) { | ||||||
| 	if db.Statement.SQL.String() == "" { | 	if db.Statement.SQL.String() == "" { | ||||||
| 		db.Statement.AddClauseIfNotExists(clause.Select{}) | 		clauseSelect := clause.Select{} | ||||||
| 		db.Statement.AddClauseIfNotExists(clause.From{}) | 
 | ||||||
|  | 		if len(db.Statement.Selects) > 0 { | ||||||
|  | 			for _, name := range db.Statement.Selects { | ||||||
|  | 				if f := db.Statement.Schema.LookUpField(name); f != nil { | ||||||
|  | 					clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ | ||||||
|  | 						Name: f.DBName, | ||||||
|  | 					}) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if len(db.Statement.Joins) != 0 { | ||||||
|  | 			joins := []clause.Join{} | ||||||
|  | 
 | ||||||
|  | 			if len(db.Statement.Selects) == 0 { | ||||||
|  | 				for _, dbName := range db.Statement.Schema.DBNames { | ||||||
|  | 					clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ | ||||||
|  | 						Name: dbName, | ||||||
|  | 					}) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			for name, conds := range db.Statement.Joins { | ||||||
|  | 				if relation, ok := db.Statement.Schema.Relationships.Relations[name]; ok { | ||||||
|  | 					for _, s := range relation.FieldSchema.DBNames { | ||||||
|  | 						clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ | ||||||
|  | 							Table: relation.FieldSchema.Table, | ||||||
|  | 							Name:  s, | ||||||
|  | 						}) | ||||||
|  | 					} | ||||||
|  | 
 | ||||||
|  | 					var exprs []clause.Expression | ||||||
|  | 					for _, ref := range relation.References { | ||||||
|  | 						if ref.OwnPrimaryKey { | ||||||
|  | 							exprs = append(exprs, clause.Expr{ | ||||||
|  | 								SQL: fmt.Sprintf("%s.%s = %s.%s", db.Statement.Schema.Table, ref.PrimaryKey.DBName, relation.FieldSchema.Table, ref.ForeignKey.DBName), | ||||||
|  | 							}) | ||||||
|  | 						} else { | ||||||
|  | 							if ref.PrimaryValue == "" { | ||||||
|  | 								exprs = append(exprs, clause.Expr{ | ||||||
|  | 									SQL: fmt.Sprintf("%s.%s = %s.%s", db.Statement.Schema.Table, ref.ForeignKey.DBName, relation.FieldSchema.Table, ref.PrimaryKey.DBName), | ||||||
|  | 								}) | ||||||
|  | 							} else { | ||||||
|  | 								exprs = append(exprs, clause.Expr{ | ||||||
|  | 									SQL:  fmt.Sprintf("%s.%s = ?", relation.FieldSchema.Table, ref.PrimaryKey.DBName), | ||||||
|  | 									Vars: []interface{}{ref.PrimaryValue}, | ||||||
|  | 								}) | ||||||
|  | 							} | ||||||
|  | 						} | ||||||
|  | 					} | ||||||
|  | 
 | ||||||
|  | 					joins = append(joins, clause.Join{ | ||||||
|  | 						Type:  clause.LeftJoin, | ||||||
|  | 						Table: clause.Table{Name: relation.FieldSchema.Table}, | ||||||
|  | 						ON:    clause.Where{Exprs: exprs}, | ||||||
|  | 					}) | ||||||
|  | 				} else { | ||||||
|  | 					joins = append(joins, clause.Join{ | ||||||
|  | 						Expression: clause.Expr{SQL: name, Vars: conds}, | ||||||
|  | 					}) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			db.Statement.AddClause(clause.From{Joins: joins}) | ||||||
|  | 		} else { | ||||||
|  | 			db.Statement.AddClauseIfNotExists(clause.From{}) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		db.Statement.AddClauseIfNotExists(clauseSelect) | ||||||
| 		db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") | 		db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -134,6 +134,10 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { | |||||||
| //     db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
 | //     db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
 | ||||||
| func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { | func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
|  | 	if tx.Statement.Joins == nil { | ||||||
|  | 		tx.Statement.Joins = map[string][]interface{}{} | ||||||
|  | 	} | ||||||
|  | 	tx.Statement.Joins[query] = args | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| @ -211,8 +215,12 @@ func (db *DB) Scopes(funcs ...func(*DB) *DB) *DB { | |||||||
| 
 | 
 | ||||||
| // Preload preload associations with given conditions
 | // Preload preload associations with given conditions
 | ||||||
| //    db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
 | //    db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
 | ||||||
| func (db *DB) Preload(column string, conditions ...interface{}) (tx *DB) { | func (db *DB) Preload(query string, args ...interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
|  | 	if tx.Statement.Preloads == nil { | ||||||
|  | 		tx.Statement.Preloads = map[string][]interface{}{} | ||||||
|  | 	} | ||||||
|  | 	tx.Statement.Preloads[query] = args | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -11,32 +11,37 @@ const ( | |||||||
| 
 | 
 | ||||||
| // Join join clause for from
 | // Join join clause for from
 | ||||||
| type Join struct { | type Join struct { | ||||||
| 	Type  JoinType | 	Type       JoinType | ||||||
| 	Table Table | 	Table      Table | ||||||
| 	ON    Where | 	ON         Where | ||||||
| 	Using []string | 	Using      []string | ||||||
|  | 	Expression Expression | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (join Join) Build(builder Builder) { | func (join Join) Build(builder Builder) { | ||||||
| 	if join.Type != "" { | 	if join.Expression != nil { | ||||||
| 		builder.WriteString(string(join.Type)) | 		join.Expression.Build(builder) | ||||||
| 		builder.WriteByte(' ') | 	} else { | ||||||
| 	} | 		if join.Type != "" { | ||||||
| 
 | 			builder.WriteString(string(join.Type)) | ||||||
| 	builder.WriteString("JOIN ") | 			builder.WriteByte(' ') | ||||||
| 	builder.WriteQuoted(join.Table) | 		} | ||||||
| 
 | 
 | ||||||
| 	if len(join.ON.Exprs) > 0 { | 		builder.WriteString("JOIN ") | ||||||
| 		builder.WriteString(" ON ") | 		builder.WriteQuoted(join.Table) | ||||||
| 		join.ON.Build(builder) | 
 | ||||||
| 	} else if len(join.Using) > 0 { | 		if len(join.ON.Exprs) > 0 { | ||||||
| 		builder.WriteString(" USING (") | 			builder.WriteString(" ON ") | ||||||
| 		for idx, c := range join.Using { | 			join.ON.Build(builder) | ||||||
| 			if idx > 0 { | 		} else if len(join.Using) > 0 { | ||||||
| 				builder.WriteByte(',') | 			builder.WriteString(" USING (") | ||||||
| 			} | 			for idx, c := range join.Using { | ||||||
| 			builder.WriteQuoted(c) | 				if idx > 0 { | ||||||
|  | 					builder.WriteByte(',') | ||||||
|  | 				} | ||||||
|  | 				builder.WriteQuoted(c) | ||||||
|  | 			} | ||||||
|  | 			builder.WriteByte(')') | ||||||
| 		} | 		} | ||||||
| 		builder.WriteByte(')') |  | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | |||||||
							
								
								
									
										10
									
								
								statement.go
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								statement.go
									
									
									
									
									
								
							| @ -24,6 +24,8 @@ type Statement struct { | |||||||
| 	Clauses              map[string]clause.Clause | 	Clauses              map[string]clause.Clause | ||||||
| 	Selects              []string // selected columns
 | 	Selects              []string // selected columns
 | ||||||
| 	Omits                []string // omit columns
 | 	Omits                []string // omit columns
 | ||||||
|  | 	Joins                map[string][]interface{} | ||||||
|  | 	Preloads             map[string][]interface{} | ||||||
| 	Settings             sync.Map | 	Settings             sync.Map | ||||||
| 	ConnPool             ConnPool | 	ConnPool             ConnPool | ||||||
| 	Schema               *schema.Schema | 	Schema               *schema.Schema | ||||||
| @ -265,6 +267,14 @@ func (stmt *Statement) reinit() { | |||||||
| 		delete(stmt.Clauses, k) | 		delete(stmt.Clauses, k) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	for k := range stmt.Joins { | ||||||
|  | 		delete(stmt.Joins, k) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for k := range stmt.Preloads { | ||||||
|  | 		delete(stmt.Preloads, k) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	stmt.Settings.Range(func(k, _ interface{}) bool { | 	stmt.Settings.Range(func(k, _ interface{}) bool { | ||||||
| 		stmt.Settings.Delete(k) | 		stmt.Settings.Delete(k) | ||||||
| 		return true | 		return true | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu