support for named parameter query

This commit is contained in:
Niti Santikul 2019-10-20 17:35:09 +07:00
parent 5b3e40ac12
commit 0b0956a4ef
2 changed files with 120 additions and 45 deletions

View File

@ -771,3 +771,33 @@ func TestPluckWithSelect(t *testing.T) {
t.Errorf("Should correctly pluck with select, got: %s", userAges)
}
}
func TestNamedParameters(t *testing.T) {
DB.Save(&User{Name: "kanoonsantikul", Age: 35})
DB.Save(&User{Name: "kanoongorm", Age: 30})
var user1, user2, user3 User
err := DB.Model(&User{}).Where("name LIKE ?UserName AND age <= ?MaxAge", map[string]interface{}{
"MaxAge": 30,
"UserName": "%kanoon%"}).First(&user1).Error
if err != nil {
t.Error(err)
}
err = DB.Model(&User{}).Where("name LIKE ?firstName AND name LIKE ?lastName",
map[string]interface{}{"firstName": "kanoon%"},
map[string]interface{}{"lastName": "%santikul"}).First(&user2).Error
if err != nil {
t.Error(err)
}
if user2.Age != 35 {
t.Error("Should merge Map parameters as one argument")
}
err = DB.Model(&User{}).Where("age in (?ages)",
map[string]interface{}{"ages": []int64{21, 22, 23, 24, 25}}).First(&user3).Error
if err != nil {
t.Error(err)
}
}

View File

@ -464,6 +464,7 @@ var (
isNumberRegexp = regexp.MustCompile("^\\s*\\d+\\s*$") // match if string is number
comparisonRegexp = regexp.MustCompile("(?i) (=|<>|(>|<)(=?)|LIKE|IS|IN) ")
countingQueryRegexp = regexp.MustCompile("(?i)^count(.+)$")
namedArgRegexp = regexp.MustCompile("\\?([_a-zA-Z\\d]+)")
)
func (scope *Scope) quoteIfPossible(str string) string {
@ -600,17 +601,12 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool)
return
}
replacements := []string{}
args := clause["args"].([]interface{})
for _, arg := range args {
var err error
switch reflect.ValueOf(arg).Kind() {
case reflect.Slice: // For where("id in (?)", []int64{1,2})
getReplacementsForSliceArgs := func(arg interface{}) (replacement string, err error) {
if scanner, ok := interface{}(arg).(driver.Valuer); ok {
arg, err = scanner.Value()
replacements = append(replacements, scope.AddToVars(arg))
replacement = scope.AddToVars(arg)
} else if b, ok := arg.([]byte); ok {
replacements = append(replacements, scope.AddToVars(b))
replacement = scope.AddToVars(b)
} else if as, ok := arg.([][]interface{}); ok {
var tempMarks []string
for _, a := range as {
@ -625,23 +621,51 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool)
}
if len(tempMarks) > 0 {
replacements = append(replacements, strings.Join(tempMarks, ","))
replacement = strings.Join(tempMarks, ",")
}
} else if values := reflect.ValueOf(arg); values.Len() > 0 {
var tempMarks []string
for i := 0; i < values.Len(); i++ {
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
}
replacements = append(replacements, strings.Join(tempMarks, ","))
replacement = strings.Join(tempMarks, ",")
} else {
replacements = append(replacements, scope.AddToVars(Expr("NULL")))
replacement = scope.AddToVars(Expr("NULL"))
}
default:
return
}
getReplacementForArg := func(arg interface{}) (replacement string, err error) {
if valuer, ok := interface{}(arg).(driver.Valuer); ok {
arg, err = valuer.Value()
}
replacements = append(replacements, scope.AddToVars(arg))
replacement = scope.AddToVars(arg)
return
}
replacements := []string{}
args := clause["args"].([]interface{})
namedArgs := map[string]interface{}{}
for _, arg := range args {
var err error
var replacement string
switch reflect.ValueOf(arg).Kind() {
case reflect.Map:
if tempArgs, ok := arg.(map[string]interface{}); ok {
for argName, value := range tempArgs {
namedArgs["?"+argName] = value
}
}
case reflect.Slice: // For where("id in (?)", []int64{1,2})
replacement, err = getReplacementsForSliceArgs(arg)
replacements = append(replacements, replacement)
default:
replacement, err = getReplacementForArg(arg)
replacements = append(replacements, replacement)
}
if err != nil {
@ -649,6 +673,7 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool)
}
}
if len(replacements) > 0 {
buff := bytes.NewBuffer([]byte{})
i := 0
for _, s := range str {
@ -661,6 +686,26 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool)
}
str = buff.String()
} else if len(namedArgs) > 0 {
str = namedArgRegexp.ReplaceAllStringFunc(str, func(match string) string {
arg := namedArgs[match]
var err error
var replacement string
switch reflect.ValueOf(arg).Kind() {
case reflect.Slice:
replacement, err = getReplacementsForSliceArgs(arg)
default:
replacement, err = getReplacementForArg(arg)
}
if err != nil {
scope.Err(err)
}
return replacement
})
}
return
}