Merge 3b41aa8f6636908f0f2e0ab75e4e0ade64bac885 into 87fc1b24737a885147240041293603eceb844356
This commit is contained in:
commit
001b074d5f
@ -12,7 +12,7 @@ import (
|
|||||||
func TestCreate(t *testing.T) {
|
func TestCreate(t *testing.T) {
|
||||||
float := 35.03554004971999
|
float := 35.03554004971999
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
user := User{Name: "CreateUser", Age: 18, Birthday: &now, UserNum: Num(111), PasswordHash: []byte{'f', 'a', 'k', '4'}, Latitude: float}
|
user := User{Name: "CreateUser", Age: 18, Birthday: &now, UserNum: Num(111), PasswordHash: []byte{0, 'f', 'a', 'k', '4'}, Latitude: float}
|
||||||
|
|
||||||
if !DB.NewRecord(user) || !DB.NewRecord(&user) {
|
if !DB.NewRecord(user) || !DB.NewRecord(&user) {
|
||||||
t.Error("User should be new record before create")
|
t.Error("User should be new record before create")
|
||||||
@ -29,7 +29,7 @@ func TestCreate(t *testing.T) {
|
|||||||
var newUser User
|
var newUser User
|
||||||
DB.First(&newUser, user.Id)
|
DB.First(&newUser, user.Id)
|
||||||
|
|
||||||
if !reflect.DeepEqual(newUser.PasswordHash, []byte{'f', 'a', 'k', '4'}) {
|
if !reflect.DeepEqual(newUser.PasswordHash, []byte{0, 'f', 'a', 'k', '4'}) {
|
||||||
t.Errorf("User's PasswordHash should be saved ([]byte)")
|
t.Errorf("User's PasswordHash should be saved ([]byte)")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -16,6 +16,9 @@ type Dialect interface {
|
|||||||
// SetDB set db for dialect
|
// SetDB set db for dialect
|
||||||
SetDB(db SQLCommon)
|
SetDB(db SQLCommon)
|
||||||
|
|
||||||
|
// StringifyVar return escape string of var
|
||||||
|
StringifyVar(value interface{}) (string, bool)
|
||||||
|
|
||||||
// BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1
|
// BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1
|
||||||
BindVar(i int) string
|
BindVar(i int) string
|
||||||
// Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name
|
// Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name
|
||||||
|
@ -30,6 +30,10 @@ func (s *commonDialect) SetDB(db SQLCommon) {
|
|||||||
s.db = db
|
s.db = db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (commonDialect) StringifyVar(value interface{}) (string, bool) {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
func (commonDialect) BindVar(i int) string {
|
func (commonDialect) BindVar(i int) string {
|
||||||
return "$$$" // ?
|
return "$$$" // ?
|
||||||
}
|
}
|
||||||
|
@ -1,10 +1,14 @@
|
|||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
dPostgres "github.com/jinzhu/gorm/dialects/postgres"
|
||||||
)
|
)
|
||||||
|
|
||||||
type postgres struct {
|
type postgres struct {
|
||||||
@ -20,6 +24,88 @@ func (postgres) GetName() string {
|
|||||||
return "postgres"
|
return "postgres"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func literal(s string) string {
|
||||||
|
p := ""
|
||||||
|
|
||||||
|
if strings.Contains(s, `\`) {
|
||||||
|
p = "E"
|
||||||
|
}
|
||||||
|
|
||||||
|
s = strings.Replace(s, `'`, `''`, -1)
|
||||||
|
s = strings.Replace(s, `\`, `\\`, -1)
|
||||||
|
return p + `'` + s + `'`
|
||||||
|
}
|
||||||
|
|
||||||
|
func isNil(value interface{}) (ret bool) {
|
||||||
|
ret = false
|
||||||
|
defer func() {
|
||||||
|
if e := recover(); e != nil {
|
||||||
|
// DO NOTHING
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if value == nil {
|
||||||
|
ret = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if reflect.ValueOf(value).IsNil() {
|
||||||
|
ret = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p postgres) StringifyVar(value interface{}) (ret string, ok bool) {
|
||||||
|
ok = true
|
||||||
|
if isNil(value) {
|
||||||
|
ret = "NULL"
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if reflect.TypeOf(value).Kind() == reflect.Ptr {
|
||||||
|
ret, ok = p.StringifyVar(reflect.ValueOf(value).Elem().Interface())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch value.(type) {
|
||||||
|
case string:
|
||||||
|
s := value.(string)
|
||||||
|
ret = literal(s)
|
||||||
|
return
|
||||||
|
case time.Time:
|
||||||
|
s := value.(time.Time)
|
||||||
|
ret = literal(s.Format(time.RFC3339Nano))
|
||||||
|
return
|
||||||
|
case dPostgres.Hstore:
|
||||||
|
s := value.(dPostgres.Hstore)
|
||||||
|
if v, err := s.Value(); err == nil {
|
||||||
|
ret = literal(string(v.([]byte)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case []byte:
|
||||||
|
s := value.([]byte)
|
||||||
|
ret = "decode(" + literal(hex.EncodeToString(s)) + ", 'hex')"
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if s, ok2 := value.(driver.Valuer); ok2 {
|
||||||
|
if v, err := s.Value(); err == nil {
|
||||||
|
ret, ok = p.StringifyVar(v)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if s, ok2 := value.(fmt.Stringer); ok2 {
|
||||||
|
ret = literal(s.String())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ret = fmt.Sprintf("%v", value)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func (postgres) BindVar(i int) string {
|
func (postgres) BindVar(i int) string {
|
||||||
return fmt.Sprintf("$%v", i)
|
return fmt.Sprintf("$%v", i)
|
||||||
}
|
}
|
||||||
|
@ -49,6 +49,10 @@ func (s *mssql) SetDB(db gorm.SQLCommon) {
|
|||||||
s.db = db
|
s.db = db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (mssql) StringifyVar(value interface{}) (string, bool) {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
func (mssql) BindVar(i int) string {
|
func (mssql) BindVar(i int) string {
|
||||||
return "$$$" // ?
|
return "$$$" // ?
|
||||||
}
|
}
|
||||||
|
@ -79,10 +79,13 @@ var LogFormatter = func(values ...interface{}) (messages []interface{}) {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
formattedValuesLength := len(formattedValues)
|
formattedValuesLength := len(formattedValues)
|
||||||
for index, value := range sqlRegexp.Split(values[3].(string), -1) {
|
s := sqlRegexp.Split(values[3].(string), -1)
|
||||||
|
for index, value := range s {
|
||||||
sql += value
|
sql += value
|
||||||
if index < formattedValuesLength {
|
if index < formattedValuesLength {
|
||||||
sql += formattedValues[index]
|
sql += formattedValues[index]
|
||||||
|
} else if index != len(s)-1 {
|
||||||
|
sql += "?"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
19
main_test.go
19
main_test.go
@ -393,6 +393,13 @@ func TestRow(t *testing.T) {
|
|||||||
if age != 10 {
|
if age != 10 {
|
||||||
t.Errorf("Scan with Row")
|
t.Errorf("Scan with Row")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
age = 0
|
||||||
|
row = DB.Debug().Table("users").Where("name != ? AND name != ? AND Age = ? AND name != ?", "???", "???", 10, "???").Select("age").Row()
|
||||||
|
row.Scan(&age)
|
||||||
|
if age != 10 {
|
||||||
|
t.Errorf("Scan with Row")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRows(t *testing.T) {
|
func TestRows(t *testing.T) {
|
||||||
@ -727,6 +734,7 @@ func TestHstore(t *testing.T) {
|
|||||||
"bankAccountId": &bankAccountId,
|
"bankAccountId": &bankAccountId,
|
||||||
"phoneNumber": &phoneNumber,
|
"phoneNumber": &phoneNumber,
|
||||||
"opinion": &opinion,
|
"opinion": &opinion,
|
||||||
|
"nil": nil,
|
||||||
}
|
}
|
||||||
d := Details{Bulk: bulk}
|
d := Details{Bulk: bulk}
|
||||||
DB.Save(&d)
|
DB.Save(&d)
|
||||||
@ -738,8 +746,15 @@ func TestHstore(t *testing.T) {
|
|||||||
|
|
||||||
for k := range bulk {
|
for k := range bulk {
|
||||||
if r, ok := d2.Bulk[k]; ok {
|
if r, ok := d2.Bulk[k]; ok {
|
||||||
if res, _ := bulk[k]; *res != *r {
|
res, _ := bulk[k]
|
||||||
t.Errorf("Details should be equal")
|
if res == nil || r == nil {
|
||||||
|
if res != r {
|
||||||
|
t.Errorf("Details should be equal")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if *res != *r {
|
||||||
|
t.Errorf("Details should be equal")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
t.Errorf("Details should be existed")
|
t.Errorf("Details should be existed")
|
||||||
|
46
scope.go
46
scope.go
@ -261,18 +261,37 @@ func (scope *Scope) AddToVars(value interface{}) string {
|
|||||||
if skipBindVar {
|
if skipBindVar {
|
||||||
scope.AddToVars(arg)
|
scope.AddToVars(arg)
|
||||||
} else {
|
} else {
|
||||||
exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1)
|
exp = scope.ReplaceOnePlaceholder(exp, scope.AddToVars(arg))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return exp
|
return exp
|
||||||
}
|
}
|
||||||
|
|
||||||
|
dialect := scope.Dialect()
|
||||||
|
if str, ok := dialect.StringifyVar(value); ok {
|
||||||
|
return str
|
||||||
|
}
|
||||||
|
|
||||||
scope.SQLVars = append(scope.SQLVars, value)
|
scope.SQLVars = append(scope.SQLVars, value)
|
||||||
|
|
||||||
if skipBindVar {
|
if skipBindVar {
|
||||||
return "?"
|
return "?"
|
||||||
}
|
}
|
||||||
return scope.Dialect().BindVar(len(scope.SQLVars))
|
return dialect.BindVar(len(scope.SQLVars))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) ReplaceOnePlaceholder(sql, value string) string {
|
||||||
|
quoteCount := 0
|
||||||
|
for idx, c := range sql {
|
||||||
|
if string(c) == "'" {
|
||||||
|
quoteCount += 1
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if string(c) == "?" && quoteCount%2 == 0 {
|
||||||
|
return string([]rune(sql)[:idx]) + value + string([]rune(sql)[idx+1:])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sql
|
||||||
}
|
}
|
||||||
|
|
||||||
// SelectAttrs return selected attributes
|
// SelectAttrs return selected attributes
|
||||||
@ -560,22 +579,21 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
|
|||||||
switch reflect.ValueOf(arg).Kind() {
|
switch reflect.ValueOf(arg).Kind() {
|
||||||
case reflect.Slice: // For where("id in (?)", []int64{1,2})
|
case reflect.Slice: // For where("id in (?)", []int64{1,2})
|
||||||
if bytes, ok := arg.([]byte); ok {
|
if bytes, ok := arg.([]byte); ok {
|
||||||
str = strings.Replace(str, "?", scope.AddToVars(bytes), 1)
|
str = scope.ReplaceOnePlaceholder(str, scope.AddToVars(bytes))
|
||||||
} else if values := reflect.ValueOf(arg); values.Len() > 0 {
|
} else if values := reflect.ValueOf(arg); values.Len() > 0 {
|
||||||
var tempMarks []string
|
var tempMarks []string
|
||||||
for i := 0; i < values.Len(); i++ {
|
for i := 0; i < values.Len(); i++ {
|
||||||
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
|
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
|
||||||
}
|
}
|
||||||
str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
|
str = scope.ReplaceOnePlaceholder(str, strings.Join(tempMarks, ","))
|
||||||
} else {
|
} else {
|
||||||
str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1)
|
str = scope.ReplaceOnePlaceholder(str, scope.AddToVars(Expr("NULL")))
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
if valuer, ok := interface{}(arg).(driver.Valuer); ok {
|
if valuer, ok := interface{}(arg).(driver.Valuer); ok {
|
||||||
arg, _ = valuer.Value()
|
arg, _ = valuer.Value()
|
||||||
}
|
}
|
||||||
|
str = scope.ReplaceOnePlaceholder(str, scope.AddToVars(arg))
|
||||||
str = strings.Replace(str, "?", scope.AddToVars(arg), 1)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
@ -632,21 +650,21 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string
|
|||||||
switch reflect.ValueOf(arg).Kind() {
|
switch reflect.ValueOf(arg).Kind() {
|
||||||
case reflect.Slice: // For where("id in (?)", []int64{1,2})
|
case reflect.Slice: // For where("id in (?)", []int64{1,2})
|
||||||
if bytes, ok := arg.([]byte); ok {
|
if bytes, ok := arg.([]byte); ok {
|
||||||
str = strings.Replace(str, "?", scope.AddToVars(bytes), 1)
|
str = scope.ReplaceOnePlaceholder(str, scope.AddToVars(bytes))
|
||||||
} else if values := reflect.ValueOf(arg); values.Len() > 0 {
|
} else if values := reflect.ValueOf(arg); values.Len() > 0 {
|
||||||
var tempMarks []string
|
var tempMarks []string
|
||||||
for i := 0; i < values.Len(); i++ {
|
for i := 0; i < values.Len(); i++ {
|
||||||
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
|
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
|
||||||
}
|
}
|
||||||
str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
|
str = scope.ReplaceOnePlaceholder(str, strings.Join(tempMarks, ","))
|
||||||
} else {
|
} else {
|
||||||
str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1)
|
str = scope.ReplaceOnePlaceholder(str, scope.AddToVars(Expr("NULL")))
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
if scanner, ok := interface{}(arg).(driver.Valuer); ok {
|
if scanner, ok := interface{}(arg).(driver.Valuer); ok {
|
||||||
arg, _ = scanner.Value()
|
arg, _ = scanner.Value()
|
||||||
}
|
}
|
||||||
str = strings.Replace(notEqualSQL, "?", scope.AddToVars(arg), 1)
|
str = scope.ReplaceOnePlaceholder(notEqualSQL, scope.AddToVars(arg))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
@ -669,12 +687,12 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string)
|
|||||||
for i := 0; i < values.Len(); i++ {
|
for i := 0; i < values.Len(); i++ {
|
||||||
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
|
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
|
||||||
}
|
}
|
||||||
str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
|
str = scope.ReplaceOnePlaceholder(str, strings.Join(tempMarks, ","))
|
||||||
default:
|
default:
|
||||||
if valuer, ok := interface{}(arg).(driver.Valuer); ok {
|
if valuer, ok := interface{}(arg).(driver.Valuer); ok {
|
||||||
arg, _ = valuer.Value()
|
arg, _ = valuer.Value()
|
||||||
}
|
}
|
||||||
str = strings.Replace(str, "?", scope.AddToVars(arg), 1)
|
str = scope.ReplaceOnePlaceholder(str, scope.AddToVars(arg))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
@ -760,7 +778,7 @@ func (scope *Scope) orderSQL() string {
|
|||||||
} else if expr, ok := order.(*expr); ok {
|
} else if expr, ok := order.(*expr); ok {
|
||||||
exp := expr.expr
|
exp := expr.expr
|
||||||
for _, arg := range expr.args {
|
for _, arg := range expr.args {
|
||||||
exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1)
|
exp = scope.ReplaceOnePlaceholder(exp, scope.AddToVars(arg))
|
||||||
}
|
}
|
||||||
orders = append(orders, exp)
|
orders = append(orders, exp)
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user