Merge branch 'master' into master

This commit is contained in:
Jinzhu 2019-10-17 23:17:39 +08:00 committed by GitHub
commit 74805f9cd9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 122 additions and 70 deletions

View File

@ -96,7 +96,7 @@ func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor {
func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) {
if cp.kind == "row_query" {
if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" {
cp.logger.Print(fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...\n", callbackName))
cp.logger.Print("info", fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...", callbackName))
cp.before = "gorm:row_query"
}
}
@ -110,7 +110,7 @@ func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *
// Remove a registered callback
// db.Callback().Create().Remove("gorm:update_time_stamp_when_create")
func (cp *CallbackProcessor) Remove(callbackName string) {
cp.logger.Print(fmt.Sprintf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum()))
cp.logger.Print("info", fmt.Sprintf("[info] removing callback `%v` from %v", callbackName, fileWithLineNum()))
cp.name = callbackName
cp.remove = true
cp.parent.processors = append(cp.parent.processors, cp)
@ -119,11 +119,11 @@ func (cp *CallbackProcessor) Remove(callbackName string) {
// Replace a registered callback with new callback
// db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) {
// scope.SetColumn("Created", now)
// scope.SetColumn("Updated", now)
// scope.SetColumn("CreatedAt", now)
// scope.SetColumn("UpdatedAt", now)
// })
func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
cp.logger.Print(fmt.Sprintf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum()))
cp.logger.Print("info", fmt.Sprintf("[info] replacing callback `%v` from %v", callbackName, fileWithLineNum()))
cp.name = callbackName
cp.processor = &callback
cp.replace = true
@ -135,11 +135,15 @@ func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *S
// db.Callback().Create().Get("gorm:create")
func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) {
for _, p := range cp.parent.processors {
if p.name == callbackName && p.kind == cp.kind && !cp.remove {
return *p.processor
if p.name == callbackName && p.kind == cp.kind {
if p.remove {
callback = nil
} else {
callback = *p.processor
}
}
return nil
}
return
}
// getRIndex get right index from string slice
@ -162,7 +166,7 @@ func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) {
for _, cp := range cps {
// show warning message the callback name already exists
if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove {
cp.logger.Print(fmt.Sprintf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum()))
cp.logger.Print("warning", fmt.Sprintf("[warning] duplicated callback `%v` from %v", cp.name, fileWithLineNum()))
}
allNames = append(allNames, cp.name)
}

View File

@ -101,6 +101,7 @@ func createCallback(scope *Scope) {
}
lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn)
lastInsertIDOutputInterstitial := scope.Dialect().LastInsertIDOutputInterstitial(quotedTableName, returningColumn, columns)
if len(columns) == 0 {
scope.Raw(fmt.Sprintf(
@ -113,18 +114,19 @@ func createCallback(scope *Scope) {
))
} else {
scope.Raw(fmt.Sprintf(
"INSERT %v INTO %v (%v) VALUES (%v)%v%v",
"INSERT%v INTO %v (%v)%v VALUES (%v)%v%v",
addExtraSpaceIfExist(insertModifier),
scope.QuotedTableName(),
strings.Join(columns, ","),
addExtraSpaceIfExist(lastInsertIDOutputInterstitial),
strings.Join(placeholders, ","),
addExtraSpaceIfExist(extraOption),
addExtraSpaceIfExist(lastInsertIDReturningSuffix),
))
}
// execute create sql
if lastInsertIDReturningSuffix == "" || primaryField == nil {
// execute create sql: no primaryField
if primaryField == nil {
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
// set rows affected count
scope.db.RowsAffected, _ = result.RowsAffected()
@ -136,7 +138,26 @@ func createCallback(scope *Scope) {
}
}
}
} else {
return
}
// execute create sql: lastInsertID implemention for majority of dialects
if lastInsertIDReturningSuffix == "" && lastInsertIDOutputInterstitial == "" {
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
// set rows affected count
scope.db.RowsAffected, _ = result.RowsAffected()
// set primary value to primary field
if primaryField != nil && primaryField.IsBlank {
if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil {
scope.Err(primaryField.Set(primaryValue))
}
}
}
return
}
// execute create sql: dialects with additional lastInsertID requirements (currently postgres & mssql)
if primaryField.Field.CanAddr() {
if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
primaryField.IsBlank = false
@ -145,7 +166,7 @@ func createCallback(scope *Scope) {
} else {
scope.Err(ErrUnaddressable)
}
}
return
}
}

