finish generic version Preload

This commit is contained in:
Jinzhu 2025-05-20 22:46:24 +08:00
parent ba94e4eb2f
commit 4694673526
3 changed files with 120 additions and 32 deletions

View File

@ -152,7 +152,7 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati
return gorm.ErrInvalidData return gorm.ErrInvalidData
} }
} else { } else {
tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}) tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks, Initialized: true})
tx.Statement.ReflectValue = db.Statement.ReflectValue tx.Statement.ReflectValue = db.Statement.ReflectValue
tx.Statement.Unscoped = db.Statement.Unscoped tx.Statement.Unscoped = db.Statement.Unscoped
if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil { if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil {

View File

@ -31,7 +31,8 @@ type ChainInterface[T any] interface {
Or(query interface{}, args ...interface{}) ChainInterface[T] Or(query interface{}, args ...interface{}) ChainInterface[T]
Limit(offset int) ChainInterface[T] Limit(offset int) ChainInterface[T]
Offset(offset int) ChainInterface[T] Offset(offset int) ChainInterface[T]
Joins(query clause.JoinTarget, args func(db QueryInterface, joinTable clause.Table, curTable clause.Table) QueryInterface) ChainInterface[T] Joins(query clause.JoinTarget, args func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T]
Preload(association string, args func(db PreloadBuilder) error) ChainInterface[T]
Select(query string, args ...interface{}) ChainInterface[T] Select(query string, args ...interface{}) ChainInterface[T]
Omit(columns ...string) ChainInterface[T] Omit(columns ...string) ChainInterface[T]
MapColumns(m map[string]string) ChainInterface[T] MapColumns(m map[string]string) ChainInterface[T]
@ -39,7 +40,6 @@ type ChainInterface[T any] interface {
Group(name string) ChainInterface[T] Group(name string) ChainInterface[T]
Having(query interface{}, args ...interface{}) ChainInterface[T] Having(query interface{}, args ...interface{}) ChainInterface[T]
Order(value interface{}) ChainInterface[T] Order(value interface{}) ChainInterface[T]
Preload(query string, args ...interface{}) ChainInterface[T]
Build(builder clause.Builder) Build(builder clause.Builder)
@ -60,12 +60,24 @@ type ExecInterface[T any] interface {
Rows(ctx context.Context) (*sql.Rows, error) Rows(ctx context.Context) (*sql.Rows, error)
} }
type QueryInterface interface { type JoinBuilder interface {
Select(...string) QueryInterface Select(...string) JoinBuilder
Omit(...string) QueryInterface Omit(...string) JoinBuilder
Where(query interface{}, args ...interface{}) QueryInterface Where(query interface{}, args ...interface{}) JoinBuilder
Not(query interface{}, args ...interface{}) QueryInterface Not(query interface{}, args ...interface{}) JoinBuilder
Or(query interface{}, args ...interface{}) QueryInterface Or(query interface{}, args ...interface{}) JoinBuilder
}
type PreloadBuilder interface {
Select(...string) PreloadBuilder
Omit(...string) PreloadBuilder
Where(query interface{}, args ...interface{}) PreloadBuilder
Not(query interface{}, args ...interface{}) PreloadBuilder
Or(query interface{}, args ...interface{}) PreloadBuilder
Limit(offset int) PreloadBuilder
Offset(offset int) PreloadBuilder
Order(value interface{}) PreloadBuilder
Scopes(scopes ...func(db *Statement)) PreloadBuilder
} }
type op func(*DB) *DB type op func(*DB) *DB
@ -198,42 +210,90 @@ func (c chainG[T]) Offset(offset int) ChainInterface[T] {
}) })
} }
type query struct { type joinBuilder struct {
db *DB db *DB
} }
func (q query) Where(query interface{}, args ...interface{}) QueryInterface { func (q joinBuilder) Where(query interface{}, args ...interface{}) JoinBuilder {
q.db.Where(query, args...) q.db.Where(query, args...)
return q return q
} }
func (q query) Or(query interface{}, args ...interface{}) QueryInterface { func (q joinBuilder) Or(query interface{}, args ...interface{}) JoinBuilder {
q.db.Where(query, args...) q.db.Where(query, args...)
return q return q
} }
func (q query) Not(query interface{}, args ...interface{}) QueryInterface { func (q joinBuilder) Not(query interface{}, args ...interface{}) JoinBuilder {
q.db.Where(query, args...) q.db.Where(query, args...)
return q return q
} }
func (q query) Select(columns ...string) QueryInterface { func (q joinBuilder) Select(columns ...string) JoinBuilder {
q.db.Select(columns) q.db.Select(columns)
return q return q
} }
func (q query) Omit(columns ...string) QueryInterface { func (q joinBuilder) Omit(columns ...string) JoinBuilder {
q.db.Omit(columns...) q.db.Omit(columns...)
return q return q
} }
func (c chainG[T]) Joins(jt clause.JoinTarget, args func(db QueryInterface, joinTable clause.Table, curTable clause.Table) QueryInterface) ChainInterface[T] { type preloadBuilder struct {
db *DB
}
func (q preloadBuilder) Where(query interface{}, args ...interface{}) PreloadBuilder {
q.db.Where(query, args...)
return q
}
func (q preloadBuilder) Or(query interface{}, args ...interface{}) PreloadBuilder {
q.db.Where(query, args...)
return q
}
func (q preloadBuilder) Not(query interface{}, args ...interface{}) PreloadBuilder {
q.db.Where(query, args...)
return q
}
func (q preloadBuilder) Select(columns ...string) PreloadBuilder {
q.db.Select(columns)
return q
}
func (q preloadBuilder) Omit(columns ...string) PreloadBuilder {
q.db.Omit(columns...)
return q
}
func (q preloadBuilder) Limit(limit int) PreloadBuilder {
q.db.Limit(limit)
return q
}
func (q preloadBuilder) Offset(offset int) PreloadBuilder {
q.db.Offset(offset)
return q
}
func (q preloadBuilder) Order(value interface{}) PreloadBuilder {
q.db.Order(value)
return q
}
func (q preloadBuilder) Scopes(scopes ...func(db *Statement)) PreloadBuilder {
for _, fc := range scopes {
fc(q.db.Statement)
}
return q
}
func (c chainG[T]) Joins(jt clause.JoinTarget, args func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] {
return c.with(func(db *DB) *DB { return c.with(func(db *DB) *DB {
if jt.Table == "" { if jt.Table == "" {
jt.Table = clause.JoinTable(strings.Split(jt.Association, ".")...).Name jt.Table = clause.JoinTable(strings.Split(jt.Association, ".")...).Name
} }
q := query{db: db.Session(&Session{NewDB: true}).getInstance().Table(jt.Table)} q := joinBuilder{db: db.Session(&Session{NewDB: true}).getInstance().Table(jt.Table)}
if args != nil { if args != nil {
args(q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable}) args(q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable})
} }
@ -323,9 +383,15 @@ func (c chainG[T]) Order(value interface{}) ChainInterface[T] {
}) })
} }
func (c chainG[T]) Preload(query string, args ...interface{}) ChainInterface[T] { func (c chainG[T]) Preload(association string, args func(db PreloadBuilder) error) ChainInterface[T] {
return c.with(func(db *DB) *DB { return c.with(func(db *DB) *DB {
return db.Preload(query, args...) return db.Preload(association, func(db *DB) *DB {
q := preloadBuilder{db: db}
if args != nil {
args(q)
}
return q.db
})
}) })
} }

