From 3b41aa8f6636908f0f2e0ab75e4e0ade64bac885 Mon Sep 17 00:00:00 2001 From: wmin0 Date: Tue, 26 Dec 2017 13:26:41 +0800 Subject: [PATCH] misc: correct composing sql replace "?" logic --- logger.go | 5 ++++- main_test.go | 7 +++++++ scope.go | 39 ++++++++++++++++++++++++++------------- 3 files changed, 37 insertions(+), 14 deletions(-) diff --git a/logger.go b/logger.go index 4324a2e4..09977089 100644 --- a/logger.go +++ b/logger.go @@ -79,10 +79,13 @@ var LogFormatter = func(values ...interface{}) (messages []interface{}) { } } else { formattedValuesLength := len(formattedValues) - for index, value := range sqlRegexp.Split(values[3].(string), -1) { + s := sqlRegexp.Split(values[3].(string), -1) + for index, value := range s { sql += value if index < formattedValuesLength { sql += formattedValues[index] + } else if index != len(s)-1 { + sql += "?" } } } diff --git a/main_test.go b/main_test.go index fee8675b..6ccc49ac 100644 --- a/main_test.go +++ b/main_test.go @@ -393,6 +393,13 @@ func TestRow(t *testing.T) { if age != 10 { t.Errorf("Scan with Row") } + + age = 0 + row = DB.Debug().Table("users").Where("name != ? AND name != ? AND Age = ? AND name != ?", "???", "???", 10, "???").Select("age").Row() + row.Scan(&age) + if age != 10 { + t.Errorf("Scan with Row") + } } func TestRows(t *testing.T) { diff --git a/scope.go b/scope.go index 4d453608..3d208244 100644 --- a/scope.go +++ b/scope.go @@ -261,7 +261,7 @@ func (scope *Scope) AddToVars(value interface{}) string { if skipBindVar { scope.AddToVars(arg) } else { - exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) + exp = scope.ReplaceOnePlaceholder(exp, scope.AddToVars(arg)) } } return exp @@ -280,6 +280,20 @@ func (scope *Scope) AddToVars(value interface{}) string { return dialect.BindVar(len(scope.SQLVars)) } +func (scope *Scope) ReplaceOnePlaceholder(sql, value string) string { + quoteCount := 0 + for idx, c := range sql { + if string(c) == "'" { + quoteCount += 1 + continue + } + if string(c) == "?" && quoteCount%2 == 0 { + return string([]rune(sql)[:idx]) + value + string([]rune(sql)[idx+1:]) + } + } + return sql +} + // SelectAttrs return selected attributes func (scope *Scope) SelectAttrs() []string { if scope.selectAttrs == nil { @@ -565,22 +579,21 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri switch reflect.ValueOf(arg).Kind() { case reflect.Slice: // For where("id in (?)", []int64{1,2}) if bytes, ok := arg.([]byte); ok { - str = strings.Replace(str, "?", scope.AddToVars(bytes), 1) + str = scope.ReplaceOnePlaceholder(str, scope.AddToVars(bytes)) } else if values := reflect.ValueOf(arg); values.Len() > 0 { var tempMarks []string for i := 0; i < values.Len(); i++ { tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) } - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) + str = scope.ReplaceOnePlaceholder(str, strings.Join(tempMarks, ",")) } else { - str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1) + str = scope.ReplaceOnePlaceholder(str, scope.AddToVars(Expr("NULL"))) } default: if valuer, ok := interface{}(arg).(driver.Valuer); ok { arg, _ = valuer.Value() } - - str = strings.Replace(str, "?", scope.AddToVars(arg), 1) + str = scope.ReplaceOnePlaceholder(str, scope.AddToVars(arg)) } } return @@ -637,21 +650,21 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string switch reflect.ValueOf(arg).Kind() { case reflect.Slice: // For where("id in (?)", []int64{1,2}) if bytes, ok := arg.([]byte); ok { - str = strings.Replace(str, "?", scope.AddToVars(bytes), 1) + str = scope.ReplaceOnePlaceholder(str, scope.AddToVars(bytes)) } else if values := reflect.ValueOf(arg); values.Len() > 0 { var tempMarks []string for i := 0; i < values.Len(); i++ { tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) } - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) + str = scope.ReplaceOnePlaceholder(str, strings.Join(tempMarks, ",")) } else { - str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1) + str = scope.ReplaceOnePlaceholder(str, scope.AddToVars(Expr("NULL"))) } default: if scanner, ok := interface{}(arg).(driver.Valuer); ok { arg, _ = scanner.Value() } - str = strings.Replace(notEqualSQL, "?", scope.AddToVars(arg), 1) + str = scope.ReplaceOnePlaceholder(notEqualSQL, scope.AddToVars(arg)) } } return @@ -674,12 +687,12 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) for i := 0; i < values.Len(); i++ { tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) } - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) + str = scope.ReplaceOnePlaceholder(str, strings.Join(tempMarks, ",")) default: if valuer, ok := interface{}(arg).(driver.Valuer); ok { arg, _ = valuer.Value() } - str = strings.Replace(str, "?", scope.AddToVars(arg), 1) + str = scope.ReplaceOnePlaceholder(str, scope.AddToVars(arg)) } } return @@ -765,7 +778,7 @@ func (scope *Scope) orderSQL() string { } else if expr, ok := order.(*expr); ok { exp := expr.expr for _, arg := range expr.args { - exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) + exp = scope.ReplaceOnePlaceholder(exp, scope.AddToVars(arg)) } orders = append(orders, exp) }