View File

@ -17,7 +17,7 @@ func init() {
// beforeDeleteCallback will invoke `BeforeDelete` method before deleting
func beforeDeleteCallback(scope *Scope) {
if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
scope.Err(errors.New("Missing WHERE clause while deleting"))
scope.Err(errors.New("missing WHERE clause while deleting"))
return
}
if !scope.HasError() {

View File

@ -34,7 +34,7 @@ func assignUpdatingAttributesCallback(scope *Scope) {
// beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating
func beforeUpdateCallback(scope *Scope) {
if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
scope.Err(errors.New("Missing WHERE clause while updating"))
scope.Err(errors.New("missing WHERE clause while updating"))
return
}
if _, ok := scope.Get("gorm:update_column"); !ok {

View File

@ -2,11 +2,10 @@ package gorm_test
import (
"errors"
"github.com/jinzhu/gorm"
"reflect"
"testing"
"github.com/jinzhu/gorm"
)
func (s *Product) BeforeCreate() (err error) {
@ -175,3 +174,46 @@ func TestCallbacksWithErrors(t *testing.T) {
t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback")
}
}
func TestGetCallback(t *testing.T) {
scope := DB.NewScope(nil)
if DB.Callback().Create().Get("gorm:test_callback") != nil {
t.Errorf("`gorm:test_callback` should be nil")
}
DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 1) })
callback := DB.Callback().Create().Get("gorm:test_callback")
if callback == nil {
t.Errorf("`gorm:test_callback` should be non-nil")
}
callback(scope)
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 1 {
t.Errorf("`gorm:test_callback_value` should be `1, true` but `%v, %v`", v, ok)
}
DB.Callback().Create().Replace("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 2) })
callback = DB.Callback().Create().Get("gorm:test_callback")
if callback == nil {
t.Errorf("`gorm:test_callback` should be non-nil")
}
callback(scope)
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 2 {
t.Errorf("`gorm:test_callback_value` should be `2, true` but `%v, %v`", v, ok)
}
DB.Callback().Create().Remove("gorm:test_callback")
if DB.Callback().Create().Get("gorm:test_callback") != nil {
t.Errorf("`gorm:test_callback` should be nil")
}
DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 3) })
callback = DB.Callback().Create().Get("gorm:test_callback")
if callback == nil {
t.Errorf("`gorm:test_callback` should be non-nil")
}
callback(scope)
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 3 {
t.Errorf("`gorm:test_callback_value` should be `3, true` but `%v, %v`", v, ok)
}
}

View File

