Merge branch 'master' into master

This commit is contained in:
Jinzhu 2019-06-11 17:35:16 +08:00 committed by GitHub
commit a6c2b1d17d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 201 additions and 18 deletions

View File

@ -31,7 +31,7 @@ func beforeCreateCallback(scope *Scope) {
// updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating // updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating
func updateTimeStampForCreateCallback(scope *Scope) { func updateTimeStampForCreateCallback(scope *Scope) {
if !scope.HasError() { if !scope.HasError() {
now := NowFunc() now := scope.db.nowFunc()
if createdAtField, ok := scope.FieldByName("CreatedAt"); ok { if createdAtField, ok := scope.FieldByName("CreatedAt"); ok {
if createdAtField.IsBlank { if createdAtField.IsBlank {
@ -50,7 +50,7 @@ func updateTimeStampForCreateCallback(scope *Scope) {
// createCallback the callback used to insert data into database // createCallback the callback used to insert data into database
func createCallback(scope *Scope) { func createCallback(scope *Scope) {
if !scope.HasError() { if !scope.HasError() {
defer scope.trace(NowFunc()) defer scope.trace(scope.db.nowFunc())
var ( var (
columns, placeholders []string columns, placeholders []string

View File

@ -40,7 +40,7 @@ func deleteCallback(scope *Scope) {
"UPDATE %v SET %v=%v%v%v", "UPDATE %v SET %v=%v%v%v",
scope.QuotedTableName(), scope.QuotedTableName(),
scope.Quote(deletedAtField.DBName), scope.Quote(deletedAtField.DBName),
scope.AddToVars(NowFunc()), scope.AddToVars(scope.db.nowFunc()),
addExtraSpaceIfExist(scope.CombinedConditionSql()), addExtraSpaceIfExist(scope.CombinedConditionSql()),
addExtraSpaceIfExist(extraOption), addExtraSpaceIfExist(extraOption),
)).Exec() )).Exec()

View File

@ -24,7 +24,7 @@ func queryCallback(scope *Scope) {
return return
} }
defer scope.trace(NowFunc()) defer scope.trace(scope.db.nowFunc())
var ( var (
isSlice, isPtr bool isSlice, isPtr bool

View File

@ -50,7 +50,7 @@ func beforeUpdateCallback(scope *Scope) {
// updateTimeStampForUpdateCallback will set `UpdatedAt` when updating // updateTimeStampForUpdateCallback will set `UpdatedAt` when updating
func updateTimeStampForUpdateCallback(scope *Scope) { func updateTimeStampForUpdateCallback(scope *Scope) {
if _, ok := scope.Get("gorm:update_column"); !ok { if _, ok := scope.Get("gorm:update_column"); !ok {
scope.SetColumn("UpdatedAt", NowFunc()) scope.SetColumn("UpdatedAt", scope.db.nowFunc())
} }
} }

View File

@ -101,6 +101,46 @@ func TestCreateWithExistingTimestamp(t *testing.T) {
} }
} }
func TestCreateWithNowFuncOverride(t *testing.T) {
user1 := User{Name: "CreateUserTimestampOverride"}
timeA := now.MustParse("2016-01-01")
// do DB.New() because we don't want this test to affect other tests
db1 := DB.New()
// set the override to use static timeA
db1.SetNowFuncOverride(func() time.Time {
return timeA
})
// call .New again to check the override is carried over as well during clone
db1 = db1.New()
db1.Save(&user1)
if user1.CreatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) {
t.Errorf("CreatedAt be using the nowFuncOverride")
}
if user1.UpdatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) {
t.Errorf("UpdatedAt be using the nowFuncOverride")
}
// now create another user with a fresh DB.Now() that doesn't have the nowFuncOverride set
// to make sure that setting it only affected the above instance
user2 := User{Name: "CreateUserTimestampOverrideNoMore"}
db2 := DB.New()
db2.Save(&user2)
if user2.CreatedAt.UTC().Format(time.RFC3339) == timeA.UTC().Format(time.RFC3339) {
t.Errorf("CreatedAt no longer be using the nowFuncOverride")
}
if user2.UpdatedAt.UTC().Format(time.RFC3339) == timeA.UTC().Format(time.RFC3339) {
t.Errorf("UpdatedAt no longer be using the nowFuncOverride")
}
}
type AutoIncrementUser struct { type AutoIncrementUser struct {
User User
Sequence uint `gorm:"AUTO_INCREMENT"` Sequence uint `gorm:"AUTO_INCREMENT"`

View File

@ -1,6 +1,9 @@
package gorm package gorm
import "database/sql" import (
"context"
"database/sql"
)
// SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB. // SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB.
type SQLCommon interface { type SQLCommon interface {
@ -12,6 +15,7 @@ type SQLCommon interface {
type sqlDb interface { type sqlDb interface {
Begin() (*sql.Tx, error) Begin() (*sql.Tx, error)
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
} }
type sqlTx interface { type sqlTx interface {

53
main.go
View File

@ -1,6 +1,7 @@
package gorm package gorm
import ( import (
"context"
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
@ -30,6 +31,9 @@ type DB struct {
callbacks *Callback callbacks *Callback
dialect Dialect dialect Dialect
singularTable bool singularTable bool
// function to be used to override the creating of a new timestamp
nowFuncOverride func() time.Time
} }
type logModeValue int type logModeValue int
@ -157,6 +161,22 @@ func (s *DB) LogMode(enable bool) *DB {
return s return s
} }
// SetNowFuncOverride set the function to be used when creating a new timestamp
func (s *DB) SetNowFuncOverride(nowFuncOverride func() time.Time) *DB {
s.nowFuncOverride = nowFuncOverride
return s
}
// Get a new timestamp, using the provided nowFuncOverride on the DB instance if set,
// otherwise defaults to the global NowFunc()
func (s *DB) nowFunc() time.Time {
if s.nowFuncOverride != nil {
return s.nowFuncOverride()
}
return NowFunc()
}
// BlockGlobalUpdate if true, generates an error on update/delete without where clause. // BlockGlobalUpdate if true, generates an error on update/delete without where clause.
// This is to prevent eventual error with empty objects updates/deletions // This is to prevent eventual error with empty objects updates/deletions
func (s *DB) BlockGlobalUpdate(enable bool) *DB { func (s *DB) BlockGlobalUpdate(enable bool) *DB {
@ -446,7 +466,7 @@ func (s *DB) Save(value interface{}) *DB {
if !scope.PrimaryKeyZero() { if !scope.PrimaryKeyZero() {
newDB := scope.callCallbacks(s.parent.callbacks.updates).db newDB := scope.callCallbacks(s.parent.callbacks.updates).db
if newDB.Error == nil && newDB.RowsAffected == 0 { if newDB.Error == nil && newDB.RowsAffected == 0 {
return s.New().FirstOrCreate(value) return s.New().Table(scope.TableName()).FirstOrCreate(value)
} }
return newDB return newDB
} }
@ -503,11 +523,16 @@ func (s *DB) Debug() *DB {
return s.clone().LogMode(true) return s.clone().LogMode(true)
} }
// Begin begin a transaction // Begin begins a transaction
func (s *DB) Begin() *DB { func (s *DB) Begin() *DB {
return s.BeginTx(context.Background(), &sql.TxOptions{})
}
// BeginTX begins a transaction with options
func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) *DB {
c := s.clone() c := s.clone()
if db, ok := c.db.(sqlDb); ok && db != nil { if db, ok := c.db.(sqlDb); ok && db != nil {
tx, err := db.Begin() tx, err := db.BeginTx(ctx, opts)
c.db = interface{}(tx).(SQLCommon) c.db = interface{}(tx).(SQLCommon)
c.dialect.SetDB(c.db) c.dialect.SetDB(c.db)
@ -533,7 +558,26 @@ func (s *DB) Commit() *DB {
func (s *DB) Rollback() *DB { func (s *DB) Rollback() *DB {
var emptySQLTx *sql.Tx var emptySQLTx *sql.Tx
if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx {
s.AddError(db.Rollback()) if err := db.Rollback(); err != nil && err != sql.ErrTxDone {
s.AddError(err)
}
} else {
s.AddError(ErrInvalidTransaction)
}
return s
}
// RollbackUnlessCommitted rollback a transaction if it has not yet been
// committed.
func (s *DB) RollbackUnlessCommitted() *DB {
var emptySQLTx *sql.Tx
if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx {
err := db.Rollback()
// Ignore the error indicating that the transaction has already
// been committed.
if err != sql.ErrTxDone {
s.AddError(err)
}
} else { } else {
s.AddError(ErrInvalidTransaction) s.AddError(ErrInvalidTransaction)
} }
@ -775,6 +819,7 @@ func (s *DB) clone() *DB {
Error: s.Error, Error: s.Error,
blockGlobalUpdate: s.blockGlobalUpdate, blockGlobalUpdate: s.blockGlobalUpdate,
dialect: newDialect(s.dialect.GetName(), s.db), dialect: newDialect(s.dialect.GetName(), s.db),
nowFuncOverride: s.nowFuncOverride,
} }
s.values.Range(func(k, v interface{}) bool { s.values.Range(func(k, v interface{}) bool {

View File

@ -1,6 +1,7 @@
package gorm_test package gorm_test
import ( import (
"context"
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"fmt" "fmt"
@ -43,13 +44,13 @@ func OpenTestConnection() (db *gorm.DB, err error) {
case "mysql": case "mysql":
fmt.Println("testing mysql...") fmt.Println("testing mysql...")
if dbDSN == "" { if dbDSN == "" {
dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" dbDSN = "gorm:gorm@tcp(localhost:3306)/gorm?charset=utf8&parseTime=True"
} }
db, err = gorm.Open("mysql", dbDSN) db, err = gorm.Open("mysql", dbDSN)
case "postgres": case "postgres":
fmt.Println("testing postgres...") fmt.Println("testing postgres...")
if dbDSN == "" { if dbDSN == "" {
dbDSN = "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable" dbDSN = "user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable"
} }
db, err = gorm.Open("postgres", dbDSN) db, err = gorm.Open("postgres", dbDSN)
case "mssql": case "mssql":
@ -60,7 +61,7 @@ func OpenTestConnection() (db *gorm.DB, err error) {
// sp_changedbowner 'gorm'; // sp_changedbowner 'gorm';
fmt.Println("testing mssql...") fmt.Println("testing mssql...")
if dbDSN == "" { if dbDSN == "" {
dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:1433?database=gorm"
} }
db, err = gorm.Open("mssql", dbDSN) db, err = gorm.Open("mssql", dbDSN)
default: default:
@ -177,6 +178,15 @@ func TestSetTable(t *testing.T) {
t.Errorf("Query from specified table") t.Errorf("Query from specified table")
} }
var user User
DB.Table("deleted_users").First(&user, "name = ?", "DeletedUser")
user.Age = 20
DB.Table("deleted_users").Save(&user)
if DB.Table("deleted_users").First(&user, "name = ? AND age = ?", "DeletedUser", 20).RecordNotFound() {
t.Errorf("Failed to found updated user")
}
DB.Save(getPreparedUser("normal_user", "reset_table")) DB.Save(getPreparedUser("normal_user", "reset_table"))
DB.Table("deleted_users").Save(getPreparedUser("deleted_user", "reset_table")) DB.Table("deleted_users").Save(getPreparedUser("deleted_user", "reset_table"))
var user1, user2, user3 User var user1, user2, user3 User
@ -419,6 +429,90 @@ func TestTransaction(t *testing.T) {
if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil { if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
t.Errorf("Should be able to find committed record") t.Errorf("Should be able to find committed record")
} }
tx3 := DB.Begin()
u3 := User{Name: "transcation-3"}
if err := tx3.Save(&u3).Error; err != nil {
t.Errorf("No error should raise")
}
if err := tx3.First(&User{}, "name = ?", "transcation-3").Error; err != nil {
t.Errorf("Should find saved record")
}
tx3.RollbackUnlessCommitted()
if err := tx.First(&User{}, "name = ?", "transcation").Error; err == nil {
t.Errorf("Should not find record after rollback")
}
tx4 := DB.Begin()
u4 := User{Name: "transcation-4"}
if err := tx4.Save(&u4).Error; err != nil {
t.Errorf("No error should raise")
}
if err := tx4.First(&User{}, "name = ?", "transcation-4").Error; err != nil {
t.Errorf("Should find saved record")
}
tx4.Commit()
tx4.RollbackUnlessCommitted()
if err := DB.First(&User{}, "name = ?", "transcation-4").Error; err != nil {
t.Errorf("Should be able to find committed record")
}
}
func TestTransaction_NoErrorOnRollbackAfterCommit(t *testing.T) {
tx := DB.Begin()
u := User{Name: "transcation"}
if err := tx.Save(&u).Error; err != nil {
t.Errorf("No error should raise")
}
if err := tx.Commit().Error; err != nil {
t.Errorf("Commit should not raise error")
}
if err := tx.Rollback().Error; err != nil {
t.Errorf("Rollback should not raise error")
}
}
func TestTransactionReadonly(t *testing.T) {
dialect := os.Getenv("GORM_DIALECT")
if dialect == "" {
dialect = "sqlite"
}
switch dialect {
case "mssql", "sqlite":
t.Skipf("%s does not support readonly transactions\n", dialect)
}
tx := DB.Begin()
u := User{Name: "transcation"}
if err := tx.Save(&u).Error; err != nil {
t.Errorf("No error should raise")
}
tx.Commit()
tx = DB.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true})
if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil {
t.Errorf("Should find saved record")
}
if sqlTx, ok := tx.CommonDB().(*sql.Tx); !ok || sqlTx == nil {
t.Errorf("Should return the underlying sql.Tx")
}
u = User{Name: "transcation-2"}
if err := tx.Save(&u).Error; err == nil {
t.Errorf("Error should have been raised in a readonly transaction")
}
tx.Rollback()
} }
func TestRow(t *testing.T) { func TestRow(t *testing.T) {

View File

@ -203,7 +203,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
} }
if _, ok := field.TagSettingsGet("DEFAULT"); ok { if _, ok := field.TagSettingsGet("DEFAULT"); ok && !field.IsPrimaryKey {
field.HasDefaultValue = true field.HasDefaultValue = true
} }

View File

@ -358,7 +358,7 @@ func (scope *Scope) Raw(sql string) *Scope {
// Exec perform generated SQL // Exec perform generated SQL
func (scope *Scope) Exec() *Scope { func (scope *Scope) Exec() *Scope {
defer scope.trace(NowFunc()) defer scope.trace(scope.db.nowFunc())
if !scope.HasError() { if !scope.HasError() {
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
@ -402,7 +402,7 @@ func (scope *Scope) InstanceGet(name string) (interface{}, bool) {
// Begin start a transaction // Begin start a transaction
func (scope *Scope) Begin() *Scope { func (scope *Scope) Begin() *Scope {
if db, ok := scope.SQLDB().(sqlDb); ok { if db, ok := scope.SQLDB().(sqlDb); ok {
if tx, err := db.Begin(); err == nil { if tx, err := db.Begin(); scope.Err(err) == nil {
scope.db.db = interface{}(tx).(SQLCommon) scope.db.db = interface{}(tx).(SQLCommon)
scope.InstanceSet("gorm:started_transaction", true) scope.InstanceSet("gorm:started_transaction", true)
} }
@ -932,7 +932,7 @@ func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[strin
} }
func (scope *Scope) row() *sql.Row { func (scope *Scope) row() *sql.Row {
defer scope.trace(NowFunc()) defer scope.trace(scope.db.nowFunc())
result := &RowQueryResult{} result := &RowQueryResult{}
scope.InstanceSet("row_query_result", result) scope.InstanceSet("row_query_result", result)
@ -942,7 +942,7 @@ func (scope *Scope) row() *sql.Row {
} }
func (scope *Scope) rows() (*sql.Rows, error) { func (scope *Scope) rows() (*sql.Rows, error) {
defer scope.trace(NowFunc()) defer scope.trace(scope.db.nowFunc())
result := &RowsQueryResult{} result := &RowsQueryResult{}
scope.InstanceSet("row_query_result", result) scope.InstanceSet("row_query_result", result)