Implement Generics API
This commit is contained in:
		
							parent
							
								
									a9d27293de
								
							
						
					
					
						commit
						f3ff534c54
					
				
							
								
								
									
										254
									
								
								generics.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										254
									
								
								generics.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,254 @@ | ||||
| package gorm | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql" | ||||
| 
 | ||||
| 	"gorm.io/gorm/clause" | ||||
| ) | ||||
| 
 | ||||
| type Interface[T any] interface { | ||||
| 	Raw(sql string, values ...interface{}) ExecInterface[T] | ||||
| 	Exec(ctx context.Context, sql string, values ...interface{}) error | ||||
| 	CreateInterface[T] | ||||
| } | ||||
| 
 | ||||
| type CreateInterface[T any] interface { | ||||
| 	ChainInterface[T] | ||||
| 	Table(name string, args ...interface{}) CreateInterface[T] | ||||
| 	Create(ctx context.Context, r *T) error | ||||
| 	CreateInBatches(ctx context.Context, r *[]T, batchSize int) error | ||||
| } | ||||
| 
 | ||||
| type ChainInterface[T any] interface { | ||||
| 	ExecInterface[T] | ||||
| 	Scopes(scopes ...func(db *Statement)) ChainInterface[T] | ||||
| 	Where(query interface{}, args ...interface{}) ChainInterface[T] | ||||
| 	Not(query interface{}, args ...interface{}) ChainInterface[T] | ||||
| 	Or(query interface{}, args ...interface{}) ChainInterface[T] | ||||
| 	Limit(offset int) ChainInterface[T] | ||||
| 	Offset(offset int) ChainInterface[T] | ||||
| 	Joins(query string, args ...interface{}) ChainInterface[T] | ||||
| 	InnerJoins(query string, args ...interface{}) ChainInterface[T] | ||||
| 	Select(query string, args ...interface{}) ChainInterface[T] | ||||
| 	Omit(columns ...string) ChainInterface[T] | ||||
| 	MapColumns(m map[string]string) ChainInterface[T] | ||||
| 	Distinct(args ...interface{}) ChainInterface[T] | ||||
| 	Group(name string) ChainInterface[T] | ||||
| 	Having(query interface{}, args ...interface{}) ChainInterface[T] | ||||
| 	Order(value interface{}) ChainInterface[T] | ||||
| 	Preload(query string, args ...interface{}) ChainInterface[T] | ||||
| 
 | ||||
| 	Delete(ctx context.Context) (rowsAffected int, err error) | ||||
| 	Update(ctx context.Context, name string, value any) (rowsAffected int, err error) | ||||
| 	Updates(ctx context.Context, t T) (rowsAffected int, err error) | ||||
| 	Count(ctx context.Context, column string) (result int64, err error) | ||||
| } | ||||
| 
 | ||||
| type ExecInterface[T any] interface { | ||||
| 	Scan(ctx context.Context, r interface{}) error | ||||
| 	First(context.Context) (T, error) | ||||
| 	Last(ctx context.Context) (T, error) | ||||
| 	Find(ctx context.Context) ([]T, error) | ||||
| 	FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error | ||||
| 	Row(ctx context.Context) *sql.Row | ||||
| 	Rows(ctx context.Context) (*sql.Rows, error) | ||||
| } | ||||
| 
 | ||||
| func G[T any](db *DB, opts ...clause.Expression) Interface[T] { | ||||
| 	v := &g[T]{ | ||||
| 		db:   db.Session(&Session{NewDB: true}).Clauses(opts...), | ||||
| 		opts: opts, | ||||
| 	} | ||||
| 
 | ||||
| 	v.createG = &createG[T]{ | ||||
| 		chainG: chainG[T]{ | ||||
| 			execG: execG[T]{g: v}, | ||||
| 		}, | ||||
| 	} | ||||
| 	return v | ||||
| } | ||||
| 
 | ||||
| type g[T any] struct { | ||||
| 	*createG[T] | ||||
| 	db   *DB | ||||
| 	opts []clause.Expression | ||||
| } | ||||
| 
 | ||||
