go fmt
This commit is contained in:
parent
75d6dc912c
commit
f424f8aa2e
82
callback.go
82
callback.go
@ -97,27 +97,27 @@ func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor {
|
|||||||
|
|
||||||
// Register a new callback, refer `Callbacks.Create`
|
// Register a new callback, refer `Callbacks.Create`
|
||||||
func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) {
|
func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) {
|
||||||
callbackContext := func(ctx context.Context, scope *Scope) {
|
callbackContext := func(ctx context.Context, scope *Scope) {
|
||||||
callback(scope)
|
callback(scope)
|
||||||
}
|
}
|
||||||
|
|
||||||
cp.RegisterContext(callbackName, callbackContext)
|
cp.RegisterContext(callbackName, callbackContext)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterContext same as Register
|
// RegisterContext same as Register
|
||||||
func (cp *CallbackProcessor) RegisterContext(callbackName string, callback func(ctx context.Context, scope *Scope)) {
|
func (cp *CallbackProcessor) RegisterContext(callbackName string, callback func(ctx context.Context, scope *Scope)) {
|
||||||
if cp.kind == "row_query" {
|
if cp.kind == "row_query" {
|
||||||
if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" {
|
if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" {
|
||||||
cp.logger.Print("info", fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...", callbackName))
|
cp.logger.Print("info", fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...", callbackName))
|
||||||
cp.before = "gorm:row_query"
|
cp.before = "gorm:row_query"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cp.logger.Print("info", fmt.Sprintf("[info] registering callback `%v` from %v", callbackName, fileWithLineNum()))
|
cp.logger.Print("info", fmt.Sprintf("[info] registering callback `%v` from %v", callbackName, fileWithLineNum()))
|
||||||
cp.name = callbackName
|
cp.name = callbackName
|
||||||
cp.processor = &callback
|
cp.processor = &callback
|
||||||
cp.parent.processors = append(cp.parent.processors, cp)
|
cp.parent.processors = append(cp.parent.processors, cp)
|
||||||
cp.parent.reorder()
|
cp.parent.reorder()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove a registered callback
|
// Remove a registered callback
|
||||||
@ -136,47 +136,47 @@ func (cp *CallbackProcessor) Remove(callbackName string) {
|
|||||||
// scope.SetColumn("UpdatedAt", now)
|
// scope.SetColumn("UpdatedAt", now)
|
||||||
// })
|
// })
|
||||||
func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
|
func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
|
||||||
callbackContext := func(ctx context.Context, scope *Scope) {
|
callbackContext := func(ctx context.Context, scope *Scope) {
|
||||||
callback(scope)
|
callback(scope)
|
||||||
}
|
}
|
||||||
|
|
||||||
cp.ReplaceContext(callbackName, callbackContext)
|
cp.ReplaceContext(callbackName, callbackContext)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReplaceContext same as Replace
|
// ReplaceContext same as Replace
|
||||||
func (cp *CallbackProcessor) ReplaceContext(callbackName string, callback func(ctx context.Context, scope *Scope)) {
|
func (cp *CallbackProcessor) ReplaceContext(callbackName string, callback func(ctx context.Context, scope *Scope)) {
|
||||||
cp.logger.Print("info", fmt.Sprintf("[info] replacing callback `%v` from %v", callbackName, fileWithLineNum()))
|
cp.logger.Print("info", fmt.Sprintf("[info] replacing callback `%v` from %v", callbackName, fileWithLineNum()))
|
||||||
cp.name = callbackName
|
cp.name = callbackName
|
||||||
cp.processor = &callback
|
cp.processor = &callback
|
||||||
cp.replace = true
|
cp.replace = true
|
||||||
cp.parent.processors = append(cp.parent.processors, cp)
|
cp.parent.processors = append(cp.parent.processors, cp)
|
||||||
cp.parent.reorder()
|
cp.parent.reorder()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get registered callback
|
// Get registered callback
|
||||||
// db.Callback().Create().Get("gorm:create")
|
// db.Callback().Create().Get("gorm:create")
|
||||||
func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) {
|
func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) {
|
||||||
c := cp.GetContext(callbackName)
|
c := cp.GetContext(callbackName)
|
||||||
|
|
||||||
callback = func(scope *Scope) {
|
callback = func(scope *Scope) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
c(ctx, scope)
|
c(ctx, scope)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetContext same as Get
|
// GetContext same as Get
|
||||||
func (cp *CallbackProcessor) GetContext(callbackName string) (callback func(ctx context.Context, scope *Scope)) {
|
func (cp *CallbackProcessor) GetContext(callbackName string) (callback func(ctx context.Context, scope *Scope)) {
|
||||||
for _, p := range cp.parent.processors {
|
for _, p := range cp.parent.processors {
|
||||||
if p.name == callbackName && p.kind == cp.kind {
|
if p.name == callbackName && p.kind == cp.kind {
|
||||||
if p.remove {
|
if p.remove {
|
||||||
callback = nil
|
callback = nil
|
||||||
} else {
|
} else {
|
||||||
callback = *p.processor
|
callback = *p.processor
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// getRIndex get right index from string slice
|
// getRIndex get right index from string slice
|
||||||
|
@ -1,20 +1,20 @@
|
|||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"reflect"
|
"reflect"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func equalFuncs(funcs []*func(ctx context.Context, s *Scope), fnames []string) bool {
|
func equalFuncs(funcs []*func(ctx context.Context, s *Scope), fnames []string) bool {
|
||||||
var names []string
|
var names []string
|
||||||
for _, f := range funcs {
|
for _, f := range funcs {
|
||||||
fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(*f).Pointer()).Name(), ".")
|
fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(*f).Pointer()).Name(), ".")
|
||||||
names = append(names, fnames[len(fnames)-1])
|
names = append(names, fnames[len(fnames)-1])
|
||||||
}
|
}
|
||||||
return reflect.DeepEqual(names, fnames)
|
return reflect.DeepEqual(names, fnames)
|
||||||
}
|
}
|
||||||
|
|
||||||
func create(s *Scope) {}
|
func create(s *Scope) {}
|
||||||
@ -24,90 +24,90 @@ func afterCreate1(s *Scope) {}
|
|||||||
func afterCreate2(s *Scope) {}
|
func afterCreate2(s *Scope) {}
|
||||||
|
|
||||||
func TestRegisterCallback(t *testing.T) {
|
func TestRegisterCallback(t *testing.T) {
|
||||||
var callback = &Callback{logger: defaultLogger}
|
var callback = &Callback{logger: defaultLogger}
|
||||||
|
|
||||||
callback.Create().Register("before_create1", beforeCreate1)
|
callback.Create().Register("before_create1", beforeCreate1)
|
||||||
callback.Create().Register("before_create2", beforeCreate2)
|
callback.Create().Register("before_create2", beforeCreate2)
|
||||||
callback.Create().Register("create", create)
|
callback.Create().Register("create", create)
|
||||||
callback.Create().Register("after_create1", afterCreate1)
|
callback.Create().Register("after_create1", afterCreate1)
|
||||||
callback.Create().Register("after_create2", afterCreate2)
|
callback.Create().Register("after_create2", afterCreate2)
|
||||||
|
|
||||||
if !equalFuncs(callback.creates, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
|
if !equalFuncs(callback.creates, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
|
||||||
t.Errorf("register callback")
|
t.Errorf("register callback")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRegisterCallbackWithOrder(t *testing.T) {
|
func TestRegisterCallbackWithOrder(t *testing.T) {
|
||||||
var callback1 = &Callback{logger: defaultLogger}
|
var callback1 = &Callback{logger: defaultLogger}
|
||||||
callback1.Create().Register("before_create1", beforeCreate1)
|
callback1.Create().Register("before_create1", beforeCreate1)
|
||||||
callback1.Create().Register("create", create)
|
callback1.Create().Register("create", create)
|
||||||
callback1.Create().Register("after_create1", afterCreate1)
|
callback1.Create().Register("after_create1", afterCreate1)
|
||||||
callback1.Create().Before("after_create1").Register("after_create2", afterCreate2)
|
callback1.Create().Before("after_create1").Register("after_create2", afterCreate2)
|
||||||
if !equalFuncs(callback1.creates, []string{"beforeCreate1", "create", "afterCreate2", "afterCreate1"}) {
|
if !equalFuncs(callback1.creates, []string{"beforeCreate1", "create", "afterCreate2", "afterCreate1"}) {
|
||||||
t.Errorf("register callback with order")
|
t.Errorf("register callback with order")
|
||||||
}
|
}
|
||||||
|
|
||||||
var callback2 = &Callback{logger: defaultLogger}
|
var callback2 = &Callback{logger: defaultLogger}
|
||||||
|
|
||||||
callback2.Update().Register("create", create)
|
callback2.Update().Register("create", create)
|
||||||
callback2.Update().Before("create").Register("before_create1", beforeCreate1)
|
callback2.Update().Before("create").Register("before_create1", beforeCreate1)
|
||||||
callback2.Update().After("after_create2").Register("after_create1", afterCreate1)
|
callback2.Update().After("after_create2").Register("after_create1", afterCreate1)
|
||||||
callback2.Update().Before("before_create1").Register("before_create2", beforeCreate2)
|
callback2.Update().Before("before_create1").Register("before_create2", beforeCreate2)
|
||||||
callback2.Update().Register("after_create2", afterCreate2)
|
callback2.Update().Register("after_create2", afterCreate2)
|
||||||
|
|
||||||
if !equalFuncs(callback2.updates, []string{"beforeCreate2", "beforeCreate1", "create", "afterCreate2", "afterCreate1"}) {
|
if !equalFuncs(callback2.updates, []string{"beforeCreate2", "beforeCreate1", "create", "afterCreate2", "afterCreate1"}) {
|
||||||
t.Errorf("register callback with order")
|
t.Errorf("register callback with order")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRegisterCallbackWithComplexOrder(t *testing.T) {
|
func TestRegisterCallbackWithComplexOrder(t *testing.T) {
|
||||||
var callback1 = &Callback{logger: defaultLogger}
|
var callback1 = &Callback{logger: defaultLogger}
|
||||||
|
|
||||||
callback1.Query().Before("after_create1").After("before_create1").Register("create", create)
|
callback1.Query().Before("after_create1").After("before_create1").Register("create", create)
|
||||||
callback1.Query().Register("before_create1", beforeCreate1)
|
callback1.Query().Register("before_create1", beforeCreate1)
|
||||||
callback1.Query().Register("after_create1", afterCreate1)
|
callback1.Query().Register("after_create1", afterCreate1)
|
||||||
|
|
||||||
if !equalFuncs(callback1.queries, []string{"beforeCreate1", "create", "afterCreate1"}) {
|
if !equalFuncs(callback1.queries, []string{"beforeCreate1", "create", "afterCreate1"}) {
|
||||||
t.Errorf("register callback with order")
|
t.Errorf("register callback with order")
|
||||||
}
|
}
|
||||||
|
|
||||||
var callback2 = &Callback{logger: defaultLogger}
|
var callback2 = &Callback{logger: defaultLogger}
|
||||||
|
|
||||||
callback2.Delete().Before("after_create1").After("before_create1").Register("create", create)
|
callback2.Delete().Before("after_create1").After("before_create1").Register("create", create)
|
||||||
callback2.Delete().Before("create").Register("before_create1", beforeCreate1)
|
callback2.Delete().Before("create").Register("before_create1", beforeCreate1)
|
||||||
callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2)
|
callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2)
|
||||||
callback2.Delete().Register("after_create1", afterCreate1)
|
callback2.Delete().Register("after_create1", afterCreate1)
|
||||||
callback2.Delete().After("after_create1").Register("after_create2", afterCreate2)
|
callback2.Delete().After("after_create1").Register("after_create2", afterCreate2)
|
||||||
|
|
||||||
if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
|
if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
|
||||||
t.Errorf("register callback with order")
|
t.Errorf("register callback with order")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func replaceCreate(s *Scope) {}
|
func replaceCreate(s *Scope) {}
|
||||||
|
|
||||||
func TestReplaceCallback(t *testing.T) {
|
func TestReplaceCallback(t *testing.T) {
|
||||||
var callback = &Callback{logger: defaultLogger}
|
var callback = &Callback{logger: defaultLogger}
|
||||||
|
|
||||||
callback.Create().Before("after_create1").After("before_create1").Register("create", create)
|
callback.Create().Before("after_create1").After("before_create1").Register("create", create)
|
||||||
callback.Create().Register("before_create1", beforeCreate1)
|
callback.Create().Register("before_create1", beforeCreate1)
|
||||||
callback.Create().Register("after_create1", afterCreate1)
|
callback.Create().Register("after_create1", afterCreate1)
|
||||||
callback.Create().Replace("create", replaceCreate)
|
callback.Create().Replace("create", replaceCreate)
|
||||||
|
|
||||||
if !equalFuncs(callback.creates, []string{"beforeCreate1", "replaceCreate", "afterCreate1"}) {
|
if !equalFuncs(callback.creates, []string{"beforeCreate1", "replaceCreate", "afterCreate1"}) {
|
||||||
t.Errorf("replace callback")
|
t.Errorf("replace callback")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRemoveCallback(t *testing.T) {
|
func TestRemoveCallback(t *testing.T) {
|
||||||
var callback = &Callback{logger: defaultLogger}
|
var callback = &Callback{logger: defaultLogger}
|
||||||
|
|
||||||
callback.Create().Before("after_create1").After("before_create1").Register("create", create)
|
callback.Create().Before("after_create1").After("before_create1").Register("create", create)
|
||||||
callback.Create().Register("before_create1", beforeCreate1)
|
callback.Create().Register("before_create1", beforeCreate1)
|
||||||
callback.Create().Register("after_create1", afterCreate1)
|
callback.Create().Register("after_create1", afterCreate1)
|
||||||
callback.Create().Remove("create")
|
callback.Create().Remove("create")
|
||||||
|
|
||||||
if !equalFuncs(callback.creates, []string{"beforeCreate1", "afterCreate1"}) {
|
if !equalFuncs(callback.creates, []string{"beforeCreate1", "afterCreate1"}) {
|
||||||
t.Errorf("remove callback")
|
t.Errorf("remove callback")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,249 +1,249 @@
|
|||||||
package gorm_test
|
package gorm_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Product) BeforeCreate() (err error) {
|
func (s *Product) BeforeCreate() (err error) {
|
||||||
if s.Code == "Invalid" {
|
if s.Code == "Invalid" {
|
||||||
err = errors.New("invalid product")
|
err = errors.New("invalid product")
|
||||||
}
|
}
|
||||||
s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1
|
s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Product) BeforeUpdate() (err error) {
|
func (s *Product) BeforeUpdate() (err error) {
|
||||||
if s.Code == "dont_update" {
|
if s.Code == "dont_update" {
|
||||||
err = errors.New("can't update")
|
err = errors.New("can't update")
|
||||||
}
|
}
|
||||||
s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1
|
s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Product) BeforeSave() (err error) {
|
func (s *Product) BeforeSave() (err error) {
|
||||||
if s.Code == "dont_save" {
|
if s.Code == "dont_save" {
|
||||||
err = errors.New("can't save")
|
err = errors.New("can't save")
|
||||||
}
|
}
|
||||||
s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1
|
s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Product) AfterFind() {
|
func (s *Product) AfterFind() {
|
||||||
s.AfterFindCallTimes = s.AfterFindCallTimes + 1
|
s.AfterFindCallTimes = s.AfterFindCallTimes + 1
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Product) AfterCreate(tx *gorm.DB) {
|
func (s *Product) AfterCreate(tx *gorm.DB) {
|
||||||
tx.Model(s).UpdateColumn(Product{AfterCreateCallTimes: s.AfterCreateCallTimes + 1})
|
tx.Model(s).UpdateColumn(Product{AfterCreateCallTimes: s.AfterCreateCallTimes + 1})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Product) AfterUpdate() {
|
func (s *Product) AfterUpdate() {
|
||||||
s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1
|
s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Product) AfterSave() (err error) {
|
func (s *Product) AfterSave() (err error) {
|
||||||
if s.Code == "after_save_error" {
|
if s.Code == "after_save_error" {
|
||||||
err = errors.New("can't save")
|
err = errors.New("can't save")
|
||||||
}
|
}
|
||||||
s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1
|
s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Product) BeforeDelete() (err error) {
|
func (s *Product) BeforeDelete() (err error) {
|
||||||
if s.Code == "dont_delete" {
|
if s.Code == "dont_delete" {
|
||||||
err = errors.New("can't delete")
|
err = errors.New("can't delete")
|
||||||
}
|
}
|
||||||
s.BeforeDeleteCallTimes = s.BeforeDeleteCallTimes + 1
|
s.BeforeDeleteCallTimes = s.BeforeDeleteCallTimes + 1
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Product) AfterDelete() (err error) {
|
func (s *Product) AfterDelete() (err error) {
|
||||||
if s.Code == "after_delete_error" {
|
if s.Code == "after_delete_error" {
|
||||||
err = errors.New("can't delete")
|
err = errors.New("can't delete")
|
||||||
}
|
}
|
||||||
s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1
|
s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Product) GetCallTimes() []int64 {
|
func (s *Product) GetCallTimes() []int64 {
|
||||||
return []int64{s.BeforeCreateCallTimes, s.BeforeSaveCallTimes, s.BeforeUpdateCallTimes, s.AfterCreateCallTimes, s.AfterSaveCallTimes, s.AfterUpdateCallTimes, s.BeforeDeleteCallTimes, s.AfterDeleteCallTimes, s.AfterFindCallTimes}
|
return []int64{s.BeforeCreateCallTimes, s.BeforeSaveCallTimes, s.BeforeUpdateCallTimes, s.AfterCreateCallTimes, s.AfterSaveCallTimes, s.AfterUpdateCallTimes, s.BeforeDeleteCallTimes, s.AfterDeleteCallTimes, s.AfterFindCallTimes}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRunCallbacks(t *testing.T) {
|
func TestRunCallbacks(t *testing.T) {
|
||||||
p := Product{Code: "unique_code", Price: 100}
|
p := Product{Code: "unique_code", Price: 100}
|
||||||
DB.Save(&p)
|
DB.Save(&p)
|
||||||
|
|
||||||
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0}) {
|
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0}) {
|
||||||
t.Errorf("Callbacks should be invoked successfully, %v", p.GetCallTimes())
|
t.Errorf("Callbacks should be invoked successfully, %v", p.GetCallTimes())
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Where("Code = ?", "unique_code").First(&p)
|
DB.Where("Code = ?", "unique_code").First(&p)
|
||||||
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0, 1}) {
|
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0, 1}) {
|
||||||
t.Errorf("After callbacks values are not saved, %v", p.GetCallTimes())
|
t.Errorf("After callbacks values are not saved, %v", p.GetCallTimes())
|
||||||
}
|
}
|
||||||
|
|
||||||
p.Price = 200
|
p.Price = 200
|
||||||
DB.Save(&p)
|
DB.Save(&p)
|
||||||
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0, 1}) {
|
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0, 1}) {
|
||||||
t.Errorf("After update callbacks should be invoked successfully, %v", p.GetCallTimes())
|
t.Errorf("After update callbacks should be invoked successfully, %v", p.GetCallTimes())
|
||||||
}
|
}
|
||||||
|
|
||||||
var products []Product
|
var products []Product
|
||||||
DB.Find(&products, "code = ?", "unique_code")
|
DB.Find(&products, "code = ?", "unique_code")
|
||||||
if products[0].AfterFindCallTimes != 2 {
|
if products[0].AfterFindCallTimes != 2 {
|
||||||
t.Errorf("AfterFind callbacks should work with slice")
|
t.Errorf("AfterFind callbacks should work with slice")
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Where("Code = ?", "unique_code").First(&p)
|
DB.Where("Code = ?", "unique_code").First(&p)
|
||||||
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2}) {
|
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2}) {
|
||||||
t.Errorf("After update callbacks values are not saved, %v", p.GetCallTimes())
|
t.Errorf("After update callbacks values are not saved, %v", p.GetCallTimes())
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Delete(&p)
|
DB.Delete(&p)
|
||||||
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2}) {
|
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2}) {
|
||||||
t.Errorf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes())
|
t.Errorf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes())
|
||||||
}
|
}
|
||||||
|
|
||||||
if DB.Where("Code = ?", "unique_code").First(&p).Error == nil {
|
if DB.Where("Code = ?", "unique_code").First(&p).Error == nil {
|
||||||
t.Errorf("Can't find a deleted record")
|
t.Errorf("Can't find a deleted record")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCallbacksWithErrors(t *testing.T) {
|
func TestCallbacksWithErrors(t *testing.T) {
|
||||||
p := Product{Code: "Invalid", Price: 100}
|
p := Product{Code: "Invalid", Price: 100}
|
||||||
if DB.Save(&p).Error == nil {
|
if DB.Save(&p).Error == nil {
|
||||||
t.Errorf("An error from before create callbacks happened when create with invalid value")
|
t.Errorf("An error from before create callbacks happened when create with invalid value")
|
||||||
}
|
}
|
||||||
|
|
||||||
if DB.Where("code = ?", "Invalid").First(&Product{}).Error == nil {
|
if DB.Where("code = ?", "Invalid").First(&Product{}).Error == nil {
|
||||||
t.Errorf("Should not save record that have errors")
|
t.Errorf("Should not save record that have errors")
|
||||||
}
|
}
|
||||||
|
|
||||||
if DB.Save(&Product{Code: "dont_save", Price: 100}).Error == nil {
|
if DB.Save(&Product{Code: "dont_save", Price: 100}).Error == nil {
|
||||||
t.Errorf("An error from after create callbacks happened when create with invalid value")
|
t.Errorf("An error from after create callbacks happened when create with invalid value")
|
||||||
}
|
}
|
||||||
|
|
||||||
p2 := Product{Code: "update_callback", Price: 100}
|
p2 := Product{Code: "update_callback", Price: 100}
|
||||||
DB.Save(&p2)
|
DB.Save(&p2)
|
||||||
|
|
||||||
p2.Code = "dont_update"
|
p2.Code = "dont_update"
|
||||||
if DB.Save(&p2).Error == nil {
|
if DB.Save(&p2).Error == nil {
|
||||||
t.Errorf("An error from before update callbacks happened when update with invalid value")
|
t.Errorf("An error from before update callbacks happened when update with invalid value")
|
||||||
}
|
}
|
||||||
|
|
||||||
if DB.Where("code = ?", "update_callback").First(&Product{}).Error != nil {
|
if DB.Where("code = ?", "update_callback").First(&Product{}).Error != nil {
|
||||||
t.Errorf("Record Should not be updated due to errors happened in before update callback")
|
t.Errorf("Record Should not be updated due to errors happened in before update callback")
|
||||||
}
|
}
|
||||||
|
|
||||||
if DB.Where("code = ?", "dont_update").First(&Product{}).Error == nil {
|
if DB.Where("code = ?", "dont_update").First(&Product{}).Error == nil {
|
||||||
t.Errorf("Record Should not be updated due to errors happened in before update callback")
|
t.Errorf("Record Should not be updated due to errors happened in before update callback")
|
||||||
}
|
}
|
||||||
|
|
||||||
p2.Code = "dont_save"
|
p2.Code = "dont_save"
|
||||||
if DB.Save(&p2).Error == nil {
|
if DB.Save(&p2).Error == nil {
|
||||||
t.Errorf("An error from before save callbacks happened when update with invalid value")
|
t.Errorf("An error from before save callbacks happened when update with invalid value")
|
||||||
}
|
}
|
||||||
|
|
||||||
p3 := Product{Code: "dont_delete", Price: 100}
|
p3 := Product{Code: "dont_delete", Price: 100}
|
||||||
DB.Save(&p3)
|
DB.Save(&p3)
|
||||||
if DB.Delete(&p3).Error == nil {
|
if DB.Delete(&p3).Error == nil {
|
||||||
t.Errorf("An error from before delete callbacks happened when delete")
|
t.Errorf("An error from before delete callbacks happened when delete")
|
||||||
}
|
}
|
||||||
|
|
||||||
if DB.Where("Code = ?", "dont_delete").First(&p3).Error != nil {
|
if DB.Where("Code = ?", "dont_delete").First(&p3).Error != nil {
|
||||||
t.Errorf("An error from before delete callbacks happened")
|
t.Errorf("An error from before delete callbacks happened")
|
||||||
}
|
}
|
||||||
|
|
||||||
p4 := Product{Code: "after_save_error", Price: 100}
|
p4 := Product{Code: "after_save_error", Price: 100}
|
||||||
DB.Save(&p4)
|
DB.Save(&p4)
|
||||||
if err := DB.First(&Product{}, "code = ?", "after_save_error").Error; err == nil {
|
if err := DB.First(&Product{}, "code = ?", "after_save_error").Error; err == nil {
|
||||||
t.Errorf("Record should be reverted if get an error in after save callback")
|
t.Errorf("Record should be reverted if get an error in after save callback")
|
||||||
}
|
}
|
||||||
|
|
||||||
p5 := Product{Code: "after_delete_error", Price: 100}
|
p5 := Product{Code: "after_delete_error", Price: 100}
|
||||||
DB.Save(&p5)
|
DB.Save(&p5)
|
||||||
if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
|
if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
|
||||||
t.Errorf("Record should be found")
|
t.Errorf("Record should be found")
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Delete(&p5)
|
DB.Delete(&p5)
|
||||||
if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
|
if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
|
||||||
t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback")
|
t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetCallback(t *testing.T) {
|
func TestGetCallback(t *testing.T) {
|
||||||
scope := DB.NewScope(nil)
|
scope := DB.NewScope(nil)
|
||||||
|
|
||||||
if DB.Callback().Create().Get("gorm:test_callback") != nil {
|
if DB.Callback().Create().Get("gorm:test_callback") != nil {
|
||||||
t.Errorf("`gorm:test_callback` should be nil")
|
t.Errorf("`gorm:test_callback` should be nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 1) })
|
DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 1) })
|
||||||
callback := DB.Callback().Create().Get("gorm:test_callback")
|
callback := DB.Callback().Create().Get("gorm:test_callback")
|
||||||
if callback == nil {
|
if callback == nil {
|
||||||
t.Errorf("`gorm:test_callback` should be non-nil")
|
t.Errorf("`gorm:test_callback` should be non-nil")
|
||||||
}
|
}
|
||||||
callback(scope)
|
callback(scope)
|
||||||
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 1 {
|
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 1 {
|
||||||
t.Errorf("`gorm:test_callback_value` should be `1, true` but `%v, %v`", v, ok)
|
t.Errorf("`gorm:test_callback_value` should be `1, true` but `%v, %v`", v, ok)
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Callback().Create().Replace("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 2) })
|
DB.Callback().Create().Replace("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 2) })
|
||||||
callback = DB.Callback().Create().Get("gorm:test_callback")
|
callback = DB.Callback().Create().Get("gorm:test_callback")
|
||||||
if callback == nil {
|
if callback == nil {
|
||||||
t.Errorf("`gorm:test_callback` should be non-nil")
|
t.Errorf("`gorm:test_callback` should be non-nil")
|
||||||
}
|
}
|
||||||
callback(scope)
|
callback(scope)
|
||||||
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 2 {
|
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 2 {
|
||||||
t.Errorf("`gorm:test_callback_value` should be `2, true` but `%v, %v`", v, ok)
|
t.Errorf("`gorm:test_callback_value` should be `2, true` but `%v, %v`", v, ok)
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Callback().Create().Remove("gorm:test_callback")
|
DB.Callback().Create().Remove("gorm:test_callback")
|
||||||
if DB.Callback().Create().Get("gorm:test_callback") != nil {
|
if DB.Callback().Create().Get("gorm:test_callback") != nil {
|
||||||
t.Errorf("`gorm:test_callback` should be nil")
|
t.Errorf("`gorm:test_callback` should be nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 3) })
|
DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 3) })
|
||||||
callback = DB.Callback().Create().Get("gorm:test_callback")
|
callback = DB.Callback().Create().Get("gorm:test_callback")
|
||||||
if callback == nil {
|
if callback == nil {
|
||||||
t.Errorf("`gorm:test_callback` should be non-nil")
|
t.Errorf("`gorm:test_callback` should be non-nil")
|
||||||
}
|
}
|
||||||
callback(scope)
|
callback(scope)
|
||||||
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 3 {
|
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 3 {
|
||||||
t.Errorf("`gorm:test_callback_value` should be `3, true` but `%v, %v`", v, ok)
|
t.Errorf("`gorm:test_callback_value` should be `3, true` but `%v, %v`", v, ok)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUseDefaultCallback(t *testing.T) {
|
func TestUseDefaultCallback(t *testing.T) {
|
||||||
createCallbackName := "gorm:test_use_default_callback_for_create"
|
createCallbackName := "gorm:test_use_default_callback_for_create"
|
||||||
gorm.DefaultCallback.Create().Register(createCallbackName, func(*gorm.Scope) {
|
gorm.DefaultCallback.Create().Register(createCallbackName, func(*gorm.Scope) {
|
||||||
// nop
|
// nop
|
||||||
})
|
})
|
||||||
if gorm.DefaultCallback.Create().Get(createCallbackName) == nil {
|
if gorm.DefaultCallback.Create().Get(createCallbackName) == nil {
|
||||||
t.Errorf("`%s` expected non-nil, but got nil", createCallbackName)
|
t.Errorf("`%s` expected non-nil, but got nil", createCallbackName)
|
||||||
}
|
}
|
||||||
gorm.DefaultCallback.Create().Remove(createCallbackName)
|
gorm.DefaultCallback.Create().Remove(createCallbackName)
|
||||||
if gorm.DefaultCallback.Create().Get(createCallbackName) != nil {
|
if gorm.DefaultCallback.Create().Get(createCallbackName) != nil {
|
||||||
t.Errorf("`%s` expected nil, but got non-nil", createCallbackName)
|
t.Errorf("`%s` expected nil, but got non-nil", createCallbackName)
|
||||||
}
|
}
|
||||||
|
|
||||||
updateCallbackName := "gorm:test_use_default_callback_for_update"
|
updateCallbackName := "gorm:test_use_default_callback_for_update"
|
||||||
scopeValueName := "gorm:test_use_default_callback_for_update_value"
|
scopeValueName := "gorm:test_use_default_callback_for_update_value"
|
||||||
gorm.DefaultCallback.Update().Register(updateCallbackName, func(scope *gorm.Scope) {
|
gorm.DefaultCallback.Update().Register(updateCallbackName, func(scope *gorm.Scope) {
|
||||||
scope.Set(scopeValueName, 1)
|
scope.Set(scopeValueName, 1)
|
||||||
})
|
})
|
||||||
gorm.DefaultCallback.Update().Replace(updateCallbackName, func(scope *gorm.Scope) {
|
gorm.DefaultCallback.Update().Replace(updateCallbackName, func(scope *gorm.Scope) {
|
||||||
scope.Set(scopeValueName, 2)
|
scope.Set(scopeValueName, 2)
|
||||||
})
|
})
|
||||||
|
|
||||||
scope := DB.NewScope(nil)
|
scope := DB.NewScope(nil)
|
||||||
callback := gorm.DefaultCallback.Update().Get(updateCallbackName)
|
callback := gorm.DefaultCallback.Update().Get(updateCallbackName)
|
||||||
callback(scope)
|
callback(scope)
|
||||||
if v, ok := scope.Get(scopeValueName); !ok || v != 2 {
|
if v, ok := scope.Get(scopeValueName); !ok || v != 2 {
|
||||||
t.Errorf("`%s` should be `2, true` but `%v, %v`", scopeValueName, v, ok)
|
t.Errorf("`%s` should be `2, true` but `%v, %v`", scopeValueName, v, ok)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,357 +1,357 @@
|
|||||||
package gorm_test
|
package gorm_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type CustomizeColumn struct {
|
type CustomizeColumn struct {
|
||||||
ID int64 `gorm:"column:mapped_id; primary_key:yes"`
|
ID int64 `gorm:"column:mapped_id; primary_key:yes"`
|
||||||
Name string `gorm:"column:mapped_name"`
|
Name string `gorm:"column:mapped_name"`
|
||||||
Date *time.Time `gorm:"column:mapped_time"`
|
Date *time.Time `gorm:"column:mapped_time"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make sure an ignored field does not interfere with another field's custom
|
// Make sure an ignored field does not interfere with another field's custom
|
||||||
// column name that matches the ignored field.
|
// column name that matches the ignored field.
|
||||||
type CustomColumnAndIgnoredFieldClash struct {
|
type CustomColumnAndIgnoredFieldClash struct {
|
||||||
Body string `sql:"-"`
|
Body string `sql:"-"`
|
||||||
RawBody string `gorm:"column:body"`
|
RawBody string `gorm:"column:body"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCustomizeColumn(t *testing.T) {
|
func TestCustomizeColumn(t *testing.T) {
|
||||||
col := "mapped_name"
|
col := "mapped_name"
|
||||||
DB.DropTable(&CustomizeColumn{})
|
DB.DropTable(&CustomizeColumn{})
|
||||||
DB.AutoMigrate(&CustomizeColumn{})
|
DB.AutoMigrate(&CustomizeColumn{})
|
||||||
|
|
||||||
scope := DB.NewScope(&CustomizeColumn{})
|
scope := DB.NewScope(&CustomizeColumn{})
|
||||||
if !scope.Dialect().HasColumn(scope.TableName(), col) {
|
if !scope.Dialect().HasColumn(scope.TableName(), col) {
|
||||||
t.Errorf("CustomizeColumn should have column %s", col)
|
t.Errorf("CustomizeColumn should have column %s", col)
|
||||||
}
|
}
|
||||||
|
|
||||||
col = "mapped_id"
|
col = "mapped_id"
|
||||||
if scope.PrimaryKey() != col {
|
if scope.PrimaryKey() != col {
|
||||||
t.Errorf("CustomizeColumn should have primary key %s, but got %q", col, scope.PrimaryKey())
|
t.Errorf("CustomizeColumn should have primary key %s, but got %q", col, scope.PrimaryKey())
|
||||||
}
|
}
|
||||||
|
|
||||||
expected := "foo"
|
expected := "foo"
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
cc := CustomizeColumn{ID: 666, Name: expected, Date: &now}
|
cc := CustomizeColumn{ID: 666, Name: expected, Date: &now}
|
||||||
|
|
||||||
if count := DB.Create(&cc).RowsAffected; count != 1 {
|
if count := DB.Create(&cc).RowsAffected; count != 1 {
|
||||||
t.Error("There should be one record be affected when create record")
|
t.Error("There should be one record be affected when create record")
|
||||||
}
|
}
|
||||||
|
|
||||||
var cc1 CustomizeColumn
|
var cc1 CustomizeColumn
|
||||||
DB.First(&cc1, 666)
|
DB.First(&cc1, 666)
|
||||||
|
|
||||||
if cc1.Name != expected {
|
if cc1.Name != expected {
|
||||||
t.Errorf("Failed to query CustomizeColumn")
|
t.Errorf("Failed to query CustomizeColumn")
|
||||||
}
|
}
|
||||||
|
|
||||||
cc.Name = "bar"
|
cc.Name = "bar"
|
||||||
DB.Save(&cc)
|
DB.Save(&cc)
|
||||||
|
|
||||||
var cc2 CustomizeColumn
|
var cc2 CustomizeColumn
|
||||||
DB.First(&cc2, 666)
|
DB.First(&cc2, 666)
|
||||||
if cc2.Name != "bar" {
|
if cc2.Name != "bar" {
|
||||||
t.Errorf("Failed to query CustomizeColumn")
|
t.Errorf("Failed to query CustomizeColumn")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCustomColumnAndIgnoredFieldClash(t *testing.T) {
|
func TestCustomColumnAndIgnoredFieldClash(t *testing.T) {
|
||||||
DB.DropTable(&CustomColumnAndIgnoredFieldClash{})
|
DB.DropTable(&CustomColumnAndIgnoredFieldClash{})
|
||||||
if err := DB.AutoMigrate(&CustomColumnAndIgnoredFieldClash{}).Error; err != nil {
|
if err := DB.AutoMigrate(&CustomColumnAndIgnoredFieldClash{}).Error; err != nil {
|
||||||
t.Errorf("Should not raise error: %s", err)
|
t.Errorf("Should not raise error: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type CustomizePerson struct {
|
type CustomizePerson struct {
|
||||||
IdPerson string `gorm:"column:idPerson;primary_key:true"`
|
IdPerson string `gorm:"column:idPerson;primary_key:true"`
|
||||||
Accounts []CustomizeAccount `gorm:"many2many:PersonAccount;associationforeignkey:idAccount;foreignkey:idPerson"`
|
Accounts []CustomizeAccount `gorm:"many2many:PersonAccount;associationforeignkey:idAccount;foreignkey:idPerson"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type CustomizeAccount struct {
|
type CustomizeAccount struct {
|
||||||
IdAccount string `gorm:"column:idAccount;primary_key:true"`
|
IdAccount string `gorm:"column:idAccount;primary_key:true"`
|
||||||
Name string
|
Name string
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManyToManyWithCustomizedColumn(t *testing.T) {
|
func TestManyToManyWithCustomizedColumn(t *testing.T) {
|
||||||
DB.DropTable(&CustomizePerson{}, &CustomizeAccount{}, "PersonAccount")
|
DB.DropTable(&CustomizePerson{}, &CustomizeAccount{}, "PersonAccount")
|
||||||
DB.AutoMigrate(&CustomizePerson{}, &CustomizeAccount{})
|
DB.AutoMigrate(&CustomizePerson{}, &CustomizeAccount{})
|
||||||
|
|
||||||
account := CustomizeAccount{IdAccount: "account", Name: "id1"}
|
account := CustomizeAccount{IdAccount: "account", Name: "id1"}
|
||||||
person := CustomizePerson{
|
person := CustomizePerson{
|
||||||
IdPerson: "person",
|
IdPerson: "person",
|
||||||
Accounts: []CustomizeAccount{account},
|
Accounts: []CustomizeAccount{account},
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := DB.Create(&account).Error; err != nil {
|
if err := DB.Create(&account).Error; err != nil {
|
||||||
t.Errorf("no error should happen, but got %v", err)
|
t.Errorf("no error should happen, but got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := DB.Create(&person).Error; err != nil {
|
if err := DB.Create(&person).Error; err != nil {
|
||||||
t.Errorf("no error should happen, but got %v", err)
|
t.Errorf("no error should happen, but got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var person1 CustomizePerson
|
var person1 CustomizePerson
|
||||||
scope := DB.NewScope(nil)
|
scope := DB.NewScope(nil)
|
||||||
if err := DB.Preload("Accounts").First(&person1, scope.Quote("idPerson")+" = ?", person.IdPerson).Error; err != nil {
|
if err := DB.Preload("Accounts").First(&person1, scope.Quote("idPerson")+" = ?", person.IdPerson).Error; err != nil {
|
||||||
t.Errorf("no error should happen when preloading customized column many2many relations, but got %v", err)
|
t.Errorf("no error should happen when preloading customized column many2many relations, but got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(person1.Accounts) != 1 || person1.Accounts[0].IdAccount != "account" {
|
if len(person1.Accounts) != 1 || person1.Accounts[0].IdAccount != "account" {
|
||||||
t.Errorf("should preload correct accounts")
|
t.Errorf("should preload correct accounts")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type CustomizeUser struct {
|
type CustomizeUser struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
Email string `sql:"column:email_address"`
|
Email string `sql:"column:email_address"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type CustomizeInvitation struct {
|
type CustomizeInvitation struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
Address string `sql:"column:invitation"`
|
Address string `sql:"column:invitation"`
|
||||||
Person *CustomizeUser `gorm:"foreignkey:Email;associationforeignkey:invitation"`
|
Person *CustomizeUser `gorm:"foreignkey:Email;associationforeignkey:invitation"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOneToOneWithCustomizedColumn(t *testing.T) {
|
func TestOneToOneWithCustomizedColumn(t *testing.T) {
|
||||||
DB.DropTable(&CustomizeUser{}, &CustomizeInvitation{})
|
DB.DropTable(&CustomizeUser{}, &CustomizeInvitation{})
|
||||||
DB.AutoMigrate(&CustomizeUser{}, &CustomizeInvitation{})
|
DB.AutoMigrate(&CustomizeUser{}, &CustomizeInvitation{})
|
||||||
|
|
||||||
user := CustomizeUser{
|
user := CustomizeUser{
|
||||||
Email: "hello@example.com",
|
Email: "hello@example.com",
|
||||||
}
|
}
|
||||||
invitation := CustomizeInvitation{
|
invitation := CustomizeInvitation{
|
||||||
Address: "hello@example.com",
|
Address: "hello@example.com",
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Create(&user)
|
DB.Create(&user)
|
||||||
DB.Create(&invitation)
|
DB.Create(&invitation)
|
||||||
|
|
||||||
var invitation2 CustomizeInvitation
|
var invitation2 CustomizeInvitation
|
||||||
if err := DB.Preload("Person").Find(&invitation2, invitation.ID).Error; err != nil {
|
if err := DB.Preload("Person").Find(&invitation2, invitation.ID).Error; err != nil {
|
||||||
t.Errorf("no error should happen, but got %v", err)
|
t.Errorf("no error should happen, but got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if invitation2.Person.Email != user.Email {
|
if invitation2.Person.Email != user.Email {
|
||||||
t.Errorf("Should preload one to one relation with customize foreign keys")
|
t.Errorf("Should preload one to one relation with customize foreign keys")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type PromotionDiscount struct {
|
type PromotionDiscount struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
Name string
|
Name string
|
||||||
Coupons []*PromotionCoupon `gorm:"ForeignKey:discount_id"`
|
Coupons []*PromotionCoupon `gorm:"ForeignKey:discount_id"`
|
||||||
Rule *PromotionRule `gorm:"ForeignKey:discount_id"`
|
Rule *PromotionRule `gorm:"ForeignKey:discount_id"`
|
||||||
Benefits []PromotionBenefit `gorm:"ForeignKey:promotion_id"`
|
Benefits []PromotionBenefit `gorm:"ForeignKey:promotion_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type PromotionBenefit struct {
|
type PromotionBenefit struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
Name string
|
Name string
|
||||||
PromotionID uint
|
PromotionID uint
|
||||||
Discount PromotionDiscount `gorm:"ForeignKey:promotion_id"`
|
Discount PromotionDiscount `gorm:"ForeignKey:promotion_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type PromotionCoupon struct {
|
type PromotionCoupon struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
Code string
|
Code string
|
||||||
DiscountID uint
|
DiscountID uint
|
||||||
Discount PromotionDiscount
|
Discount PromotionDiscount
|
||||||
}
|
}
|
||||||
|
|
||||||
type PromotionRule struct {
|
type PromotionRule struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
Name string
|
Name string
|
||||||
Begin *time.Time
|
Begin *time.Time
|
||||||
End *time.Time
|
End *time.Time
|
||||||
DiscountID uint
|
DiscountID uint
|
||||||
Discount *PromotionDiscount
|
Discount *PromotionDiscount
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOneToManyWithCustomizedColumn(t *testing.T) {
|
func TestOneToManyWithCustomizedColumn(t *testing.T) {
|
||||||
DB.DropTable(&PromotionDiscount{}, &PromotionCoupon{})
|
DB.DropTable(&PromotionDiscount{}, &PromotionCoupon{})
|
||||||
DB.AutoMigrate(&PromotionDiscount{}, &PromotionCoupon{})
|
DB.AutoMigrate(&PromotionDiscount{}, &PromotionCoupon{})
|
||||||
|
|
||||||
discount := PromotionDiscount{
|
discount := PromotionDiscount{
|
||||||
Name: "Happy New Year",
|
Name: "Happy New Year",
|
||||||
Coupons: []*PromotionCoupon{
|
Coupons: []*PromotionCoupon{
|
||||||
{Code: "newyear1"},
|
{Code: "newyear1"},
|
||||||
{Code: "newyear2"},
|
{Code: "newyear2"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := DB.Create(&discount).Error; err != nil {
|
if err := DB.Create(&discount).Error; err != nil {
|
||||||
t.Errorf("no error should happen but got %v", err)
|
t.Errorf("no error should happen but got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var discount1 PromotionDiscount
|
var discount1 PromotionDiscount
|
||||||
if err := DB.Preload("Coupons").First(&discount1, "id = ?", discount.ID).Error; err != nil {
|
if err := DB.Preload("Coupons").First(&discount1, "id = ?", discount.ID).Error; err != nil {
|
||||||
t.Errorf("no error should happen but got %v", err)
|
t.Errorf("no error should happen but got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(discount.Coupons) != 2 {
|
if len(discount.Coupons) != 2 {
|
||||||
t.Errorf("should find two coupons")
|
t.Errorf("should find two coupons")
|
||||||
}
|
}
|
||||||
|
|
||||||
var coupon PromotionCoupon
|
var coupon PromotionCoupon
|
||||||
if err := DB.Preload("Discount").First(&coupon, "code = ?", "newyear1").Error; err != nil {
|
if err := DB.Preload("Discount").First(&coupon, "code = ?", "newyear1").Error; err != nil {
|
||||||
t.Errorf("no error should happen but got %v", err)
|
t.Errorf("no error should happen but got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if coupon.Discount.Name != "Happy New Year" {
|
if coupon.Discount.Name != "Happy New Year" {
|
||||||
t.Errorf("should preload discount from coupon")
|
t.Errorf("should preload discount from coupon")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHasOneWithPartialCustomizedColumn(t *testing.T) {
|
func TestHasOneWithPartialCustomizedColumn(t *testing.T) {
|
||||||
DB.DropTable(&PromotionDiscount{}, &PromotionRule{})
|
DB.DropTable(&PromotionDiscount{}, &PromotionRule{})
|
||||||
DB.AutoMigrate(&PromotionDiscount{}, &PromotionRule{})
|
DB.AutoMigrate(&PromotionDiscount{}, &PromotionRule{})
|
||||||
|
|
||||||
var begin = time.Now()
|
var begin = time.Now()
|
||||||
var end = time.Now().Add(24 * time.Hour)
|
var end = time.Now().Add(24 * time.Hour)
|
||||||
discount := PromotionDiscount{
|
discount := PromotionDiscount{
|
||||||
Name: "Happy New Year 2",
|
Name: "Happy New Year 2",
|
||||||
Rule: &PromotionRule{
|
Rule: &PromotionRule{
|
||||||
Name: "time_limited",
|
Name: "time_limited",
|
||||||
Begin: &begin,
|
Begin: &begin,
|
||||||
End: &end,
|
End: &end,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := DB.Create(&discount).Error; err != nil {
|
if err := DB.Create(&discount).Error; err != nil {
|
||||||
t.Errorf("no error should happen but got %v", err)
|
t.Errorf("no error should happen but got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var discount1 PromotionDiscount
|
var discount1 PromotionDiscount
|
||||||
if err := DB.Preload("Rule").First(&discount1, "id = ?", discount.ID).Error; err != nil {
|
if err := DB.Preload("Rule").First(&discount1, "id = ?", discount.ID).Error; err != nil {
|
||||||
t.Errorf("no error should happen but got %v", err)
|
t.Errorf("no error should happen but got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if discount.Rule.Begin.Format(time.RFC3339Nano) != begin.Format(time.RFC3339Nano) {
|
if discount.Rule.Begin.Format(time.RFC3339Nano) != begin.Format(time.RFC3339Nano) {
|
||||||
t.Errorf("Should be able to preload Rule")
|
t.Errorf("Should be able to preload Rule")
|
||||||
}
|
}
|
||||||
|
|
||||||
var rule PromotionRule
|
var rule PromotionRule
|
||||||
if err := DB.Preload("Discount").First(&rule, "name = ?", "time_limited").Error; err != nil {
|
if err := DB.Preload("Discount").First(&rule, "name = ?", "time_limited").Error; err != nil {
|
||||||
t.Errorf("no error should happen but got %v", err)
|
t.Errorf("no error should happen but got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if rule.Discount.Name != "Happy New Year 2" {
|
if rule.Discount.Name != "Happy New Year 2" {
|
||||||
t.Errorf("should preload discount from rule")
|
t.Errorf("should preload discount from rule")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBelongsToWithPartialCustomizedColumn(t *testing.T) {
|
func TestBelongsToWithPartialCustomizedColumn(t *testing.T) {
|
||||||
DB.DropTable(&PromotionDiscount{}, &PromotionBenefit{})
|
DB.DropTable(&PromotionDiscount{}, &PromotionBenefit{})
|
||||||
DB.AutoMigrate(&PromotionDiscount{}, &PromotionBenefit{})
|
DB.AutoMigrate(&PromotionDiscount{}, &PromotionBenefit{})
|
||||||
|
|
||||||
discount := PromotionDiscount{
|
discount := PromotionDiscount{
|
||||||
Name: "Happy New Year 3",
|
Name: "Happy New Year 3",
|
||||||
Benefits: []PromotionBenefit{
|
Benefits: []PromotionBenefit{
|
||||||
{Name: "free cod"},
|
{Name: "free cod"},
|
||||||
{Name: "free shipping"},
|
{Name: "free shipping"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := DB.Create(&discount).Error; err != nil {
|
if err := DB.Create(&discount).Error; err != nil {
|
||||||
t.Errorf("no error should happen but got %v", err)
|
t.Errorf("no error should happen but got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var discount1 PromotionDiscount
|
var discount1 PromotionDiscount
|
||||||
if err := DB.Preload("Benefits").First(&discount1, "id = ?", discount.ID).Error; err != nil {
|
if err := DB.Preload("Benefits").First(&discount1, "id = ?", discount.ID).Error; err != nil {
|
||||||
t.Errorf("no error should happen but got %v", err)
|
t.Errorf("no error should happen but got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(discount.Benefits) != 2 {
|
if len(discount.Benefits) != 2 {
|
||||||
t.Errorf("should find two benefits")
|
t.Errorf("should find two benefits")
|
||||||
}
|
}
|
||||||
|
|
||||||
var benefit PromotionBenefit
|
var benefit PromotionBenefit
|
||||||
if err := DB.Preload("Discount").First(&benefit, "name = ?", "free cod").Error; err != nil {
|
if err := DB.Preload("Discount").First(&benefit, "name = ?", "free cod").Error; err != nil {
|
||||||
t.Errorf("no error should happen but got %v", err)
|
t.Errorf("no error should happen but got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if benefit.Discount.Name != "Happy New Year 3" {
|
if benefit.Discount.Name != "Happy New Year 3" {
|
||||||
t.Errorf("should preload discount from coupon")
|
t.Errorf("should preload discount from coupon")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type SelfReferencingUser struct {
|
type SelfReferencingUser struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
Name string
|
Name string
|
||||||
Friends []*SelfReferencingUser `gorm:"many2many:UserFriends;association_jointable_foreignkey:friend_id"`
|
Friends []*SelfReferencingUser `gorm:"many2many:UserFriends;association_jointable_foreignkey:friend_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSelfReferencingMany2ManyColumn(t *testing.T) {
|
func TestSelfReferencingMany2ManyColumn(t *testing.T) {
|
||||||
DB.DropTable(&SelfReferencingUser{}, "UserFriends")
|
DB.DropTable(&SelfReferencingUser{}, "UserFriends")
|
||||||
DB.AutoMigrate(&SelfReferencingUser{})
|
DB.AutoMigrate(&SelfReferencingUser{})
|
||||||
if !DB.HasTable("UserFriends") {
|
if !DB.HasTable("UserFriends") {
|
||||||
t.Errorf("auto migrate error, table UserFriends should be created")
|
t.Errorf("auto migrate error, table UserFriends should be created")
|
||||||
}
|
}
|
||||||
|
|
||||||
friend1 := SelfReferencingUser{Name: "friend1_m2m"}
|
friend1 := SelfReferencingUser{Name: "friend1_m2m"}
|
||||||
if err := DB.Create(&friend1).Error; err != nil {
|
if err := DB.Create(&friend1).Error; err != nil {
|
||||||
t.Errorf("no error should happen, but got %v", err)
|
t.Errorf("no error should happen, but got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
friend2 := SelfReferencingUser{Name: "friend2_m2m"}
|
friend2 := SelfReferencingUser{Name: "friend2_m2m"}
|
||||||
if err := DB.Create(&friend2).Error; err != nil {
|
if err := DB.Create(&friend2).Error; err != nil {
|
||||||
t.Errorf("no error should happen, but got %v", err)
|
t.Errorf("no error should happen, but got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
user := SelfReferencingUser{
|
user := SelfReferencingUser{
|
||||||
Name: "self_m2m",
|
Name: "self_m2m",
|
||||||
Friends: []*SelfReferencingUser{&friend1, &friend2},
|
Friends: []*SelfReferencingUser{&friend1, &friend2},
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := DB.Create(&user).Error; err != nil {
|
if err := DB.Create(&user).Error; err != nil {
|
||||||
t.Errorf("no error should happen, but got %v", err)
|
t.Errorf("no error should happen, but got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if DB.Model(&user).Association("Friends").Count() != 2 {
|
if DB.Model(&user).Association("Friends").Count() != 2 {
|
||||||
t.Errorf("Should find created friends correctly")
|
t.Errorf("Should find created friends correctly")
|
||||||
}
|
}
|
||||||
|
|
||||||
var count int
|
var count int
|
||||||
if err := DB.Table("UserFriends").Count(&count).Error; err != nil {
|
if err := DB.Table("UserFriends").Count(&count).Error; err != nil {
|
||||||
t.Errorf("no error should happen, but got %v", err)
|
t.Errorf("no error should happen, but got %v", err)
|
||||||
}
|
}
|
||||||
if count == 0 {
|
if count == 0 {
|
||||||
t.Errorf("table UserFriends should have records")
|
t.Errorf("table UserFriends should have records")
|
||||||
}
|
}
|
||||||
|
|
||||||
var newUser = SelfReferencingUser{}
|
var newUser = SelfReferencingUser{}
|
||||||
|
|
||||||
if err := DB.Preload("Friends").First(&newUser, "id = ?", user.ID).Error; err != nil {
|
if err := DB.Preload("Friends").First(&newUser, "id = ?", user.ID).Error; err != nil {
|
||||||
t.Errorf("no error should happen, but got %v", err)
|
t.Errorf("no error should happen, but got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(newUser.Friends) != 2 {
|
if len(newUser.Friends) != 2 {
|
||||||
t.Errorf("Should preload created frineds for self reference m2m")
|
t.Errorf("Should preload created frineds for self reference m2m")
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Model(&newUser).Association("Friends").Append(&SelfReferencingUser{Name: "friend3_m2m"})
|
DB.Model(&newUser).Association("Friends").Append(&SelfReferencingUser{Name: "friend3_m2m"})
|
||||||
if DB.Model(&user).Association("Friends").Count() != 3 {
|
if DB.Model(&user).Association("Friends").Count() != 3 {
|
||||||
t.Errorf("Should find created friends correctly")
|
t.Errorf("Should find created friends correctly")
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Model(&newUser).Association("Friends").Replace(&SelfReferencingUser{Name: "friend4_m2m"})
|
DB.Model(&newUser).Association("Friends").Replace(&SelfReferencingUser{Name: "friend4_m2m"})
|
||||||
if DB.Model(&user).Association("Friends").Count() != 1 {
|
if DB.Model(&user).Association("Friends").Count() != 1 {
|
||||||
t.Errorf("Should find created friends correctly")
|
t.Errorf("Should find created friends correctly")
|
||||||
}
|
}
|
||||||
|
|
||||||
friend := SelfReferencingUser{}
|
friend := SelfReferencingUser{}
|
||||||
DB.Model(&newUser).Association("Friends").Find(&friend)
|
DB.Model(&newUser).Association("Friends").Find(&friend)
|
||||||
if friend.Name != "friend4_m2m" {
|
if friend.Name != "friend4_m2m" {
|
||||||
t.Errorf("Should find created friends correctly")
|
t.Errorf("Should find created friends correctly")
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Model(&newUser).Association("Friends").Delete(friend)
|
DB.Model(&newUser).Association("Friends").Delete(friend)
|
||||||
if DB.Model(&user).Association("Friends").Count() != 0 {
|
if DB.Model(&user).Association("Friends").Count() != 0 {
|
||||||
t.Errorf("All friends should be deleted")
|
t.Errorf("All friends should be deleted")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
40
dialect.go
40
dialect.go
@ -25,28 +25,28 @@ type Dialect interface {
|
|||||||
DataTypeOf(field *StructField) string
|
DataTypeOf(field *StructField) string
|
||||||
// HasIndex check has index or not
|
// HasIndex check has index or not
|
||||||
HasIndex(tableName string, indexName string) bool
|
HasIndex(tableName string, indexName string) bool
|
||||||
// HasIndexContext same as HasIndex
|
// HasIndexContext same as HasIndex
|
||||||
HasIndexContext(ctx context.Context, tableName string, indexName string) bool
|
HasIndexContext(ctx context.Context, tableName string, indexName string) bool
|
||||||
// HasForeignKey check has foreign key or not
|
// HasForeignKey check has foreign key or not
|
||||||
HasForeignKey(tableName string, foreignKeyName string) bool
|
HasForeignKey(tableName string, foreignKeyName string) bool
|
||||||
// HasForeignKeyContext same as HasForeignKey
|
// HasForeignKeyContext same as HasForeignKey
|
||||||
HasForeignKeyContext(ctx context.Context, tableName string, foreignKeyName string) bool
|
HasForeignKeyContext(ctx context.Context, tableName string, foreignKeyName string) bool
|
||||||
// RemoveIndex remove index
|
// RemoveIndex remove index
|
||||||
RemoveIndex(tableName string, indexName string) error
|
RemoveIndex(tableName string, indexName string) error
|
||||||
// RemoveIndexContext same as RemoveIndex
|
// RemoveIndexContext same as RemoveIndex
|
||||||
RemoveIndexContext(ctx context.Context, tableName string, indexName string) error
|
RemoveIndexContext(ctx context.Context, tableName string, indexName string) error
|
||||||
// HasTable check has table or not
|
// HasTable check has table or not
|
||||||
HasTable(tableName string) bool
|
HasTable(tableName string) bool
|
||||||
// HasTableContext same as HasTable
|
// HasTableContext same as HasTable
|
||||||
HasTableContext(ctx context.Context, tableName string) bool
|
HasTableContext(ctx context.Context, tableName string) bool
|
||||||
// HasColumn check has column or not
|
// HasColumn check has column or not
|
||||||
HasColumn(tableName string, columnName string) bool
|
HasColumn(tableName string, columnName string) bool
|
||||||
// HasColumnContext same as HasColumn
|
// HasColumnContext same as HasColumn
|
||||||
HasColumnContext(ctx context.Context, tableName string, columnName string) bool
|
HasColumnContext(ctx context.Context, tableName string, columnName string) bool
|
||||||
// ModifyColumn modify column's type
|
// ModifyColumn modify column's type
|
||||||
ModifyColumn(tableName string, columnName string, typ string) error
|
ModifyColumn(tableName string, columnName string, typ string) error
|
||||||
// ModifyColumnContext same as ModifyColumn
|
// ModifyColumnContext same as ModifyColumn
|
||||||
ModifyColumnContext(ctx context.Context, tableName string, columnName string, typ string) error
|
ModifyColumnContext(ctx context.Context, tableName string, columnName string, typ string) error
|
||||||
|
|
||||||
// LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case
|
// LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case
|
||||||
LimitAndOffsetSQL(limit, offset interface{}) (string, error)
|
LimitAndOffsetSQL(limit, offset interface{}) (string, error)
|
||||||
@ -55,12 +55,12 @@ type Dialect interface {
|
|||||||
// LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT`
|
// LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT`
|
||||||
LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string
|
LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string
|
||||||
// LastInsertIDOutputInterstitialContext same as LastInsertIDOutputInterstitial
|
// LastInsertIDOutputInterstitialContext same as LastInsertIDOutputInterstitial
|
||||||
LastInsertIDOutputInterstitialContext(ctx context.Context, tableName, columnName string, columns []string) string
|
LastInsertIDOutputInterstitialContext(ctx context.Context, tableName, columnName string, columns []string) string
|
||||||
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
|
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
|
||||||
LastInsertIDReturningSuffix(tableName, columnName string) string
|
LastInsertIDReturningSuffix(tableName, columnName string) string
|
||||||
// LastInsertIDReturningSuffixContext same as LastInsertIDReturningSuffix
|
// LastInsertIDReturningSuffixContext same as LastInsertIDReturningSuffix
|
||||||
LastInsertIDReturningSuffixContext(ctx context.Context, tableName, columnName string) string
|
LastInsertIDReturningSuffixContext(ctx context.Context, tableName, columnName string) string
|
||||||
// DefaultValueStr
|
// DefaultValueStr
|
||||||
DefaultValueStr() string
|
DefaultValueStr() string
|
||||||
|
|
||||||
// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference
|
// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference
|
||||||
@ -71,8 +71,8 @@ type Dialect interface {
|
|||||||
|
|
||||||
// CurrentDatabase return current database name
|
// CurrentDatabase return current database name
|
||||||
CurrentDatabase() string
|
CurrentDatabase() string
|
||||||
// CurrentDatabaseContext same as CurrentDatabase
|
// CurrentDatabaseContext same as CurrentDatabase
|
||||||
CurrentDatabaseContext(ctx context.Context) string
|
CurrentDatabaseContext(ctx context.Context) string
|
||||||
}
|
}
|
||||||
|
|
||||||
var dialectsMap = map[string]Dialect{}
|
var dialectsMap = map[string]Dialect{}
|
||||||
|
@ -101,8 +101,8 @@ func (s *commonDialect) DataTypeOf(field *StructField) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s commonDialect) HasIndex(tableName string, indexName string) bool {
|
func (s commonDialect) HasIndex(tableName string, indexName string) bool {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
return s.HasIndexContext(ctx, tableName, indexName)
|
return s.HasIndexContext(ctx, tableName, indexName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s commonDialect) HasIndexContext(ctx context.Context, tableName string, indexName string) bool {
|
func (s commonDialect) HasIndexContext(ctx context.Context, tableName string, indexName string) bool {
|
||||||
@ -113,8 +113,8 @@ func (s commonDialect) HasIndexContext(ctx context.Context, tableName string, in
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s commonDialect) RemoveIndex(tableName string, indexName string) error {
|
func (s commonDialect) RemoveIndex(tableName string, indexName string) error {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
return s.RemoveIndexContext(ctx, tableName, indexName)
|
return s.RemoveIndexContext(ctx, tableName, indexName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s commonDialect) RemoveIndexContext(ctx context.Context, tableName string, indexName string) error {
|
func (s commonDialect) RemoveIndexContext(ctx context.Context, tableName string, indexName string) error {
|
||||||
@ -123,8 +123,8 @@ func (s commonDialect) RemoveIndexContext(ctx context.Context, tableName string,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool {
|
func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
return s.HasForeignKeyContext(ctx, tableName, foreignKeyName)
|
return s.HasForeignKeyContext(ctx, tableName, foreignKeyName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s commonDialect) HasForeignKeyContext(_ctx context.Context, tableName string, foreignKeyName string) bool {
|
func (s commonDialect) HasForeignKeyContext(_ctx context.Context, tableName string, foreignKeyName string) bool {
|
||||||
@ -132,8 +132,8 @@ func (s commonDialect) HasForeignKeyContext(_ctx context.Context, tableName stri
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s commonDialect) HasTable(tableName string) bool {
|
func (s commonDialect) HasTable(tableName string) bool {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
return s.HasTableContext(ctx, tableName)
|
return s.HasTableContext(ctx, tableName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s commonDialect) HasTableContext(ctx context.Context, tableName string) bool {
|
func (s commonDialect) HasTableContext(ctx context.Context, tableName string) bool {
|
||||||
@ -144,8 +144,8 @@ func (s commonDialect) HasTableContext(ctx context.Context, tableName string) bo
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s commonDialect) HasColumn(tableName string, columnName string) bool {
|
func (s commonDialect) HasColumn(tableName string, columnName string) bool {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
return s.HasColumnContext(ctx, tableName, columnName)
|
return s.HasColumnContext(ctx, tableName, columnName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s commonDialect) HasColumnContext(ctx context.Context, tableName string, columnName string) bool {
|
func (s commonDialect) HasColumnContext(ctx context.Context, tableName string, columnName string) bool {
|
||||||
@ -156,8 +156,8 @@ func (s commonDialect) HasColumnContext(ctx context.Context, tableName string, c
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s commonDialect) ModifyColumn(tableName string, columnName string, typ string) error {
|
func (s commonDialect) ModifyColumn(tableName string, columnName string, typ string) error {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
return s.ModifyColumnContext(ctx, tableName, columnName, typ)
|
return s.ModifyColumnContext(ctx, tableName, columnName, typ)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s commonDialect) ModifyColumnContext(ctx context.Context, tableName string, columnName string, typ string) error {
|
func (s commonDialect) ModifyColumnContext(ctx context.Context, tableName string, columnName string, typ string) error {
|
||||||
@ -166,8 +166,8 @@ func (s commonDialect) ModifyColumnContext(ctx context.Context, tableName string
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s commonDialect) CurrentDatabase() (name string) {
|
func (s commonDialect) CurrentDatabase() (name string) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
return s.CurrentDatabaseContext(ctx)
|
return s.CurrentDatabaseContext(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s commonDialect) CurrentDatabaseContext(ctx context.Context) (name string) {
|
func (s commonDialect) CurrentDatabaseContext(ctx context.Context) (name string) {
|
||||||
@ -199,8 +199,8 @@ func (commonDialect) SelectFromDummyTable() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s commonDialect) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string {
|
func (s commonDialect) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
return s.LastInsertIDOutputInterstitialContext(ctx, tableName, columnName, columns)
|
return s.LastInsertIDOutputInterstitialContext(ctx, tableName, columnName, columns)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (commonDialect) LastInsertIDOutputInterstitialContext(_ctx context.Context, tableName, columnName string, columns []string) string {
|
func (commonDialect) LastInsertIDOutputInterstitialContext(_ctx context.Context, tableName, columnName string, columns []string) string {
|
||||||
@ -208,8 +208,8 @@ func (commonDialect) LastInsertIDOutputInterstitialContext(_ctx context.Context,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string {
|
func (s commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
return s.LastInsertIDReturningSuffixContext(ctx, tableName, columnName)
|
return s.LastInsertIDReturningSuffixContext(ctx, tableName, columnName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (commonDialect) LastInsertIDReturningSuffixContext(_ctx context.Context, tableName, columnName string) string {
|
func (commonDialect) LastInsertIDReturningSuffixContext(_ctx context.Context, tableName, columnName string) string {
|
||||||
|
@ -124,8 +124,8 @@ func (s mssql) fieldCanAutoIncrement(field *gorm.StructField) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) HasIndex(tableName string, indexName string) bool {
|
func (s mssql) HasIndex(tableName string, indexName string) bool {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
return s.HasIndexContext(ctx, tableName, indexName)
|
return s.HasIndexContext(ctx, tableName, indexName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) HasIndexContext(ctx context.Context, tableName string, indexName string) bool {
|
func (s mssql) HasIndexContext(ctx context.Context, tableName string, indexName string) bool {
|
||||||
@ -135,8 +135,8 @@ func (s mssql) HasIndexContext(ctx context.Context, tableName string, indexName
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) RemoveIndex(tableName string, indexName string) error {
|
func (s mssql) RemoveIndex(tableName string, indexName string) error {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
return s.RemoveIndexContext(ctx, tableName, indexName)
|
return s.RemoveIndexContext(ctx, tableName, indexName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) RemoveIndexContext(ctx context.Context, tableName string, indexName string) error {
|
func (s mssql) RemoveIndexContext(ctx context.Context, tableName string, indexName string) error {
|
||||||
@ -145,8 +145,8 @@ func (s mssql) RemoveIndexContext(ctx context.Context, tableName string, indexNa
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool {
|
func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
return s.HasForeignKeyContext(ctx, tableName, foreignKeyName)
|
return s.HasForeignKeyContext(ctx, tableName, foreignKeyName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) HasForeignKeyContext(ctx context.Context, tableName string, foreignKeyName string) bool {
|
func (s mssql) HasForeignKeyContext(ctx context.Context, tableName string, foreignKeyName string) bool {
|
||||||
@ -161,8 +161,8 @@ func (s mssql) HasForeignKeyContext(ctx context.Context, tableName string, forei
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) HasTable(tableName string) bool {
|
func (s mssql) HasTable(tableName string) bool {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
return s.HasTableContext(ctx, tableName)
|
return s.HasTableContext(ctx, tableName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) HasTableContext(ctx context.Context, tableName string) bool {
|
func (s mssql) HasTableContext(ctx context.Context, tableName string) bool {
|
||||||
@ -173,8 +173,8 @@ func (s mssql) HasTableContext(ctx context.Context, tableName string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) HasColumn(tableName string, columnName string) bool {
|
func (s mssql) HasColumn(tableName string, columnName string) bool {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
return s.HasColumnContext(ctx, tableName, columnName)
|
return s.HasColumnContext(ctx, tableName, columnName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) HasColumnContext(ctx context.Context, tableName string, columnName string) bool {
|
func (s mssql) HasColumnContext(ctx context.Context, tableName string, columnName string) bool {
|
||||||
@ -185,8 +185,8 @@ func (s mssql) HasColumnContext(ctx context.Context, tableName string, columnNam
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) ModifyColumn(tableName string, columnName string, typ string) error {
|
func (s mssql) ModifyColumn(tableName string, columnName string, typ string) error {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
return s.ModifyColumnContext(ctx, tableName, columnName, typ)
|
return s.ModifyColumnContext(ctx, tableName, columnName, typ)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) ModifyColumnContext(ctx context.Context, tableName string, columnName string, typ string) error {
|
func (s mssql) ModifyColumnContext(ctx context.Context, tableName string, columnName string, typ string) error {
|
||||||
@ -195,9 +195,9 @@ func (s mssql) ModifyColumnContext(ctx context.Context, tableName string, column
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) CurrentDatabase() (name string) {
|
func (s mssql) CurrentDatabase() (name string) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
s.CurrentDatabaseContext(ctx)
|
s.CurrentDatabaseContext(ctx)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) CurrentDatabaseContext(ctx context.Context) (name string) {
|
func (s mssql) CurrentDatabaseContext(ctx context.Context) (name string) {
|
||||||
@ -236,8 +236,8 @@ func (mssql) SelectFromDummyTable() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string {
|
func (s mssql) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
return s.LastInsertIDOutputInterstitialContext(ctx, tableName, columnName, columns)
|
return s.LastInsertIDOutputInterstitialContext(ctx, tableName, columnName, columns)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mssql) LastInsertIDOutputInterstitialContext(_ctx context.Context, tableName, columnName string, columns []string) string {
|
func (mssql) LastInsertIDOutputInterstitialContext(_ctx context.Context, tableName, columnName string, columns []string) string {
|
||||||
@ -249,8 +249,8 @@ func (mssql) LastInsertIDOutputInterstitialContext(_ctx context.Context, tableNa
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s mssql) LastInsertIDReturningSuffix(tableName, columnName string) string {
|
func (s mssql) LastInsertIDReturningSuffix(tableName, columnName string) string {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
return s.LastInsertIDReturningSuffixContext(ctx, tableName, columnName)
|
return s.LastInsertIDReturningSuffixContext(ctx, tableName, columnName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mssql) LastInsertIDReturningSuffixContext(_ctx context.Context, tableName, columnName string) string {
|
func (mssql) LastInsertIDReturningSuffixContext(_ctx context.Context, tableName, columnName string) string {
|
||||||
|
@ -3,89 +3,89 @@ package gorm_test
|
|||||||
import "testing"
|
import "testing"
|
||||||
|
|
||||||
type BasePost struct {
|
type BasePost struct {
|
||||||
Id int64
|
Id int64
|
||||||
Title string
|
Title string
|
||||||
URL string
|
URL string
|
||||||
}
|
}
|
||||||
|
|
||||||
type Author struct {
|
type Author struct {
|
||||||
ID string
|
ID string
|
||||||
Name string
|
Name string
|
||||||
Email string
|
Email string
|
||||||
}
|
}
|
||||||
|
|
||||||
type HNPost struct {
|
type HNPost struct {
|
||||||
BasePost
|
BasePost
|
||||||
Author `gorm:"embedded_prefix:user_"` // Embedded struct
|
Author `gorm:"embedded_prefix:user_"` // Embedded struct
|
||||||
Upvotes int32
|
Upvotes int32
|
||||||
}
|
}
|
||||||
|
|
||||||
type EngadgetPost struct {
|
type EngadgetPost struct {
|
||||||
BasePost BasePost `gorm:"embedded"`
|
BasePost BasePost `gorm:"embedded"`
|
||||||
Author Author `gorm:"embedded;embedded_prefix:author_"` // Embedded struct
|
Author Author `gorm:"embedded;embedded_prefix:author_"` // Embedded struct
|
||||||
ImageUrl string
|
ImageUrl string
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPrefixColumnNameForEmbeddedStruct(t *testing.T) {
|
func TestPrefixColumnNameForEmbeddedStruct(t *testing.T) {
|
||||||
dialect := DB.NewScope(&EngadgetPost{}).Dialect()
|
dialect := DB.NewScope(&EngadgetPost{}).Dialect()
|
||||||
engadgetPostScope := DB.NewScope(&EngadgetPost{})
|
engadgetPostScope := DB.NewScope(&EngadgetPost{})
|
||||||
if !dialect.HasColumn(engadgetPostScope.TableName(), "author_id") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_name") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_email") {
|
if !dialect.HasColumn(engadgetPostScope.TableName(), "author_id") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_name") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_email") {
|
||||||
t.Errorf("should has prefix for embedded columns")
|
t.Errorf("should has prefix for embedded columns")
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(engadgetPostScope.PrimaryFields()) != 1 {
|
if len(engadgetPostScope.PrimaryFields()) != 1 {
|
||||||
t.Errorf("should have only one primary field with embedded struct, but got %v", len(engadgetPostScope.PrimaryFields()))
|
t.Errorf("should have only one primary field with embedded struct, but got %v", len(engadgetPostScope.PrimaryFields()))
|
||||||
}
|
}
|
||||||
|
|
||||||
hnScope := DB.NewScope(&HNPost{})
|
hnScope := DB.NewScope(&HNPost{})
|
||||||
if !dialect.HasColumn(hnScope.TableName(), "user_id") || !dialect.HasColumn(hnScope.TableName(), "user_name") || !dialect.HasColumn(hnScope.TableName(), "user_email") {
|
if !dialect.HasColumn(hnScope.TableName(), "user_id") || !dialect.HasColumn(hnScope.TableName(), "user_name") || !dialect.HasColumn(hnScope.TableName(), "user_email") {
|
||||||
t.Errorf("should has prefix for embedded columns")
|
t.Errorf("should has prefix for embedded columns")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSaveAndQueryEmbeddedStruct(t *testing.T) {
|
func TestSaveAndQueryEmbeddedStruct(t *testing.T) {
|
||||||
DB.Save(&HNPost{BasePost: BasePost{Title: "news"}})
|
DB.Save(&HNPost{BasePost: BasePost{Title: "news"}})
|
||||||
DB.Save(&HNPost{BasePost: BasePost{Title: "hn_news"}})
|
DB.Save(&HNPost{BasePost: BasePost{Title: "hn_news"}})
|
||||||
var news HNPost
|
var news HNPost
|
||||||
if err := DB.First(&news, "title = ?", "hn_news").Error; err != nil {
|
if err := DB.First(&news, "title = ?", "hn_news").Error; err != nil {
|
||||||
t.Errorf("no error should happen when query with embedded struct, but got %v", err)
|
t.Errorf("no error should happen when query with embedded struct, but got %v", err)
|
||||||
} else if news.Title != "hn_news" {
|
} else if news.Title != "hn_news" {
|
||||||
t.Errorf("embedded struct's value should be scanned correctly")
|
t.Errorf("embedded struct's value should be scanned correctly")
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}})
|
DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}})
|
||||||
var egNews EngadgetPost
|
var egNews EngadgetPost
|
||||||
if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil {
|
if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil {
|
||||||
t.Errorf("no error should happen when query with embedded struct, but got %v", err)
|
t.Errorf("no error should happen when query with embedded struct, but got %v", err)
|
||||||
} else if egNews.BasePost.Title != "engadget_news" {
|
} else if egNews.BasePost.Title != "engadget_news" {
|
||||||
t.Errorf("embedded struct's value should be scanned correctly")
|
t.Errorf("embedded struct's value should be scanned correctly")
|
||||||
}
|
}
|
||||||
|
|
||||||
if DB.NewScope(&HNPost{}).PrimaryField() == nil {
|
if DB.NewScope(&HNPost{}).PrimaryField() == nil {
|
||||||
t.Errorf("primary key with embedded struct should works")
|
t.Errorf("primary key with embedded struct should works")
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, field := range DB.NewScope(&HNPost{}).Fields() {
|
for _, field := range DB.NewScope(&HNPost{}).Fields() {
|
||||||
if field.Name == "BasePost" {
|
if field.Name == "BasePost" {
|
||||||
t.Errorf("scope Fields should not contain embedded struct")
|
t.Errorf("scope Fields should not contain embedded struct")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEmbeddedPointerTypeStruct(t *testing.T) {
|
func TestEmbeddedPointerTypeStruct(t *testing.T) {
|
||||||
type HNPost struct {
|
type HNPost struct {
|
||||||
*BasePost
|
*BasePost
|
||||||
Upvotes int32
|
Upvotes int32
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Create(&HNPost{BasePost: &BasePost{Title: "embedded_pointer_type"}})
|
DB.Create(&HNPost{BasePost: &BasePost{Title: "embedded_pointer_type"}})
|
||||||
|
|
||||||
var hnPost HNPost
|
var hnPost HNPost
|
||||||
if err := DB.First(&hnPost, "title = ?", "embedded_pointer_type").Error; err != nil {
|
if err := DB.First(&hnPost, "title = ?", "embedded_pointer_type").Error; err != nil {
|
||||||
t.Errorf("No error should happen when find embedded pointer type, but got %v", err)
|
t.Errorf("No error should happen when find embedded pointer type, but got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if hnPost.Title != "embedded_pointer_type" {
|
if hnPost.Title != "embedded_pointer_type" {
|
||||||
t.Errorf("Should find correct value for embedded pointer type")
|
t.Errorf("Should find correct value for embedded pointer type")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -11,10 +11,10 @@ type SQLCommon interface {
|
|||||||
Prepare(query string) (*sql.Stmt, error)
|
Prepare(query string) (*sql.Stmt, error)
|
||||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||||
QueryRow(query string, args ...interface{}) *sql.Row
|
QueryRow(query string, args ...interface{}) *sql.Row
|
||||||
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
||||||
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
|
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
|
||||||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||||
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
|
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
|
||||||
}
|
}
|
||||||
|
|
||||||
type sqlDb interface {
|
type sqlDb interface {
|
||||||
|
1
main.go
1
main.go
@ -675,7 +675,6 @@ func (s *DB) Begin() *DB {
|
|||||||
return s.BeginTx(context.Background(), &sql.TxOptions{})
|
return s.BeginTx(context.Background(), &sql.TxOptions{})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// BeginTx begins a transaction with options
|
// BeginTx begins a transaction with options
|
||||||
func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) *DB {
|
func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) *DB {
|
||||||
c := s.clone()
|
c := s.clone()
|
||||||
|
@ -1,579 +1,579 @@
|
|||||||
package gorm_test
|
package gorm_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
Id int64
|
Id int64
|
||||||
Age int64
|
Age int64
|
||||||
UserNum Num
|
UserNum Num
|
||||||
Name string `sql:"size:255"`
|
Name string `sql:"size:255"`
|
||||||
Email string
|
Email string
|
||||||
Birthday *time.Time // Time
|
Birthday *time.Time // Time
|
||||||
CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically
|
CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically
|
||||||
UpdatedAt time.Time // UpdatedAt: Time of record is updated, will be updated automatically
|
UpdatedAt time.Time // UpdatedAt: Time of record is updated, will be updated automatically
|
||||||
Emails []Email // Embedded structs
|
Emails []Email // Embedded structs
|
||||||
BillingAddress Address // Embedded struct
|
BillingAddress Address // Embedded struct
|
||||||
BillingAddressID sql.NullInt64 // Embedded struct's foreign key
|
BillingAddressID sql.NullInt64 // Embedded struct's foreign key
|
||||||
ShippingAddress Address // Embedded struct
|
ShippingAddress Address // Embedded struct
|
||||||
ShippingAddressId int64 // Embedded struct's foreign key
|
ShippingAddressId int64 // Embedded struct's foreign key
|
||||||
CreditCard CreditCard
|
CreditCard CreditCard
|
||||||
Latitude float64
|
Latitude float64
|
||||||
Languages []Language `gorm:"many2many:user_languages;"`
|
Languages []Language `gorm:"many2many:user_languages;"`
|
||||||
CompanyID *int
|
CompanyID *int
|
||||||
Company Company
|
Company Company
|
||||||
Role Role
|
Role Role
|
||||||
Password EncryptedData
|
Password EncryptedData
|
||||||
PasswordHash []byte
|
PasswordHash []byte
|
||||||
IgnoreMe int64 `sql:"-"`
|
IgnoreMe int64 `sql:"-"`
|
||||||
IgnoreStringSlice []string `sql:"-"`
|
IgnoreStringSlice []string `sql:"-"`
|
||||||
Ignored struct{ Name string } `sql:"-"`
|
Ignored struct{ Name string } `sql:"-"`
|
||||||
IgnoredPointer *User `sql:"-"`
|
IgnoredPointer *User `sql:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type NotSoLongTableName struct {
|
type NotSoLongTableName struct {
|
||||||
Id int64
|
Id int64
|
||||||
ReallyLongThingID int64
|
ReallyLongThingID int64
|
||||||
ReallyLongThing ReallyLongTableNameToTestMySQLNameLengthLimit
|
ReallyLongThing ReallyLongTableNameToTestMySQLNameLengthLimit
|
||||||
}
|
}
|
||||||
|
|
||||||
type ReallyLongTableNameToTestMySQLNameLengthLimit struct {
|
type ReallyLongTableNameToTestMySQLNameLengthLimit struct {
|
||||||
Id int64
|
Id int64
|
||||||
}
|
}
|
||||||
|
|
||||||
type ReallyLongThingThatReferencesShort struct {
|
type ReallyLongThingThatReferencesShort struct {
|
||||||
Id int64
|
Id int64
|
||||||
ShortID int64
|
ShortID int64
|
||||||
Short Short
|
Short Short
|
||||||
}
|
}
|
||||||
|
|
||||||
type Short struct {
|
type Short struct {
|
||||||
Id int64
|
Id int64
|
||||||
}
|
}
|
||||||
|
|
||||||
type CreditCard struct {
|
type CreditCard struct {
|
||||||
ID int8
|
ID int8
|
||||||
Number string
|
Number string
|
||||||
UserId sql.NullInt64
|
UserId sql.NullInt64
|
||||||
CreatedAt time.Time `sql:"not null"`
|
CreatedAt time.Time `sql:"not null"`
|
||||||
UpdatedAt time.Time
|
UpdatedAt time.Time
|
||||||
DeletedAt *time.Time `sql:"column:deleted_time"`
|
DeletedAt *time.Time `sql:"column:deleted_time"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Email struct {
|
type Email struct {
|
||||||
Id int16
|
Id int16
|
||||||
UserId int
|
UserId int
|
||||||
Email string `sql:"type:varchar(100);"`
|
Email string `sql:"type:varchar(100);"`
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
UpdatedAt time.Time
|
UpdatedAt time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
type Address struct {
|
type Address struct {
|
||||||
ID int
|
ID int
|
||||||
Address1 string
|
Address1 string
|
||||||
Address2 string
|
Address2 string
|
||||||
Post string
|
Post string
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
UpdatedAt time.Time
|
UpdatedAt time.Time
|
||||||
DeletedAt *time.Time
|
DeletedAt *time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
type Language struct {
|
type Language struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
Name string
|
Name string
|
||||||
Users []User `gorm:"many2many:user_languages;"`
|
Users []User `gorm:"many2many:user_languages;"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Product struct {
|
type Product struct {
|
||||||
Id int64
|
Id int64
|
||||||
Code string
|
Code string
|
||||||
Price int64
|
Price int64
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
UpdatedAt time.Time
|
UpdatedAt time.Time
|
||||||
AfterFindCallTimes int64
|
AfterFindCallTimes int64
|
||||||
BeforeCreateCallTimes int64
|
BeforeCreateCallTimes int64
|
||||||
AfterCreateCallTimes int64
|
AfterCreateCallTimes int64
|
||||||
BeforeUpdateCallTimes int64
|
BeforeUpdateCallTimes int64
|
||||||
AfterUpdateCallTimes int64
|
AfterUpdateCallTimes int64
|
||||||
BeforeSaveCallTimes int64
|
BeforeSaveCallTimes int64
|
||||||
AfterSaveCallTimes int64
|
AfterSaveCallTimes int64
|
||||||
BeforeDeleteCallTimes int64
|
BeforeDeleteCallTimes int64
|
||||||
AfterDeleteCallTimes int64
|
AfterDeleteCallTimes int64
|
||||||
}
|
}
|
||||||
|
|
||||||
type Company struct {
|
type Company struct {
|
||||||
Id int64
|
Id int64
|
||||||
Name string
|
Name string
|
||||||
Owner *User `sql:"-"`
|
Owner *User `sql:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Place struct {
|
type Place struct {
|
||||||
Id int64
|
Id int64
|
||||||
PlaceAddressID int
|
PlaceAddressID int
|
||||||
PlaceAddress *Address `gorm:"save_associations:false"`
|
PlaceAddress *Address `gorm:"save_associations:false"`
|
||||||
OwnerAddressID int
|
OwnerAddressID int
|
||||||
OwnerAddress *Address `gorm:"save_associations:true"`
|
OwnerAddress *Address `gorm:"save_associations:true"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type EncryptedData []byte
|
type EncryptedData []byte
|
||||||
|
|
||||||
func (data *EncryptedData) Scan(value interface{}) error {
|
func (data *EncryptedData) Scan(value interface{}) error {
|
||||||
if b, ok := value.([]byte); ok {
|
if b, ok := value.([]byte); ok {
|
||||||
if len(b) < 3 || b[0] != '*' || b[1] != '*' || b[2] != '*' {
|
if len(b) < 3 || b[0] != '*' || b[1] != '*' || b[2] != '*' {
|
||||||
return errors.New("Too short")
|
return errors.New("Too short")
|
||||||
}
|
}
|
||||||
|
|
||||||
*data = b[3:]
|
*data = b[3:]
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return errors.New("Bytes expected")
|
return errors.New("Bytes expected")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (data EncryptedData) Value() (driver.Value, error) {
|
func (data EncryptedData) Value() (driver.Value, error) {
|
||||||
if len(data) > 0 && data[0] == 'x' {
|
if len(data) > 0 && data[0] == 'x' {
|
||||||
//needed to test failures
|
//needed to test failures
|
||||||
return nil, errors.New("Should not start with 'x'")
|
return nil, errors.New("Should not start with 'x'")
|
||||||
}
|
}
|
||||||
|
|
||||||
//prepend asterisks
|
//prepend asterisks
|
||||||
return append([]byte("***"), data...), nil
|
return append([]byte("***"), data...), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type Role struct {
|
type Role struct {
|
||||||
Name string `gorm:"size:256"`
|
Name string `gorm:"size:256"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (role *Role) Scan(value interface{}) error {
|
func (role *Role) Scan(value interface{}) error {
|
||||||
if b, ok := value.([]uint8); ok {
|
if b, ok := value.([]uint8); ok {
|
||||||
role.Name = string(b)
|
role.Name = string(b)
|
||||||
} else {
|
} else {
|
||||||
role.Name = value.(string)
|
role.Name = value.(string)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (role Role) Value() (driver.Value, error) {
|
func (role Role) Value() (driver.Value, error) {
|
||||||
return role.Name, nil
|
return role.Name, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (role Role) IsAdmin() bool {
|
func (role Role) IsAdmin() bool {
|
||||||
return role.Name == "admin"
|
return role.Name == "admin"
|
||||||
}
|
}
|
||||||
|
|
||||||
type Num int64
|
type Num int64
|
||||||
|
|
||||||
func (i *Num) Scan(src interface{}) error {
|
func (i *Num) Scan(src interface{}) error {
|
||||||
switch s := src.(type) {
|
switch s := src.(type) {
|
||||||
case []byte:
|
case []byte:
|
||||||
n, _ := strconv.Atoi(string(s))
|
n, _ := strconv.Atoi(string(s))
|
||||||
*i = Num(n)
|
*i = Num(n)
|
||||||
case int64:
|
case int64:
|
||||||
*i = Num(s)
|
*i = Num(s)
|
||||||
default:
|
default:
|
||||||
return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String())
|
return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String())
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type Animal struct {
|
type Animal struct {
|
||||||
Counter uint64 `gorm:"primary_key:yes"`
|
Counter uint64 `gorm:"primary_key:yes"`
|
||||||
Name string `sql:"DEFAULT:'galeone'"`
|
Name string `sql:"DEFAULT:'galeone'"`
|
||||||
From string //test reserved sql keyword as field name
|
From string //test reserved sql keyword as field name
|
||||||
Age time.Time `sql:"DEFAULT:current_timestamp"`
|
Age time.Time `sql:"DEFAULT:current_timestamp"`
|
||||||
unexported string // unexported value
|
unexported string // unexported value
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
UpdatedAt time.Time
|
UpdatedAt time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
type JoinTable struct {
|
type JoinTable struct {
|
||||||
From uint64
|
From uint64
|
||||||
To uint64
|
To uint64
|
||||||
Time time.Time `sql:"default: null"`
|
Time time.Time `sql:"default: null"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Post struct {
|
type Post struct {
|
||||||
Id int64
|
Id int64
|
||||||
CategoryId sql.NullInt64
|
CategoryId sql.NullInt64
|
||||||
MainCategoryId int64
|
MainCategoryId int64
|
||||||
Title string
|
Title string
|
||||||
Body string
|
Body string
|
||||||
Comments []*Comment
|
Comments []*Comment
|
||||||
Category Category
|
Category Category
|
||||||
MainCategory Category
|
MainCategory Category
|
||||||
}
|
}
|
||||||
|
|
||||||
type Category struct {
|
type Category struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
Name string
|
Name string
|
||||||
|
|
||||||
Categories []Category
|
Categories []Category
|
||||||
CategoryID *uint
|
CategoryID *uint
|
||||||
}
|
}
|
||||||
|
|
||||||
type Comment struct {
|
type Comment struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
PostId int64
|
PostId int64
|
||||||
Content string
|
Content string
|
||||||
Post Post
|
Post Post
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scanner
|
// Scanner
|
||||||
type NullValue struct {
|
type NullValue struct {
|
||||||
Id int64
|
Id int64
|
||||||
Name sql.NullString `sql:"not null"`
|
Name sql.NullString `sql:"not null"`
|
||||||
Gender *sql.NullString `sql:"not null"`
|
Gender *sql.NullString `sql:"not null"`
|
||||||
Age sql.NullInt64
|
Age sql.NullInt64
|
||||||
Male sql.NullBool
|
Male sql.NullBool
|
||||||
Height sql.NullFloat64
|
Height sql.NullFloat64
|
||||||
AddedAt NullTime
|
AddedAt NullTime
|
||||||
}
|
}
|
||||||
|
|
||||||
type NullTime struct {
|
type NullTime struct {
|
||||||
Time time.Time
|
Time time.Time
|
||||||
Valid bool
|
Valid bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (nt *NullTime) Scan(value interface{}) error {
|
func (nt *NullTime) Scan(value interface{}) error {
|
||||||
if value == nil {
|
if value == nil {
|
||||||
nt.Valid = false
|
nt.Valid = false
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
nt.Time, nt.Valid = value.(time.Time), true
|
nt.Time, nt.Valid = value.(time.Time), true
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (nt NullTime) Value() (driver.Value, error) {
|
func (nt NullTime) Value() (driver.Value, error) {
|
||||||
if !nt.Valid {
|
if !nt.Valid {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return nt.Time, nil
|
return nt.Time, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getPreparedUser(name string, role string) *User {
|
func getPreparedUser(name string, role string) *User {
|
||||||
var company Company
|
var company Company
|
||||||
DB.Where(Company{Name: role}).FirstOrCreate(&company)
|
DB.Where(Company{Name: role}).FirstOrCreate(&company)
|
||||||
|
|
||||||
return &User{
|
return &User{
|
||||||
Name: name,
|
Name: name,
|
||||||
Age: 20,
|
Age: 20,
|
||||||
Role: Role{role},
|
Role: Role{role},
|
||||||
BillingAddress: Address{Address1: fmt.Sprintf("Billing Address %v", name)},
|
BillingAddress: Address{Address1: fmt.Sprintf("Billing Address %v", name)},
|
||||||
ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)},
|
ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)},
|
||||||
CreditCard: CreditCard{Number: fmt.Sprintf("123456%v", name)},
|
CreditCard: CreditCard{Number: fmt.Sprintf("123456%v", name)},
|
||||||
Emails: []Email{
|
Emails: []Email{
|
||||||
{Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)},
|
{Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)},
|
||||||
},
|
},
|
||||||
Company: company,
|
Company: company,
|
||||||
Languages: []Language{
|
Languages: []Language{
|
||||||
{Name: fmt.Sprintf("lang_1_%v", name)},
|
{Name: fmt.Sprintf("lang_1_%v", name)},
|
||||||
{Name: fmt.Sprintf("lang_2_%v", name)},
|
{Name: fmt.Sprintf("lang_2_%v", name)},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func runMigration() {
|
func runMigration() {
|
||||||
if err := DB.DropTableIfExists(&User{}).Error; err != nil {
|
if err := DB.DropTableIfExists(&User{}).Error; err != nil {
|
||||||
fmt.Printf("Got error when try to delete table users, %+v\n", err)
|
fmt.Printf("Got error when try to delete table users, %+v\n", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, table := range []string{"animals", "user_languages"} {
|
for _, table := range []string{"animals", "user_languages"} {
|
||||||
DB.Exec(fmt.Sprintf("drop table %v;", table))
|
DB.Exec(fmt.Sprintf("drop table %v;", table))
|
||||||
}
|
}
|
||||||
|
|
||||||
values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}, &Place{}}
|
values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}, &Place{}}
|
||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
DB.DropTable(value)
|
DB.DropTable(value)
|
||||||
}
|
}
|
||||||
if err := DB.AutoMigrate(values...).Error; err != nil {
|
if err := DB.AutoMigrate(values...).Error; err != nil {
|
||||||
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
|
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIndexes(t *testing.T) {
|
func TestIndexes(t *testing.T) {
|
||||||
if err := DB.Model(&Email{}).AddIndex("idx_email_email", "email").Error; err != nil {
|
if err := DB.Model(&Email{}).AddIndex("idx_email_email", "email").Error; err != nil {
|
||||||
t.Errorf("Got error when tried to create index: %+v", err)
|
t.Errorf("Got error when tried to create index: %+v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
scope := DB.NewScope(&Email{})
|
scope := DB.NewScope(&Email{})
|
||||||
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") {
|
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") {
|
||||||
t.Errorf("Email should have index idx_email_email")
|
t.Errorf("Email should have index idx_email_email")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := DB.Model(&Email{}).RemoveIndex("idx_email_email").Error; err != nil {
|
if err := DB.Model(&Email{}).RemoveIndex("idx_email_email").Error; err != nil {
|
||||||
t.Errorf("Got error when tried to remove index: %+v", err)
|
t.Errorf("Got error when tried to remove index: %+v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") {
|
if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") {
|
||||||
t.Errorf("Email's index idx_email_email should be deleted")
|
t.Errorf("Email's index idx_email_email should be deleted")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := DB.Model(&Email{}).AddIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil {
|
if err := DB.Model(&Email{}).AddIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil {
|
||||||
t.Errorf("Got error when tried to create index: %+v", err)
|
t.Errorf("Got error when tried to create index: %+v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
|
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
|
||||||
t.Errorf("Email should have index idx_email_email_and_user_id")
|
t.Errorf("Email should have index idx_email_email_and_user_id")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil {
|
if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil {
|
||||||
t.Errorf("Got error when tried to remove index: %+v", err)
|
t.Errorf("Got error when tried to remove index: %+v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
|
if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
|
||||||
t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
|
t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := DB.Model(&Email{}).AddUniqueIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil {
|
if err := DB.Model(&Email{}).AddUniqueIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil {
|
||||||
t.Errorf("Got error when tried to create index: %+v", err)
|
t.Errorf("Got error when tried to create index: %+v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
|
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
|
||||||
t.Errorf("Email should have index idx_email_email_and_user_id")
|
t.Errorf("Email should have index idx_email_email_and_user_id")
|
||||||
}
|
}
|
||||||
|
|
||||||
if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.comiii"}, {Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error == nil {
|
if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.comiii"}, {Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error == nil {
|
||||||
t.Errorf("Should get to create duplicate record when having unique index")
|
t.Errorf("Should get to create duplicate record when having unique index")
|
||||||
}
|
}
|
||||||
|
|
||||||
var user = User{Name: "sample_user"}
|
var user = User{Name: "sample_user"}
|
||||||
DB.Save(&user)
|
DB.Save(&user)
|
||||||
if DB.Model(&user).Association("Emails").Append(Email{Email: "not-1duplicated@gmail.com"}, Email{Email: "not-duplicated2@gmail.com"}).Error != nil {
|
if DB.Model(&user).Association("Emails").Append(Email{Email: "not-1duplicated@gmail.com"}, Email{Email: "not-duplicated2@gmail.com"}).Error != nil {
|
||||||
t.Errorf("Should get no error when append two emails for user")
|
t.Errorf("Should get no error when append two emails for user")
|
||||||
}
|
}
|
||||||
|
|
||||||
if DB.Model(&user).Association("Emails").Append(Email{Email: "duplicated@gmail.com"}, Email{Email: "duplicated@gmail.com"}).Error == nil {
|
if DB.Model(&user).Association("Emails").Append(Email{Email: "duplicated@gmail.com"}, Email{Email: "duplicated@gmail.com"}).Error == nil {
|
||||||
t.Errorf("Should get no duplicated email error when insert duplicated emails for a user")
|
t.Errorf("Should get no duplicated email error when insert duplicated emails for a user")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil {
|
if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil {
|
||||||
t.Errorf("Got error when tried to remove index: %+v", err)
|
t.Errorf("Got error when tried to remove index: %+v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
|
if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
|
||||||
t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
|
t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
|
||||||
}
|
}
|
||||||
|
|
||||||
if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error != nil {
|
if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error != nil {
|
||||||
t.Errorf("Should be able to create duplicated emails after remove unique index")
|
t.Errorf("Should be able to create duplicated emails after remove unique index")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type EmailWithIdx struct {
|
type EmailWithIdx struct {
|
||||||
Id int64
|
Id int64
|
||||||
UserId int64
|
UserId int64
|
||||||
Email string `sql:"index:idx_email_agent"`
|
Email string `sql:"index:idx_email_agent"`
|
||||||
UserAgent string `sql:"index:idx_email_agent"`
|
UserAgent string `sql:"index:idx_email_agent"`
|
||||||
RegisteredAt *time.Time `sql:"unique_index"`
|
RegisteredAt *time.Time `sql:"unique_index"`
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
UpdatedAt time.Time
|
UpdatedAt time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAutoMigration(t *testing.T) {
|
func TestAutoMigration(t *testing.T) {
|
||||||
DB.AutoMigrate(&Address{})
|
DB.AutoMigrate(&Address{})
|
||||||
DB.DropTable(&EmailWithIdx{})
|
DB.DropTable(&EmailWithIdx{})
|
||||||
if err := DB.AutoMigrate(&EmailWithIdx{}).Error; err != nil {
|
if err := DB.AutoMigrate(&EmailWithIdx{}).Error; err != nil {
|
||||||
t.Errorf("Auto Migrate should not raise any error")
|
t.Errorf("Auto Migrate should not raise any error")
|
||||||
}
|
}
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
DB.Save(&EmailWithIdx{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: &now})
|
DB.Save(&EmailWithIdx{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: &now})
|
||||||
|
|
||||||
scope := DB.NewScope(&EmailWithIdx{})
|
scope := DB.NewScope(&EmailWithIdx{})
|
||||||
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_agent") {
|
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_agent") {
|
||||||
t.Errorf("Failed to create index")
|
t.Errorf("Failed to create index")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !scope.Dialect().HasIndex(scope.TableName(), "uix_email_with_idxes_registered_at") {
|
if !scope.Dialect().HasIndex(scope.TableName(), "uix_email_with_idxes_registered_at") {
|
||||||
t.Errorf("Failed to create index")
|
t.Errorf("Failed to create index")
|
||||||
}
|
}
|
||||||
|
|
||||||
var bigemail EmailWithIdx
|
var bigemail EmailWithIdx
|
||||||
DB.First(&bigemail, "user_agent = ?", "pc")
|
DB.First(&bigemail, "user_agent = ?", "pc")
|
||||||
if bigemail.Email != "jinzhu@example.org" || bigemail.UserAgent != "pc" || bigemail.RegisteredAt.IsZero() {
|
if bigemail.Email != "jinzhu@example.org" || bigemail.UserAgent != "pc" || bigemail.RegisteredAt.IsZero() {
|
||||||
t.Error("Big Emails should be saved and fetched correctly")
|
t.Error("Big Emails should be saved and fetched correctly")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateAndAutomigrateTransaction(t *testing.T) {
|
func TestCreateAndAutomigrateTransaction(t *testing.T) {
|
||||||
tx := DB.Begin()
|
tx := DB.Begin()
|
||||||
|
|
||||||
func() {
|
func() {
|
||||||
type Bar struct {
|
type Bar struct {
|
||||||
ID uint
|
ID uint
|
||||||
}
|
}
|
||||||
DB.DropTableIfExists(&Bar{})
|
DB.DropTableIfExists(&Bar{})
|
||||||
|
|
||||||
if ok := DB.HasTable("bars"); ok {
|
if ok := DB.HasTable("bars"); ok {
|
||||||
t.Errorf("Table should not exist, but does")
|
t.Errorf("Table should not exist, but does")
|
||||||
}
|
}
|
||||||
|
|
||||||
if ok := tx.HasTable("bars"); ok {
|
if ok := tx.HasTable("bars"); ok {
|
||||||
t.Errorf("Table should not exist, but does")
|
t.Errorf("Table should not exist, but does")
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
func() {
|
func() {
|
||||||
type Bar struct {
|
type Bar struct {
|
||||||
Name string
|
Name string
|
||||||
}
|
}
|
||||||
err := tx.CreateTable(&Bar{}).Error
|
err := tx.CreateTable(&Bar{}).Error
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Should have been able to create the table, but couldn't: %s", err)
|
t.Errorf("Should have been able to create the table, but couldn't: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if ok := tx.HasTable(&Bar{}); !ok {
|
if ok := tx.HasTable(&Bar{}); !ok {
|
||||||
t.Errorf("The transaction should be able to see the table")
|
t.Errorf("The transaction should be able to see the table")
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
func() {
|
func() {
|
||||||
type Bar struct {
|
type Bar struct {
|
||||||
Stuff string
|
Stuff string
|
||||||
}
|
}
|
||||||
|
|
||||||
err := tx.AutoMigrate(&Bar{}).Error
|
err := tx.AutoMigrate(&Bar{}).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Should have been able to alter the table, but couldn't")
|
t.Errorf("Should have been able to alter the table, but couldn't")
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
}
|
}
|
||||||
|
|
||||||
type MultipleIndexes struct {
|
type MultipleIndexes struct {
|
||||||
ID int64
|
ID int64
|
||||||
UserID int64 `sql:"unique_index:uix_multipleindexes_user_name,uix_multipleindexes_user_email;index:idx_multipleindexes_user_other"`
|
UserID int64 `sql:"unique_index:uix_multipleindexes_user_name,uix_multipleindexes_user_email;index:idx_multipleindexes_user_other"`
|
||||||
Name string `sql:"unique_index:uix_multipleindexes_user_name"`
|
Name string `sql:"unique_index:uix_multipleindexes_user_name"`
|
||||||
Email string `sql:"unique_index:,uix_multipleindexes_user_email"`
|
Email string `sql:"unique_index:,uix_multipleindexes_user_email"`
|
||||||
Other string `sql:"index:,idx_multipleindexes_user_other"`
|
Other string `sql:"index:,idx_multipleindexes_user_other"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMultipleIndexes(t *testing.T) {
|
func TestMultipleIndexes(t *testing.T) {
|
||||||
if err := DB.DropTableIfExists(&MultipleIndexes{}).Error; err != nil {
|
if err := DB.DropTableIfExists(&MultipleIndexes{}).Error; err != nil {
|
||||||
fmt.Printf("Got error when try to delete table multiple_indexes, %+v\n", err)
|
fmt.Printf("Got error when try to delete table multiple_indexes, %+v\n", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.AutoMigrate(&MultipleIndexes{})
|
DB.AutoMigrate(&MultipleIndexes{})
|
||||||
if err := DB.AutoMigrate(&EmailWithIdx{}).Error; err != nil {
|
if err := DB.AutoMigrate(&EmailWithIdx{}).Error; err != nil {
|
||||||
t.Errorf("Auto Migrate should not raise any error")
|
t.Errorf("Auto Migrate should not raise any error")
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Save(&MultipleIndexes{UserID: 1, Name: "jinzhu", Email: "jinzhu@example.org", Other: "foo"})
|
DB.Save(&MultipleIndexes{UserID: 1, Name: "jinzhu", Email: "jinzhu@example.org", Other: "foo"})
|
||||||
|
|
||||||
scope := DB.NewScope(&MultipleIndexes{})
|
scope := DB.NewScope(&MultipleIndexes{})
|
||||||
if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_name") {
|
if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_name") {
|
||||||
t.Errorf("Failed to create index")
|
t.Errorf("Failed to create index")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_email") {
|
if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_email") {
|
||||||
t.Errorf("Failed to create index")
|
t.Errorf("Failed to create index")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !scope.Dialect().HasIndex(scope.TableName(), "uix_multiple_indexes_email") {
|
if !scope.Dialect().HasIndex(scope.TableName(), "uix_multiple_indexes_email") {
|
||||||
t.Errorf("Failed to create index")
|
t.Errorf("Failed to create index")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !scope.Dialect().HasIndex(scope.TableName(), "idx_multipleindexes_user_other") {
|
if !scope.Dialect().HasIndex(scope.TableName(), "idx_multipleindexes_user_other") {
|
||||||
t.Errorf("Failed to create index")
|
t.Errorf("Failed to create index")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !scope.Dialect().HasIndex(scope.TableName(), "idx_multiple_indexes_other") {
|
if !scope.Dialect().HasIndex(scope.TableName(), "idx_multiple_indexes_other") {
|
||||||
t.Errorf("Failed to create index")
|
t.Errorf("Failed to create index")
|
||||||
}
|
}
|
||||||
|
|
||||||
var mutipleIndexes MultipleIndexes
|
var mutipleIndexes MultipleIndexes
|
||||||
DB.First(&mutipleIndexes, "name = ?", "jinzhu")
|
DB.First(&mutipleIndexes, "name = ?", "jinzhu")
|
||||||
if mutipleIndexes.Email != "jinzhu@example.org" || mutipleIndexes.Name != "jinzhu" {
|
if mutipleIndexes.Email != "jinzhu@example.org" || mutipleIndexes.Name != "jinzhu" {
|
||||||
t.Error("MutipleIndexes should be saved and fetched correctly")
|
t.Error("MutipleIndexes should be saved and fetched correctly")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check unique constraints
|
// Check unique constraints
|
||||||
if err := DB.Save(&MultipleIndexes{UserID: 1, Name: "name1", Email: "jinzhu@example.org", Other: "foo"}).Error; err == nil {
|
if err := DB.Save(&MultipleIndexes{UserID: 1, Name: "name1", Email: "jinzhu@example.org", Other: "foo"}).Error; err == nil {
|
||||||
t.Error("MultipleIndexes unique index failed")
|
t.Error("MultipleIndexes unique index failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := DB.Save(&MultipleIndexes{UserID: 1, Name: "name1", Email: "foo@example.org", Other: "foo"}).Error; err != nil {
|
if err := DB.Save(&MultipleIndexes{UserID: 1, Name: "name1", Email: "foo@example.org", Other: "foo"}).Error; err != nil {
|
||||||
t.Error("MultipleIndexes unique index failed")
|
t.Error("MultipleIndexes unique index failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := DB.Save(&MultipleIndexes{UserID: 2, Name: "name1", Email: "jinzhu@example.org", Other: "foo"}).Error; err == nil {
|
if err := DB.Save(&MultipleIndexes{UserID: 2, Name: "name1", Email: "jinzhu@example.org", Other: "foo"}).Error; err == nil {
|
||||||
t.Error("MultipleIndexes unique index failed")
|
t.Error("MultipleIndexes unique index failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := DB.Save(&MultipleIndexes{UserID: 2, Name: "name1", Email: "foo2@example.org", Other: "foo"}).Error; err != nil {
|
if err := DB.Save(&MultipleIndexes{UserID: 2, Name: "name1", Email: "foo2@example.org", Other: "foo"}).Error; err != nil {
|
||||||
t.Error("MultipleIndexes unique index failed")
|
t.Error("MultipleIndexes unique index failed")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestModifyColumnType(t *testing.T) {
|
func TestModifyColumnType(t *testing.T) {
|
||||||
if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" && dialect != "mysql" && dialect != "mssql" {
|
if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" && dialect != "mysql" && dialect != "mssql" {
|
||||||
t.Skip("Skipping this because only postgres, mysql and mssql support altering a column type")
|
t.Skip("Skipping this because only postgres, mysql and mssql support altering a column type")
|
||||||
}
|
}
|
||||||
|
|
||||||
type ModifyColumnType struct {
|
type ModifyColumnType struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
Name1 string `gorm:"length:100"`
|
Name1 string `gorm:"length:100"`
|
||||||
Name2 string `gorm:"length:200"`
|
Name2 string `gorm:"length:200"`
|
||||||
}
|
}
|
||||||
DB.DropTable(&ModifyColumnType{})
|
DB.DropTable(&ModifyColumnType{})
|
||||||
DB.CreateTable(&ModifyColumnType{})
|
DB.CreateTable(&ModifyColumnType{})
|
||||||
|
|
||||||
name2Field, _ := DB.NewScope(&ModifyColumnType{}).FieldByName("Name2")
|
name2Field, _ := DB.NewScope(&ModifyColumnType{}).FieldByName("Name2")
|
||||||
name2Type := DB.Dialect().DataTypeOf(name2Field.StructField)
|
name2Type := DB.Dialect().DataTypeOf(name2Field.StructField)
|
||||||
|
|
||||||
if err := DB.Model(&ModifyColumnType{}).ModifyColumn("name1", name2Type).Error; err != nil {
|
if err := DB.Model(&ModifyColumnType{}).ModifyColumn("name1", name2Type).Error; err != nil {
|
||||||
t.Errorf("No error should happen when ModifyColumn, but got %v", err)
|
t.Errorf("No error should happen when ModifyColumn, but got %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIndexWithPrefixLength(t *testing.T) {
|
func TestIndexWithPrefixLength(t *testing.T) {
|
||||||
if dialect := os.Getenv("GORM_DIALECT"); dialect != "mysql" {
|
if dialect := os.Getenv("GORM_DIALECT"); dialect != "mysql" {
|
||||||
t.Skip("Skipping this because only mysql support setting an index prefix length")
|
t.Skip("Skipping this because only mysql support setting an index prefix length")
|
||||||
}
|
}
|
||||||
|
|
||||||
type IndexWithPrefix struct {
|
type IndexWithPrefix struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
Name string
|
Name string
|
||||||
Description string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"`
|
Description string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"`
|
||||||
}
|
}
|
||||||
type IndexesWithPrefix struct {
|
type IndexesWithPrefix struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
Name string
|
Name string
|
||||||
Description1 string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"`
|
Description1 string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"`
|
||||||
Description2 string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"`
|
Description2 string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"`
|
||||||
}
|
}
|
||||||
type IndexesWithPrefixAndWithoutPrefix struct {
|
type IndexesWithPrefixAndWithoutPrefix struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
Name string `gorm:"index:idx_index_with_prefixes_length"`
|
Name string `gorm:"index:idx_index_with_prefixes_length"`
|
||||||
Description string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"`
|
Description string `gorm:"type:text;index:idx_index_with_prefixes_length(100)"`
|
||||||
}
|
}
|
||||||
tables := []interface{}{&IndexWithPrefix{}, &IndexesWithPrefix{}, &IndexesWithPrefixAndWithoutPrefix{}}
|
tables := []interface{}{&IndexWithPrefix{}, &IndexesWithPrefix{}, &IndexesWithPrefixAndWithoutPrefix{}}
|
||||||
for _, table := range tables {
|
for _, table := range tables {
|
||||||
scope := DB.NewScope(table)
|
scope := DB.NewScope(table)
|
||||||
tableName := scope.TableName()
|
tableName := scope.TableName()
|
||||||
t.Run(fmt.Sprintf("Create index with prefix length: %s", tableName), func(t *testing.T) {
|
t.Run(fmt.Sprintf("Create index with prefix length: %s", tableName), func(t *testing.T) {
|
||||||
if err := DB.DropTableIfExists(table).Error; err != nil {
|
if err := DB.DropTableIfExists(table).Error; err != nil {
|
||||||
t.Errorf("Failed to drop %s table: %v", tableName, err)
|
t.Errorf("Failed to drop %s table: %v", tableName, err)
|
||||||
}
|
}
|
||||||
if err := DB.CreateTable(table).Error; err != nil {
|
if err := DB.CreateTable(table).Error; err != nil {
|
||||||
t.Errorf("Failed to create %s table: %v", tableName, err)
|
t.Errorf("Failed to create %s table: %v", tableName, err)
|
||||||
}
|
}
|
||||||
if !scope.Dialect().HasIndex(tableName, "idx_index_with_prefixes_length") {
|
if !scope.Dialect().HasIndex(tableName, "idx_index_with_prefixes_length") {
|
||||||
t.Errorf("Failed to create %s table index:", tableName)
|
t.Errorf("Failed to create %s table index:", tableName)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
2800
preload_test.go
2800
preload_test.go
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user