misc: correct composing sql replace "?" logic

This commit is contained in:
wmin0 2017-12-26 13:26:41 +08:00
parent 9fe3aeb2a8
commit 3b41aa8f66
3 changed files with 37 additions and 14 deletions

View File

@ -79,10 +79,13 @@ var LogFormatter = func(values ...interface{}) (messages []interface{}) {
} }
} else { } else {
formattedValuesLength := len(formattedValues) formattedValuesLength := len(formattedValues)
for index, value := range sqlRegexp.Split(values[3].(string), -1) { s := sqlRegexp.Split(values[3].(string), -1)
for index, value := range s {
sql += value sql += value
if index < formattedValuesLength { if index < formattedValuesLength {
sql += formattedValues[index] sql += formattedValues[index]
} else if index != len(s)-1 {
sql += "?"
} }
} }
} }

View File

@ -393,6 +393,13 @@ func TestRow(t *testing.T) {
if age != 10 { if age != 10 {
t.Errorf("Scan with Row") t.Errorf("Scan with Row")
} }
age = 0
row = DB.Debug().Table("users").Where("name != ? AND name != ? AND Age = ? AND name != ?", "???", "???", 10, "???").Select("age").Row()
row.Scan(&age)
if age != 10 {
t.Errorf("Scan with Row")
}
} }
func TestRows(t *testing.T) { func TestRows(t *testing.T) {

View File

@ -261,7 +261,7 @@ func (scope *Scope) AddToVars(value interface{}) string {
if skipBindVar { if skipBindVar {
scope.AddToVars(arg) scope.AddToVars(arg)
} else { } else {
exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) exp = scope.ReplaceOnePlaceholder(exp, scope.AddToVars(arg))
} }
} }
return exp return exp
@ -280,6 +280,20 @@ func (scope *Scope) AddToVars(value interface{}) string {
return dialect.BindVar(len(scope.SQLVars)) return dialect.BindVar(len(scope.SQLVars))
} }
func (scope *Scope) ReplaceOnePlaceholder(sql, value string) string {
quoteCount := 0
for idx, c := range sql {
if string(c) == "'" {
quoteCount += 1
continue
}
if string(c) == "?" && quoteCount%2 == 0 {
return string([]rune(sql)[:idx]) + value + string([]rune(sql)[idx+1:])
}
}
return sql
}
// SelectAttrs return selected attributes // SelectAttrs return selected attributes
func (scope *Scope) SelectAttrs() []string { func (scope *Scope) SelectAttrs() []string {
if scope.selectAttrs == nil { if scope.selectAttrs == nil {
@ -565,22 +579,21 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
switch reflect.ValueOf(arg).Kind() { switch reflect.ValueOf(arg).Kind() {
case reflect.Slice: // For where("id in (?)", []int64{1,2}) case reflect.Slice: // For where("id in (?)", []int64{1,2})
if bytes, ok := arg.([]byte); ok { if bytes, ok := arg.([]byte); ok {
str = strings.Replace(str, "?", scope.AddToVars(bytes), 1) str = scope.ReplaceOnePlaceholder(str, scope.AddToVars(bytes))
} else if values := reflect.ValueOf(arg); values.Len() > 0 { } else if values := reflect.ValueOf(arg); values.Len() > 0 {
var tempMarks []string var tempMarks []string
for i := 0; i < values.Len(); i++ { for i := 0; i < values.Len(); i++ {
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
} }
str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) str = scope.ReplaceOnePlaceholder(str, strings.Join(tempMarks, ","))
} else { } else {
str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1) str = scope.ReplaceOnePlaceholder(str, scope.AddToVars(Expr("NULL")))
} }
default: default:
if valuer, ok := interface{}(arg).(driver.Valuer); ok { if valuer, ok := interface{}(arg).(driver.Valuer); ok {
arg, _ = valuer.Value() arg, _ = valuer.Value()
} }
str = scope.ReplaceOnePlaceholder(str, scope.AddToVars(arg))
str = strings.Replace(str, "?", scope.AddToVars(arg), 1)
} }
} }
return return
@ -637,21 +650,21 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string
switch reflect.ValueOf(arg).Kind() { switch reflect.ValueOf(arg).Kind() {
case reflect.Slice: // For where("id in (?)", []int64{1,2}) case reflect.Slice: // For where("id in (?)", []int64{1,2})
if bytes, ok := arg.([]byte); ok { if bytes, ok := arg.([]byte); ok {
str = strings.Replace(str, "?", scope.AddToVars(bytes), 1) str = scope.ReplaceOnePlaceholder(str, scope.AddToVars(bytes))
} else if values := reflect.ValueOf(arg); values.Len() > 0 { } else if values := reflect.ValueOf(arg); values.Len() > 0 {
var tempMarks []string var tempMarks []string
for i := 0; i < values.Len(); i++ { for i := 0; i < values.Len(); i++ {
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
} }
str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) str = scope.ReplaceOnePlaceholder(str, strings.Join(tempMarks, ","))
} else { } else {
str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1) str = scope.ReplaceOnePlaceholder(str, scope.AddToVars(Expr("NULL")))
} }
default: default:
if scanner, ok := interface{}(arg).(driver.Valuer); ok { if scanner, ok := interface{}(arg).(driver.Valuer); ok {
arg, _ = scanner.Value() arg, _ = scanner.Value()
} }
str = strings.Replace(notEqualSQL, "?", scope.AddToVars(arg), 1) str = scope.ReplaceOnePlaceholder(notEqualSQL, scope.AddToVars(arg))
} }
} }
return return
@ -674,12 +687,12 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string)
for i := 0; i < values.Len(); i++ { for i := 0; i < values.Len(); i++ {
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
} }
str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) str = scope.ReplaceOnePlaceholder(str, strings.Join(tempMarks, ","))
default: default:
if valuer, ok := interface{}(arg).(driver.Valuer); ok { if valuer, ok := interface{}(arg).(driver.Valuer); ok {
arg, _ = valuer.Value() arg, _ = valuer.Value()
} }
str = strings.Replace(str, "?", scope.AddToVars(arg), 1) str = scope.ReplaceOnePlaceholder(str, scope.AddToVars(arg))
} }
} }
return return
@ -765,7 +778,7 @@ func (scope *Scope) orderSQL() string {
} else if expr, ok := order.(*expr); ok { } else if expr, ok := order.(*expr); ok {
exp := expr.expr exp := expr.expr
for _, arg := range expr.args { for _, arg := range expr.args {
exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) exp = scope.ReplaceOnePlaceholder(exp, scope.AddToVars(arg))
} }
orders = append(orders, exp) orders = append(orders, exp)
} }