Merge branch 'master' into master
This commit is contained in:
commit
c429ef5f11
28
callback.go
28
callback.go
@ -96,11 +96,17 @@ func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor {
|
||||
func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) {
|
||||
if cp.kind == "row_query" {
|
||||
if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" {
|
||||
cp.logger.Print(fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...\n", callbackName))
|
||||
cp.logger.Print("info", fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...", callbackName))
|
||||
cp.before = "gorm:row_query"
|
||||
}
|
||||
}
|
||||
|
||||
if cp.logger != nil {
|
||||
// note cp.logger will be nil during the default gorm callback registrations
|
||||
// as they occur within init() blocks. However, any user-registered callbacks
|
||||
// will happen after cp.logger exists (as the default logger or user-specified).
|
||||
cp.logger.Print("info", fmt.Sprintf("[info] registering callback `%v` from %v", callbackName, fileWithLineNum()))
|
||||
}
|
||||
cp.name = callbackName
|
||||
cp.processor = &callback
|
||||
cp.parent.processors = append(cp.parent.processors, cp)
|
||||
@ -110,7 +116,7 @@ func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *
|
||||
// Remove a registered callback
|
||||
// db.Callback().Create().Remove("gorm:update_time_stamp_when_create")
|
||||
func (cp *CallbackProcessor) Remove(callbackName string) {
|
||||
cp.logger.Print(fmt.Sprintf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum()))
|
||||
cp.logger.Print("info", fmt.Sprintf("[info] removing callback `%v` from %v", callbackName, fileWithLineNum()))
|
||||
cp.name = callbackName
|
||||
cp.remove = true
|
||||
cp.parent.processors = append(cp.parent.processors, cp)
|
||||
@ -119,11 +125,11 @@ func (cp *CallbackProcessor) Remove(callbackName string) {
|
||||
|
||||
// Replace a registered callback with new callback
|
||||
// db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) {
|
||||
// scope.SetColumn("Created", now)
|
||||
// scope.SetColumn("Updated", now)
|
||||
// scope.SetColumn("CreatedAt", now)
|
||||
// scope.SetColumn("UpdatedAt", now)
|
||||
// })
|
||||
func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
|
||||
cp.logger.Print(fmt.Sprintf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum()))
|
||||
cp.logger.Print("info", fmt.Sprintf("[info] replacing callback `%v` from %v", callbackName, fileWithLineNum()))
|
||||
cp.name = callbackName
|
||||
cp.processor = &callback
|
||||
cp.replace = true
|
||||
@ -135,11 +141,15 @@ func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *S
|
||||
// db.Callback().Create().Get("gorm:create")
|
||||
func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) {
|
||||
for _, p := range cp.parent.processors {
|
||||
if p.name == callbackName && p.kind == cp.kind && !cp.remove {
|
||||
return *p.processor
|
||||
if p.name == callbackName && p.kind == cp.kind {
|
||||
if p.remove {
|
||||
callback = nil
|
||||
} else {
|
||||
callback = *p.processor
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// getRIndex get right index from string slice
|
||||
@ -162,7 +172,7 @@ func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) {
|
||||
for _, cp := range cps {
|
||||
// show warning message the callback name already exists
|
||||
if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove {
|
||||
cp.logger.Print(fmt.Sprintf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum()))
|
||||
cp.logger.Print("warning", fmt.Sprintf("[warning] duplicated callback `%v` from %v", cp.name, fileWithLineNum()))
|
||||
}
|
||||
allNames = append(allNames, cp.name)
|
||||
}
|
||||
|
@ -101,10 +101,11 @@ func createCallback(scope *Scope) {
|
||||
}
|
||||
|
||||
lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn)
|
||||
lastInsertIDOutputInterstitial := scope.Dialect().LastInsertIDOutputInterstitial(quotedTableName, returningColumn, columns)
|
||||
|
||||
if len(columns) == 0 {
|
||||
scope.Raw(fmt.Sprintf(
|
||||
"INSERT %v INTO %v %v%v%v",
|
||||
"INSERT%v INTO %v %v%v%v",
|
||||
addExtraSpaceIfExist(insertModifier),
|
||||
quotedTableName,
|
||||
scope.Dialect().DefaultValueStr(),
|
||||
@ -113,18 +114,19 @@ func createCallback(scope *Scope) {
|
||||
))
|
||||
} else {
|
||||
scope.Raw(fmt.Sprintf(
|
||||
"INSERT %v INTO %v (%v) VALUES (%v)%v%v",
|
||||
"INSERT%v INTO %v (%v)%v VALUES (%v)%v%v",
|
||||
addExtraSpaceIfExist(insertModifier),
|
||||
scope.QuotedTableName(),
|
||||
strings.Join(columns, ","),
|
||||
addExtraSpaceIfExist(lastInsertIDOutputInterstitial),
|
||||
strings.Join(placeholders, ","),
|
||||
addExtraSpaceIfExist(extraOption),
|
||||
addExtraSpaceIfExist(lastInsertIDReturningSuffix),
|
||||
))
|
||||
}
|
||||
|
||||
// execute create sql
|
||||
if lastInsertIDReturningSuffix == "" || primaryField == nil {
|
||||
// execute create sql: no primaryField
|
||||
if primaryField == nil {
|
||||
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
||||
// set rows affected count
|
||||
scope.db.RowsAffected, _ = result.RowsAffected()
|
||||
@ -136,7 +138,26 @@ func createCallback(scope *Scope) {
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return
|
||||
}
|
||||
|
||||
// execute create sql: lastInsertID implemention for majority of dialects
|
||||
if lastInsertIDReturningSuffix == "" && lastInsertIDOutputInterstitial == "" {
|
||||
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
||||
// set rows affected count
|
||||
scope.db.RowsAffected, _ = result.RowsAffected()
|
||||
|
||||
// set primary value to primary field
|
||||
if primaryField != nil && primaryField.IsBlank {
|
||||
if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil {
|
||||
scope.Err(primaryField.Set(primaryValue))
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// execute create sql: dialects with additional lastInsertID requirements (currently postgres & mssql)
|
||||
if primaryField.Field.CanAddr() {
|
||||
if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
|
||||
primaryField.IsBlank = false
|
||||
@ -145,7 +166,7 @@ func createCallback(scope *Scope) {
|
||||
} else {
|
||||
scope.Err(ErrUnaddressable)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -17,7 +17,7 @@ func init() {
|
||||
// beforeDeleteCallback will invoke `BeforeDelete` method before deleting
|
||||
func beforeDeleteCallback(scope *Scope) {
|
||||
if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
|
||||
scope.Err(errors.New("Missing WHERE clause while deleting"))
|
||||
scope.Err(errors.New("missing WHERE clause while deleting"))
|
||||
return
|
||||
}
|
||||
if !scope.HasError() {
|
||||
|
@ -34,7 +34,7 @@ func assignUpdatingAttributesCallback(scope *Scope) {
|
||||
// beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating
|
||||
func beforeUpdateCallback(scope *Scope) {
|
||||
if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
|
||||
scope.Err(errors.New("Missing WHERE clause while updating"))
|
||||
scope.Err(errors.New("missing WHERE clause while updating"))
|
||||
return
|
||||
}
|
||||
if _, ok := scope.Get("gorm:update_column"); !ok {
|
||||
|
@ -2,11 +2,10 @@ package gorm_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
func (s *Product) BeforeCreate() (err error) {
|
||||
@ -175,3 +174,46 @@ func TestCallbacksWithErrors(t *testing.T) {
|
||||
t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCallback(t *testing.T) {
|
||||
scope := DB.NewScope(nil)
|
||||
|
||||
if DB.Callback().Create().Get("gorm:test_callback") != nil {
|
||||
t.Errorf("`gorm:test_callback` should be nil")
|
||||
}
|
||||
|
||||
DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 1) })
|
||||
callback := DB.Callback().Create().Get("gorm:test_callback")
|
||||
if callback == nil {
|
||||
t.Errorf("`gorm:test_callback` should be non-nil")
|
||||
}
|
||||
callback(scope)
|
||||
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 1 {
|
||||
t.Errorf("`gorm:test_callback_value` should be `1, true` but `%v, %v`", v, ok)
|
||||
}
|
||||
|
||||
DB.Callback().Create().Replace("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 2) })
|
||||
callback = DB.Callback().Create().Get("gorm:test_callback")
|
||||
if callback == nil {
|
||||
t.Errorf("`gorm:test_callback` should be non-nil")
|
||||
}
|
||||
callback(scope)
|
||||
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 2 {
|
||||
t.Errorf("`gorm:test_callback_value` should be `2, true` but `%v, %v`", v, ok)
|
||||
}
|
||||
|
||||
DB.Callback().Create().Remove("gorm:test_callback")
|
||||
if DB.Callback().Create().Get("gorm:test_callback") != nil {
|
||||
t.Errorf("`gorm:test_callback` should be nil")
|
||||
}
|
||||
|
||||
DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 3) })
|
||||
callback = DB.Callback().Create().Get("gorm:test_callback")
|
||||
if callback == nil {
|
||||
t.Errorf("`gorm:test_callback` should be non-nil")
|
||||
}
|
||||
callback(scope)
|
||||
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 3 {
|
||||
t.Errorf("`gorm:test_callback_value` should be `3, true` but `%v, %v`", v, ok)
|
||||
}
|
||||
}
|
||||
|
@ -40,6 +40,8 @@ type Dialect interface {
|
||||
LimitAndOffsetSQL(limit, offset interface{}) string
|
||||
// SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
|
||||
SelectFromDummyTable() string
|
||||
// LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT`
|
||||
LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string
|
||||
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
|
||||
LastInsertIDReturningSuffix(tableName, columnName string) string
|
||||
// DefaultValueStr
|
||||
|
@ -157,6 +157,10 @@ func (commonDialect) SelectFromDummyTable() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (commonDialect) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string {
|
||||
return ""
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package gorm
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
@ -102,10 +103,10 @@ func (s *mysql) DataTypeOf(field *StructField) string {
|
||||
precision = fmt.Sprintf("(%s)", p)
|
||||
}
|
||||
|
||||
if _, ok := field.TagSettingsGet("NOT NULL"); ok || field.IsPrimaryKey {
|
||||
sqlType = fmt.Sprintf("timestamp%v", precision)
|
||||
if _, ok := field.TagSettings["NOT NULL"]; ok || field.IsPrimaryKey {
|
||||
sqlType = fmt.Sprintf("DATETIME%v", precision)
|
||||
} else {
|
||||
sqlType = fmt.Sprintf("timestamp%v NULL", precision)
|
||||
sqlType = fmt.Sprintf("DATETIME%v NULL", precision)
|
||||
}
|
||||
}
|
||||
default:
|
||||
@ -120,7 +121,7 @@ func (s *mysql) DataTypeOf(field *StructField) string {
|
||||
}
|
||||
|
||||
if sqlType == "" {
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", dataValue.Type().Name(), dataValue.Kind().String()))
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) in field %s for mysql", dataValue.Type().Name(), dataValue.Kind().String(), field.Name))
|
||||
}
|
||||
|
||||
if strings.TrimSpace(additionalType) == "" {
|
||||
@ -161,6 +162,40 @@ func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool {
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s mysql) HasTable(tableName string) bool {
|
||||
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
|
||||
var name string
|
||||
// allow mysql database name with '-' character
|
||||
if err := s.db.QueryRow(fmt.Sprintf("SHOW TABLES FROM `%s` WHERE `Tables_in_%s` = ?", currentDatabase, currentDatabase), tableName).Scan(&name); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return false
|
||||
}
|
||||
panic(err)
|
||||
} else {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (s mysql) HasIndex(tableName string, indexName string) bool {
|
||||
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
|
||||
if rows, err := s.db.Query(fmt.Sprintf("SHOW INDEXES FROM `%s` FROM `%s` WHERE Key_name = ?", tableName, currentDatabase), indexName); err != nil {
|
||||
panic(err)
|
||||
} else {
|
||||
defer rows.Close()
|
||||
return rows.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func (s mysql) HasColumn(tableName string, columnName string) bool {
|
||||
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
|
||||
if rows, err := s.db.Query(fmt.Sprintf("SHOW COLUMNS FROM `%s` FROM `%s` WHERE Field = ?", tableName, currentDatabase), columnName); err != nil {
|
||||
panic(err)
|
||||
} else {
|
||||
defer rows.Close()
|
||||
return rows.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func (s mysql) CurrentDatabase() (name string) {
|
||||
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
|
||||
return
|
||||
|
@ -120,6 +120,10 @@ func (s postgres) CurrentDatabase() (name string) {
|
||||
return
|
||||
}
|
||||
|
||||
func (s postgres) LastInsertIDOutputInterstitial(tableName, key string, columns []string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string {
|
||||
return fmt.Sprintf("RETURNING %v.%v", tableName, key)
|
||||
}
|
||||
|
@ -190,6 +190,14 @@ func (mssql) SelectFromDummyTable() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (mssql) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string {
|
||||
if len(columns) == 0 {
|
||||
// No OUTPUT to query
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("OUTPUT Inserted.%v", columnName)
|
||||
}
|
||||
|
||||
func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string {
|
||||
return ""
|
||||
}
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/lib/pq/hstore"
|
||||
)
|
||||
|
2
go.mod
2
go.mod
@ -9,5 +9,5 @@ require (
|
||||
github.com/jinzhu/inflection v1.0.0
|
||||
github.com/jinzhu/now v1.0.1
|
||||
github.com/lib/pq v1.1.1
|
||||
github.com/mattn/go-sqlite3 v1.10.0
|
||||
github.com/mattn/go-sqlite3 v1.11.0
|
||||
)
|
||||
|
4
go.sum
4
go.sum
@ -52,8 +52,8 @@ github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxv
|
||||
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
|
||||
github.com/lib/pq v1.1.1 h1:sJZmqHoEaY7f+NPP8pgLB/WxulyR3fewgCM2qaSlBb4=
|
||||
github.com/lib/pq v1.1.1/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
|
||||
github.com/mattn/go-sqlite3 v1.10.0 h1:jbhqpg7tQe4SupckyijYiy0mJJ/pRyHvXf7JdWK860o=
|
||||
github.com/mattn/go-sqlite3 v1.10.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
|
||||
github.com/mattn/go-sqlite3 v1.11.0 h1:LDdKkqtYlom37fkvqs8rMPFKAMe8+SgjbwZ6ex1/A/Q=
|
||||
github.com/mattn/go-sqlite3 v1.11.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
|
||||
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
|
||||
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
|
||||
|
@ -49,7 +49,11 @@ var LogFormatter = func(values ...interface{}) (messages []interface{}) {
|
||||
if indirectValue.IsValid() {
|
||||
value = indirectValue.Interface()
|
||||
if t, ok := value.(time.Time); ok {
|
||||
if t.IsZero() {
|
||||
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", "0000-00-00 00:00:00"))
|
||||
} else {
|
||||
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", t.Format("2006-01-02 15:04:05")))
|
||||
}
|
||||
} else if b, ok := value.([]byte); ok {
|
||||
if str := string(b); isPrintable(str) {
|
||||
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", str))
|
||||
|
8
main.go
8
main.go
@ -210,8 +210,8 @@ func (s *DB) NewScope(value interface{}) *Scope {
|
||||
return scope
|
||||
}
|
||||
|
||||
// QueryExpr returns the query as expr object
|
||||
func (s *DB) QueryExpr() *expr {
|
||||
// QueryExpr returns the query as SqlExpr object
|
||||
func (s *DB) QueryExpr() *SqlExpr {
|
||||
scope := s.NewScope(s.Value)
|
||||
scope.InstanceSet("skip_bindvar", true)
|
||||
scope.prepareQuerySQL()
|
||||
@ -220,7 +220,7 @@ func (s *DB) QueryExpr() *expr {
|
||||
}
|
||||
|
||||
// SubQuery returns the query as sub query
|
||||
func (s *DB) SubQuery() *expr {
|
||||
func (s *DB) SubQuery() *SqlExpr {
|
||||
scope := s.NewScope(s.Value)
|
||||
scope.InstanceSet("skip_bindvar", true)
|
||||
scope.prepareQuerySQL()
|
||||
@ -435,6 +435,7 @@ func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
|
||||
}
|
||||
|
||||
// Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
|
||||
// WARNING when update with struct, GORM will not update fields that with zero value
|
||||
func (s *DB) Update(attrs ...interface{}) *DB {
|
||||
return s.Updates(toSearchableMap(attrs...), true)
|
||||
}
|
||||
@ -481,6 +482,7 @@ func (s *DB) Create(value interface{}) *DB {
|
||||
}
|
||||
|
||||
// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
|
||||
// WARNING If model has DeletedAt field, GORM will only set field DeletedAt's value to current time
|
||||
func (s *DB) Delete(value interface{}, where ...interface{}) *DB {
|
||||
return s.NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db
|
||||
}
|
||||
|
@ -17,6 +17,10 @@ var DefaultTableNameHandler = func(db *DB, defaultTableName string) string {
|
||||
return defaultTableName
|
||||
}
|
||||
|
||||
// lock for mutating global cached model metadata
|
||||
var structsLock sync.Mutex
|
||||
|
||||
// global cache of model metadata
|
||||
var modelStructsMap sync.Map
|
||||
|
||||
// ModelStruct model definition
|
||||
@ -419,8 +423,12 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||
for idx, foreignKey := range foreignKeys {
|
||||
if foreignField := getForeignField(foreignKey, toFields); foreignField != nil {
|
||||
if associationField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); associationField != nil {
|
||||
// source foreign keys
|
||||
// mark field as foreignkey, use global lock to avoid race
|
||||
structsLock.Lock()
|
||||
foreignField.IsForeignKey = true
|
||||
structsLock.Unlock()
|
||||
|
||||
// association foreign keys
|
||||
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name)
|
||||
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName)
|
||||
|
||||
@ -523,8 +531,12 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||
for idx, foreignKey := range foreignKeys {
|
||||
if foreignField := getForeignField(foreignKey, toFields); foreignField != nil {
|
||||
if scopeField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); scopeField != nil {
|
||||
// mark field as foreignkey, use global lock to avoid race
|
||||
structsLock.Lock()
|
||||
foreignField.IsForeignKey = true
|
||||
// source foreign keys
|
||||
structsLock.Unlock()
|
||||
|
||||
// association foreign keys
|
||||
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scopeField.Name)
|
||||
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scopeField.DBName)
|
||||
|
||||
@ -582,7 +594,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||
for idx, foreignKey := range foreignKeys {
|
||||
if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil {
|
||||
if associationField := getForeignField(associationForeignKeys[idx], toFields); associationField != nil {
|
||||
// mark field as foreignkey, use global lock to avoid race
|
||||
structsLock.Lock()
|
||||
foreignField.IsForeignKey = true
|
||||
structsLock.Unlock()
|
||||
|
||||
// association foreign keys
|
||||
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name)
|
||||
|
93
model_struct_test.go
Normal file
93
model_struct_test.go
Normal file
@ -0,0 +1,93 @@
|
||||
package gorm_test
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
type ModelA struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
|
||||
ModelCs []ModelC `gorm:"foreignkey:OtherAID"`
|
||||
}
|
||||
|
||||
type ModelB struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
|
||||
ModelCs []ModelC `gorm:"foreignkey:OtherBID"`
|
||||
}
|
||||
|
||||
type ModelC struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
|
||||
OtherAID uint64
|
||||
OtherA *ModelA `gorm:"foreignkey:OtherAID"`
|
||||
OtherBID uint64
|
||||
OtherB *ModelB `gorm:"foreignkey:OtherBID"`
|
||||
}
|
||||
|
||||
// This test will try to cause a race condition on the model's foreignkey metadata
|
||||
func TestModelStructRaceSameModel(t *testing.T) {
|
||||
// use a WaitGroup to execute as much in-sync as possible
|
||||
// it's more likely to hit a race condition than without
|
||||
n := 32
|
||||
start := sync.WaitGroup{}
|
||||
start.Add(n)
|
||||
|
||||
// use another WaitGroup to know when the test is done
|
||||
done := sync.WaitGroup{}
|
||||
done.Add(n)
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
go func() {
|
||||
start.Wait()
|
||||
|
||||
// call GetStructFields, this had a race condition before we fixed it
|
||||
DB.NewScope(&ModelA{}).GetStructFields()
|
||||
|
||||
done.Done()
|
||||
}()
|
||||
|
||||
start.Done()
|
||||
}
|
||||
|
||||
done.Wait()
|
||||
}
|
||||
|
||||
// This test will try to cause a race condition on the model's foreignkey metadata
|
||||
func TestModelStructRaceDifferentModel(t *testing.T) {
|
||||
// use a WaitGroup to execute as much in-sync as possible
|
||||
// it's more likely to hit a race condition than without
|
||||
n := 32
|
||||
start := sync.WaitGroup{}
|
||||
start.Add(n)
|
||||
|
||||
// use another WaitGroup to know when the test is done
|
||||
done := sync.WaitGroup{}
|
||||
done.Add(n)
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
i := i
|
||||
go func() {
|
||||
start.Wait()
|
||||
|
||||
// call GetStructFields, this had a race condition before we fixed it
|
||||
if i%2 == 0 {
|
||||
DB.NewScope(&ModelA{}).GetStructFields()
|
||||
} else {
|
||||
DB.NewScope(&ModelB{}).GetStructFields()
|
||||
}
|
||||
|
||||
done.Done()
|
||||
}()
|
||||
|
||||
start.Done()
|
||||
}
|
||||
|
||||
done.Wait()
|
||||
}
|
@ -133,23 +133,6 @@ func TestStringPrimaryKeyForNumericValueStartingWithZero(t *testing.T) {
|
||||
t.Errorf("Fetch a record from with a string primary key for a numeric value starting with zero should work, but failed, zip code is %v", address.ZipCode)
|
||||
}
|
||||
}
|
||||
func TestStringAgainstIncompleteParentheses(t *testing.T) {
|
||||
type AddressByZipCode struct {
|
||||
ZipCode string `gorm:"primary_key"`
|
||||
Address string
|
||||
}
|
||||
|
||||
DB.AutoMigrate(&AddressByZipCode{})
|
||||
DB.Create(&AddressByZipCode{ZipCode: "00502", Address: "Holtsville"})
|
||||
|
||||
var address AddressByZipCode
|
||||
var addresses []AddressByZipCode
|
||||
_ = DB.First(&address, "address_by_zip_codes=00502)) UNION ALL SELECT NULL,version(),current_database(),NULL,NULL,NULL,NULL,NULL--").Find(&addresses).GetErrors()
|
||||
if len(addresses) > 0 {
|
||||
t.Errorf("Fetch a record from with a string that has incomplete parentheses should be fail, zip code is %v", address.ZipCode)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestFindAsSliceOfPointers(t *testing.T) {
|
||||
DB.Save(&User{Name: "user"})
|
||||
|
29
scope.go
29
scope.go
@ -225,7 +225,7 @@ func (scope *Scope) SetColumn(column interface{}, value interface{}) error {
|
||||
updateAttrs[field.DBName] = value
|
||||
return field.Set(value)
|
||||
}
|
||||
if (field.DBName == dbName) || (field.Name == name && mostMatchedField == nil) {
|
||||
if !field.IsIgnored && ((field.DBName == dbName) || (field.Name == name && mostMatchedField == nil)) {
|
||||
mostMatchedField = field
|
||||
}
|
||||
}
|
||||
@ -257,7 +257,7 @@ func (scope *Scope) CallMethod(methodName string) {
|
||||
func (scope *Scope) AddToVars(value interface{}) string {
|
||||
_, skipBindVar := scope.InstanceGet("skip_bindvar")
|
||||
|
||||
if expr, ok := value.(*expr); ok {
|
||||
if expr, ok := value.(*SqlExpr); ok {
|
||||
exp := expr.expr
|
||||
for _, arg := range expr.args {
|
||||
if skipBindVar {
|
||||
@ -277,23 +277,6 @@ func (scope *Scope) AddToVars(value interface{}) string {
|
||||
return scope.Dialect().BindVar(len(scope.SQLVars))
|
||||
}
|
||||
|
||||
// IsCompleteParentheses check if the string has complete parentheses to prevent SQL injection
|
||||
func (scope *Scope) IsCompleteParentheses(value string) bool {
|
||||
count := 0
|
||||
for i, _ := range value {
|
||||
if value[i] == 40 { // (
|
||||
count++
|
||||
} else if value[i] == 41 { // )
|
||||
count--
|
||||
}
|
||||
if count < 0 {
|
||||
break
|
||||
}
|
||||
i++
|
||||
}
|
||||
return count == 0
|
||||
}
|
||||
|
||||
// SelectAttrs return selected attributes
|
||||
func (scope *Scope) SelectAttrs() []string {
|
||||
if scope.selectAttrs == nil {
|
||||
@ -573,10 +556,6 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool)
|
||||
}
|
||||
|
||||
if value != "" {
|
||||
if !scope.IsCompleteParentheses(value) {
|
||||
scope.Err(fmt.Errorf("incomplete parentheses found: %v", value))
|
||||
return
|
||||
}
|
||||
if !include {
|
||||
if comparisonRegexp.MatchString(value) {
|
||||
str = fmt.Sprintf("NOT (%v)", value)
|
||||
@ -806,7 +785,7 @@ func (scope *Scope) orderSQL() string {
|
||||
for _, order := range scope.Search.orders {
|
||||
if str, ok := order.(string); ok {
|
||||
orders = append(orders, scope.quoteIfPossible(str))
|
||||
} else if expr, ok := order.(*expr); ok {
|
||||
} else if expr, ok := order.(*SqlExpr); ok {
|
||||
exp := expr.expr
|
||||
for _, arg := range expr.args {
|
||||
exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1)
|
||||
@ -933,7 +912,7 @@ func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[strin
|
||||
|
||||
for key, value := range convertInterfaceToMap(value, true, scope.db) {
|
||||
if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) {
|
||||
if _, ok := value.(*expr); ok {
|
||||
if _, ok := value.(*SqlExpr); ok {
|
||||
hasUpdate = true
|
||||
results[field.DBName] = value
|
||||
} else {
|
||||
|
@ -98,7 +98,7 @@ func (s *search) Group(query string) *search {
|
||||
}
|
||||
|
||||
func (s *search) Having(query interface{}, values ...interface{}) *search {
|
||||
if val, ok := query.(*expr); ok {
|
||||
if val, ok := query.(*SqlExpr); ok {
|
||||
s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": val.expr, "args": val.args})
|
||||
} else {
|
||||
s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": query, "args": values})
|
||||
|
6
utils.go
6
utils.go
@ -61,15 +61,15 @@ func newSafeMap() *safeMap {
|
||||
}
|
||||
|
||||
// SQL expression
|
||||
type expr struct {
|
||||
type SqlExpr struct {
|
||||
expr string
|
||||
args []interface{}
|
||||
}
|
||||
|
||||
// Expr generate raw SQL expression, for example:
|
||||
// DB.Model(&product).Update("price", gorm.Expr("price * ? + ?", 2, 100))
|
||||
func Expr(expression string, args ...interface{}) *expr {
|
||||
return &expr{expr: expression, args: args}
|
||||
func Expr(expression string, args ...interface{}) *SqlExpr {
|
||||
return &SqlExpr{expr: expression, args: args}
|
||||
}
|
||||
|
||||
func indirect(reflectValue reflect.Value) reflect.Value {
|
||||
|
Loading…
x
Reference in New Issue
Block a user