diff --git a/callback.go b/callback.go index 719b0a78..56b2064a 100644 --- a/callback.go +++ b/callback.go @@ -96,11 +96,17 @@ func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor { func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) { if cp.kind == "row_query" { if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" { - cp.logger.Print(fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...\n", callbackName)) + cp.logger.Print("info", fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...", callbackName)) cp.before = "gorm:row_query" } } + if cp.logger != nil { + // note cp.logger will be nil during the default gorm callback registrations + // as they occur within init() blocks. However, any user-registered callbacks + // will happen after cp.logger exists (as the default logger or user-specified). + cp.logger.Print("info", fmt.Sprintf("[info] registering callback `%v` from %v", callbackName, fileWithLineNum())) + } cp.name = callbackName cp.processor = &callback cp.parent.processors = append(cp.parent.processors, cp) @@ -110,7 +116,7 @@ func (cp *CallbackProcessor) Register(callbackName string, callback func(scope * // Remove a registered callback // db.Callback().Create().Remove("gorm:update_time_stamp_when_create") func (cp *CallbackProcessor) Remove(callbackName string) { - cp.logger.Print(fmt.Sprintf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum())) + cp.logger.Print("info", fmt.Sprintf("[info] removing callback `%v` from %v", callbackName, fileWithLineNum())) cp.name = callbackName cp.remove = true cp.parent.processors = append(cp.parent.processors, cp) @@ -123,7 +129,7 @@ func (cp *CallbackProcessor) Remove(callbackName string) { // scope.SetColumn("UpdatedAt", now) // }) func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) { - cp.logger.Print(fmt.Sprintf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum())) + cp.logger.Print("info", fmt.Sprintf("[info] replacing callback `%v` from %v", callbackName, fileWithLineNum())) cp.name = callbackName cp.processor = &callback cp.replace = true @@ -166,7 +172,7 @@ func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) { for _, cp := range cps { // show warning message the callback name already exists if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove { - cp.logger.Print(fmt.Sprintf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum())) + cp.logger.Print("warning", fmt.Sprintf("[warning] duplicated callback `%v` from %v", cp.name, fileWithLineNum())) } allNames = append(allNames, cp.name) } diff --git a/callback_create.go b/callback_create.go index 87aba8ee..3527858b 100644 --- a/callback_create.go +++ b/callback_create.go @@ -101,10 +101,11 @@ func createCallback(scope *Scope) { } lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn) + lastInsertIDOutputInterstitial := scope.Dialect().LastInsertIDOutputInterstitial(quotedTableName, returningColumn, columns) if len(columns) == 0 { scope.Raw(fmt.Sprintf( - "INSERT %v INTO %v %v%v%v", + "INSERT%v INTO %v %v%v%v", addExtraSpaceIfExist(insertModifier), quotedTableName, scope.Dialect().DefaultValueStr(), @@ -113,18 +114,19 @@ func createCallback(scope *Scope) { )) } else { scope.Raw(fmt.Sprintf( - "INSERT %v INTO %v (%v) VALUES (%v)%v%v", + "INSERT%v INTO %v (%v)%v VALUES (%v)%v%v", addExtraSpaceIfExist(insertModifier), scope.QuotedTableName(), strings.Join(columns, ","), + addExtraSpaceIfExist(lastInsertIDOutputInterstitial), strings.Join(placeholders, ","), addExtraSpaceIfExist(extraOption), addExtraSpaceIfExist(lastInsertIDReturningSuffix), )) } - // execute create sql - if lastInsertIDReturningSuffix == "" || primaryField == nil { + // execute create sql: no primaryField + if primaryField == nil { if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { // set rows affected count scope.db.RowsAffected, _ = result.RowsAffected() @@ -136,16 +138,35 @@ func createCallback(scope *Scope) { } } } - } else { - if primaryField.Field.CanAddr() { - if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil { - primaryField.IsBlank = false - scope.db.RowsAffected = 1 - } - } else { - scope.Err(ErrUnaddressable) - } + return } + + // execute create sql: lastInsertID implemention for majority of dialects + if lastInsertIDReturningSuffix == "" && lastInsertIDOutputInterstitial == "" { + if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { + // set rows affected count + scope.db.RowsAffected, _ = result.RowsAffected() + + // set primary value to primary field + if primaryField != nil && primaryField.IsBlank { + if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil { + scope.Err(primaryField.Set(primaryValue)) + } + } + } + return + } + + // execute create sql: dialects with additional lastInsertID requirements (currently postgres & mssql) + if primaryField.Field.CanAddr() { + if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil { + primaryField.IsBlank = false + scope.db.RowsAffected = 1 + } + } else { + scope.Err(ErrUnaddressable) + } + return } } diff --git a/dialect.go b/dialect.go index 831c0a8e..b6f95df7 100644 --- a/dialect.go +++ b/dialect.go @@ -40,6 +40,8 @@ type Dialect interface { LimitAndOffsetSQL(limit, offset interface{}) string // SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL` SelectFromDummyTable() string + // LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT` + LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string // LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING` LastInsertIDReturningSuffix(tableName, columnName string) string // DefaultValueStr diff --git a/dialect_common.go b/dialect_common.go index e3a5b702..16da76dc 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -157,6 +157,10 @@ func (commonDialect) SelectFromDummyTable() string { return "" } +func (commonDialect) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string { + return "" +} + func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string { return "" } diff --git a/dialect_mysql.go b/dialect_mysql.go index 5a1ad708..ab6a8a91 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -2,6 +2,7 @@ package gorm import ( "crypto/sha1" + "database/sql" "fmt" "reflect" "regexp" @@ -102,10 +103,10 @@ func (s *mysql) DataTypeOf(field *StructField) string { precision = fmt.Sprintf("(%s)", p) } - if _, ok := field.TagSettingsGet("NOT NULL"); ok || field.IsPrimaryKey { - sqlType = fmt.Sprintf("timestamp%v", precision) + if _, ok := field.TagSettings["NOT NULL"]; ok || field.IsPrimaryKey { + sqlType = fmt.Sprintf("DATETIME%v", precision) } else { - sqlType = fmt.Sprintf("timestamp%v NULL", precision) + sqlType = fmt.Sprintf("DATETIME%v NULL", precision) } } default: @@ -120,7 +121,7 @@ func (s *mysql) DataTypeOf(field *StructField) string { } if sqlType == "" { - panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", dataValue.Type().Name(), dataValue.Kind().String())) + panic(fmt.Sprintf("invalid sql type %s (%s) in field %s for mysql", dataValue.Type().Name(), dataValue.Kind().String(), field.Name)) } if strings.TrimSpace(additionalType) == "" { @@ -161,6 +162,40 @@ func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool { return count > 0 } +func (s mysql) HasTable(tableName string) bool { + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + var name string + // allow mysql database name with '-' character + if err := s.db.QueryRow(fmt.Sprintf("SHOW TABLES FROM `%s` WHERE `Tables_in_%s` = ?", currentDatabase, currentDatabase), tableName).Scan(&name); err != nil { + if err == sql.ErrNoRows { + return false + } + panic(err) + } else { + return true + } +} + +func (s mysql) HasIndex(tableName string, indexName string) bool { + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + if rows, err := s.db.Query(fmt.Sprintf("SHOW INDEXES FROM `%s` FROM `%s` WHERE Key_name = ?", tableName, currentDatabase), indexName); err != nil { + panic(err) + } else { + defer rows.Close() + return rows.Next() + } +} + +func (s mysql) HasColumn(tableName string, columnName string) bool { + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + if rows, err := s.db.Query(fmt.Sprintf("SHOW COLUMNS FROM `%s` FROM `%s` WHERE Field = ?", tableName, currentDatabase), columnName); err != nil { + panic(err) + } else { + defer rows.Close() + return rows.Next() + } +} + func (s mysql) CurrentDatabase() (name string) { s.db.QueryRow("SELECT DATABASE()").Scan(&name) return diff --git a/dialect_postgres.go b/dialect_postgres.go index 53d31388..d2df3131 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -120,6 +120,10 @@ func (s postgres) CurrentDatabase() (name string) { return } +func (s postgres) LastInsertIDOutputInterstitial(tableName, key string, columns []string) string { + return "" +} + func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string { return fmt.Sprintf("RETURNING %v.%v", tableName, key) } diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 8c2360fc..eb79f7e7 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -190,6 +190,14 @@ func (mssql) SelectFromDummyTable() string { return "" } +func (mssql) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string { + if len(columns) == 0 { + // No OUTPUT to query + return "" + } + return fmt.Sprintf("OUTPUT Inserted.%v", columnName) +} + func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string { return "" } diff --git a/main.go b/main.go index 6ef031b6..367b22d8 100644 --- a/main.go +++ b/main.go @@ -212,8 +212,8 @@ func (s *DB) NewScope(value interface{}) *Scope { return scope } -// QueryExpr returns the query as expr object -func (s *DB) QueryExpr() *expr { +// QueryExpr returns the query as SqlExpr object +func (s *DB) QueryExpr() *SqlExpr { scope := s.NewScope(s.Value) scope.InstanceSet("skip_bindvar", true) scope.prepareQuerySQL() @@ -222,7 +222,7 @@ func (s *DB) QueryExpr() *expr { } // SubQuery returns the query as sub query -func (s *DB) SubQuery() *expr { +func (s *DB) SubQuery() *SqlExpr { scope := s.NewScope(s.Value) scope.InstanceSet("skip_bindvar", true) scope.prepareQuerySQL() @@ -437,6 +437,7 @@ func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB { } // Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update +// WARNING when update with struct, GORM will not update fields that with zero value func (s *DB) Update(attrs ...interface{}) *DB { return s.Updates(toSearchableMap(attrs...), true) } @@ -483,6 +484,7 @@ func (s *DB) Create(value interface{}) *DB { } // Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition +// WARNING If model has DeletedAt field, GORM will only set field DeletedAt's value to current time func (s *DB) Delete(value interface{}, where ...interface{}) *DB { return s.NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db } diff --git a/model_struct.go b/model_struct.go index 5234b287..d9e2e90f 100644 --- a/model_struct.go +++ b/model_struct.go @@ -17,6 +17,10 @@ var DefaultTableNameHandler = func(db *DB, defaultTableName string) string { return defaultTableName } +// lock for mutating global cached model metadata +var structsLock sync.Mutex + +// global cache of model metadata var modelStructsMap sync.Map // ModelStruct model definition @@ -419,8 +423,12 @@ func (scope *Scope) GetModelStruct() *ModelStruct { for idx, foreignKey := range foreignKeys { if foreignField := getForeignField(foreignKey, toFields); foreignField != nil { if associationField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); associationField != nil { - // source foreign keys + // mark field as foreignkey, use global lock to avoid race + structsLock.Lock() foreignField.IsForeignKey = true + structsLock.Unlock() + + // association foreign keys relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name) relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName) @@ -523,8 +531,12 @@ func (scope *Scope) GetModelStruct() *ModelStruct { for idx, foreignKey := range foreignKeys { if foreignField := getForeignField(foreignKey, toFields); foreignField != nil { if scopeField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); scopeField != nil { + // mark field as foreignkey, use global lock to avoid race + structsLock.Lock() foreignField.IsForeignKey = true - // source foreign keys + structsLock.Unlock() + + // association foreign keys relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scopeField.Name) relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scopeField.DBName) @@ -582,7 +594,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct { for idx, foreignKey := range foreignKeys { if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { if associationField := getForeignField(associationForeignKeys[idx], toFields); associationField != nil { + // mark field as foreignkey, use global lock to avoid race + structsLock.Lock() foreignField.IsForeignKey = true + structsLock.Unlock() // association foreign keys relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name) diff --git a/model_struct_test.go b/model_struct_test.go new file mode 100644 index 00000000..2ae419a0 --- /dev/null +++ b/model_struct_test.go @@ -0,0 +1,93 @@ +package gorm_test + +import ( + "sync" + "testing" + + "github.com/jinzhu/gorm" +) + +type ModelA struct { + gorm.Model + Name string + + ModelCs []ModelC `gorm:"foreignkey:OtherAID"` +} + +type ModelB struct { + gorm.Model + Name string + + ModelCs []ModelC `gorm:"foreignkey:OtherBID"` +} + +type ModelC struct { + gorm.Model + Name string + + OtherAID uint64 + OtherA *ModelA `gorm:"foreignkey:OtherAID"` + OtherBID uint64 + OtherB *ModelB `gorm:"foreignkey:OtherBID"` +} + +// This test will try to cause a race condition on the model's foreignkey metadata +func TestModelStructRaceSameModel(t *testing.T) { + // use a WaitGroup to execute as much in-sync as possible + // it's more likely to hit a race condition than without + n := 32 + start := sync.WaitGroup{} + start.Add(n) + + // use another WaitGroup to know when the test is done + done := sync.WaitGroup{} + done.Add(n) + + for i := 0; i < n; i++ { + go func() { + start.Wait() + + // call GetStructFields, this had a race condition before we fixed it + DB.NewScope(&ModelA{}).GetStructFields() + + done.Done() + }() + + start.Done() + } + + done.Wait() +} + +// This test will try to cause a race condition on the model's foreignkey metadata +func TestModelStructRaceDifferentModel(t *testing.T) { + // use a WaitGroup to execute as much in-sync as possible + // it's more likely to hit a race condition than without + n := 32 + start := sync.WaitGroup{} + start.Add(n) + + // use another WaitGroup to know when the test is done + done := sync.WaitGroup{} + done.Add(n) + + for i := 0; i < n; i++ { + i := i + go func() { + start.Wait() + + // call GetStructFields, this had a race condition before we fixed it + if i%2 == 0 { + DB.NewScope(&ModelA{}).GetStructFields() + } else { + DB.NewScope(&ModelB{}).GetStructFields() + } + + done.Done() + }() + + start.Done() + } + + done.Wait() +} diff --git a/scope.go b/scope.go index c962c165..eb7525b8 100644 --- a/scope.go +++ b/scope.go @@ -225,7 +225,7 @@ func (scope *Scope) SetColumn(column interface{}, value interface{}) error { updateAttrs[field.DBName] = value return field.Set(value) } - if (field.DBName == dbName) || (field.Name == name && mostMatchedField == nil) { + if !field.IsIgnored && ((field.DBName == dbName) || (field.Name == name && mostMatchedField == nil)) { mostMatchedField = field } } @@ -257,7 +257,7 @@ func (scope *Scope) CallMethod(methodName string) { func (scope *Scope) AddToVars(value interface{}) string { _, skipBindVar := scope.InstanceGet("skip_bindvar") - if expr, ok := value.(*expr); ok { + if expr, ok := value.(*SqlExpr); ok { exp := expr.expr for _, arg := range expr.args { if skipBindVar { @@ -785,7 +785,7 @@ func (scope *Scope) orderSQL() string { for _, order := range scope.Search.orders { if str, ok := order.(string); ok { orders = append(orders, scope.quoteIfPossible(str)) - } else if expr, ok := order.(*expr); ok { + } else if expr, ok := order.(*SqlExpr); ok { exp := expr.expr for _, arg := range expr.args { exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) @@ -912,7 +912,7 @@ func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[strin for key, value := range convertInterfaceToMap(value, true, scope.db) { if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) { - if _, ok := value.(*expr); ok { + if _, ok := value.(*SqlExpr); ok { hasUpdate = true results[field.DBName] = value } else { diff --git a/search.go b/search.go index 90138595..7c4cc184 100644 --- a/search.go +++ b/search.go @@ -98,7 +98,7 @@ func (s *search) Group(query string) *search { } func (s *search) Having(query interface{}, values ...interface{}) *search { - if val, ok := query.(*expr); ok { + if val, ok := query.(*SqlExpr); ok { s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": val.expr, "args": val.args}) } else { s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": query, "args": values}) diff --git a/utils.go b/utils.go index e58e57a5..d2ae9465 100644 --- a/utils.go +++ b/utils.go @@ -58,15 +58,15 @@ func newSafeMap() *safeMap { } // SQL expression -type expr struct { +type SqlExpr struct { expr string args []interface{} } // Expr generate raw SQL expression, for example: // DB.Model(&product).Update("price", gorm.Expr("price * ? + ?", 2, 100)) -func Expr(expression string, args ...interface{}) *expr { - return &expr{expr: expression, args: args} +func Expr(expression string, args ...interface{}) *SqlExpr { + return &SqlExpr{expr: expression, args: args} } func indirect(reflectValue reflect.Value) reflect.Value {