diff --git a/callback_create.go b/callback_create.go index 763a2dfd..87aba8ee 100644 --- a/callback_create.go +++ b/callback_create.go @@ -31,7 +31,7 @@ func beforeCreateCallback(scope *Scope) { // updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating func updateTimeStampForCreateCallback(scope *Scope) { if !scope.HasError() { - now := NowFunc() + now := scope.db.nowFunc() if createdAtField, ok := scope.FieldByName("CreatedAt"); ok { if createdAtField.IsBlank { @@ -50,7 +50,7 @@ func updateTimeStampForCreateCallback(scope *Scope) { // createCallback the callback used to insert data into database func createCallback(scope *Scope) { if !scope.HasError() { - defer scope.trace(NowFunc()) + defer scope.trace(scope.db.nowFunc()) var ( columns, placeholders []string diff --git a/callback_delete.go b/callback_delete.go index 73d90880..50242e48 100644 --- a/callback_delete.go +++ b/callback_delete.go @@ -40,7 +40,7 @@ func deleteCallback(scope *Scope) { "UPDATE %v SET %v=%v%v%v", scope.QuotedTableName(), scope.Quote(deletedAtField.DBName), - scope.AddToVars(NowFunc()), + scope.AddToVars(scope.db.nowFunc()), addExtraSpaceIfExist(scope.CombinedConditionSql()), addExtraSpaceIfExist(extraOption), )).Exec() diff --git a/callback_query.go b/callback_query.go index 7facc42b..e3b3d534 100644 --- a/callback_query.go +++ b/callback_query.go @@ -24,7 +24,7 @@ func queryCallback(scope *Scope) { return } - defer scope.trace(NowFunc()) + defer scope.trace(scope.db.nowFunc()) var ( isSlice, isPtr bool diff --git a/callback_update.go b/callback_update.go index c52162c8..56711d37 100644 --- a/callback_update.go +++ b/callback_update.go @@ -50,7 +50,7 @@ func beforeUpdateCallback(scope *Scope) { // updateTimeStampForUpdateCallback will set `UpdatedAt` when updating func updateTimeStampForUpdateCallback(scope *Scope) { if _, ok := scope.Get("gorm:update_column"); !ok { - scope.SetColumn("UpdatedAt", NowFunc()) + scope.SetColumn("UpdatedAt", scope.db.nowFunc()) } } diff --git a/create_test.go b/create_test.go index 450dd8a4..c80bdcbb 100644 --- a/create_test.go +++ b/create_test.go @@ -101,6 +101,46 @@ func TestCreateWithExistingTimestamp(t *testing.T) { } } +func TestCreateWithNowFuncOverride(t *testing.T) { + user1 := User{Name: "CreateUserTimestampOverride"} + + timeA := now.MustParse("2016-01-01") + + // do DB.New() because we don't want this test to affect other tests + db1 := DB.New() + // set the override to use static timeA + db1.SetNowFuncOverride(func() time.Time { + return timeA + }) + // call .New again to check the override is carried over as well during clone + db1 = db1.New() + + db1.Save(&user1) + + if user1.CreatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) { + t.Errorf("CreatedAt be using the nowFuncOverride") + } + if user1.UpdatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) { + t.Errorf("UpdatedAt be using the nowFuncOverride") + } + + // now create another user with a fresh DB.Now() that doesn't have the nowFuncOverride set + // to make sure that setting it only affected the above instance + + user2 := User{Name: "CreateUserTimestampOverrideNoMore"} + + db2 := DB.New() + + db2.Save(&user2) + + if user2.CreatedAt.UTC().Format(time.RFC3339) == timeA.UTC().Format(time.RFC3339) { + t.Errorf("CreatedAt no longer be using the nowFuncOverride") + } + if user2.UpdatedAt.UTC().Format(time.RFC3339) == timeA.UTC().Format(time.RFC3339) { + t.Errorf("UpdatedAt no longer be using the nowFuncOverride") + } +} + type AutoIncrementUser struct { User Sequence uint `gorm:"AUTO_INCREMENT"` diff --git a/interface.go b/interface.go index 55128f7f..fe649231 100644 --- a/interface.go +++ b/interface.go @@ -1,6 +1,9 @@ package gorm -import "database/sql" +import ( + "context" + "database/sql" +) // SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB. type SQLCommon interface { @@ -12,6 +15,7 @@ type SQLCommon interface { type sqlDb interface { Begin() (*sql.Tx, error) + BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) } type sqlTx interface { diff --git a/main.go b/main.go index 079a380d..e24638a6 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package gorm import ( + "context" "database/sql" "errors" "fmt" @@ -30,6 +31,9 @@ type DB struct { callbacks *Callback dialect Dialect singularTable bool + + // function to be used to override the creating of a new timestamp + nowFuncOverride func() time.Time } type logModeValue int @@ -157,6 +161,22 @@ func (s *DB) LogMode(enable bool) *DB { return s } +// SetNowFuncOverride set the function to be used when creating a new timestamp +func (s *DB) SetNowFuncOverride(nowFuncOverride func() time.Time) *DB { + s.nowFuncOverride = nowFuncOverride + return s +} + +// Get a new timestamp, using the provided nowFuncOverride on the DB instance if set, +// otherwise defaults to the global NowFunc() +func (s *DB) nowFunc() time.Time { + if s.nowFuncOverride != nil { + return s.nowFuncOverride() + } + + return NowFunc() +} + // BlockGlobalUpdate if true, generates an error on update/delete without where clause. // This is to prevent eventual error with empty objects updates/deletions func (s *DB) BlockGlobalUpdate(enable bool) *DB { @@ -446,7 +466,7 @@ func (s *DB) Save(value interface{}) *DB { if !scope.PrimaryKeyZero() { newDB := scope.callCallbacks(s.parent.callbacks.updates).db if newDB.Error == nil && newDB.RowsAffected == 0 { - return s.New().FirstOrCreate(value) + return s.New().Table(scope.TableName()).FirstOrCreate(value) } return newDB } @@ -503,11 +523,16 @@ func (s *DB) Debug() *DB { return s.clone().LogMode(true) } -// Begin begin a transaction +// Begin begins a transaction func (s *DB) Begin() *DB { + return s.BeginTx(context.Background(), &sql.TxOptions{}) +} + +// BeginTX begins a transaction with options +func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) *DB { c := s.clone() if db, ok := c.db.(sqlDb); ok && db != nil { - tx, err := db.Begin() + tx, err := db.BeginTx(ctx, opts) c.db = interface{}(tx).(SQLCommon) c.dialect.SetDB(c.db) @@ -533,7 +558,26 @@ func (s *DB) Commit() *DB { func (s *DB) Rollback() *DB { var emptySQLTx *sql.Tx if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { - s.AddError(db.Rollback()) + if err := db.Rollback(); err != nil && err != sql.ErrTxDone { + s.AddError(err) + } + } else { + s.AddError(ErrInvalidTransaction) + } + return s +} + +// RollbackUnlessCommitted rollback a transaction if it has not yet been +// committed. +func (s *DB) RollbackUnlessCommitted() *DB { + var emptySQLTx *sql.Tx + if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { + err := db.Rollback() + // Ignore the error indicating that the transaction has already + // been committed. + if err != sql.ErrTxDone { + s.AddError(err) + } } else { s.AddError(ErrInvalidTransaction) } @@ -775,6 +819,7 @@ func (s *DB) clone() *DB { Error: s.Error, blockGlobalUpdate: s.blockGlobalUpdate, dialect: newDialect(s.dialect.GetName(), s.db), + nowFuncOverride: s.nowFuncOverride, } s.values.Range(func(k, v interface{}) bool { diff --git a/main_test.go b/main_test.go index 14bf34ac..35474cf3 100644 --- a/main_test.go +++ b/main_test.go @@ -1,6 +1,7 @@ package gorm_test import ( + "context" "database/sql" "database/sql/driver" "fmt" @@ -43,13 +44,13 @@ func OpenTestConnection() (db *gorm.DB, err error) { case "mysql": fmt.Println("testing mysql...") if dbDSN == "" { - dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" + dbDSN = "gorm:gorm@tcp(localhost:3306)/gorm?charset=utf8&parseTime=True" } db, err = gorm.Open("mysql", dbDSN) case "postgres": fmt.Println("testing postgres...") if dbDSN == "" { - dbDSN = "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable" + dbDSN = "user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" } db, err = gorm.Open("postgres", dbDSN) case "mssql": @@ -60,7 +61,7 @@ func OpenTestConnection() (db *gorm.DB, err error) { // sp_changedbowner 'gorm'; fmt.Println("testing mssql...") if dbDSN == "" { - dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" + dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:1433?database=gorm" } db, err = gorm.Open("mssql", dbDSN) default: @@ -177,6 +178,15 @@ func TestSetTable(t *testing.T) { t.Errorf("Query from specified table") } + var user User + DB.Table("deleted_users").First(&user, "name = ?", "DeletedUser") + + user.Age = 20 + DB.Table("deleted_users").Save(&user) + if DB.Table("deleted_users").First(&user, "name = ? AND age = ?", "DeletedUser", 20).RecordNotFound() { + t.Errorf("Failed to found updated user") + } + DB.Save(getPreparedUser("normal_user", "reset_table")) DB.Table("deleted_users").Save(getPreparedUser("deleted_user", "reset_table")) var user1, user2, user3 User @@ -419,6 +429,90 @@ func TestTransaction(t *testing.T) { if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil { t.Errorf("Should be able to find committed record") } + + tx3 := DB.Begin() + u3 := User{Name: "transcation-3"} + if err := tx3.Save(&u3).Error; err != nil { + t.Errorf("No error should raise") + } + + if err := tx3.First(&User{}, "name = ?", "transcation-3").Error; err != nil { + t.Errorf("Should find saved record") + } + + tx3.RollbackUnlessCommitted() + + if err := tx.First(&User{}, "name = ?", "transcation").Error; err == nil { + t.Errorf("Should not find record after rollback") + } + + tx4 := DB.Begin() + u4 := User{Name: "transcation-4"} + if err := tx4.Save(&u4).Error; err != nil { + t.Errorf("No error should raise") + } + + if err := tx4.First(&User{}, "name = ?", "transcation-4").Error; err != nil { + t.Errorf("Should find saved record") + } + + tx4.Commit() + + tx4.RollbackUnlessCommitted() + + if err := DB.First(&User{}, "name = ?", "transcation-4").Error; err != nil { + t.Errorf("Should be able to find committed record") + } +} + +func TestTransaction_NoErrorOnRollbackAfterCommit(t *testing.T) { + tx := DB.Begin() + u := User{Name: "transcation"} + if err := tx.Save(&u).Error; err != nil { + t.Errorf("No error should raise") + } + + if err := tx.Commit().Error; err != nil { + t.Errorf("Commit should not raise error") + } + + if err := tx.Rollback().Error; err != nil { + t.Errorf("Rollback should not raise error") + } +} + +func TestTransactionReadonly(t *testing.T) { + dialect := os.Getenv("GORM_DIALECT") + if dialect == "" { + dialect = "sqlite" + } + switch dialect { + case "mssql", "sqlite": + t.Skipf("%s does not support readonly transactions\n", dialect) + } + + tx := DB.Begin() + u := User{Name: "transcation"} + if err := tx.Save(&u).Error; err != nil { + t.Errorf("No error should raise") + } + tx.Commit() + + tx = DB.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}) + if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil { + t.Errorf("Should find saved record") + } + + if sqlTx, ok := tx.CommonDB().(*sql.Tx); !ok || sqlTx == nil { + t.Errorf("Should return the underlying sql.Tx") + } + + u = User{Name: "transcation-2"} + if err := tx.Save(&u).Error; err == nil { + t.Errorf("Error should have been raised in a readonly transaction") + } + + tx.Rollback() } func TestRow(t *testing.T) { diff --git a/model_struct.go b/model_struct.go index eb9421f9..d5de59b2 100644 --- a/model_struct.go +++ b/model_struct.go @@ -203,7 +203,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) } - if _, ok := field.TagSettingsGet("DEFAULT"); ok { + if _, ok := field.TagSettingsGet("DEFAULT"); ok && !field.IsPrimaryKey { field.HasDefaultValue = true } diff --git a/scope.go b/scope.go index ae9b477b..915ae6cf 100644 --- a/scope.go +++ b/scope.go @@ -358,7 +358,7 @@ func (scope *Scope) Raw(sql string) *Scope { // Exec perform generated SQL func (scope *Scope) Exec() *Scope { - defer scope.trace(NowFunc()) + defer scope.trace(scope.db.nowFunc()) if !scope.HasError() { if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { @@ -402,7 +402,7 @@ func (scope *Scope) InstanceGet(name string) (interface{}, bool) { // Begin start a transaction func (scope *Scope) Begin() *Scope { if db, ok := scope.SQLDB().(sqlDb); ok { - if tx, err := db.Begin(); err == nil { + if tx, err := db.Begin(); scope.Err(err) == nil { scope.db.db = interface{}(tx).(SQLCommon) scope.InstanceSet("gorm:started_transaction", true) } @@ -932,7 +932,7 @@ func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[strin } func (scope *Scope) row() *sql.Row { - defer scope.trace(NowFunc()) + defer scope.trace(scope.db.nowFunc()) result := &RowQueryResult{} scope.InstanceSet("row_query_result", result) @@ -942,7 +942,7 @@ func (scope *Scope) row() *sql.Row { } func (scope *Scope) rows() (*sql.Rows, error) { - defer scope.trace(NowFunc()) + defer scope.trace(scope.db.nowFunc()) result := &RowsQueryResult{} scope.InstanceSet("row_query_result", result)