Add extract a interface based on gorm.DB methods

This commit is contained in:
Alvaro Viebrantz 2016-01-17 14:41:34 -03:00
parent 341d047aa7
commit cd40e372d3
30 changed files with 430 additions and 312 deletions

View File

@ -23,7 +23,7 @@ func (association *Association) setErr(err error) *Association {
func (association *Association) Find(value interface{}) *Association { func (association *Association) Find(value interface{}) *Association {
association.Scope.related(value, association.Column) association.Scope.related(value, association.Column)
return association.setErr(association.Scope.db.Error) return association.setErr(association.Scope.db.GetError())
} }
func (association *Association) saveAssociations(values ...interface{}) *Association { func (association *Association) saveAssociations(values ...interface{}) *Association {
@ -42,7 +42,7 @@ func (association *Association) saveAssociations(values ...interface{}) *Associa
// value has to been saved for many2many // value has to been saved for many2many
if relationship.Kind == "many_to_many" { if relationship.Kind == "many_to_many" {
if scope.New(reflectValue.Interface()).PrimaryKeyZero() { if scope.New(reflectValue.Interface()).PrimaryKeyZero() {
association.setErr(scope.NewDB().Save(reflectValue.Interface()).Error) association.setErr(scope.NewDB().Save(reflectValue.Interface()).GetError())
} }
} }
@ -68,7 +68,7 @@ func (association *Association) saveAssociations(values ...interface{}) *Associa
if relationship.Kind == "many_to_many" { if relationship.Kind == "many_to_many" {
association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, reflectValue.Interface())) association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, reflectValue.Interface()))
} else { } else {
association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).Error) association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).GetError())
if setFieldBackToValue { if setFieldBackToValue {
reflectValue.Elem().Set(field.Field) reflectValue.Elem().Set(field.Field)
@ -106,7 +106,7 @@ func (association *Association) Replace(values ...interface{}) *Association {
relationship = association.Field.Relationship relationship = association.Field.Relationship
scope = association.Scope scope = association.Scope
field = association.Field.Field field = association.Field.Field
newDB = scope.NewDB() newDB Database = scope.NewDB()
) )
// Append new values // Append new values
@ -122,7 +122,7 @@ func (association *Association) Replace(values ...interface{}) *Association {
for _, foreignKey := range relationship.ForeignDBNames { for _, foreignKey := range relationship.ForeignDBNames {
foreignKeyMap[foreignKey] = nil foreignKeyMap[foreignKey] = nil
} }
association.setErr(newDB.Model(scope.Value).UpdateColumn(foreignKeyMap).Error) association.setErr(newDB.Model(scope.Value).UpdateColumn(foreignKeyMap).GetError())
} }
} else { } else {
// Relations // Relations
@ -173,7 +173,7 @@ func (association *Association) Replace(values ...interface{}) *Association {
} }
fieldValue := reflect.New(association.Field.Field.Type()).Interface() fieldValue := reflect.New(association.Field.Field.Type()).Interface()
association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error) association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).GetError())
} }
} }
return association return association
@ -184,7 +184,7 @@ func (association *Association) Delete(values ...interface{}) *Association {
relationship = association.Field.Relationship relationship = association.Field.Relationship
scope = association.Scope scope = association.Scope
field = association.Field.Field field = association.Field.Field
newDB = scope.NewDB() newDB Database = scope.NewDB()
) )
if len(values) == 0 { if len(values) == 0 {
@ -231,12 +231,12 @@ func (association *Association) Delete(values ...interface{}) *Association {
// set foreign key to be null // set foreign key to be null
modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface() modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface()
if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.Error == nil { if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.GetError() == nil {
if results.RowsAffected > 0 { if results.GetRowsAffected() > 0 {
scope.updatedAttrsWithValues(foreignKeyMap, false) scope.updatedAttrsWithValues(foreignKeyMap, false)
} }
} else { } else {
association.setErr(results.Error) association.setErr(results.GetError())
} }
} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
// find all relations // find all relations
@ -254,7 +254,7 @@ func (association *Association) Delete(values ...interface{}) *Association {
// set matched relation's foreign key to be null // set matched relation's foreign key to be null
fieldValue := reflect.New(association.Field.Field.Type()).Interface() fieldValue := reflect.New(association.Field.Field.Type()).Interface()
association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error) association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).GetError())
} }
} }
@ -308,7 +308,7 @@ func (association *Association) Count() int {
if relationship.Kind == "many_to_many" { if relationship.Kind == "many_to_many" {
relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.DB(), association.Scope.Value).Model(fieldValue).Count(&count) relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.DB(), association.Scope.Value).Model(fieldValue).Count(&count)
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
query := scope.DB() var query Database = scope.DB()
for idx, foreignKey := range relationship.ForeignDBNames { for idx, foreignKey := range relationship.ForeignDBNames {
if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok {
query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(foreignKey)), query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(foreignKey)),
@ -321,7 +321,7 @@ func (association *Association) Count() int {
} }
query.Model(fieldValue).Count(&count) query.Model(fieldValue).Count(&count)
} else if relationship.Kind == "belongs_to" { } else if relationship.Kind == "belongs_to" {
query := scope.DB() var query Database = scope.DB()
for idx, primaryKey := range relationship.AssociationForeignDBNames { for idx, primaryKey := range relationship.AssociationForeignDBNames {
if field, ok := scope.FieldByName(relationship.ForeignDBNames[idx]); ok { if field, ok := scope.FieldByName(relationship.ForeignDBNames[idx]); ok {
query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(primaryKey)), query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(primaryKey)),

View File

@ -15,7 +15,7 @@ func TestBelongsTo(t *testing.T) {
MainCategory: Category{Name: "Main Category 1"}, MainCategory: Category{Name: "Main Category 1"},
} }
if err := DB.Save(&post).Error; err != nil { if err := DB.Save(&post).GetError(); err != nil {
t.Errorf("Got errors when save post", err.Error()) t.Errorf("Got errors when save post", err.Error())
} }
@ -183,7 +183,7 @@ func TestHasOne(t *testing.T) {
CreditCard: CreditCard{Number: "411111111111"}, CreditCard: CreditCard{Number: "411111111111"},
} }
if err := DB.Save(&user).Error; err != nil { if err := DB.Save(&user).GetError(); err != nil {
t.Errorf("Got errors when save user", err.Error()) t.Errorf("Got errors when save user", err.Error())
} }
@ -330,7 +330,7 @@ func TestHasMany(t *testing.T) {
Comments: []*Comment{{Content: "Comment 1"}, {Content: "Comment 2"}}, Comments: []*Comment{{Content: "Comment 1"}, {Content: "Comment 2"}},
} }
if err := DB.Save(&post).Error; err != nil { if err := DB.Save(&post).GetError(); err != nil {
t.Errorf("Got errors when save post", err.Error()) t.Errorf("Got errors when save post", err.Error())
} }
@ -351,7 +351,7 @@ func TestHasMany(t *testing.T) {
} }
// Query // Query
if DB.First(&Comment{}, "content = ?", "Comment 1").Error != nil { if DB.First(&Comment{}, "content = ?", "Comment 1").GetError() != nil {
t.Errorf("Comment 1 should be saved") t.Errorf("Comment 1 should be saved")
} }

View File

