From 2d6d7f94859e31a26bb15bb6f29acb9ec1615764 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 18 Apr 2025 18:18:56 +0800 Subject: [PATCH] =?UTF-8?q?use=20delayed=E2=80=91ops=20pipeline=20for=20ge?= =?UTF-8?q?nerics=20API?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- generics.go | 231 +++++++++++++++++++++++++---------------- tests/generics_test.go | 15 +-- 2 files changed, 150 insertions(+), 96 deletions(-) diff --git a/generics.go b/generics.go index f40c73be..5930a6ce 100644 --- a/generics.go +++ b/generics.go @@ -56,10 +56,18 @@ type ExecInterface[T any] interface { Rows(ctx context.Context) (*sql.Rows, error) } +type op func(*DB) *DB + func G[T any](db *DB, opts ...clause.Expression) Interface[T] { v := &g[T]{ - db: db.Session(&Session{NewDB: true}).Clauses(opts...), - opts: opts, + db: db.Session(&Session{NewDB: true}), + ops: make([]op, 0, 5), + } + + if len(opts) > 0 { + v.ops = append(v.ops, func(db *DB) *DB { + return db.Clauses(opts...) + }) } v.createG = &createG[T]{ @@ -72,142 +80,187 @@ func G[T any](db *DB, opts ...clause.Expression) Interface[T] { type g[T any] struct { *createG[T] - db *DB - opts []clause.Expression + db *DB + ops []op } -func (g *g[T]) Raw(sql string, values ...interface{}) ExecInterface[T] { - g.db = g.db.Raw(sql, values...) - return &g.execG +func (g *g[T]) apply(ctx context.Context) *DB { + db := g.db.Session(&Session{NewDB: true, Context: ctx}).getInstance() + for _, op := range g.ops { + db = op(db) + } + return db } -func (g *g[T]) Exec(ctx context.Context, sql string, values ...interface{}) error { - return g.db.WithContext(ctx).Exec(sql, values...).Error +func (c *g[T]) Raw(sql string, values ...interface{}) ExecInterface[T] { + return execG[T]{g: &g[T]{ + db: c.db, + ops: append(c.ops, func(db *DB) *DB { + return db.Raw(sql, values...) + }), + }} +} + +func (c *g[T]) Exec(ctx context.Context, sql string, values ...interface{}) error { + return c.apply(ctx).Exec(sql, values...).Error } type createG[T any] struct { chainG[T] } -func (g *createG[T]) Table(name string, args ...interface{}) CreateInterface[T] { - g.g.db = g.g.db.Table(name, args...) - return g +func (c createG[T]) Table(name string, args ...interface{}) CreateInterface[T] { + return createG[T]{c.with(func(db *DB) *DB { + return db.Table(name, args...) + })} } -func (g *createG[T]) Create(ctx context.Context, r *T) error { - return g.g.db.WithContext(ctx).Create(r).Error +func (c createG[T]) Create(ctx context.Context, r *T) error { + return c.g.apply(ctx).Create(r).Error } -func (g *createG[T]) CreateInBatches(ctx context.Context, r *[]T, batchSize int) error { - return g.g.db.WithContext(ctx).CreateInBatches(r, batchSize).Error +func (c createG[T]) CreateInBatches(ctx context.Context, r *[]T, batchSize int) error { + return c.g.apply(ctx).CreateInBatches(r, batchSize).Error } type chainG[T any] struct { execG[T] } -func (g *chainG[T]) Scopes(scopes ...func(db *Statement)) ChainInterface[T] { - for _, fc := range scopes { - fc(g.g.db.Statement) +func (c chainG[T]) with(op op) chainG[T] { + return chainG[T]{ + execG: execG[T]{g: &g[T]{ + db: c.g.db, + ops: append(c.g.ops, op), + }}, } - return g } -func (g *chainG[T]) Where(query interface{}, args ...interface{}) ChainInterface[T] { - g.g.db = g.g.db.Where(query, args...) - return g +func (c chainG[T]) Scopes(scopes ...func(db *Statement)) ChainInterface[T] { + return c.with(func(db *DB) *DB { + for _, fc := range scopes { + fc(db.Statement) + } + return db + }) } -func (g *chainG[T]) Not(query interface{}, args ...interface{}) ChainInterface[T] { - g.g.db = g.g.db.Not(query, args...) - return g +func (c chainG[T]) Table(name string, args ...interface{}) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Table(name, args...) + }) } -func (g *chainG[T]) Or(query interface{}, args ...interface{}) ChainInterface[T] { - g.g.db = g.g.db.Or(query, args...) - return g +func (c chainG[T]) Where(query interface{}, args ...interface{}) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Where(query, args...) + }) } -func (g *chainG[T]) Limit(offset int) ChainInterface[T] { - g.g.db = g.g.db.Limit(offset) - return g +func (c chainG[T]) Not(query interface{}, args ...interface{}) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Not(query, args...) + }) } -func (g *chainG[T]) Offset(offset int) ChainInterface[T] { - g.g.db = g.g.db.Offset(offset) - return g +func (c chainG[T]) Or(query interface{}, args ...interface{}) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Or(query, args...) + }) } -func (g *chainG[T]) Joins(query string, args ...interface{}) ChainInterface[T] { - g.g.db = g.g.db.Joins(query, args...) - return g +func (c chainG[T]) Limit(offset int) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Limit(offset) + }) } -func (g *chainG[T]) InnerJoins(query string, args ...interface{}) ChainInterface[T] { - g.g.db = g.g.db.InnerJoins(query, args...) - return g +func (c chainG[T]) Offset(offset int) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Offset(offset) + }) } -func (g *chainG[T]) Select(query string, args ...interface{}) ChainInterface[T] { - g.g.db = g.g.db.Select(query, args...) - return g +func (c chainG[T]) Joins(query string, args ...interface{}) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Joins(query, args...) + }) } -func (g *chainG[T]) Omit(columns ...string) ChainInterface[T] { - g.g.db = g.g.db.Omit(columns...) - return g +func (c chainG[T]) InnerJoins(query string, args ...interface{}) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.InnerJoins(query, args...) + }) } -func (g *chainG[T]) MapColumns(m map[string]string) ChainInterface[T] { - g.g.db = g.g.db.MapColumns(m) - return g +func (c chainG[T]) Select(query string, args ...interface{}) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Select(query, args...) + }) } -func (g *chainG[T]) Distinct(args ...interface{}) ChainInterface[T] { - g.g.db = g.g.db.Distinct(args...) - return g +func (c chainG[T]) Omit(columns ...string) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Omit(columns...) + }) } -func (g *chainG[T]) Group(name string) ChainInterface[T] { - g.g.db = g.g.db.Group(name) - return g +func (c chainG[T]) MapColumns(m map[string]string) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.MapColumns(m) + }) } -func (g *chainG[T]) Having(query interface{}, args ...interface{}) ChainInterface[T] { - g.g.db = g.g.db.Having(query, args...) - return g +func (c chainG[T]) Distinct(args ...interface{}) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Distinct(args...) + }) } -func (g *chainG[T]) Order(value interface{}) ChainInterface[T] { - g.g.db = g.g.db.Order(value) - return g +func (c chainG[T]) Group(name string) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Group(name) + }) } -func (g *chainG[T]) Preload(query string, args ...interface{}) ChainInterface[T] { - g.g.db = g.g.db.Preload(query, args...) - return g +func (c chainG[T]) Having(query interface{}, args ...interface{}) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Having(query, args...) + }) } -func (g *chainG[T]) Delete(ctx context.Context) (rowsAffected int, err error) { +func (c chainG[T]) Order(value interface{}) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Order(value) + }) +} + +func (c chainG[T]) Preload(query string, args ...interface{}) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Preload(query, args...) + }) +} + +func (c chainG[T]) Delete(ctx context.Context) (rowsAffected int, err error) { r := new(T) - res := g.g.db.WithContext(ctx).Delete(r) + res := c.g.apply(ctx).Delete(r) return int(res.RowsAffected), res.Error } -func (g *chainG[T]) Update(ctx context.Context, name string, value any) (rowsAffected int, err error) { +func (c chainG[T]) Update(ctx context.Context, name string, value any) (rowsAffected int, err error) { var r T - res := g.g.db.WithContext(ctx).Model(r).Update(name, value) + res := c.g.apply(ctx).Model(r).Update(name, value) return int(res.RowsAffected), res.Error } -func (g *chainG[T]) Updates(ctx context.Context, t T) (rowsAffected int, err error) { - res := g.g.db.WithContext(ctx).Updates(t) +func (c chainG[T]) Updates(ctx context.Context, t T) (rowsAffected int, err error) { + res := c.g.apply(ctx).Updates(t) return int(res.RowsAffected), res.Error } -func (g *chainG[T]) Count(ctx context.Context, column string) (result int64, err error) { +func (c chainG[T]) Count(ctx context.Context, column string) (result int64, err error) { var r T - err = g.g.db.WithContext(ctx).Model(r).Select(column).Count(&result).Error + err = c.g.apply(ctx).Model(r).Select(column).Count(&result).Error return } @@ -215,47 +268,47 @@ type execG[T any] struct { g *g[T] } -func (g *execG[T]) First(ctx context.Context) (T, error) { +func (g execG[T]) First(ctx context.Context) (T, error) { var r T - err := g.g.db.WithContext(ctx).First(&r).Error + err := g.g.apply(ctx).First(&r).Error return r, err } -func (g *execG[T]) Scan(ctx context.Context, result interface{}) error { +func (g execG[T]) Scan(ctx context.Context, result interface{}) error { var r T - err := g.g.db.WithContext(ctx).Model(r).Find(&result).Error + err := g.g.apply(ctx).Model(r).Find(&result).Error return err } -func (g *execG[T]) Last(ctx context.Context) (T, error) { +func (g execG[T]) Last(ctx context.Context) (T, error) { var r T - err := g.g.db.WithContext(ctx).Last(&r).Error + err := g.g.apply(ctx).Last(&r).Error return r, err } -func (g *execG[T]) Take(ctx context.Context) (T, error) { +func (g execG[T]) Take(ctx context.Context) (T, error) { var r T - err := g.g.db.WithContext(ctx).Take(&r).Error + err := g.g.apply(ctx).Take(&r).Error return r, err } -func (g *execG[T]) Find(ctx context.Context) ([]T, error) { +func (g execG[T]) Find(ctx context.Context) ([]T, error) { var r []T - err := g.g.db.WithContext(ctx).Find(&r).Error + err := g.g.apply(ctx).Find(&r).Error return r, err } -func (g *execG[T]) FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error { +func (g execG[T]) FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error { var data []T - return g.g.db.WithContext(ctx).FindInBatches(&data, batchSize, func(tx *DB, batch int) error { + return g.g.apply(ctx).FindInBatches(&data, batchSize, func(tx *DB, batch int) error { return fc(data, batch) }).Error } -func (g *execG[T]) Row(ctx context.Context) *sql.Row { - return g.g.db.WithContext(ctx).Row() +func (g execG[T]) Row(ctx context.Context) *sql.Row { + return g.g.apply(ctx).Row() } -func (g *execG[T]) Rows(ctx context.Context) (*sql.Rows, error) { - return g.g.db.WithContext(ctx).Rows() +func (g execG[T]) Rows(ctx context.Context) (*sql.Rows, error) { + return g.g.apply(ctx).Rows() } diff --git a/tests/generics_test.go b/tests/generics_test.go index 9e047a55..bceb2f9e 100644 --- a/tests/generics_test.go +++ b/tests/generics_test.go @@ -12,11 +12,10 @@ import ( ) func TestGenericsCreate(t *testing.T) { - generic := gorm.G[User](DB) ctx := context.Background() user := User{Name: "TestGenericsCreate", Age: 18} - err := generic.Create(ctx, &user) + err := gorm.G[User](DB).Create(ctx, &user) if err != nil { t.Fatalf("Create failed: %v", err) } @@ -60,7 +59,7 @@ func TestGenericsCreate(t *testing.T) { mapResult, err := gorm.G[map[string]interface{}](DB).Table("users").Where("name = ?", user.Name).MapColumns(map[string]string{"name": "user_name"}).Take(ctx) if v := mapResult["user_name"]; fmt.Sprint(v) != user.Name { - t.Errorf("failed to find map results, got %v", mapResult) + t.Errorf("failed to find map results, got %v, err %v", mapResult, err) } } @@ -92,6 +91,7 @@ func TestGenericsCreateInBatches(t *testing.T) { found, err := gorm.G[User](DB).Raw("SELECT * FROM users WHERE name LIKE ?", "GenericsCreateInBatches%").Find(ctx) if len(found) != len(batch) { + fmt.Println(found) t.Errorf("expected %d from Raw Find, got %d", len(batch), len(found)) } @@ -278,12 +278,13 @@ func TestGenericsScopes(t *testing.T) { func TestGenericsJoinsAndPreload(t *testing.T) { ctx := context.Background() + db := gorm.G[User](DB) u := User{Name: "GenericsJoins", Company: Company{Name: "GenericsCompany"}} - DB.Create(&u) + db.Create(ctx, &u) // LEFT JOIN + WHERE - result, err := gorm.G[User](DB).Joins("Company").Where("Company.name = ?", u.Company.Name).First(ctx) + result, err := db.Joins("Company").Where("Company.name = ?", u.Company.Name).First(ctx) if err != nil { t.Fatalf("Joins failed: %v", err) } @@ -292,7 +293,7 @@ func TestGenericsJoinsAndPreload(t *testing.T) { } // INNER JOIN + Inline WHERE - result2, err := gorm.G[User](DB).InnerJoins("Company", "Company.name = ?", u.Company.Name).First(ctx) + result2, err := db.InnerJoins("Company", "Company.name = ?", u.Company.Name).First(ctx) if err != nil { t.Fatalf("InnerJoins failed: %v", err) } @@ -301,7 +302,7 @@ func TestGenericsJoinsAndPreload(t *testing.T) { } // Preload - result3, err := gorm.G[User](DB).Preload("Company").Where("name = ?", u.Name).First(ctx) + result3, err := db.Preload("Company").Where("name = ?", u.Name).First(ctx) if err != nil { t.Fatalf("Joins failed: %v", err) }