From f424f8aa2e110aeddc300f8e580f10e491134a95 Mon Sep 17 00:00:00 2001 From: Daniel Gatis Date: Fri, 21 Feb 2020 14:03:47 -0300 Subject: [PATCH] go fmt --- callback.go | 82 +- callback_system_test.go | 138 +- callbacks_test.go | 352 ++--- customize_column_test.go | 490 +++---- dialect.go | 40 +- dialect_common.go | 36 +- dialects/mssql/mssql.go | 38 +- embedded_struct_test.go | 118 +- interface.go | 8 +- main.go | 1 - migration_test.go | 808 +++++------ preload_test.go | 2800 +++++++++++++++++++------------------- 12 files changed, 2455 insertions(+), 2456 deletions(-) diff --git a/callback.go b/callback.go index e0088917..2de5ae50 100644 --- a/callback.go +++ b/callback.go @@ -97,27 +97,27 @@ func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor { // Register a new callback, refer `Callbacks.Create` func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) { - callbackContext := func(ctx context.Context, scope *Scope) { - callback(scope) - } + callbackContext := func(ctx context.Context, scope *Scope) { + callback(scope) + } - cp.RegisterContext(callbackName, callbackContext) + cp.RegisterContext(callbackName, callbackContext) } // RegisterContext same as Register func (cp *CallbackProcessor) RegisterContext(callbackName string, callback func(ctx context.Context, scope *Scope)) { - if cp.kind == "row_query" { - if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" { - cp.logger.Print("info", fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...", callbackName)) - cp.before = "gorm:row_query" - } - } + if cp.kind == "row_query" { + if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" { + cp.logger.Print("info", fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...", callbackName)) + cp.before = "gorm:row_query" + } + } - cp.logger.Print("info", fmt.Sprintf("[info] registering callback `%v` from %v", callbackName, fileWithLineNum())) - cp.name = callbackName - cp.processor = &callback - cp.parent.processors = append(cp.parent.processors, cp) - cp.parent.reorder() + cp.logger.Print("info", fmt.Sprintf("[info] registering callback `%v` from %v", callbackName, fileWithLineNum())) + cp.name = callbackName + cp.processor = &callback + cp.parent.processors = append(cp.parent.processors, cp) + cp.parent.reorder() } // Remove a registered callback @@ -136,47 +136,47 @@ func (cp *CallbackProcessor) Remove(callbackName string) { // scope.SetColumn("UpdatedAt", now) // }) func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) { - callbackContext := func(ctx context.Context, scope *Scope) { - callback(scope) - } + callbackContext := func(ctx context.Context, scope *Scope) { + callback(scope) + } - cp.ReplaceContext(callbackName, callbackContext) + cp.ReplaceContext(callbackName, callbackContext) } // ReplaceContext same as Replace func (cp *CallbackProcessor) ReplaceContext(callbackName string, callback func(ctx context.Context, scope *Scope)) { - cp.logger.Print("info", fmt.Sprintf("[info] replacing callback `%v` from %v", callbackName, fileWithLineNum())) - cp.name = callbackName - cp.processor = &callback - cp.replace = true - cp.parent.processors = append(cp.parent.processors, cp) - cp.parent.reorder() + cp.logger.Print("info", fmt.Sprintf("[info] replacing callback `%v` from %v", callbackName, fileWithLineNum())) + cp.name = callbackName + cp.processor = &callback + cp.replace = true + cp.parent.processors = append(cp.parent.processors, cp) + cp.parent.reorder() } // Get registered callback // db.Callback().Create().Get("gorm:create") func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) { - c := cp.GetContext(callbackName) + c := cp.GetContext(callbackName) - callback = func(scope *Scope) { - ctx := context.Background() - c(ctx, scope) - } - return + callback = func(scope *Scope) { + ctx := context.Background() + c(ctx, scope) + } + return } // GetContext same as Get func (cp *CallbackProcessor) GetContext(callbackName string) (callback func(ctx context.Context, scope *Scope)) { - for _, p := range cp.parent.processors { - if p.name == callbackName && p.kind == cp.kind { - if p.remove { - callback = nil - } else { - callback = *p.processor - } - } - } - return + for _, p := range cp.parent.processors { + if p.name == callbackName && p.kind == cp.kind { + if p.remove { + callback = nil + } else { + callback = *p.processor + } + } + } + return } // getRIndex get right index from string slice diff --git a/callback_system_test.go b/callback_system_test.go index 4fb42a27..4dcee9f1 100644 --- a/callback_system_test.go +++ b/callback_system_test.go @@ -1,20 +1,20 @@ package gorm import ( - "context" - "reflect" - "runtime" - "strings" - "testing" + "context" + "reflect" + "runtime" + "strings" + "testing" ) func equalFuncs(funcs []*func(ctx context.Context, s *Scope), fnames []string) bool { - var names []string - for _, f := range funcs { - fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(*f).Pointer()).Name(), ".") - names = append(names, fnames[len(fnames)-1]) - } - return reflect.DeepEqual(names, fnames) + var names []string + for _, f := range funcs { + fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(*f).Pointer()).Name(), ".") + names = append(names, fnames[len(fnames)-1]) + } + return reflect.DeepEqual(names, fnames) } func create(s *Scope) {} @@ -24,90 +24,90 @@ func afterCreate1(s *Scope) {} func afterCreate2(s *Scope) {} func TestRegisterCallback(t *testing.T) { - var callback = &Callback{logger: defaultLogger} + var callback = &Callback{logger: defaultLogger} - callback.Create().Register("before_create1", beforeCreate1) - callback.Create().Register("before_create2", beforeCreate2) - callback.Create().Register("create", create) - callback.Create().Register("after_create1", afterCreate1) - callback.Create().Register("after_create2", afterCreate2) + callback.Create().Register("before_create1", beforeCreate1) + callback.Create().Register("before_create2", beforeCreate2) + callback.Create().Register("create", create) + callback.Create().Register("after_create1", afterCreate1) + callback.Create().Register("after_create2", afterCreate2) - if !equalFuncs(callback.creates, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) { - t.Errorf("register callback") - } + if !equalFuncs(callback.creates, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) { + t.Errorf("register callback") + } } func TestRegisterCallbackWithOrder(t *testing.T) { - var callback1 = &Callback{logger: defaultLogger} - callback1.Create().Register("before_create1", beforeCreate1) - callback1.Create().Register("create", create) - callback1.Create().Register("after_create1", afterCreate1) - callback1.Create().Before("after_create1").Register("after_create2", afterCreate2) - if !equalFuncs(callback1.creates, []string{"beforeCreate1", "create", "afterCreate2", "afterCreate1"}) { - t.Errorf("register callback with order") - } + var callback1 = &Callback{logger: defaultLogger} + callback1.Create().Register("before_create1", beforeCreate1) + callback1.Create().Register("create", create) + callback1.Create().Register("after_create1", afterCreate1) + callback1.Create().Before("after_create1").Register("after_create2", afterCreate2) + if !equalFuncs(callback1.creates, []string{"beforeCreate1", "create", "afterCreate2", "afterCreate1"}) { + t.Errorf("register callback with order") + } - var callback2 = &Callback{logger: defaultLogger} + var callback2 = &Callback{logger: defaultLogger} - callback2.Update().Register("create", create) - callback2.Update().Before("create").Register("before_create1", beforeCreate1) - callback2.Update().After("after_create2").Register("after_create1", afterCreate1) - callback2.Update().Before("before_create1").Register("before_create2", beforeCreate2) - callback2.Update().Register("after_create2", afterCreate2) + callback2.Update().Register("create", create) + callback2.Update().Before("create").Register("before_create1", beforeCreate1) + callback2.Update().After("after_create2").Register("after_create1", afterCreate1) + callback2.Update().Before("before_create1").Register("before_create2", beforeCreate2) + callback2.Update().Register("after_create2", afterCreate2) - if !equalFuncs(callback2.updates, []string{"beforeCreate2", "beforeCreate1", "create", "afterCreate2", "afterCreate1"}) { - t.Errorf("register callback with order") - } + if !equalFuncs(callback2.updates, []string{"beforeCreate2", "beforeCreate1", "create", "afterCreate2", "afterCreate1"}) { + t.Errorf("register callback with order") + } } func TestRegisterCallbackWithComplexOrder(t *testing.T) { - var callback1 = &Callback{logger: defaultLogger} + var callback1 = &Callback{logger: defaultLogger} - callback1.Query().Before("after_create1").After("before_create1").Register("create", create) - callback1.Query().Register("before_create1", beforeCreate1) - callback1.Query().Register("after_create1", afterCreate1) + callback1.Query().Before("after_create1").After("before_create1").Register("create", create) + callback1.Query().Register("before_create1", beforeCreate1) + callback1.Query().Register("after_create1", afterCreate1) - if !equalFuncs(callback1.queries, []string{"beforeCreate1", "create", "afterCreate1"}) { - t.Errorf("register callback with order") - } + if !equalFuncs(callback1.queries, []string{"beforeCreate1", "create", "afterCreate1"}) { + t.Errorf("register callback with order") + } - var callback2 = &Callback{logger: defaultLogger} + var callback2 = &Callback{logger: defaultLogger} - callback2.Delete().Before("after_create1").After("before_create1").Register("create", create) - callback2.Delete().Before("create").Register("before_create1", beforeCreate1) - callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2) - callback2.Delete().Register("after_create1", afterCreate1) - callback2.Delete().After("after_create1").Register("after_create2", afterCreate2) + callback2.Delete().Before("after_create1").After("before_create1").Register("create", create) + callback2.Delete().Before("create").Register("before_create1", beforeCreate1) + callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2) + callback2.Delete().Register("after_create1", afterCreate1) + callback2.Delete().After("after_create1").Register("after_create2", afterCreate2) - if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) { - t.Errorf("register callback with order") - } + if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) { + t.Errorf("register callback with order") + } } func replaceCreate(s *Scope) {} func TestReplaceCallback(t *testing.T) { - var callback = &Callback{logger: defaultLogger} + var callback = &Callback{logger: defaultLogger} - callback.Create().Before("after_create1").After("before_create1").Register("create", create) - callback.Create().Register("before_create1", beforeCreate1) - callback.Create().Register("after_create1", afterCreate1) - callback.Create().Replace("create", replaceCreate) + callback.Create().Before("after_create1").After("before_create1").Register("create", create) + callback.Create().Register("before_create1", beforeCreate1) + callback.Create().Register("after_create1", afterCreate1) + callback.Create().Replace("create", replaceCreate) - if !equalFuncs(callback.creates, []string{"beforeCreate1", "replaceCreate", "afterCreate1"}) { - t.Errorf("replace callback") - } + if !equalFuncs(callback.creates, []string{"beforeCreate1", "replaceCreate", "afterCreate1"}) { + t.Errorf("replace callback") + } } func TestRemoveCallback(t *testing.T) { - var callback = &Callback{logger: defaultLogger} + var callback = &Callback{logger: defaultLogger} - callback.Create().Before("after_create1").After("before_create1").Register("create", create) - callback.Create().Register("before_create1", beforeCreate1) - callback.Create().Register("after_create1", afterCreate1) - callback.Create().Remove("create") + callback.Create().Before("after_create1").After("before_create1").Register("create", create) + callback.Create().Register("before_create1", beforeCreate1) + callback.Create().Register("after_create1", afterCreate1) + callback.Create().Remove("create") - if !equalFuncs(callback.creates, []string{"beforeCreate1", "afterCreate1"}) { - t.Errorf("remove callback") - } + if !equalFuncs(callback.creates, []string{"beforeCreate1", "afterCreate1"}) { + t.Errorf("remove callback") + } } diff --git a/callbacks_test.go b/callbacks_test.go index 61551e9a..bebd0e38 100644 --- a/callbacks_test.go +++ b/callbacks_test.go @@ -1,249 +1,249 @@ package gorm_test import ( - "errors" - "reflect" - "testing" + "errors" + "reflect" + "testing" - "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm" ) func (s *Product) BeforeCreate() (err error) { - if s.Code == "Invalid" { - err = errors.New("invalid product") - } - s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1 - return + if s.Code == "Invalid" { + err = errors.New("invalid product") + } + s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1 + return } func (s *Product) BeforeUpdate() (err error) { - if s.Code == "dont_update" { - err = errors.New("can't update") - } - s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1 - return + if s.Code == "dont_update" { + err = errors.New("can't update") + } + s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1 + return } func (s *Product) BeforeSave() (err error) { - if s.Code == "dont_save" { - err = errors.New("can't save") - } - s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1 - return + if s.Code == "dont_save" { + err = errors.New("can't save") + } + s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1 + return } func (s *Product) AfterFind() { - s.AfterFindCallTimes = s.AfterFindCallTimes + 1 + s.AfterFindCallTimes = s.AfterFindCallTimes + 1 } func (s *Product) AfterCreate(tx *gorm.DB) { - tx.Model(s).UpdateColumn(Product{AfterCreateCallTimes: s.AfterCreateCallTimes + 1}) + tx.Model(s).UpdateColumn(Product{AfterCreateCallTimes: s.AfterCreateCallTimes + 1}) } func (s *Product) AfterUpdate() { - s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1 + s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1 } func (s *Product) AfterSave() (err error) { - if s.Code == "after_save_error" { - err = errors.New("can't save") - } - s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1 - return + if s.Code == "after_save_error" { + err = errors.New("can't save") + } + s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1 + return } func (s *Product) BeforeDelete() (err error) { - if s.Code == "dont_delete" { - err = errors.New("can't delete") - } - s.BeforeDeleteCallTimes = s.BeforeDeleteCallTimes + 1 - return + if s.Code == "dont_delete" { + err = errors.New("can't delete") + } + s.BeforeDeleteCallTimes = s.BeforeDeleteCallTimes + 1 + return } func (s *Product) AfterDelete() (err error) { - if s.Code == "after_delete_error" { - err = errors.New("can't delete") - } - s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1 - return + if s.Code == "after_delete_error" { + err = errors.New("can't delete") + } + s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1 + return } func (s *Product) GetCallTimes() []int64 { - return []int64{s.BeforeCreateCallTimes, s.BeforeSaveCallTimes, s.BeforeUpdateCallTimes, s.AfterCreateCallTimes, s.AfterSaveCallTimes, s.AfterUpdateCallTimes, s.BeforeDeleteCallTimes, s.AfterDeleteCallTimes, s.AfterFindCallTimes} + return []int64{s.BeforeCreateCallTimes, s.BeforeSaveCallTimes, s.BeforeUpdateCallTimes, s.AfterCreateCallTimes, s.AfterSaveCallTimes, s.AfterUpdateCallTimes, s.BeforeDeleteCallTimes, s.AfterDeleteCallTimes, s.AfterFindCallTimes} } func TestRunCallbacks(t *testing.T) { - p := Product{Code: "unique_code", Price: 100} - DB.Save(&p) + p := Product{Code: "unique_code", Price: 100} + DB.Save(&p) - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0}) { - t.Errorf("Callbacks should be invoked successfully, %v", p.GetCallTimes()) - } + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0}) { + t.Errorf("Callbacks should be invoked successfully, %v", p.GetCallTimes()) + } - DB.Where("Code = ?", "unique_code").First(&p) - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0, 1}) { - t.Errorf("After callbacks values are not saved, %v", p.GetCallTimes()) - } + DB.Where("Code = ?", "unique_code").First(&p) + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0, 1}) { + t.Errorf("After callbacks values are not saved, %v", p.GetCallTimes()) + } - p.Price = 200 - DB.Save(&p) - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0, 1}) { - t.Errorf("After update callbacks should be invoked successfully, %v", p.GetCallTimes()) - } + p.Price = 200 + DB.Save(&p) + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0, 1}) { + t.Errorf("After update callbacks should be invoked successfully, %v", p.GetCallTimes()) + } - var products []Product - DB.Find(&products, "code = ?", "unique_code") - if products[0].AfterFindCallTimes != 2 { - t.Errorf("AfterFind callbacks should work with slice") - } + var products []Product + DB.Find(&products, "code = ?", "unique_code") + if products[0].AfterFindCallTimes != 2 { + t.Errorf("AfterFind callbacks should work with slice") + } - DB.Where("Code = ?", "unique_code").First(&p) - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2}) { - t.Errorf("After update callbacks values are not saved, %v", p.GetCallTimes()) - } + DB.Where("Code = ?", "unique_code").First(&p) + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2}) { + t.Errorf("After update callbacks values are not saved, %v", p.GetCallTimes()) + } - DB.Delete(&p) - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2}) { - t.Errorf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes()) - } + DB.Delete(&p) + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2}) { + t.Errorf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes()) + } - if DB.Where("Code = ?", "unique_code").First(&p).Error == nil { - t.Errorf("Can't find a deleted record") - } + if DB.Where("Code = ?", "unique_code").First(&p).Error == nil { + t.Errorf("Can't find a deleted record") + } } func TestCallbacksWithErrors(t *testing.T) { - p := Product{Code: "Invalid", Price: 100} - if DB.Save(&p).Error == nil { - t.Errorf("An error from before create callbacks happened when create with invalid value") - } + p := Product{Code: "Invalid", Price: 100} + if DB.Save(&p).Error == nil { + t.Errorf("An error from before create callbacks happened when create with invalid value") + } - if DB.Where("code = ?", "Invalid").First(&Product{}).Error == nil { - t.Errorf("Should not save record that have errors") - } + if DB.Where("code = ?", "Invalid").First(&Product{}).Error == nil { + t.Errorf("Should not save record that have errors") + } - if DB.Save(&Product{Code: "dont_save", Price: 100}).Error == nil { - t.Errorf("An error from after create callbacks happened when create with invalid value") - } + if DB.Save(&Product{Code: "dont_save", Price: 100}).Error == nil { + t.Errorf("An error from after create callbacks happened when create with invalid value") + } - p2 := Product{Code: "update_callback", Price: 100} - DB.Save(&p2) + p2 := Product{Code: "update_callback", Price: 100} + DB.Save(&p2) - p2.Code = "dont_update" - if DB.Save(&p2).Error == nil { - t.Errorf("An error from before update callbacks happened when update with invalid value") - } + p2.Code = "dont_update" + if DB.Save(&p2).Error == nil { + t.Errorf("An error from before update callbacks happened when update with invalid value") + } - if DB.Where("code = ?", "update_callback").First(&Product{}).Error != nil { - t.Errorf("Record Should not be updated due to errors happened in before update callback") - } + if DB.Where("code = ?", "update_callback").First(&Product{}).Error != nil { + t.Errorf("Record Should not be updated due to errors happened in before update callback") + } - if DB.Where("code = ?", "dont_update").First(&Product{}).Error == nil { - t.Errorf("Record Should not be updated due to errors happened in before update callback") - } + if DB.Where("code = ?", "dont_update").First(&Product{}).Error == nil { + t.Errorf("Record Should not be updated due to errors happened in before update callback") + } - p2.Code = "dont_save" - if DB.Save(&p2).Error == nil { - t.Errorf("An error from before save callbacks happened when update with invalid value") - } + p2.Code = "dont_save" + if DB.Save(&p2).Error == nil { + t.Errorf("An error from before save callbacks happened when update with invalid value") + } - p3 := Product{Code: "dont_delete", Price: 100} - DB.Save(&p3) - if DB.Delete(&p3).Error == nil { - t.Errorf("An error from before delete callbacks happened when delete") - } + p3 := Product{Code: "dont_delete", Price: 100} + DB.Save(&p3) + if DB.Delete(&p3).Error == nil { + t.Errorf("An error from before delete callbacks happened when delete") + } - if DB.Where("Code = ?", "dont_delete").First(&p3).Error != nil { - t.Errorf("An error from before delete callbacks happened") - } + if DB.Where("Code = ?", "dont_delete").First(&p3).Error != nil { + t.Errorf("An error from before delete callbacks happened") + } - p4 := Product{Code: "after_save_error", Price: 100} - DB.Save(&p4) - if err := DB.First(&Product{}, "code = ?", "after_save_error").Error; err == nil { - t.Errorf("Record should be reverted if get an error in after save callback") - } + p4 := Product{Code: "after_save_error", Price: 100} + DB.Save(&p4) + if err := DB.First(&Product{}, "code = ?", "after_save_error").Error; err == nil { + t.Errorf("Record should be reverted if get an error in after save callback") + } - p5 := Product{Code: "after_delete_error", Price: 100} - DB.Save(&p5) - if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { - t.Errorf("Record should be found") - } + p5 := Product{Code: "after_delete_error", Price: 100} + DB.Save(&p5) + if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { + t.Errorf("Record should be found") + } - DB.Delete(&p5) - if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { - t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback") - } + DB.Delete(&p5) + if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { + t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback") + } } func TestGetCallback(t *testing.T) { - scope := DB.NewScope(nil) + scope := DB.NewScope(nil) - if DB.Callback().Create().Get("gorm:test_callback") != nil { - t.Errorf("`gorm:test_callback` should be nil") - } + if DB.Callback().Create().Get("gorm:test_callback") != nil { + t.Errorf("`gorm:test_callback` should be nil") + } - DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 1) }) - callback := DB.Callback().Create().Get("gorm:test_callback") - if callback == nil { - t.Errorf("`gorm:test_callback` should be non-nil") - } - callback(scope) - if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 1 { - t.Errorf("`gorm:test_callback_value` should be `1, true` but `%v, %v`", v, ok) - } + DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 1) }) + callback := DB.Callback().Create().Get("gorm:test_callback") + if callback == nil { + t.Errorf("`gorm:test_callback` should be non-nil") + } + callback(scope) + if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 1 { + t.Errorf("`gorm:test_callback_value` should be `1, true` but `%v, %v`", v, ok) + } - DB.Callback().Create().Replace("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 2) }) - callback = DB.Callback().Create().Get("gorm:test_callback") - if callback == nil { - t.Errorf("`gorm:test_callback` should be non-nil") - } - callback(scope) - if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 2 { - t.Errorf("`gorm:test_callback_value` should be `2, true` but `%v, %v`", v, ok) - } + DB.Callback().Create().Replace("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 2) }) + callback = DB.Callback().Create().Get("gorm:test_callback") + if callback == nil { + t.Errorf("`gorm:test_callback` should be non-nil") + } + callback(scope) + if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 2 { + t.Errorf("`gorm:test_callback_value` should be `2, true` but `%v, %v`", v, ok) + } - DB.Callback().Create().Remove("gorm:test_callback") - if DB.Callback().Create().Get("gorm:test_callback") != nil { - t.Errorf("`gorm:test_callback` should be nil") - } + DB.Callback().Create().Remove("gorm:test_callback") + if DB.Callback().Create().Get("gorm:test_callback") != nil { + t.Errorf("`gorm:test_callback` should be nil") + } - DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 3) }) - callback = DB.Callback().Create().Get("gorm:test_callback") - if callback == nil { - t.Errorf("`gorm:test_callback` should be non-nil") - } - callback(scope) - if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 3 { - t.Errorf("`gorm:test_callback_value` should be `3, true` but `%v, %v`", v, ok) - } + DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 3) }) + callback = DB.Callback().Create().Get("gorm:test_callback") + if callback == nil { + t.Errorf("`gorm:test_callback` should be non-nil") + } + callback(scope) + if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 3 { + t.Errorf("`gorm:test_callback_value` should be `3, true` but `%v, %v`", v, ok) + } } func TestUseDefaultCallback(t *testing.T) { - createCallbackName := "gorm:test_use_default_callback_for_create" - gorm.DefaultCallback.Create().Register(createCallbackName, func(*gorm.Scope) { - // nop - }) - if gorm.DefaultCallback.Create().Get(createCallbackName) == nil { - t.Errorf("`%s` expected non-nil, but got nil", createCallbackName) - } - gorm.DefaultCallback.Create().Remove(createCallbackName) - if gorm.DefaultCallback.Create().Get(createCallbackName) != nil { - t.Errorf("`%s` expected nil, but got non-nil", createCallbackName) - } + createCallbackName := "gorm:test_use_default_callback_for_create" + gorm.DefaultCallback.Create().Register(createCallbackName, func(*gorm.Scope) { + // nop + }) + if gorm.DefaultCallback.Create().Get(createCallbackName) == nil { + t.Errorf("`%s` expected non-nil, but got nil", createCallbackName) + } + gorm.DefaultCallback.Create().Remove(createCallbackName) + if gorm.DefaultCallback.Create().Get(createCallbackName) != nil { + t.Errorf("`%s` expected nil, but got non-nil", createCallbackName) + } - updateCallbackName := "gorm:test_use_default_callback_for_update" - scopeValueName := "gorm:test_use_default_callback_for_update_value" - gorm.DefaultCallback.Update().Register(updateCallbackName, func(scope *gorm.Scope) { - scope.Set(scopeValueName, 1) - }) - gorm.DefaultCallback.Update().Replace(updateCallbackName, func(scope *gorm.Scope) { - scope.Set(scopeValueName, 2) - }) + updateCallbackName := "gorm:test_use_default_callback_for_update" + scopeValueName := "gorm:test_use_default_callback_for_update_value" + gorm.DefaultCallback.Update().Register(updateCallbackName, func(scope *gorm.Scope) { + scope.Set(scopeValueName, 1) + }) + gorm.DefaultCallback.Update().Replace(updateCallbackName, func(scope *gorm.Scope) { + scope.Set(scopeValueName, 2) + }) - scope := DB.NewScope(nil) - callback := gorm.DefaultCallback.Update().Get(updateCallbackName) - callback(scope) - if v, ok := scope.Get(scopeValueName); !ok || v != 2 { - t.Errorf("`%s` should be `2, true` but `%v, %v`", scopeValueName, v, ok) - } + scope := DB.NewScope(nil) + callback := gorm.DefaultCallback.Update().Get(updateCallbackName) + callback(scope) + if v, ok := scope.Get(scopeValueName); !ok || v != 2 { + t.Errorf("`%s` should be `2, true` but `%v, %v`", scopeValueName, v, ok) + } } diff --git a/customize_column_test.go b/customize_column_test.go index a6cbe6b6..c236ac24 100644 --- a/customize_column_test.go +++ b/customize_column_test.go @@ -1,357 +1,357 @@ package gorm_test import ( - "testing" - "time" + "testing" + "time" - "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm" ) type CustomizeColumn struct { - ID int64 `gorm:"column:mapped_id; primary_key:yes"` - Name string `gorm:"column:mapped_name"` - Date *time.Time `gorm:"column:mapped_time"` + ID int64 `gorm:"column:mapped_id; primary_key:yes"` + Name string `gorm:"column:mapped_name"` + Date *time.Time `gorm:"column:mapped_time"` } // Make sure an ignored field does not interfere with another field's custom // column name that matches the ignored field. type CustomColumnAndIgnoredFieldClash struct { - Body string `sql:"-"` - RawBody string `gorm:"column:body"` + Body string `sql:"-"` + RawBody string `gorm:"column:body"` } func TestCustomizeColumn(t *testing.T) { - col := "mapped_name" - DB.DropTable(&CustomizeColumn{}) - DB.AutoMigrate(&CustomizeColumn{}) + col := "mapped_name" + DB.DropTable(&CustomizeColumn{}) + DB.AutoMigrate(&CustomizeColumn{}) - scope := DB.NewScope(&CustomizeColumn{}) - if !scope.Dialect().HasColumn(scope.TableName(), col) { - t.Errorf("CustomizeColumn should have column %s", col) - } + scope := DB.NewScope(&CustomizeColumn{}) + if !scope.Dialect().HasColumn(scope.TableName(), col) { + t.Errorf("CustomizeColumn should have column %s", col) + } - col = "mapped_id" - if scope.PrimaryKey() != col { - t.Errorf("CustomizeColumn should have primary key %s, but got %q", col, scope.PrimaryKey()) - } + col = "mapped_id" + if scope.PrimaryKey() != col { + t.Errorf("CustomizeColumn should have primary key %s, but got %q", col, scope.PrimaryKey()) + } - expected := "foo" - now := time.Now() - cc := CustomizeColumn{ID: 666, Name: expected, Date: &now} + expected := "foo" + now := time.Now() + cc := CustomizeColumn{ID: 666, Name: expected, Date: &now} - if count := DB.Create(&cc).RowsAffected; count != 1 { - t.Error("There should be one record be affected when create record") - } + if count := DB.Create(&cc).RowsAffected; count != 1 { + t.Error("There should be one record be affected when create record") + } - var cc1 CustomizeColumn - DB.First(&cc1, 666) + var cc1 CustomizeColumn + DB.First(&cc1, 666) - if cc1.Name != expected { - t.Errorf("Failed to query CustomizeColumn") - } + if cc1.Name != expected { + t.Errorf("Failed to query CustomizeColumn") + } - cc.Name = "bar" - DB.Save(&cc) + cc.Name = "bar" + DB.Save(&cc) - var cc2 CustomizeColumn - DB.First(&cc2, 666) - if cc2.Name != "bar" { - t.Errorf("Failed to query CustomizeColumn") - } + var cc2 CustomizeColumn + DB.First(&cc2, 666) + if cc2.Name != "bar" { + t.Errorf("Failed to query CustomizeColumn") + } } func TestCustomColumnAndIgnoredFieldClash(t *testing.T) { - DB.DropTable(&CustomColumnAndIgnoredFieldClash{}) - if err := DB.AutoMigrate(&CustomColumnAndIgnoredFieldClash{}).Error; err != nil { - t.Errorf("Should not raise error: %s", err) - } + DB.DropTable(&CustomColumnAndIgnoredFieldClash{}) + if err := DB.AutoMigrate(&CustomColumnAndIgnoredFieldClash{}).Error; err != nil { + t.Errorf("Should not raise error: %s", err) + } } type CustomizePerson struct { - IdPerson string `gorm:"column:idPerson;primary_key:true"` - Accounts []CustomizeAccount `gorm:"many2many:PersonAccount;associationforeignkey:idAccount;foreignkey:idPerson"` + IdPerson string `gorm:"column:idPerson;primary_key:true"` + Accounts []CustomizeAccount `gorm:"many2many:PersonAccount;associationforeignkey:idAccount;foreignkey:idPerson"` } type CustomizeAccount struct { - IdAccount string `gorm:"column:idAccount;primary_key:true"` - Name string + IdAccount string `gorm:"column:idAccount;primary_key:true"` + Name string } func TestManyToManyWithCustomizedColumn(t *testing.T) { - DB.DropTable(&CustomizePerson{}, &CustomizeAccount{}, "PersonAccount") - DB.AutoMigrate(&CustomizePerson{}, &CustomizeAccount{}) + DB.DropTable(&CustomizePerson{}, &CustomizeAccount{}, "PersonAccount") + DB.AutoMigrate(&CustomizePerson{}, &CustomizeAccount{}) - account := CustomizeAccount{IdAccount: "account", Name: "id1"} - person := CustomizePerson{ - IdPerson: "person", - Accounts: []CustomizeAccount{account}, - } + account := CustomizeAccount{IdAccount: "account", Name: "id1"} + person := CustomizePerson{ + IdPerson: "person", + Accounts: []CustomizeAccount{account}, + } - if err := DB.Create(&account).Error; err != nil { - t.Errorf("no error should happen, but got %v", err) - } + if err := DB.Create(&account).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } - if err := DB.Create(&person).Error; err != nil { - t.Errorf("no error should happen, but got %v", err) - } + if err := DB.Create(&person).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } - var person1 CustomizePerson - scope := DB.NewScope(nil) - if err := DB.Preload("Accounts").First(&person1, scope.Quote("idPerson")+" = ?", person.IdPerson).Error; err != nil { - t.Errorf("no error should happen when preloading customized column many2many relations, but got %v", err) - } + var person1 CustomizePerson + scope := DB.NewScope(nil) + if err := DB.Preload("Accounts").First(&person1, scope.Quote("idPerson")+" = ?", person.IdPerson).Error; err != nil { + t.Errorf("no error should happen when preloading customized column many2many relations, but got %v", err) + } - if len(person1.Accounts) != 1 || person1.Accounts[0].IdAccount != "account" { - t.Errorf("should preload correct accounts") - } + if len(person1.Accounts) != 1 || person1.Accounts[0].IdAccount != "account" { + t.Errorf("should preload correct accounts") + } } type CustomizeUser struct { - gorm.Model - Email string `sql:"column:email_address"` + gorm.Model + Email string `sql:"column:email_address"` } type CustomizeInvitation struct { - gorm.Model - Address string `sql:"column:invitation"` - Person *CustomizeUser `gorm:"foreignkey:Email;associationforeignkey:invitation"` + gorm.Model + Address string `sql:"column:invitation"` + Person *CustomizeUser `gorm:"foreignkey:Email;associationforeignkey:invitation"` } func TestOneToOneWithCustomizedColumn(t *testing.T) { - DB.DropTable(&CustomizeUser{}, &CustomizeInvitation{}) - DB.AutoMigrate(&CustomizeUser{}, &CustomizeInvitation{}) + DB.DropTable(&CustomizeUser{}, &CustomizeInvitation{}) + DB.AutoMigrate(&CustomizeUser{}, &CustomizeInvitation{}) - user := CustomizeUser{ - Email: "hello@example.com", - } - invitation := CustomizeInvitation{ - Address: "hello@example.com", - } + user := CustomizeUser{ + Email: "hello@example.com", + } + invitation := CustomizeInvitation{ + Address: "hello@example.com", + } - DB.Create(&user) - DB.Create(&invitation) + DB.Create(&user) + DB.Create(&invitation) - var invitation2 CustomizeInvitation - if err := DB.Preload("Person").Find(&invitation2, invitation.ID).Error; err != nil { - t.Errorf("no error should happen, but got %v", err) - } + var invitation2 CustomizeInvitation + if err := DB.Preload("Person").Find(&invitation2, invitation.ID).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } - if invitation2.Person.Email != user.Email { - t.Errorf("Should preload one to one relation with customize foreign keys") - } + if invitation2.Person.Email != user.Email { + t.Errorf("Should preload one to one relation with customize foreign keys") + } } type PromotionDiscount struct { - gorm.Model - Name string - Coupons []*PromotionCoupon `gorm:"ForeignKey:discount_id"` - Rule *PromotionRule `gorm:"ForeignKey:discount_id"` - Benefits []PromotionBenefit `gorm:"ForeignKey:promotion_id"` + gorm.Model + Name string + Coupons []*PromotionCoupon `gorm:"ForeignKey:discount_id"` + Rule *PromotionRule `gorm:"ForeignKey:discount_id"` + Benefits []PromotionBenefit `gorm:"ForeignKey:promotion_id"` } type PromotionBenefit struct { - gorm.Model - Name string - PromotionID uint - Discount PromotionDiscount `gorm:"ForeignKey:promotion_id"` + gorm.Model + Name string + PromotionID uint + Discount PromotionDiscount `gorm:"ForeignKey:promotion_id"` } type PromotionCoupon struct { - gorm.Model - Code string - DiscountID uint - Discount PromotionDiscount + gorm.Model + Code string + DiscountID uint + Discount PromotionDiscount } type PromotionRule struct { - gorm.Model - Name string - Begin *time.Time - End *time.Time - DiscountID uint - Discount *PromotionDiscount + gorm.Model + Name string + Begin *time.Time + End *time.Time + DiscountID uint + Discount *PromotionDiscount } func TestOneToManyWithCustomizedColumn(t *testing.T) { - DB.DropTable(&PromotionDiscount{}, &PromotionCoupon{}) - DB.AutoMigrate(&PromotionDiscount{}, &PromotionCoupon{}) + DB.DropTable(&PromotionDiscount{}, &PromotionCoupon{}) + DB.AutoMigrate(&PromotionDiscount{}, &PromotionCoupon{}) - discount := PromotionDiscount{ - Name: "Happy New Year", - Coupons: []*PromotionCoupon{ - {Code: "newyear1"}, - {Code: "newyear2"}, - }, - } + discount := PromotionDiscount{ + Name: "Happy New Year", + Coupons: []*PromotionCoupon{ + {Code: "newyear1"}, + {Code: "newyear2"}, + }, + } - if err := DB.Create(&discount).Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } + if err := DB.Create(&discount).Error; err != nil { + t.Errorf("no error should happen but got %v", err) + } - var discount1 PromotionDiscount - if err := DB.Preload("Coupons").First(&discount1, "id = ?", discount.ID).Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } + var discount1 PromotionDiscount + if err := DB.Preload("Coupons").First(&discount1, "id = ?", discount.ID).Error; err != nil { + t.Errorf("no error should happen but got %v", err) + } - if len(discount.Coupons) != 2 { - t.Errorf("should find two coupons") - } + if len(discount.Coupons) != 2 { + t.Errorf("should find two coupons") + } - var coupon PromotionCoupon - if err := DB.Preload("Discount").First(&coupon, "code = ?", "newyear1").Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } + var coupon PromotionCoupon + if err := DB.Preload("Discount").First(&coupon, "code = ?", "newyear1").Error; err != nil { + t.Errorf("no error should happen but got %v", err) + } - if coupon.Discount.Name != "Happy New Year" { - t.Errorf("should preload discount from coupon") - } + if coupon.Discount.Name != "Happy New Year" { + t.Errorf("should preload discount from coupon") + } } func TestHasOneWithPartialCustomizedColumn(t *testing.T) { - DB.DropTable(&PromotionDiscount{}, &PromotionRule{}) - DB.AutoMigrate(&PromotionDiscount{}, &PromotionRule{}) + DB.DropTable(&PromotionDiscount{}, &PromotionRule{}) + DB.AutoMigrate(&PromotionDiscount{}, &PromotionRule{}) - var begin = time.Now() - var end = time.Now().Add(24 * time.Hour) - discount := PromotionDiscount{ - Name: "Happy New Year 2", - Rule: &PromotionRule{ - Name: "time_limited", - Begin: &begin, - End: &end, - }, - } + var begin = time.Now() + var end = time.Now().Add(24 * time.Hour) + discount := PromotionDiscount{ + Name: "Happy New Year 2", + Rule: &PromotionRule{ + Name: "time_limited", + Begin: &begin, + End: &end, + }, + } - if err := DB.Create(&discount).Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } + if err := DB.Create(&discount).Error; err != nil { + t.Errorf("no error should happen but got %v", err) + } - var discount1 PromotionDiscount - if err := DB.Preload("Rule").First(&discount1, "id = ?", discount.ID).Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } + var discount1 PromotionDiscount + if err := DB.Preload("Rule").First(&discount1, "id = ?", discount.ID).Error; err != nil { + t.Errorf("no error should happen but got %v", err) + } - if discount.Rule.Begin.Format(time.RFC3339Nano) != begin.Format(time.RFC3339Nano) { - t.Errorf("Should be able to preload Rule") - } + if discount.Rule.Begin.Format(time.RFC3339Nano) != begin.Format(time.RFC3339Nano) { + t.Errorf("Should be able to preload Rule") + } - var rule PromotionRule - if err := DB.Preload("Discount").First(&rule, "name = ?", "time_limited").Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } + var rule PromotionRule + if err := DB.Preload("Discount").First(&rule, "name = ?", "time_limited").Error; err != nil { + t.Errorf("no error should happen but got %v", err) + } - if rule.Discount.Name != "Happy New Year 2" { - t.Errorf("should preload discount from rule") - } + if rule.Discount.Name != "Happy New Year 2" { + t.Errorf("should preload discount from rule") + } } func TestBelongsToWithPartialCustomizedColumn(t *testing.T) { - DB.DropTable(&PromotionDiscount{}, &PromotionBenefit{}) - DB.AutoMigrate(&PromotionDiscount{}, &PromotionBenefit{}) + DB.DropTable(&PromotionDiscount{}, &PromotionBenefit{}) + DB.AutoMigrate(&PromotionDiscount{}, &PromotionBenefit{}) - discount := PromotionDiscount{ - Name: "Happy New Year 3", - Benefits: []PromotionBenefit{ - {Name: "free cod"}, - {Name: "free shipping"}, - }, - } + discount := PromotionDiscount{ + Name: "Happy New Year 3", + Benefits: []PromotionBenefit{ + {Name: "free cod"}, + {Name: "free shipping"}, + }, + } - if err := DB.Create(&discount).Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } + if err := DB.Create(&discount).Error; err != nil { + t.Errorf("no error should happen but got %v", err) + } - var discount1 PromotionDiscount - if err := DB.Preload("Benefits").First(&discount1, "id = ?", discount.ID).Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } + var discount1 PromotionDiscount + if err := DB.Preload("Benefits").First(&discount1, "id = ?", discount.ID).Error; err != nil { + t.Errorf("no error should happen but got %v", err) + } - if len(discount.Benefits) != 2 { - t.Errorf("should find two benefits") - } + if len(discount.Benefits) != 2 { + t.Errorf("should find two benefits") + } - var benefit PromotionBenefit - if err := DB.Preload("Discount").First(&benefit, "name = ?", "free cod").Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } + var benefit PromotionBenefit + if err := DB.Preload("Discount").First(&benefit, "name = ?", "free cod").Error; err != nil { + t.Errorf("no error should happen but got %v", err) + } - if benefit.Discount.Name != "Happy New Year 3" { - t.Errorf("should preload discount from coupon") - } + if benefit.Discount.Name != "Happy New Year 3" { + t.Errorf("should preload discount from coupon") + } } type SelfReferencingUser struct { - gorm.Model - Name string - Friends []*SelfReferencingUser `gorm:"many2many:UserFriends;association_jointable_foreignkey:friend_id"` + gorm.Model + Name string + Friends []*SelfReferencingUser `gorm:"many2many:UserFriends;association_jointable_foreignkey:friend_id"` } func TestSelfReferencingMany2ManyColumn(t *testing.T) { - DB.DropTable(&SelfReferencingUser{}, "UserFriends") - DB.AutoMigrate(&SelfReferencingUser{}) - if !DB.HasTable("UserFriends") { - t.Errorf("auto migrate error, table UserFriends should be created") - } + DB.DropTable(&SelfReferencingUser{}, "UserFriends") + DB.AutoMigrate(&SelfReferencingUser{}) + if !DB.HasTable("UserFriends") { + t.Errorf("auto migrate error, table UserFriends should be created") + } - friend1 := SelfReferencingUser{Name: "friend1_m2m"} - if err := DB.Create(&friend1).Error; err != nil { - t.Errorf("no error should happen, but got %v", err) - } + friend1 := SelfReferencingUser{Name: "friend1_m2m"} + if err := DB.Create(&friend1).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } - friend2 := SelfReferencingUser{Name: "friend2_m2m"} - if err := DB.Create(&friend2).Error; err != nil { - t.Errorf("no error should happen, but got %v", err) - } + friend2 := SelfReferencingUser{Name: "friend2_m2m"} + if err := DB.Create(&friend2).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } - user := SelfReferencingUser{ - Name: "self_m2m", - Friends: []*SelfReferencingUser{&friend1, &friend2}, - } + user := SelfReferencingUser{ + Name: "self_m2m", + Friends: []*SelfReferencingUser{&friend1, &friend2}, + } - if err := DB.Create(&user).Error; err != nil { - t.Errorf("no error should happen, but got %v", err) - } + if err := DB.Create(&user).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } - if DB.Model(&user).Association("Friends").Count() != 2 { - t.Errorf("Should find created friends correctly") - } + if DB.Model(&user).Association("Friends").Count() != 2 { + t.Errorf("Should find created friends correctly") + } - var count int - if err := DB.Table("UserFriends").Count(&count).Error; err != nil { - t.Errorf("no error should happen, but got %v", err) - } - if count == 0 { - t.Errorf("table UserFriends should have records") - } + var count int + if err := DB.Table("UserFriends").Count(&count).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } + if count == 0 { + t.Errorf("table UserFriends should have records") + } - var newUser = SelfReferencingUser{} + var newUser = SelfReferencingUser{} - if err := DB.Preload("Friends").First(&newUser, "id = ?", user.ID).Error; err != nil { - t.Errorf("no error should happen, but got %v", err) - } + if err := DB.Preload("Friends").First(&newUser, "id = ?", user.ID).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } - if len(newUser.Friends) != 2 { - t.Errorf("Should preload created frineds for self reference m2m") - } + if len(newUser.Friends) != 2 { + t.Errorf("Should preload created frineds for self reference m2m") + } - DB.Model(&newUser).Association("Friends").Append(&SelfReferencingUser{Name: "friend3_m2m"}) - if DB.Model(&user).Association("Friends").Count() != 3 { - t.Errorf("Should find created friends correctly") - } + DB.Model(&newUser).Association("Friends").Append(&SelfReferencingUser{Name: "friend3_m2m"}) + if DB.Model(&user).Association("Friends").Count() != 3 { + t.Errorf("Should find created friends correctly") + } - DB.Model(&newUser).Association("Friends").Replace(&SelfReferencingUser{Name: "friend4_m2m"}) - if DB.Model(&user).Association("Friends").Count() != 1 { - t.Errorf("Should find created friends correctly") - } + DB.Model(&newUser).Association("Friends").Replace(&SelfReferencingUser{Name: "friend4_m2m"}) + if DB.Model(&user).Association("Friends").Count() != 1 { + t.Errorf("Should find created friends correctly") + } - friend := SelfReferencingUser{} - DB.Model(&newUser).Association("Friends").Find(&friend) - if friend.Name != "friend4_m2m" { - t.Errorf("Should find created friends correctly") - } + friend := SelfReferencingUser{} + DB.Model(&newUser).Association("Friends").Find(&friend) + if friend.Name != "friend4_m2m" { + t.Errorf("Should find created friends correctly") + } - DB.Model(&newUser).Association("Friends").Delete(friend) - if DB.Model(&user).Association("Friends").Count() != 0 { - t.Errorf("All friends should be deleted") - } + DB.Model(&newUser).Association("Friends").Delete(friend) + if DB.Model(&user).Association("Friends").Count() != 0 { + t.Errorf("All friends should be deleted") + } } diff --git a/dialect.go b/dialect.go index b047546c..bcc0f4fc 100644 --- a/dialect.go +++ b/dialect.go @@ -25,28 +25,28 @@ type Dialect interface { DataTypeOf(field *StructField) string // HasIndex check has index or not HasIndex(tableName string, indexName string) bool - // HasIndexContext same as HasIndex - HasIndexContext(ctx context.Context, tableName string, indexName string) bool - // HasForeignKey check has foreign key or not + // HasIndexContext same as HasIndex + HasIndexContext(ctx context.Context, tableName string, indexName string) bool + // HasForeignKey check has foreign key or not HasForeignKey(tableName string, foreignKeyName string) bool - // HasForeignKeyContext same as HasForeignKey - HasForeignKeyContext(ctx context.Context, tableName string, foreignKeyName string) bool - // RemoveIndex remove index + // HasForeignKeyContext same as HasForeignKey + HasForeignKeyContext(ctx context.Context, tableName string, foreignKeyName string) bool + // RemoveIndex remove index RemoveIndex(tableName string, indexName string) error // RemoveIndexContext same as RemoveIndex - RemoveIndexContext(ctx context.Context, tableName string, indexName string) error - // HasTable check has table or not + RemoveIndexContext(ctx context.Context, tableName string, indexName string) error + // HasTable check has table or not HasTable(tableName string) bool - // HasTableContext same as HasTable - HasTableContext(ctx context.Context, tableName string) bool + // HasTableContext same as HasTable + HasTableContext(ctx context.Context, tableName string) bool // HasColumn check has column or not HasColumn(tableName string, columnName string) bool // HasColumnContext same as HasColumn - HasColumnContext(ctx context.Context, tableName string, columnName string) bool - // ModifyColumn modify column's type + HasColumnContext(ctx context.Context, tableName string, columnName string) bool + // ModifyColumn modify column's type ModifyColumn(tableName string, columnName string, typ string) error - // ModifyColumnContext same as ModifyColumn - ModifyColumnContext(ctx context.Context, tableName string, columnName string, typ string) error + // ModifyColumnContext same as ModifyColumn + ModifyColumnContext(ctx context.Context, tableName string, columnName string, typ string) error // LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case LimitAndOffsetSQL(limit, offset interface{}) (string, error) @@ -55,12 +55,12 @@ type Dialect interface { // LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT` LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string // LastInsertIDOutputInterstitialContext same as LastInsertIDOutputInterstitial - LastInsertIDOutputInterstitialContext(ctx context.Context, tableName, columnName string, columns []string) string - // LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING` + LastInsertIDOutputInterstitialContext(ctx context.Context, tableName, columnName string, columns []string) string + // LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING` LastInsertIDReturningSuffix(tableName, columnName string) string // LastInsertIDReturningSuffixContext same as LastInsertIDReturningSuffix - LastInsertIDReturningSuffixContext(ctx context.Context, tableName, columnName string) string - // DefaultValueStr + LastInsertIDReturningSuffixContext(ctx context.Context, tableName, columnName string) string + // DefaultValueStr DefaultValueStr() string // BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference @@ -71,8 +71,8 @@ type Dialect interface { // CurrentDatabase return current database name CurrentDatabase() string - // CurrentDatabaseContext same as CurrentDatabase - CurrentDatabaseContext(ctx context.Context) string + // CurrentDatabaseContext same as CurrentDatabase + CurrentDatabaseContext(ctx context.Context) string } var dialectsMap = map[string]Dialect{} diff --git a/dialect_common.go b/dialect_common.go index c365ff10..4af3187c 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -101,8 +101,8 @@ func (s *commonDialect) DataTypeOf(field *StructField) string { } func (s commonDialect) HasIndex(tableName string, indexName string) bool { - ctx := context.Background() - return s.HasIndexContext(ctx, tableName, indexName) + ctx := context.Background() + return s.HasIndexContext(ctx, tableName, indexName) } func (s commonDialect) HasIndexContext(ctx context.Context, tableName string, indexName string) bool { @@ -113,8 +113,8 @@ func (s commonDialect) HasIndexContext(ctx context.Context, tableName string, in } func (s commonDialect) RemoveIndex(tableName string, indexName string) error { - ctx := context.Background() - return s.RemoveIndexContext(ctx, tableName, indexName) + ctx := context.Background() + return s.RemoveIndexContext(ctx, tableName, indexName) } func (s commonDialect) RemoveIndexContext(ctx context.Context, tableName string, indexName string) error { @@ -123,8 +123,8 @@ func (s commonDialect) RemoveIndexContext(ctx context.Context, tableName string, } func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool { - ctx := context.Background() - return s.HasForeignKeyContext(ctx, tableName, foreignKeyName) + ctx := context.Background() + return s.HasForeignKeyContext(ctx, tableName, foreignKeyName) } func (s commonDialect) HasForeignKeyContext(_ctx context.Context, tableName string, foreignKeyName string) bool { @@ -132,8 +132,8 @@ func (s commonDialect) HasForeignKeyContext(_ctx context.Context, tableName stri } func (s commonDialect) HasTable(tableName string) bool { - ctx := context.Background() - return s.HasTableContext(ctx, tableName) + ctx := context.Background() + return s.HasTableContext(ctx, tableName) } func (s commonDialect) HasTableContext(ctx context.Context, tableName string) bool { @@ -144,8 +144,8 @@ func (s commonDialect) HasTableContext(ctx context.Context, tableName string) bo } func (s commonDialect) HasColumn(tableName string, columnName string) bool { - ctx := context.Background() - return s.HasColumnContext(ctx, tableName, columnName) + ctx := context.Background() + return s.HasColumnContext(ctx, tableName, columnName) } func (s commonDialect) HasColumnContext(ctx context.Context, tableName string, columnName string) bool { @@ -156,8 +156,8 @@ func (s commonDialect) HasColumnContext(ctx context.Context, tableName string, c } func (s commonDialect) ModifyColumn(tableName string, columnName string, typ string) error { - ctx := context.Background() - return s.ModifyColumnContext(ctx, tableName, columnName, typ) + ctx := context.Background() + return s.ModifyColumnContext(ctx, tableName, columnName, typ) } func (s commonDialect) ModifyColumnContext(ctx context.Context, tableName string, columnName string, typ string) error { @@ -166,8 +166,8 @@ func (s commonDialect) ModifyColumnContext(ctx context.Context, tableName string } func (s commonDialect) CurrentDatabase() (name string) { - ctx := context.Background() - return s.CurrentDatabaseContext(ctx) + ctx := context.Background() + return s.CurrentDatabaseContext(ctx) } func (s commonDialect) CurrentDatabaseContext(ctx context.Context) (name string) { @@ -199,8 +199,8 @@ func (commonDialect) SelectFromDummyTable() string { } func (s commonDialect) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string { - ctx := context.Background() - return s.LastInsertIDOutputInterstitialContext(ctx, tableName, columnName, columns) + ctx := context.Background() + return s.LastInsertIDOutputInterstitialContext(ctx, tableName, columnName, columns) } func (commonDialect) LastInsertIDOutputInterstitialContext(_ctx context.Context, tableName, columnName string, columns []string) string { @@ -208,8 +208,8 @@ func (commonDialect) LastInsertIDOutputInterstitialContext(_ctx context.Context, } func (s commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string { - ctx := context.Background() - return s.LastInsertIDReturningSuffixContext(ctx, tableName, columnName) + ctx := context.Background() + return s.LastInsertIDReturningSuffixContext(ctx, tableName, columnName) } func (commonDialect) LastInsertIDReturningSuffixContext(_ctx context.Context, tableName, columnName string) string { diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 451c3ad0..12414741 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -124,8 +124,8 @@ func (s mssql) fieldCanAutoIncrement(field *gorm.StructField) bool { } func (s mssql) HasIndex(tableName string, indexName string) bool { - ctx := context.Background() - return s.HasIndexContext(ctx, tableName, indexName) + ctx := context.Background() + return s.HasIndexContext(ctx, tableName, indexName) } func (s mssql) HasIndexContext(ctx context.Context, tableName string, indexName string) bool { @@ -135,8 +135,8 @@ func (s mssql) HasIndexContext(ctx context.Context, tableName string, indexName } func (s mssql) RemoveIndex(tableName string, indexName string) error { - ctx := context.Background() - return s.RemoveIndexContext(ctx, tableName, indexName) + ctx := context.Background() + return s.RemoveIndexContext(ctx, tableName, indexName) } func (s mssql) RemoveIndexContext(ctx context.Context, tableName string, indexName string) error { @@ -145,8 +145,8 @@ func (s mssql) RemoveIndexContext(ctx context.Context, tableName string, indexNa } func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool { - ctx := context.Background() - return s.HasForeignKeyContext(ctx, tableName, foreignKeyName) + ctx := context.Background() + return s.HasForeignKeyContext(ctx, tableName, foreignKeyName) } func (s mssql) HasForeignKeyContext(ctx context.Context, tableName string, foreignKeyName string) bool { @@ -161,8 +161,8 @@ func (s mssql) HasForeignKeyContext(ctx context.Context, tableName string, forei } func (s mssql) HasTable(tableName string) bool { - ctx := context.Background() - return s.HasTableContext(ctx, tableName) + ctx := context.Background() + return s.HasTableContext(ctx, tableName) } func (s mssql) HasTableContext(ctx context.Context, tableName string) bool { @@ -173,8 +173,8 @@ func (s mssql) HasTableContext(ctx context.Context, tableName string) bool { } func (s mssql) HasColumn(tableName string, columnName string) bool { - ctx := context.Background() - return s.HasColumnContext(ctx, tableName, columnName) + ctx := context.Background() + return s.HasColumnContext(ctx, tableName, columnName) } func (s mssql) HasColumnContext(ctx context.Context, tableName string, columnName string) bool { @@ -185,8 +185,8 @@ func (s mssql) HasColumnContext(ctx context.Context, tableName string, columnNam } func (s mssql) ModifyColumn(tableName string, columnName string, typ string) error { - ctx := context.Background() - return s.ModifyColumnContext(ctx, tableName, columnName, typ) + ctx := context.Background() + return s.ModifyColumnContext(ctx, tableName, columnName, typ) } func (s mssql) ModifyColumnContext(ctx context.Context, tableName string, columnName string, typ string) error { @@ -195,9 +195,9 @@ func (s mssql) ModifyColumnContext(ctx context.Context, tableName string, column } func (s mssql) CurrentDatabase() (name string) { - ctx := context.Background() - s.CurrentDatabaseContext(ctx) - return + ctx := context.Background() + s.CurrentDatabaseContext(ctx) + return } func (s mssql) CurrentDatabaseContext(ctx context.Context) (name string) { @@ -236,8 +236,8 @@ func (mssql) SelectFromDummyTable() string { } func (s mssql) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string { - ctx := context.Background() - return s.LastInsertIDOutputInterstitialContext(ctx, tableName, columnName, columns) + ctx := context.Background() + return s.LastInsertIDOutputInterstitialContext(ctx, tableName, columnName, columns) } func (mssql) LastInsertIDOutputInterstitialContext(_ctx context.Context, tableName, columnName string, columns []string) string { @@ -249,8 +249,8 @@ func (mssql) LastInsertIDOutputInterstitialContext(_ctx context.Context, tableNa } func (s mssql) LastInsertIDReturningSuffix(tableName, columnName string) string { - ctx := context.Background() - return s.LastInsertIDReturningSuffixContext(ctx, tableName, columnName) + ctx := context.Background() + return s.LastInsertIDReturningSuffixContext(ctx, tableName, columnName) } func (mssql) LastInsertIDReturningSuffixContext(_ctx context.Context, tableName, columnName string) string { diff --git a/embedded_struct_test.go b/embedded_struct_test.go index 380b3e5c..5f8ece57 100644 --- a/embedded_struct_test.go +++ b/embedded_struct_test.go @@ -3,89 +3,89 @@ package gorm_test import "testing" type BasePost struct { - Id int64 - Title string - URL string + Id int64 + Title string + URL string } type Author struct { - ID string - Name string - Email string + ID string + Name string + Email string } type HNPost struct { - BasePost - Author `gorm:"embedded_prefix:user_"` // Embedded struct - Upvotes int32 + BasePost + Author `gorm:"embedded_prefix:user_"` // Embedded struct + Upvotes int32 } type EngadgetPost struct { - BasePost BasePost `gorm:"embedded"` - Author Author `gorm:"embedded;embedded_prefix:author_"` // Embedded struct - ImageUrl string + BasePost BasePost `gorm:"embedded"` + Author Author `gorm:"embedded;embedded_prefix:author_"` // Embedded struct + ImageUrl string } func TestPrefixColumnNameForEmbeddedStruct(t *testing.T) { - dialect := DB.NewScope(&EngadgetPost{}).Dialect() - engadgetPostScope := DB.NewScope(&EngadgetPost{}) - if !dialect.HasColumn(engadgetPostScope.TableName(), "author_id") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_name") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_email") { - t.Errorf("should has prefix for embedded columns") - } + dialect := DB.NewScope(&EngadgetPost{}).Dialect() + engadgetPostScope := DB.NewScope(&EngadgetPost{}) + if !dialect.HasColumn(engadgetPostScope.TableName(), "author_id") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_name") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_email") { + t.Errorf("should has prefix for embedded columns") + } - if len(engadgetPostScope.PrimaryFields()) != 1 { - t.Errorf("should have only one primary field with embedded struct, but got %v", len(engadgetPostScope.PrimaryFields())) - } + if len(engadgetPostScope.PrimaryFields()) != 1 { + t.Errorf("should have only one primary field with embedded struct, but got %v", len(engadgetPostScope.PrimaryFields())) + } - hnScope := DB.NewScope(&HNPost{}) - if !dialect.HasColumn(hnScope.TableName(), "user_id") || !dialect.HasColumn(hnScope.TableName(), "user_name") || !dialect.HasColumn(hnScope.TableName(), "user_email") { - t.Errorf("should has prefix for embedded columns") - } + hnScope := DB.NewScope(&HNPost{}) + if !dialect.HasColumn(hnScope.TableName(), "user_id") || !dialect.HasColumn(hnScope.TableName(), "user_name") || !dialect.HasColumn(hnScope.TableName(), "user_email") { + t.Errorf("should has prefix for embedded columns") + } } func TestSaveAndQueryEmbeddedStruct(t *testing.T) { - DB.Save(&HNPost{BasePost: BasePost{Title: "news"}}) - DB.Save(&HNPost{BasePost: BasePost{Title: "hn_news"}}) - var news HNPost - if err := DB.First(&news, "title = ?", "hn_news").Error; err != nil { - t.Errorf("no error should happen when query with embedded struct, but got %v", err) - } else if news.Title != "hn_news" { - t.Errorf("embedded struct's value should be scanned correctly") - } + DB.Save(&HNPost{BasePost: BasePost{Title: "news"}}) + DB.Save(&HNPost{BasePost: BasePost{Title: "hn_news"}}) + var news HNPost + if err := DB.First(&news, "title = ?", "hn_news").Error; err != nil { + t.Errorf("no error should happen when query with embedded struct, but got %v", err) + } else if news.Title != "hn_news" { + t.Errorf("embedded struct's value should be scanned correctly") + } - DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}}) - var egNews EngadgetPost - if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil { - t.Errorf("no error should happen when query with embedded struct, but got %v", err) - } else if egNews.BasePost.Title != "engadget_news" { - t.Errorf("embedded struct's value should be scanned correctly") - } + DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}}) + var egNews EngadgetPost + if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil { + t.Errorf("no error should happen when query with embedded struct, but got %v", err) + } else if egNews.BasePost.Title != "engadget_news" { + t.Errorf("embedded struct's value should be scanned correctly") + } - if DB.NewScope(&HNPost{}).PrimaryField() == nil { - t.Errorf("primary key with embedded struct should works") - } + if DB.NewScope(&HNPost{}).PrimaryField() == nil { + t.Errorf("primary key with embedded struct should works") + } - for _, field := range DB.NewScope(&HNPost{}).Fields() { - if field.Name == "BasePost" { - t.Errorf("scope Fields should not contain embedded struct") - } - } + for _, field := range DB.NewScope(&HNPost{}).Fields() { + if field.Name == "BasePost" { + t.Errorf("scope Fields should not contain embedded struct") + } + } } func TestEmbeddedPointerTypeStruct(t *testing.T) { - type HNPost struct { - *BasePost - Upvotes int32 - } + type HNPost struct { + *BasePost + Upvotes int32 + } - DB.Create(&HNPost{BasePost: &BasePost{Title: "embedded_pointer_type"}}) + DB.Create(&HNPost{BasePost: &BasePost{Title: "embedded_pointer_type"}}) - var hnPost HNPost - if err := DB.First(&hnPost, "title = ?", "embedded_pointer_type").Error; err != nil { - t.Errorf("No error should happen when find embedded pointer type, but got %v", err) - } + var hnPost HNPost + if err := DB.First(&hnPost, "title = ?", "embedded_pointer_type").Error; err != nil { + t.Errorf("No error should happen when find embedded pointer type, but got %v", err) + } - if hnPost.Title != "embedded_pointer_type" { - t.Errorf("Should find correct value for embedded pointer type") - } + if hnPost.Title != "embedded_pointer_type" { + t.Errorf("Should find correct value for embedded pointer type") + } } diff --git a/interface.go b/interface.go index fc551247..b7473fb8 100644 --- a/interface.go +++ b/interface.go @@ -11,10 +11,10 @@ type SQLCommon interface { Prepare(query string) (*sql.Stmt, error) Query(query string, args ...interface{}) (*sql.Rows, error) QueryRow(query string, args ...interface{}) *sql.Row - ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) - PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) - QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) - QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row } type sqlDb interface { diff --git a/main.go b/main.go index 754b953f..c17d6f72 100644 --- a/main.go +++ b/main.go @@ -675,7 +675,6 @@ func (s *DB) Begin() *DB { return s.BeginTx(context.Background(), &sql.TxOptions{}) } - // BeginTx begins a transaction with options func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) *DB { c := s.clone() diff --git a/migration_test.go b/migration_test.go index 36021fa6..d94ec9ec 100644 --- a/migration_test.go +++ b/migration_test.go @@ -1,579 +1,579 @@ package gorm_test import ( - "database/sql" - "database/sql/driver" - "errors" - "fmt" - "os" - "reflect" - "strconv" - "testing" - "time" + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "os" + "reflect" + "strconv" + "testing" + "time" - "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm" ) type User struct { - Id int64 - Age int64 - UserNum Num - Name string `sql:"size:255"` - Email string - Birthday *time.Time // Time - CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically - UpdatedAt time.Time // UpdatedAt: Time of record is updated, will be updated automatically - Emails []Email // Embedded structs - BillingAddress Address // Embedded struct - BillingAddressID sql.NullInt64 // Embedded struct's foreign key - ShippingAddress Address // Embedded struct - ShippingAddressId int64 // Embedded struct's foreign key - CreditCard CreditCard - Latitude float64 - Languages []Language `gorm:"many2many:user_languages;"` - CompanyID *int - Company Company - Role Role - Password EncryptedData - PasswordHash []byte - IgnoreMe int64 `sql:"-"` - IgnoreStringSlice []string `sql:"-"` - Ignored struct{ Name string } `sql:"-"` - IgnoredPointer *User `sql:"-"` + Id int64 + Age int64 + UserNum Num + Name string `sql:"size:255"` + Email string + Birthday *time.Time // Time + CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically + UpdatedAt time.Time // UpdatedAt: Time of record is updated, will be updated automatically + Emails []Email // Embedded structs + BillingAddress Address // Embedded struct + BillingAddressID sql.NullInt64 // Embedded struct's foreign key + ShippingAddress Address // Embedded struct + ShippingAddressId int64 // Embedded struct's foreign key + CreditCard CreditCard + Latitude float64 + Languages []Language `gorm:"many2many:user_languages;"` + CompanyID *int + Company Company + Role Role + Password EncryptedData + PasswordHash []byte + IgnoreMe int64 `sql:"-"` + IgnoreStringSlice []string `sql:"-"` + Ignored struct{ Name string } `sql:"-"` + IgnoredPointer *User `sql:"-"` } type NotSoLongTableName struct { - Id int64 - ReallyLongThingID int64 - ReallyLongThing ReallyLongTableNameToTestMySQLNameLengthLimit + Id int64 + ReallyLongThingID int64 + ReallyLongThing ReallyLongTableNameToTestMySQLNameLengthLimit } type ReallyLongTableNameToTestMySQLNameLengthLimit struct { - Id int64 + Id int64 } type ReallyLongThingThatReferencesShort struct { - Id int64 - ShortID int64 - Short Short + Id int64 + ShortID int64 + Short Short } type Short struct { - Id int64 + Id int64 } type CreditCard struct { - ID int8 - Number string - UserId sql.NullInt64 - CreatedAt time.Time `sql:"not null"` - UpdatedAt time.Time - DeletedAt *time.Time `sql:"column:deleted_time"` + ID int8 + Number string + UserId sql.NullInt64 + CreatedAt time.Time `sql:"not null"` + UpdatedAt time.Time + DeletedAt *time.Time `sql:"column:deleted_time"` } type Email struct { - Id int16 - UserId int - Email string `sql:"type:varchar(100);"` - CreatedAt time.Time - UpdatedAt time.Time + Id int16 + UserId int + Email string `sql:"type:varchar(100);"` + CreatedAt time.Time + UpdatedAt time.Time } type Address struct { - ID int - Address1 string - Address2 string - Post string - CreatedAt time.Time - UpdatedAt time.Time - DeletedAt *time.Time + ID int + Address1 string + Address2 string + Post string + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt *time.Time } type Language struct { - gorm.Model - Name string - Users []User `gorm:"many2many:user_languages;"` + gorm.Model + Name string + Users []User `gorm:"many2many:user_languages;"` } type Product struct { - Id int64 - Code string - Price int64 - CreatedAt time.Time - UpdatedAt time.Time - AfterFindCallTimes int64 - BeforeCreateCallTimes int64 - AfterCreateCallTimes int64 - BeforeUpdateCallTimes int64 - AfterUpdateCallTimes int64 - BeforeSaveCallTimes int64 - AfterSaveCallTimes int64 - BeforeDeleteCallTimes int64 - AfterDeleteCallTimes int64 + Id int64 + Code string + Price int64 + CreatedAt time.Time + UpdatedAt time.Time + AfterFindCallTimes int64 + BeforeCreateCallTimes int64 + AfterCreateCallTimes int64 + BeforeUpdateCallTimes int64 + AfterUpdateCallTimes int64 + BeforeSaveCallTimes int64 + AfterSaveCallTimes int64 + BeforeDeleteCallTimes int64 + AfterDeleteCallTimes int64 } type Company struct { - Id int64 - Name string - Owner *User `sql:"-"` + Id int64 + Name string + Owner *User `sql:"-"` } type Place struct { - Id int64 - PlaceAddressID int - PlaceAddress *Address `gorm:"save_associations:false"` - OwnerAddressID int - OwnerAddress *Address `gorm:"save_associations:true"` + Id int64 + PlaceAddressID int + PlaceAddress *Address `gorm:"save_associations:false"` + OwnerAddressID int + OwnerAddress *Address `gorm:"save_associations:true"` } type EncryptedData []byte func (data *EncryptedData) Scan(value interface{}) error { - if b, ok := value.([]byte); ok { - if len(b) < 3 || b[0] != '*' || b[1] != '*' || b[2] != '*' { - return errors.New("Too short") - } + if b, ok := value.([]byte); ok { + if len(b) < 3 || b[0] != '*' || b[1] != '*' || b[2] != '*' { + return errors.New("Too short") + } - *data = b[3:] - return nil - } + *data = b[3:] + return nil + } - return errors.New("Bytes expected") + return errors.New("Bytes expected") } func (data EncryptedData) Value() (driver.Value, error) { - if len(data) > 0 && data[0] == 'x' { - //needed to test failures - return nil, errors.New("Should not start with 'x'") - } + if len(data) > 0 && data[0] == 'x' { + //needed to test failures + return nil, errors.New("Should not start with 'x'") + } - //prepend asterisks - return append([]byte("***"), data...), nil + //prepend asterisks + return append([]byte("***"), data...), nil } type Role struct { - Name string `gorm:"size:256"` + Name string `gorm:"size:256"` } func (role *Role) Scan(value interface{}) error { - if b, ok := value.([]uint8); ok { - role.Name = string(b) - } else { - role.Name = value.(string) - } - return nil + if b, ok := value.([]uint8); ok { + role.Name = string(b) + } else { + role.Name = value.(string) + } + return nil } func (role Role) Value() (driver.Value, error) { - return role.Name, nil + return role.Name, nil } func (role Role) IsAdmin() bool { - return role.Name == "admin" + return role.Name == "admin" } type Num int64 func (i *Num) Scan(src interface{}) error { - switch s := src.(type) { - case []byte: - n, _ := strconv.Atoi(string(s)) - *i = Num(n) - case int64: - *i = Num(s) - default: - return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String()) - } - return nil + switch s := src.(type) { + case []byte: + n, _ := strconv.Atoi(string(s)) + *i = Num(n) + case int64: + *i = Num(s) + default: + return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String()) + } + return nil } type Animal struct { - Counter uint64 `gorm:"primary_key:yes"` - Name string `sql:"DEFAULT:'galeone'"` - From string //test reserved sql keyword as field name - Age time.Time `sql:"DEFAULT:current_timestamp"` - unexported string // unexported value - CreatedAt time.Time - UpdatedAt time.Time + Counter uint64 `gorm:"primary_key:yes"` + Name string `sql:"DEFAULT:'galeone'"` + From string //test reserved sql keyword as field name + Age time.Time `sql:"DEFAULT:current_timestamp"` + unexported string // unexported value + CreatedAt time.Time + UpdatedAt time.Time } type JoinTable struct { - From uint64 - To uint64 - Time time.Time `sql:"default: null"` + From uint64 + To uint64 + Time time.Time `sql:"default: null"` } type Post struct { - Id int64 - CategoryId sql.NullInt64 - MainCategoryId int64 - Title string - Body string - Comments []*Comment - Category Category - MainCategory Category + Id int64 + CategoryId sql.NullInt64 + MainCategoryId int64 + Title string + Body string + Comments []*Comment + Category Category + MainCategory Category } type Category struct { - gorm.Model - Name string + gorm.Model + Name string - Categories []Category - CategoryID *uint + Categories []Category + CategoryID *uint } type Comment struct { - gorm.Model - PostId int64 - Content string - Post Post + gorm.Model + PostId int64 + Content string + Post Post } // Scanner type NullValue struct { - Id int64 - Name sql.NullString `sql:"not null"` - Gender *sql.NullString `sql:"not null"` - Age sql.NullInt64 - Male sql.NullBool - Height sql.NullFloat64 - AddedAt NullTime + Id int64 + Name sql.NullString `sql:"not null"` + Gender *sql.NullString `sql:"not null"` + Age sql.NullInt64 + Male sql.NullBool + Height sql.NullFloat64 + AddedAt NullTime } type NullTime struct { - Time time.Time - Valid bool + Time time.Time + Valid bool } func (nt *NullTime) Scan(value interface{}) error { - if value == nil { - nt.Valid = false - return nil - } - nt.Time, nt.Valid = value.(time.Time), true - return nil + if value == nil { + nt.Valid = false + return nil + } + nt.Time, nt.Valid = value.(time.Time), true + return nil } func (nt NullTime) Value() (driver.Value, error) { - if !nt.Valid { - return nil, nil - } - return nt.Time, nil + if !nt.Valid { + return nil, nil + } + return nt.Time, nil } func getPreparedUser(name string, role string) *User { - var company Company - DB.Where(Company{Name: role}).FirstOrCreate(&company) + var company Company + DB.Where(Company{Name: role}).FirstOrCreate(&company) - return &User{ - Name: name, - Age: 20, - Role: Role{role}, - BillingAddress: Address{Address1: fmt.Sprintf("Billing Address %v", name)}, - ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)}, - CreditCard: CreditCard{Number: fmt.Sprintf("123456%v", name)}, - Emails: []Email{ - {Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)}, - }, - Company: company, - Languages: []Language{ - {Name: fmt.Sprintf("lang_1_%v", name)}, - {Name: fmt.Sprintf("lang_2_%v", name)}, - }, - } + return &User{ + Name: name, + Age: 20, + Role: Role{role}, + BillingAddress: Address{Address1: fmt.Sprintf("Billing Address %v", name)}, + ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)}, + CreditCard: CreditCard{Number: fmt.Sprintf("123456%v", name)}, + Emails: []Email{ + {Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)}, + }, + Company: company, + Languages: []Language{ + {Name: fmt.Sprintf("lang_1_%v", name)}, + {Name: fmt.Sprintf("lang_2_%v", name)}, + }, + } } func runMigration() { - if err := DB.DropTableIfExists(&User{}).Error; err != nil { - fmt.Printf("Got error when try to delete table users, %+v\n", err) - } + if err := DB.DropTableIfExists(&User{}).Error; err != nil { + fmt.Printf("Got error when try to delete table users, %+v\n", err) + } - for _, table := range []string{"animals", "user_languages"} { - DB.Exec(fmt.Sprintf("drop table %v;", table)) - } + for _, table := range []string{"animals", "user_languages"} { + DB.Exec(fmt.Sprintf("drop table %v;", table)) + } - values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}, &Place{}} - for _, value := range values { - DB.DropTable(value) - } - if err := DB.AutoMigrate(values...).Error; err != nil { - panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) - } + values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}, &Place{}} + for _, value := range values { + DB.DropTable(value) + } + if err := DB.AutoMigrate(values...).Error; err != nil { + panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) + } } func TestIndexes(t *testing.T) { - if err := DB.Model(&Email{}).AddIndex("idx_email_email", "email").Error; err != nil { - t.Errorf("Got error when tried to create index: %+v", err) - } + if err := DB.Model(&Email{}).AddIndex("idx_email_email", "email").Error; err != nil { + t.Errorf("Got error when tried to create index: %+v", err) + } - scope := DB.NewScope(&Email{}) - if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") { - t.Errorf("Email should have index idx_email_email") - } + scope := DB.NewScope(&Email{}) + if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") { + t.Errorf("Email should have index idx_email_email") + } - if err := DB.Model(&Email{}).RemoveIndex("idx_email_email").Error; err != nil { - t.Errorf("Got error when tried to remove index: %+v", err) - } + if err := DB.Model(&Email{}).RemoveIndex("idx_email_email").Error; err != nil { + t.Errorf("Got error when tried to remove index: %+v", err) + } - if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") { - t.Errorf("Email's index idx_email_email should be deleted") - } + if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") { + t.Errorf("Email's index idx_email_email should be deleted") + } - if err := DB.Model(&Email{}).AddIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil { - t.Errorf("Got error when tried to create index: %+v", err) - } + if err := DB.Model(&Email{}).AddIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil { + t.Errorf("Got error when tried to create index: %+v", err) + } - if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { - t.Errorf("Email should have index idx_email_email_and_user_id") - } + if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { + t.Errorf("Email should have index idx_email_email_and_user_id") + } - if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil { - t.Errorf("Got error when tried to remove index: %+v", err) - } + if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil { + t.Errorf("Got error when tried to remove index: %+v", err) + } - if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { - t.Errorf("Email's index idx_email_email_and_user_id should be deleted") - } + if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { + t.Errorf("Email's index idx_email_email_and_user_id should be deleted") + } - if err := DB.Model(&Email{}).AddUniqueIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil { - t.Errorf("Got error when tried to create index: %+v", err) - } + if err := DB.Model(&Email{}).AddUniqueIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil { + t.Errorf("Got error when tried to create index: %+v", err) + } - if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { - t.Errorf("Email should have index idx_email_email_and_user_id") - } + if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { + t.Errorf("Email should have index idx_email_email_and_user_id") + } - if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.comiii"}, {Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error == nil { - t.Errorf("Should get to create duplicate record when having unique index") - } + if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.comiii"}, {Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error == nil { + t.Errorf("Should get to create duplicate record when having unique index") + } - var user = User{Name: "sample_user"} - DB.Save(&user) - if DB.Model(&user).Association("Emails").Append(Email{Email: "not-1duplicated@gmail.com"}, Email{Email: "not-duplicated2@gmail.com"}).Error != nil { - t.Errorf("Should get no error when append two emails for user") - } + var user = User{Name: "sample_user"} + DB.Save(&user) + if DB.Model(&user).Association("Emails").Append(Email{Email: "not-1duplicated@gmail.com"}, Email{Email: "not-duplicated2@gmail.com"}).Error != nil { + t.Errorf("Should get no error when append two emails for user") + } - if DB.Model(&user).Association("Emails").Append(Email{Email: "duplicated@gmail.com"}, Email{Email: "duplicated@gmail.com"}).Error == nil { - t.Errorf("Should get no duplicated email error when insert duplicated emails for a user") - } + if DB.Model(&user).Association("Emails").Append(Email{Email: "duplicated@gmail.com"}, Email{Email: "duplicated@gmail.com"}).Error == nil { + t.Errorf("Should get no duplicated email error when insert duplicated emails for a user") + } - if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil { - t.Errorf("Got error when tried to remove index: %+v", err) - } + if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil { + t.Errorf("Got error when tried to remove index: %+v", err) + } - if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { - t.Errorf("Email's index idx_email_email_and_user_id should be deleted") - } + if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { + t.Errorf("Email's index idx_email_email_and_user_id should be deleted") + } - if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error != nil { - t.Errorf("Should be able to create duplicated emails after remove unique index") - } + if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error != nil { + t.Errorf("Should be able to create duplicated emails after remove unique index") + } } type EmailWithIdx struct { - Id int64 - UserId int64 - Email string `sql:"index:idx_email_agent"` - UserAgent string `sql:"index:idx_email_agent"` - RegisteredAt *time.Time `sql:"unique_index"` - CreatedAt time.Time - UpdatedAt time.Time + Id int64 + UserId int64 + Email string `sql:"index:idx_email_agent"` + UserAgent string `sql:"index:idx_email_agent"` + RegisteredAt *time.Time `sql:"unique_index"` + CreatedAt time.Time + UpdatedAt time.Time } func TestAutoMigration(t *testing.T) { - DB.AutoMigrate(&Address{}) - DB.DropTable(&EmailWithIdx{}) - if err := DB.AutoMigrate(&EmailWithIdx{}).Error; err != nil { - t.Errorf("Auto Migrate should not raise any error") - } + DB.AutoMigrate(&Address{}) + DB.DropTable(&EmailWithIdx{}) + if err := DB.AutoMigrate(&EmailWithIdx{}).Error; err != nil { + t.Errorf("Auto Migrate should not raise any error") + } - now := time.Now() - DB.Save(&EmailWithIdx{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: &now}) + now := time.Now() + DB.Save(&EmailWithIdx{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: &now}) - scope := DB.NewScope(&EmailWithIdx{}) - if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_agent") { - t.Errorf("Failed to create index") - } + scope := DB.NewScope(&EmailWithIdx{}) + if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_agent") { + t.Errorf("Failed to create index") + } - if !scope.Dialect().HasIndex(scope.TableName(), "uix_email_with_idxes_registered_at") { - t.Errorf("Failed to create index") - } + if !scope.Dialect().HasIndex(scope.TableName(), "uix_email_with_idxes_registered_at") { + t.Errorf("Failed to create index") + } - var bigemail EmailWithIdx - DB.First(&bigemail, "user_agent = ?", "pc") - if bigemail.Email != "jinzhu@example.org" || bigemail.UserAgent != "pc" || bigemail.RegisteredAt.IsZero() { - t.Error("Big Emails should be saved and fetched correctly") - } + var bigemail EmailWithIdx + DB.First(&bigemail, "user_agent = ?", "pc") + if bigemail.Email != "jinzhu@example.org" || bigemail.UserAgent != "pc" || bigemail.RegisteredAt.IsZero() { + t.Error("Big Emails should be saved and fetched correctly") + } } func TestCreateAndAutomigrateTransaction(t *testing.T) { - tx := DB.Begin() + tx := DB.Begin() - func() { - type Bar struct { - ID uint - } - DB.DropTableIfExists(&Bar{}) + func() { + type Bar struct { + ID uint + } + DB.DropTableIfExists(&Bar{}) - if ok := DB.HasTable("bars"); ok { - t.Errorf("Table should not exist, but does") - } + if ok := DB.HasTable("bars"); ok { + t.Errorf("Table should not exist, but does") + } - if ok := tx.HasTable("bars"); ok { - t.Errorf("Table should not exist, but does") - } - }() + if ok := tx.HasTable("bars"); ok { + t.Errorf("Table should not exist, but does") + } + }() - func() { - type Bar struct { - Name string - } - err := tx.CreateTable(&Bar{}).Error + func() { + type Bar struct { + Name string + } + err := tx.CreateTable(&Bar{}).Error - if err != nil { - t.Errorf("Should have been able to create the table, but couldn't: %s", err) - } + if err != nil { + t.Errorf("Should have been able to create the table, but couldn't: %s", err) + } - if ok := tx.HasTable(&Bar{}); !ok { - t.Errorf("The transaction should be able to see the table") - } - }() + if ok := tx.HasTable(&Bar{}); !ok { + t.Errorf("The transaction should be able to see the table") + } + }() - func() { - type Bar struct { - Stuff string - } + func() { + type Bar struct { + Stuff string + } - err := tx.AutoMigrate(&Bar{}).Error - if err != nil { - t.Errorf("Should have been able to alter the table, but couldn't") - } - }() + err := tx.AutoMigrate(&Bar{}).Error + if err != nil { + t.Errorf("Should have been able to alter the table, but couldn't") + } + }() - tx.Rollback() + tx.Rollback() } type MultipleIndexes struct { - ID int64 - UserID int64 `sql:"unique_index:uix_multipleindexes_user_name,uix_multipleindexes_user_email;index:idx_multipleindexes_user_other"` - Name string `sql:"unique_index:uix_multipleindexes_user_name"` - Email string `sql:"unique_index:,uix_multipleindexes_user_email"` - Other string `sql:"index:,idx_multipleindexes_user_other"` + ID int64 + UserID int64 `sql:"unique_index:uix_multipleindexes_user_name,uix_multipleindexes_user_email;index:idx_multipleindexes_user_other"` + Name string `sql:"unique_index:uix_multipleindexes_user_name"` + Email string `sql:"unique_index:,uix_multipleindexes_user_email"` + Other string `sql:"index:,idx_multipleindexes_user_other"` } func TestMultipleIndexes(t *testing.T) { - if err := DB.DropTableIfExists(&MultipleIndexes{}).Error; err != nil { - fmt.Printf("Got error when try to delete table multiple_indexes, %+v\n", err) - } + if err := DB.DropTableIfExists(&MultipleIndexes{}).Error; err != nil { + fmt.Printf("Got error when try to delete table multiple_indexes, %+v\n", err) + } - DB.AutoMigrate(&MultipleIndexes{}) - if err := DB.AutoMigrate(&EmailWithIdx{}).Error; err != nil { - t.Errorf("Auto Migrate should not raise any error") - } + DB.AutoMigrate(&MultipleIndexes{}) + if err := DB.AutoMigrate(&EmailWithIdx{}).Error; err != nil { + t.Errorf("Auto Migrate should not raise any error") + } - DB.Save(&MultipleIndexes{UserID: 1, Name: "jinzhu", Email: "jinzhu@example.org", Other: "foo"}) + DB.Save(&MultipleIndexes{UserID: 1, Name: "jinzhu", Email: "jinzhu@example.org", Other: "foo"}) - scope := DB.NewScope(&MultipleIndexes{}) - if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_name") { - t.Errorf("Failed to create index") - } + scope := DB.NewScope(&MultipleIndexes{}) + if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_name") { + t.Errorf("Failed to create index") + } - if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_email") { - t.Errorf("Failed to create index") - } + if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_email") { + t.Errorf("Failed to create index") + } - if !scope.Dialect().HasIndex(scope.TableName(), "uix_multiple_indexes_email") { - t.Errorf("Failed to create index") - } + if !scope.Dialect().HasIndex(scope.TableName(), "uix_multiple_indexes_email") { + t.Errorf("Failed to create index") + } - if !scope.Dialect().HasIndex(scope.TableName(), "idx_multipleindexes_user_other") { - t.Errorf("Failed to create index") - } + if !scope.Dialect().HasIndex(scope.TableName(), "idx_multipleindexes_user_other") { + t.Errorf("Failed to create index") + } - if !scope.Dialect().HasIndex(scope.TableName(), "idx_multiple_indexes_other") { - t.Errorf("Failed to create index") - } + if !scope.Dialect().HasIndex(scope.TableName(), "idx_multiple_indexes_other") { + t.Errorf("Failed to create index") + } - var mutipleIndexes MultipleIndexes - DB.First(&mutipleIndexes, "name = ?", "jinzhu") - if mutipleIndexes.Email != "jinzhu@example.org" || mutipleIndexes.Name != "jinzhu" { - t.Error("MutipleIndexes should be saved and fetched correctly") - } + var mutipleIndexes MultipleIndexes + DB.First(&mutipleIndexes, "name = ?", "jinzhu") + if mutipleIndexes.Email != "jinzhu@example.org" || mutipleIndexes.Name != "jinzhu" { + t.Error("MutipleIndexes should be saved and fetched correctly") + } - // Check unique constraints - if err := DB.Save(&MultipleIndexes{UserID: 1, Name: "name1", Email: "jinzhu@example.org", Other: "foo"}).Error; err == nil { - t.Error("MultipleIndexes unique index failed") - } + // Check unique constraints + if err := DB.Save(&MultipleIndexes{UserID: 1, Name: "name1", Email: "jinzhu@example.org", Other: "foo"}).Error; err == nil { + t.Error("MultipleIndexes unique index failed") + } - if err := DB.Save(&MultipleIndexes{UserID: 1, Name: "name1", Email: "foo@example.org", Other: "foo"}).Error; err != nil { - t.Error("MultipleIndexes unique index failed") - } + if err := DB.Save(&MultipleIndexes{UserID: 1, Name: "name1", Email: "foo@example.org", Other: "foo"}).Error; err != nil { + t.Error("MultipleIndexes unique index failed") + } - if err := DB.Save(&MultipleIndexes{UserID: 2, Name: "name1", Email: "jinzhu@example.org", Other: "foo"}).Error; err == nil { - t.Error("MultipleIndexes unique index failed") - } + if err := DB.Save(&MultipleIndexes{UserID: 2, Name: "name1", Email: "jinzhu@example.org", Other: "foo"}).Error; err == nil { + t.Error("MultipleIndexes unique index failed") + } - if err := DB.Save(&MultipleIndexes{UserID: 2, Name: "name1", Email: "foo2@example.org", Other: "foo"}).Error; err != nil { - t.Error("MultipleIndexes unique index failed") - } + if err := DB.Save(&MultipleIndexes{UserID: 2, Name: "name1", Email: "foo2@example.org", Other: "foo"}).Error; err != nil { + t.Error("MultipleIndexes unique index failed") + } } func TestModifyColumnType(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" && dialect != "mysql" && dialect != "mssql" { - t.Skip("Skipping this because only postgres, mysql and mssql support altering a column type") - } + if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" && dialect != "mysql" && dialect != "mssql" { + t.Skip("Skipping this because only postgres, mysql and mssql support altering a column type") + } - type ModifyColumnType struct { - gorm.Model - Name1 string `gorm:"length:100"` - Name2 string `gorm:"length:200"` - } - DB.DropTable(&ModifyColumnType{}) - DB.CreateTable(&ModifyColumnType{}) + type ModifyColumnType struct { + gorm.Model + Name1 string `gorm:"length:100"` + Name2 string `gorm:"length:200"` + } + DB.DropTable(&ModifyColumnType{}) + DB.CreateTable(&ModifyColumnType{}) - name2Field, _ := DB.NewScope(&ModifyColumnType{}).FieldByName("Name2") - name2Type := DB.Dialect().DataTypeOf(name2Field.StructField) + name2Field, _ := DB.NewScope(&ModifyColumnType{}).FieldByName("Name2") + name2Type := DB.Dialect().DataTypeOf(name2Field.StructField) - if err := DB.Model(&ModifyColumnType{}).ModifyColumn("name1", name2Type).Error; err != nil { - t.Errorf("No error should happen when ModifyColumn, but got %v", err) - } + if err := DB.Model(&ModifyColumnType{}).ModifyColumn("name1", name2Type).Error; err != nil { + t.Errorf("No error should happen when ModifyColumn, but got %v", err) + } } func TestIndexWithPrefixLength(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect != "mysql" { - t.Skip("Skipping this because only mysql support setting an index prefix length") - } + if dialect := os.Getenv("GORM_DIALECT"); dialect != "mysql" { + t.Skip("Skipping this because only mysql support setting an index prefix length") + } - type IndexWithPrefix struct { - gorm.Model - Name string - Description string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"` - } - type IndexesWithPrefix struct { - gorm.Model - Name string - Description1 string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"` - Description2 string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"` - } - type IndexesWithPrefixAndWithoutPrefix struct { - gorm.Model - Name string `gorm:"index:idx_index_with_prefixes_length"` - Description string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"` - } - tables := []interface{}{&IndexWithPrefix{}, &IndexesWithPrefix{}, &IndexesWithPrefixAndWithoutPrefix{}} - for _, table := range tables { - scope := DB.NewScope(table) - tableName := scope.TableName() - t.Run(fmt.Sprintf("Create index with prefix length: %s", tableName), func(t *testing.T) { - if err := DB.DropTableIfExists(table).Error; err != nil { - t.Errorf("Failed to drop %s table: %v", tableName, err) - } - if err := DB.CreateTable(table).Error; err != nil { - t.Errorf("Failed to create %s table: %v", tableName, err) - } - if !scope.Dialect().HasIndex(tableName, "idx_index_with_prefixes_length") { - t.Errorf("Failed to create %s table index:", tableName) - } - }) - } + type IndexWithPrefix struct { + gorm.Model + Name string + Description string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"` + } + type IndexesWithPrefix struct { + gorm.Model + Name string + Description1 string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"` + Description2 string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"` + } + type IndexesWithPrefixAndWithoutPrefix struct { + gorm.Model + Name string `gorm:"index:idx_index_with_prefixes_length"` + Description string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"` + } + tables := []interface{}{&IndexWithPrefix{}, &IndexesWithPrefix{}, &IndexesWithPrefixAndWithoutPrefix{}} + for _, table := range tables { + scope := DB.NewScope(table) + tableName := scope.TableName() + t.Run(fmt.Sprintf("Create index with prefix length: %s", tableName), func(t *testing.T) { + if err := DB.DropTableIfExists(table).Error; err != nil { + t.Errorf("Failed to drop %s table: %v", tableName, err) + } + if err := DB.CreateTable(table).Error; err != nil { + t.Errorf("Failed to create %s table: %v", tableName, err) + } + if !scope.Dialect().HasIndex(tableName, "idx_index_with_prefixes_length") { + t.Errorf("Failed to create %s table index:", tableName) + } + }) + } } diff --git a/preload_test.go b/preload_test.go index dd64ebf1..dd29fb5e 100644 --- a/preload_test.go +++ b/preload_test.go @@ -1,1701 +1,1701 @@ package gorm_test import ( - "database/sql" - "encoding/json" - "os" - "reflect" - "testing" + "database/sql" + "encoding/json" + "os" + "reflect" + "testing" - "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm" ) func getPreloadUser(name string) *User { - return getPreparedUser(name, "Preload") + return getPreparedUser(name, "Preload") } func checkUserHasPreloadData(user User, t *testing.T) { - u := getPreloadUser(user.Name) - if user.BillingAddress.Address1 != u.BillingAddress.Address1 { - t.Error("Failed to preload user's BillingAddress") - } + u := getPreloadUser(user.Name) + if user.BillingAddress.Address1 != u.BillingAddress.Address1 { + t.Error("Failed to preload user's BillingAddress") + } - if user.ShippingAddress.Address1 != u.ShippingAddress.Address1 { - t.Error("Failed to preload user's ShippingAddress") - } + if user.ShippingAddress.Address1 != u.ShippingAddress.Address1 { + t.Error("Failed to preload user's ShippingAddress") + } - if user.CreditCard.Number != u.CreditCard.Number { - t.Error("Failed to preload user's CreditCard") - } + if user.CreditCard.Number != u.CreditCard.Number { + t.Error("Failed to preload user's CreditCard") + } - if user.Company.Name != u.Company.Name { - t.Error("Failed to preload user's Company") - } + if user.Company.Name != u.Company.Name { + t.Error("Failed to preload user's Company") + } - if len(user.Emails) != len(u.Emails) { - t.Error("Failed to preload user's Emails") - } else { - var found int - for _, e1 := range u.Emails { - for _, e2 := range user.Emails { - if e1.Email == e2.Email { - found++ - break - } - } - } - if found != len(u.Emails) { - t.Error("Failed to preload user's email details") - } - } + if len(user.Emails) != len(u.Emails) { + t.Error("Failed to preload user's Emails") + } else { + var found int + for _, e1 := range u.Emails { + for _, e2 := range user.Emails { + if e1.Email == e2.Email { + found++ + break + } + } + } + if found != len(u.Emails) { + t.Error("Failed to preload user's email details") + } + } } func TestPreload(t *testing.T) { - user1 := getPreloadUser("user1") - DB.Save(user1) + user1 := getPreloadUser("user1") + DB.Save(user1) - preloadDB := DB.Where("role = ?", "Preload").Preload("BillingAddress").Preload("ShippingAddress"). - Preload("CreditCard").Preload("Emails").Preload("Company") - var user User - preloadDB.Find(&user) - checkUserHasPreloadData(user, t) + preloadDB := DB.Where("role = ?", "Preload").Preload("BillingAddress").Preload("ShippingAddress"). + Preload("CreditCard").Preload("Emails").Preload("Company") + var user User + preloadDB.Find(&user) + checkUserHasPreloadData(user, t) - user2 := getPreloadUser("user2") - DB.Save(user2) + user2 := getPreloadUser("user2") + DB.Save(user2) - user3 := getPreloadUser("user3") - DB.Save(user3) + user3 := getPreloadUser("user3") + DB.Save(user3) - var users []User - preloadDB.Find(&users) + var users []User + preloadDB.Find(&users) - for _, user := range users { - checkUserHasPreloadData(user, t) - } + for _, user := range users { + checkUserHasPreloadData(user, t) + } - var users2 []*User - preloadDB.Find(&users2) + var users2 []*User + preloadDB.Find(&users2) - for _, user := range users2 { - checkUserHasPreloadData(*user, t) - } + for _, user := range users2 { + checkUserHasPreloadData(*user, t) + } - var users3 []*User - preloadDB.Preload("Emails", "email = ?", user3.Emails[0].Email).Find(&users3) + var users3 []*User + preloadDB.Preload("Emails", "email = ?", user3.Emails[0].Email).Find(&users3) - for _, user := range users3 { - if user.Name == user3.Name { - if len(user.Emails) != 1 { - t.Errorf("should only preload one emails for user3 when with condition") - } - } else if len(user.Emails) != 0 { - t.Errorf("should not preload any emails for other users when with condition") - } else if user.Emails == nil { - t.Errorf("should return an empty slice to indicate zero results") - } - } + for _, user := range users3 { + if user.Name == user3.Name { + if len(user.Emails) != 1 { + t.Errorf("should only preload one emails for user3 when with condition") + } + } else if len(user.Emails) != 0 { + t.Errorf("should not preload any emails for other users when with condition") + } else if user.Emails == nil { + t.Errorf("should return an empty slice to indicate zero results") + } + } } func TestAutoPreload(t *testing.T) { - user1 := getPreloadUser("auto_user1") - DB.Save(user1) + user1 := getPreloadUser("auto_user1") + DB.Save(user1) - preloadDB := DB.Set("gorm:auto_preload", true).Where("role = ?", "Preload") - var user User - preloadDB.Find(&user) - checkUserHasPreloadData(user, t) + preloadDB := DB.Set("gorm:auto_preload", true).Where("role = ?", "Preload") + var user User + preloadDB.Find(&user) + checkUserHasPreloadData(user, t) - user2 := getPreloadUser("auto_user2") - DB.Save(user2) + user2 := getPreloadUser("auto_user2") + DB.Save(user2) - var users []User - preloadDB.Find(&users) + var users []User + preloadDB.Find(&users) - for _, user := range users { - checkUserHasPreloadData(user, t) - } + for _, user := range users { + checkUserHasPreloadData(user, t) + } - var users2 []*User - preloadDB.Find(&users2) + var users2 []*User + preloadDB.Find(&users2) - for _, user := range users2 { - checkUserHasPreloadData(*user, t) - } + for _, user := range users2 { + checkUserHasPreloadData(*user, t) + } } func TestAutoPreloadFalseDoesntPreload(t *testing.T) { - user1 := getPreloadUser("auto_user1") - DB.Save(user1) + user1 := getPreloadUser("auto_user1") + DB.Save(user1) - preloadDB := DB.Set("gorm:auto_preload", false).Where("role = ?", "Preload") - var user User - preloadDB.Find(&user) + preloadDB := DB.Set("gorm:auto_preload", false).Where("role = ?", "Preload") + var user User + preloadDB.Find(&user) - if user.BillingAddress.Address1 != "" { - t.Error("AutoPreload was set to fasle, but still fetched data") - } + if user.BillingAddress.Address1 != "" { + t.Error("AutoPreload was set to fasle, but still fetched data") + } - user2 := getPreloadUser("auto_user2") - DB.Save(user2) + user2 := getPreloadUser("auto_user2") + DB.Save(user2) - var users []User - preloadDB.Find(&users) + var users []User + preloadDB.Find(&users) - for _, user := range users { - if user.BillingAddress.Address1 != "" { - t.Error("AutoPreload was set to fasle, but still fetched data") - } - } + for _, user := range users { + if user.BillingAddress.Address1 != "" { + t.Error("AutoPreload was set to fasle, but still fetched data") + } + } } func TestNestedPreload1(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1 Level1 - Level3ID uint - } - Level3 struct { - ID uint - Name string - Level2 Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1 Level1 + Level3ID uint + } + Level3 struct { + ID uint + Name string + Level2 Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + t.Error(err) + } - want := Level3{Level2: Level2{Level1: Level1{Value: "value"}}} - if err := DB.Create(&want).Error; err != nil { - t.Error(err) - } + want := Level3{Level2: Level2{Level1: Level1{Value: "value"}}} + if err := DB.Create(&want).Error; err != nil { + t.Error(err) + } - var got Level3 - if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil { - t.Error(err) - } + var got Level3 + if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil { + t.Error(err) + } - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } - if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").Error; err != gorm.ErrRecordNotFound { - t.Error(err) - } + if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").Error; err != gorm.ErrRecordNotFound { + t.Error(err) + } } func TestNestedPreload2(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1s []*Level1 - Level3ID uint - } - Level3 struct { - ID uint - Name string - Level2s []Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1s []*Level1 + Level3ID uint + } + Level3 struct { + ID uint + Name string + Level2s []Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + t.Error(err) + } - want := Level3{ - Level2s: []Level2{ - { - Level1s: []*Level1{ - {Value: "value1"}, - {Value: "value2"}, - }, - }, - { - Level1s: []*Level1{ - {Value: "value3"}, - }, - }, - }, - } - if err := DB.Create(&want).Error; err != nil { - t.Error(err) - } + want := Level3{ + Level2s: []Level2{ + { + Level1s: []*Level1{ + {Value: "value1"}, + {Value: "value2"}, + }, + }, + { + Level1s: []*Level1{ + {Value: "value3"}, + }, + }, + }, + } + if err := DB.Create(&want).Error; err != nil { + t.Error(err) + } - var got Level3 - if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil { - t.Error(err) - } + var got Level3 + if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil { + t.Error(err) + } - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } } func TestNestedPreload3(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1 Level1 - Level3ID uint - } - Level3 struct { - Name string - ID uint - Level2s []Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1 Level1 + Level3ID uint + } + Level3 struct { + Name string + ID uint + Level2s []Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + t.Error(err) + } - want := Level3{ - Level2s: []Level2{ - {Level1: Level1{Value: "value1"}}, - {Level1: Level1{Value: "value2"}}, - }, - } - if err := DB.Create(&want).Error; err != nil { - t.Error(err) - } + want := Level3{ + Level2s: []Level2{ + {Level1: Level1{Value: "value1"}}, + {Level1: Level1{Value: "value2"}}, + }, + } + if err := DB.Create(&want).Error; err != nil { + t.Error(err) + } - var got Level3 - if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil { - t.Error(err) - } + var got Level3 + if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil { + t.Error(err) + } - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } } func TestNestedPreload4(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1s []Level1 - Level3ID uint - } - Level3 struct { - ID uint - Name string - Level2 Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1s []Level1 + Level3ID uint + } + Level3 struct { + ID uint + Name string + Level2 Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + t.Error(err) + } - want := Level3{ - Level2: Level2{ - Level1s: []Level1{ - {Value: "value1"}, - {Value: "value2"}, - }, - }, - } - if err := DB.Create(&want).Error; err != nil { - t.Error(err) - } + want := Level3{ + Level2: Level2{ + Level1s: []Level1{ + {Value: "value1"}, + {Value: "value2"}, + }, + }, + } + if err := DB.Create(&want).Error; err != nil { + t.Error(err) + } - var got Level3 - if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil { - t.Error(err) - } + var got Level3 + if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil { + t.Error(err) + } - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } } // Slice: []Level3 func TestNestedPreload5(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1 Level1 - Level3ID uint - } - Level3 struct { - ID uint - Name string - Level2 Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1 Level1 + Level3ID uint + } + Level3 struct { + ID uint + Name string + Level2 Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + t.Error(err) + } - want := make([]Level3, 2) - want[0] = Level3{Level2: Level2{Level1: Level1{Value: "value"}}} - if err := DB.Create(&want[0]).Error; err != nil { - t.Error(err) - } - want[1] = Level3{Level2: Level2{Level1: Level1{Value: "value2"}}} - if err := DB.Create(&want[1]).Error; err != nil { - t.Error(err) - } + want := make([]Level3, 2) + want[0] = Level3{Level2: Level2{Level1: Level1{Value: "value"}}} + if err := DB.Create(&want[0]).Error; err != nil { + t.Error(err) + } + want[1] = Level3{Level2: Level2{Level1: Level1{Value: "value2"}}} + if err := DB.Create(&want[1]).Error; err != nil { + t.Error(err) + } - var got []Level3 - if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil { - t.Error(err) - } + var got []Level3 + if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil { + t.Error(err) + } - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } } func TestNestedPreload6(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1s []Level1 - Level3ID uint - } - Level3 struct { - ID uint - Name string - Level2s []Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1s []Level1 + Level3ID uint + } + Level3 struct { + ID uint + Name string + Level2s []Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + t.Error(err) + } - want := make([]Level3, 2) - want[0] = Level3{ - Level2s: []Level2{ - { - Level1s: []Level1{ - {Value: "value1"}, - {Value: "value2"}, - }, - }, - { - Level1s: []Level1{ - {Value: "value3"}, - }, - }, - }, - } - if err := DB.Create(&want[0]).Error; err != nil { - t.Error(err) - } + want := make([]Level3, 2) + want[0] = Level3{ + Level2s: []Level2{ + { + Level1s: []Level1{ + {Value: "value1"}, + {Value: "value2"}, + }, + }, + { + Level1s: []Level1{ + {Value: "value3"}, + }, + }, + }, + } + if err := DB.Create(&want[0]).Error; err != nil { + t.Error(err) + } - want[1] = Level3{ - Level2s: []Level2{ - { - Level1s: []Level1{ - {Value: "value3"}, - {Value: "value4"}, - }, - }, - { - Level1s: []Level1{ - {Value: "value5"}, - }, - }, - }, - } - if err := DB.Create(&want[1]).Error; err != nil { - t.Error(err) - } + want[1] = Level3{ + Level2s: []Level2{ + { + Level1s: []Level1{ + {Value: "value3"}, + {Value: "value4"}, + }, + }, + { + Level1s: []Level1{ + {Value: "value5"}, + }, + }, + }, + } + if err := DB.Create(&want[1]).Error; err != nil { + t.Error(err) + } - var got []Level3 - if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil { - t.Error(err) - } + var got []Level3 + if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil { + t.Error(err) + } - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } } func TestNestedPreload7(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1 Level1 - Level3ID uint - } - Level3 struct { - ID uint - Name string - Level2s []Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1 Level1 + Level3ID uint + } + Level3 struct { + ID uint + Name string + Level2s []Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + t.Error(err) + } - want := make([]Level3, 2) - want[0] = Level3{ - Level2s: []Level2{ - {Level1: Level1{Value: "value1"}}, - {Level1: Level1{Value: "value2"}}, - }, - } - if err := DB.Create(&want[0]).Error; err != nil { - t.Error(err) - } + want := make([]Level3, 2) + want[0] = Level3{ + Level2s: []Level2{ + {Level1: Level1{Value: "value1"}}, + {Level1: Level1{Value: "value2"}}, + }, + } + if err := DB.Create(&want[0]).Error; err != nil { + t.Error(err) + } - want[1] = Level3{ - Level2s: []Level2{ - {Level1: Level1{Value: "value3"}}, - {Level1: Level1{Value: "value4"}}, - }, - } - if err := DB.Create(&want[1]).Error; err != nil { - t.Error(err) - } + want[1] = Level3{ + Level2s: []Level2{ + {Level1: Level1{Value: "value3"}}, + {Level1: Level1{Value: "value4"}}, + }, + } + if err := DB.Create(&want[1]).Error; err != nil { + t.Error(err) + } - var got []Level3 - if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil { - t.Error(err) - } + var got []Level3 + if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil { + t.Error(err) + } - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } } func TestNestedPreload8(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1s []Level1 - Level3ID uint - } - Level3 struct { - ID uint - Name string - Level2 Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1s []Level1 + Level3ID uint + } + Level3 struct { + ID uint + Name string + Level2 Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + t.Error(err) + } - want := make([]Level3, 2) - want[0] = Level3{ - Level2: Level2{ - Level1s: []Level1{ - {Value: "value1"}, - {Value: "value2"}, - }, - }, - } - if err := DB.Create(&want[0]).Error; err != nil { - t.Error(err) - } - want[1] = Level3{ - Level2: Level2{ - Level1s: []Level1{ - {Value: "value3"}, - {Value: "value4"}, - }, - }, - } - if err := DB.Create(&want[1]).Error; err != nil { - t.Error(err) - } + want := make([]Level3, 2) + want[0] = Level3{ + Level2: Level2{ + Level1s: []Level1{ + {Value: "value1"}, + {Value: "value2"}, + }, + }, + } + if err := DB.Create(&want[0]).Error; err != nil { + t.Error(err) + } + want[1] = Level3{ + Level2: Level2{ + Level1s: []Level1{ + {Value: "value3"}, + {Value: "value4"}, + }, + }, + } + if err := DB.Create(&want[1]).Error; err != nil { + t.Error(err) + } - var got []Level3 - if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil { - t.Error(err) - } + var got []Level3 + if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil { + t.Error(err) + } - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } } func TestNestedPreload9(t *testing.T) { - type ( - Level0 struct { - ID uint - Value string - Level1ID uint - } - Level1 struct { - ID uint - Value string - Level2ID uint - Level2_1ID uint - Level0s []Level0 - } - Level2 struct { - ID uint - Level1s []Level1 - Level3ID uint - } - Level2_1 struct { - ID uint - Level1s []Level1 - Level3ID uint - } - Level3 struct { - ID uint - Name string - Level2 Level2 - Level2_1 Level2_1 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level2_1{}) - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists(&Level0{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}, &Level2_1{}, &Level0{}).Error; err != nil { - t.Error(err) - } + type ( + Level0 struct { + ID uint + Value string + Level1ID uint + } + Level1 struct { + ID uint + Value string + Level2ID uint + Level2_1ID uint + Level0s []Level0 + } + Level2 struct { + ID uint + Level1s []Level1 + Level3ID uint + } + Level2_1 struct { + ID uint + Level1s []Level1 + Level3ID uint + } + Level3 struct { + ID uint + Name string + Level2 Level2 + Level2_1 Level2_1 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level2_1{}) + DB.DropTableIfExists(&Level1{}) + DB.DropTableIfExists(&Level0{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}, &Level2_1{}, &Level0{}).Error; err != nil { + t.Error(err) + } - want := make([]Level3, 2) - want[0] = Level3{ - Level2: Level2{ - Level1s: []Level1{ - {Value: "value1"}, - {Value: "value2"}, - }, - }, - Level2_1: Level2_1{ - Level1s: []Level1{ - { - Value: "value1-1", - Level0s: []Level0{{Value: "Level0-1"}}, - }, - { - Value: "value2-2", - Level0s: []Level0{{Value: "Level0-2"}}, - }, - }, - }, - } - if err := DB.Create(&want[0]).Error; err != nil { - t.Error(err) - } - want[1] = Level3{ - Level2: Level2{ - Level1s: []Level1{ - {Value: "value3"}, - {Value: "value4"}, - }, - }, - Level2_1: Level2_1{ - Level1s: []Level1{ - { - Value: "value3-3", - Level0s: []Level0{}, - }, - { - Value: "value4-4", - Level0s: []Level0{}, - }, - }, - }, - } - if err := DB.Create(&want[1]).Error; err != nil { - t.Error(err) - } + want := make([]Level3, 2) + want[0] = Level3{ + Level2: Level2{ + Level1s: []Level1{ + {Value: "value1"}, + {Value: "value2"}, + }, + }, + Level2_1: Level2_1{ + Level1s: []Level1{ + { + Value: "value1-1", + Level0s: []Level0{{Value: "Level0-1"}}, + }, + { + Value: "value2-2", + Level0s: []Level0{{Value: "Level0-2"}}, + }, + }, + }, + } + if err := DB.Create(&want[0]).Error; err != nil { + t.Error(err) + } + want[1] = Level3{ + Level2: Level2{ + Level1s: []Level1{ + {Value: "value3"}, + {Value: "value4"}, + }, + }, + Level2_1: Level2_1{ + Level1s: []Level1{ + { + Value: "value3-3", + Level0s: []Level0{}, + }, + { + Value: "value4-4", + Level0s: []Level0{}, + }, + }, + }, + } + if err := DB.Create(&want[1]).Error; err != nil { + t.Error(err) + } - var got []Level3 - if err := DB.Preload("Level2").Preload("Level2.Level1s").Preload("Level2_1").Preload("Level2_1.Level1s").Preload("Level2_1.Level1s.Level0s").Find(&got).Error; err != nil { - t.Error(err) - } + var got []Level3 + if err := DB.Preload("Level2").Preload("Level2.Level1s").Preload("Level2_1").Preload("Level2_1.Level1s").Preload("Level2_1.Level1s.Level0s").Find(&got).Error; err != nil { + t.Error(err) + } - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } } type LevelA1 struct { - ID uint - Value string + ID uint + Value string } type LevelA2 struct { - ID uint - Value string - LevelA3s []*LevelA3 + ID uint + Value string + LevelA3s []*LevelA3 } type LevelA3 struct { - ID uint - Value string - LevelA1ID sql.NullInt64 - LevelA1 *LevelA1 - LevelA2ID sql.NullInt64 - LevelA2 *LevelA2 + ID uint + Value string + LevelA1ID sql.NullInt64 + LevelA1 *LevelA1 + LevelA2ID sql.NullInt64 + LevelA2 *LevelA2 } func TestNestedPreload10(t *testing.T) { - DB.DropTableIfExists(&LevelA3{}) - DB.DropTableIfExists(&LevelA2{}) - DB.DropTableIfExists(&LevelA1{}) + DB.DropTableIfExists(&LevelA3{}) + DB.DropTableIfExists(&LevelA2{}) + DB.DropTableIfExists(&LevelA1{}) - if err := DB.AutoMigrate(&LevelA1{}, &LevelA2{}, &LevelA3{}).Error; err != nil { - t.Error(err) - } + if err := DB.AutoMigrate(&LevelA1{}, &LevelA2{}, &LevelA3{}).Error; err != nil { + t.Error(err) + } - levelA1 := &LevelA1{Value: "foo"} - if err := DB.Save(levelA1).Error; err != nil { - t.Error(err) - } + levelA1 := &LevelA1{Value: "foo"} + if err := DB.Save(levelA1).Error; err != nil { + t.Error(err) + } - want := []*LevelA2{ - { - Value: "bar", - LevelA3s: []*LevelA3{ - { - Value: "qux", - LevelA1: levelA1, - }, - }, - }, - { - Value: "bar 2", - LevelA3s: []*LevelA3{}, - }, - } - for _, levelA2 := range want { - if err := DB.Save(levelA2).Error; err != nil { - t.Error(err) - } - } + want := []*LevelA2{ + { + Value: "bar", + LevelA3s: []*LevelA3{ + { + Value: "qux", + LevelA1: levelA1, + }, + }, + }, + { + Value: "bar 2", + LevelA3s: []*LevelA3{}, + }, + } + for _, levelA2 := range want { + if err := DB.Save(levelA2).Error; err != nil { + t.Error(err) + } + } - var got []*LevelA2 - if err := DB.Preload("LevelA3s.LevelA1").Find(&got).Error; err != nil { - t.Error(err) - } + var got []*LevelA2 + if err := DB.Preload("LevelA3s.LevelA1").Find(&got).Error; err != nil { + t.Error(err) + } - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } } type LevelB1 struct { - ID uint - Value string - LevelB3s []*LevelB3 + ID uint + Value string + LevelB3s []*LevelB3 } type LevelB2 struct { - ID uint - Value string + ID uint + Value string } type LevelB3 struct { - ID uint - Value string - LevelB1ID sql.NullInt64 - LevelB1 *LevelB1 - LevelB2s []*LevelB2 `gorm:"many2many:levelb1_levelb3_levelb2s"` + ID uint + Value string + LevelB1ID sql.NullInt64 + LevelB1 *LevelB1 + LevelB2s []*LevelB2 `gorm:"many2many:levelb1_levelb3_levelb2s"` } func TestNestedPreload11(t *testing.T) { - DB.DropTableIfExists(&LevelB2{}) - DB.DropTableIfExists(&LevelB3{}) - DB.DropTableIfExists(&LevelB1{}) - if err := DB.AutoMigrate(&LevelB1{}, &LevelB2{}, &LevelB3{}).Error; err != nil { - t.Error(err) - } + DB.DropTableIfExists(&LevelB2{}) + DB.DropTableIfExists(&LevelB3{}) + DB.DropTableIfExists(&LevelB1{}) + if err := DB.AutoMigrate(&LevelB1{}, &LevelB2{}, &LevelB3{}).Error; err != nil { + t.Error(err) + } - levelB1 := &LevelB1{Value: "foo"} - if err := DB.Create(levelB1).Error; err != nil { - t.Error(err) - } + levelB1 := &LevelB1{Value: "foo"} + if err := DB.Create(levelB1).Error; err != nil { + t.Error(err) + } - levelB3 := &LevelB3{ - Value: "bar", - LevelB1ID: sql.NullInt64{Valid: true, Int64: int64(levelB1.ID)}, - LevelB2s: []*LevelB2{}, - } - if err := DB.Create(levelB3).Error; err != nil { - t.Error(err) - } - levelB1.LevelB3s = []*LevelB3{levelB3} + levelB3 := &LevelB3{ + Value: "bar", + LevelB1ID: sql.NullInt64{Valid: true, Int64: int64(levelB1.ID)}, + LevelB2s: []*LevelB2{}, + } + if err := DB.Create(levelB3).Error; err != nil { + t.Error(err) + } + levelB1.LevelB3s = []*LevelB3{levelB3} - want := []*LevelB1{levelB1} - var got []*LevelB1 - if err := DB.Preload("LevelB3s.LevelB2s").Find(&got).Error; err != nil { - t.Error(err) - } + want := []*LevelB1{levelB1} + var got []*LevelB1 + if err := DB.Preload("LevelB3s.LevelB2s").Find(&got).Error; err != nil { + t.Error(err) + } - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } } type LevelC1 struct { - ID uint - Value string - LevelC2ID uint + ID uint + Value string + LevelC2ID uint } type LevelC2 struct { - ID uint - Value string - LevelC1 LevelC1 + ID uint + Value string + LevelC1 LevelC1 } type LevelC3 struct { - ID uint - Value string - LevelC2ID uint - LevelC2 LevelC2 + ID uint + Value string + LevelC2ID uint + LevelC2 LevelC2 } func TestNestedPreload12(t *testing.T) { - DB.DropTableIfExists(&LevelC2{}) - DB.DropTableIfExists(&LevelC3{}) - DB.DropTableIfExists(&LevelC1{}) - if err := DB.AutoMigrate(&LevelC1{}, &LevelC2{}, &LevelC3{}).Error; err != nil { - t.Error(err) - } + DB.DropTableIfExists(&LevelC2{}) + DB.DropTableIfExists(&LevelC3{}) + DB.DropTableIfExists(&LevelC1{}) + if err := DB.AutoMigrate(&LevelC1{}, &LevelC2{}, &LevelC3{}).Error; err != nil { + t.Error(err) + } - level2 := LevelC2{ - Value: "c2", - LevelC1: LevelC1{ - Value: "c1", - }, - } - DB.Create(&level2) + level2 := LevelC2{ + Value: "c2", + LevelC1: LevelC1{ + Value: "c1", + }, + } + DB.Create(&level2) - want := []LevelC3{ - { - Value: "c3-1", - LevelC2: level2, - }, { - Value: "c3-2", - LevelC2: level2, - }, - } + want := []LevelC3{ + { + Value: "c3-1", + LevelC2: level2, + }, { + Value: "c3-2", + LevelC2: level2, + }, + } - for i := range want { - if err := DB.Create(&want[i]).Error; err != nil { - t.Error(err) - } - } + for i := range want { + if err := DB.Create(&want[i]).Error; err != nil { + t.Error(err) + } + } - var got []LevelC3 - if err := DB.Preload("LevelC2").Preload("LevelC2.LevelC1").Find(&got).Error; err != nil { - t.Error(err) - } + var got []LevelC3 + if err := DB.Preload("LevelC2").Preload("LevelC2.LevelC1").Find(&got).Error; err != nil { + t.Error(err) + } - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } } func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect == "" || dialect == "sqlite" || dialect == "mssql" { - return - } + if dialect := os.Getenv("GORM_DIALECT"); dialect == "" || dialect == "sqlite" || dialect == "mssql" { + return + } - type ( - Level1 struct { - ID uint `gorm:"primary_key;"` - LanguageCode string `gorm:"primary_key"` - Value string - } - Level2 struct { - ID uint `gorm:"primary_key;"` - LanguageCode string `gorm:"primary_key"` - Value string - Level1s []Level1 `gorm:"many2many:levels;"` - } - ) + type ( + Level1 struct { + ID uint `gorm:"primary_key;"` + LanguageCode string `gorm:"primary_key"` + Value string + } + Level2 struct { + ID uint `gorm:"primary_key;"` + LanguageCode string `gorm:"primary_key"` + Value string + Level1s []Level1 `gorm:"many2many:levels;"` + } + ) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists("levels") + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + DB.DropTableIfExists("levels") - if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } + if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil { + t.Error(err) + } - want := Level2{Value: "Bob", LanguageCode: "ru", Level1s: []Level1{ - {Value: "ru", LanguageCode: "ru"}, - {Value: "en", LanguageCode: "en"}, - }} - if err := DB.Save(&want).Error; err != nil { - t.Error(err) - } + want := Level2{Value: "Bob", LanguageCode: "ru", Level1s: []Level1{ + {Value: "ru", LanguageCode: "ru"}, + {Value: "en", LanguageCode: "en"}, + }} + if err := DB.Save(&want).Error; err != nil { + t.Error(err) + } - want2 := Level2{Value: "Tom", LanguageCode: "zh", Level1s: []Level1{ - {Value: "zh", LanguageCode: "zh"}, - {Value: "de", LanguageCode: "de"}, - }} - if err := DB.Save(&want2).Error; err != nil { - t.Error(err) - } + want2 := Level2{Value: "Tom", LanguageCode: "zh", Level1s: []Level1{ + {Value: "zh", LanguageCode: "zh"}, + {Value: "de", LanguageCode: "de"}, + }} + if err := DB.Save(&want2).Error; err != nil { + t.Error(err) + } - var got Level2 - if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { - t.Error(err) - } + var got Level2 + if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { + t.Error(err) + } - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } - var got2 Level2 - if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { - t.Error(err) - } + var got2 Level2 + if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { + t.Error(err) + } - if !reflect.DeepEqual(got2, want2) { - t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2)) - } + if !reflect.DeepEqual(got2, want2) { + t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2)) + } - var got3 []Level2 - if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { - t.Error(err) - } + var got3 []Level2 + if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { + t.Error(err) + } - if !reflect.DeepEqual(got3, []Level2{got, got2}) { - t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level2{got, got2})) - } + if !reflect.DeepEqual(got3, []Level2{got, got2}) { + t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level2{got, got2})) + } - var got4 []Level2 - if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { - t.Error(err) - } + var got4 []Level2 + if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { + t.Error(err) + } - var ruLevel1 Level1 - var zhLevel1 Level1 - DB.First(&ruLevel1, "value = ?", "ru") - DB.First(&zhLevel1, "value = ?", "zh") + var ruLevel1 Level1 + var zhLevel1 Level1 + DB.First(&ruLevel1, "value = ?", "ru") + DB.First(&zhLevel1, "value = ?", "zh") - got.Level1s = []Level1{ruLevel1} - got2.Level1s = []Level1{zhLevel1} - if !reflect.DeepEqual(got4, []Level2{got, got2}) { - t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2})) - } + got.Level1s = []Level1{ruLevel1} + got2.Level1s = []Level1{zhLevel1} + if !reflect.DeepEqual(got4, []Level2{got, got2}) { + t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2})) + } - if err := DB.Preload("Level1s").Find(&got4, "value IN (?)", []string{"non-existing"}).Error; err != nil { - t.Error(err) - } + if err := DB.Preload("Level1s").Find(&got4, "value IN (?)", []string{"non-existing"}).Error; err != nil { + t.Error(err) + } } func TestManyToManyPreloadForNestedPointer(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - } - Level2 struct { - ID uint - Value string - Level1s []*Level1 `gorm:"many2many:levels;"` - } - Level3 struct { - ID uint - Value string - Level2ID sql.NullInt64 - Level2 *Level2 - } - ) + type ( + Level1 struct { + ID uint + Value string + } + Level2 struct { + ID uint + Value string + Level1s []*Level1 `gorm:"many2many:levels;"` + } + Level3 struct { + ID uint + Value string + Level2ID sql.NullInt64 + Level2 *Level2 + } + ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists("levels") + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + DB.DropTableIfExists("levels") - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + t.Error(err) + } - want := Level3{ - Value: "Bob", - Level2: &Level2{ - Value: "Foo", - Level1s: []*Level1{ - {Value: "ru"}, - {Value: "en"}, - }, - }, - } - if err := DB.Save(&want).Error; err != nil { - t.Error(err) - } + want := Level3{ + Value: "Bob", + Level2: &Level2{ + Value: "Foo", + Level1s: []*Level1{ + {Value: "ru"}, + {Value: "en"}, + }, + }, + } + if err := DB.Save(&want).Error; err != nil { + t.Error(err) + } - want2 := Level3{ - Value: "Tom", - Level2: &Level2{ - Value: "Bar", - Level1s: []*Level1{ - {Value: "zh"}, - {Value: "de"}, - }, - }, - } - if err := DB.Save(&want2).Error; err != nil { - t.Error(err) - } + want2 := Level3{ + Value: "Tom", + Level2: &Level2{ + Value: "Bar", + Level1s: []*Level1{ + {Value: "zh"}, + {Value: "de"}, + }, + }, + } + if err := DB.Save(&want2).Error; err != nil { + t.Error(err) + } - var got Level3 - if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { - t.Error(err) - } + var got Level3 + if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { + t.Error(err) + } - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } - var got2 Level3 - if err := DB.Preload("Level2.Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { - t.Error(err) - } + var got2 Level3 + if err := DB.Preload("Level2.Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { + t.Error(err) + } - if !reflect.DeepEqual(got2, want2) { - t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2)) - } + if !reflect.DeepEqual(got2, want2) { + t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2)) + } - var got3 []Level3 - if err := DB.Preload("Level2.Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { - t.Error(err) - } + var got3 []Level3 + if err := DB.Preload("Level2.Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { + t.Error(err) + } - if !reflect.DeepEqual(got3, []Level3{got, got2}) { - t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level3{got, got2})) - } + if !reflect.DeepEqual(got3, []Level3{got, got2}) { + t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level3{got, got2})) + } - var got4 []Level3 - if err := DB.Preload("Level2.Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { - t.Error(err) - } + var got4 []Level3 + if err := DB.Preload("Level2.Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { + t.Error(err) + } - var got5 Level3 - DB.Preload("Level2.Level1s").Find(&got5, "value = ?", "bogus") + var got5 Level3 + DB.Preload("Level2.Level1s").Find(&got5, "value = ?", "bogus") - var ruLevel1 Level1 - var zhLevel1 Level1 - DB.First(&ruLevel1, "value = ?", "ru") - DB.First(&zhLevel1, "value = ?", "zh") + var ruLevel1 Level1 + var zhLevel1 Level1 + DB.First(&ruLevel1, "value = ?", "ru") + DB.First(&zhLevel1, "value = ?", "zh") - got.Level2.Level1s = []*Level1{&ruLevel1} - got2.Level2.Level1s = []*Level1{&zhLevel1} - if !reflect.DeepEqual(got4, []Level3{got, got2}) { - t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level3{got, got2})) - } + got.Level2.Level1s = []*Level1{&ruLevel1} + got2.Level2.Level1s = []*Level1{&zhLevel1} + if !reflect.DeepEqual(got4, []Level3{got, got2}) { + t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level3{got, got2})) + } } func TestNestedManyToManyPreload(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - } - Level2 struct { - ID uint - Value string - Level1s []*Level1 `gorm:"many2many:level1_level2;"` - } - Level3 struct { - ID uint - Value string - Level2s []Level2 `gorm:"many2many:level2_level3;"` - } - ) + type ( + Level1 struct { + ID uint + Value string + } + Level2 struct { + ID uint + Value string + Level1s []*Level1 `gorm:"many2many:level1_level2;"` + } + Level3 struct { + ID uint + Value string + Level2s []Level2 `gorm:"many2many:level2_level3;"` + } + ) - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists("level1_level2") - DB.DropTableIfExists("level2_level3") + DB.DropTableIfExists(&Level1{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists("level1_level2") + DB.DropTableIfExists("level2_level3") - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + t.Error(err) + } - want := Level3{ - Value: "Level3", - Level2s: []Level2{ - { - Value: "Bob", - Level1s: []*Level1{ - {Value: "ru"}, - {Value: "en"}, - }, - }, { - Value: "Tom", - Level1s: []*Level1{ - {Value: "zh"}, - {Value: "de"}, - }, - }, - }, - } + want := Level3{ + Value: "Level3", + Level2s: []Level2{ + { + Value: "Bob", + Level1s: []*Level1{ + {Value: "ru"}, + {Value: "en"}, + }, + }, { + Value: "Tom", + Level1s: []*Level1{ + {Value: "zh"}, + {Value: "de"}, + }, + }, + }, + } - if err := DB.Save(&want).Error; err != nil { - t.Error(err) - } + if err := DB.Save(&want).Error; err != nil { + t.Error(err) + } - var got Level3 - if err := DB.Preload("Level2s").Preload("Level2s.Level1s").Find(&got, "value = ?", "Level3").Error; err != nil { - t.Error(err) - } + var got Level3 + if err := DB.Preload("Level2s").Preload("Level2s.Level1s").Find(&got, "value = ?", "Level3").Error; err != nil { + t.Error(err) + } - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } - if err := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound { - t.Error(err) - } + if err := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound { + t.Error(err) + } } func TestNestedManyToManyPreload2(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - } - Level2 struct { - ID uint - Value string - Level1s []*Level1 `gorm:"many2many:level1_level2;"` - } - Level3 struct { - ID uint - Value string - Level2ID sql.NullInt64 - Level2 *Level2 - } - ) + type ( + Level1 struct { + ID uint + Value string + } + Level2 struct { + ID uint + Value string + Level1s []*Level1 `gorm:"many2many:level1_level2;"` + } + Level3 struct { + ID uint + Value string + Level2ID sql.NullInt64 + Level2 *Level2 + } + ) - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists("level1_level2") + DB.DropTableIfExists(&Level1{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists("level1_level2") - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + t.Error(err) + } - want := Level3{ - Value: "Level3", - Level2: &Level2{ - Value: "Bob", - Level1s: []*Level1{ - {Value: "ru"}, - {Value: "en"}, - }, - }, - } + want := Level3{ + Value: "Level3", + Level2: &Level2{ + Value: "Bob", + Level1s: []*Level1{ + {Value: "ru"}, + {Value: "en"}, + }, + }, + } - if err := DB.Save(&want).Error; err != nil { - t.Error(err) - } + if err := DB.Save(&want).Error; err != nil { + t.Error(err) + } - var got Level3 - if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "Level3").Error; err != nil { - t.Error(err) - } + var got Level3 + if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "Level3").Error; err != nil { + t.Error(err) + } - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } - if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound { - t.Error(err) - } + if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound { + t.Error(err) + } } func TestNestedManyToManyPreload3(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - } - Level2 struct { - ID uint - Value string - Level1s []*Level1 `gorm:"many2many:level1_level2;"` - } - Level3 struct { - ID uint - Value string - Level2ID sql.NullInt64 - Level2 *Level2 - } - ) + type ( + Level1 struct { + ID uint + Value string + } + Level2 struct { + ID uint + Value string + Level1s []*Level1 `gorm:"many2many:level1_level2;"` + } + Level3 struct { + ID uint + Value string + Level2ID sql.NullInt64 + Level2 *Level2 + } + ) - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists("level1_level2") + DB.DropTableIfExists(&Level1{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists("level1_level2") - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + t.Error(err) + } - level1Zh := &Level1{Value: "zh"} - level1Ru := &Level1{Value: "ru"} - level1En := &Level1{Value: "en"} + level1Zh := &Level1{Value: "zh"} + level1Ru := &Level1{Value: "ru"} + level1En := &Level1{Value: "en"} - level21 := &Level2{ - Value: "Level2-1", - Level1s: []*Level1{level1Zh, level1Ru}, - } + level21 := &Level2{ + Value: "Level2-1", + Level1s: []*Level1{level1Zh, level1Ru}, + } - level22 := &Level2{ - Value: "Level2-2", - Level1s: []*Level1{level1Zh, level1En}, - } + level22 := &Level2{ + Value: "Level2-2", + Level1s: []*Level1{level1Zh, level1En}, + } - wants := []*Level3{ - { - Value: "Level3-1", - Level2: level21, - }, - { - Value: "Level3-2", - Level2: level22, - }, - { - Value: "Level3-3", - Level2: level21, - }, - } + wants := []*Level3{ + { + Value: "Level3-1", + Level2: level21, + }, + { + Value: "Level3-2", + Level2: level22, + }, + { + Value: "Level3-3", + Level2: level21, + }, + } - for _, want := range wants { - if err := DB.Save(&want).Error; err != nil { - t.Error(err) - } - } + for _, want := range wants { + if err := DB.Save(&want).Error; err != nil { + t.Error(err) + } + } - var gots []*Level3 - if err := DB.Preload("Level2.Level1s", func(db *gorm.DB) *gorm.DB { - return db.Order("level1.id ASC") - }).Find(&gots).Error; err != nil { - t.Error(err) - } + var gots []*Level3 + if err := DB.Preload("Level2.Level1s", func(db *gorm.DB) *gorm.DB { + return db.Order("level1.id ASC") + }).Find(&gots).Error; err != nil { + t.Error(err) + } - if !reflect.DeepEqual(gots, wants) { - t.Errorf("got %s; want %s", toJSONString(gots), toJSONString(wants)) - } + if !reflect.DeepEqual(gots, wants) { + t.Errorf("got %s; want %s", toJSONString(gots), toJSONString(wants)) + } } func TestNestedManyToManyPreload3ForStruct(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - } - Level2 struct { - ID uint - Value string - Level1s []Level1 `gorm:"many2many:level1_level2;"` - } - Level3 struct { - ID uint - Value string - Level2ID sql.NullInt64 - Level2 Level2 - } - ) + type ( + Level1 struct { + ID uint + Value string + } + Level2 struct { + ID uint + Value string + Level1s []Level1 `gorm:"many2many:level1_level2;"` + } + Level3 struct { + ID uint + Value string + Level2ID sql.NullInt64 + Level2 Level2 + } + ) - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists("level1_level2") + DB.DropTableIfExists(&Level1{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists("level1_level2") - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + t.Error(err) + } - level1Zh := Level1{Value: "zh"} - level1Ru := Level1{Value: "ru"} - level1En := Level1{Value: "en"} + level1Zh := Level1{Value: "zh"} + level1Ru := Level1{Value: "ru"} + level1En := Level1{Value: "en"} - level21 := Level2{ - Value: "Level2-1", - Level1s: []Level1{level1Zh, level1Ru}, - } + level21 := Level2{ + Value: "Level2-1", + Level1s: []Level1{level1Zh, level1Ru}, + } - level22 := Level2{ - Value: "Level2-2", - Level1s: []Level1{level1Zh, level1En}, - } + level22 := Level2{ + Value: "Level2-2", + Level1s: []Level1{level1Zh, level1En}, + } - wants := []*Level3{ - { - Value: "Level3-1", - Level2: level21, - }, - { - Value: "Level3-2", - Level2: level22, - }, - { - Value: "Level3-3", - Level2: level21, - }, - } + wants := []*Level3{ + { + Value: "Level3-1", + Level2: level21, + }, + { + Value: "Level3-2", + Level2: level22, + }, + { + Value: "Level3-3", + Level2: level21, + }, + } - for _, want := range wants { - if err := DB.Save(&want).Error; err != nil { - t.Error(err) - } - } + for _, want := range wants { + if err := DB.Save(&want).Error; err != nil { + t.Error(err) + } + } - var gots []*Level3 - if err := DB.Preload("Level2.Level1s", func(db *gorm.DB) *gorm.DB { - return db.Order("level1.id ASC") - }).Find(&gots).Error; err != nil { - t.Error(err) - } + var gots []*Level3 + if err := DB.Preload("Level2.Level1s", func(db *gorm.DB) *gorm.DB { + return db.Order("level1.id ASC") + }).Find(&gots).Error; err != nil { + t.Error(err) + } - if !reflect.DeepEqual(gots, wants) { - t.Errorf("got %s; want %s", toJSONString(gots), toJSONString(wants)) - } + if !reflect.DeepEqual(gots, wants) { + t.Errorf("got %s; want %s", toJSONString(gots), toJSONString(wants)) + } } func TestNestedManyToManyPreload4(t *testing.T) { - type ( - Level4 struct { - ID uint - Value string - Level3ID uint - } - Level3 struct { - ID uint - Value string - Level4s []*Level4 - } - Level2 struct { - ID uint - Value string - Level3s []*Level3 `gorm:"many2many:level2_level3;"` - } - Level1 struct { - ID uint - Value string - Level2s []*Level2 `gorm:"many2many:level1_level2;"` - } - ) + type ( + Level4 struct { + ID uint + Value string + Level3ID uint + } + Level3 struct { + ID uint + Value string + Level4s []*Level4 + } + Level2 struct { + ID uint + Value string + Level3s []*Level3 `gorm:"many2many:level2_level3;"` + } + Level1 struct { + ID uint + Value string + Level2s []*Level2 `gorm:"many2many:level1_level2;"` + } + ) - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level4{}) - DB.DropTableIfExists("level1_level2") - DB.DropTableIfExists("level2_level3") + DB.DropTableIfExists(&Level1{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level4{}) + DB.DropTableIfExists("level1_level2") + DB.DropTableIfExists("level2_level3") - dummy := Level1{ - Value: "Level1", - Level2s: []*Level2{{ - Value: "Level2", - Level3s: []*Level3{{ - Value: "Level3", - Level4s: []*Level4{{ - Value: "Level4", - }}, - }}, - }}, - } + dummy := Level1{ + Value: "Level1", + Level2s: []*Level2{{ + Value: "Level2", + Level3s: []*Level3{{ + Value: "Level3", + Level4s: []*Level4{{ + Value: "Level4", + }}, + }}, + }}, + } - if err := DB.AutoMigrate(&Level4{}, &Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } + if err := DB.AutoMigrate(&Level4{}, &Level3{}, &Level2{}, &Level1{}).Error; err != nil { + t.Error(err) + } - if err := DB.Save(&dummy).Error; err != nil { - t.Error(err) - } + if err := DB.Save(&dummy).Error; err != nil { + t.Error(err) + } - var level1 Level1 - if err := DB.Preload("Level2s").Preload("Level2s.Level3s").Preload("Level2s.Level3s.Level4s").First(&level1).Error; err != nil { - t.Error(err) - } + var level1 Level1 + if err := DB.Preload("Level2s").Preload("Level2s.Level3s").Preload("Level2s.Level3s.Level4s").First(&level1).Error; err != nil { + t.Error(err) + } } func TestManyToManyPreloadForPointer(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - } - Level2 struct { - ID uint - Value string - Level1s []*Level1 `gorm:"many2many:levels;"` - } - ) + type ( + Level1 struct { + ID uint + Value string + } + Level2 struct { + ID uint + Value string + Level1s []*Level1 `gorm:"many2many:levels;"` + } + ) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists("levels") + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + DB.DropTableIfExists("levels") - if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } + if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil { + t.Error(err) + } - want := Level2{Value: "Bob", Level1s: []*Level1{ - {Value: "ru"}, - {Value: "en"}, - }} - if err := DB.Save(&want).Error; err != nil { - t.Error(err) - } + want := Level2{Value: "Bob", Level1s: []*Level1{ + {Value: "ru"}, + {Value: "en"}, + }} + if err := DB.Save(&want).Error; err != nil { + t.Error(err) + } - want2 := Level2{Value: "Tom", Level1s: []*Level1{ - {Value: "zh"}, - {Value: "de"}, - }} - if err := DB.Save(&want2).Error; err != nil { - t.Error(err) - } + want2 := Level2{Value: "Tom", Level1s: []*Level1{ + {Value: "zh"}, + {Value: "de"}, + }} + if err := DB.Save(&want2).Error; err != nil { + t.Error(err) + } - var got Level2 - if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { - t.Error(err) - } + var got Level2 + if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { + t.Error(err) + } - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } - var got2 Level2 - if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { - t.Error(err) - } + var got2 Level2 + if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { + t.Error(err) + } - if !reflect.DeepEqual(got2, want2) { - t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2)) - } + if !reflect.DeepEqual(got2, want2) { + t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2)) + } - var got3 []Level2 - if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { - t.Error(err) - } + var got3 []Level2 + if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { + t.Error(err) + } - if !reflect.DeepEqual(got3, []Level2{got, got2}) { - t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level2{got, got2})) - } + if !reflect.DeepEqual(got3, []Level2{got, got2}) { + t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level2{got, got2})) + } - var got4 []Level2 - if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { - t.Error(err) - } + var got4 []Level2 + if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { + t.Error(err) + } - var got5 Level2 - DB.Preload("Level1s").First(&got5, "value = ?", "bogus") + var got5 Level2 + DB.Preload("Level1s").First(&got5, "value = ?", "bogus") - var ruLevel1 Level1 - var zhLevel1 Level1 - DB.First(&ruLevel1, "value = ?", "ru") - DB.First(&zhLevel1, "value = ?", "zh") + var ruLevel1 Level1 + var zhLevel1 Level1 + DB.First(&ruLevel1, "value = ?", "ru") + DB.First(&zhLevel1, "value = ?", "zh") - got.Level1s = []*Level1{&ruLevel1} - got2.Level1s = []*Level1{&zhLevel1} - if !reflect.DeepEqual(got4, []Level2{got, got2}) { - t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2})) - } + got.Level1s = []*Level1{&ruLevel1} + got2.Level1s = []*Level1{&zhLevel1} + if !reflect.DeepEqual(got4, []Level2{got, got2}) { + t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2})) + } } func TestNilPointerSlice(t *testing.T) { - type ( - Level3 struct { - ID uint - Value string - } - Level2 struct { - ID uint - Value string - Level3ID uint - Level3 *Level3 - } - Level1 struct { - ID uint - Value string - Level2ID uint - Level2 *Level2 - } - ) + type ( + Level3 struct { + ID uint + Value string + } + Level2 struct { + ID uint + Value string + Level3ID uint + Level3 *Level3 + } + Level1 struct { + ID uint + Value string + Level2ID uint + Level2 *Level2 + } + ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + t.Error(err) + } - want := Level1{ - Value: "Bob", - Level2: &Level2{ - Value: "en", - Level3: &Level3{ - Value: "native", - }, - }, - } - if err := DB.Save(&want).Error; err != nil { - t.Error(err) - } + want := Level1{ + Value: "Bob", + Level2: &Level2{ + Value: "en", + Level3: &Level3{ + Value: "native", + }, + }, + } + if err := DB.Save(&want).Error; err != nil { + t.Error(err) + } - want2 := Level1{ - Value: "Tom", - Level2: nil, - } - if err := DB.Save(&want2).Error; err != nil { - t.Error(err) - } + want2 := Level1{ + Value: "Tom", + Level2: nil, + } + if err := DB.Save(&want2).Error; err != nil { + t.Error(err) + } - var got []Level1 - if err := DB.Preload("Level2").Preload("Level2.Level3").Find(&got).Error; err != nil { - t.Error(err) - } + var got []Level1 + if err := DB.Preload("Level2").Preload("Level2.Level3").Find(&got).Error; err != nil { + t.Error(err) + } - if len(got) != 2 { - t.Errorf("got %v items, expected 2", len(got)) - } + if len(got) != 2 { + t.Errorf("got %v items, expected 2", len(got)) + } - if !reflect.DeepEqual(got[0], want) && !reflect.DeepEqual(got[1], want) { - t.Errorf("got %s; want array containing %s", toJSONString(got), toJSONString(want)) - } + if !reflect.DeepEqual(got[0], want) && !reflect.DeepEqual(got[1], want) { + t.Errorf("got %s; want array containing %s", toJSONString(got), toJSONString(want)) + } - if !reflect.DeepEqual(got[0], want2) && !reflect.DeepEqual(got[1], want2) { - t.Errorf("got %s; want array containing %s", toJSONString(got), toJSONString(want2)) - } + if !reflect.DeepEqual(got[0], want2) && !reflect.DeepEqual(got[1], want2) { + t.Errorf("got %s; want array containing %s", toJSONString(got), toJSONString(want2)) + } } func TestNilPointerSlice2(t *testing.T) { - type ( - Level4 struct { - ID uint - } - Level3 struct { - ID uint - Level4ID sql.NullInt64 `sql:"index"` - Level4 *Level4 - } - Level2 struct { - ID uint - Level3s []*Level3 `gorm:"many2many:level2_level3s"` - } - Level1 struct { - ID uint - Level2ID sql.NullInt64 `sql:"index"` - Level2 *Level2 - } - ) + type ( + Level4 struct { + ID uint + } + Level3 struct { + ID uint + Level4ID sql.NullInt64 `sql:"index"` + Level4 *Level4 + } + Level2 struct { + ID uint + Level3s []*Level3 `gorm:"many2many:level2_level3s"` + } + Level1 struct { + ID uint + Level2ID sql.NullInt64 `sql:"index"` + Level2 *Level2 + } + ) - DB.DropTableIfExists(new(Level4)) - DB.DropTableIfExists(new(Level3)) - DB.DropTableIfExists(new(Level2)) - DB.DropTableIfExists(new(Level1)) + DB.DropTableIfExists(new(Level4)) + DB.DropTableIfExists(new(Level3)) + DB.DropTableIfExists(new(Level2)) + DB.DropTableIfExists(new(Level1)) - if err := DB.AutoMigrate(new(Level4), new(Level3), new(Level2), new(Level1)).Error; err != nil { - t.Error(err) - } + if err := DB.AutoMigrate(new(Level4), new(Level3), new(Level2), new(Level1)).Error; err != nil { + t.Error(err) + } - want := new(Level1) - if err := DB.Save(want).Error; err != nil { - t.Error(err) - } + want := new(Level1) + if err := DB.Save(want).Error; err != nil { + t.Error(err) + } - got := new(Level1) - err := DB.Preload("Level2.Level3s.Level4").Last(&got).Error - if err != nil { - t.Error(err) - } + got := new(Level1) + err := DB.Preload("Level2.Level3s.Level4").Last(&got).Error + if err != nil { + t.Error(err) + } - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } } func TestPrefixedPreloadDuplication(t *testing.T) { - type ( - Level4 struct { - ID uint - Name string - Level3ID uint - } - Level3 struct { - ID uint - Name string - Level4s []*Level4 - } - Level2 struct { - ID uint - Name string - Level3ID sql.NullInt64 `sql:"index"` - Level3 *Level3 - } - Level1 struct { - ID uint - Name string - Level2ID sql.NullInt64 `sql:"index"` - Level2 *Level2 - } - ) + type ( + Level4 struct { + ID uint + Name string + Level3ID uint + } + Level3 struct { + ID uint + Name string + Level4s []*Level4 + } + Level2 struct { + ID uint + Name string + Level3ID sql.NullInt64 `sql:"index"` + Level3 *Level3 + } + Level1 struct { + ID uint + Name string + Level2ID sql.NullInt64 `sql:"index"` + Level2 *Level2 + } + ) - DB.DropTableIfExists(new(Level3)) - DB.DropTableIfExists(new(Level4)) - DB.DropTableIfExists(new(Level2)) - DB.DropTableIfExists(new(Level1)) + DB.DropTableIfExists(new(Level3)) + DB.DropTableIfExists(new(Level4)) + DB.DropTableIfExists(new(Level2)) + DB.DropTableIfExists(new(Level1)) - if err := DB.AutoMigrate(new(Level3), new(Level4), new(Level2), new(Level1)).Error; err != nil { - t.Error(err) - } + if err := DB.AutoMigrate(new(Level3), new(Level4), new(Level2), new(Level1)).Error; err != nil { + t.Error(err) + } - lvl := &Level3{} - if err := DB.Save(lvl).Error; err != nil { - t.Error(err) - } + lvl := &Level3{} + if err := DB.Save(lvl).Error; err != nil { + t.Error(err) + } - sublvl1 := &Level4{Level3ID: lvl.ID} - if err := DB.Save(sublvl1).Error; err != nil { - t.Error(err) - } - sublvl2 := &Level4{Level3ID: lvl.ID} - if err := DB.Save(sublvl2).Error; err != nil { - t.Error(err) - } + sublvl1 := &Level4{Level3ID: lvl.ID} + if err := DB.Save(sublvl1).Error; err != nil { + t.Error(err) + } + sublvl2 := &Level4{Level3ID: lvl.ID} + if err := DB.Save(sublvl2).Error; err != nil { + t.Error(err) + } - lvl.Level4s = []*Level4{sublvl1, sublvl2} + lvl.Level4s = []*Level4{sublvl1, sublvl2} - want1 := Level1{ - Level2: &Level2{ - Level3: lvl, - }, - } - if err := DB.Save(&want1).Error; err != nil { - t.Error(err) - } + want1 := Level1{ + Level2: &Level2{ + Level3: lvl, + }, + } + if err := DB.Save(&want1).Error; err != nil { + t.Error(err) + } - want2 := Level1{ - Level2: &Level2{ - Level3: lvl, - }, - } - if err := DB.Save(&want2).Error; err != nil { - t.Error(err) - } + want2 := Level1{ + Level2: &Level2{ + Level3: lvl, + }, + } + if err := DB.Save(&want2).Error; err != nil { + t.Error(err) + } - want := []Level1{want1, want2} + want := []Level1{want1, want2} - var got []Level1 - err := DB.Preload("Level2.Level3.Level4s").Find(&got).Error - if err != nil { - t.Error(err) - } + var got []Level1 + err := DB.Preload("Level2.Level3.Level4s").Find(&got).Error + if err != nil { + t.Error(err) + } - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } } func TestPreloadManyToManyCallbacks(t *testing.T) { - type ( - Level2 struct { - ID uint - Name string - } - Level1 struct { - ID uint - Name string - Level2s []Level2 `gorm:"many2many:level1_level2s;AssociationForeignKey:ID;ForeignKey:ID"` - } - ) + type ( + Level2 struct { + ID uint + Name string + } + Level1 struct { + ID uint + Name string + Level2s []Level2 `gorm:"many2many:level1_level2s;AssociationForeignKey:ID;ForeignKey:ID"` + } + ) - DB.DropTableIfExists("level1_level2s") - DB.DropTableIfExists(new(Level1)) - DB.DropTableIfExists(new(Level2)) + DB.DropTableIfExists("level1_level2s") + DB.DropTableIfExists(new(Level1)) + DB.DropTableIfExists(new(Level2)) - if err := DB.AutoMigrate(new(Level1), new(Level2)).Error; err != nil { - t.Error(err) - } + if err := DB.AutoMigrate(new(Level1), new(Level2)).Error; err != nil { + t.Error(err) + } - lvl := Level1{ - Name: "l1", - Level2s: []Level2{ - {Name: "l2-1"}, {Name: "l2-2"}, - }, - } - DB.Save(&lvl) + lvl := Level1{ + Name: "l1", + Level2s: []Level2{ + {Name: "l2-1"}, {Name: "l2-2"}, + }, + } + DB.Save(&lvl) - called := 0 + called := 0 - DB.Callback().Query().After("gorm:query").Register("TestPreloadManyToManyCallbacks", func(scope *gorm.Scope) { - called = called + 1 - }) + DB.Callback().Query().After("gorm:query").Register("TestPreloadManyToManyCallbacks", func(scope *gorm.Scope) { + called = called + 1 + }) - DB.Preload("Level2s").First(&Level1{}, "id = ?", lvl.ID) + DB.Preload("Level2s").First(&Level1{}, "id = ?", lvl.ID) - if called != 3 { - t.Errorf("Wanted callback to be called 3 times but got %d", called) - } + if called != 3 { + t.Errorf("Wanted callback to be called 3 times but got %d", called) + } } func toJSONString(v interface{}) []byte { - r, _ := json.MarshalIndent(v, "", " ") - return r + r, _ := json.MarshalIndent(v, "", " ") + return r }