This commit is contained in:
Daniel Gatis 2020-02-21 14:03:47 -03:00
parent 75d6dc912c
commit f424f8aa2e
12 changed files with 2455 additions and 2456 deletions

View File

@ -97,27 +97,27 @@ func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor {
// Register a new callback, refer `Callbacks.Create` // Register a new callback, refer `Callbacks.Create`
func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) { func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) {
callbackContext := func(ctx context.Context, scope *Scope) { callbackContext := func(ctx context.Context, scope *Scope) {
callback(scope) callback(scope)
} }
cp.RegisterContext(callbackName, callbackContext) cp.RegisterContext(callbackName, callbackContext)
} }
// RegisterContext same as Register // RegisterContext same as Register
func (cp *CallbackProcessor) RegisterContext(callbackName string, callback func(ctx context.Context, scope *Scope)) { func (cp *CallbackProcessor) RegisterContext(callbackName string, callback func(ctx context.Context, scope *Scope)) {
if cp.kind == "row_query" { if cp.kind == "row_query" {
if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" { if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" {
cp.logger.Print("info", fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...", callbackName)) cp.logger.Print("info", fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...", callbackName))
cp.before = "gorm:row_query" cp.before = "gorm:row_query"
} }
} }
cp.logger.Print("info", fmt.Sprintf("[info] registering callback `%v` from %v", callbackName, fileWithLineNum())) cp.logger.Print("info", fmt.Sprintf("[info] registering callback `%v` from %v", callbackName, fileWithLineNum()))
cp.name = callbackName cp.name = callbackName
cp.processor = &callback cp.processor = &callback
cp.parent.processors = append(cp.parent.processors, cp) cp.parent.processors = append(cp.parent.processors, cp)
cp.parent.reorder() cp.parent.reorder()
} }
// Remove a registered callback // Remove a registered callback
@ -136,47 +136,47 @@ func (cp *CallbackProcessor) Remove(callbackName string) {
// scope.SetColumn("UpdatedAt", now) // scope.SetColumn("UpdatedAt", now)
// }) // })
func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) { func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
callbackContext := func(ctx context.Context, scope *Scope) { callbackContext := func(ctx context.Context, scope *Scope) {
callback(scope) callback(scope)
} }
cp.ReplaceContext(callbackName, callbackContext) cp.ReplaceContext(callbackName, callbackContext)
} }
// ReplaceContext same as Replace // ReplaceContext same as Replace
func (cp *CallbackProcessor) ReplaceContext(callbackName string, callback func(ctx context.Context, scope *Scope)) { func (cp *CallbackProcessor) ReplaceContext(callbackName string, callback func(ctx context.Context, scope *Scope)) {
cp.logger.Print("info", fmt.Sprintf("[info] replacing callback `%v` from %v", callbackName, fileWithLineNum())) cp.logger.Print("info", fmt.Sprintf("[info] replacing callback `%v` from %v", callbackName, fileWithLineNum()))
cp.name = callbackName cp.name = callbackName
cp.processor = &callback cp.processor = &callback
cp.replace = true cp.replace = true
cp.parent.processors = append(cp.parent.processors, cp) cp.parent.processors = append(cp.parent.processors, cp)
cp.parent.reorder() cp.parent.reorder()
} }
// Get registered callback // Get registered callback
// db.Callback().Create().Get("gorm:create") // db.Callback().Create().Get("gorm:create")
func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) { func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) {
c := cp.GetContext(callbackName) c := cp.GetContext(callbackName)
callback = func(scope *Scope) { callback = func(scope *Scope) {
ctx := context.Background() ctx := context.Background()
c(ctx, scope) c(ctx, scope)
} }
return return
} }
// GetContext same as Get // GetContext same as Get
func (cp *CallbackProcessor) GetContext(callbackName string) (callback func(ctx context.Context, scope *Scope)) { func (cp *CallbackProcessor) GetContext(callbackName string) (callback func(ctx context.Context, scope *Scope)) {
for _, p := range cp.parent.processors { for _, p := range cp.parent.processors {
if p.name == callbackName && p.kind == cp.kind { if p.name == callbackName && p.kind == cp.kind {
if p.remove { if p.remove {
callback = nil callback = nil
} else { } else {
callback = *p.processor callback = *p.processor
} }
} }
} }
return return
} }
// getRIndex get right index from string slice // getRIndex get right index from string slice

View File

@ -1,20 +1,20 @@
package gorm package gorm
import ( import (
"context" "context"
"reflect" "reflect"
"runtime" "runtime"
"strings" "strings"
"testing" "testing"
) )
func equalFuncs(funcs []*func(ctx context.Context, s *Scope), fnames []string) bool { func equalFuncs(funcs []*func(ctx context.Context, s *Scope), fnames []string) bool {
var names []string var names []string
for _, f := range funcs { for _, f := range funcs {
fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(*f).Pointer()).Name(), ".") fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(*f).Pointer()).Name(), ".")
names = append(names, fnames[len(fnames)-1]) names = append(names, fnames[len(fnames)-1])
} }
return reflect.DeepEqual(names, fnames) return reflect.DeepEqual(names, fnames)
} }
func create(s *Scope) {} func create(s *Scope) {}
@ -24,90 +24,90 @@ func afterCreate1(s *Scope) {}
func afterCreate2(s *Scope) {} func afterCreate2(s *Scope) {}
func TestRegisterCallback(t *testing.T) { func TestRegisterCallback(t *testing.T) {
var callback = &Callback{logger: defaultLogger} var callback = &Callback{logger: defaultLogger}
callback.Create().Register("before_create1", beforeCreate1) callback.Create().Register("before_create1", beforeCreate1)
callback.Create().Register("before_create2", beforeCreate2) callback.Create().Register("before_create2", beforeCreate2)
callback.Create().Register("create", create) callback.Create().Register("create", create)
callback.Create().Register("after_create1", afterCreate1) callback.Create().Register("after_create1", afterCreate1)
callback.Create().Register("after_create2", afterCreate2) callback.Create().Register("after_create2", afterCreate2)
if !equalFuncs(callback.creates, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) { if !equalFuncs(callback.creates, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
t.Errorf("register callback") t.Errorf("register callback")
} }
} }
func TestRegisterCallbackWithOrder(t *testing.T) { func TestRegisterCallbackWithOrder(t *testing.T) {
var callback1 = &Callback{logger: defaultLogger} var callback1 = &Callback{logger: defaultLogger}
callback1.Create().Register("before_create1", beforeCreate1) callback1.Create().Register("before_create1", beforeCreate1)
callback1.Create().Register("create", create) callback1.Create().Register("create", create)
callback1.Create().Register("after_create1", afterCreate1) callback1.Create().Register("after_create1", afterCreate1)
callback1.Create().Before("after_create1").Register("after_create2", afterCreate2) callback1.Create().Before("after_create1").Register("after_create2", afterCreate2)
if !equalFuncs(callback1.creates, []string{"beforeCreate1", "create", "afterCreate2", "afterCreate1"}) { if !equalFuncs(callback1.creates, []string{"beforeCreate1", "create", "afterCreate2", "afterCreate1"}) {
t.Errorf("register callback with order") t.Errorf("register callback with order")
} }
var callback2 = &Callback{logger: defaultLogger} var callback2 = &Callback{logger: defaultLogger}
callback2.Update().Register("create", create) callback2.Update().Register("create", create)
callback2.Update().Before("create").Register("before_create1", beforeCreate1) callback2.Update().Before("create").Register("before_create1", beforeCreate1)
callback2.Update().After("after_create2").Register("after_create1", afterCreate1) callback2.Update().After("after_create2").Register("after_create1", afterCreate1)
callback2.Update().Before("before_create1").Register("before_create2", beforeCreate2) callback2.Update().Before("before_create1").Register("before_create2", beforeCreate2)
callback2.Update().Register("after_create2", afterCreate2) callback2.Update().Register("after_create2", afterCreate2)
if !equalFuncs(callback2.updates, []string{"beforeCreate2", "beforeCreate1", "create", "afterCreate2", "afterCreate1"}) { if !equalFuncs(callback2.updates, []string{"beforeCreate2", "beforeCreate1", "create", "afterCreate2", "afterCreate1"}) {
t.Errorf("register callback with order") t.Errorf("register callback with order")
} }
} }
func TestRegisterCallbackWithComplexOrder(t *testing.T) { func TestRegisterCallbackWithComplexOrder(t *testing.T) {
var callback1 = &Callback{logger: defaultLogger} var callback1 = &Callback{logger: defaultLogger}
callback1.Query().Before("after_create1").After("before_create1").Register("create", create) callback1.Query().Before("after_create1").After("before_create1").Register("create", create)
callback1.Query().Register("before_create1", beforeCreate1) callback1.Query().Register("before_create1", beforeCreate1)
callback1.Query().Register("after_create1", afterCreate1) callback1.Query().Register("after_create1", afterCreate1)
if !equalFuncs(callback1.queries, []string{"beforeCreate1", "create", "afterCreate1"}) { if !equalFuncs(callback1.queries, []string{"beforeCreate1", "create", "afterCreate1"}) {
t.Errorf("register callback with order") t.Errorf("register callback with order")
} }
var callback2 = &Callback{logger: defaultLogger} var callback2 = &Callback{logger: defaultLogger}
callback2.Delete().Before("after_create1").After("before_create1").Register("create", create) callback2.Delete().Before("after_create1").After("before_create1").Register("create", create)
callback2.Delete().Before("create").Register("before_create1", beforeCreate1) callback2.Delete().Before("create").Register("before_create1", beforeCreate1)
callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2) callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2)
callback2.Delete().Register("after_create1", afterCreate1) callback2.Delete().Register("after_create1", afterCreate1)
callback2.Delete().After("after_create1").Register("after_create2", afterCreate2) callback2.Delete().After("after_create1").Register("after_create2", afterCreate2)
if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) { if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
t.Errorf("register callback with order") t.Errorf("register callback with order")
} }
} }
func replaceCreate(s *Scope) {} func replaceCreate(s *Scope) {}
func TestReplaceCallback(t *testing.T) { func TestReplaceCallback(t *testing.T) {
var callback = &Callback{logger: defaultLogger} var callback = &Callback{logger: defaultLogger}
callback.Create().Before("after_create1").After("before_create1").Register("create", create) callback.Create().Before("after_create1").After("before_create1").Register("create", create)
callback.Create().Register("before_create1", beforeCreate1) callback.Create().Register("before_create1", beforeCreate1)
callback.Create().Register("after_create1", afterCreate1) callback.Create().Register("after_create1", afterCreate1)
callback.Create().Replace("create", replaceCreate) callback.Create().Replace("create", replaceCreate)
if !equalFuncs(callback.creates, []string{"beforeCreate1", "replaceCreate", "afterCreate1"}) { if !equalFuncs(callback.creates, []string{"beforeCreate1", "replaceCreate", "afterCreate1"}) {
t.Errorf("replace callback") t.Errorf("replace callback")
} }
} }
func TestRemoveCallback(t *testing.T) { func TestRemoveCallback(t *testing.T) {
var callback = &Callback{logger: defaultLogger} var callback = &Callback{logger: defaultLogger}
callback.Create().Before("after_create1").After("before_create1").Register("create", create) callback.Create().Before("after_create1").After("before_create1").Register("create", create)
callback.Create().Register("before_create1", beforeCreate1) callback.Create().Register("before_create1", beforeCreate1)
callback.Create().Register("after_create1", afterCreate1) callback.Create().Register("after_create1", afterCreate1)
callback.Create().Remove("create") callback.Create().Remove("create")
if !equalFuncs(callback.creates, []string{"beforeCreate1", "afterCreate1"}) { if !equalFuncs(callback.creates, []string{"beforeCreate1", "afterCreate1"}) {
t.Errorf("remove callback") t.Errorf("remove callback")
} }
} }

View File

