misc: postgres compose sql var
This commit is contained in:
		
							parent
							
								
									0a51f6cdc5
								
							
						
					
					
						commit
						9fe3aeb2a8
					
				@ -12,7 +12,7 @@ import (
 | 
			
		||||
func TestCreate(t *testing.T) {
 | 
			
		||||
	float := 35.03554004971999
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
	user := User{Name: "CreateUser", Age: 18, Birthday: &now, UserNum: Num(111), PasswordHash: []byte{'f', 'a', 'k', '4'}, Latitude: float}
 | 
			
		||||
	user := User{Name: "CreateUser", Age: 18, Birthday: &now, UserNum: Num(111), PasswordHash: []byte{0, 'f', 'a', 'k', '4'}, Latitude: float}
 | 
			
		||||
 | 
			
		||||
	if !DB.NewRecord(user) || !DB.NewRecord(&user) {
 | 
			
		||||
		t.Error("User should be new record before create")
 | 
			
		||||
@ -29,7 +29,7 @@ func TestCreate(t *testing.T) {
 | 
			
		||||
	var newUser User
 | 
			
		||||
	DB.First(&newUser, user.Id)
 | 
			
		||||
 | 
			
		||||
	if !reflect.DeepEqual(newUser.PasswordHash, []byte{'f', 'a', 'k', '4'}) {
 | 
			
		||||
	if !reflect.DeepEqual(newUser.PasswordHash, []byte{0, 'f', 'a', 'k', '4'}) {
 | 
			
		||||
		t.Errorf("User's PasswordHash should be saved ([]byte)")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,9 @@ type Dialect interface {
 | 
			
		||||
	// SetDB set db for dialect
 | 
			
		||||
	SetDB(db SQLCommon)
 | 
			
		||||
 | 
			
		||||
	// StringifyVar return escape string of var
 | 
			
		||||
	StringifyVar(value interface{}) (string, bool)
 | 
			
		||||
 | 
			
		||||
	// BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1
 | 
			
		||||
	BindVar(i int) string
 | 
			
		||||
	// Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name
 | 
			
		||||
 | 
			
		||||
@ -30,6 +30,10 @@ func (s *commonDialect) SetDB(db SQLCommon) {
 | 
			
		||||
	s.db = db
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (commonDialect) StringifyVar(value interface{}) (string, bool) {
 | 
			
		||||
	return "", false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (commonDialect) BindVar(i int) string {
 | 
			
		||||
	return "$$$" // ?
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -1,10 +1,14 @@
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql/driver"
 | 
			
		||||
	"encoding/hex"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	dPostgres "github.com/jinzhu/gorm/dialects/postgres"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type postgres struct {
 | 
			
		||||
@ -20,6 +24,88 @@ func (postgres) GetName() string {
 | 
			
		||||
	return "postgres"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func literal(s string) string {
 | 
			
		||||
	p := ""
 | 
			
		||||
 | 
			
		||||
	if strings.Contains(s, `\`) {
 | 
			
		||||
		p = "E"
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	s = strings.Replace(s, `'`, `''`, -1)
 | 
			
		||||
	s = strings.Replace(s, `\`, `\\`, -1)
 | 
			
		||||
	return p + `'` + s + `'`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func isNil(value interface{}) (ret bool) {
 | 
			
		||||
	ret = false
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if e := recover(); e != nil {
 | 
			
		||||
			// DO NOTHING
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	if value == nil {
 | 
			
		||||
		ret = true
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if reflect.ValueOf(value).IsNil() {
 | 
			
		||||
		ret = true
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p postgres) StringifyVar(value interface{}) (ret string, ok bool) {
 | 
			
		||||
	ok = true
 | 
			
		||||
	if isNil(value) {
 | 
			
		||||
		ret = "NULL"
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if reflect.TypeOf(value).Kind() == reflect.Ptr {
 | 
			
		||||
		ret, ok = p.StringifyVar(reflect.ValueOf(value).Elem().Interface())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	switch value.(type) {
 | 
			
		||||
	case string:
 | 
			
		||||
		s := value.(string)
 | 
			
		||||
		ret = literal(s)
 | 
			
		||||
		return
 | 
			
		||||
	case time.Time:
 | 
			
		||||
		s := value.(time.Time)
 | 
			
		||||
		ret = literal(s.Format(time.RFC3339Nano))
 | 
			
		||||
		return
 | 
			
		||||
	case dPostgres.Hstore:
 | 
			
		||||
		s := value.(dPostgres.Hstore)
 | 
			
		||||
		if v, err := s.Value(); err == nil {
 | 
			
		||||
			ret = literal(string(v.([]byte)))
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	case []byte:
 | 
			
		||||
		s := value.([]byte)
 | 
			
		||||
		ret = "decode(" + literal(hex.EncodeToString(s)) + ", 'hex')"
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if s, ok2 := value.(driver.Valuer); ok2 {
 | 
			
		||||
		if v, err := s.Value(); err == nil {
 | 
			
		||||
			ret, ok = p.StringifyVar(v)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if s, ok2 := value.(fmt.Stringer); ok2 {
 | 
			
		||||
		ret = literal(s.String())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ret = fmt.Sprintf("%v", value)
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (postgres) BindVar(i int) string {
 | 
			
		||||
	return fmt.Sprintf("$%v", i)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -49,6 +49,10 @@ func (s *mssql) SetDB(db gorm.SQLCommon) {
 | 
			
		||||
	s.db = db
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (mssql) StringifyVar(value interface{}) (string, bool) {
 | 
			
		||||
	return "", false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (mssql) BindVar(i int) string {
 | 
			
		||||
	return "$$$" // ?
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										12
									
								
								main_test.go
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								main_test.go
									
									
									
									
									
								
							@ -727,6 +727,7 @@ func TestHstore(t *testing.T) {
 | 
			
		||||
		"bankAccountId": &bankAccountId,
 | 
			
		||||
		"phoneNumber":   &phoneNumber,
 | 
			
		||||
		"opinion":       &opinion,
 | 
			
		||||
		"nil":           nil,
 | 
			
		||||
	}
 | 
			
		||||
	d := Details{Bulk: bulk}
 | 
			
		||||
	DB.Save(&d)
 | 
			
		||||
@ -738,8 +739,15 @@ func TestHstore(t *testing.T) {
 | 
			
		||||
 | 
			
		||||
	for k := range bulk {
 | 
			
		||||
		if r, ok := d2.Bulk[k]; ok {
 | 
			
		||||
			if res, _ := bulk[k]; *res != *r {
 | 
			
		||||
				t.Errorf("Details should be equal")
 | 
			
		||||
			res, _ := bulk[k]
 | 
			
		||||
			if res == nil || r == nil {
 | 
			
		||||
				if res != r {
 | 
			
		||||
					t.Errorf("Details should be equal")
 | 
			
		||||
				}
 | 
			
		||||
			} else {
 | 
			
		||||
				if *res != *r {
 | 
			
		||||
					t.Errorf("Details should be equal")
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
			t.Errorf("Details should be existed")
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										7
									
								
								scope.go
									
									
									
									
									
								
							
							
						
						
									
										7
									
								
								scope.go
									
									
									
									
									
								
							@ -267,12 +267,17 @@ func (scope *Scope) AddToVars(value interface{}) string {
 | 
			
		||||
		return exp
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	dialect := scope.Dialect()
 | 
			
		||||
	if str, ok := dialect.StringifyVar(value); ok {
 | 
			
		||||
		return str
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	scope.SQLVars = append(scope.SQLVars, value)
 | 
			
		||||
 | 
			
		||||
	if skipBindVar {
 | 
			
		||||
		return "?"
 | 
			
		||||
	}
 | 
			
		||||
	return scope.Dialect().BindVar(len(scope.SQLVars))
 | 
			
		||||
	return dialect.BindVar(len(scope.SQLVars))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SelectAttrs return selected attributes
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user