@ -78,7 +78,8 @@ func Create(scope *Scope) {
if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil { if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
id, err := result.LastInsertId() id, err := result.LastInsertId()
if scope.Err(err) == nil { if scope.Err(err) == nil {
scope.db.RowsAffected, _ = result.RowsAffected() rowsAffected, _ := result.RowsAffected()
scope.db.SetRowsAffected(rowsAffected)
if primaryField != nil && primaryField.IsBlank { if primaryField != nil && primaryField.IsBlank {
scope.Err(scope.SetColumn(primaryField, id)) scope.Err(scope.SetColumn(primaryField, id))
} }
@ -87,13 +88,14 @@ func Create(scope *Scope) {
} else { } else {
if primaryField == nil { if primaryField == nil {
if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == nil { if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == nil {
scope.db.RowsAffected, _ = results.RowsAffected() rowsAffected, _ := results.RowsAffected()
scope.db.SetRowsAffected(rowsAffected)
} else { } else {
scope.Err(err) scope.Err(err)
} }
} else { } else {
if err := scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())); err == nil { if err := scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())); err == nil {
scope.db.RowsAffected = 1 scope.db.SetRowsAffected(1)
} else { } else {
scope.Err(err) scope.Err(err)
} }

View File

@ -45,7 +45,7 @@ func Query(scope *Scope) {
if !scope.HasError() { if !scope.HasError() {
rows, err := scope.SqlDB().Query(scope.Sql, scope.SqlVars...) rows, err := scope.SqlDB().Query(scope.Sql, scope.SqlVars...)
scope.db.RowsAffected = 0 scope.db.SetRowsAffected(0)
if scope.Err(err) != nil { if scope.Err(err) != nil {
return return
@ -54,7 +54,8 @@ func Query(scope *Scope) {
columns, _ := rows.Columns() columns, _ := rows.Columns()
for rows.Next() { for rows.Next() {
scope.db.RowsAffected++ rowsAffected := scope.db.GetRowsAffected()+1
scope.db.SetRowsAffected(rowsAffected)
anyRecordFound = true anyRecordFound = true
elem := dest elem := dest

View File

@ -18,7 +18,7 @@ func SaveBeforeAssociations(scope *Scope) {
if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored {
if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
value := field.Field value := field.Field
scope.Err(scope.NewDB().Save(value.Addr().Interface()).Error) scope.Err(scope.NewDB().Save(value.Addr().Interface()).GetError())
if len(relationship.ForeignFieldNames) != 0 { if len(relationship.ForeignFieldNames) != 0 {
for idx, fieldName := range relationship.ForeignFieldNames { for idx, fieldName := range relationship.ForeignFieldNames {
associationForeignName := relationship.AssociationForeignDBNames[idx] associationForeignName := relationship.AssociationForeignDBNames[idx]
@ -62,7 +62,7 @@ func SaveAfterAssociations(scope *Scope) {
scope.Err(newScope.SetColumn(relationship.PolymorphicType, scope.TableName())) scope.Err(newScope.SetColumn(relationship.PolymorphicType, scope.TableName()))
} }
scope.Err(newDB.Save(elem).Error) scope.Err(newDB.Save(elem).GetError())
if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil { if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil {
scope.Err(joinTableHandler.Add(joinTableHandler, scope.NewDB(), scope.Value, newScope.Value)) scope.Err(joinTableHandler.Add(joinTableHandler, scope.NewDB(), scope.Value, newScope.Value))
@ -83,7 +83,7 @@ func SaveAfterAssociations(scope *Scope) {
if relationship.PolymorphicType != "" { if relationship.PolymorphicType != "" {
scope.Err(newScope.SetColumn(relationship.PolymorphicType, scope.TableName())) scope.Err(newScope.SetColumn(relationship.PolymorphicType, scope.TableName()))
} }
scope.Err(scope.NewDB().Save(elem).Error) scope.Err(scope.NewDB().Save(elem).GetError())
} }
} }
} }

View File

@ -108,22 +108,22 @@ func TestRunCallbacks(t *testing.T) {
t.Errorf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes()) t.Errorf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes())
} }
if DB.Where("Code = ?", "unique_code").First(&p).Error == nil { if DB.Where("Code = ?", "unique_code").First(&p).GetError() == nil {
t.Errorf("Can't find a deleted record") t.Errorf("Can't find a deleted record")
} }
} }
func TestCallbacksWithErrors(t *testing.T) { func TestCallbacksWithErrors(t *testing.T) {
p := Product{Code: "Invalid", Price: 100} p := Product{Code: "Invalid", Price: 100}
if DB.Save(&p).Error == nil { if DB.Save(&p).GetError() == nil {
t.Errorf("An error from before create callbacks happened when create with invalid value") t.Errorf("An error from before create callbacks happened when create with invalid value")
} }
if DB.Where("code = ?", "Invalid").First(&Product{}).Error == nil { if DB.Where("code = ?", "Invalid").First(&Product{}).GetError() == nil {
t.Errorf("Should not save record that have errors") t.Errorf("Should not save record that have errors")
} }
if DB.Save(&Product{Code: "dont_save", Price: 100}).Error == nil { if DB.Save(&Product{Code: "dont_save", Price: 100}).GetError() == nil {
t.Errorf("An error from after create callbacks happened when create with invalid value") t.Errorf("An error from after create callbacks happened when create with invalid value")
} }
@ -131,47 +131,47 @@ func TestCallbacksWithErrors(t *testing.T) {
DB.Save(&p2) DB.Save(&p2)
p2.Code = "dont_update" p2.Code = "dont_update"
if DB.Save(&p2).Error == nil { if DB.Save(&p2).GetError() == nil {
t.Errorf("An error from before update callbacks happened when update with invalid value") t.Errorf("An error from before update callbacks happened when update with invalid value")
} }
if DB.Where("code = ?", "update_callback").First(&Product{}).Error != nil { if DB.Where("code = ?", "update_callback").First(&Product{}).GetError() != nil {
t.Errorf("Record Should not be updated due to errors happened in before update callback") t.Errorf("Record Should not be updated due to errors happened in before update callback")
} }
if DB.Where("code = ?", "dont_update").First(&Product{}).Error == nil { if DB.Where("code = ?", "dont_update").First(&Product{}).GetError() == nil {
t.Errorf("Record Should not be updated due to errors happened in before update callback") t.Errorf("Record Should not be updated due to errors happened in before update callback")
} }
p2.Code = "dont_save" p2.Code = "dont_save"
if DB.Save(&p2).Error == nil { if DB.Save(&p2).GetError() == nil {
t.Errorf("An error from before save callbacks happened when update with invalid value") t.Errorf("An error from before save callbacks happened when update with invalid value")
} }
p3 := Product{Code: "dont_delete", Price: 100} p3 := Product{Code: "dont_delete", Price: 100}
DB.Save(&p3) DB.Save(&p3)
if DB.Delete(&p3).Error == nil { if DB.Delete(&p3).GetError() == nil {
t.Errorf("An error from before delete callbacks happened when delete") t.Errorf("An error from before delete callbacks happened when delete")
} }
if DB.Where("Code = ?", "dont_delete").First(&p3).Error != nil { if DB.Where("Code = ?", "dont_delete").First(&p3).GetError() != nil {
t.Errorf("An error from before delete callbacks happened") t.Errorf("An error from before delete callbacks happened")
} }
p4 := Product{Code: "after_save_error", Price: 100} p4 := Product{Code: "after_save_error", Price: 100}
DB.Save(&p4) DB.Save(&p4)
if err := DB.First(&Product{}, "code = ?", "after_save_error").Error; err == nil { if err := DB.First(&Product{}, "code = ?", "after_save_error").GetError(); err == nil {
t.Errorf("Record should be reverted if get an error in after save callback") t.Errorf("Record should be reverted if get an error in after save callback")
} }
p5 := Product{Code: "after_delete_error", Price: 100} p5 := Product{Code: "after_delete_error", Price: 100}
DB.Save(&p5) DB.Save(&p5)
if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { if err := DB.First(&Product{}, "code = ?", "after_delete_error").GetError(); err != nil {
t.Errorf("Record should be found") t.Errorf("Record should be found")
} }
DB.Delete(&p5) DB.Delete(&p5)
if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { if err := DB.First(&Product{}, "code = ?", "after_delete_error").GetError(); err != nil {
t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback") t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback")
} }
} }

View File

@ -96,7 +96,7 @@ func (c commonDialect) HasIndex(scope *Scope, tableName string, indexName string
} }
func (commonDialect) RemoveIndex(scope *Scope, indexName string) { func (commonDialect) RemoveIndex(scope *Scope, indexName string) {
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())).Error) scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())).GetError())
} }
// RawScanInt scans the first column of the first row into the `scan' int pointer. // RawScanInt scans the first column of the first row into the `scan' int pointer.

View File

@ -15,7 +15,7 @@ func TestCreate(t *testing.T) {
t.Error("User should be new record before create") t.Error("User should be new record before create")
} }
if count := DB.Save(&user).RowsAffected; count != 1 { if count := DB.Save(&user).GetRowsAffected(); count != 1 {
t.Error("There should be one record be affected when create record") t.Error("There should be one record be affected when create record")
} }
@ -63,7 +63,7 @@ func TestCreateWithNoGORMPrimayKey(t *testing.T) {
} }
jt := JoinTable{From: 1, To: 2} jt := JoinTable{From: 1, To: 2}
err := DB.Create(&jt).Error err := DB.Create(&jt).GetError()
if err != nil { if err != nil {
t.Errorf("No error should happen when create a record without a GORM primary key. But in the database this primary key exists and is the union of 2 or more fields\n But got: %s", err) t.Errorf("No error should happen when create a record without a GORM primary key. But in the database this primary key exists and is the union of 2 or more fields\n But got: %s", err)
} }
@ -71,7 +71,7 @@ func TestCreateWithNoGORMPrimayKey(t *testing.T) {
func TestCreateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) { func TestCreateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) {
animal := Animal{Name: "Ferdinand"} animal := Animal{Name: "Ferdinand"}
if DB.Save(&animal).Error != nil { if DB.Save(&animal).GetError() != nil {
t.Errorf("No error should happen when create a record without std primary key") t.Errorf("No error should happen when create a record without std primary key")
} }
@ -86,7 +86,7 @@ func TestCreateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) {
// Test create with default value not overrided // Test create with default value not overrided
an := Animal{From: "nerdz"} an := Animal{From: "nerdz"}
if DB.Save(&an).Error != nil { if DB.Save(&an).GetError() != nil {
t.Errorf("No error should happen when create an record without std primary key") t.Errorf("No error should happen when create an record without std primary key")
} }

View File

@ -38,7 +38,7 @@ func TestCustomizeColumn(t *testing.T) {
expected := "foo" expected := "foo"
cc := CustomizeColumn{ID: 666, Name: expected, Date: time.Now()} cc := CustomizeColumn{ID: 666, Name: expected, Date: time.Now()}
if count := DB.Create(&cc).RowsAffected; count != 1 { if count := DB.Create(&cc).GetRowsAffected(); count != 1 {
t.Error("There should be one record be affected when create record") t.Error("There should be one record be affected when create record")
} }
@ -61,7 +61,7 @@ func TestCustomizeColumn(t *testing.T) {
func TestCustomColumnAndIgnoredFieldClash(t *testing.T) { func TestCustomColumnAndIgnoredFieldClash(t *testing.T) {
DB.DropTable(&CustomColumnAndIgnoredFieldClash{}) DB.DropTable(&CustomColumnAndIgnoredFieldClash{})
if err := DB.AutoMigrate(&CustomColumnAndIgnoredFieldClash{}).Error; err != nil { if err := DB.AutoMigrate(&CustomColumnAndIgnoredFieldClash{}).GetError(); err != nil {
t.Errorf("Should not raise error: %s", err) t.Errorf("Should not raise error: %s", err)
} }
} }
@ -86,17 +86,17 @@ func TestManyToManyWithCustomizedColumn(t *testing.T) {
Accounts: []CustomizeAccount{account}, Accounts: []CustomizeAccount{account},
} }
if err := DB.Create(&account).Error; err != nil { if err := DB.Create(&account).GetError(); err != nil {
t.Errorf("no error should happen, but got %v", err) t.Errorf("no error should happen, but got %v", err)
} }
if err := DB.Create(&person).Error; err != nil { if err := DB.Create(&person).GetError(); err != nil {
t.Errorf("no error should happen, but got %v", err) t.Errorf("no error should happen, but got %v", err)
} }
var person1 CustomizePerson var person1 CustomizePerson
scope := DB.NewScope(nil) scope := DB.NewScope(nil)
if err := DB.Preload("Accounts").First(&person1, scope.Quote("idPerson")+" = ?", person.IdPerson).Error; err != nil { if err := DB.Preload("Accounts").First(&person1, scope.Quote("idPerson")+" = ?", person.IdPerson).GetError(); err != nil {
t.Errorf("no error should happen when preloading customized column many2many relations, but got %v", err) t.Errorf("no error should happen when preloading customized column many2many relations, but got %v", err)
} }
@ -131,7 +131,7 @@ func TestOneToOneWithCustomizedColumn(t *testing.T) {
DB.Create(&invitation) DB.Create(&invitation)
var invitation2 CustomizeInvitation var invitation2 CustomizeInvitation
if err := DB.Preload("Person").Find(&invitation2, invitation.ID).Error; err != nil { if err := DB.Preload("Person").Find(&invitation2, invitation.ID).GetError(); err != nil {
t.Errorf("no error should happen, but got %v", err) t.Errorf("no error should happen, but got %v", err)
} }
@ -183,12 +183,12 @@ func TestOneToManyWithCustomizedColumn(t *testing.T) {
}, },
} }
if err := DB.Create(&discount).Error; err != nil { if err := DB.Create(&discount).GetError(); err != nil {
t.Errorf("no error should happen but got %v", err) t.Errorf("no error should happen but got %v", err)
} }
var discount1 PromotionDiscount var discount1 PromotionDiscount
if err := DB.Preload("Coupons").First(&discount1, "id = ?", discount.ID).Error; err != nil { if err := DB.Preload("Coupons").First(&discount1, "id = ?", discount.ID).GetError(); err != nil {
t.Errorf("no error should happen but got %v", err) t.Errorf("no error should happen but got %v", err)
} }
@ -197,7 +197,7 @@ func TestOneToManyWithCustomizedColumn(t *testing.T) {
} }
var coupon PromotionCoupon var coupon PromotionCoupon
if err := DB.Preload("Discount").First(&coupon, "code = ?", "newyear1").Error; err != nil { if err := DB.Preload("Discount").First(&coupon, "code = ?", "newyear1").GetError(); err != nil {
t.Errorf("no error should happen but got %v", err) t.Errorf("no error should happen but got %v", err)
} }
@ -221,12 +221,12 @@ func TestHasOneWithPartialCustomizedColumn(t *testing.T) {
}, },
} }
if err := DB.Create(&discount).Error; err != nil { if err := DB.Create(&discount).GetError(); err != nil {
t.Errorf("no error should happen but got %v", err) t.Errorf("no error should happen but got %v", err)
} }
var discount1 PromotionDiscount var discount1 PromotionDiscount
if err := DB.Preload("Rule").First(&discount1, "id = ?", discount.ID).Error; err != nil { if err := DB.Preload("Rule").First(&discount1, "id = ?", discount.ID).GetError(); err != nil {
t.Errorf("no error should happen but got %v", err) t.Errorf("no error should happen but got %v", err)
} }
@ -235,7 +235,7 @@ func TestHasOneWithPartialCustomizedColumn(t *testing.T) {
} }
var rule PromotionRule var rule PromotionRule
if err := DB.Preload("Discount").First(&rule, "name = ?", "time_limited").Error; err != nil { if err := DB.Preload("Discount").First(&rule, "name = ?", "time_limited").GetError(); err != nil {
t.Errorf("no error should happen but got %v", err) t.Errorf("no error should happen but got %v", err)
} }
@ -256,12 +256,12 @@ func TestBelongsToWithPartialCustomizedColumn(t *testing.T) {
}, },
} }
if err := DB.Create(&discount).Error; err != nil { if err := DB.Create(&discount).GetError(); err != nil {
t.Errorf("no error should happen but got %v", err) t.Errorf("no error should happen but got %v", err)
} }
var discount1 PromotionDiscount var discount1 PromotionDiscount
if err := DB.Preload("Benefits").First(&discount1, "id = ?", discount.ID).Error; err != nil { if err := DB.Preload("Benefits").First(&discount1, "id = ?", discount.ID).GetError(); err != nil {
t.Errorf("no error should happen but got %v", err) t.Errorf("no error should happen but got %v", err)
} }
@ -270,7 +270,7 @@ func TestBelongsToWithPartialCustomizedColumn(t *testing.T) {
} }
var benefit PromotionBenefit var benefit PromotionBenefit
if err := DB.Preload("Discount").First(&benefit, "name = ?", "free cod").Error; err != nil { if err := DB.Preload("Discount").First(&benefit, "name = ?", "free cod").GetError(); err != nil {
t.Errorf("no error should happen but got %v", err) t.Errorf("no error should happen but got %v", err)
} }

View File

@ -18,7 +18,7 @@ func TestDdlErrors(t *testing.T) {
}() }()
DB.HasTable("foobarbaz") DB.HasTable("foobarbaz")
if DB.Error == nil { if DB.GetError() == nil {
t.Errorf("Expected operation on closed db to produce an error, but err was nil") t.Errorf("Expected operation on closed db to produce an error, but err was nil")
} }
} }

View File

