Update structure
This commit is contained in:
		
							parent
							
								
									be7a4064ef
								
							
						
					
					
						commit
						da3bbdf2fb
					
				
							
								
								
									
										1050
									
								
								association_test.go
									
									
									
									
									
								
							
							
						
						
									
										1050
									
								
								association_test.go
									
									
									
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							@ -1,99 +0,0 @@
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"reflect"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Define callbacks for querying
 | 
			
		||||
func init() {
 | 
			
		||||
	DefaultCallback.Query().Register("gorm:query", queryCallback)
 | 
			
		||||
	DefaultCallback.Query().Register("gorm:preload", preloadCallback)
 | 
			
		||||
	DefaultCallback.Query().Register("gorm:after_query", afterQueryCallback)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// queryCallback used to query data from database
 | 
			
		||||
func queryCallback(scope *Scope) {
 | 
			
		||||
	if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defer scope.trace(NowFunc())
 | 
			
		||||
 | 
			
		||||
	var (
 | 
			
		||||
		isSlice, isPtr bool
 | 
			
		||||
		resultType     reflect.Type
 | 
			
		||||
		results        = scope.IndirectValue()
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok {
 | 
			
		||||
		if primaryField := scope.PrimaryField(); primaryField != nil {
 | 
			
		||||
			scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryField.DBName), orderBy))
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if value, ok := scope.Get("gorm:query_destination"); ok {
 | 
			
		||||
		results = indirect(reflect.ValueOf(value))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if kind := results.Kind(); kind == reflect.Slice {
 | 
			
		||||
		isSlice = true
 | 
			
		||||
		resultType = results.Type().Elem()
 | 
			
		||||
		results.Set(reflect.MakeSlice(results.Type(), 0, 0))
 | 
			
		||||
 | 
			
		||||
		if resultType.Kind() == reflect.Ptr {
 | 
			
		||||
			isPtr = true
 | 
			
		||||
			resultType = resultType.Elem()
 | 
			
		||||
		}
 | 
			
		||||
	} else if kind != reflect.Struct {
 | 
			
		||||
		scope.Err(errors.New("unsupported destination, should be slice or struct"))
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	scope.prepareQuerySQL()
 | 
			
		||||
 | 
			
		||||
	if !scope.HasError() {
 | 
			
		||||
		scope.db.RowsAffected = 0
 | 
			
		||||
		if str, ok := scope.Get("gorm:query_option"); ok {
 | 
			
		||||
			scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
 | 
			
		||||
			defer rows.Close()
 | 
			
		||||
 | 
			
		||||
			columns, _ := rows.Columns()
 | 
			
		||||
			for rows.Next() {
 | 
			
		||||
				scope.db.RowsAffected++
 | 
			
		||||
 | 
			
		||||
				elem := results
 | 
			
		||||
				if isSlice {
 | 
			
		||||
					elem = reflect.New(resultType).Elem()
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields())
 | 
			
		||||
 | 
			
		||||
				if isSlice {
 | 
			
		||||
					if isPtr {
 | 
			
		||||
						results.Set(reflect.Append(results, elem.Addr()))
 | 
			
		||||
					} else {
 | 
			
		||||
						results.Set(reflect.Append(results, elem))
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if err := rows.Err(); err != nil {
 | 
			
		||||
				scope.Err(err)
 | 
			
		||||
			} else if scope.db.RowsAffected == 0 && !isSlice {
 | 
			
		||||
				scope.Err(ErrRecordNotFound)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// afterQueryCallback will invoke `AfterFind` method after querying
 | 
			
		||||
func afterQueryCallback(scope *Scope) {
 | 
			
		||||
	if !scope.HasError() {
 | 
			
		||||
		scope.CallMethod("AfterFind")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -1,112 +0,0 @@
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"runtime"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func equalFuncs(funcs []*func(s *Scope), fnames []string) bool {
 | 
			
		||||
	var names []string
 | 
			
		||||
	for _, f := range funcs {
 | 
			
		||||
		fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(*f).Pointer()).Name(), ".")
 | 
			
		||||
		names = append(names, fnames[len(fnames)-1])
 | 
			
		||||
	}
 | 
			
		||||
	return reflect.DeepEqual(names, fnames)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func create(s *Scope)        {}
 | 
			
		||||
func beforeCreate1(s *Scope) {}
 | 
			
		||||
func beforeCreate2(s *Scope) {}
 | 
			
		||||
func afterCreate1(s *Scope)  {}
 | 
			
		||||
func afterCreate2(s *Scope)  {}
 | 
			
		||||
 | 
			
		||||
func TestRegisterCallback(t *testing.T) {
 | 
			
		||||
	var callback = &Callback{}
 | 
			
		||||
 | 
			
		||||
	callback.Create().Register("before_create1", beforeCreate1)
 | 
			
		||||
	callback.Create().Register("before_create2", beforeCreate2)
 | 
			
		||||
	callback.Create().Register("create", create)
 | 
			
		||||
	callback.Create().Register("after_create1", afterCreate1)
 | 
			
		||||
	callback.Create().Register("after_create2", afterCreate2)
 | 
			
		||||
 | 
			
		||||
	if !equalFuncs(callback.creates, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
 | 
			
		||||
		t.Errorf("register callback")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestRegisterCallbackWithOrder(t *testing.T) {
 | 
			
		||||
	var callback1 = &Callback{}
 | 
			
		||||
	callback1.Create().Register("before_create1", beforeCreate1)
 | 
			
		||||
	callback1.Create().Register("create", create)
 | 
			
		||||
	callback1.Create().Register("after_create1", afterCreate1)
 | 
			
		||||
	callback1.Create().Before("after_create1").Register("after_create2", afterCreate2)
 | 
			
		||||
	if !equalFuncs(callback1.creates, []string{"beforeCreate1", "create", "afterCreate2", "afterCreate1"}) {
 | 
			
		||||
		t.Errorf("register callback with order")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var callback2 = &Callback{}
 | 
			
		||||
 | 
			
		||||
	callback2.Update().Register("create", create)
 | 
			
		||||
	callback2.Update().Before("create").Register("before_create1", beforeCreate1)
 | 
			
		||||
	callback2.Update().After("after_create2").Register("after_create1", afterCreate1)
 | 
			
		||||
	callback2.Update().Before("before_create1").Register("before_create2", beforeCreate2)
 | 
			
		||||
	callback2.Update().Register("after_create2", afterCreate2)
 | 
			
		||||
 | 
			
		||||
	if !equalFuncs(callback2.updates, []string{"beforeCreate2", "beforeCreate1", "create", "afterCreate2", "afterCreate1"}) {
 | 
			
		||||
		t.Errorf("register callback with order")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestRegisterCallbackWithComplexOrder(t *testing.T) {
 | 
			
		||||
	var callback1 = &Callback{}
 | 
			
		||||
 | 
			
		||||
	callback1.Query().Before("after_create1").After("before_create1").Register("create", create)
 | 
			
		||||
	callback1.Query().Register("before_create1", beforeCreate1)
 | 
			
		||||
	callback1.Query().Register("after_create1", afterCreate1)
 | 
			
		||||
 | 
			
		||||
	if !equalFuncs(callback1.queries, []string{"beforeCreate1", "create", "afterCreate1"}) {
 | 
			
		||||
		t.Errorf("register callback with order")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var callback2 = &Callback{}
 | 
			
		||||
 | 
			
		||||
	callback2.Delete().Before("after_create1").After("before_create1").Register("create", create)
 | 
			
		||||
	callback2.Delete().Before("create").Register("before_create1", beforeCreate1)
 | 
			
		||||
	callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2)
 | 
			
		||||
	callback2.Delete().Register("after_create1", afterCreate1)
 | 
			
		||||
	callback2.Delete().After("after_create1").Register("after_create2", afterCreate2)
 | 
			
		||||
 | 
			
		||||
	if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
 | 
			
		||||
		t.Errorf("register callback with order")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func replaceCreate(s *Scope) {}
 | 
			
		||||
 | 
			
		||||
func TestReplaceCallback(t *testing.T) {
 | 
			
		||||
	var callback = &Callback{}
 | 
			
		||||
 | 
			
		||||
	callback.Create().Before("after_create1").After("before_create1").Register("create", create)
 | 
			
		||||
	callback.Create().Register("before_create1", beforeCreate1)
 | 
			
		||||
	callback.Create().Register("after_create1", afterCreate1)
 | 
			
		||||
	callback.Create().Replace("create", replaceCreate)
 | 
			
		||||
 | 
			
		||||
	if !equalFuncs(callback.creates, []string{"beforeCreate1", "replaceCreate", "afterCreate1"}) {
 | 
			
		||||
		t.Errorf("replace callback")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestRemoveCallback(t *testing.T) {
 | 
			
		||||
	var callback = &Callback{}
 | 
			
		||||
 | 
			
		||||
	callback.Create().Before("after_create1").After("before_create1").Register("create", create)
 | 
			
		||||
	callback.Create().Register("before_create1", beforeCreate1)
 | 
			
		||||
	callback.Create().Register("after_create1", afterCreate1)
 | 
			
		||||
	callback.Create().Remove("create")
 | 
			
		||||
 | 
			
		||||
	if !equalFuncs(callback.creates, []string{"beforeCreate1", "afterCreate1"}) {
 | 
			
		||||
		t.Errorf("remove callback")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -8,6 +8,98 @@ import (
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Define callbacks for querying
 | 
			
		||||
func init() {
 | 
			
		||||
	DefaultCallback.Query().Register("gorm:query", queryCallback)
 | 
			
		||||
	DefaultCallback.Query().Register("gorm:preload", preloadCallback)
 | 
			
		||||
	DefaultCallback.Query().Register("gorm:after_query", afterQueryCallback)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// queryCallback used to query data from database
 | 
			
		||||
func queryCallback(scope *Scope) {
 | 
			
		||||
	if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defer scope.trace(NowFunc())
 | 
			
		||||
 | 
			
		||||
	var (
 | 
			
		||||
		isSlice, isPtr bool
 | 
			
		||||
		resultType     reflect.Type
 | 
			
		||||
		results        = scope.IndirectValue()
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok {
 | 
			
		||||
		if primaryField := scope.PrimaryField(); primaryField != nil {
 | 
			
		||||
			scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryField.DBName), orderBy))
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if value, ok := scope.Get("gorm:query_destination"); ok {
 | 
			
		||||
		results = indirect(reflect.ValueOf(value))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if kind := results.Kind(); kind == reflect.Slice {
 | 
			
		||||
		isSlice = true
 | 
			
		||||
		resultType = results.Type().Elem()
 | 
			
		||||
		results.Set(reflect.MakeSlice(results.Type(), 0, 0))
 | 
			
		||||
 | 
			
		||||
		if resultType.Kind() == reflect.Ptr {
 | 
			
		||||
			isPtr = true
 | 
			
		||||
			resultType = resultType.Elem()
 | 
			
		||||
		}
 | 
			
		||||
	} else if kind != reflect.Struct {
 | 
			
		||||
		scope.Err(errors.New("unsupported destination, should be slice or struct"))
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	scope.prepareQuerySQL()
 | 
			
		||||
 | 
			
		||||
	if !scope.HasError() {
 | 
			
		||||
		scope.db.RowsAffected = 0
 | 
			
		||||
		if str, ok := scope.Get("gorm:query_option"); ok {
 | 
			
		||||
			scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
 | 
			
		||||
			defer rows.Close()
 | 
			
		||||
 | 
			
		||||
			columns, _ := rows.Columns()
 | 
			
		||||
			for rows.Next() {
 | 
			
		||||
				scope.db.RowsAffected++
 | 
			
		||||
 | 
			
		||||
				elem := results
 | 
			
		||||
				if isSlice {
 | 
			
		||||
					elem = reflect.New(resultType).Elem()
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields())
 | 
			
		||||
 | 
			
		||||
				if isSlice {
 | 
			
		||||
					if isPtr {
 | 
			
		||||
						results.Set(reflect.Append(results, elem.Addr()))
 | 
			
		||||
					} else {
 | 
			
		||||
						results.Set(reflect.Append(results, elem))
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if err := rows.Err(); err != nil {
 | 
			
		||||
				scope.Err(err)
 | 
			
		||||
			} else if scope.db.RowsAffected == 0 && !isSlice {
 | 
			
		||||
				scope.Err(ErrRecordNotFound)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// afterQueryCallback will invoke `AfterFind` method after querying
 | 
			
		||||
func afterQueryCallback(scope *Scope) {
 | 
			
		||||
	if !scope.HasError() {
 | 
			
		||||
		scope.CallMethod("AfterFind")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// preloadCallback used to preload associations
 | 
			
		||||
func preloadCallback(scope *Scope) {
 | 
			
		||||
	if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
 | 
			
		||||
@ -1,177 +0,0 @@
 | 
			
		||||
package gorm_test
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
 | 
			
		||||
	"github.com/jinzhu/gorm"
 | 
			
		||||
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func (s *Product) BeforeCreate() (err error) {
 | 
			
		||||
	if s.Code == "Invalid" {
 | 
			
		||||
		err = errors.New("invalid product")
 | 
			
		||||
	}
 | 
			
		||||
	s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Product) BeforeUpdate() (err error) {
 | 
			
		||||
	if s.Code == "dont_update" {
 | 
			
		||||
		err = errors.New("can't update")
 | 
			
		||||
	}
 | 
			
		||||
	s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Product) BeforeSave() (err error) {
 | 
			
		||||
	if s.Code == "dont_save" {
 | 
			
		||||
		err = errors.New("can't save")
 | 
			
		||||
	}
 | 
			
		||||
	s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Product) AfterFind() {
 | 
			
		||||
	s.AfterFindCallTimes = s.AfterFindCallTimes + 1
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Product) AfterCreate(tx *gorm.DB) {
 | 
			
		||||
	tx.Model(s).UpdateColumn(Product{AfterCreateCallTimes: s.AfterCreateCallTimes + 1})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Product) AfterUpdate() {
 | 
			
		||||
	s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Product) AfterSave() (err error) {
 | 
			
		||||
	if s.Code == "after_save_error" {
 | 
			
		||||
		err = errors.New("can't save")
 | 
			
		||||
	}
 | 
			
		||||
	s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Product) BeforeDelete() (err error) {
 | 
			
		||||
	if s.Code == "dont_delete" {
 | 
			
		||||
		err = errors.New("can't delete")
 | 
			
		||||
	}
 | 
			
		||||
	s.BeforeDeleteCallTimes = s.BeforeDeleteCallTimes + 1
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Product) AfterDelete() (err error) {
 | 
			
		||||
	if s.Code == "after_delete_error" {
 | 
			
		||||
		err = errors.New("can't delete")
 | 
			
		||||
	}
 | 
			
		||||
	s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Product) GetCallTimes() []int64 {
 | 
			
		||||
	return []int64{s.BeforeCreateCallTimes, s.BeforeSaveCallTimes, s.BeforeUpdateCallTimes, s.AfterCreateCallTimes, s.AfterSaveCallTimes, s.AfterUpdateCallTimes, s.BeforeDeleteCallTimes, s.AfterDeleteCallTimes, s.AfterFindCallTimes}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestRunCallbacks(t *testing.T) {
 | 
			
		||||
	p := Product{Code: "unique_code", Price: 100}
 | 
			
		||||
	DB.Save(&p)
 | 
			
		||||
 | 
			
		||||
	if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0}) {
 | 
			
		||||
		t.Errorf("Callbacks should be invoked successfully, %v", p.GetCallTimes())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Where("Code = ?", "unique_code").First(&p)
 | 
			
		||||
	if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0, 1}) {
 | 
			
		||||
		t.Errorf("After callbacks values are not saved, %v", p.GetCallTimes())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	p.Price = 200
 | 
			
		||||
	DB.Save(&p)
 | 
			
		||||
	if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0, 1}) {
 | 
			
		||||
		t.Errorf("After update callbacks should be invoked successfully, %v", p.GetCallTimes())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var products []Product
 | 
			
		||||
	DB.Find(&products, "code = ?", "unique_code")
 | 
			
		||||
	if products[0].AfterFindCallTimes != 2 {
 | 
			
		||||
		t.Errorf("AfterFind callbacks should work with slice")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Where("Code = ?", "unique_code").First(&p)
 | 
			
		||||
	if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2}) {
 | 
			
		||||
		t.Errorf("After update callbacks values are not saved, %v", p.GetCallTimes())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Delete(&p)
 | 
			
		||||
	if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2}) {
 | 
			
		||||
		t.Errorf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.Where("Code = ?", "unique_code").First(&p).Error == nil {
 | 
			
		||||
		t.Errorf("Can't find a deleted record")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestCallbacksWithErrors(t *testing.T) {
 | 
			
		||||
	p := Product{Code: "Invalid", Price: 100}
 | 
			
		||||
	if DB.Save(&p).Error == nil {
 | 
			
		||||
		t.Errorf("An error from before create callbacks happened when create with invalid value")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.Where("code = ?", "Invalid").First(&Product{}).Error == nil {
 | 
			
		||||
		t.Errorf("Should not save record that have errors")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.Save(&Product{Code: "dont_save", Price: 100}).Error == nil {
 | 
			
		||||
		t.Errorf("An error from after create callbacks happened when create with invalid value")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	p2 := Product{Code: "update_callback", Price: 100}
 | 
			
		||||
	DB.Save(&p2)
 | 
			
		||||
 | 
			
		||||
	p2.Code = "dont_update"
 | 
			
		||||
	if DB.Save(&p2).Error == nil {
 | 
			
		||||
		t.Errorf("An error from before update callbacks happened when update with invalid value")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.Where("code = ?", "update_callback").First(&Product{}).Error != nil {
 | 
			
		||||
		t.Errorf("Record Should not be updated due to errors happened in before update callback")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.Where("code = ?", "dont_update").First(&Product{}).Error == nil {
 | 
			
		||||
		t.Errorf("Record Should not be updated due to errors happened in before update callback")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	p2.Code = "dont_save"
 | 
			
		||||
	if DB.Save(&p2).Error == nil {
 | 
			
		||||
		t.Errorf("An error from before save callbacks happened when update with invalid value")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	p3 := Product{Code: "dont_delete", Price: 100}
 | 
			
		||||
	DB.Save(&p3)
 | 
			
		||||
	if DB.Delete(&p3).Error == nil {
 | 
			
		||||
		t.Errorf("An error from before delete callbacks happened when delete")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.Where("Code = ?", "dont_delete").First(&p3).Error != nil {
 | 
			
		||||
		t.Errorf("An error from before delete callbacks happened")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	p4 := Product{Code: "after_save_error", Price: 100}
 | 
			
		||||
	DB.Save(&p4)
 | 
			
		||||
	if err := DB.First(&Product{}, "code = ?", "after_save_error").Error; err == nil {
 | 
			
		||||
		t.Errorf("Record should be reverted if get an error in after save callback")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	p5 := Product{Code: "after_delete_error", Price: 100}
 | 
			
		||||
	DB.Save(&p5)
 | 
			
		||||
	if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
 | 
			
		||||
		t.Errorf("Record should be found")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Delete(&p5)
 | 
			
		||||
	if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
 | 
			
		||||
		t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										231
									
								
								create_test.go
									
									
									
									
									
								
							
							
						
						
									
										231
									
								
								create_test.go
									
									
									
									
									
								
							@ -1,231 +0,0 @@
 | 
			
		||||
package gorm_test
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"os"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/jinzhu/now"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestCreate(t *testing.T) {
 | 
			
		||||
	float := 35.03554004971999
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
	user := User{Name: "CreateUser", Age: 18, Birthday: &now, UserNum: Num(111), PasswordHash: []byte{'f', 'a', 'k', '4'}, Latitude: float}
 | 
			
		||||
 | 
			
		||||
	if !DB.NewRecord(user) || !DB.NewRecord(&user) {
 | 
			
		||||
		t.Error("User should be new record before create")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if count := DB.Save(&user).RowsAffected; count != 1 {
 | 
			
		||||
		t.Error("There should be one record be affected when create record")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.NewRecord(user) || DB.NewRecord(&user) {
 | 
			
		||||
		t.Error("User should not new record after save")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var newUser User
 | 
			
		||||
	if err := DB.First(&newUser, user.Id).Error; err != nil {
 | 
			
		||||
		t.Errorf("No error should happen, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !reflect.DeepEqual(newUser.PasswordHash, []byte{'f', 'a', 'k', '4'}) {
 | 
			
		||||
		t.Errorf("User's PasswordHash should be saved ([]byte)")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if newUser.Age != 18 {
 | 
			
		||||
		t.Errorf("User's Age should be saved (int)")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if newUser.UserNum != Num(111) {
 | 
			
		||||
		t.Errorf("User's UserNum should be saved (custom type), but got %v", newUser.UserNum)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if newUser.Latitude != float {
 | 
			
		||||
		t.Errorf("Float64 should not be changed after save")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if user.CreatedAt.IsZero() {
 | 
			
		||||
		t.Errorf("Should have created_at after create")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if newUser.CreatedAt.IsZero() {
 | 
			
		||||
		t.Errorf("Should have created_at after create")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Model(user).Update("name", "create_user_new_name")
 | 
			
		||||
	DB.First(&user, user.Id)
 | 
			
		||||
	if user.CreatedAt.Format(time.RFC3339Nano) != newUser.CreatedAt.Format(time.RFC3339Nano) {
 | 
			
		||||
		t.Errorf("CreatedAt should not be changed after update")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestCreateEmptyStrut(t *testing.T) {
 | 
			
		||||
	type EmptyStruct struct {
 | 
			
		||||
		ID uint
 | 
			
		||||
	}
 | 
			
		||||
	DB.AutoMigrate(&EmptyStruct{})
 | 
			
		||||
 | 
			
		||||
	if err := DB.Create(&EmptyStruct{}).Error; err != nil {
 | 
			
		||||
		t.Errorf("No error should happen when creating user, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestCreateWithExistingTimestamp(t *testing.T) {
 | 
			
		||||
	user := User{Name: "CreateUserExistingTimestamp"}
 | 
			
		||||
 | 
			
		||||
	timeA := now.MustParse("2016-01-01")
 | 
			
		||||
	user.CreatedAt = timeA
 | 
			
		||||
	user.UpdatedAt = timeA
 | 
			
		||||
	DB.Save(&user)
 | 
			
		||||
 | 
			
		||||
	if user.CreatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) {
 | 
			
		||||
		t.Errorf("CreatedAt should not be changed")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if user.UpdatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) {
 | 
			
		||||
		t.Errorf("UpdatedAt should not be changed")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var newUser User
 | 
			
		||||
	DB.First(&newUser, user.Id)
 | 
			
		||||
 | 
			
		||||
	if newUser.CreatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) {
 | 
			
		||||
		t.Errorf("CreatedAt should not be changed")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if newUser.UpdatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) {
 | 
			
		||||
		t.Errorf("UpdatedAt should not be changed")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AutoIncrementUser struct {
 | 
			
		||||
	User
 | 
			
		||||
	Sequence uint `gorm:"AUTO_INCREMENT"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestCreateWithAutoIncrement(t *testing.T) {
 | 
			
		||||
	if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" {
 | 
			
		||||
		t.Skip("Skipping this because only postgres properly support auto_increment on a non-primary_key column")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.AutoMigrate(&AutoIncrementUser{})
 | 
			
		||||
 | 
			
		||||
	user1 := AutoIncrementUser{}
 | 
			
		||||
	user2 := AutoIncrementUser{}
 | 
			
		||||
 | 
			
		||||
	DB.Create(&user1)
 | 
			
		||||
	DB.Create(&user2)
 | 
			
		||||
 | 
			
		||||
	if user2.Sequence-user1.Sequence != 1 {
 | 
			
		||||
		t.Errorf("Auto increment should apply on Sequence")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestCreateWithNoGORMPrimayKey(t *testing.T) {
 | 
			
		||||
	if dialect := os.Getenv("GORM_DIALECT"); dialect == "mssql" {
 | 
			
		||||
		t.Skip("Skipping this because MSSQL will return identity only if the table has an Id column")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	jt := JoinTable{From: 1, To: 2}
 | 
			
		||||
	err := DB.Create(&jt).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("No error should happen when create a record without a GORM primary key. But in the database this primary key exists and is the union of 2 or more fields\n But got: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestCreateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) {
 | 
			
		||||
	animal := Animal{Name: "Ferdinand"}
 | 
			
		||||
	if DB.Save(&animal).Error != nil {
 | 
			
		||||
		t.Errorf("No error should happen when create a record without std primary key")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if animal.Counter == 0 {
 | 
			
		||||
		t.Errorf("No std primary key should be filled value after create")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if animal.Name != "Ferdinand" {
 | 
			
		||||
		t.Errorf("Default value should be overrided")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Test create with default value not overrided
 | 
			
		||||
	an := Animal{From: "nerdz"}
 | 
			
		||||
 | 
			
		||||
	if DB.Save(&an).Error != nil {
 | 
			
		||||
		t.Errorf("No error should happen when create an record without std primary key")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// We must fetch the value again, to have the default fields updated
 | 
			
		||||
	// (We can't do this in the update statements, since sql default can be expressions
 | 
			
		||||
	// And be different from the fields' type (eg. a time.Time fields has a default value of "now()"
 | 
			
		||||
	DB.Model(Animal{}).Where(&Animal{Counter: an.Counter}).First(&an)
 | 
			
		||||
 | 
			
		||||
	if an.Name != "galeone" {
 | 
			
		||||
		t.Errorf("Default value should fill the field. But got %v", an.Name)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestAnonymousScanner(t *testing.T) {
 | 
			
		||||
	user := User{Name: "anonymous_scanner", Role: Role{Name: "admin"}}
 | 
			
		||||
	DB.Save(&user)
 | 
			
		||||
 | 
			
		||||
	var user2 User
 | 
			
		||||
	DB.First(&user2, "name = ?", "anonymous_scanner")
 | 
			
		||||
	if user2.Role.Name != "admin" {
 | 
			
		||||
		t.Errorf("Should be able to get anonymous scanner")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !user2.Role.IsAdmin() {
 | 
			
		||||
		t.Errorf("Should be able to get anonymous scanner")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestAnonymousField(t *testing.T) {
 | 
			
		||||
	user := User{Name: "anonymous_field", Company: Company{Name: "company"}}
 | 
			
		||||
	DB.Save(&user)
 | 
			
		||||
 | 
			
		||||
	var user2 User
 | 
			
		||||
	DB.First(&user2, "name = ?", "anonymous_field")
 | 
			
		||||
	DB.Model(&user2).Related(&user2.Company)
 | 
			
		||||
	if user2.Company.Name != "company" {
 | 
			
		||||
		t.Errorf("Should be able to get anonymous field")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestSelectWithCreate(t *testing.T) {
 | 
			
		||||
	user := getPreparedUser("select_user", "select_with_create")
 | 
			
		||||
	DB.Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Create(user)
 | 
			
		||||
 | 
			
		||||
	var queryuser User
 | 
			
		||||
	DB.Preload("BillingAddress").Preload("ShippingAddress").
 | 
			
		||||
		Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryuser, user.Id)
 | 
			
		||||
 | 
			
		||||
	if queryuser.Name != user.Name || queryuser.Age == user.Age {
 | 
			
		||||
		t.Errorf("Should only create users with name column")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if queryuser.BillingAddressID.Int64 == 0 || queryuser.ShippingAddressId != 0 ||
 | 
			
		||||
		queryuser.CreditCard.ID == 0 || len(queryuser.Emails) == 0 {
 | 
			
		||||
		t.Errorf("Should only create selected relationships")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestOmitWithCreate(t *testing.T) {
 | 
			
		||||
	user := getPreparedUser("omit_user", "omit_with_create")
 | 
			
		||||
	DB.Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Create(user)
 | 
			
		||||
 | 
			
		||||
	var queryuser User
 | 
			
		||||
	DB.Preload("BillingAddress").Preload("ShippingAddress").
 | 
			
		||||
		Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryuser, user.Id)
 | 
			
		||||
 | 
			
		||||
	if queryuser.Name == user.Name || queryuser.Age != user.Age {
 | 
			
		||||
		t.Errorf("Should only create users with age column")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if queryuser.BillingAddressID.Int64 != 0 || queryuser.ShippingAddressId == 0 ||
 | 
			
		||||
		queryuser.CreditCard.ID != 0 || len(queryuser.Emails) != 0 {
 | 
			
		||||
		t.Errorf("Should not create omitted relationships")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -1,346 +0,0 @@
 | 
			
		||||
package gorm_test
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/jinzhu/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type CustomizeColumn struct {
 | 
			
		||||
	ID   int64      `gorm:"column:mapped_id; primary_key:yes"`
 | 
			
		||||
	Name string     `gorm:"column:mapped_name"`
 | 
			
		||||
	Date *time.Time `gorm:"column:mapped_time"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Make sure an ignored field does not interfere with another field's custom
 | 
			
		||||
// column name that matches the ignored field.
 | 
			
		||||
type CustomColumnAndIgnoredFieldClash struct {
 | 
			
		||||
	Body    string `sql:"-"`
 | 
			
		||||
	RawBody string `gorm:"column:body"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestCustomizeColumn(t *testing.T) {
 | 
			
		||||
	col := "mapped_name"
 | 
			
		||||
	DB.DropTable(&CustomizeColumn{})
 | 
			
		||||
	DB.AutoMigrate(&CustomizeColumn{})
 | 
			
		||||
 | 
			
		||||
	scope := DB.NewScope(&CustomizeColumn{})
 | 
			
		||||
	if !scope.Dialect().HasColumn(scope.TableName(), col) {
 | 
			
		||||
		t.Errorf("CustomizeColumn should have column %s", col)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	col = "mapped_id"
 | 
			
		||||
	if scope.PrimaryKey() != col {
 | 
			
		||||
		t.Errorf("CustomizeColumn should have primary key %s, but got %q", col, scope.PrimaryKey())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	expected := "foo"
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
	cc := CustomizeColumn{ID: 666, Name: expected, Date: &now}
 | 
			
		||||
 | 
			
		||||
	if count := DB.Create(&cc).RowsAffected; count != 1 {
 | 
			
		||||
		t.Error("There should be one record be affected when create record")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var cc1 CustomizeColumn
 | 
			
		||||
	DB.First(&cc1, 666)
 | 
			
		||||
 | 
			
		||||
	if cc1.Name != expected {
 | 
			
		||||
		t.Errorf("Failed to query CustomizeColumn")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	cc.Name = "bar"
 | 
			
		||||
	DB.Save(&cc)
 | 
			
		||||
 | 
			
		||||
	var cc2 CustomizeColumn
 | 
			
		||||
	DB.First(&cc2, 666)
 | 
			
		||||
	if cc2.Name != "bar" {
 | 
			
		||||
		t.Errorf("Failed to query CustomizeColumn")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestCustomColumnAndIgnoredFieldClash(t *testing.T) {
 | 
			
		||||
	DB.DropTable(&CustomColumnAndIgnoredFieldClash{})
 | 
			
		||||
	if err := DB.AutoMigrate(&CustomColumnAndIgnoredFieldClash{}).Error; err != nil {
 | 
			
		||||
		t.Errorf("Should not raise error: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type CustomizePerson struct {
 | 
			
		||||
	IdPerson string             `gorm:"column:idPerson;primary_key:true"`
 | 
			
		||||
	Accounts []CustomizeAccount `gorm:"many2many:PersonAccount;associationforeignkey:idAccount;foreignkey:idPerson"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type CustomizeAccount struct {
 | 
			
		||||
	IdAccount string `gorm:"column:idAccount;primary_key:true"`
 | 
			
		||||
	Name      string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestManyToManyWithCustomizedColumn(t *testing.T) {
 | 
			
		||||
	DB.DropTable(&CustomizePerson{}, &CustomizeAccount{}, "PersonAccount")
 | 
			
		||||
	DB.AutoMigrate(&CustomizePerson{}, &CustomizeAccount{})
 | 
			
		||||
 | 
			
		||||
	account := CustomizeAccount{IdAccount: "account", Name: "id1"}
 | 
			
		||||
	person := CustomizePerson{
 | 
			
		||||
		IdPerson: "person",
 | 
			
		||||
		Accounts: []CustomizeAccount{account},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := DB.Create(&account).Error; err != nil {
 | 
			
		||||
		t.Errorf("no error should happen, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := DB.Create(&person).Error; err != nil {
 | 
			
		||||
		t.Errorf("no error should happen, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var person1 CustomizePerson
 | 
			
		||||
	scope := DB.NewScope(nil)
 | 
			
		||||
	if err := DB.Preload("Accounts").First(&person1, scope.Quote("idPerson")+" = ?", person.IdPerson).Error; err != nil {
 | 
			
		||||
		t.Errorf("no error should happen when preloading customized column many2many relations, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(person1.Accounts) != 1 || person1.Accounts[0].IdAccount != "account" {
 | 
			
		||||
		t.Errorf("should preload correct accounts")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type CustomizeUser struct {
 | 
			
		||||
	gorm.Model
 | 
			
		||||
	Email string `sql:"column:email_address"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type CustomizeInvitation struct {
 | 
			
		||||
	gorm.Model
 | 
			
		||||
	Address string         `sql:"column:invitation"`
 | 
			
		||||
	Person  *CustomizeUser `gorm:"foreignkey:Email;associationforeignkey:invitation"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestOneToOneWithCustomizedColumn(t *testing.T) {
 | 
			
		||||
	DB.DropTable(&CustomizeUser{}, &CustomizeInvitation{})
 | 
			
		||||
	DB.AutoMigrate(&CustomizeUser{}, &CustomizeInvitation{})
 | 
			
		||||
 | 
			
		||||
	user := CustomizeUser{
 | 
			
		||||
		Email: "hello@example.com",
 | 
			
		||||
	}
 | 
			
		||||
	invitation := CustomizeInvitation{
 | 
			
		||||
		Address: "hello@example.com",
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Create(&user)
 | 
			
		||||
	DB.Create(&invitation)
 | 
			
		||||
 | 
			
		||||
	var invitation2 CustomizeInvitation
 | 
			
		||||
	if err := DB.Preload("Person").Find(&invitation2, invitation.ID).Error; err != nil {
 | 
			
		||||
		t.Errorf("no error should happen, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if invitation2.Person.Email != user.Email {
 | 
			
		||||
		t.Errorf("Should preload one to one relation with customize foreign keys")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type PromotionDiscount struct {
 | 
			
		||||
	gorm.Model
 | 
			
		||||
	Name     string
 | 
			
		||||
	Coupons  []*PromotionCoupon `gorm:"ForeignKey:discount_id"`
 | 
			
		||||
	Rule     *PromotionRule     `gorm:"ForeignKey:discount_id"`
 | 
			
		||||
	Benefits []PromotionBenefit `gorm:"ForeignKey:promotion_id"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type PromotionBenefit struct {
 | 
			
		||||
	gorm.Model
 | 
			
		||||
	Name        string
 | 
			
		||||
	PromotionID uint
 | 
			
		||||
	Discount    PromotionDiscount `gorm:"ForeignKey:promotion_id"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type PromotionCoupon struct {
 | 
			
		||||
	gorm.Model
 | 
			
		||||
	Code       string
 | 
			
		||||
	DiscountID uint
 | 
			
		||||
	Discount   PromotionDiscount
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type PromotionRule struct {
 | 
			
		||||
	gorm.Model
 | 
			
		||||
	Name       string
 | 
			
		||||
	Begin      *time.Time
 | 
			
		||||
	End        *time.Time
 | 
			
		||||
	DiscountID uint
 | 
			
		||||
	Discount   *PromotionDiscount
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestOneToManyWithCustomizedColumn(t *testing.T) {
 | 
			
		||||
	DB.DropTable(&PromotionDiscount{}, &PromotionCoupon{})
 | 
			
		||||
	DB.AutoMigrate(&PromotionDiscount{}, &PromotionCoupon{})
 | 
			
		||||
 | 
			
		||||
	discount := PromotionDiscount{
 | 
			
		||||
		Name: "Happy New Year",
 | 
			
		||||
		Coupons: []*PromotionCoupon{
 | 
			
		||||
			{Code: "newyear1"},
 | 
			
		||||
			{Code: "newyear2"},
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := DB.Create(&discount).Error; err != nil {
 | 
			
		||||
		t.Errorf("no error should happen but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var discount1 PromotionDiscount
 | 
			
		||||
	if err := DB.Preload("Coupons").First(&discount1, "id = ?", discount.ID).Error; err != nil {
 | 
			
		||||
		t.Errorf("no error should happen but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(discount.Coupons) != 2 {
 | 
			
		||||
		t.Errorf("should find two coupons")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var coupon PromotionCoupon
 | 
			
		||||
	if err := DB.Preload("Discount").First(&coupon, "code = ?", "newyear1").Error; err != nil {
 | 
			
		||||
		t.Errorf("no error should happen but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if coupon.Discount.Name != "Happy New Year" {
 | 
			
		||||
		t.Errorf("should preload discount from coupon")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestHasOneWithPartialCustomizedColumn(t *testing.T) {
 | 
			
		||||
	DB.DropTable(&PromotionDiscount{}, &PromotionRule{})
 | 
			
		||||
	DB.AutoMigrate(&PromotionDiscount{}, &PromotionRule{})
 | 
			
		||||
 | 
			
		||||
	var begin = time.Now()
 | 
			
		||||
	var end = time.Now().Add(24 * time.Hour)
 | 
			
		||||
	discount := PromotionDiscount{
 | 
			
		||||
		Name: "Happy New Year 2",
 | 
			
		||||
		Rule: &PromotionRule{
 | 
			
		||||
			Name:  "time_limited",
 | 
			
		||||
			Begin: &begin,
 | 
			
		||||
			End:   &end,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := DB.Create(&discount).Error; err != nil {
 | 
			
		||||
		t.Errorf("no error should happen but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var discount1 PromotionDiscount
 | 
			
		||||
	if err := DB.Preload("Rule").First(&discount1, "id = ?", discount.ID).Error; err != nil {
 | 
			
		||||
		t.Errorf("no error should happen but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if discount.Rule.Begin.Format(time.RFC3339Nano) != begin.Format(time.RFC3339Nano) {
 | 
			
		||||
		t.Errorf("Should be able to preload Rule")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var rule PromotionRule
 | 
			
		||||
	if err := DB.Preload("Discount").First(&rule, "name = ?", "time_limited").Error; err != nil {
 | 
			
		||||
		t.Errorf("no error should happen but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if rule.Discount.Name != "Happy New Year 2" {
 | 
			
		||||
		t.Errorf("should preload discount from rule")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestBelongsToWithPartialCustomizedColumn(t *testing.T) {
 | 
			
		||||
	DB.DropTable(&PromotionDiscount{}, &PromotionBenefit{})
 | 
			
		||||
	DB.AutoMigrate(&PromotionDiscount{}, &PromotionBenefit{})
 | 
			
		||||
 | 
			
		||||
	discount := PromotionDiscount{
 | 
			
		||||
		Name: "Happy New Year 3",
 | 
			
		||||
		Benefits: []PromotionBenefit{
 | 
			
		||||
			{Name: "free cod"},
 | 
			
		||||
			{Name: "free shipping"},
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := DB.Create(&discount).Error; err != nil {
 | 
			
		||||
		t.Errorf("no error should happen but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var discount1 PromotionDiscount
 | 
			
		||||
	if err := DB.Preload("Benefits").First(&discount1, "id = ?", discount.ID).Error; err != nil {
 | 
			
		||||
		t.Errorf("no error should happen but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(discount.Benefits) != 2 {
 | 
			
		||||
		t.Errorf("should find two benefits")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var benefit PromotionBenefit
 | 
			
		||||
	if err := DB.Preload("Discount").First(&benefit, "name = ?", "free cod").Error; err != nil {
 | 
			
		||||
		t.Errorf("no error should happen but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if benefit.Discount.Name != "Happy New Year 3" {
 | 
			
		||||
		t.Errorf("should preload discount from coupon")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type SelfReferencingUser struct {
 | 
			
		||||
	gorm.Model
 | 
			
		||||
	Name    string
 | 
			
		||||
	Friends []*SelfReferencingUser `gorm:"many2many:UserFriends;association_jointable_foreignkey:friend_id"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestSelfReferencingMany2ManyColumn(t *testing.T) {
 | 
			
		||||
	DB.DropTable(&SelfReferencingUser{}, "UserFriends")
 | 
			
		||||
	DB.AutoMigrate(&SelfReferencingUser{})
 | 
			
		||||
 | 
			
		||||
	friend1 := SelfReferencingUser{Name: "friend1_m2m"}
 | 
			
		||||
	if err := DB.Create(&friend1).Error; err != nil {
 | 
			
		||||
		t.Errorf("no error should happen, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	friend2 := SelfReferencingUser{Name: "friend2_m2m"}
 | 
			
		||||
	if err := DB.Create(&friend2).Error; err != nil {
 | 
			
		||||
		t.Errorf("no error should happen, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	user := SelfReferencingUser{
 | 
			
		||||
		Name:    "self_m2m",
 | 
			
		||||
		Friends: []*SelfReferencingUser{&friend1, &friend2},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := DB.Create(&user).Error; err != nil {
 | 
			
		||||
		t.Errorf("no error should happen, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.Model(&user).Association("Friends").Count() != 2 {
 | 
			
		||||
		t.Errorf("Should find created friends correctly")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var newUser = SelfReferencingUser{}
 | 
			
		||||
 | 
			
		||||
	if err := DB.Preload("Friends").First(&newUser, "id = ?", user.ID).Error; err != nil {
 | 
			
		||||
		t.Errorf("no error should happen, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(newUser.Friends) != 2 {
 | 
			
		||||
		t.Errorf("Should preload created frineds for self reference m2m")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Model(&newUser).Association("Friends").Append(&SelfReferencingUser{Name: "friend3_m2m"})
 | 
			
		||||
	if DB.Model(&user).Association("Friends").Count() != 3 {
 | 
			
		||||
		t.Errorf("Should find created friends correctly")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Model(&newUser).Association("Friends").Replace(&SelfReferencingUser{Name: "friend4_m2m"})
 | 
			
		||||
	if DB.Model(&user).Association("Friends").Count() != 1 {
 | 
			
		||||
		t.Errorf("Should find created friends correctly")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	friend := SelfReferencingUser{}
 | 
			
		||||
	DB.Model(&newUser).Association("Friends").Find(&friend)
 | 
			
		||||
	if friend.Name != "friend4_m2m" {
 | 
			
		||||
		t.Errorf("Should find created friends correctly")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Model(&newUser).Association("Friends").Delete(friend)
 | 
			
		||||
	if DB.Model(&user).Association("Friends").Count() != 0 {
 | 
			
		||||
		t.Errorf("All friends should be deleted")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -1,91 +0,0 @@
 | 
			
		||||
package gorm_test
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestDelete(t *testing.T) {
 | 
			
		||||
	user1, user2 := User{Name: "delete1"}, User{Name: "delete2"}
 | 
			
		||||
	DB.Save(&user1)
 | 
			
		||||
	DB.Save(&user2)
 | 
			
		||||
 | 
			
		||||
	if err := DB.Delete(&user1).Error; err != nil {
 | 
			
		||||
		t.Errorf("No error should happen when delete a record, err=%s", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() {
 | 
			
		||||
		t.Errorf("User can't be found after delete")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.Where("name = ?", user2.Name).First(&User{}).RecordNotFound() {
 | 
			
		||||
		t.Errorf("Other users that not deleted should be found-able")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestInlineDelete(t *testing.T) {
 | 
			
		||||
	user1, user2 := User{Name: "inline_delete1"}, User{Name: "inline_delete2"}
 | 
			
		||||
	DB.Save(&user1)
 | 
			
		||||
	DB.Save(&user2)
 | 
			
		||||
 | 
			
		||||
	if DB.Delete(&User{}, user1.Id).Error != nil {
 | 
			
		||||
		t.Errorf("No error should happen when delete a record")
 | 
			
		||||
	} else if !DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() {
 | 
			
		||||
		t.Errorf("User can't be found after delete")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := DB.Delete(&User{}, "name = ?", user2.Name).Error; err != nil {
 | 
			
		||||
		t.Errorf("No error should happen when delete a record, err=%s", err)
 | 
			
		||||
	} else if !DB.Where("name = ?", user2.Name).First(&User{}).RecordNotFound() {
 | 
			
		||||
		t.Errorf("User can't be found after delete")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestSoftDelete(t *testing.T) {
 | 
			
		||||
	type User struct {
 | 
			
		||||
		Id        int64
 | 
			
		||||
		Name      string
 | 
			
		||||
		DeletedAt *time.Time
 | 
			
		||||
	}
 | 
			
		||||
	DB.AutoMigrate(&User{})
 | 
			
		||||
 | 
			
		||||
	user := User{Name: "soft_delete"}
 | 
			
		||||
	DB.Save(&user)
 | 
			
		||||
	DB.Delete(&user)
 | 
			
		||||
 | 
			
		||||
	if DB.First(&User{}, "name = ?", user.Name).Error == nil {
 | 
			
		||||
		t.Errorf("Can't find a soft deleted record")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := DB.Unscoped().First(&User{}, "name = ?", user.Name).Error; err != nil {
 | 
			
		||||
		t.Errorf("Should be able to find soft deleted record with Unscoped, but err=%s", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Unscoped().Delete(&user)
 | 
			
		||||
	if !DB.Unscoped().First(&User{}, "name = ?", user.Name).RecordNotFound() {
 | 
			
		||||
		t.Errorf("Can't find permanently deleted record")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestSoftDeleteWithCustomizedDeletedAtColumnName(t *testing.T) {
 | 
			
		||||
	creditCard := CreditCard{Number: "411111111234567"}
 | 
			
		||||
	DB.Save(&creditCard)
 | 
			
		||||
	DB.Delete(&creditCard)
 | 
			
		||||
 | 
			
		||||
	if deletedAtField, ok := DB.NewScope(&CreditCard{}).FieldByName("DeletedAt"); !ok || deletedAtField.DBName != "deleted_time" {
 | 
			
		||||
		t.Errorf("CreditCard's DeletedAt's column name should be `deleted_time`")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.First(&CreditCard{}, "number = ?", creditCard.Number).Error == nil {
 | 
			
		||||
		t.Errorf("Can't find a soft deleted record")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := DB.Unscoped().First(&CreditCard{}, "number = ?", creditCard.Number).Error; err != nil {
 | 
			
		||||
		t.Errorf("Should be able to find soft deleted record with Unscoped, but err=%s", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Unscoped().Delete(&creditCard)
 | 
			
		||||
	if !DB.Unscoped().First(&CreditCard{}, "number = ?", creditCard.Number).RecordNotFound() {
 | 
			
		||||
		t.Errorf("Can't find permanently deleted record")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										136
									
								
								dialect.go
									
									
									
									
									
								
							
							
						
						
									
										136
									
								
								dialect.go
									
									
									
									
									
								
							@ -1,130 +1,20 @@
 | 
			
		||||
package gorm
 | 
			
		||||
package dialects
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"github.com/jinzhu/gorm"
 | 
			
		||||
	"github.com/jinzhu/gorm/builder"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Dialect interface contains behaviors that differ across SQL database
 | 
			
		||||
// Dialect GORM dialect interface
 | 
			
		||||
type Dialect interface {
 | 
			
		||||
	// GetName get dialect's name
 | 
			
		||||
	GetName() string
 | 
			
		||||
	// CRUD operations
 | 
			
		||||
	Insert(*gorm.DB, builder.Statement) error
 | 
			
		||||
	Query(*gorm.DB, builder.Statement) error
 | 
			
		||||
	Update(*gorm.DB, builder.Statement) error
 | 
			
		||||
	Delete(*gorm.DB, builder.Statement) error
 | 
			
		||||
 | 
			
		||||
	// SetDB set db for dialect
 | 
			
		||||
	SetDB(db SQLCommon)
 | 
			
		||||
 | 
			
		||||
	// BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1
 | 
			
		||||
	BindVar(i int) string
 | 
			
		||||
	// Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name
 | 
			
		||||
	Quote(key string) string
 | 
			
		||||
	// DataTypeOf return data's sql type
 | 
			
		||||
	DataTypeOf(field *StructField) string
 | 
			
		||||
 | 
			
		||||
	// HasIndex check has index or not
 | 
			
		||||
	HasIndex(tableName string, indexName string) bool
 | 
			
		||||
	// HasForeignKey check has foreign key or not
 | 
			
		||||
	HasForeignKey(tableName string, foreignKeyName string) bool
 | 
			
		||||
	// RemoveIndex remove index
 | 
			
		||||
	RemoveIndex(tableName string, indexName string) error
 | 
			
		||||
	// HasTable check has table or not
 | 
			
		||||
	HasTable(tableName string) bool
 | 
			
		||||
	// HasColumn check has column or not
 | 
			
		||||
	HasColumn(tableName string, columnName string) bool
 | 
			
		||||
	// ModifyColumn modify column's type
 | 
			
		||||
	ModifyColumn(tableName string, columnName string, typ string) error
 | 
			
		||||
 | 
			
		||||
	// LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case
 | 
			
		||||
	LimitAndOffsetSQL(limit, offset interface{}) string
 | 
			
		||||
	// SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
 | 
			
		||||
	SelectFromDummyTable() string
 | 
			
		||||
	// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
 | 
			
		||||
	LastInsertIDReturningSuffix(tableName, columnName string) string
 | 
			
		||||
	// DefaultValueStr
 | 
			
		||||
	DefaultValueStr() string
 | 
			
		||||
 | 
			
		||||
	// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference
 | 
			
		||||
	BuildKeyName(kind, tableName string, fields ...string) string
 | 
			
		||||
 | 
			
		||||
	// CurrentDatabase return current database name
 | 
			
		||||
	CurrentDatabase() string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var dialectsMap = map[string]Dialect{}
 | 
			
		||||
 | 
			
		||||
func newDialect(name string, db SQLCommon) Dialect {
 | 
			
		||||
	if value, ok := dialectsMap[name]; ok {
 | 
			
		||||
		dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect)
 | 
			
		||||
		dialect.SetDB(db)
 | 
			
		||||
		return dialect
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", name)
 | 
			
		||||
	commontDialect := &commonDialect{}
 | 
			
		||||
	commontDialect.SetDB(db)
 | 
			
		||||
	return commontDialect
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RegisterDialect register new dialect
 | 
			
		||||
func RegisterDialect(name string, dialect Dialect) {
 | 
			
		||||
	dialectsMap[name] = dialect
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ParseFieldStructForDialect get field's sql data type
 | 
			
		||||
var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fieldValue reflect.Value, sqlType string, size int, additionalType string) {
 | 
			
		||||
	// Get redirected field type
 | 
			
		||||
	var (
 | 
			
		||||
		reflectType = field.Struct.Type
 | 
			
		||||
		dataType    = field.TagSettings["TYPE"]
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	for reflectType.Kind() == reflect.Ptr {
 | 
			
		||||
		reflectType = reflectType.Elem()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Get redirected field value
 | 
			
		||||
	fieldValue = reflect.Indirect(reflect.New(reflectType))
 | 
			
		||||
 | 
			
		||||
	if gormDataType, ok := fieldValue.Interface().(interface {
 | 
			
		||||
		GormDataType(Dialect) string
 | 
			
		||||
	}); ok {
 | 
			
		||||
		dataType = gormDataType.GormDataType(dialect)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Get scanner's real value
 | 
			
		||||
	if dataType == "" {
 | 
			
		||||
		var getScannerValue func(reflect.Value)
 | 
			
		||||
		getScannerValue = func(value reflect.Value) {
 | 
			
		||||
			fieldValue = value
 | 
			
		||||
			if _, isScanner := reflect.New(fieldValue.Type()).Interface().(sql.Scanner); isScanner && fieldValue.Kind() == reflect.Struct {
 | 
			
		||||
				getScannerValue(fieldValue.Field(0))
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		getScannerValue(fieldValue)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Default Size
 | 
			
		||||
	if num, ok := field.TagSettings["SIZE"]; ok {
 | 
			
		||||
		size, _ = strconv.Atoi(num)
 | 
			
		||||
	} else {
 | 
			
		||||
		size = 255
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Default type from tag setting
 | 
			
		||||
	additionalType = field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"]
 | 
			
		||||
	if value, ok := field.TagSettings["DEFAULT"]; ok {
 | 
			
		||||
		additionalType = additionalType + " DEFAULT " + value
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return fieldValue, dataType, size, strings.TrimSpace(additionalType)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func currentDatabaseAndTable(dialect Dialect, tableName string) (string, string) {
 | 
			
		||||
	if strings.Contains(tableName, ".") {
 | 
			
		||||
		splitStrings := strings.SplitN(tableName, ".", 2)
 | 
			
		||||
		return splitStrings[0], splitStrings[1]
 | 
			
		||||
	}
 | 
			
		||||
	return dialect.CurrentDatabase(), tableName
 | 
			
		||||
	// DB Driver interface
 | 
			
		||||
	QueryRow(*gorm.DB) error
 | 
			
		||||
	QueryRows(*gorm.DB) error
 | 
			
		||||
	Exec(*gorm.DB) error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -1,176 +0,0 @@
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"regexp"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// DefaultForeignKeyNamer contains the default foreign key name generator method
 | 
			
		||||
type DefaultForeignKeyNamer struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type commonDialect struct {
 | 
			
		||||
	db SQLCommon
 | 
			
		||||
	DefaultForeignKeyNamer
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	RegisterDialect("common", &commonDialect{})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (commonDialect) GetName() string {
 | 
			
		||||
	return "common"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *commonDialect) SetDB(db SQLCommon) {
 | 
			
		||||
	s.db = db
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (commonDialect) BindVar(i int) string {
 | 
			
		||||
	return "$$$" // ?
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (commonDialect) Quote(key string) string {
 | 
			
		||||
	return fmt.Sprintf(`"%s"`, key)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *commonDialect) fieldCanAutoIncrement(field *StructField) bool {
 | 
			
		||||
	if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
 | 
			
		||||
		return strings.ToLower(value) != "false"
 | 
			
		||||
	}
 | 
			
		||||
	return field.IsPrimaryKey
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *commonDialect) DataTypeOf(field *StructField) string {
 | 
			
		||||
	var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)
 | 
			
		||||
 | 
			
		||||
	if sqlType == "" {
 | 
			
		||||
		switch dataValue.Kind() {
 | 
			
		||||
		case reflect.Bool:
 | 
			
		||||
			sqlType = "BOOLEAN"
 | 
			
		||||
		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
 | 
			
		||||
			if s.fieldCanAutoIncrement(field) {
 | 
			
		||||
				sqlType = "INTEGER AUTO_INCREMENT"
 | 
			
		||||
			} else {
 | 
			
		||||
				sqlType = "INTEGER"
 | 
			
		||||
			}
 | 
			
		||||
		case reflect.Int64, reflect.Uint64:
 | 
			
		||||
			if s.fieldCanAutoIncrement(field) {
 | 
			
		||||
				sqlType = "BIGINT AUTO_INCREMENT"
 | 
			
		||||
			} else {
 | 
			
		||||
				sqlType = "BIGINT"
 | 
			
		||||
			}
 | 
			
		||||
		case reflect.Float32, reflect.Float64:
 | 
			
		||||
			sqlType = "FLOAT"
 | 
			
		||||
		case reflect.String:
 | 
			
		||||
			if size > 0 && size < 65532 {
 | 
			
		||||
				sqlType = fmt.Sprintf("VARCHAR(%d)", size)
 | 
			
		||||
			} else {
 | 
			
		||||
				sqlType = "VARCHAR(65532)"
 | 
			
		||||
			}
 | 
			
		||||
		case reflect.Struct:
 | 
			
		||||
			if _, ok := dataValue.Interface().(time.Time); ok {
 | 
			
		||||
				sqlType = "TIMESTAMP"
 | 
			
		||||
			}
 | 
			
		||||
		default:
 | 
			
		||||
			if _, ok := dataValue.Interface().([]byte); ok {
 | 
			
		||||
				if size > 0 && size < 65532 {
 | 
			
		||||
					sqlType = fmt.Sprintf("BINARY(%d)", size)
 | 
			
		||||
				} else {
 | 
			
		||||
					sqlType = "BINARY(65532)"
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if sqlType == "" {
 | 
			
		||||
		panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", dataValue.Type().Name(), dataValue.Kind().String()))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if strings.TrimSpace(additionalType) == "" {
 | 
			
		||||
		return sqlType
 | 
			
		||||
	}
 | 
			
		||||
	return fmt.Sprintf("%v %v", sqlType, additionalType)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s commonDialect) HasIndex(tableName string, indexName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
 | 
			
		||||
	s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s commonDialect) RemoveIndex(tableName string, indexName string) error {
 | 
			
		||||
	_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName))
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool {
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s commonDialect) HasTable(tableName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
 | 
			
		||||
	s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s commonDialect) HasColumn(tableName string, columnName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
 | 
			
		||||
	s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s commonDialect) ModifyColumn(tableName string, columnName string, typ string) error {
 | 
			
		||||
	_, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", tableName, columnName, typ))
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s commonDialect) CurrentDatabase() (name string) {
 | 
			
		||||
	s.db.QueryRow("SELECT DATABASE()").Scan(&name)
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
 | 
			
		||||
	if limit != nil {
 | 
			
		||||
		if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
 | 
			
		||||
			sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if offset != nil {
 | 
			
		||||
		if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
 | 
			
		||||
			sql += fmt.Sprintf(" OFFSET %d", parsedOffset)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (commonDialect) SelectFromDummyTable() string {
 | 
			
		||||
	return ""
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string {
 | 
			
		||||
	return ""
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (commonDialect) DefaultValueStr() string {
 | 
			
		||||
	return "DEFAULT VALUES"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference
 | 
			
		||||
func (DefaultForeignKeyNamer) BuildKeyName(kind, tableName string, fields ...string) string {
 | 
			
		||||
	keyName := fmt.Sprintf("%s_%s_%s", kind, tableName, strings.Join(fields, "_"))
 | 
			
		||||
	keyName = regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(keyName, "_")
 | 
			
		||||
	return keyName
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IsByteArrayOrSlice returns true of the reflected value is an array or slice
 | 
			
		||||
func IsByteArrayOrSlice(value reflect.Value) bool {
 | 
			
		||||
	return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0))
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										191
									
								
								dialect_mysql.go
									
									
									
									
									
								
							
							
						
						
									
										191
									
								
								dialect_mysql.go
									
									
									
									
									
								
							@ -1,191 +0,0 @@
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto/sha1"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"regexp"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
	"unicode/utf8"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type mysql struct {
 | 
			
		||||
	commonDialect
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	RegisterDialect("mysql", &mysql{})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (mysql) GetName() string {
 | 
			
		||||
	return "mysql"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (mysql) Quote(key string) string {
 | 
			
		||||
	return fmt.Sprintf("`%s`", key)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Get Data Type for MySQL Dialect
 | 
			
		||||
func (s *mysql) DataTypeOf(field *StructField) string {
 | 
			
		||||
	var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)
 | 
			
		||||
 | 
			
		||||
	// MySQL allows only one auto increment column per table, and it must
 | 
			
		||||
	// be a KEY column.
 | 
			
		||||
	if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
 | 
			
		||||
		if _, ok = field.TagSettings["INDEX"]; !ok && !field.IsPrimaryKey {
 | 
			
		||||
			delete(field.TagSettings, "AUTO_INCREMENT")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if sqlType == "" {
 | 
			
		||||
		switch dataValue.Kind() {
 | 
			
		||||
		case reflect.Bool:
 | 
			
		||||
			sqlType = "boolean"
 | 
			
		||||
		case reflect.Int8:
 | 
			
		||||
			if s.fieldCanAutoIncrement(field) {
 | 
			
		||||
				field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
 | 
			
		||||
				sqlType = "tinyint AUTO_INCREMENT"
 | 
			
		||||
			} else {
 | 
			
		||||
				sqlType = "tinyint"
 | 
			
		||||
			}
 | 
			
		||||
		case reflect.Int, reflect.Int16, reflect.Int32:
 | 
			
		||||
			if s.fieldCanAutoIncrement(field) {
 | 
			
		||||
				field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
 | 
			
		||||
				sqlType = "int AUTO_INCREMENT"
 | 
			
		||||
			} else {
 | 
			
		||||
				sqlType = "int"
 | 
			
		||||
			}
 | 
			
		||||
		case reflect.Uint8:
 | 
			
		||||
			if s.fieldCanAutoIncrement(field) {
 | 
			
		||||
				field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
 | 
			
		||||
				sqlType = "tinyint unsigned AUTO_INCREMENT"
 | 
			
		||||
			} else {
 | 
			
		||||
				sqlType = "tinyint unsigned"
 | 
			
		||||
			}
 | 
			
		||||
		case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
 | 
			
		||||
			if s.fieldCanAutoIncrement(field) {
 | 
			
		||||
				field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
 | 
			
		||||
				sqlType = "int unsigned AUTO_INCREMENT"
 | 
			
		||||
			} else {
 | 
			
		||||
				sqlType = "int unsigned"
 | 
			
		||||
			}
 | 
			
		||||
		case reflect.Int64:
 | 
			
		||||
			if s.fieldCanAutoIncrement(field) {
 | 
			
		||||
				field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
 | 
			
		||||
				sqlType = "bigint AUTO_INCREMENT"
 | 
			
		||||
			} else {
 | 
			
		||||
				sqlType = "bigint"
 | 
			
		||||
			}
 | 
			
		||||
		case reflect.Uint64:
 | 
			
		||||
			if s.fieldCanAutoIncrement(field) {
 | 
			
		||||
				field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
 | 
			
		||||
				sqlType = "bigint unsigned AUTO_INCREMENT"
 | 
			
		||||
			} else {
 | 
			
		||||
				sqlType = "bigint unsigned"
 | 
			
		||||
			}
 | 
			
		||||
		case reflect.Float32, reflect.Float64:
 | 
			
		||||
			sqlType = "double"
 | 
			
		||||
		case reflect.String:
 | 
			
		||||
			if size > 0 && size < 65532 {
 | 
			
		||||
				sqlType = fmt.Sprintf("varchar(%d)", size)
 | 
			
		||||
			} else {
 | 
			
		||||
				sqlType = "longtext"
 | 
			
		||||
			}
 | 
			
		||||
		case reflect.Struct:
 | 
			
		||||
			if _, ok := dataValue.Interface().(time.Time); ok {
 | 
			
		||||
				precision := ""
 | 
			
		||||
				if p, ok := field.TagSettings["PRECISION"]; ok {
 | 
			
		||||
					precision = fmt.Sprintf("(%s)", p)
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				if _, ok := field.TagSettings["NOT NULL"]; ok {
 | 
			
		||||
					sqlType = fmt.Sprintf("timestamp%v", precision)
 | 
			
		||||
				} else {
 | 
			
		||||
					sqlType = fmt.Sprintf("timestamp%v NULL", precision)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		default:
 | 
			
		||||
			if IsByteArrayOrSlice(dataValue) {
 | 
			
		||||
				if size > 0 && size < 65532 {
 | 
			
		||||
					sqlType = fmt.Sprintf("varbinary(%d)", size)
 | 
			
		||||
				} else {
 | 
			
		||||
					sqlType = "longblob"
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if sqlType == "" {
 | 
			
		||||
		panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", dataValue.Type().Name(), dataValue.Kind().String()))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if strings.TrimSpace(additionalType) == "" {
 | 
			
		||||
		return sqlType
 | 
			
		||||
	}
 | 
			
		||||
	return fmt.Sprintf("%v %v", sqlType, additionalType)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s mysql) RemoveIndex(tableName string, indexName string) error {
 | 
			
		||||
	_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s mysql) ModifyColumn(tableName string, columnName string, typ string) error {
 | 
			
		||||
	_, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v MODIFY COLUMN %v %v", tableName, columnName, typ))
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
 | 
			
		||||
	if limit != nil {
 | 
			
		||||
		if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
 | 
			
		||||
			sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
 | 
			
		||||
 | 
			
		||||
			if offset != nil {
 | 
			
		||||
				if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
 | 
			
		||||
					sql += fmt.Sprintf(" OFFSET %d", parsedOffset)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
 | 
			
		||||
	s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", currentDatabase, tableName, foreignKeyName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s mysql) CurrentDatabase() (name string) {
 | 
			
		||||
	s.db.QueryRow("SELECT DATABASE()").Scan(&name)
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (mysql) SelectFromDummyTable() string {
 | 
			
		||||
	return "FROM DUAL"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s mysql) BuildKeyName(kind, tableName string, fields ...string) string {
 | 
			
		||||
	keyName := s.commonDialect.BuildKeyName(kind, tableName, fields...)
 | 
			
		||||
	if utf8.RuneCountInString(keyName) <= 64 {
 | 
			
		||||
		return keyName
 | 
			
		||||
	}
 | 
			
		||||
	h := sha1.New()
 | 
			
		||||
	h.Write([]byte(keyName))
 | 
			
		||||
	bs := h.Sum(nil)
 | 
			
		||||
 | 
			
		||||
	// sha1 is 40 characters, keep first 24 characters of destination
 | 
			
		||||
	destRunes := []rune(regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(fields[0], "_"))
 | 
			
		||||
	if len(destRunes) > 24 {
 | 
			
		||||
		destRunes = destRunes[:24]
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return fmt.Sprintf("%s%x", string(destRunes), bs)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (mysql) DefaultValueStr() string {
 | 
			
		||||
	return "VALUES()"
 | 
			
		||||
}
 | 
			
		||||
@ -1,143 +0,0 @@
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type postgres struct {
 | 
			
		||||
	commonDialect
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	RegisterDialect("postgres", &postgres{})
 | 
			
		||||
	RegisterDialect("cloudsqlpostgres", &postgres{})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (postgres) GetName() string {
 | 
			
		||||
	return "postgres"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (postgres) BindVar(i int) string {
 | 
			
		||||
	return fmt.Sprintf("$%v", i)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *postgres) DataTypeOf(field *StructField) string {
 | 
			
		||||
	var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)
 | 
			
		||||
 | 
			
		||||
	if sqlType == "" {
 | 
			
		||||
		switch dataValue.Kind() {
 | 
			
		||||
		case reflect.Bool:
 | 
			
		||||
			sqlType = "boolean"
 | 
			
		||||
		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uintptr:
 | 
			
		||||
			if s.fieldCanAutoIncrement(field) {
 | 
			
		||||
				field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
 | 
			
		||||
				sqlType = "serial"
 | 
			
		||||
			} else {
 | 
			
		||||
				sqlType = "integer"
 | 
			
		||||
			}
 | 
			
		||||
		case reflect.Int64, reflect.Uint32, reflect.Uint64:
 | 
			
		||||
			if s.fieldCanAutoIncrement(field) {
 | 
			
		||||
				field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
 | 
			
		||||
				sqlType = "bigserial"
 | 
			
		||||
			} else {
 | 
			
		||||
				sqlType = "bigint"
 | 
			
		||||
			}
 | 
			
		||||
		case reflect.Float32, reflect.Float64:
 | 
			
		||||
			sqlType = "numeric"
 | 
			
		||||
		case reflect.String:
 | 
			
		||||
			if _, ok := field.TagSettings["SIZE"]; !ok {
 | 
			
		||||
				size = 0 // if SIZE haven't been set, use `text` as the default type, as there are no performance different
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if size > 0 && size < 65532 {
 | 
			
		||||
				sqlType = fmt.Sprintf("varchar(%d)", size)
 | 
			
		||||
			} else {
 | 
			
		||||
				sqlType = "text"
 | 
			
		||||
			}
 | 
			
		||||
		case reflect.Struct:
 | 
			
		||||
			if _, ok := dataValue.Interface().(time.Time); ok {
 | 
			
		||||
				sqlType = "timestamp with time zone"
 | 
			
		||||
			}
 | 
			
		||||
		case reflect.Map:
 | 
			
		||||
			if dataValue.Type().Name() == "Hstore" {
 | 
			
		||||
				sqlType = "hstore"
 | 
			
		||||
			}
 | 
			
		||||
		default:
 | 
			
		||||
			if IsByteArrayOrSlice(dataValue) {
 | 
			
		||||
				sqlType = "bytea"
 | 
			
		||||
 | 
			
		||||
				if isUUID(dataValue) {
 | 
			
		||||
					sqlType = "uuid"
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				if isJSON(dataValue) {
 | 
			
		||||
					sqlType = "jsonb"
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if sqlType == "" {
 | 
			
		||||
		panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", dataValue.Type().Name(), dataValue.Kind().String()))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if strings.TrimSpace(additionalType) == "" {
 | 
			
		||||
		return sqlType
 | 
			
		||||
	}
 | 
			
		||||
	return fmt.Sprintf("%v %v", sqlType, additionalType)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s postgres) HasIndex(tableName string, indexName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()", tableName, indexName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRow("SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", tableName, foreignKeyName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s postgres) HasTable(tableName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()", tableName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s postgres) HasColumn(tableName string, columnName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()", tableName, columnName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s postgres) CurrentDatabase() (name string) {
 | 
			
		||||
	s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name)
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string {
 | 
			
		||||
	return fmt.Sprintf("RETURNING %v.%v", tableName, key)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (postgres) SupportLastInsertID() bool {
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func isUUID(value reflect.Value) bool {
 | 
			
		||||
	if value.Kind() != reflect.Array || value.Type().Len() != 16 {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	typename := value.Type().Name()
 | 
			
		||||
	lower := strings.ToLower(typename)
 | 
			
		||||
	return "uuid" == lower || "guid" == lower
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func isJSON(value reflect.Value) bool {
 | 
			
		||||
	_, ok := value.Interface().(json.RawMessage)
 | 
			
		||||
	return ok
 | 
			
		||||
}
 | 
			
		||||
@ -1,107 +0,0 @@
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type sqlite3 struct {
 | 
			
		||||
	commonDialect
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	RegisterDialect("sqlite3", &sqlite3{})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (sqlite3) GetName() string {
 | 
			
		||||
	return "sqlite3"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Get Data Type for Sqlite Dialect
 | 
			
		||||
func (s *sqlite3) DataTypeOf(field *StructField) string {
 | 
			
		||||
	var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)
 | 
			
		||||
 | 
			
		||||
	if sqlType == "" {
 | 
			
		||||
		switch dataValue.Kind() {
 | 
			
		||||
		case reflect.Bool:
 | 
			
		||||
			sqlType = "bool"
 | 
			
		||||
		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
 | 
			
		||||
			if s.fieldCanAutoIncrement(field) {
 | 
			
		||||
				field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
 | 
			
		||||
				sqlType = "integer primary key autoincrement"
 | 
			
		||||
			} else {
 | 
			
		||||
				sqlType = "integer"
 | 
			
		||||
			}
 | 
			
		||||
		case reflect.Int64, reflect.Uint64:
 | 
			
		||||
			if s.fieldCanAutoIncrement(field) {
 | 
			
		||||
				field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
 | 
			
		||||
				sqlType = "integer primary key autoincrement"
 | 
			
		||||
			} else {
 | 
			
		||||
				sqlType = "bigint"
 | 
			
		||||
			}
 | 
			
		||||
		case reflect.Float32, reflect.Float64:
 | 
			
		||||
			sqlType = "real"
 | 
			
		||||
		case reflect.String:
 | 
			
		||||
			if size > 0 && size < 65532 {
 | 
			
		||||
				sqlType = fmt.Sprintf("varchar(%d)", size)
 | 
			
		||||
			} else {
 | 
			
		||||
				sqlType = "text"
 | 
			
		||||
			}
 | 
			
		||||
		case reflect.Struct:
 | 
			
		||||
			if _, ok := dataValue.Interface().(time.Time); ok {
 | 
			
		||||
				sqlType = "datetime"
 | 
			
		||||
			}
 | 
			
		||||
		default:
 | 
			
		||||
			if IsByteArrayOrSlice(dataValue) {
 | 
			
		||||
				sqlType = "blob"
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if sqlType == "" {
 | 
			
		||||
		panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", dataValue.Type().Name(), dataValue.Kind().String()))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if strings.TrimSpace(additionalType) == "" {
 | 
			
		||||
		return sqlType
 | 
			
		||||
	}
 | 
			
		||||
	return fmt.Sprintf("%v %v", sqlType, additionalType)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s sqlite3) HasIndex(tableName string, indexName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s sqlite3) HasTable(tableName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRow("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s sqlite3) HasColumn(tableName string, columnName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n", columnName, columnName), tableName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s sqlite3) CurrentDatabase() (name string) {
 | 
			
		||||
	var (
 | 
			
		||||
		ifaces   = make([]interface{}, 3)
 | 
			
		||||
		pointers = make([]*string, 3)
 | 
			
		||||
		i        int
 | 
			
		||||
	)
 | 
			
		||||
	for i = 0; i < 3; i++ {
 | 
			
		||||
		ifaces[i] = &pointers[i]
 | 
			
		||||
	}
 | 
			
		||||
	if err := s.db.QueryRow("PRAGMA database_list").Scan(ifaces...); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if pointers[1] != nil {
 | 
			
		||||
		name = *pointers[1]
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
@ -1,91 +0,0 @@
 | 
			
		||||
package gorm_test
 | 
			
		||||
 | 
			
		||||
import "testing"
 | 
			
		||||
 | 
			
		||||
type BasePost struct {
 | 
			
		||||
	Id    int64
 | 
			
		||||
	Title string
 | 
			
		||||
	URL   string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Author struct {
 | 
			
		||||
	ID    string
 | 
			
		||||
	Name  string
 | 
			
		||||
	Email string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type HNPost struct {
 | 
			
		||||
	BasePost
 | 
			
		||||
	Author  `gorm:"embedded_prefix:user_"` // Embedded struct
 | 
			
		||||
	Upvotes int32
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type EngadgetPost struct {
 | 
			
		||||
	BasePost BasePost `gorm:"embedded"`
 | 
			
		||||
	Author   Author   `gorm:"embedded;embedded_prefix:author_"` // Embedded struct
 | 
			
		||||
	ImageUrl string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestPrefixColumnNameForEmbeddedStruct(t *testing.T) {
 | 
			
		||||
	dialect := DB.NewScope(&EngadgetPost{}).Dialect()
 | 
			
		||||
	engadgetPostScope := DB.NewScope(&EngadgetPost{})
 | 
			
		||||
	if !dialect.HasColumn(engadgetPostScope.TableName(), "author_id") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_name") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_email") {
 | 
			
		||||
		t.Errorf("should has prefix for embedded columns")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(engadgetPostScope.PrimaryFields()) != 1 {
 | 
			
		||||
		t.Errorf("should have only one primary field with embedded struct, but got %v", len(engadgetPostScope.PrimaryFields()))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	hnScope := DB.NewScope(&HNPost{})
 | 
			
		||||
	if !dialect.HasColumn(hnScope.TableName(), "user_id") || !dialect.HasColumn(hnScope.TableName(), "user_name") || !dialect.HasColumn(hnScope.TableName(), "user_email") {
 | 
			
		||||
		t.Errorf("should has prefix for embedded columns")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestSaveAndQueryEmbeddedStruct(t *testing.T) {
 | 
			
		||||
	DB.Save(&HNPost{BasePost: BasePost{Title: "news"}})
 | 
			
		||||
	DB.Save(&HNPost{BasePost: BasePost{Title: "hn_news"}})
 | 
			
		||||
	var news HNPost
 | 
			
		||||
	if err := DB.First(&news, "title = ?", "hn_news").Error; err != nil {
 | 
			
		||||
		t.Errorf("no error should happen when query with embedded struct, but got %v", err)
 | 
			
		||||
	} else if news.Title != "hn_news" {
 | 
			
		||||
		t.Errorf("embedded struct's value should be scanned correctly")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}})
 | 
			
		||||
	var egNews EngadgetPost
 | 
			
		||||
	if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil {
 | 
			
		||||
		t.Errorf("no error should happen when query with embedded struct, but got %v", err)
 | 
			
		||||
	} else if egNews.BasePost.Title != "engadget_news" {
 | 
			
		||||
		t.Errorf("embedded struct's value should be scanned correctly")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.NewScope(&HNPost{}).PrimaryField() == nil {
 | 
			
		||||
		t.Errorf("primary key with embedded struct should works")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, field := range DB.NewScope(&HNPost{}).Fields() {
 | 
			
		||||
		if field.Name == "BasePost" {
 | 
			
		||||
			t.Errorf("scope Fields should not contain embedded struct")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestEmbeddedPointerTypeStruct(t *testing.T) {
 | 
			
		||||
	type HNPost struct {
 | 
			
		||||
		*BasePost
 | 
			
		||||
		Upvotes int32
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Create(&HNPost{BasePost: &BasePost{Title: "embedded_pointer_type"}})
 | 
			
		||||
 | 
			
		||||
	var hnPost HNPost
 | 
			
		||||
	if err := DB.First(&hnPost, "title = ?", "embedded_pointer_type").Error; err != nil {
 | 
			
		||||
		t.Errorf("No error should happen when find embedded pointer type, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if hnPost.Title != "embedded_pointer_type" {
 | 
			
		||||
		t.Errorf("Should find correct value for embedded pointer type")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -1,20 +0,0 @@
 | 
			
		||||
package gorm_test
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
	"github.com/jinzhu/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestErrorsCanBeUsedOutsideGorm(t *testing.T) {
 | 
			
		||||
	errs := []error{errors.New("First"), errors.New("Second")}
 | 
			
		||||
 | 
			
		||||
	gErrs := gorm.Errors(errs)
 | 
			
		||||
	gErrs = gErrs.Add(errors.New("Third"))
 | 
			
		||||
	gErrs = gErrs.Add(gErrs)
 | 
			
		||||
 | 
			
		||||
	if gErrs.Error() != "First; Second; Third" {
 | 
			
		||||
		t.Fatalf("Gave wrong error, got %s", gErrs.Error())
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -1,49 +0,0 @@
 | 
			
		||||
package gorm_test
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
	"github.com/jinzhu/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type CalculateField struct {
 | 
			
		||||
	gorm.Model
 | 
			
		||||
	Name     string
 | 
			
		||||
	Children []CalculateFieldChild
 | 
			
		||||
	Category CalculateFieldCategory
 | 
			
		||||
	EmbeddedField
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type EmbeddedField struct {
 | 
			
		||||
	EmbeddedName string `sql:"NOT NULL;DEFAULT:'hello'"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type CalculateFieldChild struct {
 | 
			
		||||
	gorm.Model
 | 
			
		||||
	CalculateFieldID uint
 | 
			
		||||
	Name             string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type CalculateFieldCategory struct {
 | 
			
		||||
	gorm.Model
 | 
			
		||||
	CalculateFieldID uint
 | 
			
		||||
	Name             string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestCalculateField(t *testing.T) {
 | 
			
		||||
	var field CalculateField
 | 
			
		||||
	var scope = DB.NewScope(&field)
 | 
			
		||||
	if field, ok := scope.FieldByName("Children"); !ok || field.Relationship == nil {
 | 
			
		||||
		t.Errorf("Should calculate fields correctly for the first time")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if field, ok := scope.FieldByName("Category"); !ok || field.Relationship == nil {
 | 
			
		||||
		t.Errorf("Should calculate fields correctly for the first time")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if field, ok := scope.FieldByName("embedded_name"); !ok {
 | 
			
		||||
		t.Errorf("should find embedded field")
 | 
			
		||||
	} else if _, ok := field.TagSettings["NOT NULL"]; !ok {
 | 
			
		||||
		t.Errorf("should find embedded field's tag settings")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										17
									
								
								gorm.go
									
									
									
									
									
								
							
							
						
						
									
										17
									
								
								gorm.go
									
									
									
									
									
								
							@ -1,6 +1,10 @@
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import "github.com/jinzhu/gorm/logger"
 | 
			
		||||
import (
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/jinzhu/gorm/logger"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Config GORM config
 | 
			
		||||
type Config struct {
 | 
			
		||||
@ -45,3 +49,14 @@ type DB struct {
 | 
			
		||||
	Error        error
 | 
			
		||||
	RowsAffected int64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Model base model definition, including fields `ID`, `CreatedAt`, `UpdatedAt`, `DeletedAt`, which could be embedded in your models
 | 
			
		||||
//    type User struct {
 | 
			
		||||
//      gorm.Model
 | 
			
		||||
//    }
 | 
			
		||||
type Model struct {
 | 
			
		||||
	ID        uint `gorm:"primary_key"`
 | 
			
		||||
	CreatedAt time.Time
 | 
			
		||||
	UpdatedAt time.Time
 | 
			
		||||
	DeletedAt *time.Time `sql:"index"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										20
									
								
								interface.go
									
									
									
									
									
								
							
							
						
						
									
										20
									
								
								interface.go
									
									
									
									
									
								
							@ -1,20 +0,0 @@
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import "database/sql"
 | 
			
		||||
 | 
			
		||||
// SQLCommon is the minimal database connection functionality gorm requires.  Implemented by *sql.DB.
 | 
			
		||||
type SQLCommon interface {
 | 
			
		||||
	Exec(query string, args ...interface{}) (sql.Result, error)
 | 
			
		||||
	Prepare(query string) (*sql.Stmt, error)
 | 
			
		||||
	Query(query string, args ...interface{}) (*sql.Rows, error)
 | 
			
		||||
	QueryRow(query string, args ...interface{}) *sql.Row
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type sqlDb interface {
 | 
			
		||||
	Begin() (*sql.Tx, error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type sqlTx interface {
 | 
			
		||||
	Commit() error
 | 
			
		||||
	Rollback() error
 | 
			
		||||
}
 | 
			
		||||
@ -1,117 +0,0 @@
 | 
			
		||||
package gorm_test
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/jinzhu/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Person struct {
 | 
			
		||||
	Id        int
 | 
			
		||||
	Name      string
 | 
			
		||||
	Addresses []*Address `gorm:"many2many:person_addresses;"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type PersonAddress struct {
 | 
			
		||||
	gorm.JoinTableHandler
 | 
			
		||||
	PersonID  int
 | 
			
		||||
	AddressID int
 | 
			
		||||
	DeletedAt *time.Time
 | 
			
		||||
	CreatedAt time.Time
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (*PersonAddress) Add(handler gorm.JoinTableHandlerInterface, db *gorm.DB, foreignValue interface{}, associationValue interface{}) error {
 | 
			
		||||
	foreignPrimaryKey, _ := strconv.Atoi(fmt.Sprint(db.NewScope(foreignValue).PrimaryKeyValue()))
 | 
			
		||||
	associationPrimaryKey, _ := strconv.Atoi(fmt.Sprint(db.NewScope(associationValue).PrimaryKeyValue()))
 | 
			
		||||
	if result := db.Unscoped().Model(&PersonAddress{}).Where(map[string]interface{}{
 | 
			
		||||
		"person_id":  foreignPrimaryKey,
 | 
			
		||||
		"address_id": associationPrimaryKey,
 | 
			
		||||
	}).Update(map[string]interface{}{
 | 
			
		||||
		"person_id":  foreignPrimaryKey,
 | 
			
		||||
		"address_id": associationPrimaryKey,
 | 
			
		||||
		"deleted_at": gorm.Expr("NULL"),
 | 
			
		||||
	}).RowsAffected; result == 0 {
 | 
			
		||||
		return db.Create(&PersonAddress{
 | 
			
		||||
			PersonID:  foreignPrimaryKey,
 | 
			
		||||
			AddressID: associationPrimaryKey,
 | 
			
		||||
		}).Error
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (*PersonAddress) Delete(handler gorm.JoinTableHandlerInterface, db *gorm.DB, sources ...interface{}) error {
 | 
			
		||||
	return db.Delete(&PersonAddress{}).Error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pa *PersonAddress) JoinWith(handler gorm.JoinTableHandlerInterface, db *gorm.DB, source interface{}) *gorm.DB {
 | 
			
		||||
	table := pa.Table(db)
 | 
			
		||||
	return db.Joins("INNER JOIN person_addresses ON person_addresses.address_id = addresses.id").Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestJoinTable(t *testing.T) {
 | 
			
		||||
	DB.Exec("drop table person_addresses;")
 | 
			
		||||
	DB.AutoMigrate(&Person{})
 | 
			
		||||
	DB.SetJoinTableHandler(&Person{}, "Addresses", &PersonAddress{})
 | 
			
		||||
 | 
			
		||||
	address1 := &Address{Address1: "address 1"}
 | 
			
		||||
	address2 := &Address{Address1: "address 2"}
 | 
			
		||||
	person := &Person{Name: "person", Addresses: []*Address{address1, address2}}
 | 
			
		||||
	DB.Save(person)
 | 
			
		||||
 | 
			
		||||
	DB.Model(person).Association("Addresses").Delete(address1)
 | 
			
		||||
 | 
			
		||||
	if DB.Find(&[]PersonAddress{}, "person_id = ?", person.Id).RowsAffected != 1 {
 | 
			
		||||
		t.Errorf("Should found one address")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.Model(person).Association("Addresses").Count() != 1 {
 | 
			
		||||
		t.Errorf("Should found one address")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.Unscoped().Find(&[]PersonAddress{}, "person_id = ?", person.Id).RowsAffected != 2 {
 | 
			
		||||
		t.Errorf("Found two addresses with Unscoped")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.Model(person).Association("Addresses").Clear(); DB.Model(person).Association("Addresses").Count() != 0 {
 | 
			
		||||
		t.Errorf("Should deleted all addresses")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestEmbeddedMany2ManyRelationship(t *testing.T) {
 | 
			
		||||
	type EmbeddedPerson struct {
 | 
			
		||||
		ID        int
 | 
			
		||||
		Name      string
 | 
			
		||||
		Addresses []*Address `gorm:"many2many:person_addresses;"`
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	type NewPerson struct {
 | 
			
		||||
		EmbeddedPerson
 | 
			
		||||
		ExternalID uint
 | 
			
		||||
	}
 | 
			
		||||
	DB.Exec("drop table person_addresses;")
 | 
			
		||||
	DB.AutoMigrate(&NewPerson{})
 | 
			
		||||
 | 
			
		||||
	address1 := &Address{Address1: "address 1"}
 | 
			
		||||
	address2 := &Address{Address1: "address 2"}
 | 
			
		||||
	person := &NewPerson{ExternalID: 100, EmbeddedPerson: EmbeddedPerson{Name: "person", Addresses: []*Address{address1, address2}}}
 | 
			
		||||
	if err := DB.Save(person).Error; err != nil {
 | 
			
		||||
		t.Errorf("no error should return when save embedded many2many relationship, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := DB.Model(person).Association("Addresses").Delete(address1).Error; err != nil {
 | 
			
		||||
		t.Errorf("no error should return when delete embedded many2many relationship, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	association := DB.Model(person).Association("Addresses")
 | 
			
		||||
	if count := association.Count(); count != 1 || association.Error != nil {
 | 
			
		||||
		t.Errorf("Should found one address, but got %v, error is %v", count, association.Error)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if association.Clear(); association.Count() != 0 {
 | 
			
		||||
		t.Errorf("Should deleted all addresses")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										949
									
								
								main_test.go
									
									
									
									
									
								
							
							
						
						
									
										949
									
								
								main_test.go
									
									
									
									
									
								
							@ -1,949 +0,0 @@
 | 
			
		||||
package gorm_test
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"database/sql/driver"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"os"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/erikstmartin/go-testdb"
 | 
			
		||||
	"github.com/jinzhu/gorm"
 | 
			
		||||
	_ "github.com/jinzhu/gorm/dialects/mssql"
 | 
			
		||||
	_ "github.com/jinzhu/gorm/dialects/mysql"
 | 
			
		||||
	"github.com/jinzhu/gorm/dialects/postgres"
 | 
			
		||||
	_ "github.com/jinzhu/gorm/dialects/sqlite"
 | 
			
		||||
	"github.com/jinzhu/now"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	DB                 *gorm.DB
 | 
			
		||||
	t1, t2, t3, t4, t5 time.Time
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	var err error
 | 
			
		||||
 | 
			
		||||
	if DB, err = OpenTestConnection(); err != nil {
 | 
			
		||||
		panic(fmt.Sprintf("No error should happen when connecting to test database, but got err=%+v", err))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	runMigration()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func OpenTestConnection() (db *gorm.DB, err error) {
 | 
			
		||||
	dbDSN := os.Getenv("GORM_DSN")
 | 
			
		||||
	switch os.Getenv("GORM_DIALECT") {
 | 
			
		||||
	case "mysql":
 | 
			
		||||
		fmt.Println("testing mysql...")
 | 
			
		||||
		if dbDSN == "" {
 | 
			
		||||
			dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True"
 | 
			
		||||
		}
 | 
			
		||||
		db, err = gorm.Open("mysql", dbDSN)
 | 
			
		||||
	case "postgres":
 | 
			
		||||
		fmt.Println("testing postgres...")
 | 
			
		||||
		if dbDSN == "" {
 | 
			
		||||
			dbDSN = "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable"
 | 
			
		||||
		}
 | 
			
		||||
		db, err = gorm.Open("postgres", dbDSN)
 | 
			
		||||
	case "mssql":
 | 
			
		||||
		// CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86';
 | 
			
		||||
		// CREATE DATABASE gorm;
 | 
			
		||||
		// USE gorm;
 | 
			
		||||
		// CREATE USER gorm FROM LOGIN gorm;
 | 
			
		||||
		// sp_changedbowner 'gorm';
 | 
			
		||||
		fmt.Println("testing mssql...")
 | 
			
		||||
		if dbDSN == "" {
 | 
			
		||||
			dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm"
 | 
			
		||||
		}
 | 
			
		||||
		db, err = gorm.Open("mssql", dbDSN)
 | 
			
		||||
	default:
 | 
			
		||||
		fmt.Println("testing sqlite3...")
 | 
			
		||||
		db, err = gorm.Open("sqlite3", filepath.Join(os.TempDir(), "gorm.db"))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// db.SetLogger(Logger{log.New(os.Stdout, "\r\n", 0)})
 | 
			
		||||
	// db.SetLogger(log.New(os.Stdout, "\r\n", 0))
 | 
			
		||||
	if debug := os.Getenv("DEBUG"); debug == "true" {
 | 
			
		||||
		db.LogMode(true)
 | 
			
		||||
	} else if debug == "false" {
 | 
			
		||||
		db.LogMode(false)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	db.DB().SetMaxIdleConns(10)
 | 
			
		||||
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestStringPrimaryKey(t *testing.T) {
 | 
			
		||||
	type UUIDStruct struct {
 | 
			
		||||
		ID   string `gorm:"primary_key"`
 | 
			
		||||
		Name string
 | 
			
		||||
	}
 | 
			
		||||
	DB.DropTable(&UUIDStruct{})
 | 
			
		||||
	DB.AutoMigrate(&UUIDStruct{})
 | 
			
		||||
 | 
			
		||||
	data := UUIDStruct{ID: "uuid", Name: "hello"}
 | 
			
		||||
	if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" || data.Name != "hello" {
 | 
			
		||||
		t.Errorf("string primary key should not be populated")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	data = UUIDStruct{ID: "uuid", Name: "hello world"}
 | 
			
		||||
	if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" || data.Name != "hello world" {
 | 
			
		||||
		t.Errorf("string primary key should not be populated")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestExceptionsWithInvalidSql(t *testing.T) {
 | 
			
		||||
	var columns []string
 | 
			
		||||
	if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil {
 | 
			
		||||
		t.Errorf("Should got error with invalid SQL")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.Model(&User{}).Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil {
 | 
			
		||||
		t.Errorf("Should got error with invalid SQL")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Find(&User{}).Error == nil {
 | 
			
		||||
		t.Errorf("Should got error with invalid SQL")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var count1, count2 int64
 | 
			
		||||
	DB.Model(&User{}).Count(&count1)
 | 
			
		||||
	if count1 <= 0 {
 | 
			
		||||
		t.Errorf("Should find some users")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.Where("name = ?", "jinzhu; delete * from users").First(&User{}).Error == nil {
 | 
			
		||||
		t.Errorf("Should got error with invalid SQL")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Model(&User{}).Count(&count2)
 | 
			
		||||
	if count1 != count2 {
 | 
			
		||||
		t.Errorf("No user should not be deleted by invalid SQL")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestSetTable(t *testing.T) {
 | 
			
		||||
	DB.Create(getPreparedUser("pluck_user1", "pluck_user"))
 | 
			
		||||
	DB.Create(getPreparedUser("pluck_user2", "pluck_user"))
 | 
			
		||||
	DB.Create(getPreparedUser("pluck_user3", "pluck_user"))
 | 
			
		||||
 | 
			
		||||
	if err := DB.Table("users").Where("role = ?", "pluck_user").Pluck("age", &[]int{}).Error; err != nil {
 | 
			
		||||
		t.Error("No errors should happen if set table for pluck", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var users []User
 | 
			
		||||
	if DB.Table("users").Find(&[]User{}).Error != nil {
 | 
			
		||||
		t.Errorf("No errors should happen if set table for find")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.Table("invalid_table").Find(&users).Error == nil {
 | 
			
		||||
		t.Errorf("Should got error when table is set to an invalid table")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Exec("drop table deleted_users;")
 | 
			
		||||
	if DB.Table("deleted_users").CreateTable(&User{}).Error != nil {
 | 
			
		||||
		t.Errorf("Create table with specified table")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Table("deleted_users").Save(&User{Name: "DeletedUser"})
 | 
			
		||||
 | 
			
		||||
	var deletedUsers []User
 | 
			
		||||
	DB.Table("deleted_users").Find(&deletedUsers)
 | 
			
		||||
	if len(deletedUsers) != 1 {
 | 
			
		||||
		t.Errorf("Query from specified table")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Save(getPreparedUser("normal_user", "reset_table"))
 | 
			
		||||
	DB.Table("deleted_users").Save(getPreparedUser("deleted_user", "reset_table"))
 | 
			
		||||
	var user1, user2, user3 User
 | 
			
		||||
	DB.Where("role = ?", "reset_table").First(&user1).Table("deleted_users").First(&user2).Table("").First(&user3)
 | 
			
		||||
	if (user1.Name != "normal_user") || (user2.Name != "deleted_user") || (user3.Name != "normal_user") {
 | 
			
		||||
		t.Errorf("unset specified table with blank string")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Order struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Cart struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c Cart) TableName() string {
 | 
			
		||||
	return "shopping_cart"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestHasTable(t *testing.T) {
 | 
			
		||||
	type Foo struct {
 | 
			
		||||
		Id    int
 | 
			
		||||
		Stuff string
 | 
			
		||||
	}
 | 
			
		||||
	DB.DropTable(&Foo{})
 | 
			
		||||
 | 
			
		||||
	// Table should not exist at this point, HasTable should return false
 | 
			
		||||
	if ok := DB.HasTable("foos"); ok {
 | 
			
		||||
		t.Errorf("Table should not exist, but does")
 | 
			
		||||
	}
 | 
			
		||||
	if ok := DB.HasTable(&Foo{}); ok {
 | 
			
		||||
		t.Errorf("Table should not exist, but does")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// We create the table
 | 
			
		||||
	if err := DB.CreateTable(&Foo{}).Error; err != nil {
 | 
			
		||||
		t.Errorf("Table should be created")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// And now it should exits, and HasTable should return true
 | 
			
		||||
	if ok := DB.HasTable("foos"); !ok {
 | 
			
		||||
		t.Errorf("Table should exist, but HasTable informs it does not")
 | 
			
		||||
	}
 | 
			
		||||
	if ok := DB.HasTable(&Foo{}); !ok {
 | 
			
		||||
		t.Errorf("Table should exist, but HasTable informs it does not")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestTableName(t *testing.T) {
 | 
			
		||||
	DB := DB.Model("")
 | 
			
		||||
	if DB.NewScope(Order{}).TableName() != "orders" {
 | 
			
		||||
		t.Errorf("Order's table name should be orders")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.NewScope(&Order{}).TableName() != "orders" {
 | 
			
		||||
		t.Errorf("&Order's table name should be orders")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.NewScope([]Order{}).TableName() != "orders" {
 | 
			
		||||
		t.Errorf("[]Order's table name should be orders")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.NewScope(&[]Order{}).TableName() != "orders" {
 | 
			
		||||
		t.Errorf("&[]Order's table name should be orders")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.SingularTable(true)
 | 
			
		||||
	if DB.NewScope(Order{}).TableName() != "order" {
 | 
			
		||||
		t.Errorf("Order's singular table name should be order")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.NewScope(&Order{}).TableName() != "order" {
 | 
			
		||||
		t.Errorf("&Order's singular table name should be order")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.NewScope([]Order{}).TableName() != "order" {
 | 
			
		||||
		t.Errorf("[]Order's singular table name should be order")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.NewScope(&[]Order{}).TableName() != "order" {
 | 
			
		||||
		t.Errorf("&[]Order's singular table name should be order")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.NewScope(&Cart{}).TableName() != "shopping_cart" {
 | 
			
		||||
		t.Errorf("&Cart's singular table name should be shopping_cart")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.NewScope(Cart{}).TableName() != "shopping_cart" {
 | 
			
		||||
		t.Errorf("Cart's singular table name should be shopping_cart")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.NewScope(&[]Cart{}).TableName() != "shopping_cart" {
 | 
			
		||||
		t.Errorf("&[]Cart's singular table name should be shopping_cart")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.NewScope([]Cart{}).TableName() != "shopping_cart" {
 | 
			
		||||
		t.Errorf("[]Cart's singular table name should be shopping_cart")
 | 
			
		||||
	}
 | 
			
		||||
	DB.SingularTable(false)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestNullValues(t *testing.T) {
 | 
			
		||||
	DB.DropTable(&NullValue{})
 | 
			
		||||
	DB.AutoMigrate(&NullValue{})
 | 
			
		||||
 | 
			
		||||
	if err := DB.Save(&NullValue{
 | 
			
		||||
		Name:    sql.NullString{String: "hello", Valid: true},
 | 
			
		||||
		Gender:  &sql.NullString{String: "M", Valid: true},
 | 
			
		||||
		Age:     sql.NullInt64{Int64: 18, Valid: true},
 | 
			
		||||
		Male:    sql.NullBool{Bool: true, Valid: true},
 | 
			
		||||
		Height:  sql.NullFloat64{Float64: 100.11, Valid: true},
 | 
			
		||||
		AddedAt: NullTime{Time: time.Now(), Valid: true},
 | 
			
		||||
	}).Error; err != nil {
 | 
			
		||||
		t.Errorf("Not error should raise when test null value")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var nv NullValue
 | 
			
		||||
	DB.First(&nv, "name = ?", "hello")
 | 
			
		||||
 | 
			
		||||
	if nv.Name.String != "hello" || nv.Gender.String != "M" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 || nv.AddedAt.Valid != true {
 | 
			
		||||
		t.Errorf("Should be able to fetch null value")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := DB.Save(&NullValue{
 | 
			
		||||
		Name:    sql.NullString{String: "hello-2", Valid: true},
 | 
			
		||||
		Gender:  &sql.NullString{String: "F", Valid: true},
 | 
			
		||||
		Age:     sql.NullInt64{Int64: 18, Valid: false},
 | 
			
		||||
		Male:    sql.NullBool{Bool: true, Valid: true},
 | 
			
		||||
		Height:  sql.NullFloat64{Float64: 100.11, Valid: true},
 | 
			
		||||
		AddedAt: NullTime{Time: time.Now(), Valid: false},
 | 
			
		||||
	}).Error; err != nil {
 | 
			
		||||
		t.Errorf("Not error should raise when test null value")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var nv2 NullValue
 | 
			
		||||
	DB.First(&nv2, "name = ?", "hello-2")
 | 
			
		||||
	if nv2.Name.String != "hello-2" || nv2.Gender.String != "F" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 || nv2.AddedAt.Valid != false {
 | 
			
		||||
		t.Errorf("Should be able to fetch null value")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := DB.Save(&NullValue{
 | 
			
		||||
		Name:    sql.NullString{String: "hello-3", Valid: false},
 | 
			
		||||
		Gender:  &sql.NullString{String: "M", Valid: true},
 | 
			
		||||
		Age:     sql.NullInt64{Int64: 18, Valid: false},
 | 
			
		||||
		Male:    sql.NullBool{Bool: true, Valid: true},
 | 
			
		||||
		Height:  sql.NullFloat64{Float64: 100.11, Valid: true},
 | 
			
		||||
		AddedAt: NullTime{Time: time.Now(), Valid: false},
 | 
			
		||||
	}).Error; err == nil {
 | 
			
		||||
		t.Errorf("Can't save because of name can't be null")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestNullValuesWithFirstOrCreate(t *testing.T) {
 | 
			
		||||
	var nv1 = NullValue{
 | 
			
		||||
		Name:   sql.NullString{String: "first_or_create", Valid: true},
 | 
			
		||||
		Gender: &sql.NullString{String: "M", Valid: true},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var nv2 NullValue
 | 
			
		||||
	result := DB.Where(nv1).FirstOrCreate(&nv2)
 | 
			
		||||
 | 
			
		||||
	if result.RowsAffected != 1 {
 | 
			
		||||
		t.Errorf("RowsAffected should be 1 after create some record")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if result.Error != nil {
 | 
			
		||||
		t.Errorf("Should not raise any error, but got %v", result.Error)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if nv2.Name.String != "first_or_create" || nv2.Gender.String != "M" {
 | 
			
		||||
		t.Errorf("first or create with nullvalues")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := DB.Where(nv1).Assign(NullValue{Age: sql.NullInt64{Int64: 18, Valid: true}}).FirstOrCreate(&nv2).Error; err != nil {
 | 
			
		||||
		t.Errorf("Should not raise any error, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if nv2.Age.Int64 != 18 {
 | 
			
		||||
		t.Errorf("should update age to 18")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestTransaction(t *testing.T) {
 | 
			
		||||
	tx := DB.Begin()
 | 
			
		||||
	u := User{Name: "transcation"}
 | 
			
		||||
	if err := tx.Save(&u).Error; err != nil {
 | 
			
		||||
		t.Errorf("No error should raise")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil {
 | 
			
		||||
		t.Errorf("Should find saved record")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if sqlTx, ok := tx.CommonDB().(*sql.Tx); !ok || sqlTx == nil {
 | 
			
		||||
		t.Errorf("Should return the underlying sql.Tx")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tx.Rollback()
 | 
			
		||||
 | 
			
		||||
	if err := tx.First(&User{}, "name = ?", "transcation").Error; err == nil {
 | 
			
		||||
		t.Errorf("Should not find record after rollback")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tx2 := DB.Begin()
 | 
			
		||||
	u2 := User{Name: "transcation-2"}
 | 
			
		||||
	if err := tx2.Save(&u2).Error; err != nil {
 | 
			
		||||
		t.Errorf("No error should raise")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := tx2.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
 | 
			
		||||
		t.Errorf("Should find saved record")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tx2.Commit()
 | 
			
		||||
 | 
			
		||||
	if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
 | 
			
		||||
		t.Errorf("Should be able to find committed record")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestRow(t *testing.T) {
 | 
			
		||||
	user1 := User{Name: "RowUser1", Age: 1, Birthday: parseTime("2000-1-1")}
 | 
			
		||||
	user2 := User{Name: "RowUser2", Age: 10, Birthday: parseTime("2010-1-1")}
 | 
			
		||||
	user3 := User{Name: "RowUser3", Age: 20, Birthday: parseTime("2020-1-1")}
 | 
			
		||||
	DB.Save(&user1).Save(&user2).Save(&user3)
 | 
			
		||||
 | 
			
		||||
	row := DB.Table("users").Where("name = ?", user2.Name).Select("age").Row()
 | 
			
		||||
	var age int64
 | 
			
		||||
	row.Scan(&age)
 | 
			
		||||
	if age != 10 {
 | 
			
		||||
		t.Errorf("Scan with Row")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestRows(t *testing.T) {
 | 
			
		||||
	user1 := User{Name: "RowsUser1", Age: 1, Birthday: parseTime("2000-1-1")}
 | 
			
		||||
	user2 := User{Name: "RowsUser2", Age: 10, Birthday: parseTime("2010-1-1")}
 | 
			
		||||
	user3 := User{Name: "RowsUser3", Age: 20, Birthday: parseTime("2020-1-1")}
 | 
			
		||||
	DB.Save(&user1).Save(&user2).Save(&user3)
 | 
			
		||||
 | 
			
		||||
	rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("Not error should happen, got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	count := 0
 | 
			
		||||
	for rows.Next() {
 | 
			
		||||
		var name string
 | 
			
		||||
		var age int64
 | 
			
		||||
		rows.Scan(&name, &age)
 | 
			
		||||
		count++
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if count != 2 {
 | 
			
		||||
		t.Errorf("Should found two records")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestScanRows(t *testing.T) {
 | 
			
		||||
	user1 := User{Name: "ScanRowsUser1", Age: 1, Birthday: parseTime("2000-1-1")}
 | 
			
		||||
	user2 := User{Name: "ScanRowsUser2", Age: 10, Birthday: parseTime("2010-1-1")}
 | 
			
		||||
	user3 := User{Name: "ScanRowsUser3", Age: 20, Birthday: parseTime("2020-1-1")}
 | 
			
		||||
	DB.Save(&user1).Save(&user2).Save(&user3)
 | 
			
		||||
 | 
			
		||||
	rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("Not error should happen, got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	type Result struct {
 | 
			
		||||
		Name string
 | 
			
		||||
		Age  int
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var results []Result
 | 
			
		||||
	for rows.Next() {
 | 
			
		||||
		var result Result
 | 
			
		||||
		if err := DB.ScanRows(rows, &result); err != nil {
 | 
			
		||||
			t.Errorf("should get no error, but got %v", err)
 | 
			
		||||
		}
 | 
			
		||||
		results = append(results, result)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) {
 | 
			
		||||
		t.Errorf("Should find expected results")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestScan(t *testing.T) {
 | 
			
		||||
	user1 := User{Name: "ScanUser1", Age: 1, Birthday: parseTime("2000-1-1")}
 | 
			
		||||
	user2 := User{Name: "ScanUser2", Age: 10, Birthday: parseTime("2010-1-1")}
 | 
			
		||||
	user3 := User{Name: "ScanUser3", Age: 20, Birthday: parseTime("2020-1-1")}
 | 
			
		||||
	DB.Save(&user1).Save(&user2).Save(&user3)
 | 
			
		||||
 | 
			
		||||
	type result struct {
 | 
			
		||||
		Name string
 | 
			
		||||
		Age  int
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var res result
 | 
			
		||||
	DB.Table("users").Select("name, age").Where("name = ?", user3.Name).Scan(&res)
 | 
			
		||||
	if res.Name != user3.Name {
 | 
			
		||||
		t.Errorf("Scan into struct should work")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var doubleAgeRes = &result{}
 | 
			
		||||
	if err := DB.Table("users").Select("age + age as age").Where("name = ?", user3.Name).Scan(&doubleAgeRes).Error; err != nil {
 | 
			
		||||
		t.Errorf("Scan to pointer of pointer")
 | 
			
		||||
	}
 | 
			
		||||
	if doubleAgeRes.Age != res.Age*2 {
 | 
			
		||||
		t.Errorf("Scan double age as age")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var ress []result
 | 
			
		||||
	DB.Table("users").Select("name, age").Where("name in (?)", []string{user2.Name, user3.Name}).Scan(&ress)
 | 
			
		||||
	if len(ress) != 2 || ress[0].Name != user2.Name || ress[1].Name != user3.Name {
 | 
			
		||||
		t.Errorf("Scan into struct map")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestRaw(t *testing.T) {
 | 
			
		||||
	user1 := User{Name: "ExecRawSqlUser1", Age: 1, Birthday: parseTime("2000-1-1")}
 | 
			
		||||
	user2 := User{Name: "ExecRawSqlUser2", Age: 10, Birthday: parseTime("2010-1-1")}
 | 
			
		||||
	user3 := User{Name: "ExecRawSqlUser3", Age: 20, Birthday: parseTime("2020-1-1")}
 | 
			
		||||
	DB.Save(&user1).Save(&user2).Save(&user3)
 | 
			
		||||
 | 
			
		||||
	type result struct {
 | 
			
		||||
		Name  string
 | 
			
		||||
		Email string
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var ress []result
 | 
			
		||||
	DB.Raw("SELECT name, age FROM users WHERE name = ? or name = ?", user2.Name, user3.Name).Scan(&ress)
 | 
			
		||||
	if len(ress) != 2 || ress[0].Name != user2.Name || ress[1].Name != user3.Name {
 | 
			
		||||
		t.Errorf("Raw with scan")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rows, _ := DB.Raw("select name, age from users where name = ?", user3.Name).Rows()
 | 
			
		||||
	count := 0
 | 
			
		||||
	for rows.Next() {
 | 
			
		||||
		count++
 | 
			
		||||
	}
 | 
			
		||||
	if count != 1 {
 | 
			
		||||
		t.Errorf("Raw with Rows should find one record with name 3")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Exec("update users set name=? where name in (?)", "jinzhu", []string{user1.Name, user2.Name, user3.Name})
 | 
			
		||||
	if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.ErrRecordNotFound {
 | 
			
		||||
		t.Error("Raw sql to update records")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestGroup(t *testing.T) {
 | 
			
		||||
	rows, err := DB.Select("name").Table("users").Group("name").Rows()
 | 
			
		||||
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		defer rows.Close()
 | 
			
		||||
		for rows.Next() {
 | 
			
		||||
			var name string
 | 
			
		||||
			rows.Scan(&name)
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		t.Errorf("Should not raise any error")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestJoins(t *testing.T) {
 | 
			
		||||
	var user = User{
 | 
			
		||||
		Name:       "joins",
 | 
			
		||||
		CreditCard: CreditCard{Number: "411111111111"},
 | 
			
		||||
		Emails:     []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}},
 | 
			
		||||
	}
 | 
			
		||||
	DB.Save(&user)
 | 
			
		||||
 | 
			
		||||
	var users1 []User
 | 
			
		||||
	DB.Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins").Find(&users1)
 | 
			
		||||
	if len(users1) != 2 {
 | 
			
		||||
		t.Errorf("should find two users using left join")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var users2 []User
 | 
			
		||||
	DB.Joins("left join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Where("name = ?", "joins").First(&users2)
 | 
			
		||||
	if len(users2) != 1 {
 | 
			
		||||
		t.Errorf("should find one users using left join with conditions")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var users3 []User
 | 
			
		||||
	DB.Joins("join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Joins("join credit_cards on credit_cards.user_id = users.id AND credit_cards.number = ?", "411111111111").Where("name = ?", "joins").First(&users3)
 | 
			
		||||
	if len(users3) != 1 {
 | 
			
		||||
		t.Errorf("should find one users using multiple left join conditions")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var users4 []User
 | 
			
		||||
	DB.Joins("join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Joins("join credit_cards on credit_cards.user_id = users.id AND credit_cards.number = ?", "422222222222").Where("name = ?", "joins").First(&users4)
 | 
			
		||||
	if len(users4) != 0 {
 | 
			
		||||
		t.Errorf("should find no user when searching with unexisting credit card")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var users5 []User
 | 
			
		||||
	db5 := DB.Joins("join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Joins("join credit_cards on credit_cards.user_id = users.id AND credit_cards.number = ?", "411111111111").Where(User{Id: 1}).Where(Email{Id: 1}).Not(Email{Id: 10}).First(&users5)
 | 
			
		||||
	if db5.Error != nil {
 | 
			
		||||
		t.Errorf("Should not raise error for join where identical fields in different tables. Error: %s", db5.Error.Error())
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestJoinsWithSelect(t *testing.T) {
 | 
			
		||||
	type result struct {
 | 
			
		||||
		Name  string
 | 
			
		||||
		Email string
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	user := User{
 | 
			
		||||
		Name:   "joins_with_select",
 | 
			
		||||
		Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}},
 | 
			
		||||
	}
 | 
			
		||||
	DB.Save(&user)
 | 
			
		||||
 | 
			
		||||
	var results []result
 | 
			
		||||
	DB.Table("users").Select("name, emails.email").Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins_with_select").Scan(&results)
 | 
			
		||||
	if len(results) != 2 || results[0].Email != "join1@example.com" || results[1].Email != "join2@example.com" {
 | 
			
		||||
		t.Errorf("Should find all two emails with Join select")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestHaving(t *testing.T) {
 | 
			
		||||
	rows, err := DB.Select("name, count(*) as total").Table("users").Group("name").Having("name IN (?)", []string{"2", "3"}).Rows()
 | 
			
		||||
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		defer rows.Close()
 | 
			
		||||
		for rows.Next() {
 | 
			
		||||
			var name string
 | 
			
		||||
			var total int64
 | 
			
		||||
			rows.Scan(&name, &total)
 | 
			
		||||
 | 
			
		||||
			if name == "2" && total != 1 {
 | 
			
		||||
				t.Errorf("Should have one user having name 2")
 | 
			
		||||
			}
 | 
			
		||||
			if name == "3" && total != 2 {
 | 
			
		||||
				t.Errorf("Should have two users having name 3")
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		t.Errorf("Should not raise any error")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestQueryBuilderSubselectInWhere(t *testing.T) {
 | 
			
		||||
	user := User{Name: "query_expr_select_ruser1", Email: "root@user1.com", Age: 32}
 | 
			
		||||
	DB.Save(&user)
 | 
			
		||||
	user = User{Name: "query_expr_select_ruser2", Email: "nobody@user2.com", Age: 16}
 | 
			
		||||
	DB.Save(&user)
 | 
			
		||||
	user = User{Name: "query_expr_select_ruser3", Email: "root@user3.com", Age: 64}
 | 
			
		||||
	DB.Save(&user)
 | 
			
		||||
	user = User{Name: "query_expr_select_ruser4", Email: "somebody@user3.com", Age: 128}
 | 
			
		||||
	DB.Save(&user)
 | 
			
		||||
 | 
			
		||||
	var users []User
 | 
			
		||||
	DB.Select("*").Where("name IN (?)", DB.
 | 
			
		||||
		Select("name").Table("users").Where("name LIKE ?", "query_expr_select%").QueryExpr()).Find(&users)
 | 
			
		||||
 | 
			
		||||
	if len(users) != 4 {
 | 
			
		||||
		t.Errorf("Four users should be found, instead found %d", len(users))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Select("*").Where("name LIKE ?", "query_expr_select%").Where("age >= (?)", DB.
 | 
			
		||||
		Select("AVG(age)").Table("users").Where("name LIKE ?", "query_expr_select%").QueryExpr()).Find(&users)
 | 
			
		||||
 | 
			
		||||
	if len(users) != 2 {
 | 
			
		||||
		t.Errorf("Two users should be found, instead found %d", len(users))
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestQueryBuilderRawQueryWithSubquery(t *testing.T) {
 | 
			
		||||
	user := User{Name: "subquery_test_user1", Age: 10}
 | 
			
		||||
	DB.Save(&user)
 | 
			
		||||
	user = User{Name: "subquery_test_user2", Age: 11}
 | 
			
		||||
	DB.Save(&user)
 | 
			
		||||
	user = User{Name: "subquery_test_user3", Age: 12}
 | 
			
		||||
	DB.Save(&user)
 | 
			
		||||
 | 
			
		||||
	var count int
 | 
			
		||||
	err := DB.Raw("select count(*) from (?) tmp",
 | 
			
		||||
		DB.Table("users").
 | 
			
		||||
			Select("name").
 | 
			
		||||
			Where("age >= ? and name in (?)", 10, []string{"subquery_test_user1", "subquery_test_user2"}).
 | 
			
		||||
			Group("name").
 | 
			
		||||
			QueryExpr(),
 | 
			
		||||
	).Count(&count).Error
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("Expected to get no errors, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	if count != 2 {
 | 
			
		||||
		t.Errorf("Row count must be 2, instead got %d", count)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = DB.Raw("select count(*) from (?) tmp",
 | 
			
		||||
		DB.Table("users").
 | 
			
		||||
			Select("name").
 | 
			
		||||
			Where("name LIKE ?", "subquery_test%").
 | 
			
		||||
			Not("age <= ?", 10).Not("name in (?)", []string{"subquery_test_user1", "subquery_test_user2"}).
 | 
			
		||||
			Group("name").
 | 
			
		||||
			QueryExpr(),
 | 
			
		||||
	).Count(&count).Error
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("Expected to get no errors, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	if count != 1 {
 | 
			
		||||
		t.Errorf("Row count must be 1, instead got %d", count)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestQueryBuilderSubselectInHaving(t *testing.T) {
 | 
			
		||||
	user := User{Name: "query_expr_having_ruser1", Email: "root@user1.com", Age: 64}
 | 
			
		||||
	DB.Save(&user)
 | 
			
		||||
	user = User{Name: "query_expr_having_ruser2", Email: "root@user2.com", Age: 128}
 | 
			
		||||
	DB.Save(&user)
 | 
			
		||||
	user = User{Name: "query_expr_having_ruser3", Email: "root@user1.com", Age: 64}
 | 
			
		||||
	DB.Save(&user)
 | 
			
		||||
	user = User{Name: "query_expr_having_ruser4", Email: "root@user2.com", Age: 128}
 | 
			
		||||
	DB.Save(&user)
 | 
			
		||||
 | 
			
		||||
	var users []User
 | 
			
		||||
	DB.Select("AVG(age) as avgage").Where("name LIKE ?", "query_expr_having_%").Group("email").Having("AVG(age) > (?)", DB.
 | 
			
		||||
		Select("AVG(age)").Where("name LIKE ?", "query_expr_having_%").Table("users").QueryExpr()).Find(&users)
 | 
			
		||||
 | 
			
		||||
	if len(users) != 1 {
 | 
			
		||||
		t.Errorf("Two user group should be found, instead found %d", len(users))
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func DialectHasTzSupport() bool {
 | 
			
		||||
	// NB: mssql and FoundationDB do not support time zones.
 | 
			
		||||
	if dialect := os.Getenv("GORM_DIALECT"); dialect == "foundation" {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	return true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestTimeWithZone(t *testing.T) {
 | 
			
		||||
	var format = "2006-01-02 15:04:05 -0700"
 | 
			
		||||
	var times []time.Time
 | 
			
		||||
	GMT8, _ := time.LoadLocation("Asia/Shanghai")
 | 
			
		||||
	times = append(times, time.Date(2013, 02, 19, 1, 51, 49, 123456789, GMT8))
 | 
			
		||||
	times = append(times, time.Date(2013, 02, 18, 17, 51, 49, 123456789, time.UTC))
 | 
			
		||||
 | 
			
		||||
	for index, vtime := range times {
 | 
			
		||||
		name := "time_with_zone_" + strconv.Itoa(index)
 | 
			
		||||
		user := User{Name: name, Birthday: &vtime}
 | 
			
		||||
 | 
			
		||||
		if !DialectHasTzSupport() {
 | 
			
		||||
			// If our driver dialect doesn't support TZ's, just use UTC for everything here.
 | 
			
		||||
			utcBirthday := user.Birthday.UTC()
 | 
			
		||||
			user.Birthday = &utcBirthday
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		DB.Save(&user)
 | 
			
		||||
		expectedBirthday := "2013-02-18 17:51:49 +0000"
 | 
			
		||||
		foundBirthday := user.Birthday.UTC().Format(format)
 | 
			
		||||
		if foundBirthday != expectedBirthday {
 | 
			
		||||
			t.Errorf("User's birthday should not be changed after save for name=%s, expected bday=%+v but actual value=%+v", name, expectedBirthday, foundBirthday)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var findUser, findUser2, findUser3 User
 | 
			
		||||
		DB.First(&findUser, "name = ?", name)
 | 
			
		||||
		foundBirthday = findUser.Birthday.UTC().Format(format)
 | 
			
		||||
		if foundBirthday != expectedBirthday {
 | 
			
		||||
			t.Errorf("User's birthday should not be changed after find for name=%s, expected bday=%+v but actual value=%+v", name, expectedBirthday, foundBirthday)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(-time.Minute)).First(&findUser2).RecordNotFound() {
 | 
			
		||||
			t.Errorf("User should be found")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if !DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(time.Minute)).First(&findUser3).RecordNotFound() {
 | 
			
		||||
			t.Errorf("User should not be found")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestHstore(t *testing.T) {
 | 
			
		||||
	type Details struct {
 | 
			
		||||
		Id   int64
 | 
			
		||||
		Bulk postgres.Hstore
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" {
 | 
			
		||||
		t.Skip()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS hstore").Error; err != nil {
 | 
			
		||||
		fmt.Println("\033[31mHINT: Must be superuser to create hstore extension (ALTER USER gorm WITH SUPERUSER;)\033[0m")
 | 
			
		||||
		panic(fmt.Sprintf("No error should happen when create hstore extension, but got %+v", err))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Exec("drop table details")
 | 
			
		||||
 | 
			
		||||
	if err := DB.CreateTable(&Details{}).Error; err != nil {
 | 
			
		||||
		panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	bankAccountId, phoneNumber, opinion := "123456", "14151321232", "sharkbait"
 | 
			
		||||
	bulk := map[string]*string{
 | 
			
		||||
		"bankAccountId": &bankAccountId,
 | 
			
		||||
		"phoneNumber":   &phoneNumber,
 | 
			
		||||
		"opinion":       &opinion,
 | 
			
		||||
	}
 | 
			
		||||
	d := Details{Bulk: bulk}
 | 
			
		||||
	DB.Save(&d)
 | 
			
		||||
 | 
			
		||||
	var d2 Details
 | 
			
		||||
	if err := DB.First(&d2).Error; err != nil {
 | 
			
		||||
		t.Errorf("Got error when tried to fetch details: %+v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for k := range bulk {
 | 
			
		||||
		if r, ok := d2.Bulk[k]; ok {
 | 
			
		||||
			if res, _ := bulk[k]; *res != *r {
 | 
			
		||||
				t.Errorf("Details should be equal")
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
			t.Errorf("Details should be existed")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestSetAndGet(t *testing.T) {
 | 
			
		||||
	if value, ok := DB.Set("hello", "world").Get("hello"); !ok {
 | 
			
		||||
		t.Errorf("Should be able to get setting after set")
 | 
			
		||||
	} else {
 | 
			
		||||
		if value.(string) != "world" {
 | 
			
		||||
			t.Errorf("Setted value should not be changed")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if _, ok := DB.Get("non_existing"); ok {
 | 
			
		||||
		t.Errorf("Get non existing key should return error")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestCompatibilityMode(t *testing.T) {
 | 
			
		||||
	DB, _ := gorm.Open("testdb", "")
 | 
			
		||||
	testdb.SetQueryFunc(func(query string) (driver.Rows, error) {
 | 
			
		||||
		columns := []string{"id", "name", "age"}
 | 
			
		||||
		result := `
 | 
			
		||||
		1,Tim,20
 | 
			
		||||
		2,Joe,25
 | 
			
		||||
		3,Bob,30
 | 
			
		||||
		`
 | 
			
		||||
		return testdb.RowsFromCSVString(columns, result), nil
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	var users []User
 | 
			
		||||
	DB.Find(&users)
 | 
			
		||||
	if (users[0].Name != "Tim") || len(users) != 3 {
 | 
			
		||||
		t.Errorf("Unexcepted result returned")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestOpenExistingDB(t *testing.T) {
 | 
			
		||||
	DB.Save(&User{Name: "jnfeinstein"})
 | 
			
		||||
	dialect := os.Getenv("GORM_DIALECT")
 | 
			
		||||
 | 
			
		||||
	db, err := gorm.Open(dialect, DB.DB())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("Should have wrapped the existing DB connection")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var user User
 | 
			
		||||
	if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.ErrRecordNotFound {
 | 
			
		||||
		t.Errorf("Should have found existing record")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestDdlErrors(t *testing.T) {
 | 
			
		||||
	var err error
 | 
			
		||||
 | 
			
		||||
	if err = DB.Close(); err != nil {
 | 
			
		||||
		t.Errorf("Closing DDL test db connection err=%s", err)
 | 
			
		||||
	}
 | 
			
		||||
	defer func() {
 | 
			
		||||
		// Reopen DB connection.
 | 
			
		||||
		if DB, err = OpenTestConnection(); err != nil {
 | 
			
		||||
			t.Fatalf("Failed re-opening db connection: %s", err)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	if err := DB.Find(&User{}).Error; err == nil {
 | 
			
		||||
		t.Errorf("Expected operation on closed db to produce an error, but err was nil")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestOpenWithOneParameter(t *testing.T) {
 | 
			
		||||
	db, err := gorm.Open("dialect")
 | 
			
		||||
	if db != nil {
 | 
			
		||||
		t.Error("Open with one parameter returned non nil for db")
 | 
			
		||||
	}
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		t.Error("Open with one parameter returned err as nil")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestBlockGlobalUpdate(t *testing.T) {
 | 
			
		||||
	db := DB.New()
 | 
			
		||||
	db.Create(&Toy{Name: "Stuffed Animal", OwnerType: "Nobody"})
 | 
			
		||||
 | 
			
		||||
	err := db.Model(&Toy{}).Update("OwnerType", "Human").Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Error("Unexpected error on global update")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = db.Delete(&Toy{}).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Error("Unexpected error on global delete")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	db.BlockGlobalUpdate(true)
 | 
			
		||||
 | 
			
		||||
	db.Create(&Toy{Name: "Stuffed Animal", OwnerType: "Nobody"})
 | 
			
		||||
 | 
			
		||||
	err = db.Model(&Toy{}).Update("OwnerType", "Human").Error
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		t.Error("Expected error on global update")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = db.Model(&Toy{}).Where(&Toy{OwnerType: "Martian"}).Update("OwnerType", "Astronaut").Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Error("Unxpected error on conditional update")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = db.Delete(&Toy{}).Error
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		t.Error("Expected error on global delete")
 | 
			
		||||
	}
 | 
			
		||||
	err = db.Where(&Toy{OwnerType: "Martian"}).Delete(&Toy{}).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Error("Unexpected error on conditional delete")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func BenchmarkGorm(b *testing.B) {
 | 
			
		||||
	b.N = 2000
 | 
			
		||||
	for x := 0; x < b.N; x++ {
 | 
			
		||||
		e := strconv.Itoa(x) + "benchmark@example.org"
 | 
			
		||||
		now := time.Now()
 | 
			
		||||
		email := EmailWithIdx{Email: e, UserAgent: "pc", RegisteredAt: &now}
 | 
			
		||||
		// Insert
 | 
			
		||||
		DB.Save(&email)
 | 
			
		||||
		// Query
 | 
			
		||||
		DB.First(&EmailWithIdx{}, "email = ?", e)
 | 
			
		||||
		// Update
 | 
			
		||||
		DB.Model(&email).UpdateColumn("email", "new-"+e)
 | 
			
		||||
		// Delete
 | 
			
		||||
		DB.Delete(&email)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func BenchmarkRawSql(b *testing.B) {
 | 
			
		||||
	DB, _ := sql.Open("postgres", "user=gorm DB.ame=gorm sslmode=disable")
 | 
			
		||||
	DB.SetMaxIdleConns(10)
 | 
			
		||||
	insertSql := "INSERT INTO emails (user_id,email,user_agent,registered_at,created_at,updated_at) VALUES ($1,$2,$3,$4,$5,$6) RETURNING id"
 | 
			
		||||
	querySql := "SELECT * FROM emails WHERE email = $1 ORDER BY id LIMIT 1"
 | 
			
		||||
	updateSql := "UPDATE emails SET email = $1, updated_at = $2 WHERE id = $3"
 | 
			
		||||
	deleteSql := "DELETE FROM orders WHERE id = $1"
 | 
			
		||||
 | 
			
		||||
	b.N = 2000
 | 
			
		||||
	for x := 0; x < b.N; x++ {
 | 
			
		||||
		var id int64
 | 
			
		||||
		e := strconv.Itoa(x) + "benchmark@example.org"
 | 
			
		||||
		now := time.Now()
 | 
			
		||||
		email := EmailWithIdx{Email: e, UserAgent: "pc", RegisteredAt: &now}
 | 
			
		||||
		// Insert
 | 
			
		||||
		DB.QueryRow(insertSql, email.UserId, email.Email, email.UserAgent, email.RegisteredAt, time.Now(), time.Now()).Scan(&id)
 | 
			
		||||
		// Query
 | 
			
		||||
		rows, _ := DB.Query(querySql, email.Email)
 | 
			
		||||
		rows.Close()
 | 
			
		||||
		// Update
 | 
			
		||||
		DB.Exec(updateSql, "new-"+e, time.Now(), id)
 | 
			
		||||
		// Delete
 | 
			
		||||
		DB.Exec(deleteSql, id)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func parseTime(str string) *time.Time {
 | 
			
		||||
	t := now.New(time.Now().UTC()).MustParse(str)
 | 
			
		||||
	return &t
 | 
			
		||||
}
 | 
			
		||||
@ -1,485 +0,0 @@
 | 
			
		||||
package gorm_test
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"database/sql/driver"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"os"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/jinzhu/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type User struct {
 | 
			
		||||
	Id                int64
 | 
			
		||||
	Age               int64
 | 
			
		||||
	UserNum           Num
 | 
			
		||||
	Name              string `sql:"size:255"`
 | 
			
		||||
	Email             string
 | 
			
		||||
	Birthday          *time.Time    // Time
 | 
			
		||||
	CreatedAt         time.Time     // CreatedAt: Time of record is created, will be insert automatically
 | 
			
		||||
	UpdatedAt         time.Time     // UpdatedAt: Time of record is updated, will be updated automatically
 | 
			
		||||
	Emails            []Email       // Embedded structs
 | 
			
		||||
	BillingAddress    Address       // Embedded struct
 | 
			
		||||
	BillingAddressID  sql.NullInt64 // Embedded struct's foreign key
 | 
			
		||||
	ShippingAddress   Address       // Embedded struct
 | 
			
		||||
	ShippingAddressId int64         // Embedded struct's foreign key
 | 
			
		||||
	CreditCard        CreditCard
 | 
			
		||||
	Latitude          float64
 | 
			
		||||
	Languages         []Language `gorm:"many2many:user_languages;"`
 | 
			
		||||
	CompanyID         *int
 | 
			
		||||
	Company           Company
 | 
			
		||||
	Role              Role
 | 
			
		||||
	Password          EncryptedData
 | 
			
		||||
	PasswordHash      []byte
 | 
			
		||||
	IgnoreMe          int64                 `sql:"-"`
 | 
			
		||||
	IgnoreStringSlice []string              `sql:"-"`
 | 
			
		||||
	Ignored           struct{ Name string } `sql:"-"`
 | 
			
		||||
	IgnoredPointer    *User                 `sql:"-"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type NotSoLongTableName struct {
 | 
			
		||||
	Id                int64
 | 
			
		||||
	ReallyLongThingID int64
 | 
			
		||||
	ReallyLongThing   ReallyLongTableNameToTestMySQLNameLengthLimit
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ReallyLongTableNameToTestMySQLNameLengthLimit struct {
 | 
			
		||||
	Id int64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ReallyLongThingThatReferencesShort struct {
 | 
			
		||||
	Id      int64
 | 
			
		||||
	ShortID int64
 | 
			
		||||
	Short   Short
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Short struct {
 | 
			
		||||
	Id int64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type CreditCard struct {
 | 
			
		||||
	ID        int8
 | 
			
		||||
	Number    string
 | 
			
		||||
	UserId    sql.NullInt64
 | 
			
		||||
	CreatedAt time.Time `sql:"not null"`
 | 
			
		||||
	UpdatedAt time.Time
 | 
			
		||||
	DeletedAt *time.Time `sql:"column:deleted_time"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Email struct {
 | 
			
		||||
	Id        int16
 | 
			
		||||
	UserId    int
 | 
			
		||||
	Email     string `sql:"type:varchar(100);"`
 | 
			
		||||
	CreatedAt time.Time
 | 
			
		||||
	UpdatedAt time.Time
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Address struct {
 | 
			
		||||
	ID        int
 | 
			
		||||
	Address1  string
 | 
			
		||||
	Address2  string
 | 
			
		||||
	Post      string
 | 
			
		||||
	CreatedAt time.Time
 | 
			
		||||
	UpdatedAt time.Time
 | 
			
		||||
	DeletedAt *time.Time
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Language struct {
 | 
			
		||||
	gorm.Model
 | 
			
		||||
	Name  string
 | 
			
		||||
	Users []User `gorm:"many2many:user_languages;"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Product struct {
 | 
			
		||||
	Id                    int64
 | 
			
		||||
	Code                  string
 | 
			
		||||
	Price                 int64
 | 
			
		||||
	CreatedAt             time.Time
 | 
			
		||||
	UpdatedAt             time.Time
 | 
			
		||||
	AfterFindCallTimes    int64
 | 
			
		||||
	BeforeCreateCallTimes int64
 | 
			
		||||
	AfterCreateCallTimes  int64
 | 
			
		||||
	BeforeUpdateCallTimes int64
 | 
			
		||||
	AfterUpdateCallTimes  int64
 | 
			
		||||
	BeforeSaveCallTimes   int64
 | 
			
		||||
	AfterSaveCallTimes    int64
 | 
			
		||||
	BeforeDeleteCallTimes int64
 | 
			
		||||
	AfterDeleteCallTimes  int64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Company struct {
 | 
			
		||||
	Id    int64
 | 
			
		||||
	Name  string
 | 
			
		||||
	Owner *User `sql:"-"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type EncryptedData []byte
 | 
			
		||||
 | 
			
		||||
func (data *EncryptedData) Scan(value interface{}) error {
 | 
			
		||||
	if b, ok := value.([]byte); ok {
 | 
			
		||||
		if len(b) < 3 || b[0] != '*' || b[1] != '*' || b[2] != '*' {
 | 
			
		||||
			return errors.New("Too short")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		*data = b[3:]
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return errors.New("Bytes expected")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (data EncryptedData) Value() (driver.Value, error) {
 | 
			
		||||
	if len(data) > 0 && data[0] == 'x' {
 | 
			
		||||
		//needed to test failures
 | 
			
		||||
		return nil, errors.New("Should not start with 'x'")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	//prepend asterisks
 | 
			
		||||
	return append([]byte("***"), data...), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Role struct {
 | 
			
		||||
	Name string `gorm:"size:256"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (role *Role) Scan(value interface{}) error {
 | 
			
		||||
	if b, ok := value.([]uint8); ok {
 | 
			
		||||
		role.Name = string(b)
 | 
			
		||||
	} else {
 | 
			
		||||
		role.Name = value.(string)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (role Role) Value() (driver.Value, error) {
 | 
			
		||||
	return role.Name, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (role Role) IsAdmin() bool {
 | 
			
		||||
	return role.Name == "admin"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Num int64
 | 
			
		||||
 | 
			
		||||
func (i *Num) Scan(src interface{}) error {
 | 
			
		||||
	switch s := src.(type) {
 | 
			
		||||
	case []byte:
 | 
			
		||||
		n, _ := strconv.Atoi(string(s))
 | 
			
		||||
		*i = Num(n)
 | 
			
		||||
	case int64:
 | 
			
		||||
		*i = Num(s)
 | 
			
		||||
	default:
 | 
			
		||||
		return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String())
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Animal struct {
 | 
			
		||||
	Counter    uint64    `gorm:"primary_key:yes"`
 | 
			
		||||
	Name       string    `sql:"DEFAULT:'galeone'"`
 | 
			
		||||
	From       string    //test reserved sql keyword as field name
 | 
			
		||||
	Age        time.Time `sql:"DEFAULT:current_timestamp"`
 | 
			
		||||
	unexported string    // unexported value
 | 
			
		||||
	CreatedAt  time.Time
 | 
			
		||||
	UpdatedAt  time.Time
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type JoinTable struct {
 | 
			
		||||
	From uint64
 | 
			
		||||
	To   uint64
 | 
			
		||||
	Time time.Time `sql:"default: null"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Post struct {
 | 
			
		||||
	Id             int64
 | 
			
		||||
	CategoryId     sql.NullInt64
 | 
			
		||||
	MainCategoryId int64
 | 
			
		||||
	Title          string
 | 
			
		||||
	Body           string
 | 
			
		||||
	Comments       []*Comment
 | 
			
		||||
	Category       Category
 | 
			
		||||
	MainCategory   Category
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Category struct {
 | 
			
		||||
	gorm.Model
 | 
			
		||||
	Name string
 | 
			
		||||
 | 
			
		||||
	Categories []Category
 | 
			
		||||
	CategoryID *uint
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Comment struct {
 | 
			
		||||
	gorm.Model
 | 
			
		||||
	PostId  int64
 | 
			
		||||
	Content string
 | 
			
		||||
	Post    Post
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Scanner
 | 
			
		||||
type NullValue struct {
 | 
			
		||||
	Id      int64
 | 
			
		||||
	Name    sql.NullString  `sql:"not null"`
 | 
			
		||||
	Gender  *sql.NullString `sql:"not null"`
 | 
			
		||||
	Age     sql.NullInt64
 | 
			
		||||
	Male    sql.NullBool
 | 
			
		||||
	Height  sql.NullFloat64
 | 
			
		||||
	AddedAt NullTime
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type NullTime struct {
 | 
			
		||||
	Time  time.Time
 | 
			
		||||
	Valid bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (nt *NullTime) Scan(value interface{}) error {
 | 
			
		||||
	if value == nil {
 | 
			
		||||
		nt.Valid = false
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	nt.Time, nt.Valid = value.(time.Time), true
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (nt NullTime) Value() (driver.Value, error) {
 | 
			
		||||
	if !nt.Valid {
 | 
			
		||||
		return nil, nil
 | 
			
		||||
	}
 | 
			
		||||
	return nt.Time, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getPreparedUser(name string, role string) *User {
 | 
			
		||||
	var company Company
 | 
			
		||||
	DB.Where(Company{Name: role}).FirstOrCreate(&company)
 | 
			
		||||
 | 
			
		||||
	return &User{
 | 
			
		||||
		Name:            name,
 | 
			
		||||
		Age:             20,
 | 
			
		||||
		Role:            Role{role},
 | 
			
		||||
		BillingAddress:  Address{Address1: fmt.Sprintf("Billing Address %v", name)},
 | 
			
		||||
		ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)},
 | 
			
		||||
		CreditCard:      CreditCard{Number: fmt.Sprintf("123456%v", name)},
 | 
			
		||||
		Emails: []Email{
 | 
			
		||||
			{Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)},
 | 
			
		||||
		},
 | 
			
		||||
		Company: company,
 | 
			
		||||
		Languages: []Language{
 | 
			
		||||
			{Name: fmt.Sprintf("lang_1_%v", name)},
 | 
			
		||||
			{Name: fmt.Sprintf("lang_2_%v", name)},
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func runMigration() {
 | 
			
		||||
	if err := DB.DropTableIfExists(&User{}).Error; err != nil {
 | 
			
		||||
		fmt.Printf("Got error when try to delete table users, %+v\n", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, table := range []string{"animals", "user_languages"} {
 | 
			
		||||
		DB.Exec(fmt.Sprintf("drop table %v;", table))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}}
 | 
			
		||||
	for _, value := range values {
 | 
			
		||||
		DB.DropTable(value)
 | 
			
		||||
	}
 | 
			
		||||
	if err := DB.AutoMigrate(values...).Error; err != nil {
 | 
			
		||||
		panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestIndexes(t *testing.T) {
 | 
			
		||||
	if err := DB.Model(&Email{}).AddIndex("idx_email_email", "email").Error; err != nil {
 | 
			
		||||
		t.Errorf("Got error when tried to create index: %+v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	scope := DB.NewScope(&Email{})
 | 
			
		||||
	if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") {
 | 
			
		||||
		t.Errorf("Email should have index idx_email_email")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := DB.Model(&Email{}).RemoveIndex("idx_email_email").Error; err != nil {
 | 
			
		||||
		t.Errorf("Got error when tried to remove index: %+v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") {
 | 
			
		||||
		t.Errorf("Email's index idx_email_email should be deleted")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := DB.Model(&Email{}).AddIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil {
 | 
			
		||||
		t.Errorf("Got error when tried to create index: %+v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
 | 
			
		||||
		t.Errorf("Email should have index idx_email_email_and_user_id")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil {
 | 
			
		||||
		t.Errorf("Got error when tried to remove index: %+v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
 | 
			
		||||
		t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := DB.Model(&Email{}).AddUniqueIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil {
 | 
			
		||||
		t.Errorf("Got error when tried to create index: %+v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
 | 
			
		||||
		t.Errorf("Email should have index idx_email_email_and_user_id")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.comiii"}, {Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error == nil {
 | 
			
		||||
		t.Errorf("Should get to create duplicate record when having unique index")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var user = User{Name: "sample_user"}
 | 
			
		||||
	DB.Save(&user)
 | 
			
		||||
	if DB.Model(&user).Association("Emails").Append(Email{Email: "not-1duplicated@gmail.com"}, Email{Email: "not-duplicated2@gmail.com"}).Error != nil {
 | 
			
		||||
		t.Errorf("Should get no error when append two emails for user")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.Model(&user).Association("Emails").Append(Email{Email: "duplicated@gmail.com"}, Email{Email: "duplicated@gmail.com"}).Error == nil {
 | 
			
		||||
		t.Errorf("Should get no duplicated email error when insert duplicated emails for a user")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil {
 | 
			
		||||
		t.Errorf("Got error when tried to remove index: %+v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
 | 
			
		||||
		t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error != nil {
 | 
			
		||||
		t.Errorf("Should be able to create duplicated emails after remove unique index")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type EmailWithIdx struct {
 | 
			
		||||
	Id           int64
 | 
			
		||||
	UserId       int64
 | 
			
		||||
	Email        string     `sql:"index:idx_email_agent"`
 | 
			
		||||
	UserAgent    string     `sql:"index:idx_email_agent"`
 | 
			
		||||
	RegisteredAt *time.Time `sql:"unique_index"`
 | 
			
		||||
	CreatedAt    time.Time
 | 
			
		||||
	UpdatedAt    time.Time
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestAutoMigration(t *testing.T) {
 | 
			
		||||
	DB.AutoMigrate(&Address{})
 | 
			
		||||
	DB.DropTable(&EmailWithIdx{})
 | 
			
		||||
	if err := DB.AutoMigrate(&EmailWithIdx{}).Error; err != nil {
 | 
			
		||||
		t.Errorf("Auto Migrate should not raise any error")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
	DB.Save(&EmailWithIdx{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: &now})
 | 
			
		||||
 | 
			
		||||
	scope := DB.NewScope(&EmailWithIdx{})
 | 
			
		||||
	if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_agent") {
 | 
			
		||||
		t.Errorf("Failed to create index")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !scope.Dialect().HasIndex(scope.TableName(), "uix_email_with_idxes_registered_at") {
 | 
			
		||||
		t.Errorf("Failed to create index")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var bigemail EmailWithIdx
 | 
			
		||||
	DB.First(&bigemail, "user_agent = ?", "pc")
 | 
			
		||||
	if bigemail.Email != "jinzhu@example.org" || bigemail.UserAgent != "pc" || bigemail.RegisteredAt.IsZero() {
 | 
			
		||||
		t.Error("Big Emails should be saved and fetched correctly")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type MultipleIndexes struct {
 | 
			
		||||
	ID     int64
 | 
			
		||||
	UserID int64  `sql:"unique_index:uix_multipleindexes_user_name,uix_multipleindexes_user_email;index:idx_multipleindexes_user_other"`
 | 
			
		||||
	Name   string `sql:"unique_index:uix_multipleindexes_user_name"`
 | 
			
		||||
	Email  string `sql:"unique_index:,uix_multipleindexes_user_email"`
 | 
			
		||||
	Other  string `sql:"index:,idx_multipleindexes_user_other"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestMultipleIndexes(t *testing.T) {
 | 
			
		||||
	if err := DB.DropTableIfExists(&MultipleIndexes{}).Error; err != nil {
 | 
			
		||||
		fmt.Printf("Got error when try to delete table multiple_indexes, %+v\n", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.AutoMigrate(&MultipleIndexes{})
 | 
			
		||||
	if err := DB.AutoMigrate(&EmailWithIdx{}).Error; err != nil {
 | 
			
		||||
		t.Errorf("Auto Migrate should not raise any error")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Save(&MultipleIndexes{UserID: 1, Name: "jinzhu", Email: "jinzhu@example.org", Other: "foo"})
 | 
			
		||||
 | 
			
		||||
	scope := DB.NewScope(&MultipleIndexes{})
 | 
			
		||||
	if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_name") {
 | 
			
		||||
		t.Errorf("Failed to create index")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_email") {
 | 
			
		||||
		t.Errorf("Failed to create index")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !scope.Dialect().HasIndex(scope.TableName(), "uix_multiple_indexes_email") {
 | 
			
		||||
		t.Errorf("Failed to create index")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !scope.Dialect().HasIndex(scope.TableName(), "idx_multipleindexes_user_other") {
 | 
			
		||||
		t.Errorf("Failed to create index")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !scope.Dialect().HasIndex(scope.TableName(), "idx_multiple_indexes_other") {
 | 
			
		||||
		t.Errorf("Failed to create index")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var mutipleIndexes MultipleIndexes
 | 
			
		||||
	DB.First(&mutipleIndexes, "name = ?", "jinzhu")
 | 
			
		||||
	if mutipleIndexes.Email != "jinzhu@example.org" || mutipleIndexes.Name != "jinzhu" {
 | 
			
		||||
		t.Error("MutipleIndexes should be saved and fetched correctly")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Check unique constraints
 | 
			
		||||
	if err := DB.Save(&MultipleIndexes{UserID: 1, Name: "name1", Email: "jinzhu@example.org", Other: "foo"}).Error; err == nil {
 | 
			
		||||
		t.Error("MultipleIndexes unique index failed")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := DB.Save(&MultipleIndexes{UserID: 1, Name: "name1", Email: "foo@example.org", Other: "foo"}).Error; err != nil {
 | 
			
		||||
		t.Error("MultipleIndexes unique index failed")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := DB.Save(&MultipleIndexes{UserID: 2, Name: "name1", Email: "jinzhu@example.org", Other: "foo"}).Error; err == nil {
 | 
			
		||||
		t.Error("MultipleIndexes unique index failed")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := DB.Save(&MultipleIndexes{UserID: 2, Name: "name1", Email: "foo2@example.org", Other: "foo"}).Error; err != nil {
 | 
			
		||||
		t.Error("MultipleIndexes unique index failed")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestModifyColumnType(t *testing.T) {
 | 
			
		||||
	if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" && dialect != "mysql" && dialect != "mssql" {
 | 
			
		||||
		t.Skip("Skipping this because only postgres, mysql and mssql support altering a column type")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	type ModifyColumnType struct {
 | 
			
		||||
		gorm.Model
 | 
			
		||||
		Name1 string `gorm:"length:100"`
 | 
			
		||||
		Name2 string `gorm:"length:200"`
 | 
			
		||||
	}
 | 
			
		||||
	DB.DropTable(&ModifyColumnType{})
 | 
			
		||||
	DB.CreateTable(&ModifyColumnType{})
 | 
			
		||||
 | 
			
		||||
	name2Field, _ := DB.NewScope(&ModifyColumnType{}).FieldByName("Name2")
 | 
			
		||||
	name2Type := DB.Dialect().DataTypeOf(name2Field.StructField)
 | 
			
		||||
 | 
			
		||||
	if err := DB.Model(&ModifyColumnType{}).ModifyColumn("name1", name2Type).Error; err != nil {
 | 
			
		||||
		t.Errorf("No error should happen when ModifyColumn, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										14
									
								
								model.go
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								model.go
									
									
									
									
									
								
							@ -1,14 +0,0 @@
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import "time"
 | 
			
		||||
 | 
			
		||||
// Model base model definition, including fields `ID`, `CreatedAt`, `UpdatedAt`, `DeletedAt`, which could be embedded in your models
 | 
			
		||||
//    type User struct {
 | 
			
		||||
//      gorm.Model
 | 
			
		||||
//    }
 | 
			
		||||
type Model struct {
 | 
			
		||||
	ID        uint `gorm:"primary_key"`
 | 
			
		||||
	CreatedAt time.Time
 | 
			
		||||
	UpdatedAt time.Time
 | 
			
		||||
	DeletedAt *time.Time `sql:"index"`
 | 
			
		||||
}
 | 
			
		||||
@ -1,381 +0,0 @@
 | 
			
		||||
package gorm_test
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"os"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"sort"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Blog struct {
 | 
			
		||||
	ID         uint   `gorm:"primary_key"`
 | 
			
		||||
	Locale     string `gorm:"primary_key"`
 | 
			
		||||
	Subject    string
 | 
			
		||||
	Body       string
 | 
			
		||||
	Tags       []Tag `gorm:"many2many:blog_tags;"`
 | 
			
		||||
	SharedTags []Tag `gorm:"many2many:shared_blog_tags;ForeignKey:id;AssociationForeignKey:id"`
 | 
			
		||||
	LocaleTags []Tag `gorm:"many2many:locale_blog_tags;ForeignKey:id,locale;AssociationForeignKey:id"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Tag struct {
 | 
			
		||||
	ID     uint   `gorm:"primary_key"`
 | 
			
		||||
	Locale string `gorm:"primary_key"`
 | 
			
		||||
	Value  string
 | 
			
		||||
	Blogs  []*Blog `gorm:"many2many:blogs_tags"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func compareTags(tags []Tag, contents []string) bool {
 | 
			
		||||
	var tagContents []string
 | 
			
		||||
	for _, tag := range tags {
 | 
			
		||||
		tagContents = append(tagContents, tag.Value)
 | 
			
		||||
	}
 | 
			
		||||
	sort.Strings(tagContents)
 | 
			
		||||
	sort.Strings(contents)
 | 
			
		||||
	return reflect.DeepEqual(tagContents, contents)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestManyToManyWithMultiPrimaryKeys(t *testing.T) {
 | 
			
		||||
	if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" && dialect != "mssql" {
 | 
			
		||||
		DB.DropTable(&Blog{}, &Tag{})
 | 
			
		||||
		DB.DropTable("blog_tags")
 | 
			
		||||
		DB.CreateTable(&Blog{}, &Tag{})
 | 
			
		||||
		blog := Blog{
 | 
			
		||||
			Locale:  "ZH",
 | 
			
		||||
			Subject: "subject",
 | 
			
		||||
			Body:    "body",
 | 
			
		||||
			Tags: []Tag{
 | 
			
		||||
				{Locale: "ZH", Value: "tag1"},
 | 
			
		||||
				{Locale: "ZH", Value: "tag2"},
 | 
			
		||||
			},
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		DB.Save(&blog)
 | 
			
		||||
		if !compareTags(blog.Tags, []string{"tag1", "tag2"}) {
 | 
			
		||||
			t.Errorf("Blog should has two tags")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Append
 | 
			
		||||
		var tag3 = &Tag{Locale: "ZH", Value: "tag3"}
 | 
			
		||||
		DB.Model(&blog).Association("Tags").Append([]*Tag{tag3})
 | 
			
		||||
		if !compareTags(blog.Tags, []string{"tag1", "tag2", "tag3"}) {
 | 
			
		||||
			t.Errorf("Blog should has three tags after Append")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if DB.Model(&blog).Association("Tags").Count() != 3 {
 | 
			
		||||
			t.Errorf("Blog should has three tags after Append")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var tags []Tag
 | 
			
		||||
		DB.Model(&blog).Related(&tags, "Tags")
 | 
			
		||||
		if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) {
 | 
			
		||||
			t.Errorf("Should find 3 tags with Related")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var blog1 Blog
 | 
			
		||||
		DB.Preload("Tags").Find(&blog1)
 | 
			
		||||
		if !compareTags(blog1.Tags, []string{"tag1", "tag2", "tag3"}) {
 | 
			
		||||
			t.Errorf("Preload many2many relations")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Replace
 | 
			
		||||
		var tag5 = &Tag{Locale: "ZH", Value: "tag5"}
 | 
			
		||||
		var tag6 = &Tag{Locale: "ZH", Value: "tag6"}
 | 
			
		||||
		DB.Model(&blog).Association("Tags").Replace(tag5, tag6)
 | 
			
		||||
		var tags2 []Tag
 | 
			
		||||
		DB.Model(&blog).Related(&tags2, "Tags")
 | 
			
		||||
		if !compareTags(tags2, []string{"tag5", "tag6"}) {
 | 
			
		||||
			t.Errorf("Should find 2 tags after Replace")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if DB.Model(&blog).Association("Tags").Count() != 2 {
 | 
			
		||||
			t.Errorf("Blog should has three tags after Replace")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Delete
 | 
			
		||||
		DB.Model(&blog).Association("Tags").Delete(tag5)
 | 
			
		||||
		var tags3 []Tag
 | 
			
		||||
		DB.Model(&blog).Related(&tags3, "Tags")
 | 
			
		||||
		if !compareTags(tags3, []string{"tag6"}) {
 | 
			
		||||
			t.Errorf("Should find 1 tags after Delete")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if DB.Model(&blog).Association("Tags").Count() != 1 {
 | 
			
		||||
			t.Errorf("Blog should has three tags after Delete")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		DB.Model(&blog).Association("Tags").Delete(tag3)
 | 
			
		||||
		var tags4 []Tag
 | 
			
		||||
		DB.Model(&blog).Related(&tags4, "Tags")
 | 
			
		||||
		if !compareTags(tags4, []string{"tag6"}) {
 | 
			
		||||
			t.Errorf("Tag should not be deleted when Delete with a unrelated tag")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Clear
 | 
			
		||||
		DB.Model(&blog).Association("Tags").Clear()
 | 
			
		||||
		if DB.Model(&blog).Association("Tags").Count() != 0 {
 | 
			
		||||
			t.Errorf("All tags should be cleared")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestManyToManyWithCustomizedForeignKeys(t *testing.T) {
 | 
			
		||||
	if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" && dialect != "mssql" {
 | 
			
		||||
		DB.DropTable(&Blog{}, &Tag{})
 | 
			
		||||
		DB.DropTable("shared_blog_tags")
 | 
			
		||||
		DB.CreateTable(&Blog{}, &Tag{})
 | 
			
		||||
		blog := Blog{
 | 
			
		||||
			Locale:  "ZH",
 | 
			
		||||
			Subject: "subject",
 | 
			
		||||
			Body:    "body",
 | 
			
		||||
			SharedTags: []Tag{
 | 
			
		||||
				{Locale: "ZH", Value: "tag1"},
 | 
			
		||||
				{Locale: "ZH", Value: "tag2"},
 | 
			
		||||
			},
 | 
			
		||||
		}
 | 
			
		||||
		DB.Save(&blog)
 | 
			
		||||
 | 
			
		||||
		blog2 := Blog{
 | 
			
		||||
			ID:     blog.ID,
 | 
			
		||||
			Locale: "EN",
 | 
			
		||||
		}
 | 
			
		||||
		DB.Create(&blog2)
 | 
			
		||||
 | 
			
		||||
		if !compareTags(blog.SharedTags, []string{"tag1", "tag2"}) {
 | 
			
		||||
			t.Errorf("Blog should has two tags")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Append
 | 
			
		||||
		var tag3 = &Tag{Locale: "ZH", Value: "tag3"}
 | 
			
		||||
		DB.Model(&blog).Association("SharedTags").Append([]*Tag{tag3})
 | 
			
		||||
		if !compareTags(blog.SharedTags, []string{"tag1", "tag2", "tag3"}) {
 | 
			
		||||
			t.Errorf("Blog should has three tags after Append")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if DB.Model(&blog).Association("SharedTags").Count() != 3 {
 | 
			
		||||
			t.Errorf("Blog should has three tags after Append")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if DB.Model(&blog2).Association("SharedTags").Count() != 3 {
 | 
			
		||||
			t.Errorf("Blog should has three tags after Append")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var tags []Tag
 | 
			
		||||
		DB.Model(&blog).Related(&tags, "SharedTags")
 | 
			
		||||
		if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) {
 | 
			
		||||
			t.Errorf("Should find 3 tags with Related")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		DB.Model(&blog2).Related(&tags, "SharedTags")
 | 
			
		||||
		if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) {
 | 
			
		||||
			t.Errorf("Should find 3 tags with Related")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var blog1 Blog
 | 
			
		||||
		DB.Preload("SharedTags").Find(&blog1)
 | 
			
		||||
		if !compareTags(blog1.SharedTags, []string{"tag1", "tag2", "tag3"}) {
 | 
			
		||||
			t.Errorf("Preload many2many relations")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var tag4 = &Tag{Locale: "ZH", Value: "tag4"}
 | 
			
		||||
		DB.Model(&blog2).Association("SharedTags").Append(tag4)
 | 
			
		||||
 | 
			
		||||
		DB.Model(&blog).Related(&tags, "SharedTags")
 | 
			
		||||
		if !compareTags(tags, []string{"tag1", "tag2", "tag3", "tag4"}) {
 | 
			
		||||
			t.Errorf("Should find 3 tags with Related")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		DB.Model(&blog2).Related(&tags, "SharedTags")
 | 
			
		||||
		if !compareTags(tags, []string{"tag1", "tag2", "tag3", "tag4"}) {
 | 
			
		||||
			t.Errorf("Should find 3 tags with Related")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Replace
 | 
			
		||||
		var tag5 = &Tag{Locale: "ZH", Value: "tag5"}
 | 
			
		||||
		var tag6 = &Tag{Locale: "ZH", Value: "tag6"}
 | 
			
		||||
		DB.Model(&blog2).Association("SharedTags").Replace(tag5, tag6)
 | 
			
		||||
		var tags2 []Tag
 | 
			
		||||
		DB.Model(&blog).Related(&tags2, "SharedTags")
 | 
			
		||||
		if !compareTags(tags2, []string{"tag5", "tag6"}) {
 | 
			
		||||
			t.Errorf("Should find 2 tags after Replace")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		DB.Model(&blog2).Related(&tags2, "SharedTags")
 | 
			
		||||
		if !compareTags(tags2, []string{"tag5", "tag6"}) {
 | 
			
		||||
			t.Errorf("Should find 2 tags after Replace")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if DB.Model(&blog).Association("SharedTags").Count() != 2 {
 | 
			
		||||
			t.Errorf("Blog should has three tags after Replace")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Delete
 | 
			
		||||
		DB.Model(&blog).Association("SharedTags").Delete(tag5)
 | 
			
		||||
		var tags3 []Tag
 | 
			
		||||
		DB.Model(&blog).Related(&tags3, "SharedTags")
 | 
			
		||||
		if !compareTags(tags3, []string{"tag6"}) {
 | 
			
		||||
			t.Errorf("Should find 1 tags after Delete")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if DB.Model(&blog).Association("SharedTags").Count() != 1 {
 | 
			
		||||
			t.Errorf("Blog should has three tags after Delete")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		DB.Model(&blog2).Association("SharedTags").Delete(tag3)
 | 
			
		||||
		var tags4 []Tag
 | 
			
		||||
		DB.Model(&blog).Related(&tags4, "SharedTags")
 | 
			
		||||
		if !compareTags(tags4, []string{"tag6"}) {
 | 
			
		||||
			t.Errorf("Tag should not be deleted when Delete with a unrelated tag")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Clear
 | 
			
		||||
		DB.Model(&blog2).Association("SharedTags").Clear()
 | 
			
		||||
		if DB.Model(&blog).Association("SharedTags").Count() != 0 {
 | 
			
		||||
			t.Errorf("All tags should be cleared")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) {
 | 
			
		||||
	if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" && dialect != "mssql" {
 | 
			
		||||
		DB.DropTable(&Blog{}, &Tag{})
 | 
			
		||||
		DB.DropTable("locale_blog_tags")
 | 
			
		||||
		DB.CreateTable(&Blog{}, &Tag{})
 | 
			
		||||
		blog := Blog{
 | 
			
		||||
			Locale:  "ZH",
 | 
			
		||||
			Subject: "subject",
 | 
			
		||||
			Body:    "body",
 | 
			
		||||
			LocaleTags: []Tag{
 | 
			
		||||
				{Locale: "ZH", Value: "tag1"},
 | 
			
		||||
				{Locale: "ZH", Value: "tag2"},
 | 
			
		||||
			},
 | 
			
		||||
		}
 | 
			
		||||
		DB.Save(&blog)
 | 
			
		||||
 | 
			
		||||
		blog2 := Blog{
 | 
			
		||||
			ID:     blog.ID,
 | 
			
		||||
			Locale: "EN",
 | 
			
		||||
		}
 | 
			
		||||
		DB.Create(&blog2)
 | 
			
		||||
 | 
			
		||||
		// Append
 | 
			
		||||
		var tag3 = &Tag{Locale: "ZH", Value: "tag3"}
 | 
			
		||||
		DB.Model(&blog).Association("LocaleTags").Append([]*Tag{tag3})
 | 
			
		||||
		if !compareTags(blog.LocaleTags, []string{"tag1", "tag2", "tag3"}) {
 | 
			
		||||
			t.Errorf("Blog should has three tags after Append")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if DB.Model(&blog).Association("LocaleTags").Count() != 3 {
 | 
			
		||||
			t.Errorf("Blog should has three tags after Append")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if DB.Model(&blog2).Association("LocaleTags").Count() != 0 {
 | 
			
		||||
			t.Errorf("EN Blog should has 0 tags after ZH Blog Append")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var tags []Tag
 | 
			
		||||
		DB.Model(&blog).Related(&tags, "LocaleTags")
 | 
			
		||||
		if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) {
 | 
			
		||||
			t.Errorf("Should find 3 tags with Related")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		DB.Model(&blog2).Related(&tags, "LocaleTags")
 | 
			
		||||
		if len(tags) != 0 {
 | 
			
		||||
			t.Errorf("Should find 0 tags with Related for EN Blog")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var blog1 Blog
 | 
			
		||||
		DB.Preload("LocaleTags").Find(&blog1, "locale = ? AND id = ?", "ZH", blog.ID)
 | 
			
		||||
		if !compareTags(blog1.LocaleTags, []string{"tag1", "tag2", "tag3"}) {
 | 
			
		||||
			t.Errorf("Preload many2many relations")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var tag4 = &Tag{Locale: "ZH", Value: "tag4"}
 | 
			
		||||
		DB.Model(&blog2).Association("LocaleTags").Append(tag4)
 | 
			
		||||
 | 
			
		||||
		DB.Model(&blog).Related(&tags, "LocaleTags")
 | 
			
		||||
		if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) {
 | 
			
		||||
			t.Errorf("Should find 3 tags with Related for EN Blog")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		DB.Model(&blog2).Related(&tags, "LocaleTags")
 | 
			
		||||
		if !compareTags(tags, []string{"tag4"}) {
 | 
			
		||||
			t.Errorf("Should find 1 tags with Related for EN Blog")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Replace
 | 
			
		||||
		var tag5 = &Tag{Locale: "ZH", Value: "tag5"}
 | 
			
		||||
		var tag6 = &Tag{Locale: "ZH", Value: "tag6"}
 | 
			
		||||
		DB.Model(&blog2).Association("LocaleTags").Replace(tag5, tag6)
 | 
			
		||||
 | 
			
		||||
		var tags2 []Tag
 | 
			
		||||
		DB.Model(&blog).Related(&tags2, "LocaleTags")
 | 
			
		||||
		if !compareTags(tags2, []string{"tag1", "tag2", "tag3"}) {
 | 
			
		||||
			t.Errorf("CN Blog's tags should not be changed after EN Blog Replace")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var blog11 Blog
 | 
			
		||||
		DB.Preload("LocaleTags").First(&blog11, "id = ? AND locale = ?", blog.ID, blog.Locale)
 | 
			
		||||
		if !compareTags(blog11.LocaleTags, []string{"tag1", "tag2", "tag3"}) {
 | 
			
		||||
			t.Errorf("CN Blog's tags should not be changed after EN Blog Replace")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		DB.Model(&blog2).Related(&tags2, "LocaleTags")
 | 
			
		||||
		if !compareTags(tags2, []string{"tag5", "tag6"}) {
 | 
			
		||||
			t.Errorf("Should find 2 tags after Replace")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var blog21 Blog
 | 
			
		||||
		DB.Preload("LocaleTags").First(&blog21, "id = ? AND locale = ?", blog2.ID, blog2.Locale)
 | 
			
		||||
		if !compareTags(blog21.LocaleTags, []string{"tag5", "tag6"}) {
 | 
			
		||||
			t.Errorf("EN Blog's tags should be changed after Replace")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if DB.Model(&blog).Association("LocaleTags").Count() != 3 {
 | 
			
		||||
			t.Errorf("ZH Blog should has three tags after Replace")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if DB.Model(&blog2).Association("LocaleTags").Count() != 2 {
 | 
			
		||||
			t.Errorf("EN Blog should has two tags after Replace")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Delete
 | 
			
		||||
		DB.Model(&blog).Association("LocaleTags").Delete(tag5)
 | 
			
		||||
 | 
			
		||||
		if DB.Model(&blog).Association("LocaleTags").Count() != 3 {
 | 
			
		||||
			t.Errorf("ZH Blog should has three tags after Delete with EN's tag")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if DB.Model(&blog2).Association("LocaleTags").Count() != 2 {
 | 
			
		||||
			t.Errorf("EN Blog should has two tags after ZH Blog Delete with EN's tag")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		DB.Model(&blog2).Association("LocaleTags").Delete(tag5)
 | 
			
		||||
 | 
			
		||||
		if DB.Model(&blog).Association("LocaleTags").Count() != 3 {
 | 
			
		||||
			t.Errorf("ZH Blog should has three tags after EN Blog Delete with EN's tag")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if DB.Model(&blog2).Association("LocaleTags").Count() != 1 {
 | 
			
		||||
			t.Errorf("EN Blog should has 1 tags after EN Blog Delete with EN's tag")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Clear
 | 
			
		||||
		DB.Model(&blog2).Association("LocaleTags").Clear()
 | 
			
		||||
		if DB.Model(&blog).Association("LocaleTags").Count() != 3 {
 | 
			
		||||
			t.Errorf("ZH Blog's tags should not be cleared when clear EN Blog's tags")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if DB.Model(&blog2).Association("LocaleTags").Count() != 0 {
 | 
			
		||||
			t.Errorf("EN Blog's tags should be cleared when clear EN Blog's tags")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		DB.Model(&blog).Association("LocaleTags").Clear()
 | 
			
		||||
		if DB.Model(&blog).Association("LocaleTags").Count() != 0 {
 | 
			
		||||
			t.Errorf("ZH Blog's tags should be cleared when clear ZH Blog's tags")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if DB.Model(&blog2).Association("LocaleTags").Count() != 0 {
 | 
			
		||||
			t.Errorf("EN Blog's tags should be cleared")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -1,84 +0,0 @@
 | 
			
		||||
package gorm_test
 | 
			
		||||
 | 
			
		||||
import "testing"
 | 
			
		||||
 | 
			
		||||
type PointerStruct struct {
 | 
			
		||||
	ID   int64
 | 
			
		||||
	Name *string
 | 
			
		||||
	Num  *int
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type NormalStruct struct {
 | 
			
		||||
	ID   int64
 | 
			
		||||
	Name string
 | 
			
		||||
	Num  int
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestPointerFields(t *testing.T) {
 | 
			
		||||
	DB.DropTable(&PointerStruct{})
 | 
			
		||||
	DB.AutoMigrate(&PointerStruct{})
 | 
			
		||||
	var name = "pointer struct 1"
 | 
			
		||||
	var num = 100
 | 
			
		||||
	pointerStruct := PointerStruct{Name: &name, Num: &num}
 | 
			
		||||
	if DB.Create(&pointerStruct).Error != nil {
 | 
			
		||||
		t.Errorf("Failed to save pointer struct")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var pointerStructResult PointerStruct
 | 
			
		||||
	if err := DB.First(&pointerStructResult, "id = ?", pointerStruct.ID).Error; err != nil || *pointerStructResult.Name != name || *pointerStructResult.Num != num {
 | 
			
		||||
		t.Errorf("Failed to query saved pointer struct")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var tableName = DB.NewScope(&PointerStruct{}).TableName()
 | 
			
		||||
 | 
			
		||||
	var normalStruct NormalStruct
 | 
			
		||||
	DB.Table(tableName).First(&normalStruct)
 | 
			
		||||
	if normalStruct.Name != name || normalStruct.Num != num {
 | 
			
		||||
		t.Errorf("Failed to query saved Normal struct")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var nilPointerStruct = PointerStruct{}
 | 
			
		||||
	if err := DB.Create(&nilPointerStruct).Error; err != nil {
 | 
			
		||||
		t.Error("Failed to save nil pointer struct", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var pointerStruct2 PointerStruct
 | 
			
		||||
	if err := DB.First(&pointerStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil {
 | 
			
		||||
		t.Error("Failed to query saved nil pointer struct", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var normalStruct2 NormalStruct
 | 
			
		||||
	if err := DB.Table(tableName).First(&normalStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil {
 | 
			
		||||
		t.Error("Failed to query saved nil pointer struct", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var partialNilPointerStruct1 = PointerStruct{Num: &num}
 | 
			
		||||
	if err := DB.Create(&partialNilPointerStruct1).Error; err != nil {
 | 
			
		||||
		t.Error("Failed to save partial nil pointer struct", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var pointerStruct3 PointerStruct
 | 
			
		||||
	if err := DB.First(&pointerStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || *pointerStruct3.Num != num {
 | 
			
		||||
		t.Error("Failed to query saved partial nil pointer struct", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var normalStruct3 NormalStruct
 | 
			
		||||
	if err := DB.Table(tableName).First(&normalStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || normalStruct3.Num != num {
 | 
			
		||||
		t.Error("Failed to query saved partial pointer struct", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var partialNilPointerStruct2 = PointerStruct{Name: &name}
 | 
			
		||||
	if err := DB.Create(&partialNilPointerStruct2).Error; err != nil {
 | 
			
		||||
		t.Error("Failed to save partial nil pointer struct", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var pointerStruct4 PointerStruct
 | 
			
		||||
	if err := DB.First(&pointerStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || *pointerStruct4.Name != name {
 | 
			
		||||
		t.Error("Failed to query saved partial nil pointer struct", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var normalStruct4 NormalStruct
 | 
			
		||||
	if err := DB.Table(tableName).First(&normalStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || normalStruct4.Name != name {
 | 
			
		||||
		t.Error("Failed to query saved partial pointer struct", err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -1,366 +0,0 @@
 | 
			
		||||
package gorm_test
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"sort"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Cat struct {
 | 
			
		||||
	Id   int
 | 
			
		||||
	Name string
 | 
			
		||||
	Toy  Toy `gorm:"polymorphic:Owner;"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Dog struct {
 | 
			
		||||
	Id   int
 | 
			
		||||
	Name string
 | 
			
		||||
	Toys []Toy `gorm:"polymorphic:Owner;"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Hamster struct {
 | 
			
		||||
	Id           int
 | 
			
		||||
	Name         string
 | 
			
		||||
	PreferredToy Toy `gorm:"polymorphic:Owner;polymorphic_value:hamster_preferred"`
 | 
			
		||||
	OtherToy     Toy `gorm:"polymorphic:Owner;polymorphic_value:hamster_other"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Toy struct {
 | 
			
		||||
	Id        int
 | 
			
		||||
	Name      string
 | 
			
		||||
	OwnerId   int
 | 
			
		||||
	OwnerType string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var compareToys = func(toys []Toy, contents []string) bool {
 | 
			
		||||
	var toyContents []string
 | 
			
		||||
	for _, toy := range toys {
 | 
			
		||||
		toyContents = append(toyContents, toy.Name)
 | 
			
		||||
	}
 | 
			
		||||
	sort.Strings(toyContents)
 | 
			
		||||
	sort.Strings(contents)
 | 
			
		||||
	return reflect.DeepEqual(toyContents, contents)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestPolymorphic(t *testing.T) {
 | 
			
		||||
	cat := Cat{Name: "Mr. Bigglesworth", Toy: Toy{Name: "cat toy"}}
 | 
			
		||||
	dog := Dog{Name: "Pluto", Toys: []Toy{{Name: "dog toy 1"}, {Name: "dog toy 2"}}}
 | 
			
		||||
	DB.Save(&cat).Save(&dog)
 | 
			
		||||
 | 
			
		||||
	if DB.Model(&cat).Association("Toy").Count() != 1 {
 | 
			
		||||
		t.Errorf("Cat's toys count should be 1")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.Model(&dog).Association("Toys").Count() != 2 {
 | 
			
		||||
		t.Errorf("Dog's toys count should be 2")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Query
 | 
			
		||||
	var catToys []Toy
 | 
			
		||||
	if DB.Model(&cat).Related(&catToys, "Toy").RecordNotFound() {
 | 
			
		||||
		t.Errorf("Did not find any has one polymorphic association")
 | 
			
		||||
	} else if len(catToys) != 1 {
 | 
			
		||||
		t.Errorf("Should have found only one polymorphic has one association")
 | 
			
		||||
	} else if catToys[0].Name != cat.Toy.Name {
 | 
			
		||||
		t.Errorf("Should have found the proper has one polymorphic association")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var dogToys []Toy
 | 
			
		||||
	if DB.Model(&dog).Related(&dogToys, "Toys").RecordNotFound() {
 | 
			
		||||
		t.Errorf("Did not find any polymorphic has many associations")
 | 
			
		||||
	} else if len(dogToys) != len(dog.Toys) {
 | 
			
		||||
		t.Errorf("Should have found all polymorphic has many associations")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var catToy Toy
 | 
			
		||||
	DB.Model(&cat).Association("Toy").Find(&catToy)
 | 
			
		||||
	if catToy.Name != cat.Toy.Name {
 | 
			
		||||
		t.Errorf("Should find has one polymorphic association")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var dogToys1 []Toy
 | 
			
		||||
	DB.Model(&dog).Association("Toys").Find(&dogToys1)
 | 
			
		||||
	if !compareToys(dogToys1, []string{"dog toy 1", "dog toy 2"}) {
 | 
			
		||||
		t.Errorf("Should find has many polymorphic association")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Append
 | 
			
		||||
	DB.Model(&cat).Association("Toy").Append(&Toy{
 | 
			
		||||
		Name: "cat toy 2",
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	var catToy2 Toy
 | 
			
		||||
	DB.Model(&cat).Association("Toy").Find(&catToy2)
 | 
			
		||||
	if catToy2.Name != "cat toy 2" {
 | 
			
		||||
		t.Errorf("Should update has one polymorphic association with Append")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.Model(&cat).Association("Toy").Count() != 1 {
 | 
			
		||||
		t.Errorf("Cat's toys count should be 1 after Append")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.Model(&dog).Association("Toys").Count() != 2 {
 | 
			
		||||
		t.Errorf("Should return two polymorphic has many associations")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Model(&dog).Association("Toys").Append(&Toy{
 | 
			
		||||
		Name: "dog toy 3",
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	var dogToys2 []Toy
 | 
			
		||||
	DB.Model(&dog).Association("Toys").Find(&dogToys2)
 | 
			
		||||
	if !compareToys(dogToys2, []string{"dog toy 1", "dog toy 2", "dog toy 3"}) {
 | 
			
		||||
		t.Errorf("Dog's toys should be updated with Append")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.Model(&dog).Association("Toys").Count() != 3 {
 | 
			
		||||
		t.Errorf("Should return three polymorphic has many associations")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Replace
 | 
			
		||||
	DB.Model(&cat).Association("Toy").Replace(&Toy{
 | 
			
		||||
		Name: "cat toy 3",
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	var catToy3 Toy
 | 
			
		||||
	DB.Model(&cat).Association("Toy").Find(&catToy3)
 | 
			
		||||
	if catToy3.Name != "cat toy 3" {
 | 
			
		||||
		t.Errorf("Should update has one polymorphic association with Replace")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.Model(&cat).Association("Toy").Count() != 1 {
 | 
			
		||||
		t.Errorf("Cat's toys count should be 1 after Replace")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.Model(&dog).Association("Toys").Count() != 3 {
 | 
			
		||||
		t.Errorf("Should return three polymorphic has many associations")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Model(&dog).Association("Toys").Replace(&Toy{
 | 
			
		||||
		Name: "dog toy 4",
 | 
			
		||||
	}, []Toy{
 | 
			
		||||
		{Name: "dog toy 5"}, {Name: "dog toy 6"}, {Name: "dog toy 7"},
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	var dogToys3 []Toy
 | 
			
		||||
	DB.Model(&dog).Association("Toys").Find(&dogToys3)
 | 
			
		||||
	if !compareToys(dogToys3, []string{"dog toy 4", "dog toy 5", "dog toy 6", "dog toy 7"}) {
 | 
			
		||||
		t.Errorf("Dog's toys should be updated with Replace")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.Model(&dog).Association("Toys").Count() != 4 {
 | 
			
		||||
		t.Errorf("Should return three polymorphic has many associations")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Delete
 | 
			
		||||
	DB.Model(&cat).Association("Toy").Delete(&catToy2)
 | 
			
		||||
 | 
			
		||||
	var catToy4 Toy
 | 
			
		||||
	DB.Model(&cat).Association("Toy").Find(&catToy4)
 | 
			
		||||
	if catToy4.Name != "cat toy 3" {
 | 
			
		||||
		t.Errorf("Should not update has one polymorphic association when Delete a unrelated Toy")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.Model(&cat).Association("Toy").Count() != 1 {
 | 
			
		||||
		t.Errorf("Cat's toys count should be 1")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.Model(&dog).Association("Toys").Count() != 4 {
 | 
			
		||||
		t.Errorf("Dog's toys count should be 4")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Model(&cat).Association("Toy").Delete(&catToy3)
 | 
			
		||||
 | 
			
		||||
	if !DB.Model(&cat).Related(&Toy{}, "Toy").RecordNotFound() {
 | 
			
		||||
		t.Errorf("Toy should be deleted with Delete")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.Model(&cat).Association("Toy").Count() != 0 {
 | 
			
		||||
		t.Errorf("Cat's toys count should be 0 after Delete")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.Model(&dog).Association("Toys").Count() != 4 {
 | 
			
		||||
		t.Errorf("Dog's toys count should not be changed when delete cat's toy")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Model(&dog).Association("Toys").Delete(&dogToys2)
 | 
			
		||||
 | 
			
		||||
	if DB.Model(&dog).Association("Toys").Count() != 4 {
 | 
			
		||||
		t.Errorf("Dog's toys count should not be changed when delete unrelated toys")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Model(&dog).Association("Toys").Delete(&dogToys3)
 | 
			
		||||
 | 
			
		||||
	if DB.Model(&dog).Association("Toys").Count() != 0 {
 | 
			
		||||
		t.Errorf("Dog's toys count should be deleted with Delete")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Clear
 | 
			
		||||
	DB.Model(&cat).Association("Toy").Append(&Toy{
 | 
			
		||||
		Name: "cat toy 2",
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	if DB.Model(&cat).Association("Toy").Count() != 1 {
 | 
			
		||||
		t.Errorf("Cat's toys should be added with Append")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Model(&cat).Association("Toy").Clear()
 | 
			
		||||
 | 
			
		||||
	if DB.Model(&cat).Association("Toy").Count() != 0 {
 | 
			
		||||
		t.Errorf("Cat's toys should be cleared with Clear")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Model(&dog).Association("Toys").Append(&Toy{
 | 
			
		||||
		Name: "dog toy 8",
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	if DB.Model(&dog).Association("Toys").Count() != 1 {
 | 
			
		||||
		t.Errorf("Dog's toys should be added with Append")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Model(&dog).Association("Toys").Clear()
 | 
			
		||||
 | 
			
		||||
	if DB.Model(&dog).Association("Toys").Count() != 0 {
 | 
			
		||||
		t.Errorf("Dog's toys should be cleared with Clear")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestNamedPolymorphic(t *testing.T) {
 | 
			
		||||
	hamster := Hamster{Name: "Mr. Hammond", PreferredToy: Toy{Name: "bike"}, OtherToy: Toy{Name: "treadmill"}}
 | 
			
		||||
	DB.Save(&hamster)
 | 
			
		||||
 | 
			
		||||
	hamster2 := Hamster{}
 | 
			
		||||
	DB.Preload("PreferredToy").Preload("OtherToy").Find(&hamster2, hamster.Id)
 | 
			
		||||
	if hamster2.PreferredToy.Id != hamster.PreferredToy.Id || hamster2.PreferredToy.Name != hamster.PreferredToy.Name {
 | 
			
		||||
		t.Errorf("Hamster's preferred toy couldn't be preloaded")
 | 
			
		||||
	}
 | 
			
		||||
	if hamster2.OtherToy.Id != hamster.OtherToy.Id || hamster2.OtherToy.Name != hamster.OtherToy.Name {
 | 
			
		||||
		t.Errorf("Hamster's other toy couldn't be preloaded")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// clear to omit Toy.Id in count
 | 
			
		||||
	hamster2.PreferredToy = Toy{}
 | 
			
		||||
	hamster2.OtherToy = Toy{}
 | 
			
		||||
 | 
			
		||||
	if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 {
 | 
			
		||||
		t.Errorf("Hamster's preferred toy count should be 1")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.Model(&hamster2).Association("OtherToy").Count() != 1 {
 | 
			
		||||
		t.Errorf("Hamster's other toy count should be 1")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Query
 | 
			
		||||
	var hamsterToys []Toy
 | 
			
		||||
	if DB.Model(&hamster).Related(&hamsterToys, "PreferredToy").RecordNotFound() {
 | 
			
		||||
		t.Errorf("Did not find any has one polymorphic association")
 | 
			
		||||
	} else if len(hamsterToys) != 1 {
 | 
			
		||||
		t.Errorf("Should have found only one polymorphic has one association")
 | 
			
		||||
	} else if hamsterToys[0].Name != hamster.PreferredToy.Name {
 | 
			
		||||
		t.Errorf("Should have found the proper has one polymorphic association")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.Model(&hamster).Related(&hamsterToys, "OtherToy").RecordNotFound() {
 | 
			
		||||
		t.Errorf("Did not find any has one polymorphic association")
 | 
			
		||||
	} else if len(hamsterToys) != 1 {
 | 
			
		||||
		t.Errorf("Should have found only one polymorphic has one association")
 | 
			
		||||
	} else if hamsterToys[0].Name != hamster.OtherToy.Name {
 | 
			
		||||
		t.Errorf("Should have found the proper has one polymorphic association")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	hamsterToy := Toy{}
 | 
			
		||||
	DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy)
 | 
			
		||||
	if hamsterToy.Name != hamster.PreferredToy.Name {
 | 
			
		||||
		t.Errorf("Should find has one polymorphic association")
 | 
			
		||||
	}
 | 
			
		||||
	hamsterToy = Toy{}
 | 
			
		||||
	DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy)
 | 
			
		||||
	if hamsterToy.Name != hamster.OtherToy.Name {
 | 
			
		||||
		t.Errorf("Should find has one polymorphic association")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Append
 | 
			
		||||
	DB.Model(&hamster).Association("PreferredToy").Append(&Toy{
 | 
			
		||||
		Name: "bike 2",
 | 
			
		||||
	})
 | 
			
		||||
	DB.Model(&hamster).Association("OtherToy").Append(&Toy{
 | 
			
		||||
		Name: "treadmill 2",
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	hamsterToy = Toy{}
 | 
			
		||||
	DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy)
 | 
			
		||||
	if hamsterToy.Name != "bike 2" {
 | 
			
		||||
		t.Errorf("Should update has one polymorphic association with Append")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	hamsterToy = Toy{}
 | 
			
		||||
	DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy)
 | 
			
		||||
	if hamsterToy.Name != "treadmill 2" {
 | 
			
		||||
		t.Errorf("Should update has one polymorphic association with Append")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 {
 | 
			
		||||
		t.Errorf("Hamster's toys count should be 1 after Append")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.Model(&hamster2).Association("OtherToy").Count() != 1 {
 | 
			
		||||
		t.Errorf("Hamster's toys count should be 1 after Append")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Replace
 | 
			
		||||
	DB.Model(&hamster).Association("PreferredToy").Replace(&Toy{
 | 
			
		||||
		Name: "bike 3",
 | 
			
		||||
	})
 | 
			
		||||
	DB.Model(&hamster).Association("OtherToy").Replace(&Toy{
 | 
			
		||||
		Name: "treadmill 3",
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	hamsterToy = Toy{}
 | 
			
		||||
	DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy)
 | 
			
		||||
	if hamsterToy.Name != "bike 3" {
 | 
			
		||||
		t.Errorf("Should update has one polymorphic association with Replace")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	hamsterToy = Toy{}
 | 
			
		||||
	DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy)
 | 
			
		||||
	if hamsterToy.Name != "treadmill 3" {
 | 
			
		||||
		t.Errorf("Should update has one polymorphic association with Replace")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 {
 | 
			
		||||
		t.Errorf("hamster's toys count should be 1 after Replace")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.Model(&hamster2).Association("OtherToy").Count() != 1 {
 | 
			
		||||
		t.Errorf("hamster's toys count should be 1 after Replace")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Clear
 | 
			
		||||
	DB.Model(&hamster).Association("PreferredToy").Append(&Toy{
 | 
			
		||||
		Name: "bike 2",
 | 
			
		||||
	})
 | 
			
		||||
	DB.Model(&hamster).Association("OtherToy").Append(&Toy{
 | 
			
		||||
		Name: "treadmill 2",
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	if DB.Model(&hamster).Association("PreferredToy").Count() != 1 {
 | 
			
		||||
		t.Errorf("Hamster's toys should be added with Append")
 | 
			
		||||
	}
 | 
			
		||||
	if DB.Model(&hamster).Association("OtherToy").Count() != 1 {
 | 
			
		||||
		t.Errorf("Hamster's toys should be added with Append")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Model(&hamster).Association("PreferredToy").Clear()
 | 
			
		||||
 | 
			
		||||
	if DB.Model(&hamster2).Association("PreferredToy").Count() != 0 {
 | 
			
		||||
		t.Errorf("Hamster's preferred toy should be cleared with Clear")
 | 
			
		||||
	}
 | 
			
		||||
	if DB.Model(&hamster2).Association("OtherToy").Count() != 1 {
 | 
			
		||||
		t.Errorf("Hamster's other toy should be still available")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Model(&hamster).Association("OtherToy").Clear()
 | 
			
		||||
	if DB.Model(&hamster).Association("OtherToy").Count() != 0 {
 | 
			
		||||
		t.Errorf("Hamster's other toy should be cleared with Clear")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										1675
									
								
								preload_test.go
									
									
									
									
									
								
							
							
						
						
									
										1675
									
								
								preload_test.go
									
									
									
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										773
									
								
								query_test.go
									
									
									
									
									
								
							
							
						
						
									
										773
									
								
								query_test.go
									
									
									
									
									
								
							@ -1,773 +0,0 @@
 | 
			
		||||
package gorm_test
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"reflect"
 | 
			
		||||
 | 
			
		||||
	"github.com/jinzhu/gorm"
 | 
			
		||||
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestFirstAndLast(t *testing.T) {
 | 
			
		||||
	DB.Save(&User{Name: "user1", Emails: []Email{{Email: "user1@example.com"}}})
 | 
			
		||||
	DB.Save(&User{Name: "user2", Emails: []Email{{Email: "user2@example.com"}}})
 | 
			
		||||
 | 
			
		||||
	var user1, user2, user3, user4 User
 | 
			
		||||
	DB.First(&user1)
 | 
			
		||||
	DB.Order("id").Limit(1).Find(&user2)
 | 
			
		||||
 | 
			
		||||
	ptrOfUser3 := &user3
 | 
			
		||||
	DB.Last(&ptrOfUser3)
 | 
			
		||||
	DB.Order("id desc").Limit(1).Find(&user4)
 | 
			
		||||
	if user1.Id != user2.Id || user3.Id != user4.Id {
 | 
			
		||||
		t.Errorf("First and Last should by order by primary key")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var users []User
 | 
			
		||||
	DB.First(&users)
 | 
			
		||||
	if len(users) != 1 {
 | 
			
		||||
		t.Errorf("Find first record as slice")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var user User
 | 
			
		||||
	if DB.Joins("left join emails on emails.user_id = users.id").First(&user).Error != nil {
 | 
			
		||||
		t.Errorf("Should not raise any error when order with Join table")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if user.Email != "" {
 | 
			
		||||
		t.Errorf("User's Email should be blank as no one set it")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestFirstAndLastWithNoStdPrimaryKey(t *testing.T) {
 | 
			
		||||
	DB.Save(&Animal{Name: "animal1"})
 | 
			
		||||
	DB.Save(&Animal{Name: "animal2"})
 | 
			
		||||
 | 
			
		||||
	var animal1, animal2, animal3, animal4 Animal
 | 
			
		||||
	DB.First(&animal1)
 | 
			
		||||
	DB.Order("counter").Limit(1).Find(&animal2)
 | 
			
		||||
 | 
			
		||||
	DB.Last(&animal3)
 | 
			
		||||
	DB.Order("counter desc").Limit(1).Find(&animal4)
 | 
			
		||||
	if animal1.Counter != animal2.Counter || animal3.Counter != animal4.Counter {
 | 
			
		||||
		t.Errorf("First and Last should work correctly")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestFirstAndLastWithRaw(t *testing.T) {
 | 
			
		||||
	user1 := User{Name: "user", Emails: []Email{{Email: "user1@example.com"}}}
 | 
			
		||||
	user2 := User{Name: "user", Emails: []Email{{Email: "user2@example.com"}}}
 | 
			
		||||
	DB.Save(&user1)
 | 
			
		||||
	DB.Save(&user2)
 | 
			
		||||
 | 
			
		||||
	var user3, user4 User
 | 
			
		||||
	DB.Raw("select * from users WHERE name = ?", "user").First(&user3)
 | 
			
		||||
	if user3.Id != user1.Id {
 | 
			
		||||
		t.Errorf("Find first record with raw")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Raw("select * from users WHERE name = ?", "user").Last(&user4)
 | 
			
		||||
	if user4.Id != user2.Id {
 | 
			
		||||
		t.Errorf("Find last record with raw")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestUIntPrimaryKey(t *testing.T) {
 | 
			
		||||
	var animal Animal
 | 
			
		||||
	DB.First(&animal, uint64(1))
 | 
			
		||||
	if animal.Counter != 1 {
 | 
			
		||||
		t.Errorf("Fetch a record from with a non-int primary key should work, but failed")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Model(Animal{}).Where(Animal{Counter: uint64(2)}).Scan(&animal)
 | 
			
		||||
	if animal.Counter != 2 {
 | 
			
		||||
		t.Errorf("Fetch a record from with a non-int primary key should work, but failed")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestCustomizedTypePrimaryKey(t *testing.T) {
 | 
			
		||||
	type ID uint
 | 
			
		||||
	type CustomizedTypePrimaryKey struct {
 | 
			
		||||
		ID   ID
 | 
			
		||||
		Name string
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.AutoMigrate(&CustomizedTypePrimaryKey{})
 | 
			
		||||
 | 
			
		||||
	p1 := CustomizedTypePrimaryKey{Name: "p1"}
 | 
			
		||||
	p2 := CustomizedTypePrimaryKey{Name: "p2"}
 | 
			
		||||
	p3 := CustomizedTypePrimaryKey{Name: "p3"}
 | 
			
		||||
	DB.Create(&p1)
 | 
			
		||||
	DB.Create(&p2)
 | 
			
		||||
	DB.Create(&p3)
 | 
			
		||||
 | 
			
		||||
	var p CustomizedTypePrimaryKey
 | 
			
		||||
 | 
			
		||||
	if err := DB.First(&p, p2.ID).Error; err == nil {
 | 
			
		||||
		t.Errorf("Should return error for invalid query condition")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := DB.First(&p, "id = ?", p2.ID).Error; err != nil {
 | 
			
		||||
		t.Errorf("No error should happen when querying with customized type for primary key, got err %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if p.Name != "p2" {
 | 
			
		||||
		t.Errorf("Should find correct value when querying with customized type for primary key")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestStringPrimaryKeyForNumericValueStartingWithZero(t *testing.T) {
 | 
			
		||||
	type AddressByZipCode struct {
 | 
			
		||||
		ZipCode string `gorm:"primary_key"`
 | 
			
		||||
		Address string
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.AutoMigrate(&AddressByZipCode{})
 | 
			
		||||
	DB.Create(&AddressByZipCode{ZipCode: "00501", Address: "Holtsville"})
 | 
			
		||||
 | 
			
		||||
	var address AddressByZipCode
 | 
			
		||||
	DB.First(&address, "00501")
 | 
			
		||||
	if address.ZipCode != "00501" {
 | 
			
		||||
		t.Errorf("Fetch a record from with a string primary key for a numeric value starting with zero should work, but failed, zip code is %v", address.ZipCode)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestFindAsSliceOfPointers(t *testing.T) {
 | 
			
		||||
	DB.Save(&User{Name: "user"})
 | 
			
		||||
 | 
			
		||||
	var users []User
 | 
			
		||||
	DB.Find(&users)
 | 
			
		||||
 | 
			
		||||
	var userPointers []*User
 | 
			
		||||
	DB.Find(&userPointers)
 | 
			
		||||
 | 
			
		||||
	if len(users) == 0 || len(users) != len(userPointers) {
 | 
			
		||||
		t.Errorf("Find slice of pointers")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestSearchWithPlainSQL(t *testing.T) {
 | 
			
		||||
	user1 := User{Name: "PlainSqlUser1", Age: 1, Birthday: parseTime("2000-1-1")}
 | 
			
		||||
	user2 := User{Name: "PlainSqlUser2", Age: 10, Birthday: parseTime("2010-1-1")}
 | 
			
		||||
	user3 := User{Name: "PlainSqlUser3", Age: 20, Birthday: parseTime("2020-1-1")}
 | 
			
		||||
	DB.Save(&user1).Save(&user2).Save(&user3)
 | 
			
		||||
	scopedb := DB.Where("name LIKE ?", "%PlainSqlUser%")
 | 
			
		||||
 | 
			
		||||
	if DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() {
 | 
			
		||||
		t.Errorf("Search with plain SQL")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.Where("name LIKE ?", "%"+user1.Name+"%").First(&User{}).RecordNotFound() {
 | 
			
		||||
		t.Errorf("Search with plan SQL (regexp)")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var users []User
 | 
			
		||||
	DB.Find(&users, "name LIKE ? and age > ?", "%PlainSqlUser%", 1)
 | 
			
		||||
	if len(users) != 2 {
 | 
			
		||||
		t.Errorf("Should found 2 users that age > 1, but got %v", len(users))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Where("name LIKE ?", "%PlainSqlUser%").Where("age >= ?", 1).Find(&users)
 | 
			
		||||
	if len(users) != 3 {
 | 
			
		||||
		t.Errorf("Should found 3 users that age >= 1, but got %v", len(users))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	scopedb.Where("age <> ?", 20).Find(&users)
 | 
			
		||||
	if len(users) != 2 {
 | 
			
		||||
		t.Errorf("Should found 2 users age != 20, but got %v", len(users))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	scopedb.Where("birthday > ?", parseTime("2000-1-1")).Find(&users)
 | 
			
		||||
	if len(users) != 2 {
 | 
			
		||||
		t.Errorf("Should found 2 users's birthday > 2000-1-1, but got %v", len(users))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	scopedb.Where("birthday > ?", "2002-10-10").Find(&users)
 | 
			
		||||
	if len(users) != 2 {
 | 
			
		||||
		t.Errorf("Should found 2 users's birthday >= 2002-10-10, but got %v", len(users))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	scopedb.Where("birthday >= ?", "2010-1-1").Where("birthday < ?", "2020-1-1").Find(&users)
 | 
			
		||||
	if len(users) != 1 {
 | 
			
		||||
		t.Errorf("Should found 1 users's birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Where("name in (?)", []string{user1.Name, user2.Name}).Find(&users)
 | 
			
		||||
	if len(users) != 2 {
 | 
			
		||||
		t.Errorf("Should found 2 users, but got %v", len(users))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Where("id in (?)", []int64{user1.Id, user2.Id, user3.Id}).Find(&users)
 | 
			
		||||
	if len(users) != 3 {
 | 
			
		||||
		t.Errorf("Should found 3 users, but got %v", len(users))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Where("id in (?)", user1.Id).Find(&users)
 | 
			
		||||
	if len(users) != 1 {
 | 
			
		||||
		t.Errorf("Should found 1 users, but got %v", len(users))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := DB.Where("id IN (?)", []string{}).Find(&users).Error; err != nil {
 | 
			
		||||
		t.Error("no error should happen when query with empty slice, but got: ", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := DB.Not("id IN (?)", []string{}).Find(&users).Error; err != nil {
 | 
			
		||||
		t.Error("no error should happen when query with empty slice, but got: ", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.Where("name = ?", "none existing").Find(&[]User{}).RecordNotFound() {
 | 
			
		||||
		t.Errorf("Should not get RecordNotFound error when looking for none existing records")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestSearchWithTwoDimensionalArray(t *testing.T) {
 | 
			
		||||
	var users []User
 | 
			
		||||
	user1 := User{Name: "2DSearchUser1", Age: 1, Birthday: parseTime("2000-1-1")}
 | 
			
		||||
	user2 := User{Name: "2DSearchUser2", Age: 10, Birthday: parseTime("2010-1-1")}
 | 
			
		||||
	user3 := User{Name: "2DSearchUser3", Age: 20, Birthday: parseTime("2020-1-1")}
 | 
			
		||||
	DB.Create(&user1)
 | 
			
		||||
	DB.Create(&user2)
 | 
			
		||||
	DB.Create(&user3)
 | 
			
		||||
 | 
			
		||||
	if dialect := DB.Dialect().GetName(); dialect == "mysql" || dialect == "postgres" {
 | 
			
		||||
		if err := DB.Where("(name, age) IN (?)", [][]interface{}{{"2DSearchUser1", 1}, {"2DSearchUser2", 10}}).Find(&users).Error; err != nil {
 | 
			
		||||
			t.Errorf("No error should happen when query with 2D array, but got %v", err)
 | 
			
		||||
 | 
			
		||||
			if len(users) != 2 {
 | 
			
		||||
				t.Errorf("Should find 2 users with 2D array, but got %v", len(users))
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if dialect := DB.Dialect().GetName(); dialect == "mssql" {
 | 
			
		||||
		if err := DB.Joins("JOIN (VALUES ?) AS x (col1, col2) ON x.col1 = name AND x.col2 = age", [][]interface{}{{"2DSearchUser1", 1}, {"2DSearchUser2", 10}}).Find(&users).Error; err != nil {
 | 
			
		||||
			t.Errorf("No error should happen when query with 2D array, but got %v", err)
 | 
			
		||||
 | 
			
		||||
			if len(users) != 2 {
 | 
			
		||||
				t.Errorf("Should find 2 users with 2D array, but got %v", len(users))
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestSearchWithStruct(t *testing.T) {
 | 
			
		||||
	user1 := User{Name: "StructSearchUser1", Age: 1, Birthday: parseTime("2000-1-1")}
 | 
			
		||||
	user2 := User{Name: "StructSearchUser2", Age: 10, Birthday: parseTime("2010-1-1")}
 | 
			
		||||
	user3 := User{Name: "StructSearchUser3", Age: 20, Birthday: parseTime("2020-1-1")}
 | 
			
		||||
	DB.Save(&user1).Save(&user2).Save(&user3)
 | 
			
		||||
 | 
			
		||||
	if DB.Where(user1.Id).First(&User{}).RecordNotFound() {
 | 
			
		||||
		t.Errorf("Search with primary key")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.First(&User{}, user1.Id).RecordNotFound() {
 | 
			
		||||
		t.Errorf("Search with primary key as inline condition")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.First(&User{}, fmt.Sprintf("%v", user1.Id)).RecordNotFound() {
 | 
			
		||||
		t.Errorf("Search with primary key as inline condition")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var users []User
 | 
			
		||||
	DB.Where([]int64{user1.Id, user2.Id, user3.Id}).Find(&users)
 | 
			
		||||
	if len(users) != 3 {
 | 
			
		||||
		t.Errorf("Should found 3 users when search with primary keys, but got %v", len(users))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var user User
 | 
			
		||||
	DB.First(&user, &User{Name: user1.Name})
 | 
			
		||||
	if user.Id == 0 || user.Name != user1.Name {
 | 
			
		||||
		t.Errorf("Search first record with inline pointer of struct")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.First(&user, User{Name: user1.Name})
 | 
			
		||||
	if user.Id == 0 || user.Name != user1.Name {
 | 
			
		||||
		t.Errorf("Search first record with inline struct")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Where(&User{Name: user1.Name}).First(&user)
 | 
			
		||||
	if user.Id == 0 || user.Name != user1.Name {
 | 
			
		||||
		t.Errorf("Search first record with where struct")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Find(&users, &User{Name: user2.Name})
 | 
			
		||||
	if len(users) != 1 {
 | 
			
		||||
		t.Errorf("Search all records with inline struct")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestSearchWithMap(t *testing.T) {
 | 
			
		||||
	companyID := 1
 | 
			
		||||
	user1 := User{Name: "MapSearchUser1", Age: 1, Birthday: parseTime("2000-1-1")}
 | 
			
		||||
	user2 := User{Name: "MapSearchUser2", Age: 10, Birthday: parseTime("2010-1-1")}
 | 
			
		||||
	user3 := User{Name: "MapSearchUser3", Age: 20, Birthday: parseTime("2020-1-1")}
 | 
			
		||||
	user4 := User{Name: "MapSearchUser4", Age: 30, Birthday: parseTime("2020-1-1"), CompanyID: &companyID}
 | 
			
		||||
	DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4)
 | 
			
		||||
 | 
			
		||||
	var user User
 | 
			
		||||
	DB.First(&user, map[string]interface{}{"name": user1.Name})
 | 
			
		||||
	if user.Id == 0 || user.Name != user1.Name {
 | 
			
		||||
		t.Errorf("Search first record with inline map")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	user = User{}
 | 
			
		||||
	DB.Where(map[string]interface{}{"name": user2.Name}).First(&user)
 | 
			
		||||
	if user.Id == 0 || user.Name != user2.Name {
 | 
			
		||||
		t.Errorf("Search first record with where map")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var users []User
 | 
			
		||||
	DB.Where(map[string]interface{}{"name": user3.Name}).Find(&users)
 | 
			
		||||
	if len(users) != 1 {
 | 
			
		||||
		t.Errorf("Search all records with inline map")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Find(&users, map[string]interface{}{"name": user3.Name})
 | 
			
		||||
	if len(users) != 1 {
 | 
			
		||||
		t.Errorf("Search all records with inline map")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Find(&users, map[string]interface{}{"name": user4.Name, "company_id": nil})
 | 
			
		||||
	if len(users) != 0 {
 | 
			
		||||
		t.Errorf("Search all records with inline map containing null value finding 0 records")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Find(&users, map[string]interface{}{"name": user1.Name, "company_id": nil})
 | 
			
		||||
	if len(users) != 1 {
 | 
			
		||||
		t.Errorf("Search all records with inline map containing null value finding 1 record")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Find(&users, map[string]interface{}{"name": user4.Name, "company_id": companyID})
 | 
			
		||||
	if len(users) != 1 {
 | 
			
		||||
		t.Errorf("Search all records with inline multiple value map")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestSearchWithEmptyChain(t *testing.T) {
 | 
			
		||||
	user1 := User{Name: "ChainSearchUser1", Age: 1, Birthday: parseTime("2000-1-1")}
 | 
			
		||||
	user2 := User{Name: "ChainearchUser2", Age: 10, Birthday: parseTime("2010-1-1")}
 | 
			
		||||
	user3 := User{Name: "ChainearchUser3", Age: 20, Birthday: parseTime("2020-1-1")}
 | 
			
		||||
	DB.Save(&user1).Save(&user2).Save(&user3)
 | 
			
		||||
 | 
			
		||||
	if DB.Where("").Where("").First(&User{}).Error != nil {
 | 
			
		||||
		t.Errorf("Should not raise any error if searching with empty strings")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.Where(&User{}).Where("name = ?", user1.Name).First(&User{}).Error != nil {
 | 
			
		||||
		t.Errorf("Should not raise any error if searching with empty struct")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.Where(map[string]interface{}{}).Where("name = ?", user1.Name).First(&User{}).Error != nil {
 | 
			
		||||
		t.Errorf("Should not raise any error if searching with empty map")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestSelect(t *testing.T) {
 | 
			
		||||
	user1 := User{Name: "SelectUser1"}
 | 
			
		||||
	DB.Save(&user1)
 | 
			
		||||
 | 
			
		||||
	var user User
 | 
			
		||||
	DB.Where("name = ?", user1.Name).Select("name").Find(&user)
 | 
			
		||||
	if user.Id != 0 {
 | 
			
		||||
		t.Errorf("Should not have ID because only selected name, %+v", user.Id)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if user.Name != user1.Name {
 | 
			
		||||
		t.Errorf("Should have user Name when selected it")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestOrderAndPluck(t *testing.T) {
 | 
			
		||||
	user1 := User{Name: "OrderPluckUser1", Age: 1}
 | 
			
		||||
	user2 := User{Name: "OrderPluckUser2", Age: 10}
 | 
			
		||||
	user3 := User{Name: "OrderPluckUser3", Age: 20}
 | 
			
		||||
	DB.Save(&user1).Save(&user2).Save(&user3)
 | 
			
		||||
	scopedb := DB.Model(&User{}).Where("name like ?", "%OrderPluckUser%")
 | 
			
		||||
 | 
			
		||||
	var user User
 | 
			
		||||
	scopedb.Order(gorm.Expr("case when name = ? then 0 else 1 end", "OrderPluckUser2")).First(&user)
 | 
			
		||||
	if user.Name != "OrderPluckUser2" {
 | 
			
		||||
		t.Errorf("Order with sql expression")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var ages []int64
 | 
			
		||||
	scopedb.Order("age desc").Pluck("age", &ages)
 | 
			
		||||
	if ages[0] != 20 {
 | 
			
		||||
		t.Errorf("The first age should be 20 when order with age desc")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var ages1, ages2 []int64
 | 
			
		||||
	scopedb.Order("age desc").Pluck("age", &ages1).Pluck("age", &ages2)
 | 
			
		||||
	if !reflect.DeepEqual(ages1, ages2) {
 | 
			
		||||
		t.Errorf("The first order is the primary order")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var ages3, ages4 []int64
 | 
			
		||||
	scopedb.Model(&User{}).Order("age desc").Pluck("age", &ages3).Order("age", true).Pluck("age", &ages4)
 | 
			
		||||
	if reflect.DeepEqual(ages3, ages4) {
 | 
			
		||||
		t.Errorf("Reorder should work")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var names []string
 | 
			
		||||
	var ages5 []int64
 | 
			
		||||
	scopedb.Model(User{}).Order("name").Order("age desc").Pluck("age", &ages5).Pluck("name", &names)
 | 
			
		||||
	if names != nil && ages5 != nil {
 | 
			
		||||
		if !(names[0] == user1.Name && names[1] == user2.Name && names[2] == user3.Name && ages5[2] == 20) {
 | 
			
		||||
			t.Errorf("Order with multiple orders")
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		t.Errorf("Order with multiple orders")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var ages6 []int64
 | 
			
		||||
	if err := scopedb.Order("").Pluck("age", &ages6).Error; err != nil {
 | 
			
		||||
		t.Errorf("An empty string as order clause produces invalid queries")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Model(User{}).Select("name, age").Find(&[]User{})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestLimit(t *testing.T) {
 | 
			
		||||
	user1 := User{Name: "LimitUser1", Age: 1}
 | 
			
		||||
	user2 := User{Name: "LimitUser2", Age: 10}
 | 
			
		||||
	user3 := User{Name: "LimitUser3", Age: 20}
 | 
			
		||||
	user4 := User{Name: "LimitUser4", Age: 10}
 | 
			
		||||
	user5 := User{Name: "LimitUser5", Age: 20}
 | 
			
		||||
	DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4).Save(&user5)
 | 
			
		||||
 | 
			
		||||
	var users1, users2, users3 []User
 | 
			
		||||
	DB.Order("age desc").Limit(3).Find(&users1).Limit(5).Find(&users2).Limit(-1).Find(&users3)
 | 
			
		||||
 | 
			
		||||
	if len(users1) != 3 || len(users2) != 5 || len(users3) <= 5 {
 | 
			
		||||
		t.Errorf("Limit should works")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestOffset(t *testing.T) {
 | 
			
		||||
	for i := 0; i < 20; i++ {
 | 
			
		||||
		DB.Save(&User{Name: fmt.Sprintf("OffsetUser%v", i)})
 | 
			
		||||
	}
 | 
			
		||||
	var users1, users2, users3, users4 []User
 | 
			
		||||
	DB.Limit(100).Where("name like ?", "OffsetUser%").Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4)
 | 
			
		||||
 | 
			
		||||
	if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) {
 | 
			
		||||
		t.Errorf("Offset should work")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestOr(t *testing.T) {
 | 
			
		||||
	user1 := User{Name: "OrUser1", Age: 1}
 | 
			
		||||
	user2 := User{Name: "OrUser2", Age: 10}
 | 
			
		||||
	user3 := User{Name: "OrUser3", Age: 20}
 | 
			
		||||
	DB.Save(&user1).Save(&user2).Save(&user3)
 | 
			
		||||
 | 
			
		||||
	var users []User
 | 
			
		||||
	DB.Where("name = ?", user1.Name).Or("name = ?", user2.Name).Find(&users)
 | 
			
		||||
	if len(users) != 2 {
 | 
			
		||||
		t.Errorf("Find users with or")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestCount(t *testing.T) {
 | 
			
		||||
	user1 := User{Name: "CountUser1", Age: 1}
 | 
			
		||||
	user2 := User{Name: "CountUser2", Age: 10}
 | 
			
		||||
	user3 := User{Name: "CountUser3", Age: 20}
 | 
			
		||||
 | 
			
		||||
	DB.Save(&user1).Save(&user2).Save(&user3)
 | 
			
		||||
	var count, count1, count2 int64
 | 
			
		||||
	var users []User
 | 
			
		||||
 | 
			
		||||
	if err := DB.Where("name = ?", user1.Name).Or("name = ?", user3.Name).Find(&users).Count(&count).Error; err != nil {
 | 
			
		||||
		t.Errorf(fmt.Sprintf("Count should work, but got err %v", err))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if count != int64(len(users)) {
 | 
			
		||||
		t.Errorf("Count() method should get correct value")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Model(&User{}).Where("name = ?", user1.Name).Count(&count1).Or("name in (?)", []string{user2.Name, user3.Name}).Count(&count2)
 | 
			
		||||
	if count1 != 1 || count2 != 3 {
 | 
			
		||||
		t.Errorf("Multiple count in chain")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var count3 int
 | 
			
		||||
	if err := DB.Model(&User{}).Where("name in (?)", []string{user2.Name, user2.Name, user3.Name}).Group("id").Count(&count3).Error; err != nil {
 | 
			
		||||
		t.Errorf("Not error should happen, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if count3 != 2 {
 | 
			
		||||
		t.Errorf("Should get correct count, but got %v", count3)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestNot(t *testing.T) {
 | 
			
		||||
	DB.Create(getPreparedUser("user1", "not"))
 | 
			
		||||
	DB.Create(getPreparedUser("user2", "not"))
 | 
			
		||||
	DB.Create(getPreparedUser("user3", "not"))
 | 
			
		||||
 | 
			
		||||
	user4 := getPreparedUser("user4", "not")
 | 
			
		||||
	user4.Company = Company{}
 | 
			
		||||
	DB.Create(user4)
 | 
			
		||||
 | 
			
		||||
	DB := DB.Where("role = ?", "not")
 | 
			
		||||
 | 
			
		||||
	var users1, users2, users3, users4, users5, users6, users7, users8, users9 []User
 | 
			
		||||
	if DB.Find(&users1).RowsAffected != 4 {
 | 
			
		||||
		t.Errorf("should find 4 not users")
 | 
			
		||||
	}
 | 
			
		||||
	DB.Not(users1[0].Id).Find(&users2)
 | 
			
		||||
 | 
			
		||||
	if len(users1)-len(users2) != 1 {
 | 
			
		||||
		t.Errorf("Should ignore the first users with Not")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Not([]int{}).Find(&users3)
 | 
			
		||||
	if len(users1)-len(users3) != 0 {
 | 
			
		||||
		t.Errorf("Should find all users with a blank condition")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var name3Count int64
 | 
			
		||||
	DB.Table("users").Where("name = ?", "user3").Count(&name3Count)
 | 
			
		||||
	DB.Not("name", "user3").Find(&users4)
 | 
			
		||||
	if len(users1)-len(users4) != int(name3Count) {
 | 
			
		||||
		t.Errorf("Should find all users's name not equal 3")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Not("name = ?", "user3").Find(&users4)
 | 
			
		||||
	if len(users1)-len(users4) != int(name3Count) {
 | 
			
		||||
		t.Errorf("Should find all users's name not equal 3")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Not("name <> ?", "user3").Find(&users4)
 | 
			
		||||
	if len(users4) != int(name3Count) {
 | 
			
		||||
		t.Errorf("Should find all users's name not equal 3")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Not(User{Name: "user3"}).Find(&users5)
 | 
			
		||||
 | 
			
		||||
	if len(users1)-len(users5) != int(name3Count) {
 | 
			
		||||
		t.Errorf("Should find all users's name not equal 3")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Not(map[string]interface{}{"name": "user3"}).Find(&users6)
 | 
			
		||||
	if len(users1)-len(users6) != int(name3Count) {
 | 
			
		||||
		t.Errorf("Should find all users's name not equal 3")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Not(map[string]interface{}{"name": "user3", "company_id": nil}).Find(&users7)
 | 
			
		||||
	if len(users1)-len(users7) != 2 { // not user3 or user4
 | 
			
		||||
		t.Errorf("Should find all user's name not equal to 3 who do not have company id")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Not("name", []string{"user3"}).Find(&users8)
 | 
			
		||||
	if len(users1)-len(users8) != int(name3Count) {
 | 
			
		||||
		t.Errorf("Should find all users's name not equal 3")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var name2Count int64
 | 
			
		||||
	DB.Table("users").Where("name = ?", "user2").Count(&name2Count)
 | 
			
		||||
	DB.Not("name", []string{"user3", "user2"}).Find(&users9)
 | 
			
		||||
	if len(users1)-len(users9) != (int(name3Count) + int(name2Count)) {
 | 
			
		||||
		t.Errorf("Should find all users's name not equal 3")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestFillSmallerStruct(t *testing.T) {
 | 
			
		||||
	user1 := User{Name: "SmallerUser", Age: 100}
 | 
			
		||||
	DB.Save(&user1)
 | 
			
		||||
	type SimpleUser struct {
 | 
			
		||||
		Name      string
 | 
			
		||||
		Id        int64
 | 
			
		||||
		UpdatedAt time.Time
 | 
			
		||||
		CreatedAt time.Time
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var simpleUser SimpleUser
 | 
			
		||||
	DB.Table("users").Where("name = ?", user1.Name).First(&simpleUser)
 | 
			
		||||
 | 
			
		||||
	if simpleUser.Id == 0 || simpleUser.Name == "" {
 | 
			
		||||
		t.Errorf("Should fill data correctly into smaller struct")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestFindOrInitialize(t *testing.T) {
 | 
			
		||||
	var user1, user2, user3, user4, user5, user6 User
 | 
			
		||||
	DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user1)
 | 
			
		||||
	if user1.Name != "find or init" || user1.Id != 0 || user1.Age != 33 {
 | 
			
		||||
		t.Errorf("user should be initialized with search value")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Where(User{Name: "find or init", Age: 33}).FirstOrInit(&user2)
 | 
			
		||||
	if user2.Name != "find or init" || user2.Id != 0 || user2.Age != 33 {
 | 
			
		||||
		t.Errorf("user should be initialized with search value")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.FirstOrInit(&user3, map[string]interface{}{"name": "find or init 2"})
 | 
			
		||||
	if user3.Name != "find or init 2" || user3.Id != 0 {
 | 
			
		||||
		t.Errorf("user should be initialized with inline search value")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Where(&User{Name: "find or init"}).Attrs(User{Age: 44}).FirstOrInit(&user4)
 | 
			
		||||
	if user4.Name != "find or init" || user4.Id != 0 || user4.Age != 44 {
 | 
			
		||||
		t.Errorf("user should be initialized with search value and attrs")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Where(&User{Name: "find or init"}).Assign("age", 44).FirstOrInit(&user4)
 | 
			
		||||
	if user4.Name != "find or init" || user4.Id != 0 || user4.Age != 44 {
 | 
			
		||||
		t.Errorf("user should be initialized with search value and assign attrs")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Save(&User{Name: "find or init", Age: 33})
 | 
			
		||||
	DB.Where(&User{Name: "find or init"}).Attrs("age", 44).FirstOrInit(&user5)
 | 
			
		||||
	if user5.Name != "find or init" || user5.Id == 0 || user5.Age != 33 {
 | 
			
		||||
		t.Errorf("user should be found and not initialized by Attrs")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user6)
 | 
			
		||||
	if user6.Name != "find or init" || user6.Id == 0 || user6.Age != 33 {
 | 
			
		||||
		t.Errorf("user should be found with FirstOrInit")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Where(&User{Name: "find or init"}).Assign(User{Age: 44}).FirstOrInit(&user6)
 | 
			
		||||
	if user6.Name != "find or init" || user6.Id == 0 || user6.Age != 44 {
 | 
			
		||||
		t.Errorf("user should be found and updated with assigned attrs")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestFindOrCreate(t *testing.T) {
 | 
			
		||||
	var user1, user2, user3, user4, user5, user6, user7, user8 User
 | 
			
		||||
	DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user1)
 | 
			
		||||
	if user1.Name != "find or create" || user1.Id == 0 || user1.Age != 33 {
 | 
			
		||||
		t.Errorf("user should be created with search value")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user2)
 | 
			
		||||
	if user1.Id != user2.Id || user2.Name != "find or create" || user2.Id == 0 || user2.Age != 33 {
 | 
			
		||||
		t.Errorf("user should be created with search value")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.FirstOrCreate(&user3, map[string]interface{}{"name": "find or create 2"})
 | 
			
		||||
	if user3.Name != "find or create 2" || user3.Id == 0 {
 | 
			
		||||
		t.Errorf("user should be created with inline search value")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Where(&User{Name: "find or create 3"}).Attrs("age", 44).FirstOrCreate(&user4)
 | 
			
		||||
	if user4.Name != "find or create 3" || user4.Id == 0 || user4.Age != 44 {
 | 
			
		||||
		t.Errorf("user should be created with search value and attrs")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	updatedAt1 := user4.UpdatedAt
 | 
			
		||||
	DB.Where(&User{Name: "find or create 3"}).Assign("age", 55).FirstOrCreate(&user4)
 | 
			
		||||
	if updatedAt1.Format(time.RFC3339Nano) == user4.UpdatedAt.Format(time.RFC3339Nano) {
 | 
			
		||||
		t.Errorf("UpdateAt should be changed when update values with assign")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Where(&User{Name: "find or create 4"}).Assign(User{Age: 44}).FirstOrCreate(&user4)
 | 
			
		||||
	if user4.Name != "find or create 4" || user4.Id == 0 || user4.Age != 44 {
 | 
			
		||||
		t.Errorf("user should be created with search value and assigned attrs")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Where(&User{Name: "find or create"}).Attrs("age", 44).FirstOrInit(&user5)
 | 
			
		||||
	if user5.Name != "find or create" || user5.Id == 0 || user5.Age != 33 {
 | 
			
		||||
		t.Errorf("user should be found and not initialized by Attrs")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Where(&User{Name: "find or create"}).Assign(User{Age: 44}).FirstOrCreate(&user6)
 | 
			
		||||
	if user6.Name != "find or create" || user6.Id == 0 || user6.Age != 44 {
 | 
			
		||||
		t.Errorf("user should be found and updated with assigned attrs")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Where(&User{Name: "find or create"}).Find(&user7)
 | 
			
		||||
	if user7.Name != "find or create" || user7.Id == 0 || user7.Age != 44 {
 | 
			
		||||
		t.Errorf("user should be found and updated with assigned attrs")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Where(&User{Name: "find or create embedded struct"}).Assign(User{Age: 44, CreditCard: CreditCard{Number: "1231231231"}, Emails: []Email{{Email: "jinzhu@assign_embedded_struct.com"}, {Email: "jinzhu-2@assign_embedded_struct.com"}}}).FirstOrCreate(&user8)
 | 
			
		||||
	if DB.Where("email = ?", "jinzhu-2@assign_embedded_struct.com").First(&Email{}).RecordNotFound() {
 | 
			
		||||
		t.Errorf("embedded struct email should be saved")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.Where("email = ?", "1231231231").First(&CreditCard{}).RecordNotFound() {
 | 
			
		||||
		t.Errorf("embedded struct credit card should be saved")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestSelectWithEscapedFieldName(t *testing.T) {
 | 
			
		||||
	user1 := User{Name: "EscapedFieldNameUser", Age: 1}
 | 
			
		||||
	user2 := User{Name: "EscapedFieldNameUser", Age: 10}
 | 
			
		||||
	user3 := User{Name: "EscapedFieldNameUser", Age: 20}
 | 
			
		||||
	DB.Save(&user1).Save(&user2).Save(&user3)
 | 
			
		||||
 | 
			
		||||
	var names []string
 | 
			
		||||
	DB.Model(User{}).Where(&User{Name: "EscapedFieldNameUser"}).Pluck("\"name\"", &names)
 | 
			
		||||
 | 
			
		||||
	if len(names) != 3 {
 | 
			
		||||
		t.Errorf("Expected 3 name, but got: %d", len(names))
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestSelectWithVariables(t *testing.T) {
 | 
			
		||||
	DB.Save(&User{Name: "jinzhu"})
 | 
			
		||||
 | 
			
		||||
	rows, _ := DB.Table("users").Select("? as fake", gorm.Expr("name")).Rows()
 | 
			
		||||
 | 
			
		||||
	if !rows.Next() {
 | 
			
		||||
		t.Errorf("Should have returned at least one row")
 | 
			
		||||
	} else {
 | 
			
		||||
		columns, _ := rows.Columns()
 | 
			
		||||
		if !reflect.DeepEqual(columns, []string{"fake"}) {
 | 
			
		||||
			t.Errorf("Should only contains one column")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rows.Close()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestSelectWithArrayInput(t *testing.T) {
 | 
			
		||||
	DB.Save(&User{Name: "jinzhu", Age: 42})
 | 
			
		||||
 | 
			
		||||
	var user User
 | 
			
		||||
	DB.Select([]string{"name", "age"}).Where("age = 42 AND name = 'jinzhu'").First(&user)
 | 
			
		||||
 | 
			
		||||
	if user.Name != "jinzhu" || user.Age != 42 {
 | 
			
		||||
		t.Errorf("Should have selected both age and name")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestPluckWithSelect(t *testing.T) {
 | 
			
		||||
	var (
 | 
			
		||||
		user              = User{Name: "matematik7_pluck_with_select", Age: 25}
 | 
			
		||||
		combinedName      = fmt.Sprintf("%v%v", user.Name, user.Age)
 | 
			
		||||
		combineUserAgeSQL = fmt.Sprintf("concat(%v, %v)", DB.Dialect().Quote("name"), DB.Dialect().Quote("age"))
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	if dialect := DB.Dialect().GetName(); dialect == "sqlite3" {
 | 
			
		||||
		combineUserAgeSQL = fmt.Sprintf("(%v || %v)", DB.Dialect().Quote("name"), DB.Dialect().Quote("age"))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Save(&user)
 | 
			
		||||
 | 
			
		||||
	selectStr := combineUserAgeSQL + " as user_age"
 | 
			
		||||
	var userAges []string
 | 
			
		||||
	err := DB.Model(&User{}).Where("age = ?", 25).Select(selectStr).Pluck("user_age", &userAges).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Error(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(userAges) != 1 || userAges[0] != combinedName {
 | 
			
		||||
		t.Errorf("Should correctly pluck with select, got: %s", userAges)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	selectStr = combineUserAgeSQL + fmt.Sprintf(" as %v", DB.Dialect().Quote("user_age"))
 | 
			
		||||
	userAges = userAges[:0]
 | 
			
		||||
	err = DB.Model(&User{}).Where("age = ?", 25).Select(selectStr).Pluck("user_age", &userAges).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Error(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(userAges) != 1 || userAges[0] != combinedName {
 | 
			
		||||
		t.Errorf("Should correctly pluck with select, got: %s", userAges)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										139
									
								
								scaner_test.go
									
									
									
									
									
								
							
							
						
						
									
										139
									
								
								scaner_test.go
									
									
									
									
									
								
							@ -1,139 +0,0 @@
 | 
			
		||||
package gorm_test
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql/driver"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
	"github.com/jinzhu/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestScannableSlices(t *testing.T) {
 | 
			
		||||
	if err := DB.AutoMigrate(&RecordWithSlice{}).Error; err != nil {
 | 
			
		||||
		t.Errorf("Should create table with slice values correctly: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	r1 := RecordWithSlice{
 | 
			
		||||
		Strings: ExampleStringSlice{"a", "b", "c"},
 | 
			
		||||
		Structs: ExampleStructSlice{
 | 
			
		||||
			{"name1", "value1"},
 | 
			
		||||
			{"name2", "value2"},
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := DB.Save(&r1).Error; err != nil {
 | 
			
		||||
		t.Errorf("Should save record with slice values")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var r2 RecordWithSlice
 | 
			
		||||
 | 
			
		||||
	if err := DB.Find(&r2).Error; err != nil {
 | 
			
		||||
		t.Errorf("Should fetch record with slice values")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(r2.Strings) != 3 || r2.Strings[0] != "a" || r2.Strings[1] != "b" || r2.Strings[2] != "c" {
 | 
			
		||||
		t.Errorf("Should have serialised and deserialised a string array")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(r2.Structs) != 2 || r2.Structs[0].Name != "name1" || r2.Structs[0].Value != "value1" || r2.Structs[1].Name != "name2" || r2.Structs[1].Value != "value2" {
 | 
			
		||||
		t.Errorf("Should have serialised and deserialised a struct array")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type RecordWithSlice struct {
 | 
			
		||||
	ID      uint64
 | 
			
		||||
	Strings ExampleStringSlice `sql:"type:text"`
 | 
			
		||||
	Structs ExampleStructSlice `sql:"type:text"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ExampleStringSlice []string
 | 
			
		||||
 | 
			
		||||
func (l ExampleStringSlice) Value() (driver.Value, error) {
 | 
			
		||||
	bytes, err := json.Marshal(l)
 | 
			
		||||
	return string(bytes), err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (l *ExampleStringSlice) Scan(input interface{}) error {
 | 
			
		||||
	switch value := input.(type) {
 | 
			
		||||
	case string:
 | 
			
		||||
		return json.Unmarshal([]byte(value), l)
 | 
			
		||||
	case []byte:
 | 
			
		||||
		return json.Unmarshal(value, l)
 | 
			
		||||
	default:
 | 
			
		||||
		return errors.New("not supported")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ExampleStruct struct {
 | 
			
		||||
	Name  string
 | 
			
		||||
	Value string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ExampleStructSlice []ExampleStruct
 | 
			
		||||
 | 
			
		||||
func (l ExampleStructSlice) Value() (driver.Value, error) {
 | 
			
		||||
	bytes, err := json.Marshal(l)
 | 
			
		||||
	return string(bytes), err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (l *ExampleStructSlice) Scan(input interface{}) error {
 | 
			
		||||
	switch value := input.(type) {
 | 
			
		||||
	case string:
 | 
			
		||||
		return json.Unmarshal([]byte(value), l)
 | 
			
		||||
	case []byte:
 | 
			
		||||
		return json.Unmarshal(value, l)
 | 
			
		||||
	default:
 | 
			
		||||
		return errors.New("not supported")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ScannerDataType struct {
 | 
			
		||||
	Street string `sql:"TYPE:varchar(24)"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ScannerDataType) Value() (driver.Value, error) {
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (*ScannerDataType) Scan(input interface{}) error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ScannerDataTypeTestStruct struct {
 | 
			
		||||
	Field1          int
 | 
			
		||||
	ScannerDataType *ScannerDataType `sql:"TYPE:json"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ScannerDataType2 struct {
 | 
			
		||||
	Street string `sql:"TYPE:varchar(24)"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ScannerDataType2) Value() (driver.Value, error) {
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (*ScannerDataType2) Scan(input interface{}) error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ScannerDataTypeTestStruct2 struct {
 | 
			
		||||
	Field1          int
 | 
			
		||||
	ScannerDataType *ScannerDataType2
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestScannerDataType(t *testing.T) {
 | 
			
		||||
	scope := gorm.Scope{Value: &ScannerDataTypeTestStruct{}}
 | 
			
		||||
	if field, ok := scope.FieldByName("ScannerDataType"); ok {
 | 
			
		||||
		if DB.Dialect().DataTypeOf(field.StructField) != "json" {
 | 
			
		||||
			t.Errorf("data type for scanner is wrong")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	scope = gorm.Scope{Value: &ScannerDataTypeTestStruct2{}}
 | 
			
		||||
	if field, ok := scope.FieldByName("ScannerDataType"); ok {
 | 
			
		||||
		if DB.Dialect().DataTypeOf(field.StructField) != "varchar(24)" {
 | 
			
		||||
			t.Errorf("data type for scanner is wrong")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -1,80 +0,0 @@
 | 
			
		||||
package gorm_test
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/hex"
 | 
			
		||||
	"math/rand"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
	"github.com/jinzhu/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func NameIn1And2(d *gorm.DB) *gorm.DB {
 | 
			
		||||
	return d.Where("name in (?)", []string{"ScopeUser1", "ScopeUser2"})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NameIn2And3(d *gorm.DB) *gorm.DB {
 | 
			
		||||
	return d.Where("name in (?)", []string{"ScopeUser2", "ScopeUser3"})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NameIn(names []string) func(d *gorm.DB) *gorm.DB {
 | 
			
		||||
	return func(d *gorm.DB) *gorm.DB {
 | 
			
		||||
		return d.Where("name in (?)", names)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestScopes(t *testing.T) {
 | 
			
		||||
	user1 := User{Name: "ScopeUser1", Age: 1}
 | 
			
		||||
	user2 := User{Name: "ScopeUser2", Age: 1}
 | 
			
		||||
	user3 := User{Name: "ScopeUser3", Age: 2}
 | 
			
		||||
	DB.Save(&user1).Save(&user2).Save(&user3)
 | 
			
		||||
 | 
			
		||||
	var users1, users2, users3 []User
 | 
			
		||||
	DB.Scopes(NameIn1And2).Find(&users1)
 | 
			
		||||
	if len(users1) != 2 {
 | 
			
		||||
		t.Errorf("Should found two users's name in 1, 2")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Scopes(NameIn1And2, NameIn2And3).Find(&users2)
 | 
			
		||||
	if len(users2) != 1 {
 | 
			
		||||
		t.Errorf("Should found one user's name is 2")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Scopes(NameIn([]string{user1.Name, user3.Name})).Find(&users3)
 | 
			
		||||
	if len(users3) != 2 {
 | 
			
		||||
		t.Errorf("Should found two users's name in 1, 3")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func randName() string {
 | 
			
		||||
	data := make([]byte, 8)
 | 
			
		||||
	rand.Read(data)
 | 
			
		||||
 | 
			
		||||
	return "n-" + hex.EncodeToString(data)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestValuer(t *testing.T) {
 | 
			
		||||
	name := randName()
 | 
			
		||||
 | 
			
		||||
	origUser := User{Name: name, Age: 1, Password: EncryptedData("pass1"), PasswordHash: []byte("abc")}
 | 
			
		||||
	if err := DB.Save(&origUser).Error; err != nil {
 | 
			
		||||
		t.Errorf("No error should happen when saving user, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var user2 User
 | 
			
		||||
	if err := DB.Where("name = ? AND password = ? AND password_hash = ?", name, EncryptedData("pass1"), []byte("abc")).First(&user2).Error; err != nil {
 | 
			
		||||
		t.Errorf("No error should happen when querying user with valuer, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestFailedValuer(t *testing.T) {
 | 
			
		||||
	name := randName()
 | 
			
		||||
 | 
			
		||||
	err := DB.Exec("INSERT INTO users(name, password) VALUES(?, ?)", name, EncryptedData("xpass1")).Error
 | 
			
		||||
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		t.Errorf("There should be an error should happen when insert data")
 | 
			
		||||
	} else if !strings.HasPrefix(err.Error(), "Should not start with") {
 | 
			
		||||
		t.Errorf("The error should be returned from Valuer, but get %v", err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										153
									
								
								search.go
									
									
									
									
									
								
							
							
						
						
									
										153
									
								
								search.go
									
									
									
									
									
								
							@ -1,153 +0,0 @@
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type search struct {
 | 
			
		||||
	db               *DB
 | 
			
		||||
	whereConditions  []map[string]interface{}
 | 
			
		||||
	orConditions     []map[string]interface{}
 | 
			
		||||
	notConditions    []map[string]interface{}
 | 
			
		||||
	havingConditions []map[string]interface{}
 | 
			
		||||
	joinConditions   []map[string]interface{}
 | 
			
		||||
	initAttrs        []interface{}
 | 
			
		||||
	assignAttrs      []interface{}
 | 
			
		||||
	selects          map[string]interface{}
 | 
			
		||||
	omits            []string
 | 
			
		||||
	orders           []interface{}
 | 
			
		||||
	preload          []searchPreload
 | 
			
		||||
	offset           interface{}
 | 
			
		||||
	limit            interface{}
 | 
			
		||||
	group            string
 | 
			
		||||
	tableName        string
 | 
			
		||||
	raw              bool
 | 
			
		||||
	Unscoped         bool
 | 
			
		||||
	ignoreOrderQuery bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type searchPreload struct {
 | 
			
		||||
	schema     string
 | 
			
		||||
	conditions []interface{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *search) clone() *search {
 | 
			
		||||
	clone := *s
 | 
			
		||||
	return &clone
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *search) Where(query interface{}, values ...interface{}) *search {
 | 
			
		||||
	s.whereConditions = append(s.whereConditions, map[string]interface{}{"query": query, "args": values})
 | 
			
		||||
	return s
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *search) Not(query interface{}, values ...interface{}) *search {
 | 
			
		||||
	s.notConditions = append(s.notConditions, map[string]interface{}{"query": query, "args": values})
 | 
			
		||||
	return s
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *search) Or(query interface{}, values ...interface{}) *search {
 | 
			
		||||
	s.orConditions = append(s.orConditions, map[string]interface{}{"query": query, "args": values})
 | 
			
		||||
	return s
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *search) Attrs(attrs ...interface{}) *search {
 | 
			
		||||
	s.initAttrs = append(s.initAttrs, toSearchableMap(attrs...))
 | 
			
		||||
	return s
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *search) Assign(attrs ...interface{}) *search {
 | 
			
		||||
	s.assignAttrs = append(s.assignAttrs, toSearchableMap(attrs...))
 | 
			
		||||
	return s
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *search) Order(value interface{}, reorder ...bool) *search {
 | 
			
		||||
	if len(reorder) > 0 && reorder[0] {
 | 
			
		||||
		s.orders = []interface{}{}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if value != nil && value != "" {
 | 
			
		||||
		s.orders = append(s.orders, value)
 | 
			
		||||
	}
 | 
			
		||||
	return s
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *search) Select(query interface{}, args ...interface{}) *search {
 | 
			
		||||
	s.selects = map[string]interface{}{"query": query, "args": args}
 | 
			
		||||
	return s
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *search) Omit(columns ...string) *search {
 | 
			
		||||
	s.omits = columns
 | 
			
		||||
	return s
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *search) Limit(limit interface{}) *search {
 | 
			
		||||
	s.limit = limit
 | 
			
		||||
	return s
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *search) Offset(offset interface{}) *search {
 | 
			
		||||
	s.offset = offset
 | 
			
		||||
	return s
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *search) Group(query string) *search {
 | 
			
		||||
	s.group = s.getInterfaceAsSQL(query)
 | 
			
		||||
	return s
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *search) Having(query interface{}, values ...interface{}) *search {
 | 
			
		||||
	if val, ok := query.(*expr); ok {
 | 
			
		||||
		s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": val.expr, "args": val.args})
 | 
			
		||||
	} else {
 | 
			
		||||
		s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": query, "args": values})
 | 
			
		||||
	}
 | 
			
		||||
	return s
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *search) Joins(query string, values ...interface{}) *search {
 | 
			
		||||
	s.joinConditions = append(s.joinConditions, map[string]interface{}{"query": query, "args": values})
 | 
			
		||||
	return s
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *search) Preload(schema string, values ...interface{}) *search {
 | 
			
		||||
	var preloads []searchPreload
 | 
			
		||||
	for _, preload := range s.preload {
 | 
			
		||||
		if preload.schema != schema {
 | 
			
		||||
			preloads = append(preloads, preload)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	preloads = append(preloads, searchPreload{schema, values})
 | 
			
		||||
	s.preload = preloads
 | 
			
		||||
	return s
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *search) Raw(b bool) *search {
 | 
			
		||||
	s.raw = b
 | 
			
		||||
	return s
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *search) unscoped() *search {
 | 
			
		||||
	s.Unscoped = true
 | 
			
		||||
	return s
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *search) Table(name string) *search {
 | 
			
		||||
	s.tableName = name
 | 
			
		||||
	return s
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *search) getInterfaceAsSQL(value interface{}) (str string) {
 | 
			
		||||
	switch value.(type) {
 | 
			
		||||
	case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
 | 
			
		||||
		str = fmt.Sprintf("%v", value)
 | 
			
		||||
	default:
 | 
			
		||||
		s.db.AddError(ErrInvalidSQL)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if str == "-1" {
 | 
			
		||||
		return ""
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
@ -1,30 +0,0 @@
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestCloneSearch(t *testing.T) {
 | 
			
		||||
	s := new(search)
 | 
			
		||||
	s.Where("name = ?", "jinzhu").Order("name").Attrs("name", "jinzhu").Select("name, age")
 | 
			
		||||
 | 
			
		||||
	s1 := s.clone()
 | 
			
		||||
	s1.Where("age = ?", 20).Order("age").Attrs("email", "a@e.org").Select("email")
 | 
			
		||||
 | 
			
		||||
	if reflect.DeepEqual(s.whereConditions, s1.whereConditions) {
 | 
			
		||||
		t.Errorf("Where should be copied")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if reflect.DeepEqual(s.orders, s1.orders) {
 | 
			
		||||
		t.Errorf("Order should be copied")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if reflect.DeepEqual(s.initAttrs, s1.initAttrs) {
 | 
			
		||||
		t.Errorf("InitAttrs should be copied")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if reflect.DeepEqual(s.Select, s1.Select) {
 | 
			
		||||
		t.Errorf("selectStr should be copied")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										465
									
								
								update_test.go
									
									
									
									
									
								
							
							
						
						
									
										465
									
								
								update_test.go
									
									
									
									
									
								
							@ -1,465 +0,0 @@
 | 
			
		||||
package gorm_test
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/jinzhu/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestUpdate(t *testing.T) {
 | 
			
		||||
	product1 := Product{Code: "product1code"}
 | 
			
		||||
	product2 := Product{Code: "product2code"}
 | 
			
		||||
 | 
			
		||||
	DB.Save(&product1).Save(&product2).Update("code", "product2newcode")
 | 
			
		||||
 | 
			
		||||
	if product2.Code != "product2newcode" {
 | 
			
		||||
		t.Errorf("Record should be updated")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.First(&product1, product1.Id)
 | 
			
		||||
	DB.First(&product2, product2.Id)
 | 
			
		||||
	updatedAt1 := product1.UpdatedAt
 | 
			
		||||
 | 
			
		||||
	if DB.First(&Product{}, "code = ?", product1.Code).RecordNotFound() {
 | 
			
		||||
		t.Errorf("Product1 should not be updated")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !DB.First(&Product{}, "code = ?", "product2code").RecordNotFound() {
 | 
			
		||||
		t.Errorf("Product2's code should be updated")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.First(&Product{}, "code = ?", "product2newcode").RecordNotFound() {
 | 
			
		||||
		t.Errorf("Product2's code should be updated")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Table("products").Where("code in (?)", []string{"product1code"}).Update("code", "product1newcode")
 | 
			
		||||
 | 
			
		||||
	var product4 Product
 | 
			
		||||
	DB.First(&product4, product1.Id)
 | 
			
		||||
	if updatedAt1.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) {
 | 
			
		||||
		t.Errorf("updatedAt should be updated if something changed")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !DB.First(&Product{}, "code = 'product1code'").RecordNotFound() {
 | 
			
		||||
		t.Errorf("Product1's code should be updated")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.First(&Product{}, "code = 'product1newcode'").RecordNotFound() {
 | 
			
		||||
		t.Errorf("Product should not be changed to 789")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.Model(product2).Update("CreatedAt", time.Now().Add(time.Hour)).Error != nil {
 | 
			
		||||
		t.Error("No error should raise when update with CamelCase")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.Model(&product2).UpdateColumn("CreatedAt", time.Now().Add(time.Hour)).Error != nil {
 | 
			
		||||
		t.Error("No error should raise when update_column with CamelCase")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var products []Product
 | 
			
		||||
	DB.Find(&products)
 | 
			
		||||
	if count := DB.Model(Product{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(products)) {
 | 
			
		||||
		t.Error("RowsAffected should be correct when do batch update")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.First(&product4, product4.Id)
 | 
			
		||||
	updatedAt4 := product4.UpdatedAt
 | 
			
		||||
	DB.Model(&product4).Update("price", gorm.Expr("price + ? - ?", 100, 50))
 | 
			
		||||
	var product5 Product
 | 
			
		||||
	DB.First(&product5, product4.Id)
 | 
			
		||||
	if product5.Price != product4.Price+100-50 {
 | 
			
		||||
		t.Errorf("Update with expression")
 | 
			
		||||
	}
 | 
			
		||||
	if product4.UpdatedAt.Format(time.RFC3339Nano) == updatedAt4.Format(time.RFC3339Nano) {
 | 
			
		||||
		t.Errorf("Update with expression should update UpdatedAt")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestUpdateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) {
 | 
			
		||||
	animal := Animal{Name: "Ferdinand"}
 | 
			
		||||
	DB.Save(&animal)
 | 
			
		||||
	updatedAt1 := animal.UpdatedAt
 | 
			
		||||
 | 
			
		||||
	DB.Save(&animal).Update("name", "Francis")
 | 
			
		||||
 | 
			
		||||
	if updatedAt1.Format(time.RFC3339Nano) == animal.UpdatedAt.Format(time.RFC3339Nano) {
 | 
			
		||||
		t.Errorf("updatedAt should not be updated if nothing changed")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var animals []Animal
 | 
			
		||||
	DB.Find(&animals)
 | 
			
		||||
	if count := DB.Model(Animal{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(animals)) {
 | 
			
		||||
		t.Error("RowsAffected should be correct when do batch update")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	animal = Animal{From: "somewhere"}              // No name fields, should be filled with the default value (galeone)
 | 
			
		||||
	DB.Save(&animal).Update("From", "a nice place") // The name field shoul be untouched
 | 
			
		||||
	DB.First(&animal, animal.Counter)
 | 
			
		||||
	if animal.Name != "galeone" {
 | 
			
		||||
		t.Errorf("Name fields shouldn't be changed if untouched, but got %v", animal.Name)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// When changing a field with a default value, the change must occur
 | 
			
		||||
	animal.Name = "amazing horse"
 | 
			
		||||
	DB.Save(&animal)
 | 
			
		||||
	DB.First(&animal, animal.Counter)
 | 
			
		||||
	if animal.Name != "amazing horse" {
 | 
			
		||||
		t.Errorf("Update a filed with a default value should occur. But got %v\n", animal.Name)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// When changing a field with a default value with blank value
 | 
			
		||||
	animal.Name = ""
 | 
			
		||||
	DB.Save(&animal)
 | 
			
		||||
	DB.First(&animal, animal.Counter)
 | 
			
		||||
	if animal.Name != "" {
 | 
			
		||||
		t.Errorf("Update a filed to blank with a default value should occur. But got %v\n", animal.Name)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestUpdates(t *testing.T) {
 | 
			
		||||
	product1 := Product{Code: "product1code", Price: 10}
 | 
			
		||||
	product2 := Product{Code: "product2code", Price: 10}
 | 
			
		||||
	DB.Save(&product1).Save(&product2)
 | 
			
		||||
	DB.Model(&product1).Updates(map[string]interface{}{"code": "product1newcode", "price": 100})
 | 
			
		||||
	if product1.Code != "product1newcode" || product1.Price != 100 {
 | 
			
		||||
		t.Errorf("Record should be updated also with map")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.First(&product1, product1.Id)
 | 
			
		||||
	DB.First(&product2, product2.Id)
 | 
			
		||||
	updatedAt2 := product2.UpdatedAt
 | 
			
		||||
 | 
			
		||||
	if DB.First(&Product{}, "code = ? and price = ?", product2.Code, product2.Price).RecordNotFound() {
 | 
			
		||||
		t.Errorf("Product2 should not be updated")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.First(&Product{}, "code = ?", "product1newcode").RecordNotFound() {
 | 
			
		||||
		t.Errorf("Product1 should be updated")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Table("products").Where("code in (?)", []string{"product2code"}).Updates(Product{Code: "product2newcode"})
 | 
			
		||||
	if !DB.First(&Product{}, "code = 'product2code'").RecordNotFound() {
 | 
			
		||||
		t.Errorf("Product2's code should be updated")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var product4 Product
 | 
			
		||||
	DB.First(&product4, product2.Id)
 | 
			
		||||
	if updatedAt2.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) {
 | 
			
		||||
		t.Errorf("updatedAt should be updated if something changed")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if DB.First(&Product{}, "code = ?", "product2newcode").RecordNotFound() {
 | 
			
		||||
		t.Errorf("product2's code should be updated")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	updatedAt4 := product4.UpdatedAt
 | 
			
		||||
	DB.Model(&product4).Updates(map[string]interface{}{"price": gorm.Expr("price + ?", 100)})
 | 
			
		||||
	var product5 Product
 | 
			
		||||
	DB.First(&product5, product4.Id)
 | 
			
		||||
	if product5.Price != product4.Price+100 {
 | 
			
		||||
		t.Errorf("Updates with expression")
 | 
			
		||||
	}
 | 
			
		||||
	// product4's UpdatedAt will be reset when updating
 | 
			
		||||
	if product4.UpdatedAt.Format(time.RFC3339Nano) == updatedAt4.Format(time.RFC3339Nano) {
 | 
			
		||||
		t.Errorf("Updates with expression should update UpdatedAt")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestUpdateColumn(t *testing.T) {
 | 
			
		||||
	product1 := Product{Code: "product1code", Price: 10}
 | 
			
		||||
	product2 := Product{Code: "product2code", Price: 20}
 | 
			
		||||
	DB.Save(&product1).Save(&product2).UpdateColumn(map[string]interface{}{"code": "product2newcode", "price": 100})
 | 
			
		||||
	if product2.Code != "product2newcode" || product2.Price != 100 {
 | 
			
		||||
		t.Errorf("product 2 should be updated with update column")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var product3 Product
 | 
			
		||||
	DB.First(&product3, product1.Id)
 | 
			
		||||
	if product3.Code != "product1code" || product3.Price != 10 {
 | 
			
		||||
		t.Errorf("product 1 should not be updated")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.First(&product2, product2.Id)
 | 
			
		||||
	updatedAt2 := product2.UpdatedAt
 | 
			
		||||
	DB.Model(product2).UpdateColumn("code", "update_column_new")
 | 
			
		||||
	var product4 Product
 | 
			
		||||
	DB.First(&product4, product2.Id)
 | 
			
		||||
	if updatedAt2.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) {
 | 
			
		||||
		t.Errorf("updatedAt should not be updated with update column")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Model(&product4).UpdateColumn("price", gorm.Expr("price + 100 - 50"))
 | 
			
		||||
	var product5 Product
 | 
			
		||||
	DB.First(&product5, product4.Id)
 | 
			
		||||
	if product5.Price != product4.Price+100-50 {
 | 
			
		||||
		t.Errorf("UpdateColumn with expression")
 | 
			
		||||
	}
 | 
			
		||||
	if product5.UpdatedAt.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) {
 | 
			
		||||
		t.Errorf("UpdateColumn with expression should not update UpdatedAt")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestSelectWithUpdate(t *testing.T) {
 | 
			
		||||
	user := getPreparedUser("select_user", "select_with_update")
 | 
			
		||||
	DB.Create(user)
 | 
			
		||||
 | 
			
		||||
	var reloadUser User
 | 
			
		||||
	DB.First(&reloadUser, user.Id)
 | 
			
		||||
	reloadUser.Name = "new_name"
 | 
			
		||||
	reloadUser.Age = 50
 | 
			
		||||
	reloadUser.BillingAddress = Address{Address1: "New Billing Address"}
 | 
			
		||||
	reloadUser.ShippingAddress = Address{Address1: "New ShippingAddress Address"}
 | 
			
		||||
	reloadUser.CreditCard = CreditCard{Number: "987654321"}
 | 
			
		||||
	reloadUser.Emails = []Email{
 | 
			
		||||
		{Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"},
 | 
			
		||||
	}
 | 
			
		||||
	reloadUser.Company = Company{Name: "new company"}
 | 
			
		||||
 | 
			
		||||
	DB.Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Save(&reloadUser)
 | 
			
		||||
 | 
			
		||||
	var queryUser User
 | 
			
		||||
	DB.Preload("BillingAddress").Preload("ShippingAddress").
 | 
			
		||||
		Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id)
 | 
			
		||||
 | 
			
		||||
	if queryUser.Name == user.Name || queryUser.Age != user.Age {
 | 
			
		||||
		t.Errorf("Should only update users with name column")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if queryUser.BillingAddressID.Int64 == user.BillingAddressID.Int64 ||
 | 
			
		||||
		queryUser.ShippingAddressId != user.ShippingAddressId ||
 | 
			
		||||
		queryUser.CreditCard.ID == user.CreditCard.ID ||
 | 
			
		||||
		len(queryUser.Emails) == len(user.Emails) || queryUser.Company.Id == user.Company.Id {
 | 
			
		||||
		t.Errorf("Should only update selected relationships")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestSelectWithUpdateWithMap(t *testing.T) {
 | 
			
		||||
	user := getPreparedUser("select_user", "select_with_update_map")
 | 
			
		||||
	DB.Create(user)
 | 
			
		||||
 | 
			
		||||
	updateValues := map[string]interface{}{
 | 
			
		||||
		"Name":            "new_name",
 | 
			
		||||
		"Age":             50,
 | 
			
		||||
		"BillingAddress":  Address{Address1: "New Billing Address"},
 | 
			
		||||
		"ShippingAddress": Address{Address1: "New ShippingAddress Address"},
 | 
			
		||||
		"CreditCard":      CreditCard{Number: "987654321"},
 | 
			
		||||
		"Emails": []Email{
 | 
			
		||||
			{Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"},
 | 
			
		||||
		},
 | 
			
		||||
		"Company": Company{Name: "new company"},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var reloadUser User
 | 
			
		||||
	DB.First(&reloadUser, user.Id)
 | 
			
		||||
	DB.Model(&reloadUser).Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Update(updateValues)
 | 
			
		||||
 | 
			
		||||
	var queryUser User
 | 
			
		||||
	DB.Preload("BillingAddress").Preload("ShippingAddress").
 | 
			
		||||
		Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id)
 | 
			
		||||
 | 
			
		||||
	if queryUser.Name == user.Name || queryUser.Age != user.Age {
 | 
			
		||||
		t.Errorf("Should only update users with name column")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if queryUser.BillingAddressID.Int64 == user.BillingAddressID.Int64 ||
 | 
			
		||||
		queryUser.ShippingAddressId != user.ShippingAddressId ||
 | 
			
		||||
		queryUser.CreditCard.ID == user.CreditCard.ID ||
 | 
			
		||||
		len(queryUser.Emails) == len(user.Emails) || queryUser.Company.Id == user.Company.Id {
 | 
			
		||||
		t.Errorf("Should only update selected relationships")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestOmitWithUpdate(t *testing.T) {
 | 
			
		||||
	user := getPreparedUser("omit_user", "omit_with_update")
 | 
			
		||||
	DB.Create(user)
 | 
			
		||||
 | 
			
		||||
	var reloadUser User
 | 
			
		||||
	DB.First(&reloadUser, user.Id)
 | 
			
		||||
	reloadUser.Name = "new_name"
 | 
			
		||||
	reloadUser.Age = 50
 | 
			
		||||
	reloadUser.BillingAddress = Address{Address1: "New Billing Address"}
 | 
			
		||||
	reloadUser.ShippingAddress = Address{Address1: "New ShippingAddress Address"}
 | 
			
		||||
	reloadUser.CreditCard = CreditCard{Number: "987654321"}
 | 
			
		||||
	reloadUser.Emails = []Email{
 | 
			
		||||
		{Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"},
 | 
			
		||||
	}
 | 
			
		||||
	reloadUser.Company = Company{Name: "new company"}
 | 
			
		||||
 | 
			
		||||
	DB.Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Save(&reloadUser)
 | 
			
		||||
 | 
			
		||||
	var queryUser User
 | 
			
		||||
	DB.Preload("BillingAddress").Preload("ShippingAddress").
 | 
			
		||||
		Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id)
 | 
			
		||||
 | 
			
		||||
	if queryUser.Name != user.Name || queryUser.Age == user.Age {
 | 
			
		||||
		t.Errorf("Should only update users with name column")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if queryUser.BillingAddressID.Int64 != user.BillingAddressID.Int64 ||
 | 
			
		||||
		queryUser.ShippingAddressId == user.ShippingAddressId ||
 | 
			
		||||
		queryUser.CreditCard.ID != user.CreditCard.ID ||
 | 
			
		||||
		len(queryUser.Emails) != len(user.Emails) || queryUser.Company.Id != user.Company.Id {
 | 
			
		||||
		t.Errorf("Should only update relationships that not omitted")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestOmitWithUpdateWithMap(t *testing.T) {
 | 
			
		||||
	user := getPreparedUser("select_user", "select_with_update_map")
 | 
			
		||||
	DB.Create(user)
 | 
			
		||||
 | 
			
		||||
	updateValues := map[string]interface{}{
 | 
			
		||||
		"Name":            "new_name",
 | 
			
		||||
		"Age":             50,
 | 
			
		||||
		"BillingAddress":  Address{Address1: "New Billing Address"},
 | 
			
		||||
		"ShippingAddress": Address{Address1: "New ShippingAddress Address"},
 | 
			
		||||
		"CreditCard":      CreditCard{Number: "987654321"},
 | 
			
		||||
		"Emails": []Email{
 | 
			
		||||
			{Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"},
 | 
			
		||||
		},
 | 
			
		||||
		"Company": Company{Name: "new company"},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var reloadUser User
 | 
			
		||||
	DB.First(&reloadUser, user.Id)
 | 
			
		||||
	DB.Model(&reloadUser).Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Update(updateValues)
 | 
			
		||||
 | 
			
		||||
	var queryUser User
 | 
			
		||||
	DB.Preload("BillingAddress").Preload("ShippingAddress").
 | 
			
		||||
		Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id)
 | 
			
		||||
 | 
			
		||||
	if queryUser.Name != user.Name || queryUser.Age == user.Age {
 | 
			
		||||
		t.Errorf("Should only update users with name column")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if queryUser.BillingAddressID.Int64 != user.BillingAddressID.Int64 ||
 | 
			
		||||
		queryUser.ShippingAddressId == user.ShippingAddressId ||
 | 
			
		||||
		queryUser.CreditCard.ID != user.CreditCard.ID ||
 | 
			
		||||
		len(queryUser.Emails) != len(user.Emails) || queryUser.Company.Id != user.Company.Id {
 | 
			
		||||
		t.Errorf("Should only update relationships not omitted")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestSelectWithUpdateColumn(t *testing.T) {
 | 
			
		||||
	user := getPreparedUser("select_user", "select_with_update_map")
 | 
			
		||||
	DB.Create(user)
 | 
			
		||||
 | 
			
		||||
	updateValues := map[string]interface{}{"Name": "new_name", "Age": 50}
 | 
			
		||||
 | 
			
		||||
	var reloadUser User
 | 
			
		||||
	DB.First(&reloadUser, user.Id)
 | 
			
		||||
	DB.Model(&reloadUser).Select("Name").UpdateColumn(updateValues)
 | 
			
		||||
 | 
			
		||||
	var queryUser User
 | 
			
		||||
	DB.First(&queryUser, user.Id)
 | 
			
		||||
 | 
			
		||||
	if queryUser.Name == user.Name || queryUser.Age != user.Age {
 | 
			
		||||
		t.Errorf("Should only update users with name column")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestOmitWithUpdateColumn(t *testing.T) {
 | 
			
		||||
	user := getPreparedUser("select_user", "select_with_update_map")
 | 
			
		||||
	DB.Create(user)
 | 
			
		||||
 | 
			
		||||
	updateValues := map[string]interface{}{"Name": "new_name", "Age": 50}
 | 
			
		||||
 | 
			
		||||
	var reloadUser User
 | 
			
		||||
	DB.First(&reloadUser, user.Id)
 | 
			
		||||
	DB.Model(&reloadUser).Omit("Name").UpdateColumn(updateValues)
 | 
			
		||||
 | 
			
		||||
	var queryUser User
 | 
			
		||||
	DB.First(&queryUser, user.Id)
 | 
			
		||||
 | 
			
		||||
	if queryUser.Name != user.Name || queryUser.Age == user.Age {
 | 
			
		||||
		t.Errorf("Should omit name column when update user")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestUpdateColumnsSkipsAssociations(t *testing.T) {
 | 
			
		||||
	user := getPreparedUser("update_columns_user", "special_role")
 | 
			
		||||
	user.Age = 99
 | 
			
		||||
	address1 := "first street"
 | 
			
		||||
	user.BillingAddress = Address{Address1: address1}
 | 
			
		||||
	DB.Save(user)
 | 
			
		||||
 | 
			
		||||
	// Update a single field of the user and verify that the changed address is not stored.
 | 
			
		||||
	newAge := int64(100)
 | 
			
		||||
	user.BillingAddress.Address1 = "second street"
 | 
			
		||||
	db := DB.Model(user).UpdateColumns(User{Age: newAge})
 | 
			
		||||
	if db.RowsAffected != 1 {
 | 
			
		||||
		t.Errorf("Expected RowsAffected=1 but instead RowsAffected=%v", DB.RowsAffected)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Verify that Age now=`newAge`.
 | 
			
		||||
	freshUser := &User{Id: user.Id}
 | 
			
		||||
	DB.First(freshUser)
 | 
			
		||||
	if freshUser.Age != newAge {
 | 
			
		||||
		t.Errorf("Expected freshly queried user to have Age=%v but instead found Age=%v", newAge, freshUser.Age)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Verify that user's BillingAddress.Address1 is not changed and is still "first street".
 | 
			
		||||
	DB.First(&freshUser.BillingAddress, freshUser.BillingAddressID)
 | 
			
		||||
	if freshUser.BillingAddress.Address1 != address1 {
 | 
			
		||||
		t.Errorf("Expected user's BillingAddress.Address1=%s to remain unchanged after UpdateColumns invocation, but BillingAddress.Address1=%s", address1, freshUser.BillingAddress.Address1)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestUpdatesWithBlankValues(t *testing.T) {
 | 
			
		||||
	product := Product{Code: "product1", Price: 10}
 | 
			
		||||
	DB.Save(&product)
 | 
			
		||||
 | 
			
		||||
	DB.Model(&Product{Id: product.Id}).Updates(&Product{Price: 100})
 | 
			
		||||
 | 
			
		||||
	var product1 Product
 | 
			
		||||
	DB.First(&product1, product.Id)
 | 
			
		||||
 | 
			
		||||
	if product1.Code != "product1" || product1.Price != 100 {
 | 
			
		||||
		t.Errorf("product's code should not be updated")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ElementWithIgnoredField struct {
 | 
			
		||||
	Id           int64
 | 
			
		||||
	Value        string
 | 
			
		||||
	IgnoredField int64 `sql:"-"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (e ElementWithIgnoredField) TableName() string {
 | 
			
		||||
	return "element_with_ignored_field"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestUpdatesTableWithIgnoredValues(t *testing.T) {
 | 
			
		||||
	elem := ElementWithIgnoredField{Value: "foo", IgnoredField: 10}
 | 
			
		||||
	DB.Save(&elem)
 | 
			
		||||
 | 
			
		||||
	DB.Table(elem.TableName()).
 | 
			
		||||
		Where("id = ?", elem.Id).
 | 
			
		||||
		// DB.Model(&ElementWithIgnoredField{Id: elem.Id}).
 | 
			
		||||
		Updates(&ElementWithIgnoredField{Value: "bar", IgnoredField: 100})
 | 
			
		||||
 | 
			
		||||
	var elem1 ElementWithIgnoredField
 | 
			
		||||
	err := DB.First(&elem1, elem.Id).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("error getting an element from database: %s", err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if elem1.IgnoredField != 0 {
 | 
			
		||||
		t.Errorf("element's ignored field should not be updated")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestUpdateDecodeVirtualAttributes(t *testing.T) {
 | 
			
		||||
	var user = User{
 | 
			
		||||
		Name:     "jinzhu",
 | 
			
		||||
		IgnoreMe: 88,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DB.Save(&user)
 | 
			
		||||
 | 
			
		||||
	DB.Model(&user).Updates(User{Name: "jinzhu2", IgnoreMe: 100})
 | 
			
		||||
 | 
			
		||||
	if user.IgnoreMe != 100 {
 | 
			
		||||
		t.Errorf("should decode virtual attributes to struct, so it could be used in callbacks")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -1,32 +0,0 @@
 | 
			
		||||
package gorm_test
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
	"github.com/jinzhu/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestToDBNameGenerateFriendlyName(t *testing.T) {
 | 
			
		||||
	var maps = map[string]string{
 | 
			
		||||
		"":                          "",
 | 
			
		||||
		"X":                         "x",
 | 
			
		||||
		"ThisIsATest":               "this_is_a_test",
 | 
			
		||||
		"PFAndESI":                  "pf_and_esi",
 | 
			
		||||
		"AbcAndJkl":                 "abc_and_jkl",
 | 
			
		||||
		"EmployeeID":                "employee_id",
 | 
			
		||||
		"SKU_ID":                    "sku_id",
 | 
			
		||||
		"FieldX":                    "field_x",
 | 
			
		||||
		"HTTPAndSMTP":               "http_and_smtp",
 | 
			
		||||
		"HTTPServerHandlerForURLID": "http_server_handler_for_url_id",
 | 
			
		||||
		"UUID":     "uuid",
 | 
			
		||||
		"HTTPURL":  "http_url",
 | 
			
		||||
		"HTTP_URL": "http_url",
 | 
			
		||||
		"ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIdCanBeUsedAtTheEndAsID": "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id",
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for key, value := range maps {
 | 
			
		||||
		if gorm.ToDBName(key) != value {
 | 
			
		||||
			t.Errorf("%v ToDBName should equal %v, but got %v", key, value, gorm.ToDBName(key))
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user