Fix tests with mysql, postgres
This commit is contained in:
		
							parent
							
								
									af080e6773
								
							
						
					
					
						commit
						f7f633590f
					
				| @ -4,7 +4,12 @@ import ( | ||||
| 	"github.com/jinzhu/gorm" | ||||
| ) | ||||
| 
 | ||||
| func RegisterDefaultCallbacks(db *gorm.DB) { | ||||
| type Config struct { | ||||
| 	LastInsertIDReversed bool | ||||
| 	WithReturning        bool | ||||
| } | ||||
| 
 | ||||
| func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { | ||||
| 	enableTransaction := func(db *gorm.DB) bool { | ||||
| 		return !db.SkipDefaultTransaction | ||||
| 	} | ||||
| @ -13,7 +18,7 @@ func RegisterDefaultCallbacks(db *gorm.DB) { | ||||
| 	createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) | ||||
| 	createCallback.Register("gorm:before_create", BeforeCreate) | ||||
| 	createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations) | ||||
| 	createCallback.Register("gorm:create", Create) | ||||
| 	createCallback.Register("gorm:create", Create(config)) | ||||
| 	createCallback.Register("gorm:save_after_associations", SaveAfterAssociations) | ||||
| 	createCallback.Register("gorm:after_create", AfterCreate) | ||||
| 	createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) | ||||
|  | ||||
| @ -5,6 +5,7 @@ import ( | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm" | ||||
| 	"github.com/jinzhu/gorm/clause" | ||||
| 	"github.com/jinzhu/gorm/schema" | ||||
| ) | ||||
| 
 | ||||