@ -10,7 +10,7 @@ func TestDelete(t *testing.T) {
DB.Save(&user1) DB.Save(&user1)
DB.Save(&user2) DB.Save(&user2)
if err := DB.Delete(&user1).Error; err != nil { if err := DB.Delete(&user1).GetError(); err != nil {
t.Errorf("No error should happen when delete a record, err=%s", err) t.Errorf("No error should happen when delete a record, err=%s", err)
} }
@ -28,13 +28,13 @@ func TestInlineDelete(t *testing.T) {
DB.Save(&user1) DB.Save(&user1)
DB.Save(&user2) DB.Save(&user2)
if DB.Delete(&User{}, user1.Id).Error != nil { if DB.Delete(&User{}, user1.Id).GetError() != nil {
t.Errorf("No error should happen when delete a record") t.Errorf("No error should happen when delete a record")
} else if !DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() { } else if !DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() {
t.Errorf("User can't be found after delete") t.Errorf("User can't be found after delete")
} }
if err := DB.Delete(&User{}, "name = ?", user2.Name).Error; err != nil { if err := DB.Delete(&User{}, "name = ?", user2.Name).GetError(); err != nil {
t.Errorf("No error should happen when delete a record, err=%s", err) t.Errorf("No error should happen when delete a record, err=%s", err)
} else if !DB.Where("name = ?", user2.Name).First(&User{}).RecordNotFound() { } else if !DB.Where("name = ?", user2.Name).First(&User{}).RecordNotFound() {
t.Errorf("User can't be found after delete") t.Errorf("User can't be found after delete")
@ -53,11 +53,11 @@ func TestSoftDelete(t *testing.T) {
DB.Save(&user) DB.Save(&user)
DB.Delete(&user) DB.Delete(&user)
if DB.First(&User{}, "name = ?", user.Name).Error == nil { if DB.First(&User{}, "name = ?", user.Name).GetError() == nil {
t.Errorf("Can't find a soft deleted record") t.Errorf("Can't find a soft deleted record")
} }
if err := DB.Unscoped().First(&User{}, "name = ?", user.Name).Error; err != nil { if err := DB.Unscoped().First(&User{}, "name = ?", user.Name).GetError(); err != nil {
t.Errorf("Should be able to find soft deleted record with Unscoped, but err=%s", err) t.Errorf("Should be able to find soft deleted record with Unscoped, but err=%s", err)
} }

View File

@ -22,7 +22,7 @@ func TestSaveAndQueryEmbeddedStruct(t *testing.T) {
DB.Save(&HNPost{BasePost: BasePost{Title: "news"}}) DB.Save(&HNPost{BasePost: BasePost{Title: "news"}})
DB.Save(&HNPost{BasePost: BasePost{Title: "hn_news"}}) DB.Save(&HNPost{BasePost: BasePost{Title: "hn_news"}})
var news HNPost var news HNPost
if err := DB.First(&news, "title = ?", "hn_news").Error; err != nil { if err := DB.First(&news, "title = ?", "hn_news").GetError(); err != nil {
t.Errorf("no error should happen when query with embedded struct, but got %v", err) t.Errorf("no error should happen when query with embedded struct, but got %v", err)
} else if news.Title != "hn_news" { } else if news.Title != "hn_news" {
t.Errorf("embedded struct's value should be scanned correctly") t.Errorf("embedded struct's value should be scanned correctly")
@ -30,7 +30,7 @@ func TestSaveAndQueryEmbeddedStruct(t *testing.T) {
DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}}) DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}})
var egNews EngadgetPost var egNews EngadgetPost
if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil { if err := DB.First(&egNews, "title = ?", "engadget_news").GetError(); err != nil {
t.Errorf("no error should happen when query with embedded struct, but got %v", err) t.Errorf("no error should happen when query with embedded struct, but got %v", err)
} else if egNews.BasePost.Title != "engadget_news" { } else if egNews.BasePost.Title != "engadget_news" {
t.Errorf("embedded struct's value should be scanned correctly") t.Errorf("embedded struct's value should be scanned correctly")

View File

@ -1,6 +1,8 @@
package gorm package gorm
import "database/sql" import (
"database/sql"
)
type sqlCommon interface { type sqlCommon interface {
Exec(query string, args ...interface{}) (sql.Result, error) Exec(query string, args ...interface{}) (sql.Result, error)
@ -17,3 +19,95 @@ type sqlTx interface {
Commit() error Commit() error
Rollback() error Rollback() error
} }
type Database interface {
Close() error
DB() *sql.DB
New() Database
NewScope(value interface{}) *Scope
CommonDB() sqlCommon
Callback() *callback
SetLogger(l logger)
LogMode(enable bool) Database
SingularTable(enable bool)
Where(query interface{}, args ...interface{}) Database
Or(query interface{}, args ...interface{}) Database
Not(query interface{}, args ...interface{}) Database
Limit(value interface{}) Database
Offset(value interface{}) Database
Order(value string, reorder ...bool) Database
Select(query interface{}, args ...interface{}) Database
Omit(columns ...string) Database
Group(query string) Database
Having(query string, values ...interface{}) Database
Joins(query string) Database
//Scopes(funcs ...func(Database) Database) Database
Scopes(funcs ...func(*DB) *DB) *DB
Unscoped() Database
Attrs(attrs ...interface{}) Database
Assign(attrs ...interface{}) Database
First(out interface{}, where ...interface{}) Database
Last(out interface{}, where ...interface{}) Database
Find(out interface{}, where ...interface{}) Database
Scan(dest interface{}) Database
Row() *sql.Row
Rows() (*sql.Rows, error)
Pluck(column string, value interface{}) Database
Count(value interface{}) Database
Related(value interface{}, foreignKeys ...string) Database
FirstOrInit(out interface{}, where ...interface{}) Database
FirstOrCreate(out interface{}, where ...interface{}) Database
Update(attrs ...interface{}) Database
Updates(values interface{}, ignoreProtectedAttrs ...bool) Database
UpdateColumn(attrs ...interface{}) Database
UpdateColumns(values interface{}) Database
Save(value interface{}) Database
Create(value interface{}) Database
Delete(value interface{}, where ...interface{}) Database
Raw(sql string, values ...interface{}) Database
Exec(sql string, values ...interface{}) Database
Model(value interface{}) Database
Table(name string) Database
Debug() Database
Begin() Database
Commit() Database
Rollback() Database
NewRecord(value interface{}) bool
RecordNotFound() bool
CreateTable(values ...interface{}) Database
DropTable(values ...interface{}) Database
DropTableIfExists(values ...interface{}) Database
HasTable(value interface{}) bool
AutoMigrate(values ...interface{}) Database
ModifyColumn(column string, typ string) Database
DropColumn(column string) Database
AddIndex(indexName string, column ...string) Database
AddUniqueIndex(indexName string, column ...string) Database
RemoveIndex(indexName string) Database
CurrentDatabase() string
AddForeignKey(field string, dest string, onDelete string, onUpdate string) Database
Association(column string) *Association
Preload(column string, conditions ...interface{}) Database
Set(name string, value interface{}) Database
InstantSet(name string, value interface{}) Database
Get(name string) (value interface{}, ok bool)
SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface)
AddError(err error) error
GetError() error
GetErrors() (errors []error)
GetRowsAffected() int64
SetRowsAffected(num int64)
}

View File

@ -9,10 +9,10 @@ import (
type JoinTableHandlerInterface interface { type JoinTableHandlerInterface interface {
Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type)
Table(db *DB) string Table(db Database) string
Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error Add(handler JoinTableHandlerInterface, db Database, source interface{}, destination interface{}) error
Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error Delete(handler JoinTableHandlerInterface, db Database, sources ...interface{}) error
JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB JoinWith(handler JoinTableHandlerInterface, db Database, source interface{}) Database
SourceForeignKeys() []JoinTableForeignKey SourceForeignKeys() []JoinTableForeignKey
DestinationForeignKeys() []JoinTableForeignKey DestinationForeignKeys() []JoinTableForeignKey
} }
@ -61,11 +61,11 @@ func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, s
} }
} }
func (s JoinTableHandler) Table(db *DB) string { func (s JoinTableHandler) Table(db Database) string {
return s.TableName return s.TableName
} }
func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[string]interface{} { func (s JoinTableHandler) GetSearchMap(db Database, sources ...interface{}) map[string]interface{} {
values := map[string]interface{}{} values := map[string]interface{}{}
for _, source := range sources { for _, source := range sources {
@ -85,7 +85,7 @@ func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[strin
return values return values
} }
func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1 interface{}, source2 interface{}) error { func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db Database, source1 interface{}, source2 interface{}) error {
scope := db.NewScope("") scope := db.NewScope("")
searchMap := s.GetSearchMap(db, source1, source2) searchMap := s.GetSearchMap(db, source1, source2)
@ -113,10 +113,10 @@ func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1
strings.Join(conditions, " AND "), strings.Join(conditions, " AND "),
) )
return db.Exec(sql, values...).Error return db.Exec(sql, values...).GetError()
} }
func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error { func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db Database, sources ...interface{}) error {
var ( var (
scope = db.NewScope(nil) scope = db.NewScope(nil)
conditions []string conditions []string
@ -128,10 +128,10 @@ func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sour
values = append(values, value) values = append(values, value)
} }
return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").GetError()
} }
func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB { func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db Database, source interface{}) Database {
var ( var (
scope = db.NewScope(source) scope = db.NewScope(source)
tableName = handler.Table(db) tableName = handler.Table(db)
@ -174,7 +174,7 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so
return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTableName, strings.Join(joinConditions, " AND "))). return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTableName, strings.Join(joinConditions, " AND "))).
Where(condString, toQueryValues(foreignFieldValues)...) Where(condString, toQueryValues(foreignFieldValues)...)
} else { } else {
db.Error = errors.New("wrong source type for join table handler") db.AddError(errors.New("wrong source type for join table handler"))
return db return db
} }
} }

View File

@ -22,7 +22,7 @@ type PersonAddress struct {
CreatedAt time.Time CreatedAt time.Time
} }
func (*PersonAddress) Add(handler gorm.JoinTableHandlerInterface, db *gorm.DB, foreignValue interface{}, associationValue interface{}) error { func (*PersonAddress) Add(handler gorm.JoinTableHandlerInterface, db gorm.Database, foreignValue interface{}, associationValue interface{}) error {
return db.Where(map[string]interface{}{ return db.Where(map[string]interface{}{
"person_id": db.NewScope(foreignValue).PrimaryKeyValue(), "person_id": db.NewScope(foreignValue).PrimaryKeyValue(),
"address_id": db.NewScope(associationValue).PrimaryKeyValue(), "address_id": db.NewScope(associationValue).PrimaryKeyValue(),
@ -30,14 +30,14 @@ func (*PersonAddress) Add(handler gorm.JoinTableHandlerInterface, db *gorm.DB, f
"person_id": foreignValue, "person_id": foreignValue,
"address_id": associationValue, "address_id": associationValue,
"deleted_at": gorm.Expr("NULL"), "deleted_at": gorm.Expr("NULL"),
}).FirstOrCreate(&PersonAddress{}).Error }).FirstOrCreate(&PersonAddress{}).GetError()
} }
func (*PersonAddress) Delete(handler gorm.JoinTableHandlerInterface, db *gorm.DB, sources ...interface{}) error { func (*PersonAddress) Delete(handler gorm.JoinTableHandlerInterface, db gorm.Database, sources ...interface{}) error {
return db.Delete(&PersonAddress{}).Error return db.Delete(&PersonAddress{}).GetError()
} }
func (pa *PersonAddress) JoinWith(handler gorm.JoinTableHandlerInterface, db *gorm.DB, source interface{}) *gorm.DB { func (pa *PersonAddress) JoinWith(handler gorm.JoinTableHandlerInterface, db gorm.Database, source interface{}) gorm.Database {
table := pa.Table(db) table := pa.Table(db)
return db.Joins("INNER JOIN person_addresses ON person_addresses.address_id = addresses.id").Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table)) return db.Joins("INNER JOIN person_addresses ON person_addresses.address_id = addresses.id").Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table))
} }
@ -54,7 +54,7 @@ func TestJoinTable(t *testing.T) {
DB.Model(person).Association("Addresses").Delete(address1) DB.Model(person).Association("Addresses").Delete(address1)
if DB.Find(&[]PersonAddress{}, "person_id = ?", person.Id).RowsAffected != 1 { if DB.Find(&[]PersonAddress{}, "person_id = ?", person.Id).GetRowsAffected() != 1 {
t.Errorf("Should found one address") t.Errorf("Should found one address")
} }
@ -62,7 +62,7 @@ func TestJoinTable(t *testing.T) {
t.Errorf("Should found one address") t.Errorf("Should found one address")
} }
if DB.Unscoped().Find(&[]PersonAddress{}, "person_id = ?", person.Id).RowsAffected != 2 { if DB.Unscoped().Find(&[]PersonAddress{}, "person_id = ?", person.Id).GetRowsAffected() != 2 {
t.Errorf("Found two addresses with Unscoped") t.Errorf("Found two addresses with Unscoped")
} }

140
main.go
View File

