diff --git a/main.go b/main.go index 0f2fd1f5..9af376c0 100644 --- a/main.go +++ b/main.go @@ -168,6 +168,14 @@ func (s *DB) NewScope(value interface{}) *Scope { return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value} } +// Subquery returns the query as eprx object +func (s *DB) Subquery() *expr { + scope := s.NewScope(s.Value) + scope.prepareQuerySQL() + + return Expr("("+scope.SQL+")", scope.SQLVars...) +} + // Where return a new relation, filter records with given conditions, accepts `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/crud.html#query func (s *DB) Where(query interface{}, args ...interface{}) *DB { return s.clone().search.Where(query, args...).db diff --git a/main_test.go b/main_test.go index 3b1433cf..27f8eb1f 100644 --- a/main_test.go +++ b/main_test.go @@ -607,6 +607,32 @@ func TestHaving(t *testing.T) { } } +func TestQueryBuilderSubselectInWhere(t *testing.T) { + user := User{Name: "user1", Email: "root@user1.com", Age: 32} + DB.Save(&user) + user = User{Name: "user2", Email: "nobody@user2.com", Age: 16} + DB.Save(&user) + user = User{Name: "user3", Email: "root@user3.com", Age: 64} + DB.Save(&user) + user = User{Name: "user4", Email: "somebody@user3.com", Age: 128} + DB.Save(&user) + + var users []User + DB.Select("*").Where("name IN ?", DB. + Select("name").Table("users").Where("email LIKE ?", "root@%").Subquery()).Find(&users) + + if len(users) != 2 { + t.Errorf("Two users should be found, instead found %d", len(users)) + } + + DB.Select("*").Where("age >= ?", DB. + Select("AVG(age)").Table("users").Subquery()).Find(&users) + + if len(users) != 2 { + t.Errorf("Two users should be found, instead found %d", len(users)) + } +} + func DialectHasTzSupport() bool { // NB: mssql and FoundationDB do not support time zones. if dialect := os.Getenv("GORM_DIALECT"); dialect == "mssql" || dialect == "foundation" { diff --git a/scope.go b/scope.go index 4fcb84c1..4e259cc1 100644 --- a/scope.go +++ b/scope.go @@ -256,7 +256,11 @@ func (scope *Scope) AddToVars(value interface{}) string { if expr, ok := value.(*expr); ok { exp := expr.expr for _, arg := range expr.args { - exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) + if pgParameterRegexp.MatchString(exp) { + exp = pgParameterRegexp.ReplaceAllLiteralString(exp, scope.AddToVars(arg)) + } else { + exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) + } } return exp } @@ -452,6 +456,7 @@ var ( isNumberRegexp = regexp.MustCompile("^\\s*\\d+\\s*$") // match if string is number comparisonRegexp = regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS|IN) ") countingQueryRegexp = regexp.MustCompile("(?i)^count(.+)$") + pgParameterRegexp = regexp.MustCompile(`\$[0-9]+`) // to exchange postgres `$1` style parameter placeholders ) func (scope *Scope) quoteIfPossible(str string) string {