finish generic version Preload
This commit is contained in:
parent
ba94e4eb2f
commit
4694673526
@ -152,7 +152,7 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati
|
|||||||
return gorm.ErrInvalidData
|
return gorm.ErrInvalidData
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks})
|
tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks, Initialized: true})
|
||||||
tx.Statement.ReflectValue = db.Statement.ReflectValue
|
tx.Statement.ReflectValue = db.Statement.ReflectValue
|
||||||
tx.Statement.Unscoped = db.Statement.Unscoped
|
tx.Statement.Unscoped = db.Statement.Unscoped
|
||||||
if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil {
|
if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil {
|
||||||
|
102
generics.go
102
generics.go
@ -31,7 +31,8 @@ type ChainInterface[T any] interface {
|
|||||||
Or(query interface{}, args ...interface{}) ChainInterface[T]
|
Or(query interface{}, args ...interface{}) ChainInterface[T]
|
||||||
Limit(offset int) ChainInterface[T]
|
Limit(offset int) ChainInterface[T]
|
||||||
Offset(offset int) ChainInterface[T]
|
Offset(offset int) ChainInterface[T]
|
||||||
Joins(query clause.JoinTarget, args func(db QueryInterface, joinTable clause.Table, curTable clause.Table) QueryInterface) ChainInterface[T]
|
Joins(query clause.JoinTarget, args func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T]
|
||||||
|
Preload(association string, args func(db PreloadBuilder) error) ChainInterface[T]
|
||||||
Select(query string, args ...interface{}) ChainInterface[T]
|
Select(query string, args ...interface{}) ChainInterface[T]
|
||||||
Omit(columns ...string) ChainInterface[T]
|
Omit(columns ...string) ChainInterface[T]
|
||||||
MapColumns(m map[string]string) ChainInterface[T]
|
MapColumns(m map[string]string) ChainInterface[T]
|
||||||
@ -39,7 +40,6 @@ type ChainInterface[T any] interface {
|
|||||||
Group(name string) ChainInterface[T]
|
Group(name string) ChainInterface[T]
|
||||||
Having(query interface{}, args ...interface{}) ChainInterface[T]
|
Having(query interface{}, args ...interface{}) ChainInterface[T]
|
||||||
Order(value interface{}) ChainInterface[T]
|
Order(value interface{}) ChainInterface[T]
|
||||||
Preload(query string, args ...interface{}) ChainInterface[T]
|
|
||||||
|
|
||||||
Build(builder clause.Builder)
|
Build(builder clause.Builder)
|
||||||
|
|
||||||
@ -60,12 +60,24 @@ type ExecInterface[T any] interface {
|
|||||||
Rows(ctx context.Context) (*sql.Rows, error)
|
Rows(ctx context.Context) (*sql.Rows, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type QueryInterface interface {
|
type JoinBuilder interface {
|
||||||
Select(...string) QueryInterface
|
Select(...string) JoinBuilder
|
||||||
Omit(...string) QueryInterface
|
Omit(...string) JoinBuilder
|
||||||
Where(query interface{}, args ...interface{}) QueryInterface
|
Where(query interface{}, args ...interface{}) JoinBuilder
|
||||||
Not(query interface{}, args ...interface{}) QueryInterface
|
Not(query interface{}, args ...interface{}) JoinBuilder
|
||||||
Or(query interface{}, args ...interface{}) QueryInterface
|
Or(query interface{}, args ...interface{}) JoinBuilder
|
||||||
|
}
|
||||||
|
|
||||||
|
type PreloadBuilder interface {
|
||||||
|
Select(...string) PreloadBuilder
|
||||||
|
Omit(...string) PreloadBuilder
|
||||||
|
Where(query interface{}, args ...interface{}) PreloadBuilder
|
||||||
|
Not(query interface{}, args ...interface{}) PreloadBuilder
|
||||||
|
Or(query interface{}, args ...interface{}) PreloadBuilder
|
||||||
|
Limit(offset int) PreloadBuilder
|
||||||
|
Offset(offset int) PreloadBuilder
|
||||||
|
Order(value interface{}) PreloadBuilder
|
||||||
|
Scopes(scopes ...func(db *Statement)) PreloadBuilder
|
||||||
}
|
}
|
||||||
|
|
||||||
type op func(*DB) *DB
|
type op func(*DB) *DB
|
||||||
@ -198,42 +210,90 @@ func (c chainG[T]) Offset(offset int) ChainInterface[T] {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
type query struct {
|
type joinBuilder struct {
|
||||||
db *DB
|
db *DB
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q query) Where(query interface{}, args ...interface{}) QueryInterface {
|
func (q joinBuilder) Where(query interface{}, args ...interface{}) JoinBuilder {
|
||||||
q.db.Where(query, args...)
|
q.db.Where(query, args...)
|
||||||
return q
|
return q
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q query) Or(query interface{}, args ...interface{}) QueryInterface {
|
func (q joinBuilder) Or(query interface{}, args ...interface{}) JoinBuilder {
|
||||||
q.db.Where(query, args...)
|
q.db.Where(query, args...)
|
||||||
return q
|
return q
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q query) Not(query interface{}, args ...interface{}) QueryInterface {
|
func (q joinBuilder) Not(query interface{}, args ...interface{}) JoinBuilder {
|
||||||
q.db.Where(query, args...)
|
q.db.Where(query, args...)
|
||||||
return q
|
return q
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q query) Select(columns ...string) QueryInterface {
|
func (q joinBuilder) Select(columns ...string) JoinBuilder {
|
||||||
q.db.Select(columns)
|
q.db.Select(columns)
|
||||||
return q
|
return q
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q query) Omit(columns ...string) QueryInterface {
|
func (q joinBuilder) Omit(columns ...string) JoinBuilder {
|
||||||
q.db.Omit(columns...)
|
q.db.Omit(columns...)
|
||||||
return q
|
return q
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c chainG[T]) Joins(jt clause.JoinTarget, args func(db QueryInterface, joinTable clause.Table, curTable clause.Table) QueryInterface) ChainInterface[T] {
|
type preloadBuilder struct {
|
||||||
|
db *DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q preloadBuilder) Where(query interface{}, args ...interface{}) PreloadBuilder {
|
||||||
|
q.db.Where(query, args...)
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q preloadBuilder) Or(query interface{}, args ...interface{}) PreloadBuilder {
|
||||||
|
q.db.Where(query, args...)
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q preloadBuilder) Not(query interface{}, args ...interface{}) PreloadBuilder {
|
||||||
|
q.db.Where(query, args...)
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q preloadBuilder) Select(columns ...string) PreloadBuilder {
|
||||||
|
q.db.Select(columns)
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q preloadBuilder) Omit(columns ...string) PreloadBuilder {
|
||||||
|
q.db.Omit(columns...)
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q preloadBuilder) Limit(limit int) PreloadBuilder {
|
||||||
|
q.db.Limit(limit)
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
func (q preloadBuilder) Offset(offset int) PreloadBuilder {
|
||||||
|
q.db.Offset(offset)
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
func (q preloadBuilder) Order(value interface{}) PreloadBuilder {
|
||||||
|
q.db.Order(value)
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
func (q preloadBuilder) Scopes(scopes ...func(db *Statement)) PreloadBuilder {
|
||||||
|
for _, fc := range scopes {
|
||||||
|
fc(q.db.Statement)
|
||||||
|
}
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c chainG[T]) Joins(jt clause.JoinTarget, args func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] {
|
||||||
return c.with(func(db *DB) *DB {
|
return c.with(func(db *DB) *DB {
|
||||||
if jt.Table == "" {
|
if jt.Table == "" {
|
||||||
jt.Table = clause.JoinTable(strings.Split(jt.Association, ".")...).Name
|
jt.Table = clause.JoinTable(strings.Split(jt.Association, ".")...).Name
|
||||||
}
|
}
|
||||||
|
|
||||||
q := query{db: db.Session(&Session{NewDB: true}).getInstance().Table(jt.Table)}
|
q := joinBuilder{db: db.Session(&Session{NewDB: true}).getInstance().Table(jt.Table)}
|
||||||
if args != nil {
|
if args != nil {
|
||||||
args(q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable})
|
args(q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable})
|
||||||
}
|
}
|
||||||
@ -323,9 +383,15 @@ func (c chainG[T]) Order(value interface{}) ChainInterface[T] {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c chainG[T]) Preload(query string, args ...interface{}) ChainInterface[T] {
|
func (c chainG[T]) Preload(association string, args func(db PreloadBuilder) error) ChainInterface[T] {
|
||||||
return c.with(func(db *DB) *DB {
|
return c.with(func(db *DB) *DB {
|
||||||
return db.Preload(query, args...)
|
return db.Preload(association, func(db *DB) *DB {
|
||||||
|
q := preloadBuilder{db: db}
|
||||||
|
if args != nil {
|
||||||
|
args(q)
|
||||||
|
}
|
||||||
|
return q.db
|
||||||
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -286,8 +286,9 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
|
|||||||
db.CreateInBatches(ctx, &[]User{u3, u, u2}, 10)
|
db.CreateInBatches(ctx, &[]User{u3, u, u2}, 10)
|
||||||
|
|
||||||
// Inner JOIN + WHERE
|
// Inner JOIN + WHERE
|
||||||
result, err := db.Joins(clause.Has("Company"), func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface {
|
result, err := db.Joins(clause.Has("Company"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
|
||||||
return db.Where("?.name = ?", joinTable, u.Company.Name)
|
db.Where("?.name = ?", joinTable, u.Company.Name)
|
||||||
|
return nil
|
||||||
}).First(ctx)
|
}).First(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Joins failed: %v", err)
|
t.Fatalf("Joins failed: %v", err)
|
||||||
@ -297,8 +298,9 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Inner JOIN + WHERE with map
|
// Inner JOIN + WHERE with map
|
||||||
result, err = db.Joins(clause.Has("Company"), func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface {
|
result, err = db.Joins(clause.Has("Company"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
|
||||||
return db.Where(map[string]any{"name": u.Company.Name})
|
db.Where(map[string]any{"name": u.Company.Name})
|
||||||
|
return nil
|
||||||
}).First(ctx)
|
}).First(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Joins failed: %v", err)
|
t.Fatalf("Joins failed: %v", err)
|
||||||
@ -317,11 +319,12 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Left JOIN + Alias WHERE
|
// Left JOIN + Alias WHERE
|
||||||
result, err = db.Joins(clause.LeftJoin.Association("Company").As("t"), func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface {
|
result, err = db.Joins(clause.LeftJoin.Association("Company").As("t"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
|
||||||
if joinTable.Name != "t" {
|
if joinTable.Name != "t" {
|
||||||
t.Fatalf("Join table should be t, but got %v", joinTable.Name)
|
t.Fatalf("Join table should be t, but got %v", joinTable.Name)
|
||||||
}
|
}
|
||||||
return db.Where("?.name = ?", joinTable, u.Company.Name)
|
db.Where("?.name = ?", joinTable, u.Company.Name)
|
||||||
|
return nil
|
||||||
}).Where(map[string]any{"name": u.Name}).First(ctx)
|
}).Where(map[string]any{"name": u.Name}).First(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Joins failed: %v", err)
|
t.Fatalf("Joins failed: %v", err)
|
||||||
@ -332,11 +335,12 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
|
|||||||
|
|
||||||
// Raw Subquery JOIN + WHERE
|
// Raw Subquery JOIN + WHERE
|
||||||
result, err = db.Joins(clause.LeftJoin.AssociationFrom("Company", gorm.G[Company](DB)).As("t"),
|
result, err = db.Joins(clause.LeftJoin.AssociationFrom("Company", gorm.G[Company](DB)).As("t"),
|
||||||
func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface {
|
func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
|
||||||
if joinTable.Name != "t" {
|
if joinTable.Name != "t" {
|
||||||
t.Fatalf("Join table should be t, but got %v", joinTable.Name)
|
t.Fatalf("Join table should be t, but got %v", joinTable.Name)
|
||||||
}
|
}
|
||||||
return db.Where("?.name = ?", joinTable, u.Company.Name)
|
db.Where("?.name = ?", joinTable, u.Company.Name)
|
||||||
|
return nil
|
||||||
},
|
},
|
||||||
).Where(map[string]any{"name": u2.Name}).First(ctx)
|
).Where(map[string]any{"name": u2.Name}).First(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -348,11 +352,12 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
|
|||||||
|
|
||||||
// Raw Subquery JOIN + WHERE + Select
|
// Raw Subquery JOIN + WHERE + Select
|
||||||
result, err = db.Joins(clause.LeftJoin.AssociationFrom("Company", gorm.G[Company](DB).Select("Name")).As("t"),
|
result, err = db.Joins(clause.LeftJoin.AssociationFrom("Company", gorm.G[Company](DB).Select("Name")).As("t"),
|
||||||
func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface {
|
func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
|
||||||
if joinTable.Name != "t" {
|
if joinTable.Name != "t" {
|
||||||
t.Fatalf("Join table should be t, but got %v", joinTable.Name)
|
t.Fatalf("Join table should be t, but got %v", joinTable.Name)
|
||||||
}
|
}
|
||||||
return db.Where("?.name = ?", joinTable, u.Company.Name)
|
db.Where("?.name = ?", joinTable, u.Company.Name)
|
||||||
|
return nil
|
||||||
},
|
},
|
||||||
).Where(map[string]any{"name": u2.Name}).First(ctx)
|
).Where(map[string]any{"name": u2.Name}).First(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -363,12 +368,29 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Preload
|
// Preload
|
||||||
result3, err := db.Preload("Company").Where("name = ?", u.Name).First(ctx)
|
result3, err := db.Preload("Company", nil).Where("name = ?", u.Name).First(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Joins failed: %v", err)
|
t.Fatalf("Preload failed: %v", err)
|
||||||
}
|
}
|
||||||
if result3.Name != u.Name || result3.Company.Name != u.Company.Name {
|
if result3.Name != u.Name || result3.Company.Name != u.Company.Name {
|
||||||
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
|
t.Fatalf("Preload expected %s, got %+v", u.Name, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
results, err := db.Preload("Company", func(db gorm.PreloadBuilder) error {
|
||||||
|
db.Where("name = ?", u.Company.Name)
|
||||||
|
return nil
|
||||||
|
}).Find(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Preload failed: %v", err)
|
||||||
|
}
|
||||||
|
for _, result := range results {
|
||||||
|
if result.Name == u.Name {
|
||||||
|
if result.Company.Name != u.Company.Name {
|
||||||
|
t.Fatalf("Preload user %v company should be %v, but got %+v", u.Name, u.Company.Name, result.Company.Name)
|
||||||
|
}
|
||||||
|
} else if result.Company.Name != "" {
|
||||||
|
t.Fatalf("Preload other company should not loaded, user %v company expect %v but got %+v", u.Name, u.Company.Name, result.Company.Name)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user