This commit is contained in:
Daniel Gatis 2020-02-21 14:03:47 -03:00
parent 75d6dc912c
commit f424f8aa2e
12 changed files with 2455 additions and 2456 deletions

View File

@ -97,27 +97,27 @@ func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor {
// Register a new callback, refer `Callbacks.Create`
func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) {
callbackContext := func(ctx context.Context, scope *Scope) {
callback(scope)
}
callbackContext := func(ctx context.Context, scope *Scope) {
callback(scope)
}
cp.RegisterContext(callbackName, callbackContext)
cp.RegisterContext(callbackName, callbackContext)
}
// RegisterContext same as Register
func (cp *CallbackProcessor) RegisterContext(callbackName string, callback func(ctx context.Context, scope *Scope)) {
if cp.kind == "row_query" {
if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" {
cp.logger.Print("info", fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...", callbackName))
cp.before = "gorm:row_query"
}
}
if cp.kind == "row_query" {
if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" {
cp.logger.Print("info", fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...", callbackName))
cp.before = "gorm:row_query"
}
}
cp.logger.Print("info", fmt.Sprintf("[info] registering callback `%v` from %v", callbackName, fileWithLineNum()))
cp.name = callbackName
cp.processor = &callback
cp.parent.processors = append(cp.parent.processors, cp)
cp.parent.reorder()
cp.logger.Print("info", fmt.Sprintf("[info] registering callback `%v` from %v", callbackName, fileWithLineNum()))
cp.name = callbackName
cp.processor = &callback
cp.parent.processors = append(cp.parent.processors, cp)
cp.parent.reorder()
}
// Remove a registered callback
@ -136,47 +136,47 @@ func (cp *CallbackProcessor) Remove(callbackName string) {
// scope.SetColumn("UpdatedAt", now)
// })
func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
callbackContext := func(ctx context.Context, scope *Scope) {
callback(scope)
}
callbackContext := func(ctx context.Context, scope *Scope) {
callback(scope)
}
cp.ReplaceContext(callbackName, callbackContext)
cp.ReplaceContext(callbackName, callbackContext)
}
// ReplaceContext same as Replace
func (cp *CallbackProcessor) ReplaceContext(callbackName string, callback func(ctx context.Context, scope *Scope)) {
cp.logger.Print("info", fmt.Sprintf("[info] replacing callback `%v` from %v", callbackName, fileWithLineNum()))
cp.name = callbackName
cp.processor = &callback
cp.replace = true
cp.parent.processors = append(cp.parent.processors, cp)
cp.parent.reorder()
cp.logger.Print("info", fmt.Sprintf("[info] replacing callback `%v` from %v", callbackName, fileWithLineNum()))
cp.name = callbackName
cp.processor = &callback
cp.replace = true
cp.parent.processors = append(cp.parent.processors, cp)
cp.parent.reorder()
}
// Get registered callback
// db.Callback().Create().Get("gorm:create")
func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) {
c := cp.GetContext(callbackName)
c := cp.GetContext(callbackName)
callback = func(scope *Scope) {
ctx := context.Background()
c(ctx, scope)
}
return
callback = func(scope *Scope) {
ctx := context.Background()
c(ctx, scope)
}
return
}
// GetContext same as Get
func (cp *CallbackProcessor) GetContext(callbackName string) (callback func(ctx context.Context, scope *Scope)) {
for _, p := range cp.parent.processors {
if p.name == callbackName && p.kind == cp.kind {
if p.remove {
callback = nil
} else {
callback = *p.processor
}
}
}
return
for _, p := range cp.parent.processors {
if p.name == callbackName && p.kind == cp.kind {
if p.remove {
callback = nil
} else {
callback = *p.processor
}
}
}
return
}
// getRIndex get right index from string slice

View File

@ -1,20 +1,20 @@
package gorm
import (
"context"
"reflect"
"runtime"
"strings"
"testing"
"context"
"reflect"
"runtime"
"strings"
"testing"
)
func equalFuncs(funcs []*func(ctx context.Context, s *Scope), fnames []string) bool {
var names []string
for _, f := range funcs {
fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(*f).Pointer()).Name(), ".")
names = append(names, fnames[len(fnames)-1])
}
return reflect.DeepEqual(names, fnames)
var names []string
for _, f := range funcs {
fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(*f).Pointer()).Name(), ".")
names = append(names, fnames[len(fnames)-1])
}
return reflect.DeepEqual(names, fnames)
}
func create(s *Scope) {}
@ -24,90 +24,90 @@ func afterCreate1(s *Scope) {}
func afterCreate2(s *Scope) {}
func TestRegisterCallback(t *testing.T) {
var callback = &Callback{logger: defaultLogger}
var callback = &Callback{logger: defaultLogger}
callback.Create().Register("before_create1", beforeCreate1)
callback.Create().Register("before_create2", beforeCreate2)
callback.Create().Register("create", create)
callback.Create().Register("after_create1", afterCreate1)
callback.Create().Register("after_create2", afterCreate2)
callback.Create().Register("before_create1", beforeCreate1)
callback.Create().Register("before_create2", beforeCreate2)
callback.Create().Register("create", create)
callback.Create().Register("after_create1", afterCreate1)
callback.Create().Register("after_create2", afterCreate2)
if !equalFuncs(callback.creates, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
t.Errorf("register callback")
}
if !equalFuncs(callback.creates, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
t.Errorf("register callback")
}
}
func TestRegisterCallbackWithOrder(t *testing.T) {
var callback1 = &Callback{logger: defaultLogger}
callback1.Create().Register("before_create1", beforeCreate1)
callback1.Create().Register("create", create)
callback1.Create().Register("after_create1", afterCreate1)
callback1.Create().Before("after_create1").Register("after_create2", afterCreate2)
if !equalFuncs(callback1.creates, []string{"beforeCreate1", "create", "afterCreate2", "afterCreate1"}) {
t.Errorf("register callback with order")
}
var callback1 = &Callback{logger: defaultLogger}
callback1.Create().Register("before_create1", beforeCreate1)
callback1.Create().Register("create", create)
callback1.Create().Register("after_create1", afterCreate1)
callback1.Create().Before("after_create1").Register("after_create2", afterCreate2)
if !equalFuncs(callback1.creates, []string{"beforeCreate1", "create", "afterCreate2", "afterCreate1"}) {
t.Errorf("register callback with order")
}
var callback2 = &Callback{logger: defaultLogger}
var callback2 = &Callback{logger: defaultLogger}
callback2.Update().Register("create", create)
callback2.Update().Before("create").Register("before_create1", beforeCreate1)
callback2.Update().After("after_create2").Register("after_create1", afterCreate1)
callback2.Update().Before("before_create1").Register("before_create2", beforeCreate2)
callback2.Update().Register("after_create2", afterCreate2)
callback2.Update().Register("create", create)
callback2.Update().Before("create").Register("before_create1", beforeCreate1)
callback2.Update().After("after_create2").Register("after_create1", afterCreate1)
callback2.Update().Before("before_create1").Register("before_create2", beforeCreate2)
callback2.Update().Register("after_create2", afterCreate2)
if !equalFuncs(callback2.updates, []string{"beforeCreate2", "beforeCreate1", "create", "afterCreate2", "afterCreate1"}) {
t.Errorf("register callback with order")
}
if !equalFuncs(callback2.updates, []string{"beforeCreate2", "beforeCreate1", "create", "afterCreate2", "afterCreate1"}) {
t.Errorf("register callback with order")
}
}
func TestRegisterCallbackWithComplexOrder(t *testing.T) {
var callback1 = &Callback{logger: defaultLogger}
var callback1 = &Callback{logger: defaultLogger}
callback1.Query().Before("after_create1").After("before_create1").Register("create", create)
callback1.Query().Register("before_create1", beforeCreate1)
callback1.Query().Register("after_create1", afterCreate1)
callback1.Query().Before("after_create1").After("before_create1").Register("create", create)
callback1.Query().Register("before_create1", beforeCreate1)
callback1.Query().Register("after_create1", afterCreate1)
if !equalFuncs(callback1.queries, []string{"beforeCreate1", "create", "afterCreate1"}) {
t.Errorf("register callback with order")
}
if !equalFuncs(callback1.queries, []string{"beforeCreate1", "create", "afterCreate1"}) {
t.Errorf("register callback with order")
}
var callback2 = &Callback{logger: defaultLogger}
var callback2 = &Callback{logger: defaultLogger}
callback2.Delete().Before("after_create1").After("before_create1").Register("create", create)
callback2.Delete().Before("create").Register("before_create1", beforeCreate1)
callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2)
callback2.Delete().Register("after_create1", afterCreate1)
callback2.Delete().After("after_create1").Register("after_create2", afterCreate2)
callback2.Delete().Before("after_create1").After("before_create1").Register("create", create)
callback2.Delete().Before("create").Register("before_create1", beforeCreate1)
callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2)
callback2.Delete().Register("after_create1", afterCreate1)
callback2.Delete().After("after_create1").Register("after_create2", afterCreate2)
if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
t.Errorf("register callback with order")
}
if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
t.Errorf("register callback with order")
}
}
func replaceCreate(s *Scope) {}
func TestReplaceCallback(t *testing.T) {
var callback = &Callback{logger: defaultLogger}
var callback = &Callback{logger: defaultLogger}
callback.Create().Before("after_create1").After("before_create1").Register("create", create)
callback.Create().Register("before_create1", beforeCreate1)
callback.Create().Register("after_create1", afterCreate1)
callback.Create().Replace("create", replaceCreate)
callback.Create().Before("after_create1").After("before_create1").Register("create", create)
callback.Create().Register("before_create1", beforeCreate1)
callback.Create().Register("after_create1", afterCreate1)
callback.Create().Replace("create", replaceCreate)
if !equalFuncs(callback.creates, []string{"beforeCreate1", "replaceCreate", "afterCreate1"}) {
t.Errorf("replace callback")
}
if !equalFuncs(callback.creates, []string{"beforeCreate1", "replaceCreate", "afterCreate1"}) {
t.Errorf("replace callback")
}
}
func TestRemoveCallback(t *testing.T) {
var callback = &Callback{logger: defaultLogger}
var callback = &Callback{logger: defaultLogger}
callback.Create().Before("after_create1").After("before_create1").Register("create", create)
callback.Create().Register("before_create1", beforeCreate1)
callback.Create().Register("after_create1", afterCreate1)
callback.Create().Remove("create")
callback.Create().Before("after_create1").After("before_create1").Register("create", create)
callback.Create().Register("before_create1", beforeCreate1)
callback.Create().Register("after_create1", afterCreate1)
callback.Create().Remove("create")
if !equalFuncs(callback.creates, []string{"beforeCreate1", "afterCreate1"}) {
t.Errorf("remove callback")
}
if !equalFuncs(callback.creates, []string{"beforeCreate1", "afterCreate1"}) {
t.Errorf("remove callback")
}
}

View File

@ -1,249 +1,249 @@
package gorm_test
import (
"errors"
"reflect"
"testing"
"errors"
"reflect"
"testing"
"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm"
)
func (s *Product) BeforeCreate() (err error) {
if s.Code == "Invalid" {
err = errors.New("invalid product")
}
s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1
return
if s.Code == "Invalid" {
err = errors.New("invalid product")
}
s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1
return
}
func (s *Product) BeforeUpdate() (err error) {
if s.Code == "dont_update" {
err = errors.New("can't update")
}
s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1
return
if s.Code == "dont_update" {
err = errors.New("can't update")
}
s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1
return
}
func (s *Product) BeforeSave() (err error) {
if s.Code == "dont_save" {
err = errors.New("can't save")
}
s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1
return
if s.Code == "dont_save" {
err = errors.New("can't save")
}
s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1
return
}
func (s *Product) AfterFind() {
s.AfterFindCallTimes = s.AfterFindCallTimes + 1
s.AfterFindCallTimes = s.AfterFindCallTimes + 1
}
func (s *Product) AfterCreate(tx *gorm.DB) {
tx.Model(s).UpdateColumn(Product{AfterCreateCallTimes: s.AfterCreateCallTimes + 1})
tx.Model(s).UpdateColumn(Product{AfterCreateCallTimes: s.AfterCreateCallTimes + 1})
}
func (s *Product) AfterUpdate() {
s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1
s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1
}
func (s *Product) AfterSave() (err error) {
if s.Code == "after_save_error" {
err = errors.New("can't save")
}
s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1
return
if s.Code == "after_save_error" {
err = errors.New("can't save")
}
s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1
return
}
func (s *Product) BeforeDelete() (err error) {
if s.Code == "dont_delete" {
err = errors.New("can't delete")
}
s.BeforeDeleteCallTimes = s.BeforeDeleteCallTimes + 1
return
if s.Code == "dont_delete" {
err = errors.New("can't delete")
}
s.BeforeDeleteCallTimes = s.BeforeDeleteCallTimes + 1
return
}
func (s *Product) AfterDelete() (err error) {
if s.Code == "after_delete_error" {
err = errors.New("can't delete")
}
s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1
return
if s.Code == "after_delete_error" {
err = errors.New("can't delete")
}
s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1
return
}
func (s *Product) GetCallTimes() []int64 {
return []int64{s.BeforeCreateCallTimes, s.BeforeSaveCallTimes, s.BeforeUpdateCallTimes, s.AfterCreateCallTimes, s.AfterSaveCallTimes, s.AfterUpdateCallTimes, s.BeforeDeleteCallTimes, s.AfterDeleteCallTimes, s.AfterFindCallTimes}
return []int64{s.BeforeCreateCallTimes, s.BeforeSaveCallTimes, s.BeforeUpdateCallTimes, s.AfterCreateCallTimes, s.AfterSaveCallTimes, s.AfterUpdateCallTimes, s.BeforeDeleteCallTimes, s.AfterDeleteCallTimes, s.AfterFindCallTimes}
}
func TestRunCallbacks(t *testing.T) {
p := Product{Code: "unique_code", Price: 100}
DB.Save(&p)
p := Product{Code: "unique_code", Price: 100}
DB.Save(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0}) {
t.Errorf("Callbacks should be invoked successfully, %v", p.GetCallTimes())
}
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0}) {
t.Errorf("Callbacks should be invoked successfully, %v", p.GetCallTimes())
}
DB.Where("Code = ?", "unique_code").First(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0, 1}) {
t.Errorf("After callbacks values are not saved, %v", p.GetCallTimes())
}
DB.Where("Code = ?", "unique_code").First(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0, 1}) {
t.Errorf("After callbacks values are not saved, %v", p.GetCallTimes())
}
p.Price = 200
DB.Save(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0, 1}) {
t.Errorf("After update callbacks should be invoked successfully, %v", p.GetCallTimes())
}
p.Price = 200
DB.Save(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0, 1}) {
t.Errorf("After update callbacks should be invoked successfully, %v", p.GetCallTimes())
}
var products []Product
DB.Find(&products, "code = ?", "unique_code")
if products[0].AfterFindCallTimes != 2 {
t.Errorf("AfterFind callbacks should work with slice")
}
var products []Product
DB.Find(&products, "code = ?", "unique_code")
if products[0].AfterFindCallTimes != 2 {
t.Errorf("AfterFind callbacks should work with slice")
}
DB.Where("Code = ?", "unique_code").First(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2}) {
t.Errorf("After update callbacks values are not saved, %v", p.GetCallTimes())
}
DB.Where("Code = ?", "unique_code").First(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2}) {
t.Errorf("After update callbacks values are not saved, %v", p.GetCallTimes())
}
DB.Delete(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2}) {
t.Errorf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes())
}
DB.Delete(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2}) {
t.Errorf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes())
}
if DB.Where("Code = ?", "unique_code").First(&p).Error == nil {
t.Errorf("Can't find a deleted record")
}
if DB.Where("Code = ?", "unique_code").First(&p).Error == nil {
t.Errorf("Can't find a deleted record")
}
}
func TestCallbacksWithErrors(t *testing.T) {
p := Product{Code: "Invalid", Price: 100}
if DB.Save(&p).Error == nil {
t.Errorf("An error from before create callbacks happened when create with invalid value")
}
p := Product{Code: "Invalid", Price: 100}
if DB.Save(&p).Error == nil {
t.Errorf("An error from before create callbacks happened when create with invalid value")
}
if DB.Where("code = ?", "Invalid").First(&Product{}).Error == nil {
t.Errorf("Should not save record that have errors")
}
if DB.Where("code = ?", "Invalid").First(&Product{}).Error == nil {
t.Errorf("Should not save record that have errors")
}
if DB.Save(&Product{Code: "dont_save", Price: 100}).Error == nil {
t.Errorf("An error from after create callbacks happened when create with invalid value")
}
if DB.Save(&Product{Code: "dont_save", Price: 100}).Error == nil {
t.Errorf("An error from after create callbacks happened when create with invalid value")
}
p2 := Product{Code: "update_callback", Price: 100}
DB.Save(&p2)
p2 := Product{Code: "update_callback", Price: 100}
DB.Save(&p2)
p2.Code = "dont_update"
if DB.Save(&p2).Error == nil {
t.Errorf("An error from before update callbacks happened when update with invalid value")
}
p2.Code = "dont_update"
if DB.Save(&p2).Error == nil {
t.Errorf("An error from before update callbacks happened when update with invalid value")
}
if DB.Where("code = ?", "update_callback").First(&Product{}).Error != nil {
t.Errorf("Record Should not be updated due to errors happened in before update callback")
}
if DB.Where("code = ?", "update_callback").First(&Product{}).Error != nil {
t.Errorf("Record Should not be updated due to errors happened in before update callback")
}
if DB.Where("code = ?", "dont_update").First(&Product{}).Error == nil {
t.Errorf("Record Should not be updated due to errors happened in before update callback")
}
if DB.Where("code = ?", "dont_update").First(&Product{}).Error == nil {
t.Errorf("Record Should not be updated due to errors happened in before update callback")
}
p2.Code = "dont_save"
if DB.Save(&p2).Error == nil {
t.Errorf("An error from before save callbacks happened when update with invalid value")
}
p2.Code = "dont_save"
if DB.Save(&p2).Error == nil {
t.Errorf("An error from before save callbacks happened when update with invalid value")
}
p3 := Product{Code: "dont_delete", Price: 100}
DB.Save(&p3)
if DB.Delete(&p3).Error == nil {
t.Errorf("An error from before delete callbacks happened when delete")
}
p3 := Product{Code: "dont_delete", Price: 100}
DB.Save(&p3)
if DB.Delete(&p3).Error == nil {
t.Errorf("An error from before delete callbacks happened when delete")
}
if DB.Where("Code = ?", "dont_delete").First(&p3).Error != nil {
t.Errorf("An error from before delete callbacks happened")
}
if DB.Where("Code = ?", "dont_delete").First(&p3).Error != nil {
t.Errorf("An error from before delete callbacks happened")
}
p4 := Product{Code: "after_save_error", Price: 100}
DB.Save(&p4)
if err := DB.First(&Product{}, "code = ?", "after_save_error").Error; err == nil {
t.Errorf("Record should be reverted if get an error in after save callback")
}
p4 := Product{Code: "after_save_error", Price: 100}
DB.Save(&p4)
if err := DB.First(&Product{}, "code = ?", "after_save_error").Error; err == nil {
t.Errorf("Record should be reverted if get an error in after save callback")
}
p5 := Product{Code: "after_delete_error", Price: 100}
DB.Save(&p5)
if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
t.Errorf("Record should be found")
}
p5 := Product{Code: "after_delete_error", Price: 100}
DB.Save(&p5)
if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
t.Errorf("Record should be found")
}
DB.Delete(&p5)
if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback")
}
DB.Delete(&p5)
if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback")
}
}
func TestGetCallback(t *testing.T) {
scope := DB.NewScope(nil)
scope := DB.NewScope(nil)
if DB.Callback().Create().Get("gorm:test_callback") != nil {
t.Errorf("`gorm:test_callback` should be nil")
}
if DB.Callback().Create().Get("gorm:test_callback") != nil {
t.Errorf("`gorm:test_callback` should be nil")
}
DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 1) })
callback := DB.Callback().Create().Get("gorm:test_callback")
if callback == nil {
t.Errorf("`gorm:test_callback` should be non-nil")
}
callback(scope)
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 1 {
t.Errorf("`gorm:test_callback_value` should be `1, true` but `%v, %v`", v, ok)
}
DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 1) })
callback := DB.Callback().Create().Get("gorm:test_callback")
if callback == nil {
t.Errorf("`gorm:test_callback` should be non-nil")
}
callback(scope)
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 1 {
t.Errorf("`gorm:test_callback_value` should be `1, true` but `%v, %v`", v, ok)
}
DB.Callback().Create().Replace("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 2) })
callback = DB.Callback().Create().Get("gorm:test_callback")
if callback == nil {
t.Errorf("`gorm:test_callback` should be non-nil")
}
callback(scope)
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 2 {
t.Errorf("`gorm:test_callback_value` should be `2, true` but `%v, %v`", v, ok)
}
DB.Callback().Create().Replace("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 2) })
callback = DB.Callback().Create().Get("gorm:test_callback")
if callback == nil {
t.Errorf("`gorm:test_callback` should be non-nil")
}
callback(scope)
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 2 {
t.Errorf("`gorm:test_callback_value` should be `2, true` but `%v, %v`", v, ok)
}
DB.Callback().Create().Remove("gorm:test_callback")
if DB.Callback().Create().Get("gorm:test_callback") != nil {
t.Errorf("`gorm:test_callback` should be nil")
}
DB.Callback().Create().Remove("gorm:test_callback")
if DB.Callback().Create().Get("gorm:test_callback") != nil {
t.Errorf("`gorm:test_callback` should be nil")
}
DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 3) })
callback = DB.Callback().Create().Get("gorm:test_callback")
if callback == nil {
t.Errorf("`gorm:test_callback` should be non-nil")
}
callback(scope)
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 3 {
t.Errorf("`gorm:test_callback_value` should be `3, true` but `%v, %v`", v, ok)
}
DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 3) })
callback = DB.Callback().Create().Get("gorm:test_callback")
if callback == nil {
t.Errorf("`gorm:test_callback` should be non-nil")
}
callback(scope)
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 3 {
t.Errorf("`gorm:test_callback_value` should be `3, true` but `%v, %v`", v, ok)
}
}
func TestUseDefaultCallback(t *testing.T) {
createCallbackName := "gorm:test_use_default_callback_for_create"
gorm.DefaultCallback.Create().Register(createCallbackName, func(*gorm.Scope) {
// nop
})
if gorm.DefaultCallback.Create().Get(createCallbackName) == nil {
t.Errorf("`%s` expected non-nil, but got nil", createCallbackName)
}
gorm.DefaultCallback.Create().Remove(createCallbackName)
if gorm.DefaultCallback.Create().Get(createCallbackName) != nil {
t.Errorf("`%s` expected nil, but got non-nil", createCallbackName)
}
createCallbackName := "gorm:test_use_default_callback_for_create"
gorm.DefaultCallback.Create().Register(createCallbackName, func(*gorm.Scope) {
// nop
})
if gorm.DefaultCallback.Create().Get(createCallbackName) == nil {
t.Errorf("`%s` expected non-nil, but got nil", createCallbackName)
}
gorm.DefaultCallback.Create().Remove(createCallbackName)
if gorm.DefaultCallback.Create().Get(createCallbackName) != nil {
t.Errorf("`%s` expected nil, but got non-nil", createCallbackName)
}
updateCallbackName := "gorm:test_use_default_callback_for_update"
scopeValueName := "gorm:test_use_default_callback_for_update_value"
gorm.DefaultCallback.Update().Register(updateCallbackName, func(scope *gorm.Scope) {
scope.Set(scopeValueName, 1)
})
gorm.DefaultCallback.Update().Replace(updateCallbackName, func(scope *gorm.Scope) {
scope.Set(scopeValueName, 2)
})
updateCallbackName := "gorm:test_use_default_callback_for_update"
scopeValueName := "gorm:test_use_default_callback_for_update_value"
gorm.DefaultCallback.Update().Register(updateCallbackName, func(scope *gorm.Scope) {
scope.Set(scopeValueName, 1)
})
gorm.DefaultCallback.Update().Replace(updateCallbackName, func(scope *gorm.Scope) {
scope.Set(scopeValueName, 2)
})
scope := DB.NewScope(nil)
callback := gorm.DefaultCallback.Update().Get(updateCallbackName)
callback(scope)
if v, ok := scope.Get(scopeValueName); !ok || v != 2 {
t.Errorf("`%s` should be `2, true` but `%v, %v`", scopeValueName, v, ok)
}
scope := DB.NewScope(nil)
callback := gorm.DefaultCallback.Update().Get(updateCallbackName)
callback(scope)
if v, ok := scope.Get(scopeValueName); !ok || v != 2 {
t.Errorf("`%s` should be `2, true` but `%v, %v`", scopeValueName, v, ok)
}
}

View File

@ -1,357 +1,357 @@
package gorm_test
import (
"testing"
"time"
"testing"
"time"
"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm"
)
type CustomizeColumn struct {
ID int64 `gorm:"column:mapped_id; primary_key:yes"`
Name string `gorm:"column:mapped_name"`
Date *time.Time `gorm:"column:mapped_time"`
ID int64 `gorm:"column:mapped_id; primary_key:yes"`
Name string `gorm:"column:mapped_name"`
Date *time.Time `gorm:"column:mapped_time"`
}
// Make sure an ignored field does not interfere with another field's custom
// column name that matches the ignored field.
type CustomColumnAndIgnoredFieldClash struct {
Body string `sql:"-"`
RawBody string `gorm:"column:body"`
Body string `sql:"-"`
RawBody string `gorm:"column:body"`
}
func TestCustomizeColumn(t *testing.T) {
col := "mapped_name"
DB.DropTable(&CustomizeColumn{})
DB.AutoMigrate(&CustomizeColumn{})
col := "mapped_name"
DB.DropTable(&CustomizeColumn{})
DB.AutoMigrate(&CustomizeColumn{})
scope := DB.NewScope(&CustomizeColumn{})
if !scope.Dialect().HasColumn(scope.TableName(), col) {
t.Errorf("CustomizeColumn should have column %s", col)
}
scope := DB.NewScope(&CustomizeColumn{})
if !scope.Dialect().HasColumn(scope.TableName(), col) {
t.Errorf("CustomizeColumn should have column %s", col)
}
col = "mapped_id"
if scope.PrimaryKey() != col {
t.Errorf("CustomizeColumn should have primary key %s, but got %q", col, scope.PrimaryKey())
}
col = "mapped_id"
if scope.PrimaryKey() != col {
t.Errorf("CustomizeColumn should have primary key %s, but got %q", col, scope.PrimaryKey())
}
expected := "foo"
now := time.Now()
cc := CustomizeColumn{ID: 666, Name: expected, Date: &now}
expected := "foo"
now := time.Now()
cc := CustomizeColumn{ID: 666, Name: expected, Date: &now}
if count := DB.Create(&cc).RowsAffected; count != 1 {
t.Error("There should be one record be affected when create record")
}
if count := DB.Create(&cc).RowsAffected; count != 1 {
t.Error("There should be one record be affected when create record")
}
var cc1 CustomizeColumn
DB.First(&cc1, 666)
var cc1 CustomizeColumn
DB.First(&cc1, 666)
if cc1.Name != expected {
t.Errorf("Failed to query CustomizeColumn")
}
if cc1.Name != expected {
t.Errorf("Failed to query CustomizeColumn")
}
cc.Name = "bar"
DB.Save(&cc)
cc.Name = "bar"
DB.Save(&cc)
var cc2 CustomizeColumn
DB.First(&cc2, 666)
if cc2.Name != "bar" {
t.Errorf("Failed to query CustomizeColumn")
}
var cc2 CustomizeColumn
DB.First(&cc2, 666)
if cc2.Name != "bar" {
t.Errorf("Failed to query CustomizeColumn")
}
}
func TestCustomColumnAndIgnoredFieldClash(t *testing.T) {
DB.DropTable(&CustomColumnAndIgnoredFieldClash{})
if err := DB.AutoMigrate(&CustomColumnAndIgnoredFieldClash{}).Error; err != nil {
t.Errorf("Should not raise error: %s", err)
}
DB.DropTable(&CustomColumnAndIgnoredFieldClash{})
if err := DB.AutoMigrate(&CustomColumnAndIgnoredFieldClash{}).Error; err != nil {
t.Errorf("Should not raise error: %s", err)
}
}
type CustomizePerson struct {
IdPerson string `gorm:"column:idPerson;primary_key:true"`
Accounts []CustomizeAccount `gorm:"many2many:PersonAccount;associationforeignkey:idAccount;foreignkey:idPerson"`
IdPerson string `gorm:"column:idPerson;primary_key:true"`
Accounts []CustomizeAccount `gorm:"many2many:PersonAccount;associationforeignkey:idAccount;foreignkey:idPerson"`
}
type CustomizeAccount struct {
IdAccount string `gorm:"column:idAccount;primary_key:true"`
Name string
IdAccount string `gorm:"column:idAccount;primary_key:true"`
Name string
}
func TestManyToManyWithCustomizedColumn(t *testing.T) {
DB.DropTable(&CustomizePerson{}, &CustomizeAccount{}, "PersonAccount")
DB.AutoMigrate(&CustomizePerson{}, &CustomizeAccount{})
DB.DropTable(&CustomizePerson{}, &CustomizeAccount{}, "PersonAccount")
DB.AutoMigrate(&CustomizePerson{}, &CustomizeAccount{})
account := CustomizeAccount{IdAccount: "account", Name: "id1"}
person := CustomizePerson{
IdPerson: "person",
Accounts: []CustomizeAccount{account},
}
account := CustomizeAccount{IdAccount: "account", Name: "id1"}
person := CustomizePerson{
IdPerson: "person",
Accounts: []CustomizeAccount{account},
}
if err := DB.Create(&account).Error; err != nil {
t.Errorf("no error should happen, but got %v", err)
}
if err := DB.Create(&account).Error; err != nil {
t.Errorf("no error should happen, but got %v", err)
}
if err := DB.Create(&person).Error; err != nil {
t.Errorf("no error should happen, but got %v", err)
}
if err := DB.Create(&person).Error; err != nil {
t.Errorf("no error should happen, but got %v", err)
}
var person1 CustomizePerson
scope := DB.NewScope(nil)
if err := DB.Preload("Accounts").First(&person1, scope.Quote("idPerson")+" = ?", person.IdPerson).Error; err != nil {
t.Errorf("no error should happen when preloading customized column many2many relations, but got %v", err)
}
var person1 CustomizePerson
scope := DB.NewScope(nil)
if err := DB.Preload("Accounts").First(&person1, scope.Quote("idPerson")+" = ?", person.IdPerson).Error; err != nil {
t.Errorf("no error should happen when preloading customized column many2many relations, but got %v", err)
}
if len(person1.Accounts) != 1 || person1.Accounts[0].IdAccount != "account" {
t.Errorf("should preload correct accounts")
}
if len(person1.Accounts) != 1 || person1.Accounts[0].IdAccount != "account" {
t.Errorf("should preload correct accounts")
}
}
type CustomizeUser struct {
gorm.Model
Email string `sql:"column:email_address"`
gorm.Model
Email string `sql:"column:email_address"`
}
type CustomizeInvitation struct {
gorm.Model
Address string `sql:"column:invitation"`
Person *CustomizeUser `gorm:"foreignkey:Email;associationforeignkey:invitation"`
gorm.Model
Address string `sql:"column:invitation"`
Person *CustomizeUser `gorm:"foreignkey:Email;associationforeignkey:invitation"`
}
func TestOneToOneWithCustomizedColumn(t *testing.T) {
DB.DropTable(&CustomizeUser{}, &CustomizeInvitation{})
DB.AutoMigrate(&CustomizeUser{}, &CustomizeInvitation{})
DB.DropTable(&CustomizeUser{}, &CustomizeInvitation{})
DB.AutoMigrate(&CustomizeUser{}, &CustomizeInvitation{})
user := CustomizeUser{
Email: "hello@example.com",
}
invitation := CustomizeInvitation{
Address: "hello@example.com",
}
user := CustomizeUser{
Email: "hello@example.com",
}
invitation := CustomizeInvitation{
Address: "hello@example.com",
}
DB.Create(&user)
DB.Create(&invitation)
DB.Create(&user)
DB.Create(&invitation)
var invitation2 CustomizeInvitation
if err := DB.Preload("Person").Find(&invitation2, invitation.ID).Error; err != nil {
t.Errorf("no error should happen, but got %v", err)
}
var invitation2 CustomizeInvitation
if err := DB.Preload("Person").Find(&invitation2, invitation.ID).Error; err != nil {
t.Errorf("no error should happen, but got %v", err)
}
if invitation2.Person.Email != user.Email {
t.Errorf("Should preload one to one relation with customize foreign keys")
}
if invitation2.Person.Email != user.Email {
t.Errorf("Should preload one to one relation with customize foreign keys")
}
}
type PromotionDiscount struct {
gorm.Model
Name string
Coupons []*PromotionCoupon `gorm:"ForeignKey:discount_id"`
Rule *PromotionRule `gorm:"ForeignKey:discount_id"`
Benefits []PromotionBenefit `gorm:"ForeignKey:promotion_id"`
gorm.Model
Name string
Coupons []*PromotionCoupon `gorm:"ForeignKey:discount_id"`
Rule *PromotionRule `gorm:"ForeignKey:discount_id"`
Benefits []PromotionBenefit `gorm:"ForeignKey:promotion_id"`
}
type PromotionBenefit struct {
gorm.Model
Name string
PromotionID uint
Discount PromotionDiscount `gorm:"ForeignKey:promotion_id"`
gorm.Model
Name string
PromotionID uint
Discount PromotionDiscount `gorm:"ForeignKey:promotion_id"`
}
type PromotionCoupon struct {
gorm.Model
Code string
DiscountID uint
Discount PromotionDiscount
gorm.Model
Code string
DiscountID uint
Discount PromotionDiscount
}
type PromotionRule struct {
gorm.Model
Name string
Begin *time.Time
End *time.Time
DiscountID uint
Discount *PromotionDiscount
gorm.Model
Name string
Begin *time.Time
End *time.Time
DiscountID uint
Discount *PromotionDiscount
}
func TestOneToManyWithCustomizedColumn(t *testing.T) {
DB.DropTable(&PromotionDiscount{}, &PromotionCoupon{})
DB.AutoMigrate(&PromotionDiscount{}, &PromotionCoupon{})
DB.DropTable(&PromotionDiscount{}, &PromotionCoupon{})
DB.AutoMigrate(&PromotionDiscount{}, &PromotionCoupon{})
discount := PromotionDiscount{
Name: "Happy New Year",
Coupons: []*PromotionCoupon{
{Code: "newyear1"},
{Code: "newyear2"},
},
}
discount := PromotionDiscount{
Name: "Happy New Year",
Coupons: []*PromotionCoupon{
{Code: "newyear1"},
{Code: "newyear2"},
},
}
if err := DB.Create(&discount).Error; err != nil {
t.Errorf("no error should happen but got %v", err)
}
if err := DB.Create(&discount).Error; err != nil {
t.Errorf("no error should happen but got %v", err)
}
var discount1 PromotionDiscount
if err := DB.Preload("Coupons").First(&discount1, "id = ?", discount.ID).Error; err != nil {
t.Errorf("no error should happen but got %v", err)
}
var discount1 PromotionDiscount
if err := DB.Preload("Coupons").First(&discount1, "id = ?", discount.ID).Error; err != nil {
t.Errorf("no error should happen but got %v", err)
}
if len(discount.Coupons) != 2 {
t.Errorf("should find two coupons")
}
if len(discount.Coupons) != 2 {
t.Errorf("should find two coupons")
}
var coupon PromotionCoupon
if err := DB.Preload("Discount").First(&coupon, "code = ?", "newyear1").Error; err != nil {
t.Errorf("no error should happen but got %v", err)
}
var coupon PromotionCoupon
if err := DB.Preload("Discount").First(&coupon, "code = ?", "newyear1").Error; err != nil {
t.Errorf("no error should happen but got %v", err)
}
if coupon.Discount.Name != "Happy New Year" {
t.Errorf("should preload discount from coupon")
}
if coupon.Discount.Name != "Happy New Year" {
t.Errorf("should preload discount from coupon")
}
}
func TestHasOneWithPartialCustomizedColumn(t *testing.T) {
DB.DropTable(&PromotionDiscount{}, &PromotionRule{})
DB.AutoMigrate(&PromotionDiscount{}, &PromotionRule{})
DB.DropTable(&PromotionDiscount{}, &PromotionRule{})
DB.AutoMigrate(&PromotionDiscount{}, &PromotionRule{})
var begin = time.Now()
var end = time.Now().Add(24 * time.Hour)
discount := PromotionDiscount{
Name: "Happy New Year 2",
Rule: &PromotionRule{
Name: "time_limited",
Begin: &begin,
End: &end,
},
}
var begin = time.Now()
var end = time.Now().Add(24 * time.Hour)
discount := PromotionDiscount{
Name: "Happy New Year 2",
Rule: &PromotionRule{
Name: "time_limited",
Begin: &begin,
End: &end,
},
}
if err := DB.Create(&discount).Error; err != nil {
t.Errorf("no error should happen but got %v", err)
}
if err := DB.Create(&discount).Error; err != nil {
t.Errorf("no error should happen but got %v", err)
}
var discount1 PromotionDiscount
if err := DB.Preload("Rule").First(&discount1, "id = ?", discount.ID).Error; err != nil {
t.Errorf("no error should happen but got %v", err)
}
var discount1 PromotionDiscount
if err := DB.Preload("Rule").First(&discount1, "id = ?", discount.ID).Error; err != nil {
t.Errorf("no error should happen but got %v", err)
}
if discount.Rule.Begin.Format(time.RFC3339Nano) != begin.Format(time.RFC3339Nano) {
t.Errorf("Should be able to preload Rule")
}
if discount.Rule.Begin.Format(time.RFC3339Nano) != begin.Format(time.RFC3339Nano) {
t.Errorf("Should be able to preload Rule")
}
var rule PromotionRule
if err := DB.Preload("Discount").First(&rule, "name = ?", "time_limited").Error; err != nil {
t.Errorf("no error should happen but got %v", err)
}
var rule PromotionRule
if err := DB.Preload("Discount").First(&rule, "name = ?", "time_limited").Error; err != nil {
t.Errorf("no error should happen but got %v", err)
}
if rule.Discount.Name != "Happy New Year 2" {
t.Errorf("should preload discount from rule")
}
if rule.Discount.Name != "Happy New Year 2" {
t.Errorf("should preload discount from rule")
}
}
func TestBelongsToWithPartialCustomizedColumn(t *testing.T) {
DB.DropTable(&PromotionDiscount{}, &PromotionBenefit{})
DB.AutoMigrate(&PromotionDiscount{}, &PromotionBenefit{})
DB.DropTable(&PromotionDiscount{}, &PromotionBenefit{})
DB.AutoMigrate(&PromotionDiscount{}, &PromotionBenefit{})
discount := PromotionDiscount{
Name: "Happy New Year 3",
Benefits: []PromotionBenefit{
{Name: "free cod"},
{Name: "free shipping"},
},
}
discount := PromotionDiscount{
Name: "Happy New Year 3",
Benefits: []PromotionBenefit{
{Name: "free cod"},
{Name: "free shipping"},
},
}
if err := DB.Create(&discount).Error; err != nil {
t.Errorf("no error should happen but got %v", err)
}
if err := DB.Create(&discount).Error; err != nil {
t.Errorf("no error should happen but got %v", err)
}
var discount1 PromotionDiscount
if err := DB.Preload("Benefits").First(&discount1, "id = ?", discount.ID).Error; err != nil {
t.Errorf("no error should happen but got %v", err)
}
var discount1 PromotionDiscount
if err := DB.Preload("Benefits").First(&discount1, "id = ?", discount.ID).Error; err != nil {
t.Errorf("no error should happen but got %v", err)
}
if len(discount.Benefits) != 2 {
t.Errorf("should find two benefits")
}
if len(discount.Benefits) != 2 {
t.Errorf("should find two benefits")
}
var benefit PromotionBenefit
if err := DB.Preload("Discount").First(&benefit, "name = ?", "free cod").Error; err != nil {
t.Errorf("no error should happen but got %v", err)
}
var benefit PromotionBenefit
if err := DB.Preload("Discount").First(&benefit, "name = ?", "free cod").Error; err != nil {
t.Errorf("no error should happen but got %v", err)
}
if benefit.Discount.Name != "Happy New Year 3" {
t.Errorf("should preload discount from coupon")
}
if benefit.Discount.Name != "Happy New Year 3" {
t.Errorf("should preload discount from coupon")
}
}
type SelfReferencingUser struct {
gorm.Model
Name string
Friends []*SelfReferencingUser `gorm:"many2many:UserFriends;association_jointable_foreignkey:friend_id"`
gorm.Model
Name string
Friends []*SelfReferencingUser `gorm:"many2many:UserFriends;association_jointable_foreignkey:friend_id"`
}
func TestSelfReferencingMany2ManyColumn(t *testing.T) {
DB.DropTable(&SelfReferencingUser{}, "UserFriends")
DB.AutoMigrate(&SelfReferencingUser{})
if !DB.HasTable("UserFriends") {
t.Errorf("auto migrate error, table UserFriends should be created")
}
DB.DropTable(&SelfReferencingUser{}, "UserFriends")
DB.AutoMigrate(&SelfReferencingUser{})
if !DB.HasTable("UserFriends") {
t.Errorf("auto migrate error, table UserFriends should be created")
}
friend1 := SelfReferencingUser{Name: "friend1_m2m"}
if err := DB.Create(&friend1).Error; err != nil {
t.Errorf("no error should happen, but got %v", err)
}
friend1 := SelfReferencingUser{Name: "friend1_m2m"}
if err := DB.Create(&friend1).Error; err != nil {
t.Errorf("no error should happen, but got %v", err)
}
friend2 := SelfReferencingUser{Name: "friend2_m2m"}
if err := DB.Create(&friend2).Error; err != nil {
t.Errorf("no error should happen, but got %v", err)
}
friend2 := SelfReferencingUser{Name: "friend2_m2m"}
if err := DB.Create(&friend2).Error; err != nil {
t.Errorf("no error should happen, but got %v", err)
}
user := SelfReferencingUser{
Name: "self_m2m",
Friends: []*SelfReferencingUser{&friend1, &friend2},
}
user := SelfReferencingUser{
Name: "self_m2m",
Friends: []*SelfReferencingUser{&friend1, &friend2},
}
if err := DB.Create(&user).Error; err != nil {
t.Errorf("no error should happen, but got %v", err)
}
if err := DB.Create(&user).Error; err != nil {
t.Errorf("no error should happen, but got %v", err)
}
if DB.Model(&user).Association("Friends").Count() != 2 {
t.Errorf("Should find created friends correctly")
}
if DB.Model(&user).Association("Friends").Count() != 2 {
t.Errorf("Should find created friends correctly")
}
var count int
if err := DB.Table("UserFriends").Count(&count).Error; err != nil {
t.Errorf("no error should happen, but got %v", err)
}
if count == 0 {
t.Errorf("table UserFriends should have records")
}
var count int
if err := DB.Table("UserFriends").Count(&count).Error; err != nil {
t.Errorf("no error should happen, but got %v", err)
}
if count == 0 {
t.Errorf("table UserFriends should have records")
}
var newUser = SelfReferencingUser{}
var newUser = SelfReferencingUser{}
if err := DB.Preload("Friends").First(&newUser, "id = ?", user.ID).Error; err != nil {
t.Errorf("no error should happen, but got %v", err)
}
if err := DB.Preload("Friends").First(&newUser, "id = ?", user.ID).Error; err != nil {
t.Errorf("no error should happen, but got %v", err)
}
if len(newUser.Friends) != 2 {
t.Errorf("Should preload created frineds for self reference m2m")
}
if len(newUser.Friends) != 2 {
t.Errorf("Should preload created frineds for self reference m2m")
}
DB.Model(&newUser).Association("Friends").Append(&SelfReferencingUser{Name: "friend3_m2m"})
if DB.Model(&user).Association("Friends").Count() != 3 {
t.Errorf("Should find created friends correctly")
}
DB.Model(&newUser).Association("Friends").Append(&SelfReferencingUser{Name: "friend3_m2m"})
if DB.Model(&user).Association("Friends").Count() != 3 {
t.Errorf("Should find created friends correctly")
}
DB.Model(&newUser).Association("Friends").Replace(&SelfReferencingUser{Name: "friend4_m2m"})
if DB.Model(&user).Association("Friends").Count() != 1 {
t.Errorf("Should find created friends correctly")
}
DB.Model(&newUser).Association("Friends").Replace(&SelfReferencingUser{Name: "friend4_m2m"})
if DB.Model(&user).Association("Friends").Count() != 1 {
t.Errorf("Should find created friends correctly")
}
friend := SelfReferencingUser{}
DB.Model(&newUser).Association("Friends").Find(&friend)
if friend.Name != "friend4_m2m" {
t.Errorf("Should find created friends correctly")
}
friend := SelfReferencingUser{}
DB.Model(&newUser).Association("Friends").Find(&friend)
if friend.Name != "friend4_m2m" {
t.Errorf("Should find created friends correctly")
}
DB.Model(&newUser).Association("Friends").Delete(friend)
if DB.Model(&user).Association("Friends").Count() != 0 {
t.Errorf("All friends should be deleted")
}
DB.Model(&newUser).Association("Friends").Delete(friend)
if DB.Model(&user).Association("Friends").Count() != 0 {
t.Errorf("All friends should be deleted")
}
}