@ -90,7 +90,7 @@ func (s *DB) DB() *sql.DB {
return s.db.(*sql.DB) return s.db.(*sql.DB)
} }
func (s *DB) New() *DB { func (s *DB) New() Database {
clone := s.clone() clone := s.clone()
clone.search = nil clone.search = nil
clone.Value = nil clone.Value = nil
@ -120,7 +120,7 @@ func (s *DB) SetLogger(l logger) {
s.logger = l s.logger = l
} }
func (s *DB) LogMode(enable bool) *DB { func (s *DB) LogMode(enable bool) Database {
if enable { if enable {
s.logMode = 2 s.logMode = 2
} else { } else {
@ -134,47 +134,47 @@ func (s *DB) SingularTable(enable bool) {
s.parent.singularTable = enable s.parent.singularTable = enable
} }
func (s *DB) Where(query interface{}, args ...interface{}) *DB { func (s *DB) Where(query interface{}, args ...interface{}) Database {
return s.clone().search.Where(query, args...).db return s.clone().search.Where(query, args...).db
} }
func (s *DB) Or(query interface{}, args ...interface{}) *DB { func (s *DB) Or(query interface{}, args ...interface{}) Database {
return s.clone().search.Or(query, args...).db return s.clone().search.Or(query, args...).db
} }
func (s *DB) Not(query interface{}, args ...interface{}) *DB { func (s *DB) Not(query interface{}, args ...interface{}) Database {
return s.clone().search.Not(query, args...).db return s.clone().search.Not(query, args...).db
} }
func (s *DB) Limit(value interface{}) *DB { func (s *DB) Limit(value interface{}) Database {
return s.clone().search.Limit(value).db return s.clone().search.Limit(value).db
} }
func (s *DB) Offset(value interface{}) *DB { func (s *DB) Offset(value interface{}) Database {
return s.clone().search.Offset(value).db return s.clone().search.Offset(value).db
} }
func (s *DB) Order(value string, reorder ...bool) *DB { func (s *DB) Order(value string, reorder ...bool) Database {
return s.clone().search.Order(value, reorder...).db return s.clone().search.Order(value, reorder...).db
} }
func (s *DB) Select(query interface{}, args ...interface{}) *DB { func (s *DB) Select(query interface{}, args ...interface{}) Database {
return s.clone().search.Select(query, args...).db return s.clone().search.Select(query, args...).db
} }
func (s *DB) Omit(columns ...string) *DB { func (s *DB) Omit(columns ...string) Database {
return s.clone().search.Omit(columns...).db return s.clone().search.Omit(columns...).db
} }
func (s *DB) Group(query string) *DB { func (s *DB) Group(query string) Database {
return s.clone().search.Group(query).db return s.clone().search.Group(query).db
} }
func (s *DB) Having(query string, values ...interface{}) *DB { func (s *DB) Having(query string, values ...interface{}) Database {
return s.clone().search.Having(query, values...).db return s.clone().search.Having(query, values...).db
} }
func (s *DB) Joins(query string) *DB { func (s *DB) Joins(query string) Database {
return s.clone().search.Joins(query).db return s.clone().search.Joins(query).db
} }
@ -185,37 +185,37 @@ func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB {
return s return s
} }
func (s *DB) Unscoped() *DB { func (s *DB) Unscoped() Database {
return s.clone().search.unscoped().db return s.clone().search.unscoped().db
} }
func (s *DB) Attrs(attrs ...interface{}) *DB { func (s *DB) Attrs(attrs ...interface{}) Database {
return s.clone().search.Attrs(attrs...).db return s.clone().search.Attrs(attrs...).db
} }
func (s *DB) Assign(attrs ...interface{}) *DB { func (s *DB) Assign(attrs ...interface{}) Database {
return s.clone().search.Assign(attrs...).db return s.clone().search.Assign(attrs...).db
} }
func (s *DB) First(out interface{}, where ...interface{}) *DB { func (s *DB) First(out interface{}, where ...interface{}) Database {
newScope := s.clone().NewScope(out) newScope := s.clone().NewScope(out)
newScope.Search.Limit(1) newScope.Search.Limit(1)
return newScope.Set("gorm:order_by_primary_key", "ASC"). return newScope.Set("gorm:order_by_primary_key", "ASC").
inlineCondition(where...).callCallbacks(s.parent.callback.queries).db inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
} }
func (s *DB) Last(out interface{}, where ...interface{}) *DB { func (s *DB) Last(out interface{}, where ...interface{}) Database {
newScope := s.clone().NewScope(out) newScope := s.clone().NewScope(out)
newScope.Search.Limit(1) newScope.Search.Limit(1)
return newScope.Set("gorm:order_by_primary_key", "DESC"). return newScope.Set("gorm:order_by_primary_key", "DESC").
inlineCondition(where...).callCallbacks(s.parent.callback.queries).db inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
} }
func (s *DB) Find(out interface{}, where ...interface{}) *DB { func (s *DB) Find(out interface{}, where ...interface{}) Database {
return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callback.queries).db return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
} }
func (s *DB) Scan(dest interface{}) *DB { func (s *DB) Scan(dest interface{}) Database {
return s.clone().NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callback.queries).db return s.clone().NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callback.queries).db
} }
@ -227,21 +227,21 @@ func (s *DB) Rows() (*sql.Rows, error) {
return s.NewScope(s.Value).rows() return s.NewScope(s.Value).rows()
} }
func (s *DB) Pluck(column string, value interface{}) *DB { func (s *DB) Pluck(column string, value interface{}) Database {
return s.NewScope(s.Value).pluck(column, value).db return s.NewScope(s.Value).pluck(column, value).db
} }
func (s *DB) Count(value interface{}) *DB { func (s *DB) Count(value interface{}) Database {
return s.NewScope(s.Value).count(value).db return s.NewScope(s.Value).count(value).db
} }
func (s *DB) Related(value interface{}, foreignKeys ...string) *DB { func (s *DB) Related(value interface{}, foreignKeys ...string) Database {
return s.clone().NewScope(s.Value).related(value, foreignKeys...).db return s.clone().NewScope(s.Value).related(value, foreignKeys...).db
} }
func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { func (s *DB) FirstOrInit(out interface{}, where ...interface{}) Database {
c := s.clone() c := s.clone()
if result := c.First(out, where...); result.Error != nil { if result := c.First(out, where...); result.GetError() != nil {
if !result.RecordNotFound() { if !result.RecordNotFound() {
return result return result
} }
@ -252,35 +252,35 @@ func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
return c return c
} }
func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB { func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) Database {
c := s.clone() c := s.clone()
if result := c.First(out, where...); result.Error != nil { if result := c.First(out, where...); result.GetError() != nil {
if !result.RecordNotFound() { if !result.RecordNotFound() {
return result return result
} }
c.AddError(c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callback.creates).db.Error) c.AddError(c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callback.creates).db.GetError())
} else if len(c.search.assignAttrs) > 0 { } else if len(c.search.assignAttrs) > 0 {
c.AddError(c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callback.updates).db.Error) c.AddError(c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callback.updates).db.GetError())
} }
return c return c
} }
func (s *DB) Update(attrs ...interface{}) *DB { func (s *DB) Update(attrs ...interface{}) Database {
return s.Updates(toSearchableMap(attrs...), true) return s.Updates(toSearchableMap(attrs...), true)
} }
func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB { func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) Database {
return s.clone().NewScope(s.Value). return s.clone().NewScope(s.Value).
Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0). Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0).
InstanceSet("gorm:update_interface", values). InstanceSet("gorm:update_interface", values).
callCallbacks(s.parent.callback.updates).db callCallbacks(s.parent.callback.updates).db
} }
func (s *DB) UpdateColumn(attrs ...interface{}) *DB { func (s *DB) UpdateColumn(attrs ...interface{}) Database {
return s.UpdateColumns(toSearchableMap(attrs...)) return s.UpdateColumns(toSearchableMap(attrs...))
} }
func (s *DB) UpdateColumns(values interface{}) *DB { func (s *DB) UpdateColumns(values interface{}) Database {
return s.clone().NewScope(s.Value). return s.clone().NewScope(s.Value).
Set("gorm:update_column", true). Set("gorm:update_column", true).
Set("gorm:save_associations", false). Set("gorm:save_associations", false).
@ -288,7 +288,7 @@ func (s *DB) UpdateColumns(values interface{}) *DB {
callCallbacks(s.parent.callback.updates).db callCallbacks(s.parent.callback.updates).db
} }
func (s *DB) Save(value interface{}) *DB { func (s *DB) Save(value interface{}) Database {
scope := s.clone().NewScope(value) scope := s.clone().NewScope(value)
if scope.PrimaryKeyZero() { if scope.PrimaryKeyZero() {
return scope.callCallbacks(s.parent.callback.creates).db return scope.callCallbacks(s.parent.callback.creates).db
@ -296,20 +296,20 @@ func (s *DB) Save(value interface{}) *DB {
return scope.callCallbacks(s.parent.callback.updates).db return scope.callCallbacks(s.parent.callback.updates).db
} }
func (s *DB) Create(value interface{}) *DB { func (s *DB) Create(value interface{}) Database {
scope := s.clone().NewScope(value) scope := s.clone().NewScope(value)
return scope.callCallbacks(s.parent.callback.creates).db return scope.callCallbacks(s.parent.callback.creates).db
} }
func (s *DB) Delete(value interface{}, where ...interface{}) *DB { func (s *DB) Delete(value interface{}, where ...interface{}) Database {
return s.clone().NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callback.deletes).db return s.clone().NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callback.deletes).db
} }
func (s *DB) Raw(sql string, values ...interface{}) *DB { func (s *DB) Raw(sql string, values ...interface{}) Database {
return s.clone().search.Raw(true).Where(sql, values...).db return s.clone().search.Raw(true).Where(sql, values...).db
} }
func (s *DB) Exec(sql string, values ...interface{}) *DB { func (s *DB) Exec(sql string, values ...interface{}) Database {
scope := s.clone().NewScope(nil) scope := s.clone().NewScope(nil)
generatedSql := scope.buildWhereCondition(map[string]interface{}{"query": sql, "args": values}) generatedSql := scope.buildWhereCondition(map[string]interface{}{"query": sql, "args": values})
generatedSql = strings.TrimSuffix(strings.TrimPrefix(generatedSql, "("), ")") generatedSql = strings.TrimSuffix(strings.TrimPrefix(generatedSql, "("), ")")
@ -317,24 +317,24 @@ func (s *DB) Exec(sql string, values ...interface{}) *DB {
return scope.Exec().db return scope.Exec().db
} }
func (s *DB) Model(value interface{}) *DB { func (s *DB) Model(value interface{}) Database {
c := s.clone() c := s.clone()
c.Value = value c.Value = value
return c return c
} }
func (s *DB) Table(name string) *DB { func (s *DB) Table(name string) Database {
clone := s.clone() clone := s.clone()
clone.search.Table(name) clone.search.Table(name)
clone.Value = nil clone.Value = nil
return clone return clone
} }
func (s *DB) Debug() *DB { func (s *DB) Debug() Database {
return s.clone().LogMode(true) return s.clone().LogMode(true)
} }
func (s *DB) Begin() *DB { func (s *DB) Begin() Database {
c := s.clone() c := s.clone()
if db, ok := c.db.(sqlDb); ok { if db, ok := c.db.(sqlDb); ok {
tx, err := db.Begin() tx, err := db.Begin()
@ -346,7 +346,7 @@ func (s *DB) Begin() *DB {
return c return c
} }
func (s *DB) Commit() *DB { func (s *DB) Commit() Database {
if db, ok := s.db.(sqlTx); ok { if db, ok := s.db.(sqlTx); ok {
s.AddError(db.Commit()) s.AddError(db.Commit())
} else { } else {
@ -355,7 +355,7 @@ func (s *DB) Commit() *DB {
return s return s
} }
func (s *DB) Rollback() *DB { func (s *DB) Rollback() Database {
if db, ok := s.db.(sqlTx); ok { if db, ok := s.db.(sqlTx); ok {
s.AddError(db.Rollback()) s.AddError(db.Rollback())
} else { } else {
@ -373,16 +373,18 @@ func (s *DB) RecordNotFound() bool {
} }
// Migrations // Migrations
func (s *DB) CreateTable(values ...interface{}) *DB { func (s *DB) CreateTable(values ...interface{}) Database {
db := s.clone() var db Database
db = s.clone()
for _, value := range values { for _, value := range values {
db = db.NewScope(value).createTable().db db = db.NewScope(value).createTable().db
} }
return db return db
} }
func (s *DB) DropTable(values ...interface{}) *DB { func (s *DB) DropTable(values ...interface{}) Database {
db := s.clone() var db Database
db = s.clone()
for _, value := range values { for _, value := range values {
if tableName, ok := value.(string); ok { if tableName, ok := value.(string); ok {
db = db.Table(tableName) db = db.Table(tableName)
@ -393,8 +395,9 @@ func (s *DB) DropTable(values ...interface{}) *DB {
return db return db
} }
func (s *DB) DropTableIfExists(values ...interface{}) *DB { func (s *DB) DropTableIfExists(values ...interface{}) Database {
db := s.clone() var db Database
db = s.clone()
for _, value := range values { for _, value := range values {
if tableName, ok := value.(string); ok { if tableName, ok := value.(string); ok {
db = db.Table(tableName) db = db.Table(tableName)
@ -409,43 +412,44 @@ func (s *DB) HasTable(value interface{}) bool {
scope := s.clone().NewScope(value) scope := s.clone().NewScope(value)
tableName := scope.TableName() tableName := scope.TableName()
has := scope.Dialect().HasTable(scope, tableName) has := scope.Dialect().HasTable(scope, tableName)
s.AddError(scope.db.Error) s.AddError(scope.db.GetError())
return has return has
} }
func (s *DB) AutoMigrate(values ...interface{}) *DB { func (s *DB) AutoMigrate(values ...interface{}) Database {
db := s.clone() var db Database
db = s.clone()
for _, value := range values { for _, value := range values {
db = db.NewScope(value).NeedPtr().autoMigrate().db db = db.NewScope(value).NeedPtr().autoMigrate().db
} }
return db return db
} }
func (s *DB) ModifyColumn(column string, typ string) *DB { func (s *DB) ModifyColumn(column string, typ string) Database {
scope := s.clone().NewScope(s.Value) scope := s.clone().NewScope(s.Value)
scope.modifyColumn(column, typ) scope.modifyColumn(column, typ)
return scope.db return scope.db
} }
func (s *DB) DropColumn(column string) *DB { func (s *DB) DropColumn(column string) Database {
scope := s.clone().NewScope(s.Value) scope := s.clone().NewScope(s.Value)
scope.dropColumn(column) scope.dropColumn(column)
return scope.db return scope.db
} }
func (s *DB) AddIndex(indexName string, column ...string) *DB { func (s *DB) AddIndex(indexName string, column ...string) Database {
scope := s.Unscoped().NewScope(s.Value) scope := s.Unscoped().NewScope(s.Value)
scope.addIndex(false, indexName, column...) scope.addIndex(false, indexName, column...)
return scope.db return scope.db
} }
func (s *DB) AddUniqueIndex(indexName string, column ...string) *DB { func (s *DB) AddUniqueIndex(indexName string, column ...string) Database {
scope := s.clone().NewScope(s.Value) scope := s.clone().NewScope(s.Value)
scope.addIndex(true, indexName, column...) scope.addIndex(true, indexName, column...)
return scope.db return scope.db
} }
func (s *DB) RemoveIndex(indexName string) *DB { func (s *DB) RemoveIndex(indexName string) Database {
scope := s.clone().NewScope(s.Value) scope := s.clone().NewScope(s.Value)
scope.removeIndex(indexName) scope.removeIndex(indexName)
return scope.db return scope.db
@ -465,7 +469,7 @@ Add foreign key to the given scope
Example: Example:
db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT") db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT")
*/ */
func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB { func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) Database {
scope := s.clone().NewScope(s.Value) scope := s.clone().NewScope(s.Value)
scope.addForeignKey(field, dest, onDelete, onUpdate) scope.addForeignKey(field, dest, onDelete, onUpdate)
return scope.db return scope.db
@ -492,16 +496,16 @@ func (s *DB) Association(column string) *Association {
return &Association{Error: err} return &Association{Error: err}
} }
func (s *DB) Preload(column string, conditions ...interface{}) *DB { func (s *DB) Preload(column string, conditions ...interface{}) Database {
return s.clone().search.Preload(column, conditions...).db return s.clone().search.Preload(column, conditions...).db
} }
// Set set value by name // Set set value by name
func (s *DB) Set(name string, value interface{}) *DB { func (s *DB) Set(name string, value interface{}) Database {
return s.clone().InstantSet(name, value) return s.clone().InstantSet(name, value)
} }
func (s *DB) InstantSet(name string, value interface{}) *DB { func (s *DB) InstantSet(name string, value interface{}) Database {
s.values[name] = value s.values[name] = value
return s return s
} }
@ -550,6 +554,18 @@ func (s *DB) AddError(err error) error {
return err return err
} }
func (s *DB) GetError() error {
return s.Error
}
func (s *DB) SetRowsAffected(num int64) {
s.RowsAffected = num
}
func (s *DB) GetRowsAffected() int64 {
return s.RowsAffected
}
func (s *DB) GetErrors() (errors []error) { func (s *DB) GetErrors() (errors []error) {
if errs, ok := s.Error.(errorsInterface); ok { if errs, ok := s.Error.(errorsInterface); ok {
return errs.GetErrors() return errs.GetErrors()

View File

@ -20,7 +20,7 @@ import (
) )
var ( var (
DB gorm.DB DB gorm.Database
t1, t2, t3, t4, t5 time.Time t1, t2, t3, t4, t5 time.Time
) )
@ -42,7 +42,11 @@ func init() {
runMigration() runMigration()
} }
func OpenTestConnection() (db gorm.DB, err error) { func OpenTestConnection() (*gorm.DB, error) {
var (
db gorm.DB
err error
)
switch os.Getenv("GORM_DIALECT") { switch os.Getenv("GORM_DIALECT") {
case "mysql": case "mysql":
// CREATE USER 'gorm'@'localhost' IDENTIFIED BY 'gorm'; // CREATE USER 'gorm'@'localhost' IDENTIFIED BY 'gorm';
@ -63,7 +67,8 @@ func OpenTestConnection() (db gorm.DB, err error) {
fmt.Println("testing sqlite3...") fmt.Println("testing sqlite3...")
db, err = gorm.Open("sqlite3", "/tmp/gorm.db") db, err = gorm.Open("sqlite3", "/tmp/gorm.db")
} }
return
return &db, err
} }
func TestStringPrimaryKey(t *testing.T) { func TestStringPrimaryKey(t *testing.T) {
@ -74,22 +79,22 @@ func TestStringPrimaryKey(t *testing.T) {
DB.AutoMigrate(&UUIDStruct{}) DB.AutoMigrate(&UUIDStruct{})
data := UUIDStruct{ID: "uuid", Name: "hello"} data := UUIDStruct{ID: "uuid", Name: "hello"}
if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" { if err := DB.Save(&data).GetError(); err != nil || data.ID != "uuid" {
t.Errorf("string primary key should not be populated") t.Errorf("string primary key should not be populated")
} }
} }
func TestExceptionsWithInvalidSql(t *testing.T) { func TestExceptionsWithInvalidSql(t *testing.T) {
var columns []string var columns []string
if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).GetError() == nil {
t.Errorf("Should got error with invalid SQL") t.Errorf("Should got error with invalid SQL")
} }
if DB.Model(&User{}).Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { if DB.Model(&User{}).Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).GetError() == nil {
t.Errorf("Should got error with invalid SQL") t.Errorf("Should got error with invalid SQL")
} }
if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Find(&User{}).Error == nil { if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Find(&User{}).GetError() == nil {
t.Errorf("Should got error with invalid SQL") t.Errorf("Should got error with invalid SQL")
} }
@ -99,7 +104,7 @@ func TestExceptionsWithInvalidSql(t *testing.T) {
t.Errorf("Should find some users") t.Errorf("Should find some users")
} }
if DB.Where("name = ?", "jinzhu; delete * from users").First(&User{}).Error == nil { if DB.Where("name = ?", "jinzhu; delete * from users").First(&User{}).GetError() == nil {
t.Errorf("Should got error with invalid SQL") t.Errorf("Should got error with invalid SQL")
} }
@ -114,21 +119,21 @@ func TestSetTable(t *testing.T) {
DB.Create(getPreparedUser("pluck_user2", "pluck_user")) DB.Create(getPreparedUser("pluck_user2", "pluck_user"))
DB.Create(getPreparedUser("pluck_user3", "pluck_user")) DB.Create(getPreparedUser("pluck_user3", "pluck_user"))
if err := DB.Table("users").Where("role = ?", "pluck_user").Pluck("age", &[]int{}).Error; err != nil { if err := DB.Table("users").Where("role = ?", "pluck_user").Pluck("age", &[]int{}).GetError(); err != nil {
t.Errorf("No errors should happen if set table for pluck", err.Error()) t.Errorf("No errors should happen if set table for pluck", err.Error())
} }
var users []User var users []User
if DB.Table("users").Find(&[]User{}).Error != nil { if DB.Table("users").Find(&[]User{}).GetError() != nil {
t.Errorf("No errors should happen if set table for find") t.Errorf("No errors should happen if set table for find")
} }
if DB.Table("invalid_table").Find(&users).Error == nil { if DB.Table("invalid_table").Find(&users).GetError() == nil {
t.Errorf("Should got error when table is set to an invalid table") t.Errorf("Should got error when table is set to an invalid table")
} }
DB.Exec("drop table deleted_users;") DB.Exec("drop table deleted_users;")
if DB.Table("deleted_users").CreateTable(&User{}).Error != nil { if DB.Table("deleted_users").CreateTable(&User{}).GetError() != nil {
t.Errorf("Create table with specified table") t.Errorf("Create table with specified table")
} }
@ -168,7 +173,7 @@ func TestHasTable(t *testing.T) {
if ok := DB.HasTable(&Foo{}); ok { if ok := DB.HasTable(&Foo{}); ok {
t.Errorf("Table should not exist, but does") t.Errorf("Table should not exist, but does")
} }
if err := DB.CreateTable(&Foo{}).Error; err != nil { if err := DB.CreateTable(&Foo{}).GetError(); err != nil {
t.Errorf("Table should be created") t.Errorf("Table should be created")
} }
if ok := DB.HasTable(&Foo{}); !ok { if ok := DB.HasTable(&Foo{}); !ok {
@ -240,7 +245,7 @@ func TestNullValues(t *testing.T) {
Male: sql.NullBool{Bool: true, Valid: true}, Male: sql.NullBool{Bool: true, Valid: true},
Height: sql.NullFloat64{Float64: 100.11, Valid: true}, Height: sql.NullFloat64{Float64: 100.11, Valid: true},
AddedAt: NullTime{Time: time.Now(), Valid: true}, AddedAt: NullTime{Time: time.Now(), Valid: true},
}).Error; err != nil { }).GetError(); err != nil {
t.Errorf("Not error should raise when test null value") t.Errorf("Not error should raise when test null value")
} }
@ -258,7 +263,7 @@ func TestNullValues(t *testing.T) {
Male: sql.NullBool{Bool: true, Valid: true}, Male: sql.NullBool{Bool: true, Valid: true},
Height: sql.NullFloat64{Float64: 100.11, Valid: true}, Height: sql.NullFloat64{Float64: 100.11, Valid: true},
AddedAt: NullTime{Time: time.Now(), Valid: false}, AddedAt: NullTime{Time: time.Now(), Valid: false},
}).Error; err != nil { }).GetError(); err != nil {
t.Errorf("Not error should raise when test null value") t.Errorf("Not error should raise when test null value")
} }
@ -275,7 +280,7 @@ func TestNullValues(t *testing.T) {
Male: sql.NullBool{Bool: true, Valid: true}, Male: sql.NullBool{Bool: true, Valid: true},
Height: sql.NullFloat64{Float64: 100.11, Valid: true}, Height: sql.NullFloat64{Float64: 100.11, Valid: true},
AddedAt: NullTime{Time: time.Now(), Valid: false}, AddedAt: NullTime{Time: time.Now(), Valid: false},
}).Error; err == nil { }).GetError(); err == nil {
t.Errorf("Can't save because of name can't be null") t.Errorf("Can't save because of name can't be null")
} }
} }
@ -287,7 +292,7 @@ func TestNullValuesWithFirstOrCreate(t *testing.T) {
} }
var nv2 NullValue var nv2 NullValue
if err := DB.Where(nv1).FirstOrCreate(&nv2).Error; err != nil { if err := DB.Where(nv1).FirstOrCreate(&nv2).GetError(); err != nil {
t.Errorf("Should not raise any error, but got %v", err) t.Errorf("Should not raise any error, but got %v", err)
} }
@ -295,7 +300,7 @@ func TestNullValuesWithFirstOrCreate(t *testing.T) {
t.Errorf("first or create with nullvalues") t.Errorf("first or create with nullvalues")
} }
if err := DB.Where(nv1).Assign(NullValue{Age: sql.NullInt64{Int64: 18, Valid: true}}).FirstOrCreate(&nv2).Error; err != nil { if err := DB.Where(nv1).Assign(NullValue{Age: sql.NullInt64{Int64: 18, Valid: true}}).FirstOrCreate(&nv2).GetError(); err != nil {
t.Errorf("Should not raise any error, but got %v", err) t.Errorf("Should not raise any error, but got %v", err)
} }
@ -307,11 +312,11 @@ func TestNullValuesWithFirstOrCreate(t *testing.T) {
func TestTransaction(t *testing.T) { func TestTransaction(t *testing.T) {
tx := DB.Begin() tx := DB.Begin()
u := User{Name: "transcation"} u := User{Name: "transcation"}
if err := tx.Save(&u).Error; err != nil { if err := tx.Save(&u).GetError(); err != nil {
t.Errorf("No error should raise") t.Errorf("No error should raise")
} }
if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil { if err := tx.First(&User{}, "name = ?", "transcation").GetError(); err != nil {
t.Errorf("Should find saved record") t.Errorf("Should find saved record")
} }
@ -321,23 +326,23 @@ func TestTransaction(t *testing.T) {
tx.Rollback() tx.Rollback()
if err := tx.First(&User{}, "name = ?", "transcation").Error; err == nil { if err := tx.First(&User{}, "name = ?", "transcation").GetError(); err == nil {
t.Errorf("Should not find record after rollback") t.Errorf("Should not find record after rollback")
} }
tx2 := DB.Begin() tx2 := DB.Begin()
u2 := User{Name: "transcation-2"} u2 := User{Name: "transcation-2"}
if err := tx2.Save(&u2).Error; err != nil { if err := tx2.Save(&u2).GetError(); err != nil {
t.Errorf("No error should raise") t.Errorf("No error should raise")
} }
if err := tx2.First(&User{}, "name = ?", "transcation-2").Error; err != nil { if err := tx2.First(&User{}, "name = ?", "transcation-2").GetError(); err != nil {
t.Errorf("Should find saved record") t.Errorf("Should find saved record")
} }
tx2.Commit() tx2.Commit()
if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil { if err := DB.First(&User{}, "name = ?", "transcation-2").GetError(); err != nil {
t.Errorf("Should be able to find committed record") t.Errorf("Should be able to find committed record")
} }
} }
@ -436,7 +441,7 @@ func TestRaw(t *testing.T) {
} }
DB.Exec("update users set name=? where name in (?)", "jinzhu", []string{user1.Name, user2.Name, user3.Name}) DB.Exec("update users set name=? where name in (?)", "jinzhu", []string{user1.Name, user2.Name, user3.Name})
if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.RecordNotFound { if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).GetError() != gorm.RecordNotFound {
t.Error("Raw sql to update records") t.Error("Raw sql to update records")
} }
} }
@ -568,14 +573,14 @@ func TestHstore(t *testing.T) {
t.Skip() t.Skip()
} }
if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS hstore").Error; err != nil { if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS hstore").GetError(); err != nil {
fmt.Println("\033[31mHINT: Must be superuser to create hstore extension (ALTER USER gorm WITH SUPERUSER;)\033[0m") fmt.Println("\033[31mHINT: Must be superuser to create hstore extension (ALTER USER gorm WITH SUPERUSER;)\033[0m")
panic(fmt.Sprintf("No error should happen when create hstore extension, but got %+v", err)) panic(fmt.Sprintf("No error should happen when create hstore extension, but got %+v", err))
} }
DB.Exec("drop table details") DB.Exec("drop table details")
if err := DB.CreateTable(&Details{}).Error; err != nil { if err := DB.CreateTable(&Details{}).GetError(); err != nil {
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
} }
@ -589,7 +594,7 @@ func TestHstore(t *testing.T) {
DB.Save(&d) DB.Save(&d)
var d2 Details var d2 Details
if err := DB.First(&d2).Error; err != nil { if err := DB.First(&d2).GetError(); err != nil {
t.Errorf("Got error when tried to fetch details: %+v", err) t.Errorf("Got error when tried to fetch details: %+v", err)
} }
@ -647,7 +652,7 @@ func TestOpenExistingDB(t *testing.T) {
} }
var user User var user User
if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.RecordNotFound { if db.Where("name = ?", "jnfeinstein").First(&user).GetError() == gorm.RecordNotFound {
t.Errorf("Should have found existing record") t.Errorf("Should have found existing record")
} }
} }

View File

@ -7,7 +7,7 @@ import (
) )
func runMigration() { func runMigration() {
if err := DB.DropTableIfExists(&User{}).Error; err != nil { if err := DB.DropTableIfExists(&User{}).GetError(); err != nil {
fmt.Printf("Got error when try to delete table users, %+v\n", err) fmt.Printf("Got error when try to delete table users, %+v\n", err)
} }
@ -20,13 +20,13 @@ func runMigration() {
DB.DropTable(value) DB.DropTable(value)
} }
if err := DB.AutoMigrate(values...).Error; err != nil { if err := DB.AutoMigrate(values...).GetError(); err != nil {
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
} }
} }
func TestIndexes(t *testing.T) { func TestIndexes(t *testing.T) {
if err := DB.Model(&Email{}).AddIndex("idx_email_email", "email").Error; err != nil { if err := DB.Model(&Email{}).AddIndex("idx_email_email", "email").GetError(); err != nil {
t.Errorf("Got error when tried to create index: %+v", err) t.Errorf("Got error when tried to create index: %+v", err)
} }
@ -35,7 +35,7 @@ func TestIndexes(t *testing.T) {
t.Errorf("Email should have index idx_email_email") t.Errorf("Email should have index idx_email_email")
} }
if err := DB.Model(&Email{}).RemoveIndex("idx_email_email").Error; err != nil { if err := DB.Model(&Email{}).RemoveIndex("idx_email_email").GetError(); err != nil {
t.Errorf("Got error when tried to remove index: %+v", err) t.Errorf("Got error when tried to remove index: %+v", err)
} }
@ -43,7 +43,7 @@ func TestIndexes(t *testing.T) {
t.Errorf("Email's index idx_email_email should be deleted") t.Errorf("Email's index idx_email_email should be deleted")
} }
if err := DB.Model(&Email{}).AddIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil { if err := DB.Model(&Email{}).AddIndex("idx_email_email_and_user_id", "user_id", "email").GetError(); err != nil {
t.Errorf("Got error when tried to create index: %+v", err) t.Errorf("Got error when tried to create index: %+v", err)
} }
@ -51,7 +51,7 @@ func TestIndexes(t *testing.T) {
t.Errorf("Email should have index idx_email_email_and_user_id") t.Errorf("Email should have index idx_email_email_and_user_id")
} }
if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil { if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").GetError(); err != nil {
t.Errorf("Got error when tried to remove index: %+v", err) t.Errorf("Got error when tried to remove index: %+v", err)
} }
@ -59,7 +59,7 @@ func TestIndexes(t *testing.T) {
t.Errorf("Email's index idx_email_email_and_user_id should be deleted") t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
} }
if err := DB.Model(&Email{}).AddUniqueIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil { if err := DB.Model(&Email{}).AddUniqueIndex("idx_email_email_and_user_id", "user_id", "email").GetError(); err != nil {
t.Errorf("Got error when tried to create index: %+v", err) t.Errorf("Got error when tried to create index: %+v", err)
} }
@ -67,7 +67,7 @@ func TestIndexes(t *testing.T) {
t.Errorf("Email should have index idx_email_email_and_user_id") t.Errorf("Email should have index idx_email_email_and_user_id")
} }
if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.comiii"}, {Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error == nil { if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.comiii"}, {Email: "user1@example.com"}, {Email: "user1@example.com"}}}).GetError() == nil {
t.Errorf("Should get to create duplicate record when having unique index") t.Errorf("Should get to create duplicate record when having unique index")
} }
@ -81,7 +81,7 @@ func TestIndexes(t *testing.T) {
t.Errorf("Should get no duplicated email error when insert duplicated emails for a user") t.Errorf("Should get no duplicated email error when insert duplicated emails for a user")
} }
if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil { if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").GetError(); err != nil {
t.Errorf("Got error when tried to remove index: %+v", err) t.Errorf("Got error when tried to remove index: %+v", err)
} }
@ -89,7 +89,7 @@ func TestIndexes(t *testing.T) {
t.Errorf("Email's index idx_email_email_and_user_id should be deleted") t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
} }
if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error != nil { if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.com"}, {Email: "user1@example.com"}}}).GetError() != nil {
t.Errorf("Should be able to create duplicated emails after remove unique index") t.Errorf("Should be able to create duplicated emails after remove unique index")
} }
} }
@ -110,7 +110,7 @@ func (b BigEmail) TableName() string {
func TestAutoMigration(t *testing.T) { func TestAutoMigration(t *testing.T) {
DB.AutoMigrate(&Address{}) DB.AutoMigrate(&Address{})
if err := DB.Table("emails").AutoMigrate(&BigEmail{}).Error; err != nil { if err := DB.Table("emails").AutoMigrate(&BigEmail{}).GetError(); err != nil {
t.Errorf("Auto Migrate should not raise any error") t.Errorf("Auto Migrate should not raise any error")
} }

View File

@ -14,7 +14,7 @@ import (
"github.com/jinzhu/inflection" "github.com/jinzhu/inflection"
) )
var DefaultTableNameHandler = func(db *DB, defaultTableName string) string { var DefaultTableNameHandler = func(db Database, defaultTableName string) string {
return defaultTableName return defaultTableName
} }
@ -48,7 +48,7 @@ type ModelStruct struct {
defaultTableName string defaultTableName string
} }
func (s *ModelStruct) TableName(db *DB) string { func (s *ModelStruct) TableName(db Database) string {
return DefaultTableNameHandler(db, s.defaultTableName) return DefaultTableNameHandler(db, s.defaultTableName)
} }

View File

@ -20,12 +20,12 @@ func TestPointerFields(t *testing.T) {
var name = "pointer struct 1" var name = "pointer struct 1"
var num = 100 var num = 100
pointerStruct := PointerStruct{Name: &name, Num: &num} pointerStruct := PointerStruct{Name: &name, Num: &num}
if DB.Create(&pointerStruct).Error != nil { if DB.Create(&pointerStruct).GetError() != nil {
t.Errorf("Failed to save pointer struct") t.Errorf("Failed to save pointer struct")
} }
var pointerStructResult PointerStruct var pointerStructResult PointerStruct
if err := DB.First(&pointerStructResult, "id = ?", pointerStruct.ID).Error; err != nil || *pointerStructResult.Name != name || *pointerStructResult.Num != num { if err := DB.First(&pointerStructResult, "id = ?", pointerStruct.ID).GetError(); err != nil || *pointerStructResult.Name != name || *pointerStructResult.Num != num {
t.Errorf("Failed to query saved pointer struct") t.Errorf("Failed to query saved pointer struct")
} }
@ -38,47 +38,47 @@ func TestPointerFields(t *testing.T) {
} }
var nilPointerStruct = PointerStruct{} var nilPointerStruct = PointerStruct{}
if err := DB.Create(&nilPointerStruct).Error; err != nil { if err := DB.Create(&nilPointerStruct).GetError(); err != nil {
t.Errorf("Failed to save nil pointer struct", err) t.Errorf("Failed to save nil pointer struct", err)
} }
var pointerStruct2 PointerStruct var pointerStruct2 PointerStruct
if err := DB.First(&pointerStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil { if err := DB.First(&pointerStruct2, "id = ?", nilPointerStruct.ID).GetError(); err != nil {
t.Errorf("Failed to query saved nil pointer struct", err) t.Errorf("Failed to query saved nil pointer struct", err)
} }
var normalStruct2 NormalStruct var normalStruct2 NormalStruct
if err := DB.Table(tableName).First(&normalStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil { if err := DB.Table(tableName).First(&normalStruct2, "id = ?", nilPointerStruct.ID).GetError(); err != nil {
t.Errorf("Failed to query saved nil pointer struct", err) t.Errorf("Failed to query saved nil pointer struct", err)
} }
var partialNilPointerStruct1 = PointerStruct{Num: &num} var partialNilPointerStruct1 = PointerStruct{Num: &num}
if err := DB.Create(&partialNilPointerStruct1).Error; err != nil { if err := DB.Create(&partialNilPointerStruct1).GetError(); err != nil {
t.Errorf("Failed to save partial nil pointer struct", err) t.Errorf("Failed to save partial nil pointer struct", err)
} }
var pointerStruct3 PointerStruct var pointerStruct3 PointerStruct
if err := DB.First(&pointerStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || *pointerStruct3.Num != num { if err := DB.First(&pointerStruct3, "id = ?", partialNilPointerStruct1.ID).GetError(); err != nil || *pointerStruct3.Num != num {
t.Errorf("Failed to query saved partial nil pointer struct", err) t.Errorf("Failed to query saved partial nil pointer struct", err)
} }
var normalStruct3 NormalStruct var normalStruct3 NormalStruct
if err := DB.Table(tableName).First(&normalStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || normalStruct3.Num != num { if err := DB.Table(tableName).First(&normalStruct3, "id = ?", partialNilPointerStruct1.ID).GetError(); err != nil || normalStruct3.Num != num {
t.Errorf("Failed to query saved partial pointer struct", err) t.Errorf("Failed to query saved partial pointer struct", err)
} }
var partialNilPointerStruct2 = PointerStruct{Name: &name} var partialNilPointerStruct2 = PointerStruct{Name: &name}
if err := DB.Create(&partialNilPointerStruct2).Error; err != nil { if err := DB.Create(&partialNilPointerStruct2).GetError(); err != nil {
t.Errorf("Failed to save partial nil pointer struct", err) t.Errorf("Failed to save partial nil pointer struct", err)
} }
var pointerStruct4 PointerStruct var pointerStruct4 PointerStruct
if err := DB.First(&pointerStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || *pointerStruct4.Name != name { if err := DB.First(&pointerStruct4, "id = ?", partialNilPointerStruct2.ID).GetError(); err != nil || *pointerStruct4.Name != name {
t.Errorf("Failed to query saved partial nil pointer struct", err) t.Errorf("Failed to query saved partial nil pointer struct", err)
} }
var normalStruct4 NormalStruct var normalStruct4 NormalStruct
if err := DB.Table(tableName).First(&normalStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || normalStruct4.Name != name { if err := DB.Table(tableName).First(&normalStruct4, "id = ?", partialNilPointerStruct2.ID).GetError(); err != nil || normalStruct4.Name != name {
t.Errorf("Failed to query saved partial pointer struct", err) t.Errorf("Failed to query saved partial pointer struct", err)
} }
} }

View File

@ -94,7 +94,7 @@ func (s postgres) HasColumn(scope *Scope, tableName string, columnName string) b
} }
func (postgres) RemoveIndex(scope *Scope, indexName string) { func (postgres) RemoveIndex(scope *Scope, indexName string) {
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error) scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).GetError())
} }
func (s postgres) HasIndex(scope *Scope, tableName string, indexName string) bool { func (s postgres) HasIndex(scope *Scope, tableName string, indexName string) bool {

View File

@ -115,7 +115,7 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{})
} }
results := makeSlice(field.Struct.Type) results := makeSlice(field.Struct.Type)
scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error) scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).GetError())
resultValues := reflect.Indirect(reflect.ValueOf(results)) resultValues := reflect.Indirect(reflect.ValueOf(results))
for i := 0; i < resultValues.Len(); i++ { for i := 0; i < resultValues.Len(); i++ {
@ -146,7 +146,7 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{})
} }
results := makeSlice(field.Struct.Type) results := makeSlice(field.Struct.Type)
scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error) scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).GetError())
resultValues := reflect.Indirect(reflect.ValueOf(results)) resultValues := reflect.Indirect(reflect.ValueOf(results))
if scope.IndirectValue().Kind() == reflect.Slice { if scope.IndirectValue().Kind() == reflect.Slice {
@ -176,7 +176,7 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{
} }
results := makeSlice(field.Struct.Type) results := makeSlice(field.Struct.Type)
scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error) scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).GetError())
resultValues := reflect.Indirect(reflect.ValueOf(results)) resultValues := reflect.Indirect(reflect.ValueOf(results))
for i := 0; i < resultValues.Len(); i++ { for i := 0; i < resultValues.Len(); i++ {

View File

@ -115,17 +115,17 @@ func TestNestedPreload1(t *testing.T) {
DB.DropTableIfExists(&Level3{}) DB.DropTableIfExists(&Level3{})
DB.DropTableIfExists(&Level2{}) DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{}) DB.DropTableIfExists(&Level1{})
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
want := Level3{Level2: Level2{Level1: Level1{Value: "value"}}} want := Level3{Level2: Level2{Level1: Level1{Value: "value"}}}
if err := DB.Create(&want).Error; err != nil { if err := DB.Create(&want).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
var got Level3 var got Level3
if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil { if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
@ -133,7 +133,7 @@ func TestNestedPreload1(t *testing.T) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
} }
if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").Error; err != gorm.RecordNotFound { if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").GetError(); err != gorm.RecordNotFound {
t.Error(err) t.Error(err)
} }
} }
@ -159,7 +159,7 @@ func TestNestedPreload2(t *testing.T) {
DB.DropTableIfExists(&Level3{}) DB.DropTableIfExists(&Level3{})
DB.DropTableIfExists(&Level2{}) DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{}) DB.DropTableIfExists(&Level1{})
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
@ -178,12 +178,12 @@ func TestNestedPreload2(t *testing.T) {
}, },
}, },
} }
if err := DB.Create(&want).Error; err != nil { if err := DB.Create(&want).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
var got Level3 var got Level3
if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil { if err := DB.Preload("Level2s.Level1s").Find(&got).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
@ -213,7 +213,7 @@ func TestNestedPreload3(t *testing.T) {
DB.DropTableIfExists(&Level3{}) DB.DropTableIfExists(&Level3{})
DB.DropTableIfExists(&Level2{}) DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{}) DB.DropTableIfExists(&Level1{})
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
@ -223,12 +223,12 @@ func TestNestedPreload3(t *testing.T) {
{Level1: Level1{Value: "value2"}}, {Level1: Level1{Value: "value2"}},
}, },
} }
if err := DB.Create(&want).Error; err != nil { if err := DB.Create(&want).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
var got Level3 var got Level3
if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil { if err := DB.Preload("Level2s.Level1").Find(&got).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
@ -258,7 +258,7 @@ func TestNestedPreload4(t *testing.T) {
DB.DropTableIfExists(&Level3{}) DB.DropTableIfExists(&Level3{})
DB.DropTableIfExists(&Level2{}) DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{}) DB.DropTableIfExists(&Level1{})
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
@ -270,12 +270,12 @@ func TestNestedPreload4(t *testing.T) {
}, },
}, },
} }
if err := DB.Create(&want).Error; err != nil { if err := DB.Create(&want).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
var got Level3 var got Level3
if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil { if err := DB.Preload("Level2.Level1s").Find(&got).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
@ -306,22 +306,22 @@ func TestNestedPreload5(t *testing.T) {
DB.DropTableIfExists(&Level3{}) DB.DropTableIfExists(&Level3{})
DB.DropTableIfExists(&Level2{}) DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{}) DB.DropTableIfExists(&Level1{})
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
want := make([]Level3, 2) want := make([]Level3, 2)
want[0] = Level3{Level2: Level2{Level1: Level1{Value: "value"}}} want[0] = Level3{Level2: Level2{Level1: Level1{Value: "value"}}}
if err := DB.Create(&want[0]).Error; err != nil { if err := DB.Create(&want[0]).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
want[1] = Level3{Level2: Level2{Level1: Level1{Value: "value2"}}} want[1] = Level3{Level2: Level2{Level1: Level1{Value: "value2"}}}
if err := DB.Create(&want[1]).Error; err != nil { if err := DB.Create(&want[1]).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
var got []Level3 var got []Level3
if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil { if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
@ -351,7 +351,7 @@ func TestNestedPreload6(t *testing.T) {
DB.DropTableIfExists(&Level3{}) DB.DropTableIfExists(&Level3{})
DB.DropTableIfExists(&Level2{}) DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{}) DB.DropTableIfExists(&Level1{})
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
@ -371,7 +371,7 @@ func TestNestedPreload6(t *testing.T) {
}, },
}, },
} }
if err := DB.Create(&want[0]).Error; err != nil { if err := DB.Create(&want[0]).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
@ -390,12 +390,12 @@ func TestNestedPreload6(t *testing.T) {
}, },
}, },
} }
if err := DB.Create(&want[1]).Error; err != nil { if err := DB.Create(&want[1]).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
var got []Level3 var got []Level3
if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil { if err := DB.Preload("Level2s.Level1s").Find(&got).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
@ -425,7 +425,7 @@ func TestNestedPreload7(t *testing.T) {
DB.DropTableIfExists(&Level3{}) DB.DropTableIfExists(&Level3{})
DB.DropTableIfExists(&Level2{}) DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{}) DB.DropTableIfExists(&Level1{})
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
@ -436,7 +436,7 @@ func TestNestedPreload7(t *testing.T) {
{Level1: Level1{Value: "value2"}}, {Level1: Level1{Value: "value2"}},
}, },
} }
if err := DB.Create(&want[0]).Error; err != nil { if err := DB.Create(&want[0]).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
@ -446,12 +446,12 @@ func TestNestedPreload7(t *testing.T) {
{Level1: Level1{Value: "value4"}}, {Level1: Level1{Value: "value4"}},
}, },
} }
if err := DB.Create(&want[1]).Error; err != nil { if err := DB.Create(&want[1]).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
var got []Level3 var got []Level3
if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil { if err := DB.Preload("Level2s.Level1").Find(&got).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
@ -481,7 +481,7 @@ func TestNestedPreload8(t *testing.T) {
DB.DropTableIfExists(&Level3{}) DB.DropTableIfExists(&Level3{})
DB.DropTableIfExists(&Level2{}) DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{}) DB.DropTableIfExists(&Level1{})
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
@ -494,7 +494,7 @@ func TestNestedPreload8(t *testing.T) {
}, },
}, },
} }
if err := DB.Create(&want[0]).Error; err != nil { if err := DB.Create(&want[0]).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
want[1] = Level3{ want[1] = Level3{
@ -505,12 +505,12 @@ func TestNestedPreload8(t *testing.T) {
}, },
}, },
} }
if err := DB.Create(&want[1]).Error; err != nil { if err := DB.Create(&want[1]).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
var got []Level3 var got []Level3
if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil { if err := DB.Preload("Level2.Level1s").Find(&got).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
@ -555,7 +555,7 @@ func TestNestedPreload9(t *testing.T) {
DB.DropTableIfExists(&Level2_1{}) DB.DropTableIfExists(&Level2_1{})
DB.DropTableIfExists(&Level1{}) DB.DropTableIfExists(&Level1{})
DB.DropTableIfExists(&Level0{}) DB.DropTableIfExists(&Level0{})
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}, &Level2_1{}, &Level0{}).Error; err != nil { if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}, &Level2_1{}, &Level0{}).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
@ -580,7 +580,7 @@ func TestNestedPreload9(t *testing.T) {
}, },
}, },
} }
if err := DB.Create(&want[0]).Error; err != nil { if err := DB.Create(&want[0]).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
want[1] = Level3{ want[1] = Level3{
@ -597,12 +597,12 @@ func TestNestedPreload9(t *testing.T) {
}, },
}, },
} }
if err := DB.Create(&want[1]).Error; err != nil { if err := DB.Create(&want[1]).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
var got []Level3 var got []Level3
if err := DB.Preload("Level2").Preload("Level2.Level1s").Preload("Level2_1").Preload("Level2_1.Level1s").Preload("Level2_1.Level1s.Level0s").Find(&got).Error; err != nil { if err := DB.Preload("Level2").Preload("Level2.Level1s").Preload("Level2_1").Preload("Level2_1.Level1s").Preload("Level2_1.Level1s.Level0s").Find(&got).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
@ -634,7 +634,7 @@ func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) {
DB.DropTableIfExists(&Level1{}) DB.DropTableIfExists(&Level1{})
DB.DropTableIfExists("levels") DB.DropTableIfExists("levels")
if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil { if err := DB.AutoMigrate(&Level2{}, &Level1{}).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
@ -642,7 +642,7 @@ func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) {
{Value: "ru", LanguageCode: "ru"}, {Value: "ru", LanguageCode: "ru"},
{Value: "en", LanguageCode: "en"}, {Value: "en", LanguageCode: "en"},
}} }}
if err := DB.Save(&want).Error; err != nil { if err := DB.Save(&want).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
@ -650,12 +650,12 @@ func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) {
{Value: "zh", LanguageCode: "zh"}, {Value: "zh", LanguageCode: "zh"},
{Value: "de", LanguageCode: "de"}, {Value: "de", LanguageCode: "de"},
}} }}
if err := DB.Save(&want2).Error; err != nil { if err := DB.Save(&want2).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
var got Level2 var got Level2
if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").GetError(); err != nil {
t.Error(err) t.Error(err)
} }
@ -664,7 +664,7 @@ func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) {
} }
var got2 Level2 var got2 Level2
if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").GetError(); err != nil {
t.Error(err) t.Error(err)
} }
@ -673,7 +673,7 @@ func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) {
} }
var got3 []Level2 var got3 []Level2
if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
@ -682,7 +682,7 @@ func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) {
} }
var got4 []Level2 var got4 []Level2
if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
@ -697,7 +697,7 @@ func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) {
t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2})) t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2}))
} }
if err := DB.Preload("Level1s").Find(&got4, "value IN (?)", []string{"non-existing"}).Error; err != nil { if err := DB.Preload("Level1s").Find(&got4, "value IN (?)", []string{"non-existing"}).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
} }
@ -719,7 +719,7 @@ func TestManyToManyPreloadForPointer(t *testing.T) {
DB.DropTableIfExists(&Level1{}) DB.DropTableIfExists(&Level1{})
DB.DropTableIfExists("levels") DB.DropTableIfExists("levels")
if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil { if err := DB.AutoMigrate(&Level2{}, &Level1{}).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
@ -727,7 +727,7 @@ func TestManyToManyPreloadForPointer(t *testing.T) {
{Value: "ru"}, {Value: "ru"},
{Value: "en"}, {Value: "en"},
}} }}
if err := DB.Save(&want).Error; err != nil { if err := DB.Save(&want).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
@ -735,12 +735,12 @@ func TestManyToManyPreloadForPointer(t *testing.T) {
{Value: "zh"}, {Value: "zh"},
{Value: "de"}, {Value: "de"},
}} }}
if err := DB.Save(&want2).Error; err != nil { if err := DB.Save(&want2).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
var got Level2 var got Level2
if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").GetError(); err != nil {
t.Error(err) t.Error(err)
} }
@ -749,7 +749,7 @@ func TestManyToManyPreloadForPointer(t *testing.T) {
} }
var got2 Level2 var got2 Level2
if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").GetError(); err != nil {
t.Error(err) t.Error(err)
} }
@ -758,7 +758,7 @@ func TestManyToManyPreloadForPointer(t *testing.T) {
} }
var got3 []Level2 var got3 []Level2
if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
@ -767,7 +767,7 @@ func TestManyToManyPreloadForPointer(t *testing.T) {
} }
var got4 []Level2 var got4 []Level2
if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
@ -810,7 +810,7 @@ func TestManyToManyPreloadForNestedPointer(t *testing.T) {
DB.DropTableIfExists(&Level1{}) DB.DropTableIfExists(&Level1{})
DB.DropTableIfExists("levels") DB.DropTableIfExists("levels")
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
@ -824,7 +824,7 @@ func TestManyToManyPreloadForNestedPointer(t *testing.T) {
}, },
}, },
} }
if err := DB.Save(&want).Error; err != nil { if err := DB.Save(&want).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
@ -838,12 +838,12 @@ func TestManyToManyPreloadForNestedPointer(t *testing.T) {
}, },
}, },
} }
if err := DB.Save(&want2).Error; err != nil { if err := DB.Save(&want2).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
var got Level3 var got Level3
if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "Bob").GetError(); err != nil {
t.Error(err) t.Error(err)
} }
@ -852,7 +852,7 @@ func TestManyToManyPreloadForNestedPointer(t *testing.T) {
} }
var got2 Level3 var got2 Level3
if err := DB.Preload("Level2.Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { if err := DB.Preload("Level2.Level1s").Find(&got2, "value = ?", "Tom").GetError(); err != nil {
t.Error(err) t.Error(err)
} }
@ -861,7 +861,7 @@ func TestManyToManyPreloadForNestedPointer(t *testing.T) {
} }
var got3 []Level3 var got3 []Level3
if err := DB.Preload("Level2.Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { if err := DB.Preload("Level2.Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
@ -870,7 +870,7 @@ func TestManyToManyPreloadForNestedPointer(t *testing.T) {
} }
var got4 []Level3 var got4 []Level3
if err := DB.Preload("Level2.Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { if err := DB.Preload("Level2.Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
@ -913,7 +913,7 @@ func TestNestedManyToManyPreload(t *testing.T) {
DB.DropTableIfExists("level1_level2") DB.DropTableIfExists("level1_level2")
DB.DropTableIfExists("level2_level3") DB.DropTableIfExists("level2_level3")
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
@ -936,12 +936,12 @@ func TestNestedManyToManyPreload(t *testing.T) {
}, },
} }
if err := DB.Save(&want).Error; err != nil { if err := DB.Save(&want).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
var got Level3 var got Level3
if err := DB.Preload("Level2s").Preload("Level2s.Level1s").Find(&got, "value = ?", "Level3").Error; err != nil { if err := DB.Preload("Level2s").Preload("Level2s.Level1s").Find(&got, "value = ?", "Level3").GetError(); err != nil {
t.Error(err) t.Error(err)
} }
@ -949,7 +949,7 @@ func TestNestedManyToManyPreload(t *testing.T) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
} }
if err := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.RecordNotFound { if err := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").GetError(); err != gorm.RecordNotFound {
t.Error(err) t.Error(err)
} }
} }
@ -978,7 +978,7 @@ func TestNestedManyToManyPreload2(t *testing.T) {
DB.DropTableIfExists(&Level3{}) DB.DropTableIfExists(&Level3{})
DB.DropTableIfExists("level1_level2") DB.DropTableIfExists("level1_level2")
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
@ -993,12 +993,12 @@ func TestNestedManyToManyPreload2(t *testing.T) {
}, },
} }
if err := DB.Save(&want).Error; err != nil { if err := DB.Save(&want).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
var got Level3 var got Level3
if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "Level3").Error; err != nil { if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "Level3").GetError(); err != nil {
t.Error(err) t.Error(err)
} }
@ -1006,7 +1006,7 @@ func TestNestedManyToManyPreload2(t *testing.T) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
} }
if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.RecordNotFound { if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").GetError(); err != gorm.RecordNotFound {
t.Error(err) t.Error(err)
} }
} }
@ -1035,7 +1035,7 @@ func TestNilPointerSlice(t *testing.T) {
DB.DropTableIfExists(&Level2{}) DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{}) DB.DropTableIfExists(&Level1{})
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
@ -1045,17 +1045,17 @@ func TestNilPointerSlice(t *testing.T) {
Value: "native", Value: "native",
}, },
}} }}
if err := DB.Save(&want).Error; err != nil { if err := DB.Save(&want).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
want2 := Level1{Value: "Tom", Level2: nil} want2 := Level1{Value: "Tom", Level2: nil}
if err := DB.Save(&want2).Error; err != nil { if err := DB.Save(&want2).GetError(); err != nil {
t.Error(err) t.Error(err)
} }
var got []Level1 var got []Level1
if err := DB.Preload("Level2").Preload("Level2.Level3").Find(&got).Error; err != nil { if err := DB.Preload("Level2").Preload("Level2.Level3").Find(&got).GetError(); err != nil {
t.Error(err) t.Error(err)
} }

View File

@ -31,7 +31,7 @@ func TestFirstAndLast(t *testing.T) {
t.Errorf("Find first record as slice") t.Errorf("Find first record as slice")
} }
if DB.Joins("left join emails on emails.user_id = users.id").First(&User{}).Error != nil { if DB.Joins("left join emails on emails.user_id = users.id").First(&User{}).GetError() != nil {
t.Errorf("Should not raise any error when order with Join table") t.Errorf("Should not raise any error when order with Join table")
} }
} }
@ -242,15 +242,15 @@ func TestSearchWithEmptyChain(t *testing.T) {
user3 := User{Name: "ChainearchUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} user3 := User{Name: "ChainearchUser3", Age: 20, Birthday: now.MustParse("2020-1-1")}
DB.Save(&user1).Save(&user2).Save(&user3) DB.Save(&user1).Save(&user2).Save(&user3)
if DB.Where("").Where("").First(&User{}).Error != nil { if DB.Where("").Where("").First(&User{}).GetError() != nil {
t.Errorf("Should not raise any error if searching with empty strings") t.Errorf("Should not raise any error if searching with empty strings")
} }
if DB.Where(&User{}).Where("name = ?", user1.Name).First(&User{}).Error != nil { if DB.Where(&User{}).Where("name = ?", user1.Name).First(&User{}).GetError() != nil {
t.Errorf("Should not raise any error if searching with empty struct") t.Errorf("Should not raise any error if searching with empty struct")
} }
if DB.Where(map[string]interface{}{}).Where("name = ?", user1.Name).First(&User{}).Error != nil { if DB.Where(map[string]interface{}{}).Where("name = ?", user1.Name).First(&User{}).GetError() != nil {
t.Errorf("Should not raise any error if searching with empty map") t.Errorf("Should not raise any error if searching with empty map")
} }
} }
@ -359,7 +359,7 @@ func TestCount(t *testing.T) {
var count, count1, count2 int64 var count, count1, count2 int64
var users []User var users []User
if err := DB.Where("name = ?", user1.Name).Or("name = ?", user3.Name).Find(&users).Count(&count).Error; err != nil { if err := DB.Where("name = ?", user1.Name).Or("name = ?", user3.Name).Find(&users).Count(&count).GetError(); err != nil {
t.Errorf(fmt.Sprintf("Count should work, but got err %v", err)) t.Errorf(fmt.Sprintf("Count should work, but got err %v", err))
} }
@ -381,7 +381,7 @@ func TestNot(t *testing.T) {
DB := DB.Where("role = ?", "not") DB := DB.Where("role = ?", "not")
var users1, users2, users3, users4, users5, users6, users7, users8 []User var users1, users2, users3, users4, users5, users6, users7, users8 []User
if DB.Find(&users1).RowsAffected != 4 { if DB.Find(&users1).GetRowsAffected() != 4 {
t.Errorf("should find 4 not users") t.Errorf("should find 4 not users")
} }
DB.Not(users1[0].Id).Find(&users2) DB.Not(users1[0].Id).Find(&users2)
@ -598,7 +598,7 @@ func TestSelectWithArrayInput(t *testing.T) {
func TestCurrentDatabase(t *testing.T) { func TestCurrentDatabase(t *testing.T) {
databaseName := DB.CurrentDatabase() databaseName := DB.CurrentDatabase()
if err := DB.Error; err != nil { if err := DB.GetError(); err != nil {
t.Errorf("Problem getting current db name: %s", err) t.Errorf("Problem getting current db name: %s", err)
} }
if databaseName == "" { if databaseName == "" {

View File

@ -207,7 +207,7 @@ func (scope *Scope) CallMethod(name string, checkError bool) {
case func(s *DB): case func(s *DB):
newDB := scope.NewDB() newDB := scope.NewDB()
f(newDB) f(newDB)
scope.Err(newDB.Error) scope.Err(newDB.GetError())
case func() error: case func() error:
scope.Err(f()) scope.Err(f())
case func(s *Scope) error: case func(s *Scope) error:
@ -215,7 +215,7 @@ func (scope *Scope) CallMethod(name string, checkError bool) {
case func(s *DB) error: case func(s *DB) error:
newDB := scope.NewDB() newDB := scope.NewDB()
scope.Err(f(newDB)) scope.Err(f(newDB))
scope.Err(newDB.Error) scope.Err(newDB.GetError())
default: default:
scope.Err(fmt.Errorf("unsupported function %v", name)) scope.Err(fmt.Errorf("unsupported function %v", name))
} }
@ -262,7 +262,7 @@ type tabler interface {
} }
type dbTabler interface { type dbTabler interface {
TableName(*DB) string TableName(Database) string
} }
// TableName get table name // TableName get table name

View File

@ -444,17 +444,17 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
if relationship := fromField.Relationship; relationship != nil { if relationship := fromField.Relationship; relationship != nil {
if relationship.Kind == "many_to_many" { if relationship.Kind == "many_to_many" {
joinTableHandler := relationship.JoinTableHandler joinTableHandler := relationship.JoinTableHandler
scope.Err(joinTableHandler.JoinWith(joinTableHandler, toScope.db, scope.Value).Find(value).Error) scope.Err(joinTableHandler.JoinWith(joinTableHandler, toScope.db, scope.Value).Find(value).GetError())
} else if relationship.Kind == "belongs_to" { } else if relationship.Kind == "belongs_to" {
query := toScope.db var query Database = toScope.db
for idx, foreignKey := range relationship.ForeignDBNames { for idx, foreignKey := range relationship.ForeignDBNames {
if field, ok := scope.FieldByName(foreignKey); ok { if field, ok := scope.FieldByName(foreignKey); ok {
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.AssociationForeignDBNames[idx])), field.Field.Interface()) query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.AssociationForeignDBNames[idx])), field.Field.Interface())
} }
} }
scope.Err(query.Find(value).Error) scope.Err(query.Find(value).GetError())
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
query := toScope.db var query Database = toScope.db
for idx, foreignKey := range relationship.ForeignDBNames { for idx, foreignKey := range relationship.ForeignDBNames {
if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok {
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
@ -464,16 +464,16 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
if relationship.PolymorphicType != "" { if relationship.PolymorphicType != "" {
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName()) query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName())
} }
scope.Err(query.Find(value).Error) scope.Err(query.Find(value).GetError())
} }
} else { } else {
sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey())) sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey()))
scope.Err(toScope.db.Where(sql, fromField.Field.Interface()).Find(value).Error) scope.Err(toScope.db.Where(sql, fromField.Field.Interface()).Find(value).GetError())
} }
return scope return scope
} else if toField != nil { } else if toField != nil {
sql := fmt.Sprintf("%v = ?", scope.Quote(toField.DBName)) sql := fmt.Sprintf("%v = ?", scope.Quote(toField.DBName))
scope.Err(toScope.db.Where(sql, scope.PrimaryKeyValue()).Find(value).Error) scope.Err(toScope.db.Where(sql, scope.PrimaryKeyValue()).Find(value).GetError())
return scope return scope
} }
} }
@ -525,7 +525,7 @@ func (scope *Scope) createJoinTable(field *StructField) {
} }
} }
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v)) %s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).Error) scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v)) %s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).GetError())
} }
scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler) scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler)
} }