@ -1,249 +1,249 @@
package gorm_test package gorm_test
import ( import (
"errors" "errors"
"reflect" "reflect"
"testing" "testing"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
) )
func (s *Product) BeforeCreate() (err error) { func (s *Product) BeforeCreate() (err error) {
if s.Code == "Invalid" { if s.Code == "Invalid" {
err = errors.New("invalid product") err = errors.New("invalid product")
} }
s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1 s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1
return return
} }
func (s *Product) BeforeUpdate() (err error) { func (s *Product) BeforeUpdate() (err error) {
if s.Code == "dont_update" { if s.Code == "dont_update" {
err = errors.New("can't update") err = errors.New("can't update")
} }
s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1 s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1
return return
} }
func (s *Product) BeforeSave() (err error) { func (s *Product) BeforeSave() (err error) {
if s.Code == "dont_save" { if s.Code == "dont_save" {
err = errors.New("can't save") err = errors.New("can't save")
} }
s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1 s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1
return return
} }
func (s *Product) AfterFind() { func (s *Product) AfterFind() {
s.AfterFindCallTimes = s.AfterFindCallTimes + 1 s.AfterFindCallTimes = s.AfterFindCallTimes + 1
} }
func (s *Product) AfterCreate(tx *gorm.DB) { func (s *Product) AfterCreate(tx *gorm.DB) {
tx.Model(s).UpdateColumn(Product{AfterCreateCallTimes: s.AfterCreateCallTimes + 1}) tx.Model(s).UpdateColumn(Product{AfterCreateCallTimes: s.AfterCreateCallTimes + 1})
} }
func (s *Product) AfterUpdate() { func (s *Product) AfterUpdate() {
s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1 s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1
} }
func (s *Product) AfterSave() (err error) { func (s *Product) AfterSave() (err error) {
if s.Code == "after_save_error" { if s.Code == "after_save_error" {
err = errors.New("can't save") err = errors.New("can't save")
} }
s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1 s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1
return return
} }
func (s *Product) BeforeDelete() (err error) { func (s *Product) BeforeDelete() (err error) {
if s.Code == "dont_delete" { if s.Code == "dont_delete" {
err = errors.New("can't delete") err = errors.New("can't delete")
} }
s.BeforeDeleteCallTimes = s.BeforeDeleteCallTimes + 1 s.BeforeDeleteCallTimes = s.BeforeDeleteCallTimes + 1
return return
} }
func (s *Product) AfterDelete() (err error) { func (s *Product) AfterDelete() (err error) {
if s.Code == "after_delete_error" { if s.Code == "after_delete_error" {
err = errors.New("can't delete") err = errors.New("can't delete")
} }
s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1 s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1
return return
} }
func (s *Product) GetCallTimes() []int64 { func (s *Product) GetCallTimes() []int64 {
return []int64{s.BeforeCreateCallTimes, s.BeforeSaveCallTimes, s.BeforeUpdateCallTimes, s.AfterCreateCallTimes, s.AfterSaveCallTimes, s.AfterUpdateCallTimes, s.BeforeDeleteCallTimes, s.AfterDeleteCallTimes, s.AfterFindCallTimes} return []int64{s.BeforeCreateCallTimes, s.BeforeSaveCallTimes, s.BeforeUpdateCallTimes, s.AfterCreateCallTimes, s.AfterSaveCallTimes, s.AfterUpdateCallTimes, s.BeforeDeleteCallTimes, s.AfterDeleteCallTimes, s.AfterFindCallTimes}
} }
func TestRunCallbacks(t *testing.T) { func TestRunCallbacks(t *testing.T) {
p := Product{Code: "unique_code", Price: 100} p := Product{Code: "unique_code", Price: 100}
DB.Save(&p) DB.Save(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0}) { if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0}) {
t.Errorf("Callbacks should be invoked successfully, %v", p.GetCallTimes()) t.Errorf("Callbacks should be invoked successfully, %v", p.GetCallTimes())
} }
DB.Where("Code = ?", "unique_code").First(&p) DB.Where("Code = ?", "unique_code").First(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0, 1}) { if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0, 1}) {
t.Errorf("After callbacks values are not saved, %v", p.GetCallTimes()) t.Errorf("After callbacks values are not saved, %v", p.GetCallTimes())
} }
p.Price = 200 p.Price = 200
DB.Save(&p) DB.Save(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0, 1}) { if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0, 1}) {
t.Errorf("After update callbacks should be invoked successfully, %v", p.GetCallTimes()) t.Errorf("After update callbacks should be invoked successfully, %v", p.GetCallTimes())
} }
var products []Product var products []Product
DB.Find(&products, "code = ?", "unique_code") DB.Find(&products, "code = ?", "unique_code")
if products[0].AfterFindCallTimes != 2 { if products[0].AfterFindCallTimes != 2 {
t.Errorf("AfterFind callbacks should work with slice") t.Errorf("AfterFind callbacks should work with slice")
} }
DB.Where("Code = ?", "unique_code").First(&p) DB.Where("Code = ?", "unique_code").First(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2}) { if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2}) {
t.Errorf("After update callbacks values are not saved, %v", p.GetCallTimes()) t.Errorf("After update callbacks values are not saved, %v", p.GetCallTimes())
} }
DB.Delete(&p) DB.Delete(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2}) { if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2}) {
t.Errorf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes()) t.Errorf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes())
} }
if DB.Where("Code = ?", "unique_code").First(&p).Error == nil { if DB.Where("Code = ?", "unique_code").First(&p).Error == nil {
t.Errorf("Can't find a deleted record") t.Errorf("Can't find a deleted record")
} }
} }
func TestCallbacksWithErrors(t *testing.T) { func TestCallbacksWithErrors(t *testing.T) {
p := Product{Code: "Invalid", Price: 100} p := Product{Code: "Invalid", Price: 100}
if DB.Save(&p).Error == nil { if DB.Save(&p).Error == nil {
t.Errorf("An error from before create callbacks happened when create with invalid value") t.Errorf("An error from before create callbacks happened when create with invalid value")
} }
if DB.Where("code = ?", "Invalid").First(&Product{}).Error == nil { if DB.Where("code = ?", "Invalid").First(&Product{}).Error == nil {
t.Errorf("Should not save record that have errors") t.Errorf("Should not save record that have errors")
} }
if DB.Save(&Product{Code: "dont_save", Price: 100}).Error == nil { if DB.Save(&Product{Code: "dont_save", Price: 100}).Error == nil {
t.Errorf("An error from after create callbacks happened when create with invalid value") t.Errorf("An error from after create callbacks happened when create with invalid value")
} }
p2 := Product{Code: "update_callback", Price: 100} p2 := Product{Code: "update_callback", Price: 100}
DB.Save(&p2) DB.Save(&p2)
p2.Code = "dont_update" p2.Code = "dont_update"
if DB.Save(&p2).Error == nil { if DB.Save(&p2).Error == nil {
t.Errorf("An error from before update callbacks happened when update with invalid value") t.Errorf("An error from before update callbacks happened when update with invalid value")
} }
if DB.Where("code = ?", "update_callback").First(&Product{}).Error != nil { if DB.Where("code = ?", "update_callback").First(&Product{}).Error != nil {
t.Errorf("Record Should not be updated due to errors happened in before update callback") t.Errorf("Record Should not be updated due to errors happened in before update callback")
} }
if DB.Where("code = ?", "dont_update").First(&Product{}).Error == nil { if DB.Where("code = ?", "dont_update").First(&Product{}).Error == nil {
t.Errorf("Record Should not be updated due to errors happened in before update callback") t.Errorf("Record Should not be updated due to errors happened in before update callback")
} }
p2.Code = "dont_save" p2.Code = "dont_save"
if DB.Save(&p2).Error == nil { if DB.Save(&p2).Error == nil {
t.Errorf("An error from before save callbacks happened when update with invalid value") t.Errorf("An error from before save callbacks happened when update with invalid value")
} }
p3 := Product{Code: "dont_delete", Price: 100} p3 := Product{Code: "dont_delete", Price: 100}
DB.Save(&p3) DB.Save(&p3)
if DB.Delete(&p3).Error == nil { if DB.Delete(&p3).Error == nil {
t.Errorf("An error from before delete callbacks happened when delete") t.Errorf("An error from before delete callbacks happened when delete")
} }
if DB.Where("Code = ?", "dont_delete").First(&p3).Error != nil { if DB.Where("Code = ?", "dont_delete").First(&p3).Error != nil {
t.Errorf("An error from before delete callbacks happened") t.Errorf("An error from before delete callbacks happened")
} }
p4 := Product{Code: "after_save_error", Price: 100} p4 := Product{Code: "after_save_error", Price: 100}
DB.Save(&p4) DB.Save(&p4)
if err := DB.First(&Product{}, "code = ?", "after_save_error").Error; err == nil { if err := DB.First(&Product{}, "code = ?", "after_save_error").Error; err == nil {
t.Errorf("Record should be reverted if get an error in after save callback") t.Errorf("Record should be reverted if get an error in after save callback")
} }
p5 := Product{Code: "after_delete_error", Price: 100} p5 := Product{Code: "after_delete_error", Price: 100}
DB.Save(&p5) DB.Save(&p5)
if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
t.Errorf("Record should be found") t.Errorf("Record should be found")
} }
DB.Delete(&p5) DB.Delete(&p5)
if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback") t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback")
} }
} }
func TestGetCallback(t *testing.T) { func TestGetCallback(t *testing.T) {
scope := DB.NewScope(nil) scope := DB.NewScope(nil)
if DB.Callback().Create().Get("gorm:test_callback") != nil { if DB.Callback().Create().Get("gorm:test_callback") != nil {
t.Errorf("`gorm:test_callback` should be nil") t.Errorf("`gorm:test_callback` should be nil")
} }
DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 1) }) DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 1) })
callback := DB.Callback().Create().Get("gorm:test_callback") callback := DB.Callback().Create().Get("gorm:test_callback")
if callback == nil { if callback == nil {
t.Errorf("`gorm:test_callback` should be non-nil") t.Errorf("`gorm:test_callback` should be non-nil")
} }
callback(scope) callback(scope)
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 1 { if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 1 {
t.Errorf("`gorm:test_callback_value` should be `1, true` but `%v, %v`", v, ok) t.Errorf("`gorm:test_callback_value` should be `1, true` but `%v, %v`", v, ok)
} }
DB.Callback().Create().Replace("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 2) }) DB.Callback().Create().Replace("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 2) })
callback = DB.Callback().Create().Get("gorm:test_callback") callback = DB.Callback().Create().Get("gorm:test_callback")
if callback == nil { if callback == nil {
t.Errorf("`gorm:test_callback` should be non-nil") t.Errorf("`gorm:test_callback` should be non-nil")
} }
callback(scope) callback(scope)
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 2 { if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 2 {
t.Errorf("`gorm:test_callback_value` should be `2, true` but `%v, %v`", v, ok) t.Errorf("`gorm:test_callback_value` should be `2, true` but `%v, %v`", v, ok)
} }
DB.Callback().Create().Remove("gorm:test_callback") DB.Callback().Create().Remove("gorm:test_callback")
if DB.Callback().Create().Get("gorm:test_callback") != nil { if DB.Callback().Create().Get("gorm:test_callback") != nil {
t.Errorf("`gorm:test_callback` should be nil") t.Errorf("`gorm:test_callback` should be nil")
} }
DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 3) }) DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 3) })
callback = DB.Callback().Create().Get("gorm:test_callback") callback = DB.Callback().Create().Get("gorm:test_callback")
if callback == nil { if callback == nil {
t.Errorf("`gorm:test_callback` should be non-nil") t.Errorf("`gorm:test_callback` should be non-nil")
} }
callback(scope) callback(scope)
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 3 { if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 3 {
t.Errorf("`gorm:test_callback_value` should be `3, true` but `%v, %v`", v, ok) t.Errorf("`gorm:test_callback_value` should be `3, true` but `%v, %v`", v, ok)
} }
} }
func TestUseDefaultCallback(t *testing.T) { func TestUseDefaultCallback(t *testing.T) {
createCallbackName := "gorm:test_use_default_callback_for_create" createCallbackName := "gorm:test_use_default_callback_for_create"
gorm.DefaultCallback.Create().Register(createCallbackName, func(*gorm.Scope) { gorm.DefaultCallback.Create().Register(createCallbackName, func(*gorm.Scope) {
// nop // nop
}) })
if gorm.DefaultCallback.Create().Get(createCallbackName) == nil { if gorm.DefaultCallback.Create().Get(createCallbackName) == nil {
t.Errorf("`%s` expected non-nil, but got nil", createCallbackName) t.Errorf("`%s` expected non-nil, but got nil", createCallbackName)
} }
gorm.DefaultCallback.Create().Remove(createCallbackName) gorm.DefaultCallback.Create().Remove(createCallbackName)
if gorm.DefaultCallback.Create().Get(createCallbackName) != nil { if gorm.DefaultCallback.Create().Get(createCallbackName) != nil {
t.Errorf("`%s` expected nil, but got non-nil", createCallbackName) t.Errorf("`%s` expected nil, but got non-nil", createCallbackName)
} }
updateCallbackName := "gorm:test_use_default_callback_for_update" updateCallbackName := "gorm:test_use_default_callback_for_update"
scopeValueName := "gorm:test_use_default_callback_for_update_value" scopeValueName := "gorm:test_use_default_callback_for_update_value"
gorm.DefaultCallback.Update().Register(updateCallbackName, func(scope *gorm.Scope) { gorm.DefaultCallback.Update().Register(updateCallbackName, func(scope *gorm.Scope) {
scope.Set(scopeValueName, 1) scope.Set(scopeValueName, 1)
}) })
gorm.DefaultCallback.Update().Replace(updateCallbackName, func(scope *gorm.Scope) { gorm.DefaultCallback.Update().Replace(updateCallbackName, func(scope *gorm.Scope) {
scope.Set(scopeValueName, 2) scope.Set(scopeValueName, 2)
}) })
scope := DB.NewScope(nil) scope := DB.NewScope(nil)
callback := gorm.DefaultCallback.Update().Get(updateCallbackName) callback := gorm.DefaultCallback.Update().Get(updateCallbackName)
callback(scope) callback(scope)
if v, ok := scope.Get(scopeValueName); !ok || v != 2 { if v, ok := scope.Get(scopeValueName); !ok || v != 2 {
t.Errorf("`%s` should be `2, true` but `%v, %v`", scopeValueName, v, ok) t.Errorf("`%s` should be `2, true` but `%v, %v`", scopeValueName, v, ok)
} }
} }

