finish generic version Preload
This commit is contained in:
		
							parent
							
								
									ba94e4eb2f
								
							
						
					
					
						commit
						4694673526
					
				| @ -152,7 +152,7 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati | ||||
| 					return gorm.ErrInvalidData | ||||
| 				} | ||||
| 			} else { | ||||
| 				tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}) | ||||
| 				tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks, Initialized: true}) | ||||
| 				tx.Statement.ReflectValue = db.Statement.ReflectValue | ||||
| 				tx.Statement.Unscoped = db.Statement.Unscoped | ||||
| 				if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil { | ||||
|  | ||||
							
								
								
									
										102
									
								
								generics.go
									
									
									
									
									
								
							
							
						
						
									
										102
									
								
								generics.go
									
									
									
									
									
								
							| @ -31,7 +31,8 @@ type ChainInterface[T any] interface { | ||||
| 	Or(query interface{}, args ...interface{}) ChainInterface[T] | ||||
| 	Limit(offset int) ChainInterface[T] | ||||
| 	Offset(offset int) ChainInterface[T] | ||||
| 	Joins(query clause.JoinTarget, args func(db QueryInterface, joinTable clause.Table, curTable clause.Table) QueryInterface) ChainInterface[T] | ||||
| 	Joins(query clause.JoinTarget, args func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] | ||||
| 	Preload(association string, args func(db PreloadBuilder) error) ChainInterface[T] | ||||
| 	Select(query string, args ...interface{}) ChainInterface[T] | ||||
| 	Omit(columns ...string) ChainInterface[T] | ||||
| 	MapColumns(m map[string]string) ChainInterface[T] | ||||
| @ -39,7 +40,6 @@ type ChainInterface[T any] interface { | ||||
| 	Group(name string) ChainInterface[T] | ||||
| 	Having(query interface{}, args ...interface{}) ChainInterface[T] | ||||
| 	Order(value interface{}) ChainInterface[T] | ||||
| 	Preload(query string, args ...interface{}) ChainInterface[T] | ||||
| 
 | ||||
| 	Build(builder clause.Builder) | ||||
| 
 | ||||
| @ -60,12 +60,24 @@ type ExecInterface[T any] interface { | ||||
| 	Rows(ctx context.Context) (*sql.Rows, error) | ||||
| } | ||||
| 
 | ||||
| type QueryInterface interface { | ||||
| 	Select(...string) QueryInterface | ||||
| 	Omit(...string) QueryInterface | ||||
| 	Where(query interface{}, args ...interface{}) QueryInterface | ||||
| 	Not(query interface{}, args ...interface{}) QueryInterface | ||||
| 	Or(query interface{}, args ...interface{}) QueryInterface | ||||
| type JoinBuilder interface { | ||||
| 	Select(...string) JoinBuilder | ||||
| 	Omit(...string) JoinBuilder | ||||
| 	Where(query interface{}, args ...interface{}) JoinBuilder | ||||
| 	Not(query interface{}, args ...interface{}) JoinBuilder | ||||
| 	Or(query interface{}, args ...interface{}) JoinBuilder | ||||
| } | ||||
| 
 | ||||
| type PreloadBuilder interface { | ||||
| 	Select(...string) PreloadBuilder | ||||
| 	Omit(...string) PreloadBuilder | ||||
| 	Where(query interface{}, args ...interface{}) PreloadBuilder | ||||
| 	Not(query interface{}, args ...interface{}) PreloadBuilder | ||||
| 	Or(query interface{}, args ...interface{}) PreloadBuilder | ||||
| 	Limit(offset int) PreloadBuilder | ||||
| 	Offset(offset int) PreloadBuilder | ||||
| 	Order(value interface{}) PreloadBuilder | ||||
| 	Scopes(scopes ...func(db *Statement)) PreloadBuilder | ||||
| } | ||||
| 
 | ||||
| type op func(*DB) *DB | ||||
| @ -198,42 +210,90 @@ func (c chainG[T]) Offset(offset int) ChainInterface[T] { | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| type query struct { | ||||
| type joinBuilder struct { | ||||
| 	db *DB | ||||
| } | ||||
| 
 | ||||
| func (q query) Where(query interface{}, args ...interface{}) QueryInterface { | ||||
| func (q joinBuilder) Where(query interface{}, args ...interface{}) JoinBuilder { | ||||
| 	q.db.Where(query, args...) | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| func (q query) Or(query interface{}, args ...interface{}) QueryInterface { | ||||
| func (q joinBuilder) Or(query interface{}, args ...interface{}) JoinBuilder { | ||||
| 	q.db.Where(query, args...) | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| func (q query) Not(query interface{}, args ...interface{}) QueryInterface { | ||||
| func (q joinBuilder) Not(query interface{}, args ...interface{}) JoinBuilder { | ||||
| 	q.db.Where(query, args...) | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| func (q query) Select(columns ...string) QueryInterface { | ||||
| func (q joinBuilder) Select(columns ...string) JoinBuilder { | ||||
| 	q.db.Select(columns) | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| func (q query) Omit(columns ...string) QueryInterface { | ||||
| func (q joinBuilder) Omit(columns ...string) JoinBuilder { | ||||
| 	q.db.Omit(columns...) | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| func (c chainG[T]) Joins(jt clause.JoinTarget, args func(db QueryInterface, joinTable clause.Table, curTable clause.Table) QueryInterface) ChainInterface[T] { | ||||
| type preloadBuilder struct { | ||||
| 	db *DB | ||||
| } | ||||
| 
 | ||||
| func (q preloadBuilder) Where(query interface{}, args ...interface{}) PreloadBuilder { | ||||
| 	q.db.Where(query, args...) | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| func (q preloadBuilder) Or(query interface{}, args ...interface{}) PreloadBuilder { | ||||
| 	q.db.Where(query, args...) | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| func (q preloadBuilder) Not(query interface{}, args ...interface{}) PreloadBuilder { | ||||
| 	q.db.Where(query, args...) | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| func (q preloadBuilder) Select(columns ...string) PreloadBuilder { | ||||
| 	q.db.Select(columns) | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| func (q preloadBuilder) Omit(columns ...string) PreloadBuilder { | ||||
| 	q.db.Omit(columns...) | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| func (q preloadBuilder) Limit(limit int) PreloadBuilder { | ||||
| 	q.db.Limit(limit) | ||||
| 	return q | ||||
| } | ||||
| func (q preloadBuilder) Offset(offset int) PreloadBuilder { | ||||
| 	q.db.Offset(offset) | ||||
| 	return q | ||||
| } | ||||
| func (q preloadBuilder) Order(value interface{}) PreloadBuilder { | ||||
| 	q.db.Order(value) | ||||
| 	return q | ||||
| } | ||||
| func (q preloadBuilder) Scopes(scopes ...func(db *Statement)) PreloadBuilder { | ||||
| 	for _, fc := range scopes { | ||||
| 		fc(q.db.Statement) | ||||
| 	} | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| func (c chainG[T]) Joins(jt clause.JoinTarget, args func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] { | ||||
| 	return c.with(func(db *DB) *DB { | ||||
| 		if jt.Table == "" { | ||||
| 			jt.Table = clause.JoinTable(strings.Split(jt.Association, ".")...).Name | ||||
| 		} | ||||
| 
 | ||||
| 		q := query{db: db.Session(&Session{NewDB: true}).getInstance().Table(jt.Table)} | ||||
| 		q := joinBuilder{db: db.Session(&Session{NewDB: true}).getInstance().Table(jt.Table)} | ||||
| 		if args != nil { | ||||
| 			args(q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable}) | ||||
| 		} | ||||
| @ -323,9 +383,15 @@ func (c chainG[T]) Order(value interface{}) ChainInterface[T] { | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func (c chainG[T]) Preload(query string, args ...interface{}) ChainInterface[T] { | ||||
| func (c chainG[T]) Preload(association string, args func(db PreloadBuilder) error) ChainInterface[T] { | ||||
| 	return c.with(func(db *DB) *DB { | ||||
| 		return db.Preload(query, args...) | ||||
| 		return db.Preload(association, func(db *DB) *DB { | ||||
| 			q := preloadBuilder{db: db} | ||||
| 			if args != nil { | ||||
| 				args(q) | ||||
| 			} | ||||
| 			return q.db | ||||
| 		}) | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
|  | ||||
| @ -286,8 +286,9 @@ func TestGenericsJoinsAndPreload(t *testing.T) { | ||||
| 	db.CreateInBatches(ctx, &[]User{u3, u, u2}, 10) | ||||
| 
 | ||||
| 	// Inner JOIN + WHERE
 | ||||
| 	result, err := db.Joins(clause.Has("Company"), func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface { | ||||
| 		return db.Where("?.name = ?", joinTable, u.Company.Name) | ||||
| 	result, err := db.Joins(clause.Has("Company"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error { | ||||
| 		db.Where("?.name = ?", joinTable, u.Company.Name) | ||||
| 		return nil | ||||
| 	}).First(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Joins failed: %v", err) | ||||
| @ -297,8 +298,9 @@ func TestGenericsJoinsAndPreload(t *testing.T) { | ||||
| 	} | ||||
| 
 | ||||
| 	// Inner JOIN + WHERE with map
 | ||||
| 	result, err = db.Joins(clause.Has("Company"), func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface { | ||||
| 		return db.Where(map[string]any{"name": u.Company.Name}) | ||||
| 	result, err = db.Joins(clause.Has("Company"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error { | ||||
| 		db.Where(map[string]any{"name": u.Company.Name}) | ||||
| 		return nil | ||||
| 	}).First(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Joins failed: %v", err) | ||||
| @ -317,11 +319,12 @@ func TestGenericsJoinsAndPreload(t *testing.T) { | ||||
| 	} | ||||
| 
 | ||||
| 	// Left JOIN + Alias WHERE
 | ||||
| 	result, err = db.Joins(clause.LeftJoin.Association("Company").As("t"), func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface { | ||||
| 	result, err = db.Joins(clause.LeftJoin.Association("Company").As("t"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error { | ||||
| 		if joinTable.Name != "t" { | ||||
| 			t.Fatalf("Join table should be t, but got %v", joinTable.Name) | ||||
| 		} | ||||
| 		return db.Where("?.name = ?", joinTable, u.Company.Name) | ||||
| 		db.Where("?.name = ?", joinTable, u.Company.Name) | ||||
| 		return nil | ||||
| 	}).Where(map[string]any{"name": u.Name}).First(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Joins failed: %v", err) | ||||
| @ -332,11 +335,12 @@ func TestGenericsJoinsAndPreload(t *testing.T) { | ||||
| 
 | ||||
| 	// Raw Subquery JOIN + WHERE
 | ||||
| 	result, err = db.Joins(clause.LeftJoin.AssociationFrom("Company", gorm.G[Company](DB)).As("t"), | ||||
| 		func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface { | ||||
| 		func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error { | ||||
| 			if joinTable.Name != "t" { | ||||
| 				t.Fatalf("Join table should be t, but got %v", joinTable.Name) | ||||
| 			} | ||||
| 			return db.Where("?.name = ?", joinTable, u.Company.Name) | ||||
| 			db.Where("?.name = ?", joinTable, u.Company.Name) | ||||
| 			return nil | ||||
| 		}, | ||||
| 	).Where(map[string]any{"name": u2.Name}).First(ctx) | ||||
| 	if err != nil { | ||||
| @ -348,11 +352,12 @@ func TestGenericsJoinsAndPreload(t *testing.T) { | ||||
| 
 | ||||
| 	// Raw Subquery JOIN + WHERE + Select
 | ||||
| 	result, err = db.Joins(clause.LeftJoin.AssociationFrom("Company", gorm.G[Company](DB).Select("Name")).As("t"), | ||||
| 		func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface { | ||||
| 		func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error { | ||||
| 			if joinTable.Name != "t" { | ||||
| 				t.Fatalf("Join table should be t, but got %v", joinTable.Name) | ||||
| 			} | ||||
| 			return db.Where("?.name = ?", joinTable, u.Company.Name) | ||||
| 			db.Where("?.name = ?", joinTable, u.Company.Name) | ||||
| 			return nil | ||||
| 		}, | ||||
| 	).Where(map[string]any{"name": u2.Name}).First(ctx) | ||||
| 	if err != nil { | ||||
| @ -363,12 +368,29 @@ func TestGenericsJoinsAndPreload(t *testing.T) { | ||||
| 	} | ||||
| 
 | ||||
| 	// Preload
 | ||||
| 	result3, err := db.Preload("Company").Where("name = ?", u.Name).First(ctx) | ||||
| 	result3, err := db.Preload("Company", nil).Where("name = ?", u.Name).First(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Joins failed: %v", err) | ||||
| 		t.Fatalf("Preload failed: %v", err) | ||||
| 	} | ||||
| 	if result3.Name != u.Name || result3.Company.Name != u.Company.Name { | ||||
| 		t.Fatalf("Joins expected %s, got %+v", u.Name, result) | ||||
| 		t.Fatalf("Preload expected %s, got %+v", u.Name, result) | ||||
| 	} | ||||
| 
 | ||||
| 	results, err := db.Preload("Company", func(db gorm.PreloadBuilder) error { | ||||
| 		db.Where("name = ?", u.Company.Name) | ||||
| 		return nil | ||||
| 	}).Find(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Preload failed: %v", err) | ||||
| 	} | ||||
| 	for _, result := range results { | ||||
| 		if result.Name == u.Name { | ||||
| 			if result.Company.Name != u.Company.Name { | ||||
| 				t.Fatalf("Preload user %v company should be %v, but got %+v", u.Name, u.Company.Name, result.Company.Name) | ||||
| 			} | ||||
| 		} else if result.Company.Name != "" { | ||||
| 			t.Fatalf("Preload other company should not loaded, user %v company expect %v but got %+v", u.Name, u.Company.Name, result.Company.Name) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu