From 46946735264c3ae59aab517b097745d373f664ef Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 20 May 2025 22:46:24 +0800 Subject: [PATCH] finish generic version Preload --- callbacks/preload.go | 2 +- generics.go | 102 +++++++++++++++++++++++++++++++++-------- tests/generics_test.go | 48 +++++++++++++------ 3 files changed, 120 insertions(+), 32 deletions(-) diff --git a/callbacks/preload.go b/callbacks/preload.go index fd8214bb..4a6f2b79 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -152,7 +152,7 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati return gorm.ErrInvalidData } } else { - tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}) + tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks, Initialized: true}) tx.Statement.ReflectValue = db.Statement.ReflectValue tx.Statement.Unscoped = db.Statement.Unscoped if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil { diff --git a/generics.go b/generics.go index d1d1a6e5..4953c758 100644 --- a/generics.go +++ b/generics.go @@ -31,7 +31,8 @@ type ChainInterface[T any] interface { Or(query interface{}, args ...interface{}) ChainInterface[T] Limit(offset int) ChainInterface[T] Offset(offset int) ChainInterface[T] - Joins(query clause.JoinTarget, args func(db QueryInterface, joinTable clause.Table, curTable clause.Table) QueryInterface) ChainInterface[T] + Joins(query clause.JoinTarget, args func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] + Preload(association string, args func(db PreloadBuilder) error) ChainInterface[T] Select(query string, args ...interface{}) ChainInterface[T] Omit(columns ...string) ChainInterface[T] MapColumns(m map[string]string) ChainInterface[T] @@ -39,7 +40,6 @@ type ChainInterface[T any] interface { Group(name string) ChainInterface[T] Having(query interface{}, args ...interface{}) ChainInterface[T] Order(value interface{}) ChainInterface[T] - Preload(query string, args ...interface{}) ChainInterface[T] Build(builder clause.Builder) @@ -60,12 +60,24 @@ type ExecInterface[T any] interface { Rows(ctx context.Context) (*sql.Rows, error) } -type QueryInterface interface { - Select(...string) QueryInterface - Omit(...string) QueryInterface - Where(query interface{}, args ...interface{}) QueryInterface - Not(query interface{}, args ...interface{}) QueryInterface - Or(query interface{}, args ...interface{}) QueryInterface +type JoinBuilder interface { + Select(...string) JoinBuilder + Omit(...string) JoinBuilder + Where(query interface{}, args ...interface{}) JoinBuilder + Not(query interface{}, args ...interface{}) JoinBuilder + Or(query interface{}, args ...interface{}) JoinBuilder +} + +type PreloadBuilder interface { + Select(...string) PreloadBuilder + Omit(...string) PreloadBuilder + Where(query interface{}, args ...interface{}) PreloadBuilder + Not(query interface{}, args ...interface{}) PreloadBuilder + Or(query interface{}, args ...interface{}) PreloadBuilder + Limit(offset int) PreloadBuilder + Offset(offset int) PreloadBuilder + Order(value interface{}) PreloadBuilder + Scopes(scopes ...func(db *Statement)) PreloadBuilder } type op func(*DB) *DB @@ -198,42 +210,90 @@ func (c chainG[T]) Offset(offset int) ChainInterface[T] { }) } -type query struct { +type joinBuilder struct { db *DB } -func (q query) Where(query interface{}, args ...interface{}) QueryInterface { +func (q joinBuilder) Where(query interface{}, args ...interface{}) JoinBuilder { q.db.Where(query, args...) return q } -func (q query) Or(query interface{}, args ...interface{}) QueryInterface { +func (q joinBuilder) Or(query interface{}, args ...interface{}) JoinBuilder { q.db.Where(query, args...) return q } -func (q query) Not(query interface{}, args ...interface{}) QueryInterface { +func (q joinBuilder) Not(query interface{}, args ...interface{}) JoinBuilder { q.db.Where(query, args...) return q } -func (q query) Select(columns ...string) QueryInterface { +func (q joinBuilder) Select(columns ...string) JoinBuilder { q.db.Select(columns) return q } -func (q query) Omit(columns ...string) QueryInterface { +func (q joinBuilder) Omit(columns ...string) JoinBuilder { q.db.Omit(columns...) return q } -func (c chainG[T]) Joins(jt clause.JoinTarget, args func(db QueryInterface, joinTable clause.Table, curTable clause.Table) QueryInterface) ChainInterface[T] { +type preloadBuilder struct { + db *DB +} + +func (q preloadBuilder) Where(query interface{}, args ...interface{}) PreloadBuilder { + q.db.Where(query, args...) + return q +} + +func (q preloadBuilder) Or(query interface{}, args ...interface{}) PreloadBuilder { + q.db.Where(query, args...) + return q +} + +func (q preloadBuilder) Not(query interface{}, args ...interface{}) PreloadBuilder { + q.db.Where(query, args...) + return q +} + +func (q preloadBuilder) Select(columns ...string) PreloadBuilder { + q.db.Select(columns) + return q +} + +func (q preloadBuilder) Omit(columns ...string) PreloadBuilder { + q.db.Omit(columns...) + return q +} + +func (q preloadBuilder) Limit(limit int) PreloadBuilder { + q.db.Limit(limit) + return q +} +func (q preloadBuilder) Offset(offset int) PreloadBuilder { + q.db.Offset(offset) + return q +} +func (q preloadBuilder) Order(value interface{}) PreloadBuilder { + q.db.Order(value) + return q +} +func (q preloadBuilder) Scopes(scopes ...func(db *Statement)) PreloadBuilder { + for _, fc := range scopes { + fc(q.db.Statement) + } + return q +} + +func (c chainG[T]) Joins(jt clause.JoinTarget, args func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] { return c.with(func(db *DB) *DB { if jt.Table == "" { jt.Table = clause.JoinTable(strings.Split(jt.Association, ".")...).Name } - q := query{db: db.Session(&Session{NewDB: true}).getInstance().Table(jt.Table)} + q := joinBuilder{db: db.Session(&Session{NewDB: true}).getInstance().Table(jt.Table)} if args != nil { args(q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable}) } @@ -323,9 +383,15 @@ func (c chainG[T]) Order(value interface{}) ChainInterface[T] { }) } -func (c chainG[T]) Preload(query string, args ...interface{}) ChainInterface[T] { +func (c chainG[T]) Preload(association string, args func(db PreloadBuilder) error) ChainInterface[T] { return c.with(func(db *DB) *DB { - return db.Preload(query, args...) + return db.Preload(association, func(db *DB) *DB { + q := preloadBuilder{db: db} + if args != nil { + args(q) + } + return q.db + }) }) } diff --git a/tests/generics_test.go b/tests/generics_test.go index 313b6bae..2efaacdc 100644 --- a/tests/generics_test.go +++ b/tests/generics_test.go @@ -286,8 +286,9 @@ func TestGenericsJoinsAndPreload(t *testing.T) { db.CreateInBatches(ctx, &[]User{u3, u, u2}, 10) // Inner JOIN + WHERE - result, err := db.Joins(clause.Has("Company"), func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface { - return db.Where("?.name = ?", joinTable, u.Company.Name) + result, err := db.Joins(clause.Has("Company"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error { + db.Where("?.name = ?", joinTable, u.Company.Name) + return nil }).First(ctx) if err != nil { t.Fatalf("Joins failed: %v", err) @@ -297,8 +298,9 @@ func TestGenericsJoinsAndPreload(t *testing.T) { } // Inner JOIN + WHERE with map - result, err = db.Joins(clause.Has("Company"), func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface { - return db.Where(map[string]any{"name": u.Company.Name}) + result, err = db.Joins(clause.Has("Company"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error { + db.Where(map[string]any{"name": u.Company.Name}) + return nil }).First(ctx) if err != nil { t.Fatalf("Joins failed: %v", err) @@ -317,11 +319,12 @@ func TestGenericsJoinsAndPreload(t *testing.T) { } // Left JOIN + Alias WHERE - result, err = db.Joins(clause.LeftJoin.Association("Company").As("t"), func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface { + result, err = db.Joins(clause.LeftJoin.Association("Company").As("t"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error { if joinTable.Name != "t" { t.Fatalf("Join table should be t, but got %v", joinTable.Name) } - return db.Where("?.name = ?", joinTable, u.Company.Name) + db.Where("?.name = ?", joinTable, u.Company.Name) + return nil }).Where(map[string]any{"name": u.Name}).First(ctx) if err != nil { t.Fatalf("Joins failed: %v", err) @@ -332,11 +335,12 @@ func TestGenericsJoinsAndPreload(t *testing.T) { // Raw Subquery JOIN + WHERE result, err = db.Joins(clause.LeftJoin.AssociationFrom("Company", gorm.G[Company](DB)).As("t"), - func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface { + func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error { if joinTable.Name != "t" { t.Fatalf("Join table should be t, but got %v", joinTable.Name) } - return db.Where("?.name = ?", joinTable, u.Company.Name) + db.Where("?.name = ?", joinTable, u.Company.Name) + return nil }, ).Where(map[string]any{"name": u2.Name}).First(ctx) if err != nil { @@ -348,11 +352,12 @@ func TestGenericsJoinsAndPreload(t *testing.T) { // Raw Subquery JOIN + WHERE + Select result, err = db.Joins(clause.LeftJoin.AssociationFrom("Company", gorm.G[Company](DB).Select("Name")).As("t"), - func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface { + func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error { if joinTable.Name != "t" { t.Fatalf("Join table should be t, but got %v", joinTable.Name) } - return db.Where("?.name = ?", joinTable, u.Company.Name) + db.Where("?.name = ?", joinTable, u.Company.Name) + return nil }, ).Where(map[string]any{"name": u2.Name}).First(ctx) if err != nil { @@ -363,12 +368,29 @@ func TestGenericsJoinsAndPreload(t *testing.T) { } // Preload - result3, err := db.Preload("Company").Where("name = ?", u.Name).First(ctx) + result3, err := db.Preload("Company", nil).Where("name = ?", u.Name).First(ctx) if err != nil { - t.Fatalf("Joins failed: %v", err) + t.Fatalf("Preload failed: %v", err) } if result3.Name != u.Name || result3.Company.Name != u.Company.Name { - t.Fatalf("Joins expected %s, got %+v", u.Name, result) + t.Fatalf("Preload expected %s, got %+v", u.Name, result) + } + + results, err := db.Preload("Company", func(db gorm.PreloadBuilder) error { + db.Where("name = ?", u.Company.Name) + return nil + }).Find(ctx) + if err != nil { + t.Fatalf("Preload failed: %v", err) + } + for _, result := range results { + if result.Name == u.Name { + if result.Company.Name != u.Company.Name { + t.Fatalf("Preload user %v company should be %v, but got %+v", u.Name, u.Company.Name, result.Company.Name) + } + } else if result.Company.Name != "" { + t.Fatalf("Preload other company should not loaded, user %v company expect %v but got %+v", u.Name, u.Company.Name, result.Company.Name) + } } }