improve generics version Joins support
This commit is contained in:
		
							parent
							
								
									05925b2fc0
								
							
						
					
					
						commit
						d073805c86
					
				| @ -146,9 +146,13 @@ func BuildQuerySQL(db *gorm.DB) { | ||||
| 
 | ||||
| 					if isRelations { | ||||
| 						genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join { | ||||
| 							tableAliasName := relation.Name | ||||
| 							if parentTableName != clause.CurrentTable { | ||||
| 								tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName) | ||||
| 							tableAliasName := join.Alias | ||||
| 
 | ||||
| 							if tableAliasName == "" { | ||||
| 								tableAliasName = relation.Name | ||||
| 								if parentTableName != clause.CurrentTable { | ||||
| 									tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName) | ||||
| 								} | ||||
| 							} | ||||
| 
 | ||||
| 							columnStmt := gorm.Statement{ | ||||
| @ -167,6 +171,13 @@ func BuildQuerySQL(db *gorm.DB) { | ||||
| 								} | ||||
| 							} | ||||
| 
 | ||||
| 							if join.Expression != nil { | ||||
| 								return clause.Join{ | ||||
| 									Type:       join.JoinType, | ||||
| 									Expression: join.Expression, | ||||
| 								} | ||||
| 							} | ||||
| 
 | ||||
| 							exprs := make([]clause.Expression, len(relation.References)) | ||||
| 							for idx, ref := range relation.References { | ||||
| 								if ref.OwnPrimaryKey { | ||||
|  | ||||
| @ -19,15 +19,15 @@ type JoinTarget struct { | ||||
| } | ||||
| 
 | ||||
| func Has(name string) JoinTarget { | ||||
| 	return JoinTarget{Type: LeftJoin, Association: name} | ||||
| 	return JoinTarget{Type: InnerJoin, Association: name} | ||||
| } | ||||
| 
 | ||||
| func (jt JoinType) Association(name string) JoinTarget { | ||||
| 	return JoinTarget{Type: jt, Association: name} | ||||
| } | ||||
| 
 | ||||
| func (jt JoinType) Subquery(subquery Expression) JoinTarget { | ||||
| 	return JoinTarget{Type: jt, Subquery: subquery} | ||||
| func (jt JoinType) AssociationFrom(name string, subquery Expression) JoinTarget { | ||||
| 	return JoinTarget{Type: jt, Association: name, Subquery: subquery} | ||||
| } | ||||
| 
 | ||||
| func (jt JoinTarget) As(name string) JoinTarget { | ||||
|  | ||||
							
								
								
									
										126
									
								
								generics.go
									
									
									
									
									
								
							
							
						
						
									
										126
									
								
								generics.go
									
									
									
									
									
								
							| @ -3,8 +3,11 @@ package gorm | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"gorm.io/gorm/clause" | ||||
| 	"gorm.io/gorm/logger" | ||||
| ) | ||||
| 
 | ||||
| type Interface[T any] interface { | ||||
| @ -28,7 +31,7 @@ type ChainInterface[T any] interface { | ||||
| 	Or(query interface{}, args ...interface{}) ChainInterface[T] | ||||
| 	Limit(offset int) ChainInterface[T] | ||||
| 	Offset(offset int) ChainInterface[T] | ||||
| 	Joins(query clause.JoinTarget, args func(db ChainInterface[any], joinTable clause.Table, curTable clause.Table) ChainInterface[any]) ChainInterface[T] | ||||
| 	Joins(query clause.JoinTarget, args func(db QueryInterface, joinTable clause.Table, curTable clause.Table) QueryInterface) ChainInterface[T] | ||||
| 	Select(query string, args ...interface{}) ChainInterface[T] | ||||
| 	Omit(columns ...string) ChainInterface[T] | ||||
| 	MapColumns(m map[string]string) ChainInterface[T] | ||||
| @ -38,6 +41,8 @@ type ChainInterface[T any] interface { | ||||
| 	Order(value interface{}) ChainInterface[T] | ||||
| 	Preload(query string, args ...interface{}) ChainInterface[T] | ||||
| 
 | ||||
| 	Build(builder clause.Builder) | ||||
| 
 | ||||
| 	Delete(ctx context.Context) (rowsAffected int, err error) | ||||
| 	Update(ctx context.Context, name string, value any) (rowsAffected int, err error) | ||||
| 	Updates(ctx context.Context, t T) (rowsAffected int, err error) | ||||
| @ -55,6 +60,14 @@ type ExecInterface[T any] interface { | ||||
| 	Rows(ctx context.Context) (*sql.Rows, error) | ||||
| } | ||||
| 
 | ||||
| type QueryInterface interface { | ||||
| 	Select(...string) QueryInterface | ||||
| 	Omit(...string) QueryInterface | ||||
| 	Where(query interface{}, args ...interface{}) QueryInterface | ||||
| 	Not(query interface{}, args ...interface{}) QueryInterface | ||||
| 	Or(query interface{}, args ...interface{}) QueryInterface | ||||
| } | ||||
| 
 | ||||
| type op func(*DB) *DB | ||||
| 
 | ||||
| func G[T any](db *DB, opts ...clause.Expression) Interface[T] { | ||||
| @ -185,9 +198,77 @@ func (c chainG[T]) Offset(offset int) ChainInterface[T] { | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func (c chainG[T]) Joins(query clause.JoinTarget, args func(db ChainInterface[any], joinTable clause.Table, curTable clause.Table) ChainInterface[any]) ChainInterface[T] { | ||||
| 	// TODO
 | ||||
| 	return nil | ||||
| type query struct { | ||||
| 	db *DB | ||||
| } | ||||
| 
 | ||||
| func (q query) Where(query interface{}, args ...interface{}) QueryInterface { | ||||
| 	q.db.Where(query, args...) | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| func (q query) Or(query interface{}, args ...interface{}) QueryInterface { | ||||
| 	q.db.Where(query, args...) | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| func (q query) Not(query interface{}, args ...interface{}) QueryInterface { | ||||
| 	q.db.Where(query, args...) | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| func (q query) Select(columns ...string) QueryInterface { | ||||
| 	q.db.Select(columns) | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| func (q query) Omit(columns ...string) QueryInterface { | ||||
| 	q.db.Omit(columns...) | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| func (c chainG[T]) Joins(jt clause.JoinTarget, args func(db QueryInterface, joinTable clause.Table, curTable clause.Table) QueryInterface) ChainInterface[T] { | ||||
| 	return c.with(func(db *DB) *DB { | ||||
| 		if jt.Table == "" { | ||||
| 			jt.Table = clause.JoinTable(strings.Split(jt.Association, ".")...).Name | ||||
| 		} | ||||
| 
 | ||||
| 		q := query{db: db.Session(&Session{NewDB: true}).getInstance().Table(jt.Table)} | ||||
| 		if args != nil { | ||||
| 			args(q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable}) | ||||
| 		} | ||||
| 
 | ||||
| 		j := join{ | ||||
| 			Name:     jt.Association, | ||||
| 			Alias:    jt.Table, | ||||
| 			Selects:  q.db.Statement.Selects, | ||||
| 			Omits:    q.db.Statement.Omits, | ||||
| 			JoinType: jt.Type, | ||||
| 		} | ||||
| 
 | ||||
| 		if where, ok := q.db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok { | ||||
| 			j.On = &where | ||||
| 		} | ||||
| 
 | ||||
| 		if jt.Subquery != nil { | ||||
| 			joinType := j.JoinType | ||||
| 			if joinType == "" { | ||||
| 				joinType = clause.LeftJoin | ||||
| 			} | ||||
| 
 | ||||
| 			expr := clause.NamedExpr{SQL: fmt.Sprintf("%s JOIN (?) AS ?", joinType), Vars: []interface{}{jt.Subquery, clause.Table{Name: j.Alias}}} | ||||
| 
 | ||||
| 			if j.On != nil { | ||||
| 				expr.SQL += " ON ?" | ||||
| 				expr.Vars = append(expr.Vars, clause.AndConditions{Exprs: j.On.Exprs}) | ||||
| 			} | ||||
| 
 | ||||
| 			j.Expression = expr | ||||
| 		} | ||||
| 
 | ||||
| 		db.Statement.Joins = append(db.Statement.Joins, j) | ||||
| 		return db | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func (c chainG[T]) Select(query string, args ...interface{}) ChainInterface[T] { | ||||
| @ -261,6 +342,43 @@ func (c chainG[T]) Count(ctx context.Context, column string) (result int64, err | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (c chainG[T]) Build(builder clause.Builder) { | ||||
| 	subdb := c.getInstance() | ||||
| 	subdb.Logger = logger.Discard | ||||
| 	subdb.DryRun = true | ||||
| 
 | ||||
| 	if stmt, ok := builder.(*Statement); ok { | ||||
| 		if subdb.Statement.SQL.Len() > 0 { | ||||
| 			var ( | ||||
| 				vars = subdb.Statement.Vars | ||||
| 				sql  = subdb.Statement.SQL.String() | ||||
| 			) | ||||
| 
 | ||||
| 			subdb.Statement.Vars = make([]interface{}, 0, len(vars)) | ||||
| 			for _, vv := range vars { | ||||
| 				subdb.Statement.Vars = append(subdb.Statement.Vars, vv) | ||||
| 				bindvar := strings.Builder{} | ||||
| 				subdb.BindVarTo(&bindvar, subdb.Statement, vv) | ||||
| 				sql = strings.Replace(sql, bindvar.String(), "?", 1) | ||||
| 			} | ||||
| 
 | ||||
| 			subdb.Statement.SQL.Reset() | ||||
| 			subdb.Statement.Vars = stmt.Vars | ||||
| 			if strings.Contains(sql, "@") { | ||||
| 				clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement) | ||||
| 			} else { | ||||
| 				clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement) | ||||
| 			} | ||||
| 		} else { | ||||
| 			subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...) | ||||
| 			subdb.callbacks.Query().Execute(subdb) | ||||
| 		} | ||||
| 
 | ||||
| 		builder.WriteString(subdb.Statement.SQL.String()) | ||||
| 		stmt.Vars = subdb.Statement.Vars | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type execG[T any] struct { | ||||
| 	g *g[T] | ||||
| } | ||||
|  | ||||
							
								
								
									
										7
									
								
								scan.go
									
									
									
									
									
								
							
							
						
						
									
										7
									
								
								scan.go
									
									
									
									
									
								
							| @ -4,6 +4,7 @@ import ( | ||||
| 	"database/sql" | ||||
| 	"database/sql/driver" | ||||
| 	"reflect" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"gorm.io/gorm/schema" | ||||
| @ -244,6 +245,12 @@ func Scan(rows Rows, db *DB, mode ScanMode) { | ||||
| 							matchedFieldCount[column] = 1 | ||||
| 						} | ||||
| 					} else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation
 | ||||
| 						for _, join := range db.Statement.Joins { | ||||
| 							if join.Alias == names[0] { | ||||
| 								names = append(strings.Split(join.Name, "."), names[len(names)-1]) | ||||
| 							} | ||||
| 						} | ||||
| 
 | ||||
| 						if rel, ok := sch.Relationships.Relations[names[0]]; ok { | ||||
| 							subNameCount := len(names) | ||||
| 							// nested relation fields
 | ||||
|  | ||||
							
								
								
									
										43
									
								
								statement.go
									
									
									
									
									
								
							
							
						
						
									
										43
									
								
								statement.go
									
									
									
									
									
								
							| @ -50,12 +50,14 @@ type Statement struct { | ||||
| } | ||||
| 
 | ||||
| type join struct { | ||||
| 	Name     string | ||||
| 	Conds    []interface{} | ||||
| 	On       *clause.Where | ||||
| 	Selects  []string | ||||
| 	Omits    []string | ||||
| 	JoinType clause.JoinType | ||||
| 	Name       string | ||||
| 	Alias      string | ||||
| 	Conds      []interface{} | ||||
| 	On         *clause.Where | ||||
| 	Selects    []string | ||||
| 	Omits      []string | ||||
| 	Expression clause.Expression | ||||
| 	JoinType   clause.JoinType | ||||
| } | ||||
| 
 | ||||
| // StatementModifier statement modifier interface
 | ||||
| @ -322,6 +324,11 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] | ||||
| 			arg, _ = valuer.Value() | ||||
| 		} | ||||
| 
 | ||||
| 		curTable := stmt.Table | ||||
| 		if curTable == "" { | ||||
| 			curTable = clause.CurrentTable | ||||
| 		} | ||||
| 
 | ||||
| 		switch v := arg.(type) { | ||||
| 		case clause.Expression: | ||||
| 			conds = append(conds, v) | ||||
| @ -352,7 +359,8 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] | ||||
| 			sort.Strings(keys) | ||||
| 
 | ||||
| 			for _, key := range keys { | ||||
| 				conds = append(conds, clause.Eq{Column: key, Value: v[key]}) | ||||
| 				column := clause.Column{Name: key, Table: curTable} | ||||
| 				conds = append(conds, clause.Eq{Column: column, Value: v[key]}) | ||||
| 			} | ||||
| 		case map[string]interface{}: | ||||
| 			keys := make([]string, 0, len(v)) | ||||
| @ -363,12 +371,13 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] | ||||
| 
 | ||||
| 			for _, key := range keys { | ||||
| 				reflectValue := reflect.Indirect(reflect.ValueOf(v[key])) | ||||
| 				column := clause.Column{Name: key, Table: curTable} | ||||
| 				switch reflectValue.Kind() { | ||||
| 				case reflect.Slice, reflect.Array: | ||||
| 					if _, ok := v[key].(driver.Valuer); ok { | ||||
| 						conds = append(conds, clause.Eq{Column: key, Value: v[key]}) | ||||
| 						conds = append(conds, clause.Eq{Column: column, Value: v[key]}) | ||||
| 					} else if _, ok := v[key].(Valuer); ok { | ||||
| 						conds = append(conds, clause.Eq{Column: key, Value: v[key]}) | ||||
| 						conds = append(conds, clause.Eq{Column: column, Value: v[key]}) | ||||
| 					} else { | ||||
| 						// optimize reflect value length
 | ||||
| 						valueLen := reflectValue.Len() | ||||
| @ -377,10 +386,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] | ||||
| 							values[i] = reflectValue.Index(i).Interface() | ||||
| 						} | ||||
| 
 | ||||
| 						conds = append(conds, clause.IN{Column: key, Values: values}) | ||||
| 						conds = append(conds, clause.IN{Column: column, Values: values}) | ||||
| 					} | ||||
| 				default: | ||||
| 					conds = append(conds, clause.Eq{Column: key, Value: v[key]}) | ||||
| 					conds = append(conds, clause.Eq{Column: column, Value: v[key]}) | ||||
| 				} | ||||
| 			} | ||||
| 		default: | ||||
| @ -407,9 +416,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] | ||||
| 						if selected || (!restricted && field.Readable) { | ||||
| 							if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected { | ||||
| 								if field.DBName != "" { | ||||
| 									conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) | ||||
| 									conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.DBName}, Value: v}) | ||||
| 								} else if field.DataType != "" { | ||||
| 									conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v}) | ||||
| 									conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.Name}, Value: v}) | ||||
| 								} | ||||
| 							} | ||||
| 						} | ||||
| @ -421,9 +430,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] | ||||
| 							if selected || (!restricted && field.Readable) { | ||||
| 								if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected { | ||||
| 									if field.DBName != "" { | ||||
| 										conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) | ||||
| 										conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.DBName}, Value: v}) | ||||
| 									} else if field.DataType != "" { | ||||
| 										conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v}) | ||||
| 										conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.Name}, Value: v}) | ||||
| 									} | ||||
| 								} | ||||
| 							} | ||||
| @ -448,14 +457,14 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] | ||||
| 						} | ||||
| 
 | ||||
