From 9fe3aeb2a89d98a6fe0d2a7a57341e8f4fb45d02 Mon Sep 17 00:00:00 2001 From: wmin0 Date: Thu, 21 Dec 2017 21:58:24 +0800 Subject: [PATCH 1/2] misc: postgres compose sql var --- create_test.go | 4 +- dialect.go | 3 ++ dialect_common.go | 4 ++ dialect_postgres.go | 86 +++++++++++++++++++++++++++++++++++++++++ dialects/mssql/mssql.go | 4 ++ main_test.go | 12 +++++- scope.go | 7 +++- 7 files changed, 115 insertions(+), 5 deletions(-) diff --git a/create_test.go b/create_test.go index 36472914..8eafc216 100644 --- a/create_test.go +++ b/create_test.go @@ -12,7 +12,7 @@ import ( func TestCreate(t *testing.T) { float := 35.03554004971999 now := time.Now() - user := User{Name: "CreateUser", Age: 18, Birthday: &now, UserNum: Num(111), PasswordHash: []byte{'f', 'a', 'k', '4'}, Latitude: float} + user := User{Name: "CreateUser", Age: 18, Birthday: &now, UserNum: Num(111), PasswordHash: []byte{0, 'f', 'a', 'k', '4'}, Latitude: float} if !DB.NewRecord(user) || !DB.NewRecord(&user) { t.Error("User should be new record before create") @@ -29,7 +29,7 @@ func TestCreate(t *testing.T) { var newUser User DB.First(&newUser, user.Id) - if !reflect.DeepEqual(newUser.PasswordHash, []byte{'f', 'a', 'k', '4'}) { + if !reflect.DeepEqual(newUser.PasswordHash, []byte{0, 'f', 'a', 'k', '4'}) { t.Errorf("User's PasswordHash should be saved ([]byte)") } diff --git a/dialect.go b/dialect.go index e879588b..f10ccee6 100644 --- a/dialect.go +++ b/dialect.go @@ -16,6 +16,9 @@ type Dialect interface { // SetDB set db for dialect SetDB(db SQLCommon) + // StringifyVar return escape string of var + StringifyVar(value interface{}) (string, bool) + // BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1 BindVar(i int) string // Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name diff --git a/dialect_common.go b/dialect_common.go index a99627f2..17d6c186 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -30,6 +30,10 @@ func (s *commonDialect) SetDB(db SQLCommon) { s.db = db } +func (commonDialect) StringifyVar(value interface{}) (string, bool) { + return "", false +} + func (commonDialect) BindVar(i int) string { return "$$$" // ? } diff --git a/dialect_postgres.go b/dialect_postgres.go index 6fdf4df1..8349664e 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -1,10 +1,14 @@ package gorm import ( + "database/sql/driver" + "encoding/hex" "fmt" "reflect" "strings" "time" + + dPostgres "github.com/jinzhu/gorm/dialects/postgres" ) type postgres struct { @@ -20,6 +24,88 @@ func (postgres) GetName() string { return "postgres" } +func literal(s string) string { + p := "" + + if strings.Contains(s, `\`) { + p = "E" + } + + s = strings.Replace(s, `'`, `''`, -1) + s = strings.Replace(s, `\`, `\\`, -1) + return p + `'` + s + `'` +} + +func isNil(value interface{}) (ret bool) { + ret = false + defer func() { + if e := recover(); e != nil { + // DO NOTHING + } + }() + + if value == nil { + ret = true + return + } + + if reflect.ValueOf(value).IsNil() { + ret = true + return + } + + return +} + +func (p postgres) StringifyVar(value interface{}) (ret string, ok bool) { + ok = true + if isNil(value) { + ret = "NULL" + return + } + + if reflect.TypeOf(value).Kind() == reflect.Ptr { + ret, ok = p.StringifyVar(reflect.ValueOf(value).Elem().Interface()) + return + } + + switch value.(type) { + case string: + s := value.(string) + ret = literal(s) + return + case time.Time: + s := value.(time.Time) + ret = literal(s.Format(time.RFC3339Nano)) + return + case dPostgres.Hstore: + s := value.(dPostgres.Hstore) + if v, err := s.Value(); err == nil { + ret = literal(string(v.([]byte))) + return + } + case []byte: + s := value.([]byte) + ret = "decode(" + literal(hex.EncodeToString(s)) + ", 'hex')" + return + } + + if s, ok2 := value.(driver.Valuer); ok2 { + if v, err := s.Value(); err == nil { + ret, ok = p.StringifyVar(v) + return + } + } + + if s, ok2 := value.(fmt.Stringer); ok2 { + ret = literal(s.String()) + return + } + + ret = fmt.Sprintf("%v", value) + return +} + func (postgres) BindVar(i int) string { return fmt.Sprintf("$%v", i) } diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index de2ae7ca..74c31db1 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -49,6 +49,10 @@ func (s *mssql) SetDB(db gorm.SQLCommon) { s.db = db } +func (mssql) StringifyVar(value interface{}) (string, bool) { + return "", false +} + func (mssql) BindVar(i int) string { return "$$$" // ? } diff --git a/main_test.go b/main_test.go index 34f96a86..fee8675b 100644 --- a/main_test.go +++ b/main_test.go @@ -727,6 +727,7 @@ func TestHstore(t *testing.T) { "bankAccountId": &bankAccountId, "phoneNumber": &phoneNumber, "opinion": &opinion, + "nil": nil, } d := Details{Bulk: bulk} DB.Save(&d) @@ -738,8 +739,15 @@ func TestHstore(t *testing.T) { for k := range bulk { if r, ok := d2.Bulk[k]; ok { - if res, _ := bulk[k]; *res != *r { - t.Errorf("Details should be equal") + res, _ := bulk[k] + if res == nil || r == nil { + if res != r { + t.Errorf("Details should be equal") + } + } else { + if *res != *r { + t.Errorf("Details should be equal") + } } } else { t.Errorf("Details should be existed") diff --git a/scope.go b/scope.go index 51ebd5a0..4d453608 100644 --- a/scope.go +++ b/scope.go @@ -267,12 +267,17 @@ func (scope *Scope) AddToVars(value interface{}) string { return exp } + dialect := scope.Dialect() + if str, ok := dialect.StringifyVar(value); ok { + return str + } + scope.SQLVars = append(scope.SQLVars, value) if skipBindVar { return "?" } - return scope.Dialect().BindVar(len(scope.SQLVars)) + return dialect.BindVar(len(scope.SQLVars)) } // SelectAttrs return selected attributes From 3b41aa8f6636908f0f2e0ab75e4e0ade64bac885 Mon Sep 17 00:00:00 2001 From: wmin0 Date: Tue, 26 Dec 2017 13:26:41 +0800 Subject: [PATCH 2/2] misc: correct composing sql replace "?" logic --- logger.go | 5 ++++- main_test.go | 7 +++++++ scope.go | 39 ++++++++++++++++++++++++++------------- 3 files changed, 37 insertions(+), 14 deletions(-) diff --git a/logger.go b/logger.go index 4324a2e4..09977089 100644 --- a/logger.go +++ b/logger.go @@ -79,10 +79,13 @@ var LogFormatter = func(values ...interface{}) (messages []interface{}) { } } else { formattedValuesLength := len(formattedValues) - for index, value := range sqlRegexp.Split(values[3].(string), -1) { + s := sqlRegexp.Split(values[3].(string), -1) + for index, value := range s { sql += value if index < formattedValuesLength { sql += formattedValues[index] + } else if index != len(s)-1 { + sql += "?" } } } diff --git a/main_test.go b/main_test.go index fee8675b..6ccc49ac 100644 --- a/main_test.go +++ b/main_test.go @@ -393,6 +393,13 @@ func TestRow(t *testing.T) { if age != 10 { t.Errorf("Scan with Row") } + + age = 0 + row = DB.Debug().Table("users").Where("name != ? AND name != ? AND Age = ? AND name != ?", "???", "???", 10, "???").Select("age").Row() + row.Scan(&age) + if age != 10 { + t.Errorf("Scan with Row") + } } func TestRows(t *testing.T) { diff --git a/scope.go b/scope.go index 4d453608..3d208244 100644 --- a/scope.go +++ b/scope.go @@ -261,7 +261,7 @@ func (scope *Scope) AddToVars(value interface{}) string { if skipBindVar { scope.AddToVars(arg) } else { - exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) + exp = scope.ReplaceOnePlaceholder(exp, scope.AddToVars(arg)) } } return exp @@ -280,6 +280,20 @@ func (scope *Scope) AddToVars(value interface{}) string { return dialect.BindVar(len(scope.SQLVars)) } +func (scope *Scope) ReplaceOnePlaceholder(sql, value string) string { + quoteCount := 0 + for idx, c := range sql { + if string(c) == "'" { + quoteCount += 1 + continue + } + if string(c) == "?" && quoteCount%2 == 0 { + return string([]rune(sql)[:idx]) + value + string([]rune(sql)[idx+1:]) + } + } + return sql +} + // SelectAttrs return selected attributes func (scope *Scope) SelectAttrs() []string { if scope.selectAttrs == nil { @@ -565,22 +579,21 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri switch reflect.ValueOf(arg).Kind() { case reflect.Slice: // For where("id in (?)", []int64{1,2}) if bytes, ok := arg.([]byte); ok { - str = strings.Replace(str, "?", scope.AddToVars(bytes), 1) + str = scope.ReplaceOnePlaceholder(str, scope.AddToVars(bytes)) } else if values := reflect.ValueOf(arg); values.Len() > 0 { var tempMarks []string for i := 0; i < values.Len(); i++ { tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) } - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) + str = scope.ReplaceOnePlaceholder(str, strings.Join(tempMarks, ",")) } else { - str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1) + str = scope.ReplaceOnePlaceholder(str, scope.AddToVars(Expr("NULL"))) } default: if valuer, ok := interface{}(arg).(driver.Valuer); ok { arg, _ = valuer.Value() } - - str = strings.Replace(str, "?", scope.AddToVars(arg), 1) + str = scope.ReplaceOnePlaceholder(str, scope.AddToVars(arg)) } } return @@ -637,21 +650,21 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string switch reflect.ValueOf(arg).Kind() { case reflect.Slice: // For where("id in (?)", []int64{1,2}) if bytes, ok := arg.([]byte); ok { - str = strings.Replace(str, "?", scope.AddToVars(bytes), 1) + str = scope.ReplaceOnePlaceholder(str, scope.AddToVars(bytes)) } else if values := reflect.ValueOf(arg); values.Len() > 0 { var tempMarks []string for i := 0; i < values.Len(); i++ { tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) } - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) + str = scope.ReplaceOnePlaceholder(str, strings.Join(tempMarks, ",")) } else { - str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1) + str = scope.ReplaceOnePlaceholder(str, scope.AddToVars(Expr("NULL"))) } default: if scanner, ok := interface{}(arg).(driver.Valuer); ok { arg, _ = scanner.Value() } - str = strings.Replace(notEqualSQL, "?", scope.AddToVars(arg), 1) + str = scope.ReplaceOnePlaceholder(notEqualSQL, scope.AddToVars(arg)) } } return @@ -674,12 +687,12 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) for i := 0; i < values.Len(); i++ { tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) } - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) + str = scope.ReplaceOnePlaceholder(str, strings.Join(tempMarks, ",")) default: if valuer, ok := interface{}(arg).(driver.Valuer); ok { arg, _ = valuer.Value() } - str = strings.Replace(str, "?", scope.AddToVars(arg), 1) + str = scope.ReplaceOnePlaceholder(str, scope.AddToVars(arg)) } } return @@ -765,7 +778,7 @@ func (scope *Scope) orderSQL() string { } else if expr, ok := order.(*expr); ok { exp := expr.expr for _, arg := range expr.args { - exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) + exp = scope.ReplaceOnePlaceholder(exp, scope.AddToVars(arg)) } orders = append(orders, exp) }