View File

@ -1,357 +1,357 @@
package gorm_test package gorm_test
import ( import (
"testing" "testing"
"time" "time"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
) )
type CustomizeColumn struct { type CustomizeColumn struct {
ID int64 `gorm:"column:mapped_id; primary_key:yes"` ID int64 `gorm:"column:mapped_id; primary_key:yes"`
Name string `gorm:"column:mapped_name"` Name string `gorm:"column:mapped_name"`
Date *time.Time `gorm:"column:mapped_time"` Date *time.Time `gorm:"column:mapped_time"`
} }
// Make sure an ignored field does not interfere with another field's custom // Make sure an ignored field does not interfere with another field's custom
// column name that matches the ignored field. // column name that matches the ignored field.
type CustomColumnAndIgnoredFieldClash struct { type CustomColumnAndIgnoredFieldClash struct {
Body string `sql:"-"` Body string `sql:"-"`
RawBody string `gorm:"column:body"` RawBody string `gorm:"column:body"`
} }
func TestCustomizeColumn(t *testing.T) { func TestCustomizeColumn(t *testing.T) {
col := "mapped_name" col := "mapped_name"
DB.DropTable(&CustomizeColumn{}) DB.DropTable(&CustomizeColumn{})
DB.AutoMigrate(&CustomizeColumn{}) DB.AutoMigrate(&CustomizeColumn{})
scope := DB.NewScope(&CustomizeColumn{}) scope := DB.NewScope(&CustomizeColumn{})
if !scope.Dialect().HasColumn(scope.TableName(), col) { if !scope.Dialect().HasColumn(scope.TableName(), col) {
t.Errorf("CustomizeColumn should have column %s", col) t.Errorf("CustomizeColumn should have column %s", col)
} }
col = "mapped_id" col = "mapped_id"
if scope.PrimaryKey() != col { if scope.PrimaryKey() != col {
t.Errorf("CustomizeColumn should have primary key %s, but got %q", col, scope.PrimaryKey()) t.Errorf("CustomizeColumn should have primary key %s, but got %q", col, scope.PrimaryKey())
} }
expected := "foo" expected := "foo"
now := time.Now() now := time.Now()
cc := CustomizeColumn{ID: 666, Name: expected, Date: &now} cc := CustomizeColumn{ID: 666, Name: expected, Date: &now}
if count := DB.Create(&cc).RowsAffected; count != 1 { if count := DB.Create(&cc).RowsAffected; count != 1 {
t.Error("There should be one record be affected when create record") t.Error("There should be one record be affected when create record")
} }
var cc1 CustomizeColumn var cc1 CustomizeColumn
DB.First(&cc1, 666) DB.First(&cc1, 666)
if cc1.Name != expected { if cc1.Name != expected {
t.Errorf("Failed to query CustomizeColumn") t.Errorf("Failed to query CustomizeColumn")
} }
cc.Name = "bar" cc.Name = "bar"
DB.Save(&cc) DB.Save(&cc)
var cc2 CustomizeColumn var cc2 CustomizeColumn
DB.First(&cc2, 666) DB.First(&cc2, 666)
if cc2.Name != "bar" { if cc2.Name != "bar" {
t.Errorf("Failed to query CustomizeColumn") t.Errorf("Failed to query CustomizeColumn")
} }
} }
func TestCustomColumnAndIgnoredFieldClash(t *testing.T) { func TestCustomColumnAndIgnoredFieldClash(t *testing.T) {
DB.DropTable(&CustomColumnAndIgnoredFieldClash{}) DB.DropTable(&CustomColumnAndIgnoredFieldClash{})
if err := DB.AutoMigrate(&CustomColumnAndIgnoredFieldClash{}).Error; err != nil { if err := DB.AutoMigrate(&CustomColumnAndIgnoredFieldClash{}).Error; err != nil {
t.Errorf("Should not raise error: %s", err) t.Errorf("Should not raise error: %s", err)
} }
} }
type CustomizePerson struct { type CustomizePerson struct {
IdPerson string `gorm:"column:idPerson;primary_key:true"` IdPerson string `gorm:"column:idPerson;primary_key:true"`
Accounts []CustomizeAccount `gorm:"many2many:PersonAccount;associationforeignkey:idAccount;foreignkey:idPerson"` Accounts []CustomizeAccount `gorm:"many2many:PersonAccount;associationforeignkey:idAccount;foreignkey:idPerson"`
} }
type CustomizeAccount struct { type CustomizeAccount struct {
IdAccount string `gorm:"column:idAccount;primary_key:true"` IdAccount string `gorm:"column:idAccount;primary_key:true"`
Name string Name string
} }
func TestManyToManyWithCustomizedColumn(t *testing.T) { func TestManyToManyWithCustomizedColumn(t *testing.T) {
DB.DropTable(&CustomizePerson{}, &CustomizeAccount{}, "PersonAccount") DB.DropTable(&CustomizePerson{}, &CustomizeAccount{}, "PersonAccount")
DB.AutoMigrate(&CustomizePerson{}, &CustomizeAccount{}) DB.AutoMigrate(&CustomizePerson{}, &CustomizeAccount{})
account := CustomizeAccount{IdAccount: "account", Name: "id1"} account := CustomizeAccount{IdAccount: "account", Name: "id1"}
person := CustomizePerson{ person := CustomizePerson{
IdPerson: "person", IdPerson: "person",
Accounts: []CustomizeAccount{account}, Accounts: []CustomizeAccount{account},
} }
if err := DB.Create(&account).Error; err != nil { if err := DB.Create(&account).Error; err != nil {
t.Errorf("no error should happen, but got %v", err) t.Errorf("no error should happen, but got %v", err)
} }
if err := DB.Create(&person).Error; err != nil { if err := DB.Create(&person).Error; err != nil {
t.Errorf("no error should happen, but got %v", err) t.Errorf("no error should happen, but got %v", err)
} }
var person1 CustomizePerson var person1 CustomizePerson
scope := DB.NewScope(nil) scope := DB.NewScope(nil)
if err := DB.Preload("Accounts").First(&person1, scope.Quote("idPerson")+" = ?", person.IdPerson).Error; err != nil { if err := DB.Preload("Accounts").First(&person1, scope.Quote("idPerson")+" = ?", person.IdPerson).Error; err != nil {
t.Errorf("no error should happen when preloading customized column many2many relations, but got %v", err) t.Errorf("no error should happen when preloading customized column many2many relations, but got %v", err)
} }
if len(person1.Accounts) != 1 || person1.Accounts[0].IdAccount != "account" { if len(person1.Accounts) != 1 || person1.Accounts[0].IdAccount != "account" {
t.Errorf("should preload correct accounts") t.Errorf("should preload correct accounts")
} }
} }
type CustomizeUser struct { type CustomizeUser struct {
gorm.Model gorm.Model
Email string `sql:"column:email_address"` Email string `sql:"column:email_address"`
} }
type CustomizeInvitation struct { type CustomizeInvitation struct {
gorm.Model gorm.Model
Address string `sql:"column:invitation"` Address string `sql:"column:invitation"`
Person *CustomizeUser `gorm:"foreignkey:Email;associationforeignkey:invitation"` Person *CustomizeUser `gorm:"foreignkey:Email;associationforeignkey:invitation"`
} }
func TestOneToOneWithCustomizedColumn(t *testing.T) { func TestOneToOneWithCustomizedColumn(t *testing.T) {
DB.DropTable(&CustomizeUser{}, &CustomizeInvitation{}) DB.DropTable(&CustomizeUser{}, &CustomizeInvitation{})
DB.AutoMigrate(&CustomizeUser{}, &CustomizeInvitation{}) DB.AutoMigrate(&CustomizeUser{}, &CustomizeInvitation{})
user := CustomizeUser{ user := CustomizeUser{
Email: "hello@example.com", Email: "hello@example.com",
} }
invitation := CustomizeInvitation{ invitation := CustomizeInvitation{
Address: "hello@example.com", Address: "hello@example.com",
} }
DB.Create(&user) DB.Create(&user)
DB.Create(&invitation) DB.Create(&invitation)
var invitation2 CustomizeInvitation var invitation2 CustomizeInvitation
if err := DB.Preload("Person").Find(&invitation2, invitation.ID).Error; err != nil { if err := DB.Preload("Person").Find(&invitation2, invitation.ID).Error; err != nil {
t.Errorf("no error should happen, but got %v", err) t.Errorf("no error should happen, but got %v", err)
} }
if invitation2.Person.Email != user.Email { if invitation2.Person.Email != user.Email {
t.Errorf("Should preload one to one relation with customize foreign keys") t.Errorf("Should preload one to one relation with customize foreign keys")
} }
} }
type PromotionDiscount struct { type PromotionDiscount struct {
gorm.Model gorm.Model
Name string Name string
Coupons []*PromotionCoupon `gorm:"ForeignKey:discount_id"` Coupons []*PromotionCoupon `gorm:"ForeignKey:discount_id"`
Rule *PromotionRule `gorm:"ForeignKey:discount_id"` Rule *PromotionRule `gorm:"ForeignKey:discount_id"`
Benefits []PromotionBenefit `gorm:"ForeignKey:promotion_id"` Benefits []PromotionBenefit `gorm:"ForeignKey:promotion_id"`
} }
type PromotionBenefit struct { type PromotionBenefit struct {
gorm.Model gorm.Model
Name string Name string
PromotionID uint PromotionID uint
Discount PromotionDiscount `gorm:"ForeignKey:promotion_id"` Discount PromotionDiscount `gorm:"ForeignKey:promotion_id"`
} }
type PromotionCoupon struct { type PromotionCoupon struct {
gorm.Model gorm.Model
Code string Code string
DiscountID uint DiscountID uint
Discount PromotionDiscount Discount PromotionDiscount
} }
type PromotionRule struct { type PromotionRule struct {
gorm.Model gorm.Model
Name string Name string
Begin *time.Time Begin *time.Time
End *time.Time End *time.Time
DiscountID uint DiscountID uint
Discount *PromotionDiscount Discount *PromotionDiscount
} }
func TestOneToManyWithCustomizedColumn(t *testing.T) { func TestOneToManyWithCustomizedColumn(t *testing.T) {
DB.DropTable(&PromotionDiscount{}, &PromotionCoupon{}) DB.DropTable(&PromotionDiscount{}, &PromotionCoupon{})
DB.AutoMigrate(&PromotionDiscount{}, &PromotionCoupon{}) DB.AutoMigrate(&PromotionDiscount{}, &PromotionCoupon{})
discount := PromotionDiscount{ discount := PromotionDiscount{
Name: "Happy New Year", Name: "Happy New Year",
Coupons: []*PromotionCoupon{ Coupons: []*PromotionCoupon{
{Code: "newyear1"}, {Code: "newyear1"},
{Code: "newyear2"}, {Code: "newyear2"},
}, },
} }
if err := DB.Create(&discount).Error; err != nil { if err := DB.Create(&discount).Error; err != nil {
t.Errorf("no error should happen but got %v", err) t.Errorf("no error should happen but got %v", err)
} }
var discount1 PromotionDiscount var discount1 PromotionDiscount
if err := DB.Preload("Coupons").First(&discount1, "id = ?", discount.ID).Error; err != nil { if err := DB.Preload("Coupons").First(&discount1, "id = ?", discount.ID).Error; err != nil {
t.Errorf("no error should happen but got %v", err) t.Errorf("no error should happen but got %v", err)
} }
if len(discount.Coupons) != 2 { if len(discount.Coupons) != 2 {
t.Errorf("should find two coupons") t.Errorf("should find two coupons")
} }
var coupon PromotionCoupon var coupon PromotionCoupon
if err := DB.Preload("Discount").First(&coupon, "code = ?", "newyear1").Error; err != nil { if err := DB.Preload("Discount").First(&coupon, "code = ?", "newyear1").Error; err != nil {
t.Errorf("no error should happen but got %v", err) t.Errorf("no error should happen but got %v", err)
} }
if coupon.Discount.Name != "Happy New Year" { if coupon.Discount.Name != "Happy New Year" {
t.Errorf("should preload discount from coupon") t.Errorf("should preload discount from coupon")
} }
} }
func TestHasOneWithPartialCustomizedColumn(t *testing.T) { func TestHasOneWithPartialCustomizedColumn(t *testing.T) {
DB.DropTable(&PromotionDiscount{}, &PromotionRule{}) DB.DropTable(&PromotionDiscount{}, &PromotionRule{})
DB.AutoMigrate(&PromotionDiscount{}, &PromotionRule{}) DB.AutoMigrate(&PromotionDiscount{}, &PromotionRule{})
var begin = time.Now() var begin = time.Now()
var end = time.Now().Add(24 * time.Hour) var end = time.Now().Add(24 * time.Hour)
discount := PromotionDiscount{ discount := PromotionDiscount{
Name: "Happy New Year 2", Name: "Happy New Year 2",
Rule: &PromotionRule{ Rule: &PromotionRule{
Name: "time_limited", Name: "time_limited",
Begin: &begin, Begin: &begin,
End: &end, End: &end,
}, },
} }
if err := DB.Create(&discount).Error; err != nil { if err := DB.Create(&discount).Error; err != nil {
t.Errorf("no error should happen but got %v", err) t.Errorf("no error should happen but got %v", err)
} }
var discount1 PromotionDiscount var discount1 PromotionDiscount
if err := DB.Preload("Rule").First(&discount1, "id = ?", discount.ID).Error; err != nil { if err := DB.Preload("Rule").First(&discount1, "id = ?", discount.ID).Error; err != nil {
t.Errorf("no error should happen but got %v", err) t.Errorf("no error should happen but got %v", err)
} }
if discount.Rule.Begin.Format(time.RFC3339Nano) != begin.Format(time.RFC3339Nano) { if discount.Rule.Begin.Format(time.RFC3339Nano) != begin.Format(time.RFC3339Nano) {
t.Errorf("Should be able to preload Rule") t.Errorf("Should be able to preload Rule")
} }
var rule PromotionRule var rule PromotionRule
if err := DB.Preload("Discount").First(&rule, "name = ?", "time_limited").Error; err != nil { if err := DB.Preload("Discount").First(&rule, "name = ?", "time_limited").Error; err != nil {
t.Errorf("no error should happen but got %v", err) t.Errorf("no error should happen but got %v", err)
} }
if rule.Discount.Name != "Happy New Year 2" { if rule.Discount.Name != "Happy New Year 2" {
t.Errorf("should preload discount from rule") t.Errorf("should preload discount from rule")
} }
} }
func TestBelongsToWithPartialCustomizedColumn(t *testing.T) { func TestBelongsToWithPartialCustomizedColumn(t *testing.T) {
DB.DropTable(&PromotionDiscount{}, &PromotionBenefit{}) DB.DropTable(&PromotionDiscount{}, &PromotionBenefit{})
DB.AutoMigrate(&PromotionDiscount{}, &PromotionBenefit{}) DB.AutoMigrate(&PromotionDiscount{}, &PromotionBenefit{})
discount := PromotionDiscount{ discount := PromotionDiscount{
Name: "Happy New Year 3", Name: "Happy New Year 3",
Benefits: []PromotionBenefit{ Benefits: []PromotionBenefit{
{Name: "free cod"}, {Name: "free cod"},
{Name: "free shipping"}, {Name: "free shipping"},
}, },
} }
if err := DB.Create(&discount).Error; err != nil { if err := DB.Create(&discount).Error; err != nil {
t.Errorf("no error should happen but got %v", err) t.Errorf("no error should happen but got %v", err)
} }
var discount1 PromotionDiscount var discount1 PromotionDiscount
if err := DB.Preload("Benefits").First(&discount1, "id = ?", discount.ID).Error; err != nil { if err := DB.Preload("Benefits").First(&discount1, "id = ?", discount.ID).Error; err != nil {
t.Errorf("no error should happen but got %v", err) t.Errorf("no error should happen but got %v", err)
} }
if len(discount.Benefits) != 2 { if len(discount.Benefits) != 2 {
t.Errorf("should find two benefits") t.Errorf("should find two benefits")
} }
var benefit PromotionBenefit var benefit PromotionBenefit
if err := DB.Preload("Discount").First(&benefit, "name = ?", "free cod").Error; err != nil { if err := DB.Preload("Discount").First(&benefit, "name = ?", "free cod").Error; err != nil {
t.Errorf("no error should happen but got %v", err) t.Errorf("no error should happen but got %v", err)
} }
if benefit.Discount.Name != "Happy New Year 3" { if benefit.Discount.Name != "Happy New Year 3" {
t.Errorf("should preload discount from coupon") t.Errorf("should preload discount from coupon")
} }
} }
type SelfReferencingUser struct { type SelfReferencingUser struct {
gorm.Model gorm.Model
Name string Name string
Friends []*SelfReferencingUser `gorm:"many2many:UserFriends;association_jointable_foreignkey:friend_id"` Friends []*SelfReferencingUser `gorm:"many2many:UserFriends;association_jointable_foreignkey:friend_id"`
} }
func TestSelfReferencingMany2ManyColumn(t *testing.T) { func TestSelfReferencingMany2ManyColumn(t *testing.T) {
DB.DropTable(&SelfReferencingUser{}, "UserFriends") DB.DropTable(&SelfReferencingUser{}, "UserFriends")
DB.AutoMigrate(&SelfReferencingUser{}) DB.AutoMigrate(&SelfReferencingUser{})
if !DB.HasTable("UserFriends") { if !DB.HasTable("UserFriends") {
t.Errorf("auto migrate error, table UserFriends should be created") t.Errorf("auto migrate error, table UserFriends should be created")
} }
friend1 := SelfReferencingUser{Name: "friend1_m2m"} friend1 := SelfReferencingUser{Name: "friend1_m2m"}
if err := DB.Create(&friend1).Error; err != nil { if err := DB.Create(&friend1).Error; err != nil {
t.Errorf("no error should happen, but got %v", err) t.Errorf("no error should happen, but got %v", err)
} }
friend2 := SelfReferencingUser{Name: "friend2_m2m"} friend2 := SelfReferencingUser{Name: "friend2_m2m"}
if err := DB.Create(&friend2).Error; err != nil { if err := DB.Create(&friend2).Error; err != nil {
t.Errorf("no error should happen, but got %v", err) t.Errorf("no error should happen, but got %v", err)
} }
user := SelfReferencingUser{ user := SelfReferencingUser{
Name: "self_m2m", Name: "self_m2m",
Friends: []*SelfReferencingUser{&friend1, &friend2}, Friends: []*SelfReferencingUser{&friend1, &friend2},
} }
if err := DB.Create(&user).Error; err != nil { if err := DB.Create(&user).Error; err != nil {
t.Errorf("no error should happen, but got %v", err) t.Errorf("no error should happen, but got %v", err)
} }
if DB.Model(&user).Association("Friends").Count() != 2 { if DB.Model(&user).Association("Friends").Count() != 2 {
t.Errorf("Should find created friends correctly") t.Errorf("Should find created friends correctly")
} }
var count int var count int
if err := DB.Table("UserFriends").Count(&count).Error; err != nil { if err := DB.Table("UserFriends").Count(&count).Error; err != nil {
t.Errorf("no error should happen, but got %v", err) t.Errorf("no error should happen, but got %v", err)
} }
if count == 0 { if count == 0 {
t.Errorf("table UserFriends should have records") t.Errorf("table UserFriends should have records")
} }
var newUser = SelfReferencingUser{} var newUser = SelfReferencingUser{}
if err := DB.Preload("Friends").First(&newUser, "id = ?", user.ID).Error; err != nil { if err := DB.Preload("Friends").First(&newUser, "id = ?", user.ID).Error; err != nil {
t.Errorf("no error should happen, but got %v", err) t.Errorf("no error should happen, but got %v", err)
} }
if len(newUser.Friends) != 2 { if len(newUser.Friends) != 2 {
t.Errorf("Should preload created frineds for self reference m2m") t.Errorf("Should preload created frineds for self reference m2m")
} }
DB.Model(&newUser).Association("Friends").Append(&SelfReferencingUser{Name: "friend3_m2m"}) DB.Model(&newUser).Association("Friends").Append(&SelfReferencingUser{Name: "friend3_m2m"})
if DB.Model(&user).Association("Friends").Count() != 3 { if DB.Model(&user).Association("Friends").Count() != 3 {
t.Errorf("Should find created friends correctly") t.Errorf("Should find created friends correctly")
} }
DB.Model(&newUser).Association("Friends").Replace(&SelfReferencingUser{Name: "friend4_m2m"}) DB.Model(&newUser).Association("Friends").Replace(&SelfReferencingUser{Name: "friend4_m2m"})
if DB.Model(&user).Association("Friends").Count() != 1 { if DB.Model(&user).Association("Friends").Count() != 1 {
t.Errorf("Should find created friends correctly") t.Errorf("Should find created friends correctly")
} }
friend := SelfReferencingUser{} friend := SelfReferencingUser{}
DB.Model(&newUser).Association("Friends").Find(&friend) DB.Model(&newUser).Association("Friends").Find(&friend)
if friend.Name != "friend4_m2m" { if friend.Name != "friend4_m2m" {
t.Errorf("Should find created friends correctly") t.Errorf("Should find created friends correctly")
} }
DB.Model(&newUser).Association("Friends").Delete(friend) DB.Model(&newUser).Association("Friends").Delete(friend)
if DB.Model(&user).Association("Friends").Count() != 0 { if DB.Model(&user).Association("Friends").Count() != 0 {
t.Errorf("All friends should be deleted") t.Errorf("All friends should be deleted")
} }
} }

