misc: correct composing sql replace "?" logic
This commit is contained in:
parent
9fe3aeb2a8
commit
3b41aa8f66
@ -79,10 +79,13 @@ var LogFormatter = func(values ...interface{}) (messages []interface{}) {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
formattedValuesLength := len(formattedValues)
|
formattedValuesLength := len(formattedValues)
|
||||||
for index, value := range sqlRegexp.Split(values[3].(string), -1) {
|
s := sqlRegexp.Split(values[3].(string), -1)
|
||||||
|
for index, value := range s {
|
||||||
sql += value
|
sql += value
|
||||||
if index < formattedValuesLength {
|
if index < formattedValuesLength {
|
||||||
sql += formattedValues[index]
|
sql += formattedValues[index]
|
||||||
|
} else if index != len(s)-1 {
|
||||||
|
sql += "?"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -393,6 +393,13 @@ func TestRow(t *testing.T) {
|
|||||||
if age != 10 {
|
if age != 10 {
|
||||||
t.Errorf("Scan with Row")
|
t.Errorf("Scan with Row")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
age = 0
|
||||||
|
row = DB.Debug().Table("users").Where("name != ? AND name != ? AND Age = ? AND name != ?", "???", "???", 10, "???").Select("age").Row()
|
||||||
|
row.Scan(&age)
|
||||||
|
if age != 10 {
|
||||||
|
t.Errorf("Scan with Row")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRows(t *testing.T) {
|
func TestRows(t *testing.T) {
|
||||||
|
39
scope.go
39
scope.go
@ -261,7 +261,7 @@ func (scope *Scope) AddToVars(value interface{}) string {
|
|||||||
if skipBindVar {
|
if skipBindVar {
|
||||||
scope.AddToVars(arg)
|
scope.AddToVars(arg)
|
||||||
} else {
|
} else {
|
||||||
exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1)
|
exp = scope.ReplaceOnePlaceholder(exp, scope.AddToVars(arg))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return exp
|
return exp
|
||||||
@ -280,6 +280,20 @@ func (scope *Scope) AddToVars(value interface{}) string {
|
|||||||
return dialect.BindVar(len(scope.SQLVars))
|
return dialect.BindVar(len(scope.SQLVars))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) ReplaceOnePlaceholder(sql, value string) string {
|
||||||
|
quoteCount := 0
|
||||||
|
for idx, c := range sql {
|
||||||
|
if string(c) == "'" {
|
||||||
|
quoteCount += 1
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if string(c) == "?" && quoteCount%2 == 0 {
|
||||||
|
return string([]rune(sql)[:idx]) + value + string([]rune(sql)[idx+1:])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sql
|
||||||
|
}
|
||||||
|
|
||||||
// SelectAttrs return selected attributes
|
// SelectAttrs return selected attributes
|
||||||
func (scope *Scope) SelectAttrs() []string {
|
func (scope *Scope) SelectAttrs() []string {
|
||||||
if scope.selectAttrs == nil {
|
if scope.selectAttrs == nil {
|
||||||
@ -565,22 +579,21 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
|
|||||||
switch reflect.ValueOf(arg).Kind() {
|
switch reflect.ValueOf(arg).Kind() {
|
||||||
case reflect.Slice: // For where("id in (?)", []int64{1,2})
|
case reflect.Slice: // For where("id in (?)", []int64{1,2})
|
||||||
if bytes, ok := arg.([]byte); ok {
|
if bytes, ok := arg.([]byte); ok {
|
||||||
str = strings.Replace(str, "?", scope.AddToVars(bytes), 1)
|
str = scope.ReplaceOnePlaceholder(str, scope.AddToVars(bytes))
|
||||||
} else if values := reflect.ValueOf(arg); values.Len() > 0 {
|
} else if values := reflect.ValueOf(arg); values.Len() > 0 {
|
||||||
var tempMarks []string
|
var tempMarks []string
|
||||||
for i := 0; i < values.Len(); i++ {
|
for i := 0; i < values.Len(); i++ {
|
||||||
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
|
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
|
||||||
}
|
}
|
||||||
str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
|
str = scope.ReplaceOnePlaceholder(str, strings.Join(tempMarks, ","))
|
||||||
} else {
|
} else {
|
||||||
str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1)
|
str = scope.ReplaceOnePlaceholder(str, scope.AddToVars(Expr("NULL")))
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
if valuer, ok := interface{}(arg).(driver.Valuer); ok {
|
if valuer, ok := interface{}(arg).(driver.Valuer); ok {
|
||||||
arg, _ = valuer.Value()
|
arg, _ = valuer.Value()
|
||||||
}
|
}
|
||||||
|
str = scope.ReplaceOnePlaceholder(str, scope.AddToVars(arg))
|
||||||
str = strings.Replace(str, "?", scope.AddToVars(arg), 1)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
@ -637,21 +650,21 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string
|
|||||||
switch reflect.ValueOf(arg).Kind() {
|
switch reflect.ValueOf(arg).Kind() {
|
||||||
case reflect.Slice: // For where("id in (?)", []int64{1,2})
|
case reflect.Slice: // For where("id in (?)", []int64{1,2})
|
||||||
if bytes, ok := arg.([]byte); ok {
|
if bytes, ok := arg.([]byte); ok {
|
||||||
str = strings.Replace(str, "?", scope.AddToVars(bytes), 1)
|
str = scope.ReplaceOnePlaceholder(str, scope.AddToVars(bytes))
|
||||||
} else if values := reflect.ValueOf(arg); values.Len() > 0 {
|
} else if values := reflect.ValueOf(arg); values.Len() > 0 {
|
||||||
var tempMarks []string
|
var tempMarks []string
|
||||||
for i := 0; i < values.Len(); i++ {
|
for i := 0; i < values.Len(); i++ {
|
||||||
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
|
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
|
||||||
}
|
}
|
||||||
str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
|
str = scope.ReplaceOnePlaceholder(str, strings.Join(tempMarks, ","))
|
||||||
} else {
|
} else {
|
||||||
str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1)
|
str = scope.ReplaceOnePlaceholder(str, scope.AddToVars(Expr("NULL")))
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
if scanner, ok := interface{}(arg).(driver.Valuer); ok {
|
if scanner, ok := interface{}(arg).(driver.Valuer); ok {
|
||||||
arg, _ = scanner.Value()
|
arg, _ = scanner.Value()
|
||||||
}
|
}
|
||||||
str = strings.Replace(notEqualSQL, "?", scope.AddToVars(arg), 1)
|
str = scope.ReplaceOnePlaceholder(notEqualSQL, scope.AddToVars(arg))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
@ -674,12 +687,12 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string)
|
|||||||
for i := 0; i < values.Len(); i++ {
|
for i := 0; i < values.Len(); i++ {
|
||||||
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
|
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
|
||||||
}
|
}
|
||||||
str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
|
str = scope.ReplaceOnePlaceholder(str, strings.Join(tempMarks, ","))
|
||||||
default:
|
default:
|
||||||
if valuer, ok := interface{}(arg).(driver.Valuer); ok {
|
if valuer, ok := interface{}(arg).(driver.Valuer); ok {
|
||||||
arg, _ = valuer.Value()
|
arg, _ = valuer.Value()
|
||||||
}
|
}
|
||||||
str = strings.Replace(str, "?", scope.AddToVars(arg), 1)
|
str = scope.ReplaceOnePlaceholder(str, scope.AddToVars(arg))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
@ -765,7 +778,7 @@ func (scope *Scope) orderSQL() string {
|
|||||||
} else if expr, ok := order.(*expr); ok {
|
} else if expr, ok := order.(*expr); ok {
|
||||||
exp := expr.expr
|
exp := expr.expr
|
||||||
for _, arg := range expr.args {
|
for _, arg := range expr.args {
|
||||||
exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1)
|
exp = scope.ReplaceOnePlaceholder(exp, scope.AddToVars(arg))
|
||||||
}
|
}
|
||||||
orders = append(orders, exp)
|
orders = append(orders, exp)
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user