From 47ec74e014593a013e0b0fa434140e41b165907b Mon Sep 17 00:00:00 2001 From: ManReinsp Date: Wed, 26 Jul 2017 12:36:23 +0200 Subject: [PATCH 1/5] Added support for subqueries in Where and Having clauses Queries can be transformed into subqueries by calling .Subquery() on a db object See main_test.go:TestQueryBuilderSubselectInWhere --- main.go | 8 ++++++++ main_test.go | 26 ++++++++++++++++++++++++++ scope.go | 7 ++++++- 3 files changed, 40 insertions(+), 1 deletion(-) diff --git a/main.go b/main.go index 0f2fd1f5..9af376c0 100644 --- a/main.go +++ b/main.go @@ -168,6 +168,14 @@ func (s *DB) NewScope(value interface{}) *Scope { return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value} } +// Subquery returns the query as eprx object +func (s *DB) Subquery() *expr { + scope := s.NewScope(s.Value) + scope.prepareQuerySQL() + + return Expr("("+scope.SQL+")", scope.SQLVars...) +} + // Where return a new relation, filter records with given conditions, accepts `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/crud.html#query func (s *DB) Where(query interface{}, args ...interface{}) *DB { return s.clone().search.Where(query, args...).db diff --git a/main_test.go b/main_test.go index 3b1433cf..27f8eb1f 100644 --- a/main_test.go +++ b/main_test.go @@ -607,6 +607,32 @@ func TestHaving(t *testing.T) { } } +func TestQueryBuilderSubselectInWhere(t *testing.T) { + user := User{Name: "user1", Email: "root@user1.com", Age: 32} + DB.Save(&user) + user = User{Name: "user2", Email: "nobody@user2.com", Age: 16} + DB.Save(&user) + user = User{Name: "user3", Email: "root@user3.com", Age: 64} + DB.Save(&user) + user = User{Name: "user4", Email: "somebody@user3.com", Age: 128} + DB.Save(&user) + + var users []User + DB.Select("*").Where("name IN ?", DB. + Select("name").Table("users").Where("email LIKE ?", "root@%").Subquery()).Find(&users) + + if len(users) != 2 { + t.Errorf("Two users should be found, instead found %d", len(users)) + } + + DB.Select("*").Where("age >= ?", DB. + Select("AVG(age)").Table("users").Subquery()).Find(&users) + + if len(users) != 2 { + t.Errorf("Two users should be found, instead found %d", len(users)) + } +} + func DialectHasTzSupport() bool { // NB: mssql and FoundationDB do not support time zones. if dialect := os.Getenv("GORM_DIALECT"); dialect == "mssql" || dialect == "foundation" { diff --git a/scope.go b/scope.go index 4fcb84c1..4e259cc1 100644 --- a/scope.go +++ b/scope.go @@ -256,7 +256,11 @@ func (scope *Scope) AddToVars(value interface{}) string { if expr, ok := value.(*expr); ok { exp := expr.expr for _, arg := range expr.args { - exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) + if pgParameterRegexp.MatchString(exp) { + exp = pgParameterRegexp.ReplaceAllLiteralString(exp, scope.AddToVars(arg)) + } else { + exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) + } } return exp } @@ -452,6 +456,7 @@ var ( isNumberRegexp = regexp.MustCompile("^\\s*\\d+\\s*$") // match if string is number comparisonRegexp = regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS|IN) ") countingQueryRegexp = regexp.MustCompile("(?i)^count(.+)$") + pgParameterRegexp = regexp.MustCompile(`\$[0-9]+`) // to exchange postgres `$1` style parameter placeholders ) func (scope *Scope) quoteIfPossible(str string) string { From 928f99dbb579ebf15012deecded705c957cfbdcc Mon Sep 17 00:00:00 2001 From: ManReinsp Date: Wed, 26 Jul 2017 14:46:05 +0200 Subject: [PATCH 2/5] Fixed comment spacing --- scope.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scope.go b/scope.go index 4e259cc1..a8ea3459 100644 --- a/scope.go +++ b/scope.go @@ -456,7 +456,7 @@ var ( isNumberRegexp = regexp.MustCompile("^\\s*\\d+\\s*$") // match if string is number comparisonRegexp = regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS|IN) ") countingQueryRegexp = regexp.MustCompile("(?i)^count(.+)$") - pgParameterRegexp = regexp.MustCompile(`\$[0-9]+`) // to exchange postgres `$1` style parameter placeholders + pgParameterRegexp = regexp.MustCompile(`\$[0-9]+`) // to exchange postgres `$1` style parameter placeholders ) func (scope *Scope) quoteIfPossible(str string) string { From 84847fdd470dfdff44a2580c4bbebcc0ee8904fa Mon Sep 17 00:00:00 2001 From: ManReinsp Date: Fri, 28 Jul 2017 08:26:57 +0200 Subject: [PATCH 3/5] Improved questionmark parameter placeholder replacementw --- scope.go | 43 +++++++++++++++++++++++++++---------------- 1 file changed, 27 insertions(+), 16 deletions(-) diff --git a/scope.go b/scope.go index a8ea3459..b0f0a90a 100644 --- a/scope.go +++ b/scope.go @@ -256,11 +256,7 @@ func (scope *Scope) AddToVars(value interface{}) string { if expr, ok := value.(*expr); ok { exp := expr.expr for _, arg := range expr.args { - if pgParameterRegexp.MatchString(exp) { - exp = pgParameterRegexp.ReplaceAllLiteralString(exp, scope.AddToVars(arg)) - } else { - exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) - } + exp = scope.replaceParameterPlaceholder(exp, arg) } return exp } @@ -466,6 +462,21 @@ func (scope *Scope) quoteIfPossible(str string) string { return str } +func (scope *Scope) replaceParameterPlaceholderLiteral(sql string, parameter interface{}, addToVars bool) string { + if scope.Dialect().GetName() == "postgres" && pgParameterRegexp.MatchString(sql) { + sql = pgParameterRegexp.ReplaceAllLiteralString(sql, "?") + } + if val, ok := parameter.(string); ok && !addToVars { + return strings.Replace(sql, "?", val, 1) + } + + return strings.Replace(sql, "?", scope.AddToVars(parameter), 1) +} + +func (scope *Scope) replaceParameterPlaceholder(sql string, parameter interface{}) string { + return scope.replaceParameterPlaceholderLiteral(sql, parameter, true) +} + func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { var ( ignored interface{} @@ -555,22 +566,22 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri switch reflect.ValueOf(arg).Kind() { case reflect.Slice: // For where("id in (?)", []int64{1,2}) if bytes, ok := arg.([]byte); ok { - str = strings.Replace(str, "?", scope.AddToVars(bytes), 1) + str = scope.replaceParameterPlaceholder(str, bytes) } else if values := reflect.ValueOf(arg); values.Len() > 0 { var tempMarks []string for i := 0; i < values.Len(); i++ { tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) } - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) + str = scope.replaceParameterPlaceholderLiteral(str, strings.Join(tempMarks, ","), false) } else { - str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1) + str = scope.replaceParameterPlaceholder(str, Expr("NULL")) } default: if valuer, ok := interface{}(arg).(driver.Valuer); ok { arg, _ = valuer.Value() } - str = strings.Replace(str, "?", scope.AddToVars(arg), 1) + str = scope.replaceParameterPlaceholder(str, arg) } } return @@ -627,21 +638,21 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string switch reflect.ValueOf(arg).Kind() { case reflect.Slice: // For where("id in (?)", []int64{1,2}) if bytes, ok := arg.([]byte); ok { - str = strings.Replace(str, "?", scope.AddToVars(bytes), 1) + str = scope.replaceParameterPlaceholder(str, bytes) } else if values := reflect.ValueOf(arg); values.Len() > 0 { var tempMarks []string for i := 0; i < values.Len(); i++ { tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) } - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) + str = scope.replaceParameterPlaceholder(str, strings.Join(tempMarks, ",")) } else { - str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1) + str = scope.replaceParameterPlaceholder(str, Expr("NULL")) } default: if scanner, ok := interface{}(arg).(driver.Valuer); ok { arg, _ = scanner.Value() } - str = strings.Replace(notEqualSQL, "?", scope.AddToVars(arg), 1) + str = scope.replaceParameterPlaceholder(notEqualSQL, arg) } } return @@ -664,12 +675,12 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) for i := 0; i < values.Len(); i++ { tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) } - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) + str = scope.replaceParameterPlaceholder(str, strings.Join(tempMarks, ",")) default: if valuer, ok := interface{}(arg).(driver.Valuer); ok { arg, _ = valuer.Value() } - str = strings.Replace(str, "?", scope.AddToVars(arg), 1) + str = scope.replaceParameterPlaceholder(str, arg) } } return @@ -755,7 +766,7 @@ func (scope *Scope) orderSQL() string { } else if expr, ok := order.(*expr); ok { exp := expr.expr for _, arg := range expr.args { - exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) + exp = scope.replaceParameterPlaceholder(exp, arg) } orders = append(orders, exp) } From 45370c33329aaa64ba3e6600c5cdc85445b46955 Mon Sep 17 00:00:00 2001 From: ManReinsp Date: Fri, 28 Jul 2017 11:05:29 +0200 Subject: [PATCH 4/5] Refactoring, adding Having Subquery support, allowing db.T for tablenames --- main.go | 2 +- main_test.go | 19 +++++++++++++++++++ search.go | 8 ++++++-- 3 files changed, 26 insertions(+), 3 deletions(-) diff --git a/main.go b/main.go index 9af376c0..73556760 100644 --- a/main.go +++ b/main.go @@ -226,7 +226,7 @@ func (s *DB) Group(query string) *DB { } // Having specify HAVING conditions for GROUP BY -func (s *DB) Having(query string, values ...interface{}) *DB { +func (s *DB) Having(query interface{}, values ...interface{}) *DB { return s.clone().search.Having(query, values...).db } diff --git a/main_test.go b/main_test.go index 27f8eb1f..f6519dd6 100644 --- a/main_test.go +++ b/main_test.go @@ -633,6 +633,25 @@ func TestQueryBuilderSubselectInWhere(t *testing.T) { } } +func TestQueryBuilderSubselectInHaving(t *testing.T) { + user := User{Name: "user1", Email: "root@user1.com", Age: 64} + DB.Save(&user) + user = User{Name: "user2", Email: "root@user2.com", Age: 128} + DB.Save(&user) + user = User{Name: "user3", Email: "root@user1.com", Age: 64} + DB.Save(&user) + user = User{Name: "user4", Email: "root@user2.com", Age: 128} + DB.Save(&user) + + var users []User + DB.Debug().Select("AVG(age) as avgage").Group("email").Having("avgage > ?", DB. + Select("AVG(age)").Table("users").Subquery()).Find(&users) + + if len(users) != 1 { + t.Errorf("One user group should be found, instead found %d", len(users)) + } +} + func DialectHasTzSupport() bool { // NB: mssql and FoundationDB do not support time zones. if dialect := os.Getenv("GORM_DIALECT"); dialect == "mssql" || dialect == "foundation" { diff --git a/search.go b/search.go index 23dac2c3..2e273584 100644 --- a/search.go +++ b/search.go @@ -104,8 +104,12 @@ func (s *search) Group(query string) *search { return s } -func (s *search) Having(query string, values ...interface{}) *search { - s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": query, "args": values}) +func (s *search) Having(query interface{}, values ...interface{}) *search { + if val, ok := query.(*expr); ok { + s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": val.expr, "args": val.args}) + } else { + s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": query, "args": values}) + } return s } From b1634f73588d413bfe3bff1df52fa6f6d81f471e Mon Sep 17 00:00:00 2001 From: ManReinsp Date: Fri, 28 Jul 2017 11:39:38 +0200 Subject: [PATCH 5/5] Fixed tests --- main_test.go | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/main_test.go b/main_test.go index f6519dd6..7f839e8e 100644 --- a/main_test.go +++ b/main_test.go @@ -608,13 +608,13 @@ func TestHaving(t *testing.T) { } func TestQueryBuilderSubselectInWhere(t *testing.T) { - user := User{Name: "user1", Email: "root@user1.com", Age: 32} + user := User{Name: "ruser1", Email: "root@user1.com", Age: 32} DB.Save(&user) - user = User{Name: "user2", Email: "nobody@user2.com", Age: 16} + user = User{Name: "ruser2", Email: "nobody@user2.com", Age: 16} DB.Save(&user) - user = User{Name: "user3", Email: "root@user3.com", Age: 64} + user = User{Name: "ruser3", Email: "root@user3.com", Age: 64} DB.Save(&user) - user = User{Name: "user4", Email: "somebody@user3.com", Age: 128} + user = User{Name: "ruser4", Email: "somebody@user3.com", Age: 128} DB.Save(&user) var users []User @@ -625,7 +625,7 @@ func TestQueryBuilderSubselectInWhere(t *testing.T) { t.Errorf("Two users should be found, instead found %d", len(users)) } - DB.Select("*").Where("age >= ?", DB. + DB.Select("*").Where("email LIKE ?", "root%").Where("age >= ?", DB. Select("AVG(age)").Table("users").Subquery()).Find(&users) if len(users) != 2 { @@ -634,18 +634,18 @@ func TestQueryBuilderSubselectInWhere(t *testing.T) { } func TestQueryBuilderSubselectInHaving(t *testing.T) { - user := User{Name: "user1", Email: "root@user1.com", Age: 64} + user := User{Name: "ruser1", Email: "root@user1.com", Age: 64} DB.Save(&user) - user = User{Name: "user2", Email: "root@user2.com", Age: 128} + user = User{Name: "ruser2", Email: "root@user2.com", Age: 128} DB.Save(&user) - user = User{Name: "user3", Email: "root@user1.com", Age: 64} + user = User{Name: "ruser3", Email: "root@user1.com", Age: 64} DB.Save(&user) - user = User{Name: "user4", Email: "root@user2.com", Age: 128} + user = User{Name: "ruser4", Email: "root@user2.com", Age: 128} DB.Save(&user) var users []User - DB.Debug().Select("AVG(age) as avgage").Group("email").Having("avgage > ?", DB. - Select("AVG(age)").Table("users").Subquery()).Find(&users) + DB.Select("AVG(age) as avgage").Where("email LIKE ?", "root%").Group("email").Having("avgage > ?", DB. + Select("AVG(age)").Where("email LIKE ?", "root%").Table("users").Subquery()).Find(&users) if len(users) != 1 { t.Errorf("One user group should be found, instead found %d", len(users))