| func (g *g[T]) Raw(sql string, values ...interface{}) ExecInterface[T] { | ||||
| 	g.db = g.db.Raw(sql, values...) | ||||
| 	return &g.execG | ||||
| } | ||||
| 
 | ||||
| func (g *g[T]) Exec(ctx context.Context, sql string, values ...interface{}) error { | ||||
| 	return g.db.WithContext(ctx).Exec(sql, values...).Error | ||||
| } | ||||
| 
 | ||||
| type createG[T any] struct { | ||||
| 	chainG[T] | ||||
| } | ||||
| 
 | ||||
| func (g *createG[T]) Table(name string, args ...interface{}) CreateInterface[T] { | ||||
| 	g.g.db = g.g.db.Table(name, args...) | ||||
| 	return g | ||||
| } | ||||
| 
 | ||||
| func (g *createG[T]) Create(ctx context.Context, r *T) error { | ||||
| 	return g.g.db.WithContext(ctx).Create(r).Error | ||||
| } | ||||
| 
 | ||||
| func (g *createG[T]) CreateInBatches(ctx context.Context, r *[]T, batchSize int) error { | ||||
| 	return g.g.db.WithContext(ctx).CreateInBatches(r, batchSize).Error | ||||
| } | ||||
| 
 | ||||
| type chainG[T any] struct { | ||||
| 	execG[T] | ||||
| } | ||||
| 
 | ||||
| func (g *chainG[T]) Scopes(scopes ...func(db *Statement)) ChainInterface[T] { | ||||
| 	for _, fc := range scopes { | ||||
| 		fc(g.g.db.Statement) | ||||
| 	} | ||||
| 	return g | ||||
| } | ||||
| 
 | ||||
| func (g *chainG[T]) Where(query interface{}, args ...interface{}) ChainInterface[T] { | ||||
| 	g.g.db = g.g.db.Where(query, args...) | ||||
| 	return g | ||||
| } | ||||
| 
 | ||||
| func (g *chainG[T]) Not(query interface{}, args ...interface{}) ChainInterface[T] { | ||||
| 	g.g.db = g.g.db.Not(query, args...) | ||||
| 	return g | ||||
| } | ||||
| 
 | ||||
| func (g *chainG[T]) Or(query interface{}, args ...interface{}) ChainInterface[T] { | ||||
| 	g.g.db = g.g.db.Or(query, args...) | ||||
| 	return g | ||||
| } | ||||
| 
 | ||||
| func (g *chainG[T]) Limit(offset int) ChainInterface[T] { | ||||
| 	g.g.db = g.g.db.Limit(offset) | ||||
| 	return g | ||||
| } | ||||
| 
 | ||||
| func (g *chainG[T]) Offset(offset int) ChainInterface[T] { | ||||
| 	g.g.db = g.g.db.Offset(offset) | ||||
| 	return g | ||||
| } | ||||
| 
 | ||||
| func (g *chainG[T]) Joins(query string, args ...interface{}) ChainInterface[T] { | ||||
| 	g.g.db = g.g.db.Joins(query, args...) | ||||
| 	return g | ||||
| } | ||||
| 
 | ||||
| func (g *chainG[T]) InnerJoins(query string, args ...interface{}) ChainInterface[T] { | ||||
| 	g.g.db = g.g.db.InnerJoins(query, args...) | ||||
| 	return g | ||||
| } | ||||
| 
 | ||||
| func (g *chainG[T]) Select(query string, args ...interface{}) ChainInterface[T] { | ||||
| 	g.g.db = g.g.db.Select(query, args...) | ||||
| 	return g | ||||
| } | ||||
| 
 | ||||
| func (g *chainG[T]) Omit(columns ...string) ChainInterface[T] { | ||||
| 	g.g.db = g.g.db.Omit(columns...) | ||||
| 	return g | ||||
| } | ||||
| 
 | ||||