View File

@ -6,16 +6,16 @@ import (
) )
func NameIn1And2(d *gorm.DB) *gorm.DB { func NameIn1And2(d *gorm.DB) *gorm.DB {
return d.Where("name in (?)", []string{"ScopeUser1", "ScopeUser2"}) return d.Where("name in (?)", []string{"ScopeUser1", "ScopeUser2"}).(*gorm.DB)
} }
func NameIn2And3(d *gorm.DB) *gorm.DB { func NameIn2And3(d *gorm.DB) *gorm.DB {
return d.Where("name in (?)", []string{"ScopeUser2", "ScopeUser3"}) return d.Where("name in (?)", []string{"ScopeUser2", "ScopeUser3"}).(*gorm.DB)
} }
func NameIn(names []string) func(d *gorm.DB) *gorm.DB { func NameIn(names []string) func(d *gorm.DB) *gorm.DB{
return func(d *gorm.DB) *gorm.DB { return func(d *gorm.DB) *gorm.DB {
return d.Where("name in (?)", names) return d.Where("name in (?)", names).(*gorm.DB)
} }
} }

View File

@ -7,7 +7,7 @@ import (
) )
func TestScannableSlices(t *testing.T) { func TestScannableSlices(t *testing.T) {
if err := DB.AutoMigrate(&RecordWithSlice{}).Error; err != nil { if err := DB.AutoMigrate(&RecordWithSlice{}).GetError(); err != nil {
t.Errorf("Should create table with slice values correctly: %s", err) t.Errorf("Should create table with slice values correctly: %s", err)
} }
@ -19,13 +19,13 @@ func TestScannableSlices(t *testing.T) {
}, },
} }
if err := DB.Save(&r1).Error; err != nil { if err := DB.Save(&r1).GetError(); err != nil {
t.Errorf("Should save record with slice values") t.Errorf("Should save record with slice values")
} }
var r2 RecordWithSlice var r2 RecordWithSlice
if err := DB.Find(&r2).Error; err != nil { if err := DB.Find(&r2).GetError(); err != nil {
t.Errorf("Should fetch record with slice values") t.Errorf("Should fetch record with slice values")
} }

