diff --git a/generics.go b/generics.go index 5930a6ce..9dd1af7d 100644 --- a/generics.go +++ b/generics.go @@ -127,6 +127,11 @@ type chainG[T any] struct { execG[T] } +func (c chainG[T]) getInstance() *DB { + var r T + return c.g.apply(context.Background()).Model(r).getInstance() +} + func (c chainG[T]) with(op op) chainG[T] { return chainG[T]{ execG: execG[T]{g: &g[T]{ diff --git a/statement.go b/statement.go index 39e05d09..11791c3a 100644 --- a/statement.go +++ b/statement.go @@ -205,19 +205,20 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { } else { writer.WriteString("(NULL)") } - case *DB: - subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance() - if v.Statement.SQL.Len() > 0 { + case interface{ getInstance() *DB }: + cv := v.getInstance() + subdb := cv.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance() + if cv.Statement.SQL.Len() > 0 { var ( vars = subdb.Statement.Vars - sql = v.Statement.SQL.String() + sql = cv.Statement.SQL.String() ) subdb.Statement.Vars = make([]interface{}, 0, len(vars)) for _, vv := range vars { subdb.Statement.Vars = append(subdb.Statement.Vars, vv) bindvar := strings.Builder{} - v.Dialector.BindVarTo(&bindvar, subdb.Statement, vv) + cv.Dialector.BindVarTo(&bindvar, subdb.Statement, vv) sql = strings.Replace(sql, bindvar.String(), "?", 1) } diff --git a/tests/generics_test.go b/tests/generics_test.go index 036f1cf9..1587d090 100644 --- a/tests/generics_test.go +++ b/tests/generics_test.go @@ -366,3 +366,35 @@ func TestGenericsGroupHaving(t *testing.T) { t.Errorf("expected group name 'GenericsGroupHavingMulti', got '%s'", grouped[0].Name) } } + +func TestGenericsSubQuery(t *testing.T) { + ctx := context.Background() + users := []User{ + {Name: "GenericsSubquery_1", Age: 10}, + {Name: "GenericsSubquery_2", Age: 20}, + {Name: "GenericsSubquery_3", Age: 30}, + {Name: "GenericsSubquery_4", Age: 40}, + } + + if err := gorm.G[User](DB).CreateInBatches(ctx, &users, len(users)); err != nil { + t.Fatalf("CreateInBatches failed: %v", err) + } + + results, err := gorm.G[User](DB).Where("name IN (?)", gorm.G[User](DB).Select("name").Where("name LIKE ?", "GenericsSubquery%")).Find(ctx) + if err != nil { + t.Fatalf("got error: %v", err) + } + + if len(results) != 4 { + t.Errorf("Four users should be found, instead found %d", len(results)) + } + + results, err = gorm.G[User](DB).Where("name IN (?)", gorm.G[User](DB).Select("name").Where("name IN ?", []string{"GenericsSubquery_1", "GenericsSubquery_2"}).Or("name = ?", "GenericsSubquery_3")).Find(ctx) + if err != nil { + t.Fatalf("got error: %v", err) + } + + if len(results) != 3 { + t.Errorf("Three users should be found, instead found %d", len(results)) + } +} diff --git a/tests/go.mod b/tests/go.mod index c87ccca1..2d647b08 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -1,6 +1,6 @@ module gorm.io/gorm/tests -go 1.18.0 +go 1.23.0 require ( github.com/google/uuid v1.6.0