diff --git a/scope.go b/scope.go index a8ea3459..b0f0a90a 100644 --- a/scope.go +++ b/scope.go @@ -256,11 +256,7 @@ func (scope *Scope) AddToVars(value interface{}) string { if expr, ok := value.(*expr); ok { exp := expr.expr for _, arg := range expr.args { - if pgParameterRegexp.MatchString(exp) { - exp = pgParameterRegexp.ReplaceAllLiteralString(exp, scope.AddToVars(arg)) - } else { - exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) - } + exp = scope.replaceParameterPlaceholder(exp, arg) } return exp } @@ -466,6 +462,21 @@ func (scope *Scope) quoteIfPossible(str string) string { return str } +func (scope *Scope) replaceParameterPlaceholderLiteral(sql string, parameter interface{}, addToVars bool) string { + if scope.Dialect().GetName() == "postgres" && pgParameterRegexp.MatchString(sql) { + sql = pgParameterRegexp.ReplaceAllLiteralString(sql, "?") + } + if val, ok := parameter.(string); ok && !addToVars { + return strings.Replace(sql, "?", val, 1) + } + + return strings.Replace(sql, "?", scope.AddToVars(parameter), 1) +} + +func (scope *Scope) replaceParameterPlaceholder(sql string, parameter interface{}) string { + return scope.replaceParameterPlaceholderLiteral(sql, parameter, true) +} + func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { var ( ignored interface{} @@ -555,22 +566,22 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri switch reflect.ValueOf(arg).Kind() { case reflect.Slice: // For where("id in (?)", []int64{1,2}) if bytes, ok := arg.([]byte); ok { - str = strings.Replace(str, "?", scope.AddToVars(bytes), 1) + str = scope.replaceParameterPlaceholder(str, bytes) } else if values := reflect.ValueOf(arg); values.Len() > 0 { var tempMarks []string for i := 0; i < values.Len(); i++ { tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) } - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) + str = scope.replaceParameterPlaceholderLiteral(str, strings.Join(tempMarks, ","), false) } else { - str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1) + str = scope.replaceParameterPlaceholder(str, Expr("NULL")) } default: if valuer, ok := interface{}(arg).(driver.Valuer); ok { arg, _ = valuer.Value() } - str = strings.Replace(str, "?", scope.AddToVars(arg), 1) + str = scope.replaceParameterPlaceholder(str, arg) } } return @@ -627,21 +638,21 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string switch reflect.ValueOf(arg).Kind() { case reflect.Slice: // For where("id in (?)", []int64{1,2}) if bytes, ok := arg.([]byte); ok { - str = strings.Replace(str, "?", scope.AddToVars(bytes), 1) + str = scope.replaceParameterPlaceholder(str, bytes) } else if values := reflect.ValueOf(arg); values.Len() > 0 { var tempMarks []string for i := 0; i < values.Len(); i++ { tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) } - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) + str = scope.replaceParameterPlaceholder(str, strings.Join(tempMarks, ",")) } else { - str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1) + str = scope.replaceParameterPlaceholder(str, Expr("NULL")) } default: if scanner, ok := interface{}(arg).(driver.Valuer); ok { arg, _ = scanner.Value() } - str = strings.Replace(notEqualSQL, "?", scope.AddToVars(arg), 1) + str = scope.replaceParameterPlaceholder(notEqualSQL, arg) } } return @@ -664,12 +675,12 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) for i := 0; i < values.Len(); i++ { tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) } - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) + str = scope.replaceParameterPlaceholder(str, strings.Join(tempMarks, ",")) default: if valuer, ok := interface{}(arg).(driver.Valuer); ok { arg, _ = valuer.Value() } - str = strings.Replace(str, "?", scope.AddToVars(arg), 1) + str = scope.replaceParameterPlaceholder(str, arg) } } return @@ -755,7 +766,7 @@ func (scope *Scope) orderSQL() string { } else if expr, ok := order.(*expr); ok { exp := expr.expr for _, arg := range expr.args { - exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) + exp = scope.replaceParameterPlaceholder(exp, arg) } orders = append(orders, exp) }