| 						if len(values) > 0 { | ||||
| 							conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values}) | ||||
| 							conds = append(conds, clause.IN{Column: clause.Column{Table: curTable, Name: clause.PrimaryKey}, Values: values}) | ||||
| 							return []clause.Expression{clause.And(conds...)} | ||||
| 						} | ||||
| 						return nil | ||||
| 					} | ||||
| 				} | ||||
| 
 | ||||
| 				conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args}) | ||||
| 				conds = append(conds, clause.IN{Column: clause.Column{Table: curTable, Name: clause.PrimaryKey}, Values: args}) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| @ -44,7 +44,7 @@ func TestGenericsCreate(t *testing.T) { | ||||
| 
 | ||||
| 	if u, err := gorm.G[User](DB).Omit("name").Where("name = ?", user.Name).First(ctx); err != nil { | ||||
| 		t.Fatalf("failed to find user, got error: %v", err) | ||||
| 	} else if u.Name != "" || u.Age != u.Age { | ||||
| 	} else if u.Name != "" || u.Age != user.Age { | ||||
| 		t.Errorf("found invalid user, got %v, expect %v", u, user) | ||||
| 	} | ||||
| 
 | ||||
| @ -92,7 +92,6 @@ func TestGenericsCreateInBatches(t *testing.T) { | ||||
| 
 | ||||
| 	found, err := gorm.G[User](DB).Raw("SELECT * FROM users WHERE name LIKE ?", "GenericsCreateInBatches%").Find(ctx) | ||||
| 	if len(found) != len(batch) { | ||||
| 		fmt.Println(found) | ||||
| 		t.Errorf("expected %d from Raw Find, got %d", len(batch), len(found)) | ||||
| 	} | ||||
| 
 | ||||
| @ -282,10 +281,12 @@ func TestGenericsJoinsAndPreload(t *testing.T) { | ||||
| 	db := gorm.G[User](DB) | ||||
| 
 | ||||
| 	u := User{Name: "GenericsJoins", Company: Company{Name: "GenericsCompany"}} | ||||
| 	db.Create(ctx, &u) | ||||
| 	u2 := User{Name: "GenericsJoins_2", Company: Company{Name: "GenericsCompany_2"}} | ||||
| 	u3 := User{Name: "GenericsJoins_3", Company: Company{Name: "GenericsCompany_3"}} | ||||
| 	db.CreateInBatches(ctx, &[]User{u3, u, u2}, 10) | ||||
| 
 | ||||
| 	// LEFT JOIN + WHERE
 | ||||
| 	result, err := db.Joins(clause.Has("Company"), func(db gorm.ChainInterface[any], joinTable clause.Table, curTable clause.Table) gorm.ChainInterface[any] { | ||||
| 	// Inner JOIN + WHERE
 | ||||
| 	result, err := db.Joins(clause.Has("Company"), func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface { | ||||
| 		return db.Where("?.name = ?", joinTable, u.Company.Name) | ||||
| 	}).First(ctx) | ||||
| 	if err != nil { | ||||
| @ -295,9 +296,9 @@ func TestGenericsJoinsAndPreload(t *testing.T) { | ||||
| 		t.Fatalf("Joins expected %s, got %+v", u.Name, result) | ||||
| 	} | ||||
| 
 | ||||
| 	// JOIN
 | ||||
| 	result, err = db.Joins(clause.Has("Company"), func(db gorm.ChainInterface[any], joinTable clause.Table, curTable clause.Table) gorm.ChainInterface[any] { | ||||
| 		return nil | ||||
| 	// Inner JOIN + WHERE with map
 | ||||
| 	result, err = db.Joins(clause.Has("Company"), func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface { | ||||
| 		return db.Where(map[string]any{"name": u.Company.Name}) | ||||
| 	}).First(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Joins failed: %v", err) | ||||
| @ -306,10 +307,8 @@ func TestGenericsJoinsAndPreload(t *testing.T) { | ||||
| 		t.Fatalf("Joins expected %s, got %+v", u.Name, result) | ||||
| 	} | ||||
| 
 | ||||
| 	// Left JOIN
 | ||||
| 	result, err = db.Joins(clause.LeftJoin.Association("Company").As("t"), func(db gorm.ChainInterface[any], joinTable clause.Table, curTable clause.Table) gorm.ChainInterface[any] { | ||||
| 		return nil | ||||
| 	}).First(ctx) | ||||
| 	// Left JOIN w/o WHERE
 | ||||
| 	result, err = db.Joins(clause.LeftJoin.Association("Company"), nil).Where(map[string]any{"name": u.Name}).First(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Joins failed: %v", err) | ||||
| 	} | ||||
| @ -317,6 +316,36 @@ func TestGenericsJoinsAndPreload(t *testing.T) { | ||||
| 		t.Fatalf("Joins expected %s, got %+v", u.Name, result) | ||||
| 	} | ||||
| 
 | ||||
| 	// Left JOIN + Alias WHERE
 | ||||
| 	result, err = db.Joins(clause.LeftJoin.Association("Company").As("t"), func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface { | ||||
| 		if joinTable.Name != "t" { | ||||
| 			t.Fatalf("Join table should be t, but got %v", joinTable.Name) | ||||
| 		} | ||||
| 		return db.Where("?.name = ?", joinTable, u.Company.Name) | ||||
| 	}).Where(map[string]any{"name": u.Name}).First(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Joins failed: %v", err) | ||||
| 	} | ||||
| 	if result.Name != u.Name || result.Company.Name != u.Company.Name { | ||||
| 		t.Fatalf("Joins expected %s, got %+v", u.Name, result) | ||||
| 	} | ||||
| 
 | ||||
| 	// Raw Subquery JOIN + WHERE
 | ||||
| 	result, err = db.Joins(clause.LeftJoin.AssociationFrom("Company", gorm.G[Company](DB)).As("t"), | ||||
| 		func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface { | ||||
| 			if joinTable.Name != "t" { | ||||
| 				t.Fatalf("Join table should be t, but got %v", joinTable.Name) | ||||
| 			} | ||||
| 			return db.Where("?.name = ?", joinTable, u.Company.Name) | ||||
| 		}, | ||||
| 	).Where(map[string]any{"name": u2.Name}).First(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Raw subquery join failed: %v", err) | ||||
| 	} | ||||
| 	if result.Name != u2.Name || result.Company.Name != u.Company.Name { | ||||
| 		t.Fatalf("Joins expected %s, got %+v", u.Name, result) | ||||
| 	} | ||||
| 
 | ||||
| 	// Preload
 | ||||
| 	result3, err := db.Preload("Company").Where("name = ?", u.Name).First(ctx) | ||||
| 	if err != nil { | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu