Merge 68fccb87b0b527c32e3697dd3469f7a38ddee323 into 725aa5b5ff4c0687b06d9a01096b8e4cf96b6c9e
This commit is contained in:
		
						commit
						7c6506113a
					
				| @ -380,14 +380,68 @@ func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (db *DB) executeScopes() (tx *DB) { | func (db *DB) executeScopes() (tx *DB) { | ||||||
|  | 	if len(db.Statement.scopes) == 0 { | ||||||
|  | 		return db | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	scopes := db.Statement.scopes | 	scopes := db.Statement.scopes | ||||||
| 	db.Statement.scopes = nil | 	db.Statement.scopes = nil | ||||||
|  | 	originClause := db.Statement.Clauses | ||||||
|  | 
 | ||||||
|  | 	// use clean db in scope
 | ||||||
|  | 	cleanDB := db.Session(&Session{}) | ||||||
|  | 	cleanDB.Statement.Clauses = map[string]clause.Clause{} | ||||||
|  | 
 | ||||||
|  | 	txs := make([]*DB, 0, len(scopes)) | ||||||
| 	for _, scope := range scopes { | 	for _, scope := range scopes { | ||||||
| 		db = scope(db) | 		txs = append(txs, scope(cleanDB)) | ||||||
| 	} | 	} | ||||||
|  | 
 | ||||||
|  | 	db.Statement.Clauses = originClause | ||||||
|  | 	db.mergeClauses(txs) | ||||||
| 	return db | 	return db | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (db *DB) mergeClauses(txs []*DB) { | ||||||
|  | 	if len(txs) == 0 { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, tx := range txs { | ||||||
|  | 		stmt := tx.Statement | ||||||
|  | 		if stmt != nil { | ||||||
|  | 			stmtClause := stmt.Clauses | ||||||
|  | 			// merge clauses
 | ||||||
|  | 			if cs, ok := stmtClause["WHERE"]; ok { | ||||||
|  | 				if where, ok := cs.Expression.(clause.Where); ok { | ||||||
|  | 					db.Statement.AddClause(where) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			// cover other expr
 | ||||||
|  | 			if stmt.TableExpr != nil { | ||||||
|  | 				db.Statement.TableExpr = stmt.TableExpr | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if stmt.Table != "" { | ||||||
|  | 				db.Statement.Table = stmt.Table | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if stmt.Model != nil { | ||||||
|  | 				db.Statement.Model = stmt.Model | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if stmt.Selects != nil { | ||||||
|  | 				db.Statement.Selects = stmt.Selects | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if stmt.Omits != nil { | ||||||
|  | 				db.Statement.Omits = stmt.Omits | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // Preload preload associations with given conditions
 | // Preload preload associations with given conditions
 | ||||||
| //
 | //
 | ||||||
| //	// get all users, and preload all non-cancelled orders
 | //	// get all users, and preload all non-cancelled orders
 | ||||||
|  | |||||||
| @ -335,7 +335,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] | |||||||
| 		case clause.Expression: | 		case clause.Expression: | ||||||
| 			conds = append(conds, v) | 			conds = append(conds, v) | ||||||
| 		case *DB: | 		case *DB: | ||||||
| 			v.executeScopes() | 			v = v.executeScopes() | ||||||
| 
 | 
 | ||||||
| 			if cs, ok := v.Statement.Clauses["WHERE"]; ok { | 			if cs, ok := v.Statement.Clauses["WHERE"]; ok { | ||||||
| 				if where, ok := cs.Expression.(clause.Where); ok { | 				if where, ok := cs.Expression.(clause.Where); ok { | ||||||
| @ -346,6 +346,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] | |||||||
| 							} | 							} | ||||||
| 						} | 						} | ||||||
| 					} | 					} | ||||||
|  | 
 | ||||||
| 					conds = append(conds, clause.And(where.Exprs...)) | 					conds = append(conds, clause.And(where.Exprs...)) | ||||||
| 				} else if cs.Expression != nil { | 				} else if cs.Expression != nil { | ||||||
| 					conds = append(conds, cs.Expression) | 					conds = append(conds, cs.Expression) | ||||||
|  | |||||||
| @ -90,7 +90,31 @@ func TestComplexScopes(t *testing.T) { | |||||||
| 				).Find(&Language{}) | 				).Find(&Language{}) | ||||||
| 			}, | 			}, | ||||||
| 			expected: `SELECT * FROM "languages" WHERE a = 1 AND (b = 2 OR c = 3)`, | 			expected: `SELECT * FROM "languages" WHERE a = 1 AND (b = 2 OR c = 3)`, | ||||||
| 		}, { | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "group_cond", | ||||||
|  | 			queryFn: func(tx *gorm.DB) *gorm.DB { | ||||||
|  | 				return tx.Scopes( | ||||||
|  | 					func(d *gorm.DB) *gorm.DB { return d.Table("languages1") }, | ||||||
|  | 					func(d *gorm.DB) *gorm.DB { return d.Table("languages2") }, | ||||||
|  | 					func(d *gorm.DB) *gorm.DB { | ||||||
|  | 						return d.Where( | ||||||
|  | 							d.Where("a = 1").Or("b = 2"), | ||||||
|  | 						) | ||||||
|  | 					}, | ||||||
|  | 					func(d *gorm.DB) *gorm.DB { | ||||||
|  | 						return d.Select("f1, f2") | ||||||
|  | 					}, | ||||||
|  | 					func(d *gorm.DB) *gorm.DB { | ||||||
|  | 						return d.Where( | ||||||
|  | 							d.Where("c = 3"), | ||||||
|  | 						) | ||||||
|  | 					}, | ||||||
|  | 				).Find(&Language{}) | ||||||
|  | 			}, | ||||||
|  | 			expected: `SELECT f1, f2 FROM "languages2" WHERE (a = 1 OR b = 2) AND c = 3`, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
| 			name: "depth_1_pre_cond", | 			name: "depth_1_pre_cond", | ||||||
| 			queryFn: func(tx *gorm.DB) *gorm.DB { | 			queryFn: func(tx *gorm.DB) *gorm.DB { | ||||||
| 				return tx.Where("z = 0").Scopes( | 				return tx.Where("z = 0").Scopes( | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Cr.
						Cr.