View File

@ -286,8 +286,9 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
db.CreateInBatches(ctx, &[]User{u3, u, u2}, 10) db.CreateInBatches(ctx, &[]User{u3, u, u2}, 10)
// Inner JOIN + WHERE // Inner JOIN + WHERE
result, err := db.Joins(clause.Has("Company"), func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface { result, err := db.Joins(clause.Has("Company"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
return db.Where("?.name = ?", joinTable, u.Company.Name) db.Where("?.name = ?", joinTable, u.Company.Name)
return nil
}).First(ctx) }).First(ctx)
if err != nil { if err != nil {
t.Fatalf("Joins failed: %v", err) t.Fatalf("Joins failed: %v", err)
@ -297,8 +298,9 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
} }
// Inner JOIN + WHERE with map // Inner JOIN + WHERE with map
result, err = db.Joins(clause.Has("Company"), func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface { result, err = db.Joins(clause.Has("Company"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
return db.Where(map[string]any{"name": u.Company.Name}) db.Where(map[string]any{"name": u.Company.Name})
return nil
}).First(ctx) }).First(ctx)
if err != nil { if err != nil {
t.Fatalf("Joins failed: %v", err) t.Fatalf("Joins failed: %v", err)
@ -317,11 +319,12 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
} }
// Left JOIN + Alias WHERE // Left JOIN + Alias WHERE
result, err = db.Joins(clause.LeftJoin.Association("Company").As("t"), func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface { result, err = db.Joins(clause.LeftJoin.Association("Company").As("t"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
if joinTable.Name != "t" { if joinTable.Name != "t" {
t.Fatalf("Join table should be t, but got %v", joinTable.Name) t.Fatalf("Join table should be t, but got %v", joinTable.Name)
} }
return db.Where("?.name = ?", joinTable, u.Company.Name) db.Where("?.name = ?", joinTable, u.Company.Name)
return nil
}).Where(map[string]any{"name": u.Name}).First(ctx) }).Where(map[string]any{"name": u.Name}).First(ctx)
if err != nil { if err != nil {
t.Fatalf("Joins failed: %v", err) t.Fatalf("Joins failed: %v", err)
@ -332,11 +335,12 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
// Raw Subquery JOIN + WHERE // Raw Subquery JOIN + WHERE
result, err = db.Joins(clause.LeftJoin.AssociationFrom("Company", gorm.G[Company](DB)).As("t"), result, err = db.Joins(clause.LeftJoin.AssociationFrom("Company", gorm.G[Company](DB)).As("t"),
func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface { func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
if joinTable.Name != "t" { if joinTable.Name != "t" {
t.Fatalf("Join table should be t, but got %v", joinTable.Name) t.Fatalf("Join table should be t, but got %v", joinTable.Name)
} }
return db.Where("?.name = ?", joinTable, u.Company.Name) db.Where("?.name = ?", joinTable, u.Company.Name)
return nil
}, },
).Where(map[string]any{"name": u2.Name}).First(ctx) ).Where(map[string]any{"name": u2.Name}).First(ctx)
if err != nil { if err != nil {
@ -348,11 +352,12 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
// Raw Subquery JOIN + WHERE + Select // Raw Subquery JOIN + WHERE + Select
result, err = db.Joins(clause.LeftJoin.AssociationFrom("Company", gorm.G[Company](DB).Select("Name")).As("t"), result, err = db.Joins(clause.LeftJoin.AssociationFrom("Company", gorm.G[Company](DB).Select("Name")).As("t"),
func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface { func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
if joinTable.Name != "t" { if joinTable.Name != "t" {
t.Fatalf("Join table should be t, but got %v", joinTable.Name) t.Fatalf("Join table should be t, but got %v", joinTable.Name)
} }
return db.Where("?.name = ?", joinTable, u.Company.Name) db.Where("?.name = ?", joinTable, u.Company.Name)
return nil
}, },
).Where(map[string]any{"name": u2.Name}).First(ctx) ).Where(map[string]any{"name": u2.Name}).First(ctx)
if err != nil { if err != nil {
@ -363,12 +368,29 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
} }
// Preload // Preload
result3, err := db.Preload("Company").Where("name = ?", u.Name).First(ctx) result3, err := db.Preload("Company", nil).Where("name = ?", u.Name).First(ctx)
if err != nil { if err != nil {
t.Fatalf("Joins failed: %v", err) t.Fatalf("Preload failed: %v", err)
} }
if result3.Name != u.Name || result3.Company.Name != u.Company.Name { if result3.Name != u.Name || result3.Company.Name != u.Company.Name {
t.Fatalf("Joins expected %s, got %+v", u.Name, result) t.Fatalf("Preload expected %s, got %+v", u.Name, result)
}
results, err := db.Preload("Company", func(db gorm.PreloadBuilder) error {
db.Where("name = ?", u.Company.Name)
return nil
}).Find(ctx)
if err != nil {
t.Fatalf("Preload failed: %v", err)
}
for _, result := range results {
if result.Name == u.Name {
if result.Company.Name != u.Company.Name {
t.Fatalf("Preload user %v company should be %v, but got %+v", u.Name, u.Company.Name, result.Company.Name)
}
} else if result.Company.Name != "" {
t.Fatalf("Preload other company should not loaded, user %v company expect %v but got %+v", u.Name, u.Company.Name, result.Company.Name)
}
} }
} }