Fix unordered joins, close #3267
This commit is contained in:
		
							parent
							
								
									2b510d6423
								
							
						
					
					
						commit
						3a97639880
					
				| @ -104,12 +104,12 @@ func BuildQuerySQL(db *gorm.DB) { | |||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		joins := []clause.Join{} | 		joins := []clause.Join{} | ||||||
| 		for name, conds := range db.Statement.Joins { | 		for _, join := range db.Statement.Joins { | ||||||
| 			if db.Statement.Schema == nil { | 			if db.Statement.Schema == nil { | ||||||
| 				joins = append(joins, clause.Join{ | 				joins = append(joins, clause.Join{ | ||||||
| 					Expression: clause.Expr{SQL: name, Vars: conds}, | 					Expression: clause.Expr{SQL: join.Name, Vars: join.Conds}, | ||||||
| 				}) | 				}) | ||||||
| 			} else if relation, ok := db.Statement.Schema.Relationships.Relations[name]; ok { | 			} else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok { | ||||||
| 				tableAliasName := relation.Name | 				tableAliasName := relation.Name | ||||||
| 
 | 
 | ||||||
| 				for _, s := range relation.FieldSchema.DBNames { | 				for _, s := range relation.FieldSchema.DBNames { | ||||||
| @ -149,7 +149,7 @@ func BuildQuerySQL(db *gorm.DB) { | |||||||
| 				}) | 				}) | ||||||
| 			} else { | 			} else { | ||||||
| 				joins = append(joins, clause.Join{ | 				joins = append(joins, clause.Join{ | ||||||
| 					Expression: clause.Expr{SQL: name, Vars: conds}, | 					Expression: clause.Expr{SQL: join.Name, Vars: join.Conds}, | ||||||
| 				}) | 				}) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|  | |||||||
| @ -172,10 +172,7 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { | |||||||
| //     db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
 | //     db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
 | ||||||
| func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { | func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
| 	if tx.Statement.Joins == nil { | 	tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args}) | ||||||
| 		tx.Statement.Joins = map[string][]interface{}{} |  | ||||||
| 	} |  | ||||||
| 	tx.Statement.Joins[query] = args |  | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | |||||||
							
								
								
									
										13
									
								
								statement.go
									
									
									
									
									
								
							
							
						
						
									
										13
									
								
								statement.go
									
									
									
									
									
								
							| @ -29,7 +29,7 @@ type Statement struct { | |||||||
| 	Distinct             bool | 	Distinct             bool | ||||||
| 	Selects              []string // selected columns
 | 	Selects              []string // selected columns
 | ||||||
| 	Omits                []string // omit columns
 | 	Omits                []string // omit columns
 | ||||||
| 	Joins                map[string][]interface{} | 	Joins                []join | ||||||
| 	Preloads             map[string][]interface{} | 	Preloads             map[string][]interface{} | ||||||
| 	Settings             sync.Map | 	Settings             sync.Map | ||||||
| 	ConnPool             ConnPool | 	ConnPool             ConnPool | ||||||
| @ -44,6 +44,11 @@ type Statement struct { | |||||||
| 	assigns              []interface{} | 	assigns              []interface{} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | type join struct { | ||||||
|  | 	Name  string | ||||||
|  | 	Conds []interface{} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // StatementModifier statement modifier interface
 | // StatementModifier statement modifier interface
 | ||||||
| type StatementModifier interface { | type StatementModifier interface { | ||||||
| 	ModifyStatement(*Statement) | 	ModifyStatement(*Statement) | ||||||
| @ -401,7 +406,6 @@ func (stmt *Statement) clone() *Statement { | |||||||
| 		Distinct:             stmt.Distinct, | 		Distinct:             stmt.Distinct, | ||||||
| 		Selects:              stmt.Selects, | 		Selects:              stmt.Selects, | ||||||
| 		Omits:                stmt.Omits, | 		Omits:                stmt.Omits, | ||||||
| 		Joins:                map[string][]interface{}{}, |  | ||||||
| 		Preloads:             map[string][]interface{}{}, | 		Preloads:             map[string][]interface{}{}, | ||||||
| 		ConnPool:             stmt.ConnPool, | 		ConnPool:             stmt.ConnPool, | ||||||
| 		Schema:               stmt.Schema, | 		Schema:               stmt.Schema, | ||||||
| @ -417,8 +421,9 @@ func (stmt *Statement) clone() *Statement { | |||||||
| 		newStmt.Preloads[k] = p | 		newStmt.Preloads[k] = p | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	for k, j := range stmt.Joins { | 	if len(stmt.Joins) > 0 { | ||||||
| 		newStmt.Joins[k] = j | 		newStmt.Joins = make([]join, len(stmt.Joins)) | ||||||
|  | 		copy(newStmt.Joins, stmt.Joins) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	stmt.Settings.Range(func(k, v interface{}) bool { | 	stmt.Settings.Range(func(k, v interface{}) bool { | ||||||
|  | |||||||
| @ -1,6 +1,7 @@ | |||||||
| package tests_test | package tests_test | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"regexp" | ||||||
| 	"sort" | 	"sort" | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| @ -88,6 +89,13 @@ func TestJoinConds(t *testing.T) { | |||||||
| 	if db5.Error != nil { | 	if db5.Error != nil { | ||||||
| 		t.Errorf("Should not raise error for join where identical fields in different tables. Error: %s", db5.Error.Error()) | 		t.Errorf("Should not raise error for join where identical fields in different tables. Error: %s", db5.Error.Error()) | ||||||
| 	} | 	} | ||||||
|  | 
 | ||||||
|  | 	dryDB := DB.Session(&gorm.Session{DryRun: true}) | ||||||
|  | 	stmt := dryDB.Joins("left join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number).Where(User{Model: gorm.Model{ID: 1}}).Where(Account{Model: gorm.Model{ID: 1}}).Not(Pet{Model: gorm.Model{ID: 1}}).Find(&users5).Statement | ||||||
|  | 
 | ||||||
|  | 	if !regexp.MustCompile("SELECT .* FROM .users. left join pets.*join accounts.*").MatchString(stmt.SQL.String()) { | ||||||
|  | 		t.Errorf("joins should be ordered, but got %v", stmt.SQL.String()) | ||||||
|  | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestJoinsWithSelect(t *testing.T) { | func TestJoinsWithSelect(t *testing.T) { | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu