From 0b0956a4ef3501dab1f1a8f378fbdb74cca804cb Mon Sep 17 00:00:00 2001 From: Niti Santikul Date: Sun, 20 Oct 2019 17:35:09 +0700 Subject: [PATCH] support for named parameter query --- query_test.go | 30 +++++++++++ scope.go | 135 +++++++++++++++++++++++++++++++++----------------- 2 files changed, 120 insertions(+), 45 deletions(-) diff --git a/query_test.go b/query_test.go index 15bf8b3c..29039e1f 100644 --- a/query_test.go +++ b/query_test.go @@ -771,3 +771,33 @@ func TestPluckWithSelect(t *testing.T) { t.Errorf("Should correctly pluck with select, got: %s", userAges) } } + +func TestNamedParameters(t *testing.T) { + DB.Save(&User{Name: "kanoonsantikul", Age: 35}) + DB.Save(&User{Name: "kanoongorm", Age: 30}) + + var user1, user2, user3 User + + err := DB.Model(&User{}).Where("name LIKE ?UserName AND age <= ?MaxAge", map[string]interface{}{ + "MaxAge": 30, + "UserName": "%kanoon%"}).First(&user1).Error + if err != nil { + t.Error(err) + } + + err = DB.Model(&User{}).Where("name LIKE ?firstName AND name LIKE ?lastName", + map[string]interface{}{"firstName": "kanoon%"}, + map[string]interface{}{"lastName": "%santikul"}).First(&user2).Error + if err != nil { + t.Error(err) + } + if user2.Age != 35 { + t.Error("Should merge Map parameters as one argument") + } + + err = DB.Model(&User{}).Where("age in (?ages)", + map[string]interface{}{"ages": []int64{21, 22, 23, 24, 25}}).First(&user3).Error + if err != nil { + t.Error(err) + } +} diff --git a/scope.go b/scope.go index eb7525b8..40740c5a 100644 --- a/scope.go +++ b/scope.go @@ -464,6 +464,7 @@ var ( isNumberRegexp = regexp.MustCompile("^\\s*\\d+\\s*$") // match if string is number comparisonRegexp = regexp.MustCompile("(?i) (=|<>|(>|<)(=?)|LIKE|IS|IN) ") countingQueryRegexp = regexp.MustCompile("(?i)^count(.+)$") + namedArgRegexp = regexp.MustCompile("\\?([_a-zA-Z\\d]+)") ) func (scope *Scope) quoteIfPossible(str string) string { @@ -600,48 +601,71 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool) return } + getReplacementsForSliceArgs := func(arg interface{}) (replacement string, err error) { + if scanner, ok := interface{}(arg).(driver.Valuer); ok { + arg, err = scanner.Value() + replacement = scope.AddToVars(arg) + } else if b, ok := arg.([]byte); ok { + replacement = scope.AddToVars(b) + } else if as, ok := arg.([][]interface{}); ok { + var tempMarks []string + for _, a := range as { + var arrayMarks []string + for _, v := range a { + arrayMarks = append(arrayMarks, scope.AddToVars(v)) + } + + if len(arrayMarks) > 0 { + tempMarks = append(tempMarks, fmt.Sprintf("(%v)", strings.Join(arrayMarks, ","))) + } + } + + if len(tempMarks) > 0 { + replacement = strings.Join(tempMarks, ",") + } + } else if values := reflect.ValueOf(arg); values.Len() > 0 { + var tempMarks []string + for i := 0; i < values.Len(); i++ { + tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) + } + replacement = strings.Join(tempMarks, ",") + } else { + replacement = scope.AddToVars(Expr("NULL")) + } + + return + } + + getReplacementForArg := func(arg interface{}) (replacement string, err error) { + if valuer, ok := interface{}(arg).(driver.Valuer); ok { + arg, err = valuer.Value() + } + + replacement = scope.AddToVars(arg) + return + } + replacements := []string{} + args := clause["args"].([]interface{}) + namedArgs := map[string]interface{}{} + for _, arg := range args { var err error + var replacement string switch reflect.ValueOf(arg).Kind() { + case reflect.Map: + if tempArgs, ok := arg.(map[string]interface{}); ok { + for argName, value := range tempArgs { + namedArgs["?"+argName] = value + } + } case reflect.Slice: // For where("id in (?)", []int64{1,2}) - if scanner, ok := interface{}(arg).(driver.Valuer); ok { - arg, err = scanner.Value() - replacements = append(replacements, scope.AddToVars(arg)) - } else if b, ok := arg.([]byte); ok { - replacements = append(replacements, scope.AddToVars(b)) - } else if as, ok := arg.([][]interface{}); ok { - var tempMarks []string - for _, a := range as { - var arrayMarks []string - for _, v := range a { - arrayMarks = append(arrayMarks, scope.AddToVars(v)) - } - - if len(arrayMarks) > 0 { - tempMarks = append(tempMarks, fmt.Sprintf("(%v)", strings.Join(arrayMarks, ","))) - } - } - - if len(tempMarks) > 0 { - replacements = append(replacements, strings.Join(tempMarks, ",")) - } - } else if values := reflect.ValueOf(arg); values.Len() > 0 { - var tempMarks []string - for i := 0; i < values.Len(); i++ { - tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) - } - replacements = append(replacements, strings.Join(tempMarks, ",")) - } else { - replacements = append(replacements, scope.AddToVars(Expr("NULL"))) - } + replacement, err = getReplacementsForSliceArgs(arg) + replacements = append(replacements, replacement) default: - if valuer, ok := interface{}(arg).(driver.Valuer); ok { - arg, err = valuer.Value() - } - - replacements = append(replacements, scope.AddToVars(arg)) + replacement, err = getReplacementForArg(arg) + replacements = append(replacements, replacement) } if err != nil { @@ -649,18 +673,39 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool) } } - buff := bytes.NewBuffer([]byte{}) - i := 0 - for _, s := range str { - if s == '?' && len(replacements) > i { - buff.WriteString(replacements[i]) - i++ - } else { - buff.WriteRune(s) + if len(replacements) > 0 { + buff := bytes.NewBuffer([]byte{}) + i := 0 + for _, s := range str { + if s == '?' && len(replacements) > i { + buff.WriteString(replacements[i]) + i++ + } else { + buff.WriteRune(s) + } } - } - str = buff.String() + str = buff.String() + } else if len(namedArgs) > 0 { + str = namedArgRegexp.ReplaceAllStringFunc(str, func(match string) string { + arg := namedArgs[match] + var err error + var replacement string + + switch reflect.ValueOf(arg).Kind() { + case reflect.Slice: + replacement, err = getReplacementsForSliceArgs(arg) + default: + replacement, err = getReplacementForArg(arg) + } + + if err != nil { + scope.Err(err) + } + + return replacement + }) + } return }