| func (g *chainG[T]) MapColumns(m map[string]string) ChainInterface[T] { | ||||
| 	g.g.db = g.g.db.MapColumns(m) | ||||
| 	return g | ||||
| } | ||||
| 
 | ||||
| func (g *chainG[T]) Distinct(args ...interface{}) ChainInterface[T] { | ||||
| 	g.g.db = g.g.db.Distinct(args...) | ||||
| 	return g | ||||
| } | ||||
| 
 | ||||
| func (g *chainG[T]) Group(name string) ChainInterface[T] { | ||||
| 	g.g.db = g.g.db.Group(name) | ||||
| 	return g | ||||
| } | ||||
| 
 | ||||
| func (g *chainG[T]) Having(query interface{}, args ...interface{}) ChainInterface[T] { | ||||
| 	g.g.db = g.g.db.Having(query, args...) | ||||
| 	return g | ||||
| } | ||||
| 
 | ||||
| func (g *chainG[T]) Order(value interface{}) ChainInterface[T] { | ||||
| 	g.g.db = g.g.db.Order(value) | ||||
| 	return g | ||||
| } | ||||
| 
 | ||||
| func (g *chainG[T]) Preload(query string, args ...interface{}) ChainInterface[T] { | ||||
| 	g.g.db = g.g.db.Preload(query, args...) | ||||
| 	return g | ||||
| } | ||||
| 
 | ||||
| func (g *chainG[T]) Delete(ctx context.Context) (rowsAffected int, err error) { | ||||
| 	r := new(T) | ||||
| 	res := g.g.db.WithContext(ctx).Delete(r) | ||||
| 	return int(res.RowsAffected), res.Error | ||||
| } | ||||
| 
 | ||||
| func (g *chainG[T]) Update(ctx context.Context, name string, value any) (rowsAffected int, err error) { | ||||
| 	var r T | ||||
| 	res := g.g.db.WithContext(ctx).Model(r).Update(name, value) | ||||
| 	return int(res.RowsAffected), res.Error | ||||
| } | ||||
| 
 | ||||
| func (g *chainG[T]) Updates(ctx context.Context, t T) (rowsAffected int, err error) { | ||||
| 	res := g.g.db.WithContext(ctx).Updates(t) | ||||
| 	return int(res.RowsAffected), res.Error | ||||
| } | ||||
| 
 | ||||
| func (g *chainG[T]) Count(ctx context.Context, column string) (result int64, err error) { | ||||
| 	var r T | ||||
| 	err = g.g.db.WithContext(ctx).Model(r).Select(column).Count(&result).Error | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| type execG[T any] struct { | ||||
| 	g *g[T] | ||||
| } | ||||
| 
 | ||||
| func (g *execG[T]) First(ctx context.Context) (T, error) { | ||||
| 	var r T | ||||
| 	err := g.g.db.WithContext(ctx).First(&r).Error | ||||
| 	return r, err | ||||
| } | ||||
| 
 | ||||
| func (g *execG[T]) Scan(ctx context.Context, result interface{}) error { | ||||
| 	var r T | ||||
| 	err := g.g.db.WithContext(ctx).Model(r).Find(&result).Error | ||||
| 	return err | ||||
| } | ||||
| 
 | ||||
| func (g *execG[T]) Last(ctx context.Context) (T, error) { | ||||
| 	var r T | ||||
| 	err := g.g.db.WithContext(ctx).Last(&r).Error | ||||
| 	return r, err | ||||
| } | ||||
| 
 | ||||
| func (g *execG[T]) Find(ctx context.Context) ([]T, error) { | ||||
| 	var r []T | ||||
| 	err := g.g.db.WithContext(ctx).Find(&r).Error | ||||
| 	return r, err | ||||
| } | ||||
| 
 | ||||
| func (g *execG[T]) FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error { | ||||
| 	var data []T | ||||
| 	return g.g.db.WithContext(ctx).FindInBatches(data, batchSize, func(tx *DB, batch int) error { | ||||
| 		return fc(data, batch) | ||||
| 	}).Error | ||||
| } | ||||
| 
 | ||||
| func (g *execG[T]) Row(ctx context.Context) *sql.Row { | ||||
| 	return g.g.db.WithContext(ctx).Row() | ||||
| } | ||||
| 
 | ||||
| func (g *execG[T]) Rows(ctx context.Context) (*sql.Rows, error) { | ||||
| 	return g.g.db.WithContext(ctx).Rows() | ||||
| } | ||||
							
								
								
									
										173
									
								
								tests/generics_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										173
									
								
								tests/generics_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,173 @@ | ||||
| package tests_test | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"gorm.io/gorm" | ||||
| 	. "gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| func TestGenericsCreate(t *testing.T) { | ||||
| 	generic := gorm.G[User](DB) | ||||
| 	ctx := context.Background() | ||||
| 
 | ||||
| 	user := User{Name: "TestGenericsCreate"} | ||||
| 	err := generic.Create(ctx, &user) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Create failed: %v", err) | ||||
| 	} | ||||
| 	if user.ID == 0 { | ||||
| 		t.Fatalf("no primary key found for %v", user) | ||||
| 	} | ||||
| 
 | ||||