| func BeforeCreate(db *gorm.DB) { | ||||
| @ -43,32 +44,113 @@ func BeforeCreate(db *gorm.DB) { | ||||
| func SaveBeforeAssociations(db *gorm.DB) { | ||||
| } | ||||
| 
 | ||||
| func Create(db *gorm.DB) { | ||||
| func Create(config *Config) func(db *gorm.DB) { | ||||
| 	if config.WithReturning { | ||||
| 		return CreateWithReturning | ||||
| 	} else { | ||||
| 		return func(db *gorm.DB) { | ||||
| 			db.Statement.AddClauseIfNotExists(clause.Insert{ | ||||
| 				Table: clause.Table{Name: db.Statement.Table}, | ||||
| 			}) | ||||
| 			db.Statement.AddClause(ConvertToCreateValues(db.Statement)) | ||||
| 
 | ||||
| 			db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT") | ||||
| 			result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||
| 
 | ||||
| 			if err == nil { | ||||
| 				if db.Statement.Schema != nil { | ||||
| 					if insertID, err := result.LastInsertId(); err == nil { | ||||
| 						switch db.Statement.ReflectValue.Kind() { | ||||
| 						case reflect.Slice, reflect.Array: | ||||
| 							if config.LastInsertIDReversed { | ||||
| 								for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { | ||||
| 									db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) | ||||
| 									insertID-- | ||||
| 								} | ||||
| 							} else { | ||||
| 								for i := 0; i < db.Statement.ReflectValue.Len(); i++ { | ||||
| 									db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) | ||||
| 									insertID++ | ||||
| 								} | ||||
| 							} | ||||
| 						case reflect.Struct: | ||||
| 							db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) | ||||
| 						} | ||||
| 					} else { | ||||
| 						db.AddError(err) | ||||
| 					} | ||||
| 				} | ||||
| 				db.RowsAffected, _ = result.RowsAffected() | ||||
| 			} else { | ||||
| 				db.AddError(err) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func CreateWithReturning(db *gorm.DB) { | ||||
| 	db.Statement.AddClauseIfNotExists(clause.Insert{ | ||||
| 		Table: clause.Table{Name: db.Statement.Table}, | ||||
| 	}) | ||||
| 	db.Statement.AddClause(ConvertToCreateValues(db.Statement)) | ||||
| 
 | ||||
| 	db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT") | ||||
| 	result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||
| 
 | ||||
| 	if err == nil { | ||||
| 		if db.Statement.Schema != nil { | ||||
| 			if insertID, err := result.LastInsertId(); err == nil { | ||||
| 				switch db.Statement.ReflectValue.Kind() { | ||||
| 				case reflect.Slice, reflect.Array: | ||||
| 					for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { | ||||
| 						db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) | ||||
| 						insertID-- | ||||
| 	if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 { | ||||
| 		db.Statement.WriteString(" RETURNING ") | ||||
| 
 | ||||
| 		var ( | ||||
| 			idx    int | ||||
| 			fields = make([]*schema.Field, len(sch.FieldsWithDefaultDBValue)) | ||||
| 			values = make([]interface{}, len(sch.FieldsWithDefaultDBValue)) | ||||
| 		) | ||||
| 
 | ||||
| 		for dbName, field := range sch.FieldsWithDefaultDBValue { | ||||
| 			if idx != 0 { | ||||
| 				db.Statement.WriteByte(',') | ||||
| 			} | ||||
| 
 | ||||
| 			fields[idx] = field | ||||
| 			db.Statement.WriteQuoted(dbName) | ||||
| 			idx++ | ||||
| 		} | ||||
| 
 | ||||
| 		rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||
| 
 | ||||
| 		if err == nil { | ||||
| 			defer rows.Close() | ||||
| 
 | ||||
| 			switch db.Statement.ReflectValue.Kind() { | ||||
| 			case reflect.Slice, reflect.Array: | ||||
| 				for rows.Next() { | ||||
| 					for idx, field := range fields { | ||||
| 						values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() | ||||
| 					} | ||||
| 				case reflect.Struct: | ||||
| 					db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) | ||||
| 					if err := rows.Scan(values...); err != nil { | ||||
| 						db.AddError(err) | ||||
| 					} | ||||
| 					db.RowsAffected++ | ||||
| 				} | ||||
| 			case reflect.Struct: | ||||
| 				for idx, field := range fields { | ||||
| 					values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() | ||||
| 				} | ||||
| 
 | ||||
| 				if rows.Next() { | ||||
| 					err = rows.Scan(values...) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 		db.RowsAffected, _ = result.RowsAffected() | ||||
| 
 | ||||
| 		if err != nil { | ||||
| 			db.AddError(err) | ||||
| 		} | ||||
| 	} else { | ||||
| 		db.AddError(err) | ||||
| 		if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil { | ||||
| 			db.RowsAffected, _ = result.RowsAffected() | ||||
| 		} else { | ||||
| 			db.AddError(err) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
|  | ||||
| @ -25,7 +25,7 @@ func Open(dsn string) gorm.Dialector { | ||||
| 
 | ||||
| func (dialector Dialector) Initialize(db *gorm.DB) (err error) { | ||||
| 	// register callbacks
 | ||||
| 	callbacks.RegisterDefaultCallbacks(db) | ||||
| 	callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{}) | ||||
| 	db.ConnPool, err = sql.Open("sqlserver", dialector.DSN) | ||||
| 	return | ||||
| } | ||||
|  | ||||
| @ -24,7 +24,7 @@ func Open(dsn string) gorm.Dialector { | ||||
| 
 | ||||
| func (dialector Dialector) Initialize(db *gorm.DB) (err error) { | ||||
| 	// register callbacks
 | ||||
| 	callbacks.RegisterDefaultCallbacks(db) | ||||
| 	callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{}) | ||||
| 	db.ConnPool, err = sql.Open("mysql", dialector.DSN) | ||||
| 	return | ||||
| } | ||||
|  | ||||
| @ -25,7 +25,9 @@ func Open(dsn string) gorm.Dialector { | ||||
| 
 | ||||
| func (dialector Dialector) Initialize(db *gorm.DB) (err error) { | ||||
| 	// register callbacks
 | ||||
| 	callbacks.RegisterDefaultCallbacks(db) | ||||
| 	callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ | ||||
| 		WithReturning: true, | ||||
| 	}) | ||||
| 	db.ConnPool, err = sql.Open("postgres", dialector.DSN) | ||||
| 	return | ||||
| } | ||||
|  | ||||
| @ -16,7 +16,7 @@ var ( | ||||
| ) | ||||
| 
 | ||||
| func init() { | ||||
| 	dsn := "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable" | ||||
| 	dsn := "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai" | ||||
| 	if os.Getenv("GORM_DSN") != "" { | ||||
| 		dsn = os.Getenv("GORM_DSN") | ||||
| 	} | ||||
|  | ||||
| @ -22,7 +22,9 @@ func Open(dsn string) gorm.Dialector { | ||||
| 
 | ||||
| func (dialector Dialector) Initialize(db *gorm.DB) (err error) { | ||||
| 	// register callbacks
 | ||||
| 	callbacks.RegisterDefaultCallbacks(db) | ||||
| 	callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ | ||||
| 		LastInsertIDReversed: true, | ||||
| 	}) | ||||
| 	db.ConnPool, err = sql.Open("sqlite3", dialector.DSN) | ||||
| 	return | ||||
| } | ||||
|  | ||||
| @ -32,15 +32,15 @@ func checkUserSchema(t *testing.T, user *schema.Schema) { | ||||
| 
 | ||||
| 	// check fields
 | ||||
| 	fields := []schema.Field{ | ||||
| 		{Name: "ID", DBName: "id", BindNames: []string{"Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Tag: `gorm:"primarykey"`, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}}, | ||||
| 		{Name: "ID", DBName: "id", BindNames: []string{"Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Tag: `gorm:"primarykey"`, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}, Size: 64}, | ||||
| 		{Name: "CreatedAt", DBName: "created_at", BindNames: []string{"Model", "CreatedAt"}, DataType: schema.Time}, | ||||
| 		{Name: "UpdatedAt", DBName: "updated_at", BindNames: []string{"Model", "UpdatedAt"}, DataType: schema.Time}, | ||||
| 		{Name: "DeletedAt", DBName: "deleted_at", BindNames: []string{"Model", "DeletedAt"}, Tag: `gorm:"index"`, DataType: schema.Time}, | ||||
| 		{Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String}, | ||||
| 		{Name: "Age", DBName: "age", BindNames: []string{"Age"}, DataType: schema.Uint}, | ||||
| 		{Name: "Age", DBName: "age", BindNames: []string{"Age"}, DataType: schema.Uint, Size: 64}, | ||||
| 		{Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time}, | ||||
| 		{Name: "CompanyID", DBName: "company_id", BindNames: []string{"CompanyID"}, DataType: schema.Int}, | ||||
| 		{Name: "ManagerID", DBName: "manager_id", BindNames: []string{"ManagerID"}, DataType: schema.Uint}, | ||||
| 		{Name: "CompanyID", DBName: "company_id", BindNames: []string{"CompanyID"}, DataType: schema.Int, Size: 64}, | ||||
| 		{Name: "ManagerID", DBName: "manager_id", BindNames: []string{"ManagerID"}, DataType: schema.Uint, Size: 64}, | ||||
| 		{Name: "Active", DBName: "active", BindNames: []string{"Active"}, DataType: schema.Bool}, | ||||
| 	} | ||||
| 
 | ||||
| @ -83,7 +83,7 @@ func checkUserSchema(t *testing.T, user *schema.Schema) { | ||||
| 			JoinTable: JoinTable{Name: "UserSpeak", Table: "user_speaks", Fields: []schema.Field{ | ||||
| 				{ | ||||
| 					Name: "UserID", DBName: "user_id", BindNames: []string{"UserID"}, DataType: schema.Uint, | ||||
| 					Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true, | ||||
| 					Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true, Size: 64, | ||||
| 				}, | ||||
| 				{ | ||||
| 					Name: "LanguageCode", DBName: "language_code", BindNames: []string{"LanguageCode"}, DataType: schema.String, | ||||
| @ -97,11 +97,11 @@ func checkUserSchema(t *testing.T, user *schema.Schema) { | ||||
| 			JoinTable: JoinTable{Name: "user_friends", Table: "user_friends", Fields: []schema.Field{ | ||||
| 				{ | ||||
| 					Name: "UserID", DBName: "user_id", BindNames: []string{"UserID"}, DataType: schema.Uint, | ||||
| 					Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true, | ||||
| 					Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true, Size: 64, | ||||
| 				}, | ||||
| 				{ | ||||
| 					Name: "FriendID", DBName: "friend_id", BindNames: []string{"FriendID"}, DataType: schema.Uint, | ||||
| 					Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true, | ||||
| 					Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, PrimaryKey: true, Size: 64, | ||||
| 				}, | ||||
| 			}}, | ||||
| 			References: []Reference{{"ID", "User", "UserID", "user_friends", "", true}, {"ID", "User", "FriendID", "user_friends", "", false}}, | ||||
| @ -124,7 +124,7 @@ func TestParseSchemaWithAdvancedDataType(t *testing.T) { | ||||
| 
 | ||||
| 	// check fields
 | ||||
| 	fields := []schema.Field{ | ||||
| 		{Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int, PrimaryKey: true}, | ||||
| 		{Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int, PrimaryKey: true, Size: 64}, | ||||
| 		{Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String}, | ||||
| 		{Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time}, | ||||
| 		{Name: "RegisteredAt", DBName: "registered_at", BindNames: []string{"RegisteredAt"}, DataType: schema.Time}, | ||||
|  | ||||
| @ -91,6 +91,8 @@ func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) { | ||||
| 			writer.WriteString(" AS ") | ||||
| 			stmt.DB.Dialector.QuoteTo(writer, v.Alias) | ||||
| 		} | ||||
| 	case string: | ||||
| 		stmt.DB.Dialector.QuoteTo(writer, v) | ||||
| 	default: | ||||
| 		stmt.DB.Dialector.QuoteTo(writer, fmt.Sprint(field)) | ||||
| 	} | ||||
|  | ||||
| @ -15,6 +15,7 @@ services: | ||||
|     ports: | ||||
|       - 9920:5432 | ||||
|     environment: | ||||
|       - TZ=Asia/Shanghai | ||||
|       - POSTGRES_DB=gorm | ||||
|       - POSTGRES_USER=gorm | ||||
|       - POSTGRES_PASSWORD=gorm | ||||
|  | ||||
| @ -37,7 +37,7 @@ func TestCreate(t *testing.T, db *gorm.DB) { | ||||
| 		} | ||||
| 
 | ||||
| 		if err := db.Create(&user).Error; err != nil { | ||||
| 			t.Errorf("errors happened when create: %v", err) | ||||
| 			t.Fatalf("errors happened when create: %v", err) | ||||
| 		} | ||||
| 
 | ||||
| 		if user.ID == 0 { | ||||
| @ -81,7 +81,7 @@ func TestFind(t *testing.T, db *gorm.DB) { | ||||
| 		}} | ||||
| 
 | ||||
| 		if err := db.Create(&users).Error; err != nil { | ||||
| 			t.Fatal("errors happened when create users: %v", err) | ||||
| 			t.Fatalf("errors happened when create users: %v", err) | ||||
| 		} | ||||
| 
 | ||||
| 		t.Run("First", func(t *testing.T) { | ||||
| @ -195,11 +195,11 @@ func TestUpdate(t *testing.T, db *gorm.DB) { | ||||
| 		} | ||||
| 
 | ||||
| 		if err := db.Create(&users).Error; err != nil { | ||||
| 			t.Errorf("errors happened when create: %v", err) | ||||
| 			t.Fatalf("errors happened when create: %v", err) | ||||
| 		} else if user.ID == 0 { | ||||
| 			t.Errorf("user's primary value should not zero, %v", user.ID) | ||||
| 			t.Fatalf("user's primary value should not zero, %v", user.ID) | ||||
| 		} else if user.UpdatedAt.IsZero() { | ||||
| 			t.Errorf("user's updated at should not zero, %v", user.UpdatedAt) | ||||
| 			t.Fatalf("user's updated at should not zero, %v", user.UpdatedAt) | ||||
| 		} | ||||
| 		lastUpdatedAt = user.UpdatedAt | ||||
| 
 | ||||
| @ -297,7 +297,7 @@ func TestDelete(t *testing.T, db *gorm.DB) { | ||||
| 
 | ||||
| 		for _, user := range users { | ||||
| 			if user.ID == 0 { | ||||
| 				t.Errorf("user's primary key should has value after create, got : %v", user.ID) | ||||
| 				t.Fatalf("user's primary key should has value after create, got : %v", user.ID) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu