use delayed‑ops pipeline for generics API
This commit is contained in:
parent
ba27874dcd
commit
2d6d7f9485
229
generics.go
229
generics.go
@ -56,10 +56,18 @@ type ExecInterface[T any] interface {
|
||||
Rows(ctx context.Context) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
type op func(*DB) *DB
|
||||
|
||||
func G[T any](db *DB, opts ...clause.Expression) Interface[T] {
|
||||
v := &g[T]{
|
||||
db: db.Session(&Session{NewDB: true}).Clauses(opts...),
|
||||
opts: opts,
|
||||
db: db.Session(&Session{NewDB: true}),
|
||||
ops: make([]op, 0, 5),
|
||||
}
|
||||
|
||||
if len(opts) > 0 {
|
||||
v.ops = append(v.ops, func(db *DB) *DB {
|
||||
return db.Clauses(opts...)
|
||||
})
|
||||
}
|
||||
|
||||
v.createG = &createG[T]{
|
||||
@ -73,141 +81,186 @@ func G[T any](db *DB, opts ...clause.Expression) Interface[T] {
|
||||
type g[T any] struct {
|
||||
*createG[T]
|
||||
db *DB
|
||||
opts []clause.Expression
|
||||
ops []op
|
||||
}
|
||||
|
||||
func (g *g[T]) Raw(sql string, values ...interface{}) ExecInterface[T] {
|
||||
g.db = g.db.Raw(sql, values...)
|
||||
return &g.execG
|
||||
func (g *g[T]) apply(ctx context.Context) *DB {
|
||||
db := g.db.Session(&Session{NewDB: true, Context: ctx}).getInstance()
|
||||
for _, op := range g.ops {
|
||||
db = op(db)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
func (g *g[T]) Exec(ctx context.Context, sql string, values ...interface{}) error {
|
||||
return g.db.WithContext(ctx).Exec(sql, values...).Error
|
||||
func (c *g[T]) Raw(sql string, values ...interface{}) ExecInterface[T] {
|
||||
return execG[T]{g: &g[T]{
|
||||
db: c.db,
|
||||
ops: append(c.ops, func(db *DB) *DB {
|
||||
return db.Raw(sql, values...)
|
||||
}),
|
||||
}}
|
||||
}
|
||||
|
||||
func (c *g[T]) Exec(ctx context.Context, sql string, values ...interface{}) error {
|
||||
return c.apply(ctx).Exec(sql, values...).Error
|
||||
}
|
||||
|
||||
type createG[T any] struct {
|
||||
chainG[T]
|
||||
}
|
||||
|
||||
func (g *createG[T]) Table(name string, args ...interface{}) CreateInterface[T] {
|
||||
g.g.db = g.g.db.Table(name, args...)
|
||||
return g
|
||||
func (c createG[T]) Table(name string, args ...interface{}) CreateInterface[T] {
|
||||
return createG[T]{c.with(func(db *DB) *DB {
|
||||
return db.Table(name, args...)
|
||||
})}
|
||||
}
|
||||
|
||||
func (g *createG[T]) Create(ctx context.Context, r *T) error {
|
||||
return g.g.db.WithContext(ctx).Create(r).Error
|
||||
func (c createG[T]) Create(ctx context.Context, r *T) error {
|
||||
return c.g.apply(ctx).Create(r).Error
|
||||
}
|
||||
|
||||
func (g *createG[T]) CreateInBatches(ctx context.Context, r *[]T, batchSize int) error {
|
||||
return g.g.db.WithContext(ctx).CreateInBatches(r, batchSize).Error
|
||||
func (c createG[T]) CreateInBatches(ctx context.Context, r *[]T, batchSize int) error {
|
||||
return c.g.apply(ctx).CreateInBatches(r, batchSize).Error
|
||||
}
|
||||
|
||||
type chainG[T any] struct {
|
||||
execG[T]
|
||||
}
|
||||
|
||||
func (g *chainG[T]) Scopes(scopes ...func(db *Statement)) ChainInterface[T] {
|
||||
for _, fc := range scopes {
|
||||
fc(g.g.db.Statement)
|
||||
func (c chainG[T]) with(op op) chainG[T] {
|
||||
return chainG[T]{
|
||||
execG: execG[T]{g: &g[T]{
|
||||
db: c.g.db,
|
||||
ops: append(c.g.ops, op),
|
||||
}},
|
||||
}
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *chainG[T]) Where(query interface{}, args ...interface{}) ChainInterface[T] {
|
||||
g.g.db = g.g.db.Where(query, args...)
|
||||
return g
|
||||
func (c chainG[T]) Scopes(scopes ...func(db *Statement)) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
for _, fc := range scopes {
|
||||
fc(db.Statement)
|
||||
}
|
||||
return db
|
||||
})
|
||||
}
|
||||
|
||||
func (g *chainG[T]) Not(query interface{}, args ...interface{}) ChainInterface[T] {
|
||||
g.g.db = g.g.db.Not(query, args...)
|
||||
return g
|
||||
func (c chainG[T]) Table(name string, args ...interface{}) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Table(name, args...)
|
||||
})
|
||||
}
|
||||
|
||||
func (g *chainG[T]) Or(query interface{}, args ...interface{}) ChainInterface[T] {
|
||||
g.g.db = g.g.db.Or(query, args...)
|
||||
return g
|
||||
func (c chainG[T]) Where(query interface{}, args ...interface{}) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Where(query, args...)
|
||||
})
|
||||
}
|
||||
|
||||
func (g *chainG[T]) Limit(offset int) ChainInterface[T] {
|
||||
g.g.db = g.g.db.Limit(offset)
|
||||
return g
|
||||
func (c chainG[T]) Not(query interface{}, args ...interface{}) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Not(query, args...)
|
||||
})
|
||||
}
|
||||
|
||||
func (g *chainG[T]) Offset(offset int) ChainInterface[T] {
|
||||
g.g.db = g.g.db.Offset(offset)
|
||||
return g
|
||||
func (c chainG[T]) Or(query interface{}, args ...interface{}) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Or(query, args...)
|
||||
})
|
||||
}
|
||||
|
||||
func (g *chainG[T]) Joins(query string, args ...interface{}) ChainInterface[T] {
|
||||
g.g.db = g.g.db.Joins(query, args...)
|
||||
return g
|
||||
func (c chainG[T]) Limit(offset int) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Limit(offset)
|
||||
})
|
||||
}
|
||||
|
||||
func (g *chainG[T]) InnerJoins(query string, args ...interface{}) ChainInterface[T] {
|
||||
g.g.db = g.g.db.InnerJoins(query, args...)
|
||||
return g
|
||||
func (c chainG[T]) Offset(offset int) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Offset(offset)
|
||||
})
|
||||
}
|
||||
|
||||
func (g *chainG[T]) Select(query string, args ...interface{}) ChainInterface[T] {
|
||||
g.g.db = g.g.db.Select(query, args...)
|
||||
return g
|
||||
func (c chainG[T]) Joins(query string, args ...interface{}) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Joins(query, args...)
|
||||
})
|
||||
}
|
||||
|
||||
func (g *chainG[T]) Omit(columns ...string) ChainInterface[T] {
|
||||
g.g.db = g.g.db.Omit(columns...)
|
||||
return g
|
||||
func (c chainG[T]) InnerJoins(query string, args ...interface{}) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.InnerJoins(query, args...)
|
||||
})
|
||||
}
|
||||
|
||||
func (g *chainG[T]) MapColumns(m map[string]string) ChainInterface[T] {
|
||||
g.g.db = g.g.db.MapColumns(m)
|
||||
return g
|
||||
func (c chainG[T]) Select(query string, args ...interface{}) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Select(query, args...)
|
||||
})
|
||||
}
|
||||
|
||||
func (g *chainG[T]) Distinct(args ...interface{}) ChainInterface[T] {
|
||||
g.g.db = g.g.db.Distinct(args...)
|
||||
return g
|
||||
func (c chainG[T]) Omit(columns ...string) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Omit(columns...)
|
||||
})
|
||||
}
|
||||
|
||||
func (g *chainG[T]) Group(name string) ChainInterface[T] {
|
||||
g.g.db = g.g.db.Group(name)
|
||||
return g
|
||||
func (c chainG[T]) MapColumns(m map[string]string) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.MapColumns(m)
|
||||
})
|
||||
}
|
||||
|
||||
func (g *chainG[T]) Having(query interface{}, args ...interface{}) ChainInterface[T] {
|
||||
g.g.db = g.g.db.Having(query, args...)
|
||||
return g
|
||||
func (c chainG[T]) Distinct(args ...interface{}) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Distinct(args...)
|
||||
})
|
||||
}
|
||||
|
||||
func (g *chainG[T]) Order(value interface{}) ChainInterface[T] {
|
||||
g.g.db = g.g.db.Order(value)
|
||||
return g
|
||||
func (c chainG[T]) Group(name string) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Group(name)
|
||||
})
|
||||
}
|
||||
|
||||
func (g *chainG[T]) Preload(query string, args ...interface{}) ChainInterface[T] {
|
||||
g.g.db = g.g.db.Preload(query, args...)
|
||||
return g
|
||||
func (c chainG[T]) Having(query interface{}, args ...interface{}) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Having(query, args...)
|
||||
})
|
||||
}
|
||||
|
||||
func (g *chainG[T]) Delete(ctx context.Context) (rowsAffected int, err error) {
|
||||
func (c chainG[T]) Order(value interface{}) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Order(value)
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Preload(query string, args ...interface{}) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Preload(query, args...)
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Delete(ctx context.Context) (rowsAffected int, err error) {
|
||||
r := new(T)
|
||||
res := g.g.db.WithContext(ctx).Delete(r)
|
||||
res := c.g.apply(ctx).Delete(r)
|
||||
return int(res.RowsAffected), res.Error
|
||||
}
|
||||
|
||||
func (g *chainG[T]) Update(ctx context.Context, name string, value any) (rowsAffected int, err error) {
|
||||
func (c chainG[T]) Update(ctx context.Context, name string, value any) (rowsAffected int, err error) {
|
||||
var r T
|
||||
res := g.g.db.WithContext(ctx).Model(r).Update(name, value)
|
||||
res := c.g.apply(ctx).Model(r).Update(name, value)
|
||||
return int(res.RowsAffected), res.Error
|
||||
}
|
||||
|
||||
func (g *chainG[T]) Updates(ctx context.Context, t T) (rowsAffected int, err error) {
|
||||
res := g.g.db.WithContext(ctx).Updates(t)
|
||||
func (c chainG[T]) Updates(ctx context.Context, t T) (rowsAffected int, err error) {
|
||||
res := c.g.apply(ctx).Updates(t)
|
||||
return int(res.RowsAffected), res.Error
|
||||
}
|
||||
|
||||
func (g *chainG[T]) Count(ctx context.Context, column string) (result int64, err error) {
|
||||
func (c chainG[T]) Count(ctx context.Context, column string) (result int64, err error) {
|
||||
var r T
|
||||
err = g.g.db.WithContext(ctx).Model(r).Select(column).Count(&result).Error
|
||||
err = c.g.apply(ctx).Model(r).Select(column).Count(&result).Error
|
||||
return
|
||||
}
|
||||
|
||||
@ -215,47 +268,47 @@ type execG[T any] struct {
|
||||
g *g[T]
|
||||
}
|
||||
|
||||
func (g *execG[T]) First(ctx context.Context) (T, error) {
|
||||
func (g execG[T]) First(ctx context.Context) (T, error) {
|
||||
var r T
|
||||
err := g.g.db.WithContext(ctx).First(&r).Error
|
||||
err := g.g.apply(ctx).First(&r).Error
|
||||
return r, err
|
||||
}
|
||||
|
||||
func (g *execG[T]) Scan(ctx context.Context, result interface{}) error {
|
||||
func (g execG[T]) Scan(ctx context.Context, result interface{}) error {
|
||||
var r T
|
||||
err := g.g.db.WithContext(ctx).Model(r).Find(&result).Error
|
||||
err := g.g.apply(ctx).Model(r).Find(&result).Error
|
||||
return err
|
||||
}
|
||||
|
||||
func (g *execG[T]) Last(ctx context.Context) (T, error) {
|
||||
func (g execG[T]) Last(ctx context.Context) (T, error) {
|
||||
var r T
|
||||
err := g.g.db.WithContext(ctx).Last(&r).Error
|
||||
err := g.g.apply(ctx).Last(&r).Error
|
||||
return r, err
|
||||
}
|
||||
|
||||
func (g *execG[T]) Take(ctx context.Context) (T, error) {
|
||||
func (g execG[T]) Take(ctx context.Context) (T, error) {
|
||||
var r T
|
||||
err := g.g.db.WithContext(ctx).Take(&r).Error
|
||||
err := g.g.apply(ctx).Take(&r).Error
|
||||
return r, err
|
||||
}
|
||||
|
||||
func (g *execG[T]) Find(ctx context.Context) ([]T, error) {
|
||||
func (g execG[T]) Find(ctx context.Context) ([]T, error) {
|
||||
var r []T
|
||||
err := g.g.db.WithContext(ctx).Find(&r).Error
|
||||
err := g.g.apply(ctx).Find(&r).Error
|
||||
return r, err
|
||||
}
|
||||
|
||||
func (g *execG[T]) FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error {
|
||||
func (g execG[T]) FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error {
|
||||
var data []T
|
||||
return g.g.db.WithContext(ctx).FindInBatches(&data, batchSize, func(tx *DB, batch int) error {
|
||||
return g.g.apply(ctx).FindInBatches(&data, batchSize, func(tx *DB, batch int) error {
|
||||
return fc(data, batch)
|
||||
}).Error
|
||||
}
|
||||
|
||||
func (g *execG[T]) Row(ctx context.Context) *sql.Row {
|
||||
return g.g.db.WithContext(ctx).Row()
|
||||
func (g execG[T]) Row(ctx context.Context) *sql.Row {
|
||||
return g.g.apply(ctx).Row()
|
||||
}
|
||||
|
||||
func (g *execG[T]) Rows(ctx context.Context) (*sql.Rows, error) {
|
||||
return g.g.db.WithContext(ctx).Rows()
|
||||
func (g execG[T]) Rows(ctx context.Context) (*sql.Rows, error) {
|
||||
return g.g.apply(ctx).Rows()
|
||||
}
|
||||
|
@ -12,11 +12,10 @@ import (
|
||||
)
|
||||
|
||||
func TestGenericsCreate(t *testing.T) {
|
||||
generic := gorm.G[User](DB)
|
||||
ctx := context.Background()
|
||||
|
||||
user := User{Name: "TestGenericsCreate", Age: 18}
|
||||
err := generic.Create(ctx, &user)
|
||||
err := gorm.G[User](DB).Create(ctx, &user)
|
||||
if err != nil {
|
||||
t.Fatalf("Create failed: %v", err)
|
||||
}
|
||||
@ -60,7 +59,7 @@ func TestGenericsCreate(t *testing.T) {
|
||||
|
||||
mapResult, err := gorm.G[map[string]interface{}](DB).Table("users").Where("name = ?", user.Name).MapColumns(map[string]string{"name": "user_name"}).Take(ctx)
|
||||
if v := mapResult["user_name"]; fmt.Sprint(v) != user.Name {
|
||||
t.Errorf("failed to find map results, got %v", mapResult)
|
||||
t.Errorf("failed to find map results, got %v, err %v", mapResult, err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -92,6 +91,7 @@ func TestGenericsCreateInBatches(t *testing.T) {
|
||||
|
||||
found, err := gorm.G[User](DB).Raw("SELECT * FROM users WHERE name LIKE ?", "GenericsCreateInBatches%").Find(ctx)
|
||||
if len(found) != len(batch) {
|
||||
fmt.Println(found)
|
||||
t.Errorf("expected %d from Raw Find, got %d", len(batch), len(found))
|
||||
}
|
||||
|
||||
@ -278,12 +278,13 @@ func TestGenericsScopes(t *testing.T) {
|
||||
|
||||
func TestGenericsJoinsAndPreload(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
db := gorm.G[User](DB)
|
||||
|
||||
u := User{Name: "GenericsJoins", Company: Company{Name: "GenericsCompany"}}
|
||||
DB.Create(&u)
|
||||
db.Create(ctx, &u)
|
||||
|
||||
// LEFT JOIN + WHERE
|
||||
result, err := gorm.G[User](DB).Joins("Company").Where("Company.name = ?", u.Company.Name).First(ctx)
|
||||
result, err := db.Joins("Company").Where("Company.name = ?", u.Company.Name).First(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Joins failed: %v", err)
|
||||
}
|
||||
@ -292,7 +293,7 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
|
||||
}
|
||||
|
||||
// INNER JOIN + Inline WHERE
|
||||
result2, err := gorm.G[User](DB).InnerJoins("Company", "Company.name = ?", u.Company.Name).First(ctx)
|
||||
result2, err := db.InnerJoins("Company", "Company.name = ?", u.Company.Name).First(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("InnerJoins failed: %v", err)
|
||||
}
|
||||
@ -301,7 +302,7 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
|
||||
}
|
||||
|
||||
// Preload
|
||||
result3, err := gorm.G[User](DB).Preload("Company").Where("name = ?", u.Name).First(ctx)
|
||||
result3, err := db.Preload("Company").Where("name = ?", u.Name).First(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Joins failed: %v", err)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user