View File

@ -25,28 +25,28 @@ type Dialect interface {
DataTypeOf(field *StructField) string
// HasIndex check has index or not
HasIndex(tableName string, indexName string) bool
// HasIndexContext same as HasIndex
HasIndexContext(ctx context.Context, tableName string, indexName string) bool
// HasForeignKey check has foreign key or not
// HasIndexContext same as HasIndex
HasIndexContext(ctx context.Context, tableName string, indexName string) bool
// HasForeignKey check has foreign key or not
HasForeignKey(tableName string, foreignKeyName string) bool
// HasForeignKeyContext same as HasForeignKey
HasForeignKeyContext(ctx context.Context, tableName string, foreignKeyName string) bool
// RemoveIndex remove index
// HasForeignKeyContext same as HasForeignKey
HasForeignKeyContext(ctx context.Context, tableName string, foreignKeyName string) bool
// RemoveIndex remove index
RemoveIndex(tableName string, indexName string) error
// RemoveIndexContext same as RemoveIndex
RemoveIndexContext(ctx context.Context, tableName string, indexName string) error
// HasTable check has table or not
RemoveIndexContext(ctx context.Context, tableName string, indexName string) error
// HasTable check has table or not
HasTable(tableName string) bool
// HasTableContext same as HasTable
HasTableContext(ctx context.Context, tableName string) bool
// HasTableContext same as HasTable
HasTableContext(ctx context.Context, tableName string) bool
// HasColumn check has column or not
HasColumn(tableName string, columnName string) bool
// HasColumnContext same as HasColumn
HasColumnContext(ctx context.Context, tableName string, columnName string) bool
// ModifyColumn modify column's type
HasColumnContext(ctx context.Context, tableName string, columnName string) bool
// ModifyColumn modify column's type
ModifyColumn(tableName string, columnName string, typ string) error
// ModifyColumnContext same as ModifyColumn
ModifyColumnContext(ctx context.Context, tableName string, columnName string, typ string) error
// ModifyColumnContext same as ModifyColumn
ModifyColumnContext(ctx context.Context, tableName string, columnName string, typ string) error
// LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case
LimitAndOffsetSQL(limit, offset interface{}) (string, error)
@ -55,12 +55,12 @@ type Dialect interface {
// LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT`
LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string
// LastInsertIDOutputInterstitialContext same as LastInsertIDOutputInterstitial
LastInsertIDOutputInterstitialContext(ctx context.Context, tableName, columnName string, columns []string) string
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
LastInsertIDOutputInterstitialContext(ctx context.Context, tableName, columnName string, columns []string) string
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
LastInsertIDReturningSuffix(tableName, columnName string) string
// LastInsertIDReturningSuffixContext same as LastInsertIDReturningSuffix
LastInsertIDReturningSuffixContext(ctx context.Context, tableName, columnName string) string
// DefaultValueStr
LastInsertIDReturningSuffixContext(ctx context.Context, tableName, columnName string) string
// DefaultValueStr
DefaultValueStr() string
// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference
@ -71,8 +71,8 @@ type Dialect interface {
// CurrentDatabase return current database name
CurrentDatabase() string
// CurrentDatabaseContext same as CurrentDatabase
CurrentDatabaseContext(ctx context.Context) string
// CurrentDatabaseContext same as CurrentDatabase
CurrentDatabaseContext(ctx context.Context) string
}
var dialectsMap = map[string]Dialect{}

View File

@ -101,8 +101,8 @@ func (s *commonDialect) DataTypeOf(field *StructField) string {
}
func (s commonDialect) HasIndex(tableName string, indexName string) bool {
ctx := context.Background()
return s.HasIndexContext(ctx, tableName, indexName)
ctx := context.Background()
return s.HasIndexContext(ctx, tableName, indexName)
}
func (s commonDialect) HasIndexContext(ctx context.Context, tableName string, indexName string) bool {
@ -113,8 +113,8 @@ func (s commonDialect) HasIndexContext(ctx context.Context, tableName string, in
}
func (s commonDialect) RemoveIndex(tableName string, indexName string) error {
ctx := context.Background()
return s.RemoveIndexContext(ctx, tableName, indexName)
ctx := context.Background()
return s.RemoveIndexContext(ctx, tableName, indexName)
}
func (s commonDialect) RemoveIndexContext(ctx context.Context, tableName string, indexName string) error {
@ -123,8 +123,8 @@ func (s commonDialect) RemoveIndexContext(ctx context.Context, tableName string,
}
func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool {
ctx := context.Background()
return s.HasForeignKeyContext(ctx, tableName, foreignKeyName)
ctx := context.Background()
return s.HasForeignKeyContext(ctx, tableName, foreignKeyName)
}
func (s commonDialect) HasForeignKeyContext(_ctx context.Context, tableName string, foreignKeyName string) bool {
@ -132,8 +132,8 @@ func (s commonDialect) HasForeignKeyContext(_ctx context.Context, tableName stri
}
func (s commonDialect) HasTable(tableName string) bool {
ctx := context.Background()
return s.HasTableContext(ctx, tableName)
ctx := context.Background()
return s.HasTableContext(ctx, tableName)
}
func (s commonDialect) HasTableContext(ctx context.Context, tableName string) bool {
@ -144,8 +144,8 @@ func (s commonDialect) HasTableContext(ctx context.Context, tableName string) bo
}
func (s commonDialect) HasColumn(tableName string, columnName string) bool {
ctx := context.Background()
return s.HasColumnContext(ctx, tableName, columnName)
ctx := context.Background()
return s.HasColumnContext(ctx, tableName, columnName)
}
func (s commonDialect) HasColumnContext(ctx context.Context, tableName string, columnName string) bool {
@ -156,8 +156,8 @@ func (s commonDialect) HasColumnContext(ctx context.Context, tableName string, c
}
func (s commonDialect) ModifyColumn(tableName string, columnName string, typ string) error {
ctx := context.Background()
return s.ModifyColumnContext(ctx, tableName, columnName, typ)
ctx := context.Background()
return s.ModifyColumnContext(ctx, tableName, columnName, typ)
}
func (s commonDialect) ModifyColumnContext(ctx context.Context, tableName string, columnName string, typ string) error {
@ -166,8 +166,8 @@ func (s commonDialect) ModifyColumnContext(ctx context.Context, tableName string
}
func (s commonDialect) CurrentDatabase() (name string) {
ctx := context.Background()
return s.CurrentDatabaseContext(ctx)
ctx := context.Background()
return s.CurrentDatabaseContext(ctx)
}
func (s commonDialect) CurrentDatabaseContext(ctx context.Context) (name string) {
@ -199,8 +199,8 @@ func (commonDialect) SelectFromDummyTable() string {
}
func (s commonDialect) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string {
ctx := context.Background()
return s.LastInsertIDOutputInterstitialContext(ctx, tableName, columnName, columns)
ctx := context.Background()
return s.LastInsertIDOutputInterstitialContext(ctx, tableName, columnName, columns)
}
func (commonDialect) LastInsertIDOutputInterstitialContext(_ctx context.Context, tableName, columnName string, columns []string) string {
@ -208,8 +208,8 @@ func (commonDialect) LastInsertIDOutputInterstitialContext(_ctx context.Context,
}
func (s commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string {
ctx := context.Background()
return s.LastInsertIDReturningSuffixContext(ctx, tableName, columnName)
ctx := context.Background()
return s.LastInsertIDReturningSuffixContext(ctx, tableName, columnName)
}
func (commonDialect) LastInsertIDReturningSuffixContext(_ctx context.Context, tableName, columnName string) string {

View File

@ -124,8 +124,8 @@ func (s mssql) fieldCanAutoIncrement(field *gorm.StructField) bool {
}
func (s mssql) HasIndex(tableName string, indexName string) bool {
ctx := context.Background()
return s.HasIndexContext(ctx, tableName, indexName)
ctx := context.Background()
return s.HasIndexContext(ctx, tableName, indexName)
}
func (s mssql) HasIndexContext(ctx context.Context, tableName string, indexName string) bool {
@ -135,8 +135,8 @@ func (s mssql) HasIndexContext(ctx context.Context, tableName string, indexName
}
func (s mssql) RemoveIndex(tableName string, indexName string) error {
ctx := context.Background()
return s.RemoveIndexContext(ctx, tableName, indexName)
ctx := context.Background()
return s.RemoveIndexContext(ctx, tableName, indexName)
}
func (s mssql) RemoveIndexContext(ctx context.Context, tableName string, indexName string) error {
@ -145,8 +145,8 @@ func (s mssql) RemoveIndexContext(ctx context.Context, tableName string, indexNa
}
func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool {
ctx := context.Background()
return s.HasForeignKeyContext(ctx, tableName, foreignKeyName)
ctx := context.Background()
return s.HasForeignKeyContext(ctx, tableName, foreignKeyName)
}
func (s mssql) HasForeignKeyContext(ctx context.Context, tableName string, foreignKeyName string) bool {
@ -161,8 +161,8 @@ func (s mssql) HasForeignKeyContext(ctx context.Context, tableName string, forei
}
func (s mssql) HasTable(tableName string) bool {
ctx := context.Background()
return s.HasTableContext(ctx, tableName)
ctx := context.Background()
return s.HasTableContext(ctx, tableName)
}
func (s mssql) HasTableContext(ctx context.Context, tableName string) bool {
@ -173,8 +173,8 @@ func (s mssql) HasTableContext(ctx context.Context, tableName string) bool {
}
func (s mssql) HasColumn(tableName string, columnName string) bool {
ctx := context.Background()
return s.HasColumnContext(ctx, tableName, columnName)
ctx := context.Background()
return s.HasColumnContext(ctx, tableName, columnName)
}
func (s mssql) HasColumnContext(ctx context.Context, tableName string, columnName string) bool {
@ -185,8 +185,8 @@ func (s mssql) HasColumnContext(ctx context.Context, tableName string, columnNam
}
func (s mssql) ModifyColumn(tableName string, columnName string, typ string) error {
ctx := context.Background()
return s.ModifyColumnContext(ctx, tableName, columnName, typ)
ctx := context.Background()
return s.ModifyColumnContext(ctx, tableName, columnName, typ)
}
func (s mssql) ModifyColumnContext(ctx context.Context, tableName string, columnName string, typ string) error {
@ -195,9 +195,9 @@ func (s mssql) ModifyColumnContext(ctx context.Context, tableName string, column
}
func (s mssql) CurrentDatabase() (name string) {
ctx := context.Background()
s.CurrentDatabaseContext(ctx)
return
ctx := context.Background()
s.CurrentDatabaseContext(ctx)
return
}
func (s mssql) CurrentDatabaseContext(ctx context.Context) (name string) {
@ -236,8 +236,8 @@ func (mssql) SelectFromDummyTable() string {
}
func (s mssql) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string {
ctx := context.Background()
return s.LastInsertIDOutputInterstitialContext(ctx, tableName, columnName, columns)
ctx := context.Background()
return s.LastInsertIDOutputInterstitialContext(ctx, tableName, columnName, columns)
}
func (mssql) LastInsertIDOutputInterstitialContext(_ctx context.Context, tableName, columnName string, columns []string) string {
@ -249,8 +249,8 @@ func (mssql) LastInsertIDOutputInterstitialContext(_ctx context.Context, tableNa
}
func (s mssql) LastInsertIDReturningSuffix(tableName, columnName string) string {
ctx := context.Background()
return s.LastInsertIDReturningSuffixContext(ctx, tableName, columnName)
ctx := context.Background()
return s.LastInsertIDReturningSuffixContext(ctx, tableName, columnName)
}
func (mssql) LastInsertIDReturningSuffixContext(_ctx context.Context, tableName, columnName string) string {

View File

@ -3,89 +3,89 @@ package gorm_test
import "testing"
type BasePost struct {
Id int64
Title string
URL string
Id int64
Title string
URL string
}
type Author struct {
ID string
Name string
Email string
ID string
Name string
Email string
}
type HNPost struct {
BasePost
Author `gorm:"embedded_prefix:user_"` // Embedded struct
Upvotes int32
BasePost
Author `gorm:"embedded_prefix:user_"` // Embedded struct
Upvotes int32
}
type EngadgetPost struct {
BasePost BasePost `gorm:"embedded"`
Author Author `gorm:"embedded;embedded_prefix:author_"` // Embedded struct
ImageUrl string
BasePost BasePost `gorm:"embedded"`
Author Author `gorm:"embedded;embedded_prefix:author_"` // Embedded struct
ImageUrl string
}
func TestPrefixColumnNameForEmbeddedStruct(t *testing.T) {
dialect := DB.NewScope(&EngadgetPost{}).Dialect()
engadgetPostScope := DB.NewScope(&EngadgetPost{})
if !dialect.HasColumn(engadgetPostScope.TableName(), "author_id") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_name") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_email") {
t.Errorf("should has prefix for embedded columns")
}
dialect := DB.NewScope(&EngadgetPost{}).Dialect()
engadgetPostScope := DB.NewScope(&EngadgetPost{})
if !dialect.HasColumn(engadgetPostScope.TableName(), "author_id") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_name") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_email") {
t.Errorf("should has prefix for embedded columns")
}
if len(engadgetPostScope.PrimaryFields()) != 1 {
t.Errorf("should have only one primary field with embedded struct, but got %v", len(engadgetPostScope.PrimaryFields()))
}
if len(engadgetPostScope.PrimaryFields()) != 1 {
t.Errorf("should have only one primary field with embedded struct, but got %v", len(engadgetPostScope.PrimaryFields()))
}
hnScope := DB.NewScope(&HNPost{})
if !dialect.HasColumn(hnScope.TableName(), "user_id") || !dialect.HasColumn(hnScope.TableName(), "user_name") || !dialect.HasColumn(hnScope.TableName(), "user_email") {
t.Errorf("should has prefix for embedded columns")
}
hnScope := DB.NewScope(&HNPost{})
if !dialect.HasColumn(hnScope.TableName(), "user_id") || !dialect.HasColumn(hnScope.TableName(), "user_name") || !dialect.HasColumn(hnScope.TableName(), "user_email") {
t.Errorf("should has prefix for embedded columns")
}
}
func TestSaveAndQueryEmbeddedStruct(t *testing.T) {
DB.Save(&HNPost{BasePost: BasePost{Title: "news"}})
DB.Save(&HNPost{BasePost: BasePost{Title: "hn_news"}})
var news HNPost
if err := DB.First(&news, "title = ?", "hn_news").Error; err != nil {
t.Errorf("no error should happen when query with embedded struct, but got %v", err)
} else if news.Title != "hn_news" {
t.Errorf("embedded struct's value should be scanned correctly")
}
DB.Save(&HNPost{BasePost: BasePost{Title: "news"}})
DB.Save(&HNPost{BasePost: BasePost{Title: "hn_news"}})
var news HNPost
if err := DB.First(&news, "title = ?", "hn_news").Error; err != nil {
t.Errorf("no error should happen when query with embedded struct, but got %v", err)
} else if news.Title != "hn_news" {
t.Errorf("embedded struct's value should be scanned correctly")
}
DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}})
var egNews EngadgetPost
if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil {
t.Errorf("no error should happen when query with embedded struct, but got %v", err)
} else if egNews.BasePost.Title != "engadget_news" {
t.Errorf("embedded struct's value should be scanned correctly")
}
DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}})
var egNews EngadgetPost
if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil {
t.Errorf("no error should happen when query with embedded struct, but got %v", err)
} else if egNews.BasePost.Title != "engadget_news" {
t.Errorf("embedded struct's value should be scanned correctly")
}
if DB.NewScope(&HNPost{}).PrimaryField() == nil {
t.Errorf("primary key with embedded struct should works")
}
if DB.NewScope(&HNPost{}).PrimaryField() == nil {
t.Errorf("primary key with embedded struct should works")
}
for _, field := range DB.NewScope(&HNPost{}).Fields() {
if field.Name == "BasePost" {
t.Errorf("scope Fields should not contain embedded struct")
}
}
for _, field := range DB.NewScope(&HNPost{}).Fields() {
if field.Name == "BasePost" {
t.Errorf("scope Fields should not contain embedded struct")
}
}
}
func TestEmbeddedPointerTypeStruct(t *testing.T) {
type HNPost struct {
*BasePost
Upvotes int32
}
type HNPost struct {
*BasePost
Upvotes int32
}
DB.Create(&HNPost{BasePost: &BasePost{Title: "embedded_pointer_type"}})
DB.Create(&HNPost{BasePost: &BasePost{Title: "embedded_pointer_type"}})
var hnPost HNPost
if err := DB.First(&hnPost, "title = ?", "embedded_pointer_type").Error; err != nil {
t.Errorf("No error should happen when find embedded pointer type, but got %v", err)
}
var hnPost HNPost
if err := DB.First(&hnPost, "title = ?", "embedded_pointer_type").Error; err != nil {
t.Errorf("No error should happen when find embedded pointer type, but got %v", err)
}
if hnPost.Title != "embedded_pointer_type" {
t.Errorf("Should find correct value for embedded pointer type")
}
if hnPost.Title != "embedded_pointer_type" {
t.Errorf("Should find correct value for embedded pointer type")
}
}

View File

@ -11,10 +11,10 @@ type SQLCommon interface {
Prepare(query string) (*sql.Stmt, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}
type sqlDb interface {

View File

@ -675,7 +675,6 @@ func (s *DB) Begin() *DB {
return s.BeginTx(context.Background(), &sql.TxOptions{})
}
// BeginTx begins a transaction with options
func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) *DB {
c := s.clone()

View File

@ -1,579 +1,579 @@
package gorm_test
import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"os"
"reflect"
"strconv"
"testing"
"time"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"os"
"reflect"
"strconv"
"testing"
"time"
"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm"
)
type User struct {
Id int64
Age int64
UserNum Num
Name string `sql:"size:255"`
Email string
Birthday *time.Time // Time
CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically
UpdatedAt time.Time // UpdatedAt: Time of record is updated, will be updated automatically
Emails []Email // Embedded structs
BillingAddress Address // Embedded struct
BillingAddressID sql.NullInt64 // Embedded struct's foreign key
ShippingAddress Address // Embedded struct
ShippingAddressId int64 // Embedded struct's foreign key
CreditCard CreditCard
Latitude float64
Languages []Language `gorm:"many2many:user_languages;"`
CompanyID *int
Company Company
Role Role
Password EncryptedData
PasswordHash []byte
IgnoreMe int64 `sql:"-"`
IgnoreStringSlice []string `sql:"-"`
Ignored struct{ Name string } `sql:"-"`
IgnoredPointer *User `sql:"-"`
Id int64
Age int64
UserNum Num
Name string `sql:"size:255"`
Email string
Birthday *time.Time // Time
CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically
UpdatedAt time.Time // UpdatedAt: Time of record is updated, will be updated automatically
Emails []Email // Embedded structs
BillingAddress Address // Embedded struct
BillingAddressID sql.NullInt64 // Embedded struct's foreign key
ShippingAddress Address // Embedded struct
ShippingAddressId int64 // Embedded struct's foreign key
CreditCard CreditCard
Latitude float64
Languages []Language `gorm:"many2many:user_languages;"`
CompanyID *int
Company Company
Role Role
Password EncryptedData
PasswordHash []byte
IgnoreMe int64 `sql:"-"`
IgnoreStringSlice []string `sql:"-"`
Ignored struct{ Name string } `sql:"-"`
IgnoredPointer *User `sql:"-"`
}
type NotSoLongTableName struct {
Id int64
ReallyLongThingID int64
ReallyLongThing ReallyLongTableNameToTestMySQLNameLengthLimit
Id int64
ReallyLongThingID int64
ReallyLongThing ReallyLongTableNameToTestMySQLNameLengthLimit
}
type ReallyLongTableNameToTestMySQLNameLengthLimit struct {
Id int64
Id int64
}
type ReallyLongThingThatReferencesShort struct {
Id int64
ShortID int64
Short Short
Id int64
ShortID int64
Short Short
}
type Short struct {
Id int64
Id int64
}
type CreditCard struct {
ID int8
Number string
UserId sql.NullInt64
CreatedAt time.Time `sql:"not null"`
UpdatedAt time.Time
DeletedAt *time.Time `sql:"column:deleted_time"`
ID int8
Number string
UserId sql.NullInt64
CreatedAt time.Time `sql:"not null"`
UpdatedAt time.Time
DeletedAt *time.Time `sql:"column:deleted_time"`
}
type Email struct {
Id int16
UserId int
Email string `sql:"type:varchar(100);"`
CreatedAt time.Time
UpdatedAt time.Time
Id int16
UserId int
Email string `sql:"type:varchar(100);"`
CreatedAt time.Time
UpdatedAt time.Time
}
type Address struct {
ID int
Address1 string
Address2 string
Post string
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt *time.Time
ID int
Address1 string
Address2 string
Post string
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt *time.Time
}
type Language struct {
gorm.Model
Name string
Users []User `gorm:"many2many:user_languages;"`
gorm.Model
Name string
Users []User `gorm:"many2many:user_languages;"`
}
type Product struct {
Id int64
Code string
Price int64
CreatedAt time.Time
UpdatedAt time.Time
AfterFindCallTimes int64
BeforeCreateCallTimes int64
AfterCreateCallTimes int64
BeforeUpdateCallTimes int64
AfterUpdateCallTimes int64
BeforeSaveCallTimes int64
AfterSaveCallTimes int64
BeforeDeleteCallTimes int64
AfterDeleteCallTimes int64
Id int64
Code string
Price int64
CreatedAt time.Time
UpdatedAt time.Time
AfterFindCallTimes int64
BeforeCreateCallTimes int64
AfterCreateCallTimes int64
BeforeUpdateCallTimes int64
AfterUpdateCallTimes int64
BeforeSaveCallTimes int64
AfterSaveCallTimes int64
BeforeDeleteCallTimes int64
AfterDeleteCallTimes int64
}
type Company struct {
Id int64
Name string
Owner *User `sql:"-"`
Id int64
Name string
Owner *User `sql:"-"`
}
type Place struct {
Id int64
PlaceAddressID int
PlaceAddress *Address `gorm:"save_associations:false"`
OwnerAddressID int
OwnerAddress *Address `gorm:"save_associations:true"`
Id int64
PlaceAddressID int
PlaceAddress *Address `gorm:"save_associations:false"`
OwnerAddressID int
OwnerAddress *Address `gorm:"save_associations:true"`
}
type EncryptedData []byte
func (data *EncryptedData) Scan(value interface{}) error {
if b, ok := value.([]byte); ok {
if len(b) < 3 || b[0] != '*' || b[1] != '*' || b[2] != '*' {
return errors.New("Too short")
}
if b, ok := value.([]byte); ok {
if len(b) < 3 || b[0] != '*' || b[1] != '*' || b[2] != '*' {
return errors.New("Too short")
}
*data = b[3:]
return nil
}
*data = b[3:]
return nil
}
return errors.New("Bytes expected")
return errors.New("Bytes expected")
}
func (data EncryptedData) Value() (driver.Value, error) {
if len(data) > 0 && data[0] == 'x' {
//needed to test failures
return nil, errors.New("Should not start with 'x'")
}
if len(data) > 0 && data[0] == 'x' {
//needed to test failures
return nil, errors.New("Should not start with 'x'")
}
//prepend asterisks
return append([]byte("***"), data...), nil
//prepend asterisks
return append([]byte("***"), data...), nil
}
type Role struct {
Name string `gorm:"size:256"`
Name string `gorm:"size:256"`
}
func (role *Role) Scan(value interface{}) error {
if b, ok := value.([]uint8); ok {
role.Name = string(b)
} else {
role.Name = value.(string)
}
return nil
if b, ok := value.([]uint8); ok {
role.Name = string(b)
} else {
role.Name = value.(string)
}
return nil
}
func (role Role) Value() (driver.Value, error) {
return role.Name, nil
return role.Name, nil
}
func (role Role) IsAdmin() bool {
return role.Name == "admin"
return role.Name == "admin"
}
type Num int64
func (i *Num) Scan(src interface{}) error {
switch s := src.(type) {
case []byte:
n, _ := strconv.Atoi(string(s))
*i = Num(n)
case int64:
*i = Num(s)
default:
return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String())
}
return nil
switch s := src.(type) {
case []byte:
n, _ := strconv.Atoi(string(s))
*i = Num(n)
case int64:
*i = Num(s)
default:
return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String())
}
return nil
}
type Animal struct {
Counter uint64 `gorm:"primary_key:yes"`
Name string `sql:"DEFAULT:'galeone'"`
From string //test reserved sql keyword as field name
Age time.Time `sql:"DEFAULT:current_timestamp"`
unexported string // unexported value
CreatedAt time.Time
UpdatedAt time.Time
Counter uint64 `gorm:"primary_key:yes"`
Name string `sql:"DEFAULT:'galeone'"`
From string //test reserved sql keyword as field name
Age time.Time `sql:"DEFAULT:current_timestamp"`
unexported string // unexported value
CreatedAt time.Time
UpdatedAt time.Time
}
type JoinTable struct {
From uint64
To uint64
Time time.Time `sql:"default: null"`
From uint64
To uint64
Time time.Time `sql:"default: null"`
}
type Post struct {
Id int64
CategoryId sql.NullInt64
MainCategoryId int64
Title string
Body string
Comments []*Comment
Category Category
MainCategory Category
Id int64
CategoryId sql.NullInt64
MainCategoryId int64
Title string
Body string
Comments []*Comment
Category Category
MainCategory Category
}
type Category struct {
gorm.Model
Name string
gorm.Model
Name string
Categories []Category
CategoryID *uint
Categories []Category
CategoryID *uint
}
type Comment struct {
gorm.Model
PostId int64
Content string
Post Post
gorm.Model
PostId int64
Content string
Post Post
}
// Scanner
type NullValue struct {
Id int64
Name sql.NullString `sql:"not null"`
Gender *sql.NullString `sql:"not null"`
Age sql.NullInt64
Male sql.NullBool
Height sql.NullFloat64
AddedAt NullTime
Id int64
Name sql.NullString `sql:"not null"`
Gender *sql.NullString `sql:"not null"`
Age sql.NullInt64
Male sql.NullBool
Height sql.NullFloat64
AddedAt NullTime
}
type NullTime struct {
Time time.Time
Valid bool
Time time.Time
Valid bool
}
func (nt *NullTime) Scan(value interface{}) error {
if value == nil {
nt.Valid = false
return nil
}
nt.Time, nt.Valid = value.(time.Time), true
return nil
if value == nil {
nt.Valid = false
return nil
}
nt.Time, nt.Valid = value.(time.Time), true
return nil
}
func (nt NullTime) Value() (driver.Value, error) {
if !nt.Valid {
return nil, nil
}
return nt.Time, nil
if !nt.Valid {
return nil, nil
}
return nt.Time, nil
}
func getPreparedUser(name string, role string) *User {
var company Company
DB.Where(Company{Name: role}).FirstOrCreate(&company)
var company Company
DB.Where(Company{Name: role}).FirstOrCreate(&company)
return &User{
Name: name,
Age: 20,
Role: Role{role},
BillingAddress: Address{Address1: fmt.Sprintf("Billing Address %v", name)},
ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)},
CreditCard: CreditCard{Number: fmt.Sprintf("123456%v", name)},
Emails: []Email{
{Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)},
},
Company: company,
Languages: []Language{
{Name: fmt.Sprintf("lang_1_%v", name)},
{Name: fmt.Sprintf("lang_2_%v", name)},
},
}
return &User{
Name: name,
Age: 20,
Role: Role{role},
BillingAddress: Address{Address1: fmt.Sprintf("Billing Address %v", name)},
ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)},
CreditCard: CreditCard{Number: fmt.Sprintf("123456%v", name)},
Emails: []Email{
{Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)},
},
Company: company,
Languages: []Language{
{Name: fmt.Sprintf("lang_1_%v", name)},
{Name: fmt.Sprintf("lang_2_%v", name)},
},
}
}
func runMigration() {
if err := DB.DropTableIfExists(&User{}).Error; err != nil {
fmt.Printf("Got error when try to delete table users, %+v\n", err)
}
if err := DB.DropTableIfExists(&User{}).Error; err != nil {
fmt.Printf("Got error when try to delete table users, %+v\n", err)
}
for _, table := range []string{"animals", "user_languages"} {
DB.Exec(fmt.Sprintf("drop table %v;", table))
}
for _, table := range []string{"animals", "user_languages"} {
DB.Exec(fmt.Sprintf("drop table %v;", table))
}
values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}, &Place{}}
for _, value := range values {
DB.DropTable(value)
}
if err := DB.AutoMigrate(values...).Error; err != nil {
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
}
values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}, &Place{}}
for _, value := range values {
DB.DropTable(value)
}
if err := DB.AutoMigrate(values...).Error; err != nil {
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
}
}
func TestIndexes(t *testing.T) {
if err := DB.Model(&Email{}).AddIndex("idx_email_email", "email").Error; err != nil {
t.Errorf("Got error when tried to create index: %+v", err)
}
if err := DB.Model(&Email{}).AddIndex("idx_email_email", "email").Error; err != nil {
t.Errorf("Got error when tried to create index: %+v", err)
}
scope := DB.NewScope(&Email{})
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") {
t.Errorf("Email should have index idx_email_email")
}
scope := DB.NewScope(&Email{})
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") {
t.Errorf("Email should have index idx_email_email")
}
if err := DB.Model(&Email{}).RemoveIndex("idx_email_email").Error; err != nil {
t.Errorf("Got error when tried to remove index: %+v", err)
}
if err := DB.Model(&Email{}).RemoveIndex("idx_email_email").Error; err != nil {
t.Errorf("Got error when tried to remove index: %+v", err)
}
if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") {
t.Errorf("Email's index idx_email_email should be deleted")
}
if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") {
t.Errorf("Email's index idx_email_email should be deleted")
}
if err := DB.Model(&Email{}).AddIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil {
t.Errorf("Got error when tried to create index: %+v", err)
}
if err := DB.Model(&Email{}).AddIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil {
t.Errorf("Got error when tried to create index: %+v", err)
}
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
t.Errorf("Email should have index idx_email_email_and_user_id")
}
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
t.Errorf("Email should have index idx_email_email_and_user_id")
}
if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil {
t.Errorf("Got error when tried to remove index: %+v", err)
}
if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil {
t.Errorf("Got error when tried to remove index: %+v", err)
}
if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
}
if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
}
if err := DB.Model(&Email{}).AddUniqueIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil {
t.Errorf("Got error when tried to create index: %+v", err)
}
if err := DB.Model(&Email{}).AddUniqueIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil {
t.Errorf("Got error when tried to create index: %+v", err)
}
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
t.Errorf("Email should have index idx_email_email_and_user_id")
}
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
t.Errorf("Email should have index idx_email_email_and_user_id")
}
if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.comiii"}, {Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error == nil {
t.Errorf("Should get to create duplicate record when having unique index")
}
if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.comiii"}, {Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error == nil {
t.Errorf("Should get to create duplicate record when having unique index")
}
var user = User{Name: "sample_user"}
DB.Save(&user)
if DB.Model(&user).Association("Emails").Append(Email{Email: "not-1duplicated@gmail.com"}, Email{Email: "not-duplicated2@gmail.com"}).Error != nil {
t.Errorf("Should get no error when append two emails for user")
}
var user = User{Name: "sample_user"}
DB.Save(&user)
if DB.Model(&user).Association("Emails").Append(Email{Email: "not-1duplicated@gmail.com"}, Email{Email: "not-duplicated2@gmail.com"}).Error != nil {
t.Errorf("Should get no error when append two emails for user")
}
if DB.Model(&user).Association("Emails").Append(Email{Email: "duplicated@gmail.com"}, Email{Email: "duplicated@gmail.com"}).Error == nil {
t.Errorf("Should get no duplicated email error when insert duplicated emails for a user")
}
if DB.Model(&user).Association("Emails").Append(Email{Email: "duplicated@gmail.com"}, Email{Email: "duplicated@gmail.com"}).Error == nil {
t.Errorf("Should get no duplicated email error when insert duplicated emails for a user")
}
if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil {
t.Errorf("Got error when tried to remove index: %+v", err)
}
if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil {
t.Errorf("Got error when tried to remove index: %+v", err)
}
if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
}
if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
}
if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error != nil {
t.Errorf("Should be able to create duplicated emails after remove unique index")
}
if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error != nil {
t.Errorf("Should be able to create duplicated emails after remove unique index")
}
}
type EmailWithIdx struct {
Id int64
UserId int64
Email string `sql:"index:idx_email_agent"`
UserAgent string `sql:"index:idx_email_agent"`
RegisteredAt *time.Time `sql:"unique_index"`
CreatedAt time.Time
UpdatedAt time.Time
Id int64
UserId int64
Email string `sql:"index:idx_email_agent"`
UserAgent string `sql:"index:idx_email_agent"`
RegisteredAt *time.Time `sql:"unique_index"`
CreatedAt time.Time
UpdatedAt time.Time
}
func TestAutoMigration(t *testing.T) {
DB.AutoMigrate(&Address{})
DB.DropTable(&EmailWithIdx{})
if err := DB.AutoMigrate(&EmailWithIdx{}).Error; err != nil {
t.Errorf("Auto Migrate should not raise any error")
}
DB.AutoMigrate(&Address{})
DB.DropTable(&EmailWithIdx{})
if err := DB.AutoMigrate(&EmailWithIdx{}).Error; err != nil {
t.Errorf("Auto Migrate should not raise any error")
}
now := time.Now()
DB.Save(&EmailWithIdx{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: &now})
now := time.Now()
DB.Save(&EmailWithIdx{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: &now})
scope := DB.NewScope(&EmailWithIdx{})
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_agent") {
t.Errorf("Failed to create index")
}
scope := DB.NewScope(&EmailWithIdx{})
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_agent") {
t.Errorf("Failed to create index")
}
if !scope.Dialect().HasIndex(scope.TableName(), "uix_email_with_idxes_registered_at") {
t.Errorf("Failed to create index")
}
if !scope.Dialect().HasIndex(scope.TableName(), "uix_email_with_idxes_registered_at") {
t.Errorf("Failed to create index")
}
var bigemail EmailWithIdx
DB.First(&bigemail, "user_agent = ?", "pc")
if bigemail.Email != "jinzhu@example.org" || bigemail.UserAgent != "pc" || bigemail.RegisteredAt.IsZero() {
t.Error("Big Emails should be saved and fetched correctly")
}
var bigemail EmailWithIdx
DB.First(&bigemail, "user_agent = ?", "pc")
if bigemail.Email != "jinzhu@example.org" || bigemail.UserAgent != "pc" || bigemail.RegisteredAt.IsZero() {
t.Error("Big Emails should be saved and fetched correctly")
}
}
func TestCreateAndAutomigrateTransaction(t *testing.T) {
tx := DB.Begin()
tx := DB.Begin()
func() {
type Bar struct {
ID uint
}
DB.DropTableIfExists(&Bar{})
func() {
type Bar struct {
ID uint
}
DB.DropTableIfExists(&Bar{})
if ok := DB.HasTable("bars"); ok {
t.Errorf("Table should not exist, but does")
}
if ok := DB.HasTable("bars"); ok {
t.Errorf("Table should not exist, but does")
}
if ok := tx.HasTable("bars"); ok {
t.Errorf("Table should not exist, but does")
}
}()
if ok := tx.HasTable("bars"); ok {
t.Errorf("Table should not exist, but does")
}
}()
func() {
type Bar struct {
Name string
}
err := tx.CreateTable(&Bar{}).Error
func() {
type Bar struct {
Name string
}
err := tx.CreateTable(&Bar{}).Error
if err != nil {
t.Errorf("Should have been able to create the table, but couldn't: %s", err)
}
if err != nil {
t.Errorf("Should have been able to create the table, but couldn't: %s", err)
}
if ok := tx.HasTable(&Bar{}); !ok {
t.Errorf("The transaction should be able to see the table")
}
}()
if ok := tx.HasTable(&Bar{}); !ok {
t.Errorf("The transaction should be able to see the table")
}
}()
func() {
type Bar struct {
Stuff string
}
func() {
type Bar struct {
Stuff string
}
err := tx.AutoMigrate(&Bar{}).Error
if err != nil {
t.Errorf("Should have been able to alter the table, but couldn't")
}
}()
err := tx.AutoMigrate(&Bar{}).Error
if err != nil {
t.Errorf("Should have been able to alter the table, but couldn't")
}
}()
tx.Rollback()
tx.Rollback()
}
type MultipleIndexes struct {
ID int64
UserID int64 `sql:"unique_index:uix_multipleindexes_user_name,uix_multipleindexes_user_email;index:idx_multipleindexes_user_other"`
Name string `sql:"unique_index:uix_multipleindexes_user_name"`
Email string `sql:"unique_index:,uix_multipleindexes_user_email"`
Other string `sql:"index:,idx_multipleindexes_user_other"`
ID int64
UserID int64 `sql:"unique_index:uix_multipleindexes_user_name,uix_multipleindexes_user_email;index:idx_multipleindexes_user_other"`
Name string `sql:"unique_index:uix_multipleindexes_user_name"`
Email string `sql:"unique_index:,uix_multipleindexes_user_email"`
Other string `sql:"index:,idx_multipleindexes_user_other"`
}
func TestMultipleIndexes(t *testing.T) {
if err := DB.DropTableIfExists(&MultipleIndexes{}).Error; err != nil {
fmt.Printf("Got error when try to delete table multiple_indexes, %+v\n", err)
}
if err := DB.DropTableIfExists(&MultipleIndexes{}).Error; err != nil {
fmt.Printf("Got error when try to delete table multiple_indexes, %+v\n", err)
}
DB.AutoMigrate(&MultipleIndexes{})
if err := DB.AutoMigrate(&EmailWithIdx{}).Error; err != nil {
t.Errorf("Auto Migrate should not raise any error")
}
DB.AutoMigrate(&MultipleIndexes{})
if err := DB.AutoMigrate(&EmailWithIdx{}).Error; err != nil {
t.Errorf("Auto Migrate should not raise any error")
}
DB.Save(&MultipleIndexes{UserID: 1, Name: "jinzhu", Email: "jinzhu@example.org", Other: "foo"})
DB.Save(&MultipleIndexes{UserID: 1, Name: "jinzhu", Email: "jinzhu@example.org", Other: "foo"})
scope := DB.NewScope(&MultipleIndexes{})
if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_name") {
t.Errorf("Failed to create index")
}
scope := DB.NewScope(&MultipleIndexes{})
if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_name") {
t.Errorf("Failed to create index")
}
if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_email") {
t.Errorf("Failed to create index")
}
if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_email") {
t.Errorf("Failed to create index")
}
if !scope.Dialect().HasIndex(scope.TableName(), "uix_multiple_indexes_email") {
t.Errorf("Failed to create index")
}
if !scope.Dialect().HasIndex(scope.TableName(), "uix_multiple_indexes_email") {
t.Errorf("Failed to create index")
}
if !scope.Dialect().HasIndex(scope.TableName(), "idx_multipleindexes_user_other") {
t.Errorf("Failed to create index")
}
if !scope.Dialect().HasIndex(scope.TableName(), "idx_multipleindexes_user_other") {
t.Errorf("Failed to create index")
}
if !scope.Dialect().HasIndex(scope.TableName(), "idx_multiple_indexes_other") {
t.Errorf("Failed to create index")
}
if !scope.Dialect().HasIndex(scope.TableName(), "idx_multiple_indexes_other") {
t.Errorf("Failed to create index")
}
var mutipleIndexes MultipleIndexes
DB.First(&mutipleIndexes, "name = ?", "jinzhu")
if mutipleIndexes.Email != "jinzhu@example.org" || mutipleIndexes.Name != "jinzhu" {
t.Error("MutipleIndexes should be saved and fetched correctly")
}
var mutipleIndexes MultipleIndexes
DB.First(&mutipleIndexes, "name = ?", "jinzhu")
if mutipleIndexes.Email != "jinzhu@example.org" || mutipleIndexes.Name != "jinzhu" {
t.Error("MutipleIndexes should be saved and fetched correctly")
}
// Check unique constraints
if err := DB.Save(&MultipleIndexes{UserID: 1, Name: "name1", Email: "jinzhu@example.org", Other: "foo"}).Error; err == nil {
t.Error("MultipleIndexes unique index failed")
}
// Check unique constraints
if err := DB.Save(&MultipleIndexes{UserID: 1, Name: "name1", Email: "jinzhu@example.org", Other: "foo"}).Error; err == nil {
t.Error("MultipleIndexes unique index failed")
}
if err := DB.Save(&MultipleIndexes{UserID: 1, Name: "name1", Email: "foo@example.org", Other: "foo"}).Error; err != nil {
t.Error("MultipleIndexes unique index failed")
}
if err := DB.Save(&MultipleIndexes{UserID: 1, Name: "name1", Email: "foo@example.org", Other: "foo"}).Error; err != nil {
t.Error("MultipleIndexes unique index failed")
}
if err := DB.Save(&MultipleIndexes{UserID: 2, Name: "name1", Email: "jinzhu@example.org", Other: "foo"}).Error; err == nil {
t.Error("MultipleIndexes unique index failed")
}
if err := DB.Save(&MultipleIndexes{UserID: 2, Name: "name1", Email: "jinzhu@example.org", Other: "foo"}).Error; err == nil {
t.Error("MultipleIndexes unique index failed")
}
if err := DB.Save(&MultipleIndexes{UserID: 2, Name: "name1", Email: "foo2@example.org", Other: "foo"}).Error; err != nil {
t.Error("MultipleIndexes unique index failed")
}
if err := DB.Save(&MultipleIndexes{UserID: 2, Name: "name1", Email: "foo2@example.org", Other: "foo"}).Error; err != nil {
t.Error("MultipleIndexes unique index failed")
}
}
func TestModifyColumnType(t *testing.T) {
if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" && dialect != "mysql" && dialect != "mssql" {
t.Skip("Skipping this because only postgres, mysql and mssql support altering a column type")
}
if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" && dialect != "mysql" && dialect != "mssql" {
t.Skip("Skipping this because only postgres, mysql and mssql support altering a column type")
}
type ModifyColumnType struct {
gorm.Model
Name1 string `gorm:"length:100"`
Name2 string `gorm:"length:200"`
}
DB.DropTable(&ModifyColumnType{})
DB.CreateTable(&ModifyColumnType{})
type ModifyColumnType struct {
gorm.Model
Name1 string `gorm:"length:100"`
Name2 string `gorm:"length:200"`
}
DB.DropTable(&ModifyColumnType{})
DB.CreateTable(&ModifyColumnType{})
name2Field, _ := DB.NewScope(&ModifyColumnType{}).FieldByName("Name2")
name2Type := DB.Dialect().DataTypeOf(name2Field.StructField)
name2Field, _ := DB.NewScope(&ModifyColumnType{}).FieldByName("Name2")
name2Type := DB.Dialect().DataTypeOf(name2Field.StructField)
if err := DB.Model(&ModifyColumnType{}).ModifyColumn("name1", name2Type).Error; err != nil {
t.Errorf("No error should happen when ModifyColumn, but got %v", err)
}
if err := DB.Model(&ModifyColumnType{}).ModifyColumn("name1", name2Type).Error; err != nil {
t.Errorf("No error should happen when ModifyColumn, but got %v", err)
}
}
func TestIndexWithPrefixLength(t *testing.T) {
if dialect := os.Getenv("GORM_DIALECT"); dialect != "mysql" {
t.Skip("Skipping this because only mysql support setting an index prefix length")
}
if dialect := os.Getenv("GORM_DIALECT"); dialect != "mysql" {
t.Skip("Skipping this because only mysql support setting an index prefix length")
}
type IndexWithPrefix struct {
gorm.Model
Name string
Description string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"`
}
type IndexesWithPrefix struct {
gorm.Model
Name string
Description1 string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"`
Description2 string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"`
}
type IndexesWithPrefixAndWithoutPrefix struct {
gorm.Model
Name string `gorm:"index:idx_index_with_prefixes_length"`
Description string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"`
}
tables := []interface{}{&IndexWithPrefix{}, &IndexesWithPrefix{}, &IndexesWithPrefixAndWithoutPrefix{}}
for _, table := range tables {
scope := DB.NewScope(table)
tableName := scope.TableName()
t.Run(fmt.Sprintf("Create index with prefix length: %s", tableName), func(t *testing.T) {
if err := DB.DropTableIfExists(table).Error; err != nil {
t.Errorf("Failed to drop %s table: %v", tableName, err)
}
if err := DB.CreateTable(table).Error; err != nil {
t.Errorf("Failed to create %s table: %v", tableName, err)
}
if !scope.Dialect().HasIndex(tableName, "idx_index_with_prefixes_length") {
t.Errorf("Failed to create %s table index:", tableName)
}
})
}
type IndexWithPrefix struct {
gorm.Model
Name string
Description string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"`
}
type IndexesWithPrefix struct {
gorm.Model
Name string
Description1 string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"`
Description2 string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"`
}
type IndexesWithPrefixAndWithoutPrefix struct {
gorm.Model
Name string `gorm:"index:idx_index_with_prefixes_length"`
Description string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"`
}
tables := []interface{}{&IndexWithPrefix{}, &IndexesWithPrefix{}, &IndexesWithPrefixAndWithoutPrefix{}}
for _, table := range tables {
scope := DB.NewScope(table)
tableName := scope.TableName()
t.Run(fmt.Sprintf("Create index with prefix length: %s", tableName), func(t *testing.T) {
if err := DB.DropTableIfExists(table).Error; err != nil {
t.Errorf("Failed to drop %s table: %v", tableName, err)
}
if err := DB.CreateTable(table).Error; err != nil {
t.Errorf("Failed to create %s table: %v", tableName, err)
}
if !scope.Dialect().HasIndex(tableName, "idx_index_with_prefixes_length") {
t.Errorf("Failed to create %s table index:", tableName)
}
})
}
}

File diff suppressed because it is too large Load Diff