View File

@ -25,28 +25,28 @@ type Dialect interface {
DataTypeOf(field *StructField) string DataTypeOf(field *StructField) string
// HasIndex check has index or not // HasIndex check has index or not
HasIndex(tableName string, indexName string) bool HasIndex(tableName string, indexName string) bool
// HasIndexContext same as HasIndex // HasIndexContext same as HasIndex
HasIndexContext(ctx context.Context, tableName string, indexName string) bool HasIndexContext(ctx context.Context, tableName string, indexName string) bool
// HasForeignKey check has foreign key or not // HasForeignKey check has foreign key or not
HasForeignKey(tableName string, foreignKeyName string) bool HasForeignKey(tableName string, foreignKeyName string) bool
// HasForeignKeyContext same as HasForeignKey // HasForeignKeyContext same as HasForeignKey
HasForeignKeyContext(ctx context.Context, tableName string, foreignKeyName string) bool HasForeignKeyContext(ctx context.Context, tableName string, foreignKeyName string) bool
// RemoveIndex remove index // RemoveIndex remove index
RemoveIndex(tableName string, indexName string) error RemoveIndex(tableName string, indexName string) error
// RemoveIndexContext same as RemoveIndex // RemoveIndexContext same as RemoveIndex
RemoveIndexContext(ctx context.Context, tableName string, indexName string) error RemoveIndexContext(ctx context.Context, tableName string, indexName string) error
// HasTable check has table or not // HasTable check has table or not
HasTable(tableName string) bool HasTable(tableName string) bool
// HasTableContext same as HasTable // HasTableContext same as HasTable
HasTableContext(ctx context.Context, tableName string) bool HasTableContext(ctx context.Context, tableName string) bool
// HasColumn check has column or not // HasColumn check has column or not
HasColumn(tableName string, columnName string) bool HasColumn(tableName string, columnName string) bool
// HasColumnContext same as HasColumn // HasColumnContext same as HasColumn
HasColumnContext(ctx context.Context, tableName string, columnName string) bool HasColumnContext(ctx context.Context, tableName string, columnName string) bool
// ModifyColumn modify column's type // ModifyColumn modify column's type
ModifyColumn(tableName string, columnName string, typ string) error ModifyColumn(tableName string, columnName string, typ string) error
// ModifyColumnContext same as ModifyColumn // ModifyColumnContext same as ModifyColumn
ModifyColumnContext(ctx context.Context, tableName string, columnName string, typ string) error ModifyColumnContext(ctx context.Context, tableName string, columnName string, typ string) error
// LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case // LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case
LimitAndOffsetSQL(limit, offset interface{}) (string, error) LimitAndOffsetSQL(limit, offset interface{}) (string, error)
@ -55,12 +55,12 @@ type Dialect interface {
// LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT` // LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT`
LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string
// LastInsertIDOutputInterstitialContext same as LastInsertIDOutputInterstitial // LastInsertIDOutputInterstitialContext same as LastInsertIDOutputInterstitial
LastInsertIDOutputInterstitialContext(ctx context.Context, tableName, columnName string, columns []string) string LastInsertIDOutputInterstitialContext(ctx context.Context, tableName, columnName string, columns []string) string
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING` // LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
LastInsertIDReturningSuffix(tableName, columnName string) string LastInsertIDReturningSuffix(tableName, columnName string) string
// LastInsertIDReturningSuffixContext same as LastInsertIDReturningSuffix // LastInsertIDReturningSuffixContext same as LastInsertIDReturningSuffix
LastInsertIDReturningSuffixContext(ctx context.Context, tableName, columnName string) string LastInsertIDReturningSuffixContext(ctx context.Context, tableName, columnName string) string
// DefaultValueStr // DefaultValueStr
DefaultValueStr() string DefaultValueStr() string
// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference // BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference
@ -71,8 +71,8 @@ type Dialect interface {
// CurrentDatabase return current database name // CurrentDatabase return current database name
CurrentDatabase() string CurrentDatabase() string
// CurrentDatabaseContext same as CurrentDatabase // CurrentDatabaseContext same as CurrentDatabase
CurrentDatabaseContext(ctx context.Context) string CurrentDatabaseContext(ctx context.Context) string
} }
var dialectsMap = map[string]Dialect{} var dialectsMap = map[string]Dialect{}

View File

@ -101,8 +101,8 @@ func (s *commonDialect) DataTypeOf(field *StructField) string {
} }
func (s commonDialect) HasIndex(tableName string, indexName string) bool { func (s commonDialect) HasIndex(tableName string, indexName string) bool {
ctx := context.Background() ctx := context.Background()
return s.HasIndexContext(ctx, tableName, indexName) return s.HasIndexContext(ctx, tableName, indexName)
} }
func (s commonDialect) HasIndexContext(ctx context.Context, tableName string, indexName string) bool { func (s commonDialect) HasIndexContext(ctx context.Context, tableName string, indexName string) bool {
@ -113,8 +113,8 @@ func (s commonDialect) HasIndexContext(ctx context.Context, tableName string, in
} }
func (s commonDialect) RemoveIndex(tableName string, indexName string) error { func (s commonDialect) RemoveIndex(tableName string, indexName string) error {
ctx := context.Background() ctx := context.Background()
return s.RemoveIndexContext(ctx, tableName, indexName) return s.RemoveIndexContext(ctx, tableName, indexName)
} }
func (s commonDialect) RemoveIndexContext(ctx context.Context, tableName string, indexName string) error { func (s commonDialect) RemoveIndexContext(ctx context.Context, tableName string, indexName string) error {
@ -123,8 +123,8 @@ func (s commonDialect) RemoveIndexContext(ctx context.Context, tableName string,
} }
func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool { func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool {
ctx := context.Background() ctx := context.Background()
return s.HasForeignKeyContext(ctx, tableName, foreignKeyName) return s.HasForeignKeyContext(ctx, tableName, foreignKeyName)
} }
func (s commonDialect) HasForeignKeyContext(_ctx context.Context, tableName string, foreignKeyName string) bool { func (s commonDialect) HasForeignKeyContext(_ctx context.Context, tableName string, foreignKeyName string) bool {
@ -132,8 +132,8 @@ func (s commonDialect) HasForeignKeyContext(_ctx context.Context, tableName stri
} }
func (s commonDialect) HasTable(tableName string) bool { func (s commonDialect) HasTable(tableName string) bool {
ctx := context.Background() ctx := context.Background()
return s.HasTableContext(ctx, tableName) return s.HasTableContext(ctx, tableName)
} }
func (s commonDialect) HasTableContext(ctx context.Context, tableName string) bool { func (s commonDialect) HasTableContext(ctx context.Context, tableName string) bool {
@ -144,8 +144,8 @@ func (s commonDialect) HasTableContext(ctx context.Context, tableName string) bo
} }
func (s commonDialect) HasColumn(tableName string, columnName string) bool { func (s commonDialect) HasColumn(tableName string, columnName string) bool {
ctx := context.Background() ctx := context.Background()
return s.HasColumnContext(ctx, tableName, columnName) return s.HasColumnContext(ctx, tableName, columnName)
} }
func (s commonDialect) HasColumnContext(ctx context.Context, tableName string, columnName string) bool { func (s commonDialect) HasColumnContext(ctx context.Context, tableName string, columnName string) bool {
@ -156,8 +156,8 @@ func (s commonDialect) HasColumnContext(ctx context.Context, tableName string, c
} }
func (s commonDialect) ModifyColumn(tableName string, columnName string, typ string) error { func (s commonDialect) ModifyColumn(tableName string, columnName string, typ string) error {
ctx := context.Background() ctx := context.Background()
return s.ModifyColumnContext(ctx, tableName, columnName, typ) return s.ModifyColumnContext(ctx, tableName, columnName, typ)
} }
func (s commonDialect) ModifyColumnContext(ctx context.Context, tableName string, columnName string, typ string) error { func (s commonDialect) ModifyColumnContext(ctx context.Context, tableName string, columnName string, typ string) error {
@ -166,8 +166,8 @@ func (s commonDialect) ModifyColumnContext(ctx context.Context, tableName string
} }
func (s commonDialect) CurrentDatabase() (name string) { func (s commonDialect) CurrentDatabase() (name string) {
ctx := context.Background() ctx := context.Background()
return s.CurrentDatabaseContext(ctx) return s.CurrentDatabaseContext(ctx)
} }
func (s commonDialect) CurrentDatabaseContext(ctx context.Context) (name string) { func (s commonDialect) CurrentDatabaseContext(ctx context.Context) (name string) {
@ -199,8 +199,8 @@ func (commonDialect) SelectFromDummyTable() string {
} }
func (s commonDialect) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string { func (s commonDialect) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string {
ctx := context.Background() ctx := context.Background()
return s.LastInsertIDOutputInterstitialContext(ctx, tableName, columnName, columns) return s.LastInsertIDOutputInterstitialContext(ctx, tableName, columnName, columns)
} }
func (commonDialect) LastInsertIDOutputInterstitialContext(_ctx context.Context, tableName, columnName string, columns []string) string { func (commonDialect) LastInsertIDOutputInterstitialContext(_ctx context.Context, tableName, columnName string, columns []string) string {
@ -208,8 +208,8 @@ func (commonDialect) LastInsertIDOutputInterstitialContext(_ctx context.Context,
} }
func (s commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string { func (s commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string {
ctx := context.Background() ctx := context.Background()
return s.LastInsertIDReturningSuffixContext(ctx, tableName, columnName) return s.LastInsertIDReturningSuffixContext(ctx, tableName, columnName)
} }
func (commonDialect) LastInsertIDReturningSuffixContext(_ctx context.Context, tableName, columnName string) string { func (commonDialect) LastInsertIDReturningSuffixContext(_ctx context.Context, tableName, columnName string) string {

View File

@ -124,8 +124,8 @@ func (s mssql) fieldCanAutoIncrement(field *gorm.StructField) bool {
} }
func (s mssql) HasIndex(tableName string, indexName string) bool { func (s mssql) HasIndex(tableName string, indexName string) bool {
ctx := context.Background() ctx := context.Background()
return s.HasIndexContext(ctx, tableName, indexName) return s.HasIndexContext(ctx, tableName, indexName)
} }
func (s mssql) HasIndexContext(ctx context.Context, tableName string, indexName string) bool { func (s mssql) HasIndexContext(ctx context.Context, tableName string, indexName string) bool {
@ -135,8 +135,8 @@ func (s mssql) HasIndexContext(ctx context.Context, tableName string, indexName
} }
func (s mssql) RemoveIndex(tableName string, indexName string) error { func (s mssql) RemoveIndex(tableName string, indexName string) error {
ctx := context.Background() ctx := context.Background()
return s.RemoveIndexContext(ctx, tableName, indexName) return s.RemoveIndexContext(ctx, tableName, indexName)
} }
func (s mssql) RemoveIndexContext(ctx context.Context, tableName string, indexName string) error { func (s mssql) RemoveIndexContext(ctx context.Context, tableName string, indexName string) error {
@ -145,8 +145,8 @@ func (s mssql) RemoveIndexContext(ctx context.Context, tableName string, indexNa
} }
func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool { func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool {
ctx := context.Background() ctx := context.Background()
return s.HasForeignKeyContext(ctx, tableName, foreignKeyName) return s.HasForeignKeyContext(ctx, tableName, foreignKeyName)
} }
func (s mssql) HasForeignKeyContext(ctx context.Context, tableName string, foreignKeyName string) bool { func (s mssql) HasForeignKeyContext(ctx context.Context, tableName string, foreignKeyName string) bool {
@ -161,8 +161,8 @@ func (s mssql) HasForeignKeyContext(ctx context.Context, tableName string, forei
} }
func (s mssql) HasTable(tableName string) bool { func (s mssql) HasTable(tableName string) bool {
ctx := context.Background() ctx := context.Background()
return s.HasTableContext(ctx, tableName) return s.HasTableContext(ctx, tableName)
} }
func (s mssql) HasTableContext(ctx context.Context, tableName string) bool { func (s mssql) HasTableContext(ctx context.Context, tableName string) bool {
@ -173,8 +173,8 @@ func (s mssql) HasTableContext(ctx context.Context, tableName string) bool {
} }
func (s mssql) HasColumn(tableName string, columnName string) bool { func (s mssql) HasColumn(tableName string, columnName string) bool {
ctx := context.Background() ctx := context.Background()
return s.HasColumnContext(ctx, tableName, columnName) return s.HasColumnContext(ctx, tableName, columnName)
} }
func (s mssql) HasColumnContext(ctx context.Context, tableName string, columnName string) bool { func (s mssql) HasColumnContext(ctx context.Context, tableName string, columnName string) bool {
@ -185,8 +185,8 @@ func (s mssql) HasColumnContext(ctx context.Context, tableName string, columnNam
} }
func (s mssql) ModifyColumn(tableName string, columnName string, typ string) error { func (s mssql) ModifyColumn(tableName string, columnName string, typ string) error {
ctx := context.Background() ctx := context.Background()
return s.ModifyColumnContext(ctx, tableName, columnName, typ) return s.ModifyColumnContext(ctx, tableName, columnName, typ)
} }
func (s mssql) ModifyColumnContext(ctx context.Context, tableName string, columnName string, typ string) error { func (s mssql) ModifyColumnContext(ctx context.Context, tableName string, columnName string, typ string) error {
@ -195,9 +195,9 @@ func (s mssql) ModifyColumnContext(ctx context.Context, tableName string, column
} }
func (s mssql) CurrentDatabase() (name string) { func (s mssql) CurrentDatabase() (name string) {
ctx := context.Background() ctx := context.Background()
s.CurrentDatabaseContext(ctx) s.CurrentDatabaseContext(ctx)
return return
} }
func (s mssql) CurrentDatabaseContext(ctx context.Context) (name string) { func (s mssql) CurrentDatabaseContext(ctx context.Context) (name string) {
@ -236,8 +236,8 @@ func (mssql) SelectFromDummyTable() string {
} }
func (s mssql) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string { func (s mssql) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string {
ctx := context.Background() ctx := context.Background()
return s.LastInsertIDOutputInterstitialContext(ctx, tableName, columnName, columns) return s.LastInsertIDOutputInterstitialContext(ctx, tableName, columnName, columns)
} }
func (mssql) LastInsertIDOutputInterstitialContext(_ctx context.Context, tableName, columnName string, columns []string) string { func (mssql) LastInsertIDOutputInterstitialContext(_ctx context.Context, tableName, columnName string, columns []string) string {
@ -249,8 +249,8 @@ func (mssql) LastInsertIDOutputInterstitialContext(_ctx context.Context, tableNa
} }
func (s mssql) LastInsertIDReturningSuffix(tableName, columnName string) string { func (s mssql) LastInsertIDReturningSuffix(tableName, columnName string) string {
ctx := context.Background() ctx := context.Background()
return s.LastInsertIDReturningSuffixContext(ctx, tableName, columnName) return s.LastInsertIDReturningSuffixContext(ctx, tableName, columnName)
} }
func (mssql) LastInsertIDReturningSuffixContext(_ctx context.Context, tableName, columnName string) string { func (mssql) LastInsertIDReturningSuffixContext(_ctx context.Context, tableName, columnName string) string {

View File

@ -3,89 +3,89 @@ package gorm_test
import "testing" import "testing"
type BasePost struct { type BasePost struct {
Id int64 Id int64
Title string Title string
URL string URL string
} }
type Author struct { type Author struct {
ID string ID string
Name string Name string
Email string Email string
} }
type HNPost struct { type HNPost struct {
BasePost BasePost
Author `gorm:"embedded_prefix:user_"` // Embedded struct Author `gorm:"embedded_prefix:user_"` // Embedded struct
Upvotes int32 Upvotes int32
} }
type EngadgetPost struct { type EngadgetPost struct {
BasePost BasePost `gorm:"embedded"` BasePost BasePost `gorm:"embedded"`
Author Author `gorm:"embedded;embedded_prefix:author_"` // Embedded struct Author Author `gorm:"embedded;embedded_prefix:author_"` // Embedded struct
ImageUrl string ImageUrl string
} }
func TestPrefixColumnNameForEmbeddedStruct(t *testing.T) { func TestPrefixColumnNameForEmbeddedStruct(t *testing.T) {
dialect := DB.NewScope(&EngadgetPost{}).Dialect() dialect := DB.NewScope(&EngadgetPost{}).Dialect()
engadgetPostScope := DB.NewScope(&EngadgetPost{}) engadgetPostScope := DB.NewScope(&EngadgetPost{})
if !dialect.HasColumn(engadgetPostScope.TableName(), "author_id") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_name") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_email") { if !dialect.HasColumn(engadgetPostScope.TableName(), "author_id") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_name") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_email") {
t.Errorf("should has prefix for embedded columns") t.Errorf("should has prefix for embedded columns")
} }
if len(engadgetPostScope.PrimaryFields()) != 1 { if len(engadgetPostScope.PrimaryFields()) != 1 {
t.Errorf("should have only one primary field with embedded struct, but got %v", len(engadgetPostScope.PrimaryFields())) t.Errorf("should have only one primary field with embedded struct, but got %v", len(engadgetPostScope.PrimaryFields()))
} }
hnScope := DB.NewScope(&HNPost{}) hnScope := DB.NewScope(&HNPost{})
if !dialect.HasColumn(hnScope.TableName(), "user_id") || !dialect.HasColumn(hnScope.TableName(), "user_name") || !dialect.HasColumn(hnScope.TableName(), "user_email") { if !dialect.HasColumn(hnScope.TableName(), "user_id") || !dialect.HasColumn(hnScope.TableName(), "user_name") || !dialect.HasColumn(hnScope.TableName(), "user_email") {
t.Errorf("should has prefix for embedded columns") t.Errorf("should has prefix for embedded columns")
} }
} }
func TestSaveAndQueryEmbeddedStruct(t *testing.T) { func TestSaveAndQueryEmbeddedStruct(t *testing.T) {
DB.Save(&HNPost{BasePost: BasePost{Title: "news"}}) DB.Save(&HNPost{BasePost: BasePost{Title: "news"}})
DB.Save(&HNPost{BasePost: BasePost{Title: "hn_news"}}) DB.Save(&HNPost{BasePost: BasePost{Title: "hn_news"}})
var news HNPost var news HNPost
if err := DB.First(&news, "title = ?", "hn_news").Error; err != nil { if err := DB.First(&news, "title = ?", "hn_news").Error; err != nil {
t.Errorf("no error should happen when query with embedded struct, but got %v", err) t.Errorf("no error should happen when query with embedded struct, but got %v", err)
} else if news.Title != "hn_news" { } else if news.Title != "hn_news" {
t.Errorf("embedded struct's value should be scanned correctly") t.Errorf("embedded struct's value should be scanned correctly")
} }
DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}}) DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}})
var egNews EngadgetPost var egNews EngadgetPost
if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil { if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil {
t.Errorf("no error should happen when query with embedded struct, but got %v", err) t.Errorf("no error should happen when query with embedded struct, but got %v", err)
} else if egNews.BasePost.Title != "engadget_news" { } else if egNews.BasePost.Title != "engadget_news" {
t.Errorf("embedded struct's value should be scanned correctly") t.Errorf("embedded struct's value should be scanned correctly")
} }
if DB.NewScope(&HNPost{}).PrimaryField() == nil { if DB.NewScope(&HNPost{}).PrimaryField() == nil {
t.Errorf("primary key with embedded struct should works") t.Errorf("primary key with embedded struct should works")
} }
for _, field := range DB.NewScope(&HNPost{}).Fields() { for _, field := range DB.NewScope(&HNPost{}).Fields() {
if field.Name == "BasePost" { if field.Name == "BasePost" {
t.Errorf("scope Fields should not contain embedded struct") t.Errorf("scope Fields should not contain embedded struct")
} }
} }
} }
func TestEmbeddedPointerTypeStruct(t *testing.T) { func TestEmbeddedPointerTypeStruct(t *testing.T) {
type HNPost struct { type HNPost struct {
*BasePost *BasePost
Upvotes int32 Upvotes int32
} }
DB.Create(&HNPost{BasePost: &BasePost{Title: "embedded_pointer_type"}}) DB.Create(&HNPost{BasePost: &BasePost{Title: "embedded_pointer_type"}})
var hnPost HNPost var hnPost HNPost
if err := DB.First(&hnPost, "title = ?", "embedded_pointer_type").Error; err != nil { if err := DB.First(&hnPost, "title = ?", "embedded_pointer_type").Error; err != nil {
t.Errorf("No error should happen when find embedded pointer type, but got %v", err) t.Errorf("No error should happen when find embedded pointer type, but got %v", err)
} }
if hnPost.Title != "embedded_pointer_type" { if hnPost.Title != "embedded_pointer_type" {
t.Errorf("Should find correct value for embedded pointer type") t.Errorf("Should find correct value for embedded pointer type")
} }
} }

View File

@ -11,10 +11,10 @@ type SQLCommon interface {
Prepare(query string) (*sql.Stmt, error) Prepare(query string) (*sql.Stmt, error)
Query(query string, args ...interface{}) (*sql.Rows, error) Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row QueryRow(query string, args ...interface{}) *sql.Row
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
} }
type sqlDb interface { type sqlDb interface {

View File

@ -675,7 +675,6 @@ func (s *DB) Begin() *DB {
return s.BeginTx(context.Background(), &sql.TxOptions{}) return s.BeginTx(context.Background(), &sql.TxOptions{})
} }
// BeginTx begins a transaction with options // BeginTx begins a transaction with options
func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) *DB { func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) *DB {
c := s.clone() c := s.clone()

View File

@ -1,579 +1,579 @@
package gorm_test package gorm_test
import ( import (
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"errors" "errors"
"fmt" "fmt"
"os" "os"
"reflect" "reflect"
"strconv" "strconv"
"testing" "testing"
"time" "time"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
) )
type User struct { type User struct {
Id int64 Id int64
Age int64 Age int64
UserNum Num UserNum Num
Name string `sql:"size:255"` Name string `sql:"size:255"`
Email string Email string
Birthday *time.Time // Time Birthday *time.Time // Time
CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically
UpdatedAt time.Time // UpdatedAt: Time of record is updated, will be updated automatically UpdatedAt time.Time // UpdatedAt: Time of record is updated, will be updated automatically
Emails []Email // Embedded structs Emails []Email // Embedded structs
BillingAddress Address // Embedded struct BillingAddress Address // Embedded struct
BillingAddressID sql.NullInt64 // Embedded struct's foreign key BillingAddressID sql.NullInt64 // Embedded struct's foreign key
ShippingAddress Address // Embedded struct ShippingAddress Address // Embedded struct
ShippingAddressId int64 // Embedded struct's foreign key ShippingAddressId int64 // Embedded struct's foreign key
CreditCard CreditCard CreditCard CreditCard
Latitude float64 Latitude float64
Languages []Language `gorm:"many2many:user_languages;"` Languages []Language `gorm:"many2many:user_languages;"`
CompanyID *int CompanyID *int
Company Company Company Company
Role Role Role Role
Password EncryptedData Password EncryptedData
PasswordHash []byte PasswordHash []byte
IgnoreMe int64 `sql:"-"` IgnoreMe int64 `sql:"-"`
IgnoreStringSlice []string `sql:"-"` IgnoreStringSlice []string `sql:"-"`
Ignored struct{ Name string } `sql:"-"` Ignored struct{ Name string } `sql:"-"`
IgnoredPointer *User `sql:"-"` IgnoredPointer *User `sql:"-"`
} }
type NotSoLongTableName struct { type NotSoLongTableName struct {
Id int64 Id int64
ReallyLongThingID int64 ReallyLongThingID int64
ReallyLongThing ReallyLongTableNameToTestMySQLNameLengthLimit ReallyLongThing ReallyLongTableNameToTestMySQLNameLengthLimit
} }
type ReallyLongTableNameToTestMySQLNameLengthLimit struct { type ReallyLongTableNameToTestMySQLNameLengthLimit struct {
Id int64 Id int64
} }
type ReallyLongThingThatReferencesShort struct { type ReallyLongThingThatReferencesShort struct {
Id int64 Id int64
ShortID int64 ShortID int64
Short Short Short Short
} }
type Short struct { type Short struct {
Id int64 Id int64
} }
type CreditCard struct { type CreditCard struct {
ID int8 ID int8
Number string Number string
UserId sql.NullInt64 UserId sql.NullInt64
CreatedAt time.Time `sql:"not null"` CreatedAt time.Time `sql:"not null"`
UpdatedAt time.Time UpdatedAt time.Time
DeletedAt *time.Time `sql:"column:deleted_time"` DeletedAt *time.Time `sql:"column:deleted_time"`
} }
type Email struct { type Email struct {
Id int16 Id int16
UserId int UserId int
Email string `sql:"type:varchar(100);"` Email string `sql:"type:varchar(100);"`
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
} }
type Address struct { type Address struct {
ID int ID int
Address1 string Address1 string
Address2 string Address2 string
Post string Post string
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
DeletedAt *time.Time DeletedAt *time.Time
} }
type Language struct { type Language struct {
gorm.Model gorm.Model
Name string Name string
Users []User `gorm:"many2many:user_languages;"` Users []User `gorm:"many2many:user_languages;"`
} }
type Product struct { type Product struct {
Id int64 Id int64
Code string Code string
Price int64 Price int64
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
AfterFindCallTimes int64 AfterFindCallTimes int64
BeforeCreateCallTimes int64 BeforeCreateCallTimes int64
AfterCreateCallTimes int64 AfterCreateCallTimes int64
BeforeUpdateCallTimes int64 BeforeUpdateCallTimes int64
AfterUpdateCallTimes int64 AfterUpdateCallTimes int64
BeforeSaveCallTimes int64 BeforeSaveCallTimes int64
AfterSaveCallTimes int64 AfterSaveCallTimes int64
BeforeDeleteCallTimes int64 BeforeDeleteCallTimes int64
AfterDeleteCallTimes int64 AfterDeleteCallTimes int64
} }
type Company struct { type Company struct {
Id int64 Id int64
Name string Name string
Owner *User `sql:"-"` Owner *User `sql:"-"`
} }
type Place struct { type Place struct {
Id int64 Id int64
PlaceAddressID int PlaceAddressID int
PlaceAddress *Address `gorm:"save_associations:false"` PlaceAddress *Address `gorm:"save_associations:false"`
OwnerAddressID int OwnerAddressID int
OwnerAddress *Address `gorm:"save_associations:true"` OwnerAddress *Address `gorm:"save_associations:true"`
} }
type EncryptedData []byte type EncryptedData []byte
func (data *EncryptedData) Scan(value interface{}) error { func (data *EncryptedData) Scan(value interface{}) error {
if b, ok := value.([]byte); ok { if b, ok := value.([]byte); ok {
if len(b) < 3 || b[0] != '*' || b[1] != '*' || b[2] != '*' { if len(b) < 3 || b[0] != '*' || b[1] != '*' || b[2] != '*' {
return errors.New("Too short") return errors.New("Too short")
} }
*data = b[3:] *data = b[3:]
return nil return nil
} }
return errors.New("Bytes expected") return errors.New("Bytes expected")
} }
func (data EncryptedData) Value() (driver.Value, error) { func (data EncryptedData) Value() (driver.Value, error) {
if len(data) > 0 && data[0] == 'x' { if len(data) > 0 && data[0] == 'x' {
//needed to test failures //needed to test failures
return nil, errors.New("Should not start with 'x'") return nil, errors.New("Should not start with 'x'")
} }
//prepend asterisks //prepend asterisks
return append([]byte("***"), data...), nil return append([]byte("***"), data...), nil
} }
type Role struct { type Role struct {
Name string `gorm:"size:256"` Name string `gorm:"size:256"`
} }
func (role *Role) Scan(value interface{}) error { func (role *Role) Scan(value interface{}) error {
if b, ok := value.([]uint8); ok { if b, ok := value.([]uint8); ok {
role.Name = string(b) role.Name = string(b)
} else { } else {
role.Name = value.(string) role.Name = value.(string)
} }
return nil return nil
} }
func (role Role) Value() (driver.Value, error) { func (role Role) Value() (driver.Value, error) {
return role.Name, nil return role.Name, nil
} }
func (role Role) IsAdmin() bool { func (role Role) IsAdmin() bool {
return role.Name == "admin" return role.Name == "admin"
} }
type Num int64 type Num int64
func (i *Num) Scan(src interface{}) error { func (i *Num) Scan(src interface{}) error {
switch s := src.(type) { switch s := src.(type) {
case []byte: case []byte:
n, _ := strconv.Atoi(string(s)) n, _ := strconv.Atoi(string(s))
*i = Num(n) *i = Num(n)
case int64: case int64:
*i = Num(s) *i = Num(s)
default: default:
return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String()) return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String())
} }
return nil return nil
} }
type Animal struct { type Animal struct {
Counter uint64 `gorm:"primary_key:yes"` Counter uint64 `gorm:"primary_key:yes"`
Name string `sql:"DEFAULT:'galeone'"` Name string `sql:"DEFAULT:'galeone'"`
From string //test reserved sql keyword as field name From string //test reserved sql keyword as field name
Age time.Time `sql:"DEFAULT:current_timestamp"` Age time.Time `sql:"DEFAULT:current_timestamp"`
unexported string // unexported value unexported string // unexported value
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
} }
type JoinTable struct { type JoinTable struct {
From uint64 From uint64
To uint64 To uint64
Time time.Time `sql:"default: null"` Time time.Time `sql:"default: null"`
} }
type Post struct { type Post struct {
Id int64 Id int64
CategoryId sql.NullInt64 CategoryId sql.NullInt64
MainCategoryId int64 MainCategoryId int64
Title string Title string
Body string Body string
Comments []*Comment Comments []*Comment
Category Category Category Category
MainCategory Category MainCategory Category
} }
type Category struct { type Category struct {
gorm.Model gorm.Model
Name string Name string
Categories []Category Categories []Category
CategoryID *uint CategoryID *uint
} }
type Comment struct { type Comment struct {
gorm.Model gorm.Model
PostId int64 PostId int64
Content string Content string
Post Post Post Post
} }
// Scanner // Scanner
type NullValue struct { type NullValue struct {
Id int64 Id int64
Name sql.NullString `sql:"not null"` Name sql.NullString `sql:"not null"`
Gender *sql.NullString `sql:"not null"` Gender *sql.NullString `sql:"not null"`
Age sql.NullInt64 Age sql.NullInt64
Male sql.NullBool Male sql.NullBool
Height sql.NullFloat64 Height sql.NullFloat64
AddedAt NullTime AddedAt NullTime
} }
type NullTime struct { type NullTime struct {
Time time.Time Time time.Time
Valid bool Valid bool
} }
func (nt *NullTime) Scan(value interface{}) error { func (nt *NullTime) Scan(value interface{}) error {
if value == nil { if value == nil {
nt.Valid = false nt.Valid = false
return nil return nil
} }
nt.Time, nt.Valid = value.(time.Time), true nt.Time, nt.Valid = value.(time.Time), true
return nil return nil
} }
func (nt NullTime) Value() (driver.Value, error) { func (nt NullTime) Value() (driver.Value, error) {
if !nt.Valid { if !nt.Valid {
return nil, nil return nil, nil
} }
return nt.Time, nil return nt.Time, nil
} }
func getPreparedUser(name string, role string) *User { func getPreparedUser(name string, role string) *User {
var company Company var company Company
DB.Where(Company{Name: role}).FirstOrCreate(&company) DB.Where(Company{Name: role}).FirstOrCreate(&company)
return &User{ return &User{
Name: name, Name: name,
Age: 20, Age: 20,
Role: Role{role}, Role: Role{role},
BillingAddress: Address{Address1: fmt.Sprintf("Billing Address %v", name)}, BillingAddress: Address{Address1: fmt.Sprintf("Billing Address %v", name)},
ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)}, ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)},
CreditCard: CreditCard{Number: fmt.Sprintf("123456%v", name)}, CreditCard: CreditCard{Number: fmt.Sprintf("123456%v", name)},
Emails: []Email{ Emails: []Email{
{Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)}, {Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)},
}, },
Company: company, Company: company,
Languages: []Language{ Languages: []Language{
{Name: fmt.Sprintf("lang_1_%v", name)}, {Name: fmt.Sprintf("lang_1_%v", name)},
{Name: fmt.Sprintf("lang_2_%v", name)}, {Name: fmt.Sprintf("lang_2_%v", name)},
}, },
} }
} }
func runMigration() { func runMigration() {
if err := DB.DropTableIfExists(&User{}).Error; err != nil { if err := DB.DropTableIfExists(&User{}).Error; err != nil {
fmt.Printf("Got error when try to delete table users, %+v\n", err) fmt.Printf("Got error when try to delete table users, %+v\n", err)
} }
for _, table := range []string{"animals", "user_languages"} { for _, table := range []string{"animals", "user_languages"} {
DB.Exec(fmt.Sprintf("drop table %v;", table)) DB.Exec(fmt.Sprintf("drop table %v;", table))
} }
values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}, &Place{}} values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}, &Place{}}
for _, value := range values { for _, value := range values {
DB.DropTable(value) DB.DropTable(value)
} }
if err := DB.AutoMigrate(values...).Error; err != nil { if err := DB.AutoMigrate(values...).Error; err != nil {
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
} }
} }
func TestIndexes(t *testing.T) { func TestIndexes(t *testing.T) {
if err := DB.Model(&Email{}).AddIndex("idx_email_email", "email").Error; err != nil { if err := DB.Model(&Email{}).AddIndex("idx_email_email", "email").Error; err != nil {
t.Errorf("Got error when tried to create index: %+v", err) t.Errorf("Got error when tried to create index: %+v", err)
} }
scope := DB.NewScope(&Email{}) scope := DB.NewScope(&Email{})
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") { if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") {
t.Errorf("Email should have index idx_email_email") t.Errorf("Email should have index idx_email_email")
} }
if err := DB.Model(&Email{}).RemoveIndex("idx_email_email").Error; err != nil { if err := DB.Model(&Email{}).RemoveIndex("idx_email_email").Error; err != nil {
t.Errorf("Got error when tried to remove index: %+v", err) t.Errorf("Got error when tried to remove index: %+v", err)
} }
if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") { if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") {
t.Errorf("Email's index idx_email_email should be deleted") t.Errorf("Email's index idx_email_email should be deleted")
} }
if err := DB.Model(&Email{}).AddIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil { if err := DB.Model(&Email{}).AddIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil {
t.Errorf("Got error when tried to create index: %+v", err) t.Errorf("Got error when tried to create index: %+v", err)
} }
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
t.Errorf("Email should have index idx_email_email_and_user_id") t.Errorf("Email should have index idx_email_email_and_user_id")
} }
if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil { if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil {
t.Errorf("Got error when tried to remove index: %+v", err) t.Errorf("Got error when tried to remove index: %+v", err)
} }
if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
t.Errorf("Email's index idx_email_email_and_user_id should be deleted") t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
} }
if err := DB.Model(&Email{}).AddUniqueIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil { if err := DB.Model(&Email{}).AddUniqueIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil {
t.Errorf("Got error when tried to create index: %+v", err) t.Errorf("Got error when tried to create index: %+v", err)
} }
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
t.Errorf("Email should have index idx_email_email_and_user_id") t.Errorf("Email should have index idx_email_email_and_user_id")
} }
if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.comiii"}, {Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error == nil { if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.comiii"}, {Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error == nil {
t.Errorf("Should get to create duplicate record when having unique index") t.Errorf("Should get to create duplicate record when having unique index")
} }
var user = User{Name: "sample_user"} var user = User{Name: "sample_user"}
DB.Save(&user) DB.Save(&user)
if DB.Model(&user).Association("Emails").Append(Email{Email: "not-1duplicated@gmail.com"}, Email{Email: "not-duplicated2@gmail.com"}).Error != nil { if DB.Model(&user).Association("Emails").Append(Email{Email: "not-1duplicated@gmail.com"}, Email{Email: "not-duplicated2@gmail.com"}).Error != nil {
t.Errorf("Should get no error when append two emails for user") t.Errorf("Should get no error when append two emails for user")
} }
if DB.Model(&user).Association("Emails").Append(Email{Email: "duplicated@gmail.com"}, Email{Email: "duplicated@gmail.com"}).Error == nil { if DB.Model(&user).Association("Emails").Append(Email{Email: "duplicated@gmail.com"}, Email{Email: "duplicated@gmail.com"}).Error == nil {
t.Errorf("Should get no duplicated email error when insert duplicated emails for a user") t.Errorf("Should get no duplicated email error when insert duplicated emails for a user")
} }
if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil { if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil {
t.Errorf("Got error when tried to remove index: %+v", err) t.Errorf("Got error when tried to remove index: %+v", err)
} }
if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
t.Errorf("Email's index idx_email_email_and_user_id should be deleted") t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
} }
if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error != nil { if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error != nil {
t.Errorf("Should be able to create duplicated emails after remove unique index") t.Errorf("Should be able to create duplicated emails after remove unique index")
} }
} }
type EmailWithIdx struct { type EmailWithIdx struct {
Id int64 Id int64
UserId int64 UserId int64
Email string `sql:"index:idx_email_agent"` Email string `sql:"index:idx_email_agent"`
UserAgent string `sql:"index:idx_email_agent"` UserAgent string `sql:"index:idx_email_agent"`
RegisteredAt *time.Time `sql:"unique_index"` RegisteredAt *time.Time `sql:"unique_index"`
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
} }
func TestAutoMigration(t *testing.T) { func TestAutoMigration(t *testing.T) {
DB.AutoMigrate(&Address{}) DB.AutoMigrate(&Address{})
DB.DropTable(&EmailWithIdx{}) DB.DropTable(&EmailWithIdx{})
if err := DB.AutoMigrate(&EmailWithIdx{}).Error; err != nil { if err := DB.AutoMigrate(&EmailWithIdx{}).Error; err != nil {
t.Errorf("Auto Migrate should not raise any error") t.Errorf("Auto Migrate should not raise any error")
} }
now := time.Now() now := time.Now()
DB.Save(&EmailWithIdx{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: &now}) DB.Save(&EmailWithIdx{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: &now})
scope := DB.NewScope(&EmailWithIdx{}) scope := DB.NewScope(&EmailWithIdx{})
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_agent") { if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_agent") {
t.Errorf("Failed to create index") t.Errorf("Failed to create index")
} }
if !scope.Dialect().HasIndex(scope.TableName(), "uix_email_with_idxes_registered_at") { if !scope.Dialect().HasIndex(scope.TableName(), "uix_email_with_idxes_registered_at") {
t.Errorf("Failed to create index") t.Errorf("Failed to create index")
} }
var bigemail EmailWithIdx var bigemail EmailWithIdx
DB.First(&bigemail, "user_agent = ?", "pc") DB.First(&bigemail, "user_agent = ?", "pc")
if bigemail.Email != "jinzhu@example.org" || bigemail.UserAgent != "pc" || bigemail.RegisteredAt.IsZero() { if bigemail.Email != "jinzhu@example.org" || bigemail.UserAgent != "pc" || bigemail.RegisteredAt.IsZero() {
t.Error("Big Emails should be saved and fetched correctly") t.Error("Big Emails should be saved and fetched correctly")
} }
} }
func TestCreateAndAutomigrateTransaction(t *testing.T) { func TestCreateAndAutomigrateTransaction(t *testing.T) {
tx := DB.Begin() tx := DB.Begin()
func() { func() {
type Bar struct { type Bar struct {
ID uint ID uint
} }
DB.DropTableIfExists(&Bar{}) DB.DropTableIfExists(&Bar{})
if ok := DB.HasTable("bars"); ok { if ok := DB.HasTable("bars"); ok {
t.Errorf("Table should not exist, but does") t.Errorf("Table should not exist, but does")
} }
if ok := tx.HasTable("bars"); ok { if ok := tx.HasTable("bars"); ok {
t.Errorf("Table should not exist, but does") t.Errorf("Table should not exist, but does")
} }
}() }()
func() { func() {
type Bar struct { type Bar struct {
Name string Name string
} }
err := tx.CreateTable(&Bar{}).Error err := tx.CreateTable(&Bar{}).Error
if err != nil { if err != nil {
t.Errorf("Should have been able to create the table, but couldn't: %s", err) t.Errorf("Should have been able to create the table, but couldn't: %s", err)
} }
if ok := tx.HasTable(&Bar{}); !ok { if ok := tx.HasTable(&Bar{}); !ok {
t.Errorf("The transaction should be able to see the table") t.Errorf("The transaction should be able to see the table")
} }
}() }()
func() { func() {
type Bar struct { type Bar struct {
Stuff string Stuff string
} }
err := tx.AutoMigrate(&Bar{}).Error err := tx.AutoMigrate(&Bar{}).Error
if err != nil { if err != nil {
t.Errorf("Should have been able to alter the table, but couldn't") t.Errorf("Should have been able to alter the table, but couldn't")
} }
}() }()
tx.Rollback() tx.Rollback()
} }
type MultipleIndexes struct { type MultipleIndexes struct {
ID int64 ID int64
UserID int64 `sql:"unique_index:uix_multipleindexes_user_name,uix_multipleindexes_user_email;index:idx_multipleindexes_user_other"` UserID int64 `sql:"unique_index:uix_multipleindexes_user_name,uix_multipleindexes_user_email;index:idx_multipleindexes_user_other"`
Name string `sql:"unique_index:uix_multipleindexes_user_name"` Name string `sql:"unique_index:uix_multipleindexes_user_name"`
Email string `sql:"unique_index:,uix_multipleindexes_user_email"` Email string `sql:"unique_index:,uix_multipleindexes_user_email"`
Other string `sql:"index:,idx_multipleindexes_user_other"` Other string `sql:"index:,idx_multipleindexes_user_other"`
} }
func TestMultipleIndexes(t *testing.T) { func TestMultipleIndexes(t *testing.T) {
if err := DB.DropTableIfExists(&MultipleIndexes{}).Error; err != nil { if err := DB.DropTableIfExists(&MultipleIndexes{}).Error; err != nil {
fmt.Printf("Got error when try to delete table multiple_indexes, %+v\n", err) fmt.Printf("Got error when try to delete table multiple_indexes, %+v\n", err)
} }
DB.AutoMigrate(&MultipleIndexes{}) DB.AutoMigrate(&MultipleIndexes{})
if err := DB.AutoMigrate(&EmailWithIdx{}).Error; err != nil { if err := DB.AutoMigrate(&EmailWithIdx{}).Error; err != nil {
t.Errorf("Auto Migrate should not raise any error") t.Errorf("Auto Migrate should not raise any error")
} }
DB.Save(&MultipleIndexes{UserID: 1, Name: "jinzhu", Email: "jinzhu@example.org", Other: "foo"}) DB.Save(&MultipleIndexes{UserID: 1, Name: "jinzhu", Email: "jinzhu@example.org", Other: "foo"})
scope := DB.NewScope(&MultipleIndexes{}) scope := DB.NewScope(&MultipleIndexes{})
if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_name") { if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_name") {
t.Errorf("Failed to create index") t.Errorf("Failed to create index")
} }
if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_email") { if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_email") {
t.Errorf("Failed to create index") t.Errorf("Failed to create index")
} }
if !scope.Dialect().HasIndex(scope.TableName(), "uix_multiple_indexes_email") { if !scope.Dialect().HasIndex(scope.TableName(), "uix_multiple_indexes_email") {
t.Errorf("Failed to create index") t.Errorf("Failed to create index")
} }
if !scope.Dialect().HasIndex(scope.TableName(), "idx_multipleindexes_user_other") { if !scope.Dialect().HasIndex(scope.TableName(), "idx_multipleindexes_user_other") {
t.Errorf("Failed to create index") t.Errorf("Failed to create index")
} }
if !scope.Dialect().HasIndex(scope.TableName(), "idx_multiple_indexes_other") { if !scope.Dialect().HasIndex(scope.TableName(), "idx_multiple_indexes_other") {
t.Errorf("Failed to create index") t.Errorf("Failed to create index")
} }
var mutipleIndexes MultipleIndexes var mutipleIndexes MultipleIndexes
DB.First(&mutipleIndexes, "name = ?", "jinzhu") DB.First(&mutipleIndexes, "name = ?", "jinzhu")
if mutipleIndexes.Email != "jinzhu@example.org" || mutipleIndexes.Name != "jinzhu" { if mutipleIndexes.Email != "jinzhu@example.org" || mutipleIndexes.Name != "jinzhu" {
t.Error("MutipleIndexes should be saved and fetched correctly") t.Error("MutipleIndexes should be saved and fetched correctly")
} }
// Check unique constraints // Check unique constraints
if err := DB.Save(&MultipleIndexes{UserID: 1, Name: "name1", Email: "jinzhu@example.org", Other: "foo"}).Error; err == nil { if err := DB.Save(&MultipleIndexes{UserID: 1, Name: "name1", Email: "jinzhu@example.org", Other: "foo"}).Error; err == nil {
t.Error("MultipleIndexes unique index failed") t.Error("MultipleIndexes unique index failed")
} }
if err := DB.Save(&MultipleIndexes{UserID: 1, Name: "name1", Email: "foo@example.org", Other: "foo"}).Error; err != nil { if err := DB.Save(&MultipleIndexes{UserID: 1, Name: "name1", Email: "foo@example.org", Other: "foo"}).Error; err != nil {
t.Error("MultipleIndexes unique index failed") t.Error("MultipleIndexes unique index failed")
} }
if err := DB.Save(&MultipleIndexes{UserID: 2, Name: "name1", Email: "jinzhu@example.org", Other: "foo"}).Error; err == nil { if err := DB.Save(&MultipleIndexes{UserID: 2, Name: "name1", Email: "jinzhu@example.org", Other: "foo"}).Error; err == nil {
t.Error("MultipleIndexes unique index failed") t.Error("MultipleIndexes unique index failed")
} }
if err := DB.Save(&MultipleIndexes{UserID: 2, Name: "name1", Email: "foo2@example.org", Other: "foo"}).Error; err != nil { if err := DB.Save(&MultipleIndexes{UserID: 2, Name: "name1", Email: "foo2@example.org", Other: "foo"}).Error; err != nil {
t.Error("MultipleIndexes unique index failed") t.Error("MultipleIndexes unique index failed")
} }
} }
func TestModifyColumnType(t *testing.T) { func TestModifyColumnType(t *testing.T) {
if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" && dialect != "mysql" && dialect != "mssql" { if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" && dialect != "mysql" && dialect != "mssql" {
t.Skip("Skipping this because only postgres, mysql and mssql support altering a column type") t.Skip("Skipping this because only postgres, mysql and mssql support altering a column type")
} }
type ModifyColumnType struct { type ModifyColumnType struct {
gorm.Model gorm.Model
Name1 string `gorm:"length:100"` Name1 string `gorm:"length:100"`
Name2 string `gorm:"length:200"` Name2 string `gorm:"length:200"`
} }
DB.DropTable(&ModifyColumnType{}) DB.DropTable(&ModifyColumnType{})
DB.CreateTable(&ModifyColumnType{}) DB.CreateTable(&ModifyColumnType{})
name2Field, _ := DB.NewScope(&ModifyColumnType{}).FieldByName("Name2") name2Field, _ := DB.NewScope(&ModifyColumnType{}).FieldByName("Name2")
name2Type := DB.Dialect().DataTypeOf(name2Field.StructField) name2Type := DB.Dialect().DataTypeOf(name2Field.StructField)
if err := DB.Model(&ModifyColumnType{}).ModifyColumn("name1", name2Type).Error; err != nil { if err := DB.Model(&ModifyColumnType{}).ModifyColumn("name1", name2Type).Error; err != nil {
t.Errorf("No error should happen when ModifyColumn, but got %v", err) t.Errorf("No error should happen when ModifyColumn, but got %v", err)
} }
} }
func TestIndexWithPrefixLength(t *testing.T) { func TestIndexWithPrefixLength(t *testing.T) {
if dialect := os.Getenv("GORM_DIALECT"); dialect != "mysql" { if dialect := os.Getenv("GORM_DIALECT"); dialect != "mysql" {
t.Skip("Skipping this because only mysql support setting an index prefix length") t.Skip("Skipping this because only mysql support setting an index prefix length")
} }
type IndexWithPrefix struct { type IndexWithPrefix struct {
gorm.Model gorm.Model
Name string Name string
Description string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"` Description string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"`
} }
type IndexesWithPrefix struct { type IndexesWithPrefix struct {
gorm.Model gorm.Model
Name string Name string
Description1 string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"` Description1 string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"`
Description2 string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"` Description2 string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"`
} }
type IndexesWithPrefixAndWithoutPrefix struct { type IndexesWithPrefixAndWithoutPrefix struct {
gorm.Model gorm.Model
Name string `gorm:"index:idx_index_with_prefixes_length"` Name string `gorm:"index:idx_index_with_prefixes_length"`
Description string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"` Description string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"`
} }
tables := []interface{}{&IndexWithPrefix{}, &IndexesWithPrefix{}, &IndexesWithPrefixAndWithoutPrefix{}} tables := []interface{}{&IndexWithPrefix{}, &IndexesWithPrefix{}, &IndexesWithPrefixAndWithoutPrefix{}}
for _, table := range tables { for _, table := range tables {
scope := DB.NewScope(table) scope := DB.NewScope(table)
tableName := scope.TableName() tableName := scope.TableName()
t.Run(fmt.Sprintf("Create index with prefix length: %s", tableName), func(t *testing.T) { t.Run(fmt.Sprintf("Create index with prefix length: %s", tableName), func(t *testing.T) {
if err := DB.DropTableIfExists(table).Error; err != nil { if err := DB.DropTableIfExists(table).Error; err != nil {
t.Errorf("Failed to drop %s table: %v", tableName, err) t.Errorf("Failed to drop %s table: %v", tableName, err)
} }
if err := DB.CreateTable(table).Error; err != nil { if err := DB.CreateTable(table).Error; err != nil {
t.Errorf("Failed to create %s table: %v", tableName, err) t.Errorf("Failed to create %s table: %v", tableName, err)
} }
if !scope.Dialect().HasIndex(tableName, "idx_index_with_prefixes_length") { if !scope.Dialect().HasIndex(tableName, "idx_index_with_prefixes_length") {
t.Errorf("Failed to create %s table index:", tableName) t.Errorf("Failed to create %s table index:", tableName)
} }
}) })
} }
} }

File diff suppressed because it is too large Load Diff