misc: correct composing sql replace "?" logic
This commit is contained in:
		
							parent
							
								
									9fe3aeb2a8
								
							
						
					
					
						commit
						3b41aa8f66
					
				@ -79,10 +79,13 @@ var LogFormatter = func(values ...interface{}) (messages []interface{}) {
 | 
			
		||||
				}
 | 
			
		||||
			} else {
 | 
			
		||||
				formattedValuesLength := len(formattedValues)
 | 
			
		||||
				for index, value := range sqlRegexp.Split(values[3].(string), -1) {
 | 
			
		||||
				s := sqlRegexp.Split(values[3].(string), -1)
 | 
			
		||||
				for index, value := range s {
 | 
			
		||||
					sql += value
 | 
			
		||||
					if index < formattedValuesLength {
 | 
			
		||||
						sql += formattedValues[index]
 | 
			
		||||
					} else if index != len(s)-1 {
 | 
			
		||||
						sql += "?"
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
@ -393,6 +393,13 @@ func TestRow(t *testing.T) {
 | 
			
		||||
	if age != 10 {
 | 
			
		||||
		t.Errorf("Scan with Row")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	age = 0
 | 
			
		||||
	row = DB.Debug().Table("users").Where("name != ? AND name != ? AND Age = ? AND name != ?", "???", "???", 10, "???").Select("age").Row()
 | 
			
		||||
	row.Scan(&age)
 | 
			
		||||
	if age != 10 {
 | 
			
		||||
		t.Errorf("Scan with Row")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestRows(t *testing.T) {
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										39
									
								
								scope.go
									
									
									
									
									
								
							
							
						
						
									
										39
									
								
								scope.go
									
									
									
									
									
								
							@ -261,7 +261,7 @@ func (scope *Scope) AddToVars(value interface{}) string {
 | 
			
		||||
			if skipBindVar {
 | 
			
		||||
				scope.AddToVars(arg)
 | 
			
		||||
			} else {
 | 
			
		||||
				exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1)
 | 
			
		||||
				exp = scope.ReplaceOnePlaceholder(exp, scope.AddToVars(arg))
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		return exp
 | 
			
		||||
@ -280,6 +280,20 @@ func (scope *Scope) AddToVars(value interface{}) string {
 | 
			
		||||
	return dialect.BindVar(len(scope.SQLVars))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (scope *Scope) ReplaceOnePlaceholder(sql, value string) string {
 | 
			
		||||
	quoteCount := 0
 | 
			
		||||
	for idx, c := range sql {
 | 
			
		||||
		if string(c) == "'" {
 | 
			
		||||
			quoteCount += 1
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		if string(c) == "?" && quoteCount%2 == 0 {
 | 
			
		||||
			return string([]rune(sql)[:idx]) + value + string([]rune(sql)[idx+1:])
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return sql
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SelectAttrs return selected attributes
 | 
			
		||||
func (scope *Scope) SelectAttrs() []string {
 | 
			
		||||
	if scope.selectAttrs == nil {
 | 
			
		||||
@ -565,22 +579,21 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
 | 
			
		||||
		switch reflect.ValueOf(arg).Kind() {
 | 
			
		||||
		case reflect.Slice: // For where("id in (?)", []int64{1,2})
 | 
			
		||||
			if bytes, ok := arg.([]byte); ok {
 | 
			
		||||
				str = strings.Replace(str, "?", scope.AddToVars(bytes), 1)
 | 
			
		||||
				str = scope.ReplaceOnePlaceholder(str, scope.AddToVars(bytes))
 | 
			
		||||
			} else if values := reflect.ValueOf(arg); values.Len() > 0 {
 | 
			
		||||
				var tempMarks []string
 | 
			
		||||
				for i := 0; i < values.Len(); i++ {
 | 
			
		||||
					tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
 | 
			
		||||
				}
 | 
			
		||||
				str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
 | 
			
		||||
				str = scope.ReplaceOnePlaceholder(str, strings.Join(tempMarks, ","))
 | 
			
		||||
			} else {
 | 
			
		||||
				str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1)
 | 
			
		||||
				str = scope.ReplaceOnePlaceholder(str, scope.AddToVars(Expr("NULL")))
 | 
			
		||||
			}
 | 
			
		||||
		default:
 | 
			
		||||
			if valuer, ok := interface{}(arg).(driver.Valuer); ok {
 | 
			
		||||
				arg, _ = valuer.Value()
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			str = strings.Replace(str, "?", scope.AddToVars(arg), 1)
 | 
			
		||||
			str = scope.ReplaceOnePlaceholder(str, scope.AddToVars(arg))
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
@ -637,21 +650,21 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string
 | 
			
		||||
		switch reflect.ValueOf(arg).Kind() {
 | 
			
		||||
		case reflect.Slice: // For where("id in (?)", []int64{1,2})
 | 
			
		||||
			if bytes, ok := arg.([]byte); ok {
 | 
			
		||||
				str = strings.Replace(str, "?", scope.AddToVars(bytes), 1)
 | 
			
		||||
				str = scope.ReplaceOnePlaceholder(str, scope.AddToVars(bytes))
 | 
			
		||||
			} else if values := reflect.ValueOf(arg); values.Len() > 0 {
 | 
			
		||||
				var tempMarks []string
 | 
			
		||||
				for i := 0; i < values.Len(); i++ {
 | 
			
		||||
					tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
 | 
			
		||||
				}
 | 
			
		||||
				str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
 | 
			
		||||
				str = scope.ReplaceOnePlaceholder(str, strings.Join(tempMarks, ","))
 | 
			
		||||
			} else {
 | 
			
		||||
				str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1)
 | 
			
		||||
				str = scope.ReplaceOnePlaceholder(str, scope.AddToVars(Expr("NULL")))
 | 
			
		||||
			}
 | 
			
		||||
		default:
 | 
			
		||||
			if scanner, ok := interface{}(arg).(driver.Valuer); ok {
 | 
			
		||||
				arg, _ = scanner.Value()
 | 
			
		||||
			}
 | 
			
		||||
			str = strings.Replace(notEqualSQL, "?", scope.AddToVars(arg), 1)
 | 
			
		||||
			str = scope.ReplaceOnePlaceholder(notEqualSQL, scope.AddToVars(arg))
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
@ -674,12 +687,12 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string)
 | 
			
		||||
			for i := 0; i < values.Len(); i++ {
 | 
			
		||||
				tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
 | 
			
		||||
			}
 | 
			
		||||
			str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
 | 
			
		||||
			str = scope.ReplaceOnePlaceholder(str, strings.Join(tempMarks, ","))
 | 
			
		||||
		default:
 | 
			
		||||
			if valuer, ok := interface{}(arg).(driver.Valuer); ok {
 | 
			
		||||
				arg, _ = valuer.Value()
 | 
			
		||||
			}
 | 
			
		||||
			str = strings.Replace(str, "?", scope.AddToVars(arg), 1)
 | 
			
		||||
			str = scope.ReplaceOnePlaceholder(str, scope.AddToVars(arg))
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
@ -765,7 +778,7 @@ func (scope *Scope) orderSQL() string {
 | 
			
		||||
		} else if expr, ok := order.(*expr); ok {
 | 
			
		||||
			exp := expr.expr
 | 
			
		||||
			for _, arg := range expr.args {
 | 
			
		||||
				exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1)
 | 
			
		||||
				exp = scope.ReplaceOnePlaceholder(exp, scope.AddToVars(arg))
 | 
			
		||||
			}
 | 
			
		||||
			orders = append(orders, exp)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user