@ -40,6 +40,8 @@ type Dialect interface {
LimitAndOffsetSQL(limit, offset interface{}) string
// SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
SelectFromDummyTable() string
// LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT`
LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
LastInsertIDReturningSuffix(tableName, columnName string) string
// DefaultValueStr

View File

@ -157,6 +157,10 @@ func (commonDialect) SelectFromDummyTable() string {
return ""
}
func (commonDialect) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string {
return ""
}
func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string {
return ""
}

View File

@ -120,7 +120,7 @@ func (s *mysql) DataTypeOf(field *StructField) string {
}
if sqlType == "" {
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", dataValue.Type().Name(), dataValue.Kind().String()))
panic(fmt.Sprintf("invalid sql type %s (%s) in field %s for mysql", dataValue.Type().Name(), dataValue.Kind().String(), field.Name))
}
if strings.TrimSpace(additionalType) == "" {

View File

@ -120,6 +120,10 @@ func (s postgres) CurrentDatabase() (name string) {
return
}
func (s postgres) LastInsertIDOutputInterstitial(tableName, key string, columns []string) string {
return ""
}
func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string {
return fmt.Sprintf("RETURNING %v.%v", tableName, key)
}

View File

@ -190,6 +190,14 @@ func (mssql) SelectFromDummyTable() string {
return ""
}
func (mssql) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string {
if len(columns) == 0 {
// No OUTPUT to query
return ""
}
return fmt.Sprintf("OUTPUT Inserted.%v", columnName)
}
func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string {
return ""
}

View File

@ -7,6 +7,7 @@ import (
"encoding/json"
"errors"
"fmt"
_ "github.com/lib/pq"
"github.com/lib/pq/hstore"
)

2
go.mod
View File

@ -9,5 +9,5 @@ require (
github.com/jinzhu/inflection v1.0.0
github.com/jinzhu/now v1.0.1
github.com/lib/pq v1.1.1
github.com/mattn/go-sqlite3 v1.10.0
github.com/mattn/go-sqlite3 v1.11.0
)

4
go.sum
View File

@ -52,8 +52,8 @@ github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxv
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
github.com/lib/pq v1.1.1 h1:sJZmqHoEaY7f+NPP8pgLB/WxulyR3fewgCM2qaSlBb4=
github.com/lib/pq v1.1.1/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/mattn/go-sqlite3 v1.10.0 h1:jbhqpg7tQe4SupckyijYiy0mJJ/pRyHvXf7JdWK860o=
github.com/mattn/go-sqlite3 v1.10.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/mattn/go-sqlite3 v1.11.0 h1:LDdKkqtYlom37fkvqs8rMPFKAMe8+SgjbwZ6ex1/A/Q=
github.com/mattn/go-sqlite3 v1.11.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=

View File

@ -49,7 +49,11 @@ var LogFormatter = func(values ...interface{}) (messages []interface{}) {
if indirectValue.IsValid() {
value = indirectValue.Interface()
if t, ok := value.(time.Time); ok {
if t.IsZero() {
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", "0000-00-00 00:00:00"))
} else {
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", t.Format("2006-01-02 15:04:05")))
}
} else if b, ok := value.([]byte); ok {
if str := string(b); isPrintable(str) {
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", str))

View File

@ -133,23 +133,6 @@ func TestStringPrimaryKeyForNumericValueStartingWithZero(t *testing.T) {
t.Errorf("Fetch a record from with a string primary key for a numeric value starting with zero should work, but failed, zip code is %v", address.ZipCode)
}
}
func TestStringAgainstIncompleteParentheses(t *testing.T) {
type AddressByZipCode struct {
ZipCode string `gorm:"primary_key"`
Address string
}
DB.AutoMigrate(&AddressByZipCode{})
DB.Create(&AddressByZipCode{ZipCode: "00502", Address: "Holtsville"})
var address AddressByZipCode
var addresses []AddressByZipCode
_ = DB.First(&address, "address_by_zip_codes=00502)) UNION ALL SELECT NULL,version(),current_database(),NULL,NULL,NULL,NULL,NULL--").Find(&addresses).GetErrors()
if len(addresses) > 0 {
t.Errorf("Fetch a record from with a string that has incomplete parentheses should be fail, zip code is %v", address.ZipCode)
}
}
func TestFindAsSliceOfPointers(t *testing.T) {
DB.Save(&User{Name: "user"})

View File

@ -277,23 +277,6 @@ func (scope *Scope) AddToVars(value interface{}) string {
return scope.Dialect().BindVar(len(scope.SQLVars))
}
// IsCompleteParentheses check if the string has complete parentheses to prevent SQL injection
func (scope *Scope) IsCompleteParentheses(value string) bool {
count := 0
for i, _ := range value {
if value[i] == 40 { // (
count++
} else if value[i] == 41 { // )
count--
}
if count < 0 {
break
}
i++
}
return count == 0
}
// SelectAttrs return selected attributes
func (scope *Scope) SelectAttrs() []string {
if scope.selectAttrs == nil {
@ -573,10 +556,6 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool)
}
if value != "" {
if !scope.IsCompleteParentheses(value) {
scope.Err(fmt.Errorf("incomplete parentheses found: %v", value))
return
}
if !include {
if comparisonRegexp.MatchString(value) {
str = fmt.Sprintf("NOT (%v)", value)