diff --git a/generics.go b/generics.go index fc5dfaff..f40c73be 100644 --- a/generics.go +++ b/generics.go @@ -49,6 +49,7 @@ type ExecInterface[T any] interface { Scan(ctx context.Context, r interface{}) error First(context.Context) (T, error) Last(ctx context.Context) (T, error) + Take(context.Context) (T, error) Find(ctx context.Context) ([]T, error) FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error Row(ctx context.Context) *sql.Row @@ -232,6 +233,12 @@ func (g *execG[T]) Last(ctx context.Context) (T, error) { return r, err } +func (g *execG[T]) Take(ctx context.Context) (T, error) { + var r T + err := g.g.db.WithContext(ctx).Take(&r).Error + return r, err +} + func (g *execG[T]) Find(ctx context.Context) ([]T, error) { var r []T err := g.g.db.WithContext(ctx).Find(&r).Error diff --git a/tests/generics_test.go b/tests/generics_test.go index 83686f31..9e047a55 100644 --- a/tests/generics_test.go +++ b/tests/generics_test.go @@ -3,6 +3,8 @@ package tests_test import ( "context" "fmt" + "reflect" + "sort" "testing" "gorm.io/gorm" @@ -28,6 +30,12 @@ func TestGenericsCreate(t *testing.T) { t.Errorf("found invalid user, got %v, expect %v", u, user) } + if u, err := gorm.G[User](DB).Where("name = ?", user.Name).Take(ctx); err != nil { + t.Fatalf("failed to find user, got error: %v", err) + } else if u.Name != user.Name || u.ID != user.ID { + t.Errorf("found invalid user, got %v, expect %v", u, user) + } + if u, err := gorm.G[User](DB).Select("name").Where("name = ?", user.Name).First(ctx); err != nil { t.Fatalf("failed to find user, got error: %v", err) } else if u.Name != user.Name || u.Age != 0 { @@ -50,7 +58,7 @@ func TestGenericsCreate(t *testing.T) { t.Errorf("found invalid user, got %v, expect %v", result, user) } - mapResult, err := gorm.G[map[string]interface{}](DB).Table("users").Where("name = ?", user.Name).MapColumns(map[string]string{"name": "user_name"}).First(ctx) + mapResult, err := gorm.G[map[string]interface{}](DB).Table("users").Where("name = ?", user.Name).MapColumns(map[string]string{"name": "user_name"}).Take(ctx) if v := mapResult["user_name"]; fmt.Sprint(v) != user.Name { t.Errorf("failed to find map results, got %v", mapResult) } @@ -302,4 +310,58 @@ func TestGenericsJoinsAndPreload(t *testing.T) { } } -// Distinct, Group, Having +func TestGenericsDistinct(t *testing.T) { + ctx := context.Background() + + batch := []User{ + {Name: "GenericsDistinctDup"}, + {Name: "GenericsDistinctDup"}, + {Name: "GenericsDistinctUnique"}, + } + if err := gorm.G[User](DB).CreateInBatches(ctx, &batch, len(batch)); err != nil { + t.Fatalf("CreateInBatches failed: %v", err) + } + + results, err := gorm.G[User](DB).Where("name like ?", "GenericsDistinct%").Distinct("name").Find(ctx) + if err != nil { + t.Fatalf("Distinct Find failed: %v", err) + } + + if len(results) != 2 { + t.Errorf("expected 2 distinct names, got %d", len(results)) + } + + var names []string + for _, u := range results { + names = append(names, u.Name) + } + sort.Strings(names) + expected := []string{"GenericsDistinctDup", "GenericsDistinctUnique"} + if !reflect.DeepEqual(names, expected) { + t.Errorf("expected names %v, got %v", expected, names) + } +} + +func TestGenericsGroupHaving(t *testing.T) { + ctx := context.Background() + + batch := []User{ + {Name: "GenericsGroupHavingMulti"}, + {Name: "GenericsGroupHavingMulti"}, + {Name: "GenericsGroupHavingSingle"}, + } + if err := gorm.G[User](DB).CreateInBatches(ctx, &batch, len(batch)); err != nil { + t.Fatalf("CreateInBatches failed: %v", err) + } + + grouped, err := gorm.G[User](DB).Where("name like ?", "GenericsGroupHaving%").Group("name").Having("COUNT(*) > ?", 1).Find(ctx) + if err != nil { + t.Fatalf("Group+Having Find failed: %v", err) + } + + if len(grouped) != 1 { + t.Errorf("expected 1 group with count>1, got %d", len(grouped)) + } else if grouped[0].Name != "GenericsGroupHavingMulti" { + t.Errorf("expected group name 'GenericsGroupHavingMulti', got '%s'", grouped[0].Name) + } +}