diff --git a/create_test.go b/create_test.go index 36472914..8eafc216 100644 --- a/create_test.go +++ b/create_test.go @@ -12,7 +12,7 @@ import ( func TestCreate(t *testing.T) { float := 35.03554004971999 now := time.Now() - user := User{Name: "CreateUser", Age: 18, Birthday: &now, UserNum: Num(111), PasswordHash: []byte{'f', 'a', 'k', '4'}, Latitude: float} + user := User{Name: "CreateUser", Age: 18, Birthday: &now, UserNum: Num(111), PasswordHash: []byte{0, 'f', 'a', 'k', '4'}, Latitude: float} if !DB.NewRecord(user) || !DB.NewRecord(&user) { t.Error("User should be new record before create") @@ -29,7 +29,7 @@ func TestCreate(t *testing.T) { var newUser User DB.First(&newUser, user.Id) - if !reflect.DeepEqual(newUser.PasswordHash, []byte{'f', 'a', 'k', '4'}) { + if !reflect.DeepEqual(newUser.PasswordHash, []byte{0, 'f', 'a', 'k', '4'}) { t.Errorf("User's PasswordHash should be saved ([]byte)") } diff --git a/dialect.go b/dialect.go index e879588b..f10ccee6 100644 --- a/dialect.go +++ b/dialect.go @@ -16,6 +16,9 @@ type Dialect interface { // SetDB set db for dialect SetDB(db SQLCommon) + // StringifyVar return escape string of var + StringifyVar(value interface{}) (string, bool) + // BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1 BindVar(i int) string // Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name diff --git a/dialect_common.go b/dialect_common.go index a99627f2..17d6c186 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -30,6 +30,10 @@ func (s *commonDialect) SetDB(db SQLCommon) { s.db = db } +func (commonDialect) StringifyVar(value interface{}) (string, bool) { + return "", false +} + func (commonDialect) BindVar(i int) string { return "$$$" // ? } diff --git a/dialect_postgres.go b/dialect_postgres.go index 6fdf4df1..8349664e 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -1,10 +1,14 @@ package gorm import ( + "database/sql/driver" + "encoding/hex" "fmt" "reflect" "strings" "time" + + dPostgres "github.com/jinzhu/gorm/dialects/postgres" ) type postgres struct { @@ -20,6 +24,88 @@ func (postgres) GetName() string { return "postgres" } +func literal(s string) string { + p := "" + + if strings.Contains(s, `\`) { + p = "E" + } + + s = strings.Replace(s, `'`, `''`, -1) + s = strings.Replace(s, `\`, `\\`, -1) + return p + `'` + s + `'` +} + +func isNil(value interface{}) (ret bool) { + ret = false + defer func() { + if e := recover(); e != nil { + // DO NOTHING + } + }() + + if value == nil { + ret = true + return + } + + if reflect.ValueOf(value).IsNil() { + ret = true + return + } + + return +} + +func (p postgres) StringifyVar(value interface{}) (ret string, ok bool) { + ok = true + if isNil(value) { + ret = "NULL" + return + } + + if reflect.TypeOf(value).Kind() == reflect.Ptr { + ret, ok = p.StringifyVar(reflect.ValueOf(value).Elem().Interface()) + return + } + + switch value.(type) { + case string: + s := value.(string) + ret = literal(s) + return + case time.Time: + s := value.(time.Time) + ret = literal(s.Format(time.RFC3339Nano)) + return + case dPostgres.Hstore: + s := value.(dPostgres.Hstore) + if v, err := s.Value(); err == nil { + ret = literal(string(v.([]byte))) + return + } + case []byte: + s := value.([]byte) + ret = "decode(" + literal(hex.EncodeToString(s)) + ", 'hex')" + return + } + + if s, ok2 := value.(driver.Valuer); ok2 { + if v, err := s.Value(); err == nil { + ret, ok = p.StringifyVar(v) + return + } + } + + if s, ok2 := value.(fmt.Stringer); ok2 { + ret = literal(s.String()) + return + } + + ret = fmt.Sprintf("%v", value) + return +} + func (postgres) BindVar(i int) string { return fmt.Sprintf("$%v", i) } diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index de2ae7ca..74c31db1 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -49,6 +49,10 @@ func (s *mssql) SetDB(db gorm.SQLCommon) { s.db = db } +func (mssql) StringifyVar(value interface{}) (string, bool) { + return "", false +} + func (mssql) BindVar(i int) string { return "$$$" // ? } diff --git a/main_test.go b/main_test.go index 34f96a86..fee8675b 100644 --- a/main_test.go +++ b/main_test.go @@ -727,6 +727,7 @@ func TestHstore(t *testing.T) { "bankAccountId": &bankAccountId, "phoneNumber": &phoneNumber, "opinion": &opinion, + "nil": nil, } d := Details{Bulk: bulk} DB.Save(&d) @@ -738,8 +739,15 @@ func TestHstore(t *testing.T) { for k := range bulk { if r, ok := d2.Bulk[k]; ok { - if res, _ := bulk[k]; *res != *r { - t.Errorf("Details should be equal") + res, _ := bulk[k] + if res == nil || r == nil { + if res != r { + t.Errorf("Details should be equal") + } + } else { + if *res != *r { + t.Errorf("Details should be equal") + } } } else { t.Errorf("Details should be existed") diff --git a/scope.go b/scope.go index 51ebd5a0..4d453608 100644 --- a/scope.go +++ b/scope.go @@ -267,12 +267,17 @@ func (scope *Scope) AddToVars(value interface{}) string { return exp } + dialect := scope.Dialect() + if str, ok := dialect.StringifyVar(value); ok { + return str + } + scope.SQLVars = append(scope.SQLVars, value) if skipBindVar { return "?" } - return scope.Dialect().BindVar(len(scope.SQLVars)) + return dialect.BindVar(len(scope.SQLVars)) } // SelectAttrs return selected attributes