From f3ff534c54dbfee2b83008ec815b2dd7c4514cd2 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 17 Apr 2025 09:44:25 +0800 Subject: [PATCH] Implement Generics API --- generics.go | 254 +++++++++++++++++++++++++++++++++++++++++ tests/generics_test.go | 173 ++++++++++++++++++++++++++++ 2 files changed, 427 insertions(+) create mode 100644 generics.go create mode 100644 tests/generics_test.go diff --git a/generics.go b/generics.go new file mode 100644 index 00000000..c165cc14 --- /dev/null +++ b/generics.go @@ -0,0 +1,254 @@ +package gorm + +import ( + "context" + "database/sql" + + "gorm.io/gorm/clause" +) + +type Interface[T any] interface { + Raw(sql string, values ...interface{}) ExecInterface[T] + Exec(ctx context.Context, sql string, values ...interface{}) error + CreateInterface[T] +} + +type CreateInterface[T any] interface { + ChainInterface[T] + Table(name string, args ...interface{}) CreateInterface[T] + Create(ctx context.Context, r *T) error + CreateInBatches(ctx context.Context, r *[]T, batchSize int) error +} + +type ChainInterface[T any] interface { + ExecInterface[T] + Scopes(scopes ...func(db *Statement)) ChainInterface[T] + Where(query interface{}, args ...interface{}) ChainInterface[T] + Not(query interface{}, args ...interface{}) ChainInterface[T] + Or(query interface{}, args ...interface{}) ChainInterface[T] + Limit(offset int) ChainInterface[T] + Offset(offset int) ChainInterface[T] + Joins(query string, args ...interface{}) ChainInterface[T] + InnerJoins(query string, args ...interface{}) ChainInterface[T] + Select(query string, args ...interface{}) ChainInterface[T] + Omit(columns ...string) ChainInterface[T] + MapColumns(m map[string]string) ChainInterface[T] + Distinct(args ...interface{}) ChainInterface[T] + Group(name string) ChainInterface[T] + Having(query interface{}, args ...interface{}) ChainInterface[T] + Order(value interface{}) ChainInterface[T] + Preload(query string, args ...interface{}) ChainInterface[T] + + Delete(ctx context.Context) (rowsAffected int, err error) + Update(ctx context.Context, name string, value any) (rowsAffected int, err error) + Updates(ctx context.Context, t T) (rowsAffected int, err error) + Count(ctx context.Context, column string) (result int64, err error) +} + +type ExecInterface[T any] interface { + Scan(ctx context.Context, r interface{}) error + First(context.Context) (T, error) + Last(ctx context.Context) (T, error) + Find(ctx context.Context) ([]T, error) + FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error + Row(ctx context.Context) *sql.Row + Rows(ctx context.Context) (*sql.Rows, error) +} + +func G[T any](db *DB, opts ...clause.Expression) Interface[T] { + v := &g[T]{ + db: db.Session(&Session{NewDB: true}).Clauses(opts...), + opts: opts, + } + + v.createG = &createG[T]{ + chainG: chainG[T]{ + execG: execG[T]{g: v}, + }, + } + return v +} + +type g[T any] struct { + *createG[T] + db *DB + opts []clause.Expression +} + +func (g *g[T]) Raw(sql string, values ...interface{}) ExecInterface[T] { + g.db = g.db.Raw(sql, values...) + return &g.execG +} + +func (g *g[T]) Exec(ctx context.Context, sql string, values ...interface{}) error { + return g.db.WithContext(ctx).Exec(sql, values...).Error +} + +type createG[T any] struct { + chainG[T] +} + +func (g *createG[T]) Table(name string, args ...interface{}) CreateInterface[T] { + g.g.db = g.g.db.Table(name, args...) + return g +} + +func (g *createG[T]) Create(ctx context.Context, r *T) error { + return g.g.db.WithContext(ctx).Create(r).Error +} + +func (g *createG[T]) CreateInBatches(ctx context.Context, r *[]T, batchSize int) error { + return g.g.db.WithContext(ctx).CreateInBatches(r, batchSize).Error +} + +type chainG[T any] struct { + execG[T] +} + +func (g *chainG[T]) Scopes(scopes ...func(db *Statement)) ChainInterface[T] { + for _, fc := range scopes { + fc(g.g.db.Statement) + } + return g +} + +func (g *chainG[T]) Where(query interface{}, args ...interface{}) ChainInterface[T] { + g.g.db = g.g.db.Where(query, args...) + return g +} + +func (g *chainG[T]) Not(query interface{}, args ...interface{}) ChainInterface[T] { + g.g.db = g.g.db.Not(query, args...) + return g +} + +func (g *chainG[T]) Or(query interface{}, args ...interface{}) ChainInterface[T] { + g.g.db = g.g.db.Or(query, args...) + return g +} + +func (g *chainG[T]) Limit(offset int) ChainInterface[T] { + g.g.db = g.g.db.Limit(offset) + return g +} + +func (g *chainG[T]) Offset(offset int) ChainInterface[T] { + g.g.db = g.g.db.Offset(offset) + return g +} + +func (g *chainG[T]) Joins(query string, args ...interface{}) ChainInterface[T] { + g.g.db = g.g.db.Joins(query, args...) + return g +} + +func (g *chainG[T]) InnerJoins(query string, args ...interface{}) ChainInterface[T] { + g.g.db = g.g.db.InnerJoins(query, args...) + return g +} + +func (g *chainG[T]) Select(query string, args ...interface{}) ChainInterface[T] { + g.g.db = g.g.db.Select(query, args...) + return g +} + +func (g *chainG[T]) Omit(columns ...string) ChainInterface[T] { + g.g.db = g.g.db.Omit(columns...) + return g +} + +func (g *chainG[T]) MapColumns(m map[string]string) ChainInterface[T] { + g.g.db = g.g.db.MapColumns(m) + return g +} + +func (g *chainG[T]) Distinct(args ...interface{}) ChainInterface[T] { + g.g.db = g.g.db.Distinct(args...) + return g +} + +func (g *chainG[T]) Group(name string) ChainInterface[T] { + g.g.db = g.g.db.Group(name) + return g +} + +func (g *chainG[T]) Having(query interface{}, args ...interface{}) ChainInterface[T] { + g.g.db = g.g.db.Having(query, args...) + return g +} + +func (g *chainG[T]) Order(value interface{}) ChainInterface[T] { + g.g.db = g.g.db.Order(value) + return g +} + +func (g *chainG[T]) Preload(query string, args ...interface{}) ChainInterface[T] { + g.g.db = g.g.db.Preload(query, args...) + return g +} + +func (g *chainG[T]) Delete(ctx context.Context) (rowsAffected int, err error) { + r := new(T) + res := g.g.db.WithContext(ctx).Delete(r) + return int(res.RowsAffected), res.Error +} + +func (g *chainG[T]) Update(ctx context.Context, name string, value any) (rowsAffected int, err error) { + var r T + res := g.g.db.WithContext(ctx).Model(r).Update(name, value) + return int(res.RowsAffected), res.Error +} + +func (g *chainG[T]) Updates(ctx context.Context, t T) (rowsAffected int, err error) { + res := g.g.db.WithContext(ctx).Updates(t) + return int(res.RowsAffected), res.Error +} + +func (g *chainG[T]) Count(ctx context.Context, column string) (result int64, err error) { + var r T + err = g.g.db.WithContext(ctx).Model(r).Select(column).Count(&result).Error + return +} + +type execG[T any] struct { + g *g[T] +} + +func (g *execG[T]) First(ctx context.Context) (T, error) { + var r T + err := g.g.db.WithContext(ctx).First(&r).Error + return r, err +} + +func (g *execG[T]) Scan(ctx context.Context, result interface{}) error { + var r T + err := g.g.db.WithContext(ctx).Model(r).Find(&result).Error + return err +} + +func (g *execG[T]) Last(ctx context.Context) (T, error) { + var r T + err := g.g.db.WithContext(ctx).Last(&r).Error + return r, err +} + +func (g *execG[T]) Find(ctx context.Context) ([]T, error) { + var r []T + err := g.g.db.WithContext(ctx).Find(&r).Error + return r, err +} + +func (g *execG[T]) FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error { + var data []T + return g.g.db.WithContext(ctx).FindInBatches(data, batchSize, func(tx *DB, batch int) error { + return fc(data, batch) + }).Error +} + +func (g *execG[T]) Row(ctx context.Context) *sql.Row { + return g.g.db.WithContext(ctx).Row() +} + +func (g *execG[T]) Rows(ctx context.Context) (*sql.Rows, error) { + return g.g.db.WithContext(ctx).Rows() +} diff --git a/tests/generics_test.go b/tests/generics_test.go new file mode 100644 index 00000000..4d42e953 --- /dev/null +++ b/tests/generics_test.go @@ -0,0 +1,173 @@ +package tests_test + +import ( + "context" + "testing" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func TestGenericsCreate(t *testing.T) { + generic := gorm.G[User](DB) + ctx := context.Background() + + user := User{Name: "TestGenericsCreate"} + err := generic.Create(ctx, &user) + if err != nil { + t.Fatalf("Create failed: %v", err) + } + if user.ID == 0 { + t.Fatalf("no primary key found for %v", user) + } + + if u, err := gorm.G[User](DB).Where("name = ?", user.Name).First(ctx); err != nil { + t.Fatalf("failed to find user, got error: %v", err) + } else if u.Name != user.Name || u.ID != user.ID { + t.Errorf("found invalid user, got %v, expect %v", u, user) + } + + result := struct { + ID int + Name string + }{} + if err := gorm.G[User](DB).Where("name = ?", user.Name).Scan(ctx, &result); err != nil { + t.Fatalf("failed to scan user, got error: %v", err) + } else if result.Name != user.Name || uint(result.ID) != user.ID { + t.Errorf("found invalid user, got %v, expect %v", result, user) + } +} + +func TestGenericsCreateInBatches(t *testing.T) { + batch := []User{ + {Name: "GenericsCreateInBatches1"}, + {Name: "GenericsCreateInBatches2"}, + {Name: "GenericsCreateInBatches3"}, + } + ctx := context.Background() + + if err := gorm.G[User](DB).CreateInBatches(ctx, &batch, 2); err != nil { + t.Fatalf("CreateInBatches failed: %v", err) + } + + for _, u := range batch { + if u.ID == 0 { + t.Fatalf("no primary key found for %v", u) + } + } + + count, err := gorm.G[User](DB).Where("name like ?", "GenericsCreateInBatches%").Count(ctx, "*") + if err != nil { + t.Fatalf("Count failed: %v", err) + } + if count != 3 { + t.Errorf("expected 3 records, got %d", count) + } + + found, err := gorm.G[User](DB).Raw("SELECT * FROM users WHERE name LIKE ?", "GenericsCreateInBatches%").Find(ctx) + if len(found) != len(batch) { + t.Errorf("expected %d from Raw Find, got %d", len(batch), len(found)) + } +} + +func TestGenericsExecAndUpdate(t *testing.T) { + ctx := context.Background() + + name := "GenericsExec" + if err := gorm.G[User](DB).Exec(ctx, "INSERT INTO users(name) VALUES(?)", name); err != nil { + t.Fatalf("Exec insert failed: %v", err) + } + + u, err := gorm.G[User](DB).Where("name = ?", name).First(ctx) + if err != nil { + t.Fatalf("failed to find user, got error: %v", err) + } else if u.Name != name || u.ID == 0 { + t.Errorf("found invalid user, got %v", u) + } + + name += "Update" + rows, err := gorm.G[User](DB).Where("id = ?", u.ID).Update(ctx, "name", name) + if rows != 1 { + t.Fatalf("failed to get affected rows, got %d, should be %d", rows, 1) + } + + nu, err := gorm.G[User](DB).Where("name = ?", name).First(ctx) + if err != nil { + t.Fatalf("failed to find user, got error: %v", err) + } else if nu.Name != name || u.ID != nu.ID { + t.Fatalf("found invalid user, got %v, expect %v", nu.ID, u.ID) + } + + rows, err = gorm.G[User](DB).Where("id = ?", u.ID).Updates(ctx, User{Name: "GenericsExecUpdates", Age: 18}) + if rows != 1 { + t.Fatalf("failed to get affected rows, got %d, should be %d", rows, 1) + } + + nu, err = gorm.G[User](DB).Where("id = ?", u.ID).Last(ctx) + if err != nil { + t.Fatalf("failed to find user, got error: %v", err) + } else if nu.Name != "GenericsExecUpdates" || nu.Age != 18 || u.ID != nu.ID { + t.Fatalf("found invalid user, got %v, expect %v", nu.ID, u.ID) + } +} + +func TestGenericsRow(t *testing.T) { + ctx := context.Background() + + user := User{Name: "GenericsRow"} + if err := gorm.G[User](DB).Create(ctx, &user); err != nil { + t.Fatalf("Create failed: %v", err) + } + + row := gorm.G[User](DB).Raw("SELECT name FROM users WHERE id = ?", user.ID).Row(ctx) + var name string + if err := row.Scan(&name); err != nil { + t.Fatalf("Row scan failed: %v", err) + } + if name != user.Name { + t.Errorf("expected %s, got %s", user.Name, name) + } + + user2 := User{Name: "GenericsRow2"} + if err := gorm.G[User](DB).Create(ctx, &user2); err != nil { + t.Fatalf("Create failed: %v", err) + } + rows, err := gorm.G[User](DB).Raw("SELECT name FROM users WHERE id IN ?", []uint{user.ID, user2.ID}).Rows(ctx) + if err != nil { + t.Fatalf("Rows failed: %v", err) + } + + count := 0 + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + t.Fatalf("rows.Scan failed: %v", err) + } + count++ + } + if count != 2 { + t.Errorf("expected 2 rows, got %d", count) + } +} + +func TestGenericsDelete(t *testing.T) { + ctx := context.Background() + + u := User{Name: "GenericsDelete"} + if err := gorm.G[User](DB).Create(ctx, &u); err != nil { + t.Fatalf("Create failed: %v", err) + } + + rows, err := gorm.G[User](DB).Where("id = ?", u.ID).Delete(ctx) + if err != nil { + t.Fatalf("Delete failed: %v", err) + } + if rows != 1 { + t.Errorf("expected 1 row deleted, got %d", rows) + } + + _, err = gorm.G[User](DB).Where("id = ?", u.ID).First(ctx) + if err != gorm.ErrRecordNotFound { + t.Fatalf("User after delete failed: %v", err) + } +}