support for named parameter query
This commit is contained in:
parent
5b3e40ac12
commit
0b0956a4ef
@ -771,3 +771,33 @@ func TestPluckWithSelect(t *testing.T) {
|
||||
t.Errorf("Should correctly pluck with select, got: %s", userAges)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNamedParameters(t *testing.T) {
|
||||
DB.Save(&User{Name: "kanoonsantikul", Age: 35})
|
||||
DB.Save(&User{Name: "kanoongorm", Age: 30})
|
||||
|
||||
var user1, user2, user3 User
|
||||
|
||||
err := DB.Model(&User{}).Where("name LIKE ?UserName AND age <= ?MaxAge", map[string]interface{}{
|
||||
"MaxAge": 30,
|
||||
"UserName": "%kanoon%"}).First(&user1).Error
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
err = DB.Model(&User{}).Where("name LIKE ?firstName AND name LIKE ?lastName",
|
||||
map[string]interface{}{"firstName": "kanoon%"},
|
||||
map[string]interface{}{"lastName": "%santikul"}).First(&user2).Error
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if user2.Age != 35 {
|
||||
t.Error("Should merge Map parameters as one argument")
|
||||
}
|
||||
|
||||
err = DB.Model(&User{}).Where("age in (?ages)",
|
||||
map[string]interface{}{"ages": []int64{21, 22, 23, 24, 25}}).First(&user3).Error
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
135
scope.go
135
scope.go
@ -464,6 +464,7 @@ var (
|
||||
isNumberRegexp = regexp.MustCompile("^\\s*\\d+\\s*$") // match if string is number
|
||||
comparisonRegexp = regexp.MustCompile("(?i) (=|<>|(>|<)(=?)|LIKE|IS|IN) ")
|
||||
countingQueryRegexp = regexp.MustCompile("(?i)^count(.+)$")
|
||||
namedArgRegexp = regexp.MustCompile("\\?([_a-zA-Z\\d]+)")
|
||||
)
|
||||
|
||||
func (scope *Scope) quoteIfPossible(str string) string {
|
||||
@ -600,48 +601,71 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool)
|
||||
return
|
||||
}
|
||||
|
||||
getReplacementsForSliceArgs := func(arg interface{}) (replacement string, err error) {
|
||||
if scanner, ok := interface{}(arg).(driver.Valuer); ok {
|
||||
arg, err = scanner.Value()
|
||||
replacement = scope.AddToVars(arg)
|
||||
} else if b, ok := arg.([]byte); ok {
|
||||
replacement = scope.AddToVars(b)
|
||||
} else if as, ok := arg.([][]interface{}); ok {
|
||||
var tempMarks []string
|
||||
for _, a := range as {
|
||||
var arrayMarks []string
|
||||
for _, v := range a {
|
||||
arrayMarks = append(arrayMarks, scope.AddToVars(v))
|
||||
}
|
||||
|
||||
if len(arrayMarks) > 0 {
|
||||
tempMarks = append(tempMarks, fmt.Sprintf("(%v)", strings.Join(arrayMarks, ",")))
|
||||
}
|
||||
}
|
||||
|
||||
if len(tempMarks) > 0 {
|
||||
replacement = strings.Join(tempMarks, ",")
|
||||
}
|
||||
} else if values := reflect.ValueOf(arg); values.Len() > 0 {
|
||||
var tempMarks []string
|
||||
for i := 0; i < values.Len(); i++ {
|
||||
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
|
||||
}
|
||||
replacement = strings.Join(tempMarks, ",")
|
||||
} else {
|
||||
replacement = scope.AddToVars(Expr("NULL"))
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
getReplacementForArg := func(arg interface{}) (replacement string, err error) {
|
||||
if valuer, ok := interface{}(arg).(driver.Valuer); ok {
|
||||
arg, err = valuer.Value()
|
||||
}
|
||||
|
||||
replacement = scope.AddToVars(arg)
|
||||
return
|
||||
}
|
||||
|
||||
replacements := []string{}
|
||||
|
||||
args := clause["args"].([]interface{})
|
||||
namedArgs := map[string]interface{}{}
|
||||
|
||||
for _, arg := range args {
|
||||
var err error
|
||||
var replacement string
|
||||
switch reflect.ValueOf(arg).Kind() {
|
||||
case reflect.Map:
|
||||
if tempArgs, ok := arg.(map[string]interface{}); ok {
|
||||
for argName, value := range tempArgs {
|
||||
namedArgs["?"+argName] = value
|
||||
}
|
||||
}
|
||||
case reflect.Slice: // For where("id in (?)", []int64{1,2})
|
||||
if scanner, ok := interface{}(arg).(driver.Valuer); ok {
|
||||
arg, err = scanner.Value()
|
||||
replacements = append(replacements, scope.AddToVars(arg))
|
||||
} else if b, ok := arg.([]byte); ok {
|
||||
replacements = append(replacements, scope.AddToVars(b))
|
||||
} else if as, ok := arg.([][]interface{}); ok {
|
||||
var tempMarks []string
|
||||
for _, a := range as {
|
||||
var arrayMarks []string
|
||||
for _, v := range a {
|
||||
arrayMarks = append(arrayMarks, scope.AddToVars(v))
|
||||
}
|
||||
|
||||
if len(arrayMarks) > 0 {
|
||||
tempMarks = append(tempMarks, fmt.Sprintf("(%v)", strings.Join(arrayMarks, ",")))
|
||||
}
|
||||
}
|
||||
|
||||
if len(tempMarks) > 0 {
|
||||
replacements = append(replacements, strings.Join(tempMarks, ","))
|
||||
}
|
||||
} else if values := reflect.ValueOf(arg); values.Len() > 0 {
|
||||
var tempMarks []string
|
||||
for i := 0; i < values.Len(); i++ {
|
||||
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
|
||||
}
|
||||
replacements = append(replacements, strings.Join(tempMarks, ","))
|
||||
} else {
|
||||
replacements = append(replacements, scope.AddToVars(Expr("NULL")))
|
||||
}
|
||||
replacement, err = getReplacementsForSliceArgs(arg)
|
||||
replacements = append(replacements, replacement)
|
||||
default:
|
||||
if valuer, ok := interface{}(arg).(driver.Valuer); ok {
|
||||
arg, err = valuer.Value()
|
||||
}
|
||||
|
||||
replacements = append(replacements, scope.AddToVars(arg))
|
||||
replacement, err = getReplacementForArg(arg)
|
||||
replacements = append(replacements, replacement)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@ -649,18 +673,39 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool)
|
||||
}
|
||||
}
|
||||
|
||||
buff := bytes.NewBuffer([]byte{})
|
||||
i := 0
|
||||
for _, s := range str {
|
||||
if s == '?' && len(replacements) > i {
|
||||
buff.WriteString(replacements[i])
|
||||
i++
|
||||
} else {
|
||||
buff.WriteRune(s)
|
||||
if len(replacements) > 0 {
|
||||
buff := bytes.NewBuffer([]byte{})
|
||||
i := 0
|
||||
for _, s := range str {
|
||||
if s == '?' && len(replacements) > i {
|
||||
buff.WriteString(replacements[i])
|
||||
i++
|
||||
} else {
|
||||
buff.WriteRune(s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
str = buff.String()
|
||||
str = buff.String()
|
||||
} else if len(namedArgs) > 0 {
|
||||
str = namedArgRegexp.ReplaceAllStringFunc(str, func(match string) string {
|
||||
arg := namedArgs[match]
|
||||
var err error
|
||||
var replacement string
|
||||
|
||||
switch reflect.ValueOf(arg).Kind() {
|
||||
case reflect.Slice:
|
||||
replacement, err = getReplacementsForSliceArgs(arg)
|
||||
default:
|
||||
replacement, err = getReplacementForArg(arg)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
scope.Err(err)
|
||||
}
|
||||
|
||||
return replacement
|
||||
})
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user