support for named parameter query
This commit is contained in:
parent
5b3e40ac12
commit
0b0956a4ef
@ -771,3 +771,33 @@ func TestPluckWithSelect(t *testing.T) {
|
|||||||
t.Errorf("Should correctly pluck with select, got: %s", userAges)
|
t.Errorf("Should correctly pluck with select, got: %s", userAges)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNamedParameters(t *testing.T) {
|
||||||
|
DB.Save(&User{Name: "kanoonsantikul", Age: 35})
|
||||||
|
DB.Save(&User{Name: "kanoongorm", Age: 30})
|
||||||
|
|
||||||
|
var user1, user2, user3 User
|
||||||
|
|
||||||
|
err := DB.Model(&User{}).Where("name LIKE ?UserName AND age <= ?MaxAge", map[string]interface{}{
|
||||||
|
"MaxAge": 30,
|
||||||
|
"UserName": "%kanoon%"}).First(&user1).Error
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = DB.Model(&User{}).Where("name LIKE ?firstName AND name LIKE ?lastName",
|
||||||
|
map[string]interface{}{"firstName": "kanoon%"},
|
||||||
|
map[string]interface{}{"lastName": "%santikul"}).First(&user2).Error
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
if user2.Age != 35 {
|
||||||
|
t.Error("Should merge Map parameters as one argument")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = DB.Model(&User{}).Where("age in (?ages)",
|
||||||
|
map[string]interface{}{"ages": []int64{21, 22, 23, 24, 25}}).First(&user3).Error
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
71
scope.go
71
scope.go
@ -464,6 +464,7 @@ var (
|
|||||||
isNumberRegexp = regexp.MustCompile("^\\s*\\d+\\s*$") // match if string is number
|
isNumberRegexp = regexp.MustCompile("^\\s*\\d+\\s*$") // match if string is number
|
||||||
comparisonRegexp = regexp.MustCompile("(?i) (=|<>|(>|<)(=?)|LIKE|IS|IN) ")
|
comparisonRegexp = regexp.MustCompile("(?i) (=|<>|(>|<)(=?)|LIKE|IS|IN) ")
|
||||||
countingQueryRegexp = regexp.MustCompile("(?i)^count(.+)$")
|
countingQueryRegexp = regexp.MustCompile("(?i)^count(.+)$")
|
||||||
|
namedArgRegexp = regexp.MustCompile("\\?([_a-zA-Z\\d]+)")
|
||||||
)
|
)
|
||||||
|
|
||||||
func (scope *Scope) quoteIfPossible(str string) string {
|
func (scope *Scope) quoteIfPossible(str string) string {
|
||||||
@ -600,17 +601,12 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool)
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
replacements := []string{}
|
getReplacementsForSliceArgs := func(arg interface{}) (replacement string, err error) {
|
||||||
args := clause["args"].([]interface{})
|
|
||||||
for _, arg := range args {
|
|
||||||
var err error
|
|
||||||
switch reflect.ValueOf(arg).Kind() {
|
|
||||||
case reflect.Slice: // For where("id in (?)", []int64{1,2})
|
|
||||||
if scanner, ok := interface{}(arg).(driver.Valuer); ok {
|
if scanner, ok := interface{}(arg).(driver.Valuer); ok {
|
||||||
arg, err = scanner.Value()
|
arg, err = scanner.Value()
|
||||||
replacements = append(replacements, scope.AddToVars(arg))
|
replacement = scope.AddToVars(arg)
|
||||||
} else if b, ok := arg.([]byte); ok {
|
} else if b, ok := arg.([]byte); ok {
|
||||||
replacements = append(replacements, scope.AddToVars(b))
|
replacement = scope.AddToVars(b)
|
||||||
} else if as, ok := arg.([][]interface{}); ok {
|
} else if as, ok := arg.([][]interface{}); ok {
|
||||||
var tempMarks []string
|
var tempMarks []string
|
||||||
for _, a := range as {
|
for _, a := range as {
|
||||||
@ -625,23 +621,51 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool)
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(tempMarks) > 0 {
|
if len(tempMarks) > 0 {
|
||||||
replacements = append(replacements, strings.Join(tempMarks, ","))
|
replacement = strings.Join(tempMarks, ",")
|
||||||
}
|
}
|
||||||
} else if values := reflect.ValueOf(arg); values.Len() > 0 {
|
} else if values := reflect.ValueOf(arg); values.Len() > 0 {
|
||||||
var tempMarks []string
|
var tempMarks []string
|
||||||
for i := 0; i < values.Len(); i++ {
|
for i := 0; i < values.Len(); i++ {
|
||||||
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
|
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
|
||||||
}
|
}
|
||||||
replacements = append(replacements, strings.Join(tempMarks, ","))
|
replacement = strings.Join(tempMarks, ",")
|
||||||
} else {
|
} else {
|
||||||
replacements = append(replacements, scope.AddToVars(Expr("NULL")))
|
replacement = scope.AddToVars(Expr("NULL"))
|
||||||
}
|
}
|
||||||
default:
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
getReplacementForArg := func(arg interface{}) (replacement string, err error) {
|
||||||
if valuer, ok := interface{}(arg).(driver.Valuer); ok {
|
if valuer, ok := interface{}(arg).(driver.Valuer); ok {
|
||||||
arg, err = valuer.Value()
|
arg, err = valuer.Value()
|
||||||
}
|
}
|
||||||
|
|
||||||
replacements = append(replacements, scope.AddToVars(arg))
|
replacement = scope.AddToVars(arg)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
replacements := []string{}
|
||||||
|
|
||||||
|
args := clause["args"].([]interface{})
|
||||||
|
namedArgs := map[string]interface{}{}
|
||||||
|
|
||||||
|
for _, arg := range args {
|
||||||
|
var err error
|
||||||
|
var replacement string
|
||||||
|
switch reflect.ValueOf(arg).Kind() {
|
||||||
|
case reflect.Map:
|
||||||
|
if tempArgs, ok := arg.(map[string]interface{}); ok {
|
||||||
|
for argName, value := range tempArgs {
|
||||||
|
namedArgs["?"+argName] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case reflect.Slice: // For where("id in (?)", []int64{1,2})
|
||||||
|
replacement, err = getReplacementsForSliceArgs(arg)
|
||||||
|
replacements = append(replacements, replacement)
|
||||||
|
default:
|
||||||
|
replacement, err = getReplacementForArg(arg)
|
||||||
|
replacements = append(replacements, replacement)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -649,6 +673,7 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(replacements) > 0 {
|
||||||
buff := bytes.NewBuffer([]byte{})
|
buff := bytes.NewBuffer([]byte{})
|
||||||
i := 0
|
i := 0
|
||||||
for _, s := range str {
|
for _, s := range str {
|
||||||
@ -661,6 +686,26 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool)
|
|||||||
}
|
}
|
||||||
|
|
||||||
str = buff.String()
|
str = buff.String()
|
||||||
|
} else if len(namedArgs) > 0 {
|
||||||
|
str = namedArgRegexp.ReplaceAllStringFunc(str, func(match string) string {
|
||||||
|
arg := namedArgs[match]
|
||||||
|
var err error
|
||||||
|
var replacement string
|
||||||
|
|
||||||
|
switch reflect.ValueOf(arg).Kind() {
|
||||||
|
case reflect.Slice:
|
||||||
|
replacement, err = getReplacementsForSliceArgs(arg)
|
||||||
|
default:
|
||||||
|
replacement, err = getReplacementForArg(arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
scope.Err(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return replacement
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user