View File

@ -62,7 +62,7 @@ func (s sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool
} }
func (sqlite3) RemoveIndex(scope *Scope, indexName string) { func (sqlite3) RemoveIndex(scope *Scope, indexName string) {
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error) scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).GetError())
} }
func (sqlite3) CurrentDatabase(scope *Scope) (name string) { func (sqlite3) CurrentDatabase(scope *Scope) (name string) {

View File

@ -56,17 +56,17 @@ func TestUpdate(t *testing.T) {
t.Errorf("Product should not be changed to 789") t.Errorf("Product should not be changed to 789")
} }
if DB.Model(product2).Update("CreatedAt", time.Now().Add(time.Hour)).Error != nil { if DB.Model(product2).Update("CreatedAt", time.Now().Add(time.Hour)).GetError() != nil {
t.Error("No error should raise when update with CamelCase") t.Error("No error should raise when update with CamelCase")
} }
if DB.Model(&product2).UpdateColumn("CreatedAt", time.Now().Add(time.Hour)).Error != nil { if DB.Model(&product2).UpdateColumn("CreatedAt", time.Now().Add(time.Hour)).GetError() != nil {
t.Error("No error should raise when update_column with CamelCase") t.Error("No error should raise when update_column with CamelCase")
} }
var products []Product var products []Product
DB.Find(&products) DB.Find(&products)
if count := DB.Model(Product{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(products)) { if count := DB.Model(Product{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).GetRowsAffected(); count != int64(len(products)) {
t.Error("RowsAffected should be correct when do batch update") t.Error("RowsAffected should be correct when do batch update")
} }
@ -95,7 +95,7 @@ func TestUpdateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) {
var animals []Animal var animals []Animal
DB.Find(&animals) DB.Find(&animals)
if count := DB.Model(Animal{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(animals)) { if count := DB.Model(Animal{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).GetRowsAffected(); count != int64(len(animals)) {
t.Error("RowsAffected should be correct when do batch update") t.Error("RowsAffected should be correct when do batch update")
} }
@ -402,8 +402,8 @@ func TestUpdateColumnsSkipsAssociations(t *testing.T) {
newAge := int64(100) newAge := int64(100)
user.BillingAddress.Address1 = "second street" user.BillingAddress.Address1 = "second street"
db := DB.Model(user).UpdateColumns(User{Age: newAge}) db := DB.Model(user).UpdateColumns(User{Age: newAge})
if db.RowsAffected != 1 { if db.GetRowsAffected() != 1 {
t.Errorf("Expected RowsAffected=1 but instead RowsAffected=%v", DB.RowsAffected) t.Errorf("Expected RowsAffected=1 but instead RowsAffected=%v", DB.GetRowsAffected())
} }
// Verify that Age now=`newAge`. // Verify that Age now=`newAge`.