finish generic version Preload
This commit is contained in:
parent
ba94e4eb2f
commit
4694673526
@ -152,7 +152,7 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati
|
||||
return gorm.ErrInvalidData
|
||||
}
|
||||
} else {
|
||||
tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks})
|
||||
tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks, Initialized: true})
|
||||
tx.Statement.ReflectValue = db.Statement.ReflectValue
|
||||
tx.Statement.Unscoped = db.Statement.Unscoped
|
||||
if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil {
|
||||
|
102
generics.go
102
generics.go
@ -31,7 +31,8 @@ type ChainInterface[T any] interface {
|
||||
Or(query interface{}, args ...interface{}) ChainInterface[T]
|
||||
Limit(offset int) ChainInterface[T]
|
||||
Offset(offset int) ChainInterface[T]
|
||||
Joins(query clause.JoinTarget, args func(db QueryInterface, joinTable clause.Table, curTable clause.Table) QueryInterface) ChainInterface[T]
|
||||
Joins(query clause.JoinTarget, args func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T]
|
||||
Preload(association string, args func(db PreloadBuilder) error) ChainInterface[T]
|
||||
Select(query string, args ...interface{}) ChainInterface[T]
|
||||
Omit(columns ...string) ChainInterface[T]
|
||||
MapColumns(m map[string]string) ChainInterface[T]
|
||||
@ -39,7 +40,6 @@ type ChainInterface[T any] interface {
|
||||
Group(name string) ChainInterface[T]
|
||||
Having(query interface{}, args ...interface{}) ChainInterface[T]
|
||||
Order(value interface{}) ChainInterface[T]
|
||||
Preload(query string, args ...interface{}) ChainInterface[T]
|
||||
|
||||
Build(builder clause.Builder)
|
||||
|
||||
@ -60,12 +60,24 @@ type ExecInterface[T any] interface {
|
||||
Rows(ctx context.Context) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
type QueryInterface interface {
|
||||
Select(...string) QueryInterface
|
||||
Omit(...string) QueryInterface
|
||||
Where(query interface{}, args ...interface{}) QueryInterface
|
||||
Not(query interface{}, args ...interface{}) QueryInterface
|
||||
Or(query interface{}, args ...interface{}) QueryInterface
|
||||
type JoinBuilder interface {
|
||||
Select(...string) JoinBuilder
|
||||
Omit(...string) JoinBuilder
|
||||
Where(query interface{}, args ...interface{}) JoinBuilder
|
||||
Not(query interface{}, args ...interface{}) JoinBuilder
|
||||
Or(query interface{}, args ...interface{}) JoinBuilder
|
||||
}
|
||||
|
||||
type PreloadBuilder interface {
|
||||
Select(...string) PreloadBuilder
|
||||
Omit(...string) PreloadBuilder
|
||||
Where(query interface{}, args ...interface{}) PreloadBuilder
|
||||
Not(query interface{}, args ...interface{}) PreloadBuilder
|
||||
Or(query interface{}, args ...interface{}) PreloadBuilder
|
||||
Limit(offset int) PreloadBuilder
|
||||
Offset(offset int) PreloadBuilder
|
||||
Order(value interface{}) PreloadBuilder
|
||||
Scopes(scopes ...func(db *Statement)) PreloadBuilder
|
||||
}
|
||||
|
||||
type op func(*DB) *DB
|
||||
@ -198,42 +210,90 @@ func (c chainG[T]) Offset(offset int) ChainInterface[T] {
|
||||
})
|
||||
}
|
||||
|
||||
type query struct {
|
||||
type joinBuilder struct {
|
||||
db *DB
|
||||
}
|
||||
|
||||
func (q query) Where(query interface{}, args ...interface{}) QueryInterface {
|
||||
func (q joinBuilder) Where(query interface{}, args ...interface{}) JoinBuilder {
|
||||
q.db.Where(query, args...)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q query) Or(query interface{}, args ...interface{}) QueryInterface {
|
||||
func (q joinBuilder) Or(query interface{}, args ...interface{}) JoinBuilder {
|
||||
q.db.Where(query, args...)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q query) Not(query interface{}, args ...interface{}) QueryInterface {
|
||||
func (q joinBuilder) Not(query interface{}, args ...interface{}) JoinBuilder {
|
||||
q.db.Where(query, args...)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q query) Select(columns ...string) QueryInterface {
|
||||
func (q joinBuilder) Select(columns ...string) JoinBuilder {
|
||||
q.db.Select(columns)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q query) Omit(columns ...string) QueryInterface {
|
||||
func (q joinBuilder) Omit(columns ...string) JoinBuilder {
|
||||
q.db.Omit(columns...)
|
||||
return q
|
||||
}
|
||||
|
||||
func (c chainG[T]) Joins(jt clause.JoinTarget, args func(db QueryInterface, joinTable clause.Table, curTable clause.Table) QueryInterface) ChainInterface[T] {
|
||||
type preloadBuilder struct {
|
||||
db *DB
|
||||
}
|
||||
|
||||
func (q preloadBuilder) Where(query interface{}, args ...interface{}) PreloadBuilder {
|
||||
q.db.Where(query, args...)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q preloadBuilder) Or(query interface{}, args ...interface{}) PreloadBuilder {
|
||||
q.db.Where(query, args...)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q preloadBuilder) Not(query interface{}, args ...interface{}) PreloadBuilder {
|
||||
q.db.Where(query, args...)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q preloadBuilder) Select(columns ...string) PreloadBuilder {
|
||||
q.db.Select(columns)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q preloadBuilder) Omit(columns ...string) PreloadBuilder {
|
||||
q.db.Omit(columns...)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q preloadBuilder) Limit(limit int) PreloadBuilder {
|
||||
q.db.Limit(limit)
|
||||
return q
|
||||
}
|
||||
func (q preloadBuilder) Offset(offset int) PreloadBuilder {
|
||||
q.db.Offset(offset)
|
||||
return q
|
||||
}
|
||||
func (q preloadBuilder) Order(value interface{}) PreloadBuilder {
|
||||
q.db.Order(value)
|
||||
return q
|
||||
}
|
||||
func (q preloadBuilder) Scopes(scopes ...func(db *Statement)) PreloadBuilder {
|
||||
for _, fc := range scopes {
|
||||
fc(q.db.Statement)
|
||||
}
|
||||
return q
|
||||
}
|
||||
|
||||
func (c chainG[T]) Joins(jt clause.JoinTarget, args func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
if jt.Table == "" {
|
||||
jt.Table = clause.JoinTable(strings.Split(jt.Association, ".")...).Name
|
||||
}
|
||||
|
||||
q := query{db: db.Session(&Session{NewDB: true}).getInstance().Table(jt.Table)}
|
||||
q := joinBuilder{db: db.Session(&Session{NewDB: true}).getInstance().Table(jt.Table)}
|
||||
if args != nil {
|
||||
args(q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable})
|
||||
}
|
||||
@ -323,9 +383,15 @@ func (c chainG[T]) Order(value interface{}) ChainInterface[T] {
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Preload(query string, args ...interface{}) ChainInterface[T] {
|
||||
func (c chainG[T]) Preload(association string, args func(db PreloadBuilder) error) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Preload(query, args...)
|
||||
return db.Preload(association, func(db *DB) *DB {
|
||||
q := preloadBuilder{db: db}
|
||||
if args != nil {
|
||||
args(q)
|
||||
}
|
||||
return q.db
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -286,8 +286,9 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
|
||||
db.CreateInBatches(ctx, &[]User{u3, u, u2}, 10)
|
||||
|
||||
// Inner JOIN + WHERE
|
||||
result, err := db.Joins(clause.Has("Company"), func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface {
|
||||
return db.Where("?.name = ?", joinTable, u.Company.Name)
|
||||
result, err := db.Joins(clause.Has("Company"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
|
||||
db.Where("?.name = ?", joinTable, u.Company.Name)
|
||||
return nil
|
||||
}).First(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Joins failed: %v", err)
|
||||
@ -297,8 +298,9 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
|
||||
}
|
||||
|
||||
// Inner JOIN + WHERE with map
|
||||
result, err = db.Joins(clause.Has("Company"), func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface {
|
||||
return db.Where(map[string]any{"name": u.Company.Name})
|
||||
result, err = db.Joins(clause.Has("Company"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
|
||||
db.Where(map[string]any{"name": u.Company.Name})
|
||||
return nil
|
||||
}).First(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Joins failed: %v", err)
|
||||
@ -317,11 +319,12 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
|
||||
}
|
||||
|
||||
// Left JOIN + Alias WHERE
|
||||
result, err = db.Joins(clause.LeftJoin.Association("Company").As("t"), func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface {
|
||||
result, err = db.Joins(clause.LeftJoin.Association("Company").As("t"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
|
||||
if joinTable.Name != "t" {
|
||||
t.Fatalf("Join table should be t, but got %v", joinTable.Name)
|
||||
}
|
||||
return db.Where("?.name = ?", joinTable, u.Company.Name)
|
||||
db.Where("?.name = ?", joinTable, u.Company.Name)
|
||||
return nil
|
||||
}).Where(map[string]any{"name": u.Name}).First(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Joins failed: %v", err)
|
||||
@ -332,11 +335,12 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
|
||||
|
||||
// Raw Subquery JOIN + WHERE
|
||||
result, err = db.Joins(clause.LeftJoin.AssociationFrom("Company", gorm.G[Company](DB)).As("t"),
|
||||
func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface {
|
||||
func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
|
||||
if joinTable.Name != "t" {
|
||||
t.Fatalf("Join table should be t, but got %v", joinTable.Name)
|
||||
}
|
||||
return db.Where("?.name = ?", joinTable, u.Company.Name)
|
||||
db.Where("?.name = ?", joinTable, u.Company.Name)
|
||||
return nil
|
||||
},
|
||||
).Where(map[string]any{"name": u2.Name}).First(ctx)
|
||||
if err != nil {
|
||||
@ -348,11 +352,12 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
|
||||
|
||||
// Raw Subquery JOIN + WHERE + Select
|
||||
result, err = db.Joins(clause.LeftJoin.AssociationFrom("Company", gorm.G[Company](DB).Select("Name")).As("t"),
|
||||
func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface {
|
||||
func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
|
||||
if joinTable.Name != "t" {
|
||||
t.Fatalf("Join table should be t, but got %v", joinTable.Name)
|
||||
}
|
||||
return db.Where("?.name = ?", joinTable, u.Company.Name)
|
||||
db.Where("?.name = ?", joinTable, u.Company.Name)
|
||||
return nil
|
||||
},
|
||||
).Where(map[string]any{"name": u2.Name}).First(ctx)
|
||||
if err != nil {
|
||||
@ -363,12 +368,29 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
|
||||
}
|
||||
|
||||
// Preload
|
||||
result3, err := db.Preload("Company").Where("name = ?", u.Name).First(ctx)
|
||||
result3, err := db.Preload("Company", nil).Where("name = ?", u.Name).First(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Joins failed: %v", err)
|
||||
t.Fatalf("Preload failed: %v", err)
|
||||
}
|
||||
if result3.Name != u.Name || result3.Company.Name != u.Company.Name {
|
||||
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
|
||||
t.Fatalf("Preload expected %s, got %+v", u.Name, result)
|
||||
}
|
||||
|
||||
results, err := db.Preload("Company", func(db gorm.PreloadBuilder) error {
|
||||
db.Where("name = ?", u.Company.Name)
|
||||
return nil
|
||||
}).Find(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Preload failed: %v", err)
|
||||
}
|
||||
for _, result := range results {
|
||||
if result.Name == u.Name {
|
||||
if result.Company.Name != u.Company.Name {
|
||||
t.Fatalf("Preload user %v company should be %v, but got %+v", u.Name, u.Company.Name, result.Company.Name)
|
||||
}
|
||||
} else if result.Company.Name != "" {
|
||||
t.Fatalf("Preload other company should not loaded, user %v company expect %v but got %+v", u.Name, u.Company.Name, result.Company.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user