| 	if u, err := gorm.G[User](DB).Where("name = ?", user.Name).First(ctx); err != nil { | ||||
| 		t.Fatalf("failed to find user, got error: %v", err) | ||||
| 	} else if u.Name != user.Name || u.ID != user.ID { | ||||
| 		t.Errorf("found invalid user, got %v, expect %v", u, user) | ||||
| 	} | ||||
| 
 | ||||
| 	result := struct { | ||||
| 		ID   int | ||||
| 		Name string | ||||
| 	}{} | ||||
| 	if err := gorm.G[User](DB).Where("name = ?", user.Name).Scan(ctx, &result); err != nil { | ||||
| 		t.Fatalf("failed to scan user, got error: %v", err) | ||||
| 	} else if result.Name != user.Name || uint(result.ID) != user.ID { | ||||
| 		t.Errorf("found invalid user, got %v, expect %v", result, user) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestGenericsCreateInBatches(t *testing.T) { | ||||
| 	batch := []User{ | ||||
| 		{Name: "GenericsCreateInBatches1"}, | ||||
| 		{Name: "GenericsCreateInBatches2"}, | ||||
| 		{Name: "GenericsCreateInBatches3"}, | ||||
| 	} | ||||
| 	ctx := context.Background() | ||||
| 
 | ||||
| 	if err := gorm.G[User](DB).CreateInBatches(ctx, &batch, 2); err != nil { | ||||
| 		t.Fatalf("CreateInBatches failed: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	for _, u := range batch { | ||||
| 		if u.ID == 0 { | ||||
| 			t.Fatalf("no primary key found for %v", u) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	count, err := gorm.G[User](DB).Where("name like ?", "GenericsCreateInBatches%").Count(ctx, "*") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Count failed: %v", err) | ||||
| 	} | ||||
| 	if count != 3 { | ||||
| 		t.Errorf("expected 3 records, got %d", count) | ||||
| 	} | ||||
| 
 | ||||
| 	found, err := gorm.G[User](DB).Raw("SELECT * FROM users WHERE name LIKE ?", "GenericsCreateInBatches%").Find(ctx) | ||||
| 	if len(found) != len(batch) { | ||||
| 		t.Errorf("expected %d from Raw Find, got %d", len(batch), len(found)) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestGenericsExecAndUpdate(t *testing.T) { | ||||
| 	ctx := context.Background() | ||||
| 
 | ||||
| 	name := "GenericsExec" | ||||
| 	if err := gorm.G[User](DB).Exec(ctx, "INSERT INTO users(name) VALUES(?)", name); err != nil { | ||||
| 		t.Fatalf("Exec insert failed: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	u, err := gorm.G[User](DB).Where("name = ?", name).First(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("failed to find user, got error: %v", err) | ||||
| 	} else if u.Name != name || u.ID == 0 { | ||||
| 		t.Errorf("found invalid user, got %v", u) | ||||
| 	} | ||||
| 
 | ||||
| 	name += "Update" | ||||
| 	rows, err := gorm.G[User](DB).Where("id = ?", u.ID).Update(ctx, "name", name) | ||||
| 	if rows != 1 { | ||||
| 		t.Fatalf("failed to get affected rows, got %d, should be %d", rows, 1) | ||||
| 	} | ||||
| 
 | ||||
| 	nu, err := gorm.G[User](DB).Where("name = ?", name).First(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("failed to find user, got error: %v", err) | ||||
| 	} else if nu.Name != name || u.ID != nu.ID { | ||||
| 		t.Fatalf("found invalid user, got %v, expect %v", nu.ID, u.ID) | ||||
| 	} | ||||
| 
 | ||||
| 	rows, err = gorm.G[User](DB).Where("id = ?", u.ID).Updates(ctx, User{Name: "GenericsExecUpdates", Age: 18}) | ||||
| 	if rows != 1 { | ||||
| 		t.Fatalf("failed to get affected rows, got %d, should be %d", rows, 1) | ||||
| 	} | ||||
| 
 | ||||
| 	nu, err = gorm.G[User](DB).Where("id = ?", u.ID).Last(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("failed to find user, got error: %v", err) | ||||
| 	} else if nu.Name != "GenericsExecUpdates" || nu.Age != 18 || u.ID != nu.ID { | ||||
| 		t.Fatalf("found invalid user, got %v, expect %v", nu.ID, u.ID) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestGenericsRow(t *testing.T) { | ||||
| 	ctx := context.Background() | ||||
| 
 | ||||
| 	user := User{Name: "GenericsRow"} | ||||
| 	if err := gorm.G[User](DB).Create(ctx, &user); err != nil { | ||||
| 		t.Fatalf("Create failed: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	row := gorm.G[User](DB).Raw("SELECT name FROM users WHERE id = ?", user.ID).Row(ctx) | ||||
| 	var name string | ||||
| 	if err := row.Scan(&name); err != nil { | ||||
| 		t.Fatalf("Row scan failed: %v", err) | ||||
| 	} | ||||
| 	if name != user.Name { | ||||
| 		t.Errorf("expected %s, got %s", user.Name, name) | ||||
| 	} | ||||
| 
 | ||||
| 	user2 := User{Name: "GenericsRow2"} | ||||
| 	if err := gorm.G[User](DB).Create(ctx, &user2); err != nil { | ||||
| 		t.Fatalf("Create failed: %v", err) | ||||
| 	} | ||||
| 	rows, err := gorm.G[User](DB).Raw("SELECT name FROM users WHERE id IN ?", []uint{user.ID, user2.ID}).Rows(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Rows failed: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	count := 0 | ||||
| 	for rows.Next() { | ||||
| 		var name string | ||||
| 		if err := rows.Scan(&name); err != nil { | ||||
| 			t.Fatalf("rows.Scan failed: %v", err) | ||||
| 		} | ||||
| 		count++ | ||||
| 	} | ||||
| 	if count != 2 { | ||||
| 		t.Errorf("expected 2 rows, got %d", count) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestGenericsDelete(t *testing.T) { | ||||
| 	ctx := context.Background() | ||||
| 
 | ||||
| 	u := User{Name: "GenericsDelete"} | ||||
| 	if err := gorm.G[User](DB).Create(ctx, &u); err != nil { | ||||
| 		t.Fatalf("Create failed: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	rows, err := gorm.G[User](DB).Where("id = ?", u.ID).Delete(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Delete failed: %v", err) | ||||
| 	} | ||||
| 	if rows != 1 { | ||||
| 		t.Errorf("expected 1 row deleted, got %d", rows) | ||||
| 	} | ||||
| 
 | ||||
| 	_, err = gorm.G[User](DB).Where("id = ?", u.ID).First(ctx) | ||||
| 	if err != gorm.ErrRecordNotFound { | ||||
| 		t.Fatalf("User after delete failed: %v", err) | ||||
| 	} | ||||
| } | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu