misc: postgres compose sql var
This commit is contained in:
parent
0a51f6cdc5
commit
9fe3aeb2a8
@ -12,7 +12,7 @@ import (
|
||||
func TestCreate(t *testing.T) {
|
||||
float := 35.03554004971999
|
||||
now := time.Now()
|
||||
user := User{Name: "CreateUser", Age: 18, Birthday: &now, UserNum: Num(111), PasswordHash: []byte{'f', 'a', 'k', '4'}, Latitude: float}
|
||||
user := User{Name: "CreateUser", Age: 18, Birthday: &now, UserNum: Num(111), PasswordHash: []byte{0, 'f', 'a', 'k', '4'}, Latitude: float}
|
||||
|
||||
if !DB.NewRecord(user) || !DB.NewRecord(&user) {
|
||||
t.Error("User should be new record before create")
|
||||
@ -29,7 +29,7 @@ func TestCreate(t *testing.T) {
|
||||
var newUser User
|
||||
DB.First(&newUser, user.Id)
|
||||
|
||||
if !reflect.DeepEqual(newUser.PasswordHash, []byte{'f', 'a', 'k', '4'}) {
|
||||
if !reflect.DeepEqual(newUser.PasswordHash, []byte{0, 'f', 'a', 'k', '4'}) {
|
||||
t.Errorf("User's PasswordHash should be saved ([]byte)")
|
||||
}
|
||||
|
||||
|
@ -16,6 +16,9 @@ type Dialect interface {
|
||||
// SetDB set db for dialect
|
||||
SetDB(db SQLCommon)
|
||||
|
||||
// StringifyVar return escape string of var
|
||||
StringifyVar(value interface{}) (string, bool)
|
||||
|
||||
// BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1
|
||||
BindVar(i int) string
|
||||
// Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name
|
||||
|
@ -30,6 +30,10 @@ func (s *commonDialect) SetDB(db SQLCommon) {
|
||||
s.db = db
|
||||
}
|
||||
|
||||
func (commonDialect) StringifyVar(value interface{}) (string, bool) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
func (commonDialect) BindVar(i int) string {
|
||||
return "$$$" // ?
|
||||
}
|
||||
|
@ -1,10 +1,14 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
dPostgres "github.com/jinzhu/gorm/dialects/postgres"
|
||||
)
|
||||
|
||||
type postgres struct {
|
||||
@ -20,6 +24,88 @@ func (postgres) GetName() string {
|
||||
return "postgres"
|
||||
}
|
||||
|
||||
func literal(s string) string {
|
||||
p := ""
|
||||
|
||||
if strings.Contains(s, `\`) {
|
||||
p = "E"
|
||||
}
|
||||
|
||||
s = strings.Replace(s, `'`, `''`, -1)
|
||||
s = strings.Replace(s, `\`, `\\`, -1)
|
||||
return p + `'` + s + `'`
|
||||
}
|
||||
|
||||
func isNil(value interface{}) (ret bool) {
|
||||
ret = false
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
// DO NOTHING
|
||||
}
|
||||
}()
|
||||
|
||||
if value == nil {
|
||||
ret = true
|
||||
return
|
||||
}
|
||||
|
||||
if reflect.ValueOf(value).IsNil() {
|
||||
ret = true
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (p postgres) StringifyVar(value interface{}) (ret string, ok bool) {
|
||||
ok = true
|
||||
if isNil(value) {
|
||||
ret = "NULL"
|
||||
return
|
||||
}
|
||||
|
||||
if reflect.TypeOf(value).Kind() == reflect.Ptr {
|
||||
ret, ok = p.StringifyVar(reflect.ValueOf(value).Elem().Interface())
|
||||
return
|
||||
}
|
||||
|
||||
switch value.(type) {
|
||||
case string:
|
||||
s := value.(string)
|
||||
ret = literal(s)
|
||||
return
|
||||
case time.Time:
|
||||
s := value.(time.Time)
|
||||
ret = literal(s.Format(time.RFC3339Nano))
|
||||
return
|
||||
case dPostgres.Hstore:
|
||||
s := value.(dPostgres.Hstore)
|
||||
if v, err := s.Value(); err == nil {
|
||||
ret = literal(string(v.([]byte)))
|
||||
return
|
||||
}
|
||||
case []byte:
|
||||
s := value.([]byte)
|
||||
ret = "decode(" + literal(hex.EncodeToString(s)) + ", 'hex')"
|
||||
return
|
||||
}
|
||||
|
||||
if s, ok2 := value.(driver.Valuer); ok2 {
|
||||
if v, err := s.Value(); err == nil {
|
||||
ret, ok = p.StringifyVar(v)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if s, ok2 := value.(fmt.Stringer); ok2 {
|
||||
ret = literal(s.String())
|
||||
return
|
||||
}
|
||||
|
||||
ret = fmt.Sprintf("%v", value)
|
||||
return
|
||||
}
|
||||
|
||||
func (postgres) BindVar(i int) string {
|
||||
return fmt.Sprintf("$%v", i)
|
||||
}
|
||||
|
@ -49,6 +49,10 @@ func (s *mssql) SetDB(db gorm.SQLCommon) {
|
||||
s.db = db
|
||||
}
|
||||
|
||||
func (mssql) StringifyVar(value interface{}) (string, bool) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
func (mssql) BindVar(i int) string {
|
||||
return "$$$" // ?
|
||||
}
|
||||
|
12
main_test.go
12
main_test.go
@ -727,6 +727,7 @@ func TestHstore(t *testing.T) {
|
||||
"bankAccountId": &bankAccountId,
|
||||
"phoneNumber": &phoneNumber,
|
||||
"opinion": &opinion,
|
||||
"nil": nil,
|
||||
}
|
||||
d := Details{Bulk: bulk}
|
||||
DB.Save(&d)
|
||||
@ -738,8 +739,15 @@ func TestHstore(t *testing.T) {
|
||||
|
||||
for k := range bulk {
|
||||
if r, ok := d2.Bulk[k]; ok {
|
||||
if res, _ := bulk[k]; *res != *r {
|
||||
t.Errorf("Details should be equal")
|
||||
res, _ := bulk[k]
|
||||
if res == nil || r == nil {
|
||||
if res != r {
|
||||
t.Errorf("Details should be equal")
|
||||
}
|
||||
} else {
|
||||
if *res != *r {
|
||||
t.Errorf("Details should be equal")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
t.Errorf("Details should be existed")
|
||||
|
7
scope.go
7
scope.go
@ -267,12 +267,17 @@ func (scope *Scope) AddToVars(value interface{}) string {
|
||||
return exp
|
||||
}
|
||||
|
||||
dialect := scope.Dialect()
|
||||
if str, ok := dialect.StringifyVar(value); ok {
|
||||
return str
|
||||
}
|
||||
|
||||
scope.SQLVars = append(scope.SQLVars, value)
|
||||
|
||||
if skipBindVar {
|
||||
return "?"
|
||||
}
|
||||
return scope.Dialect().BindVar(len(scope.SQLVars))
|
||||
return dialect.BindVar(len(scope.SQLVars))
|
||||
}
|
||||
|
||||
// SelectAttrs return selected attributes
|
||||
|
Loading…
x
Reference in New Issue
Block a user