Add extract a interface based on gorm.DB methods

This commit is contained in:
Alvaro Viebrantz 2016-01-17 14:41:34 -03:00
parent 341d047aa7
commit cd40e372d3
30 changed files with 430 additions and 312 deletions

View File

@ -23,7 +23,7 @@ func (association *Association) setErr(err error) *Association {
func (association *Association) Find(value interface{}) *Association {
association.Scope.related(value, association.Column)
return association.setErr(association.Scope.db.Error)
return association.setErr(association.Scope.db.GetError())
}
func (association *Association) saveAssociations(values ...interface{}) *Association {
@ -42,7 +42,7 @@ func (association *Association) saveAssociations(values ...interface{}) *Associa
// value has to been saved for many2many
if relationship.Kind == "many_to_many" {
if scope.New(reflectValue.Interface()).PrimaryKeyZero() {
association.setErr(scope.NewDB().Save(reflectValue.Interface()).Error)
association.setErr(scope.NewDB().Save(reflectValue.Interface()).GetError())
}
}
@ -68,7 +68,7 @@ func (association *Association) saveAssociations(values ...interface{}) *Associa
if relationship.Kind == "many_to_many" {
association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, reflectValue.Interface()))
} else {
association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).Error)
association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).GetError())
if setFieldBackToValue {
reflectValue.Elem().Set(field.Field)
@ -106,7 +106,7 @@ func (association *Association) Replace(values ...interface{}) *Association {
relationship = association.Field.Relationship
scope = association.Scope
field = association.Field.Field
newDB = scope.NewDB()
newDB Database = scope.NewDB()
)
// Append new values
@ -122,7 +122,7 @@ func (association *Association) Replace(values ...interface{}) *Association {
for _, foreignKey := range relationship.ForeignDBNames {
foreignKeyMap[foreignKey] = nil
}
association.setErr(newDB.Model(scope.Value).UpdateColumn(foreignKeyMap).Error)
association.setErr(newDB.Model(scope.Value).UpdateColumn(foreignKeyMap).GetError())
}
} else {
// Relations
@ -173,7 +173,7 @@ func (association *Association) Replace(values ...interface{}) *Association {
}
fieldValue := reflect.New(association.Field.Field.Type()).Interface()
association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error)
association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).GetError())
}
}
return association
@ -184,7 +184,7 @@ func (association *Association) Delete(values ...interface{}) *Association {
relationship = association.Field.Relationship
scope = association.Scope
field = association.Field.Field
newDB = scope.NewDB()
newDB Database = scope.NewDB()
)
if len(values) == 0 {
@ -231,12 +231,12 @@ func (association *Association) Delete(values ...interface{}) *Association {
// set foreign key to be null
modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface()
if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.Error == nil {
if results.RowsAffected > 0 {
if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.GetError() == nil {
if results.GetRowsAffected() > 0 {
scope.updatedAttrsWithValues(foreignKeyMap, false)
}
} else {
association.setErr(results.Error)
association.setErr(results.GetError())
}
} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
// find all relations
@ -254,7 +254,7 @@ func (association *Association) Delete(values ...interface{}) *Association {
// set matched relation's foreign key to be null
fieldValue := reflect.New(association.Field.Field.Type()).Interface()
association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error)
association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).GetError())
}
}
@ -308,7 +308,7 @@ func (association *Association) Count() int {
if relationship.Kind == "many_to_many" {
relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.DB(), association.Scope.Value).Model(fieldValue).Count(&count)
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
query := scope.DB()
var query Database = scope.DB()
for idx, foreignKey := range relationship.ForeignDBNames {
if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok {
query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(foreignKey)),
@ -321,7 +321,7 @@ func (association *Association) Count() int {
}
query.Model(fieldValue).Count(&count)
} else if relationship.Kind == "belongs_to" {
query := scope.DB()
var query Database = scope.DB()
for idx, primaryKey := range relationship.AssociationForeignDBNames {
if field, ok := scope.FieldByName(relationship.ForeignDBNames[idx]); ok {
query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(primaryKey)),

View File

@ -15,7 +15,7 @@ func TestBelongsTo(t *testing.T) {
MainCategory: Category{Name: "Main Category 1"},
}
if err := DB.Save(&post).Error; err != nil {
if err := DB.Save(&post).GetError(); err != nil {
t.Errorf("Got errors when save post", err.Error())
}
@ -183,7 +183,7 @@ func TestHasOne(t *testing.T) {
CreditCard: CreditCard{Number: "411111111111"},
}
if err := DB.Save(&user).Error; err != nil {
if err := DB.Save(&user).GetError(); err != nil {
t.Errorf("Got errors when save user", err.Error())
}
@ -330,7 +330,7 @@ func TestHasMany(t *testing.T) {
Comments: []*Comment{{Content: "Comment 1"}, {Content: "Comment 2"}},
}
if err := DB.Save(&post).Error; err != nil {
if err := DB.Save(&post).GetError(); err != nil {
t.Errorf("Got errors when save post", err.Error())
}
@ -351,7 +351,7 @@ func TestHasMany(t *testing.T) {
}
// Query
if DB.First(&Comment{}, "content = ?", "Comment 1").Error != nil {
if DB.First(&Comment{}, "content = ?", "Comment 1").GetError() != nil {
t.Errorf("Comment 1 should be saved")
}

View File

@ -78,7 +78,8 @@ func Create(scope *Scope) {
if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
id, err := result.LastInsertId()
if scope.Err(err) == nil {
scope.db.RowsAffected, _ = result.RowsAffected()
rowsAffected, _ := result.RowsAffected()
scope.db.SetRowsAffected(rowsAffected)
if primaryField != nil && primaryField.IsBlank {
scope.Err(scope.SetColumn(primaryField, id))
}
@ -87,13 +88,14 @@ func Create(scope *Scope) {
} else {
if primaryField == nil {
if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == nil {
scope.db.RowsAffected, _ = results.RowsAffected()
rowsAffected, _ := results.RowsAffected()
scope.db.SetRowsAffected(rowsAffected)
} else {
scope.Err(err)
}
} else {
if err := scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())); err == nil {
scope.db.RowsAffected = 1
scope.db.SetRowsAffected(1)
} else {
scope.Err(err)
}

View File

@ -45,7 +45,7 @@ func Query(scope *Scope) {
if !scope.HasError() {
rows, err := scope.SqlDB().Query(scope.Sql, scope.SqlVars...)
scope.db.RowsAffected = 0
scope.db.SetRowsAffected(0)
if scope.Err(err) != nil {
return
@ -54,7 +54,8 @@ func Query(scope *Scope) {
columns, _ := rows.Columns()
for rows.Next() {
scope.db.RowsAffected++
rowsAffected := scope.db.GetRowsAffected()+1
scope.db.SetRowsAffected(rowsAffected)
anyRecordFound = true
elem := dest

View File

@ -18,7 +18,7 @@ func SaveBeforeAssociations(scope *Scope) {
if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored {
if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
value := field.Field
scope.Err(scope.NewDB().Save(value.Addr().Interface()).Error)
scope.Err(scope.NewDB().Save(value.Addr().Interface()).GetError())
if len(relationship.ForeignFieldNames) != 0 {
for idx, fieldName := range relationship.ForeignFieldNames {
associationForeignName := relationship.AssociationForeignDBNames[idx]
@ -62,7 +62,7 @@ func SaveAfterAssociations(scope *Scope) {
scope.Err(newScope.SetColumn(relationship.PolymorphicType, scope.TableName()))
}
scope.Err(newDB.Save(elem).Error)
scope.Err(newDB.Save(elem).GetError())
if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil {
scope.Err(joinTableHandler.Add(joinTableHandler, scope.NewDB(), scope.Value, newScope.Value))
@ -83,7 +83,7 @@ func SaveAfterAssociations(scope *Scope) {
if relationship.PolymorphicType != "" {
scope.Err(newScope.SetColumn(relationship.PolymorphicType, scope.TableName()))
}
scope.Err(scope.NewDB().Save(elem).Error)
scope.Err(scope.NewDB().Save(elem).GetError())
}
}
}

View File

@ -108,22 +108,22 @@ func TestRunCallbacks(t *testing.T) {
t.Errorf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes())
}
if DB.Where("Code = ?", "unique_code").First(&p).Error == nil {
if DB.Where("Code = ?", "unique_code").First(&p).GetError() == nil {
t.Errorf("Can't find a deleted record")
}
}
func TestCallbacksWithErrors(t *testing.T) {
p := Product{Code: "Invalid", Price: 100}
if DB.Save(&p).Error == nil {
if DB.Save(&p).GetError() == nil {
t.Errorf("An error from before create callbacks happened when create with invalid value")
}
if DB.Where("code = ?", "Invalid").First(&Product{}).Error == nil {
if DB.Where("code = ?", "Invalid").First(&Product{}).GetError() == nil {
t.Errorf("Should not save record that have errors")
}
if DB.Save(&Product{Code: "dont_save", Price: 100}).Error == nil {
if DB.Save(&Product{Code: "dont_save", Price: 100}).GetError() == nil {
t.Errorf("An error from after create callbacks happened when create with invalid value")
}
@ -131,47 +131,47 @@ func TestCallbacksWithErrors(t *testing.T) {
DB.Save(&p2)
p2.Code = "dont_update"
if DB.Save(&p2).Error == nil {
if DB.Save(&p2).GetError() == nil {
t.Errorf("An error from before update callbacks happened when update with invalid value")
}
if DB.Where("code = ?", "update_callback").First(&Product{}).Error != nil {
if DB.Where("code = ?", "update_callback").First(&Product{}).GetError() != nil {
t.Errorf("Record Should not be updated due to errors happened in before update callback")
}
if DB.Where("code = ?", "dont_update").First(&Product{}).Error == nil {
if DB.Where("code = ?", "dont_update").First(&Product{}).GetError() == nil {
t.Errorf("Record Should not be updated due to errors happened in before update callback")
}
p2.Code = "dont_save"
if DB.Save(&p2).Error == nil {
if DB.Save(&p2).GetError() == nil {
t.Errorf("An error from before save callbacks happened when update with invalid value")
}
p3 := Product{Code: "dont_delete", Price: 100}
DB.Save(&p3)
if DB.Delete(&p3).Error == nil {
if DB.Delete(&p3).GetError() == nil {
t.Errorf("An error from before delete callbacks happened when delete")
}
if DB.Where("Code = ?", "dont_delete").First(&p3).Error != nil {
if DB.Where("Code = ?", "dont_delete").First(&p3).GetError() != nil {
t.Errorf("An error from before delete callbacks happened")
}
p4 := Product{Code: "after_save_error", Price: 100}
DB.Save(&p4)
if err := DB.First(&Product{}, "code = ?", "after_save_error").Error; err == nil {
if err := DB.First(&Product{}, "code = ?", "after_save_error").GetError(); err == nil {
t.Errorf("Record should be reverted if get an error in after save callback")
}
p5 := Product{Code: "after_delete_error", Price: 100}
DB.Save(&p5)
if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
if err := DB.First(&Product{}, "code = ?", "after_delete_error").GetError(); err != nil {
t.Errorf("Record should be found")
}
DB.Delete(&p5)
if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
if err := DB.First(&Product{}, "code = ?", "after_delete_error").GetError(); err != nil {
t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback")
}
}

View File

@ -96,7 +96,7 @@ func (c commonDialect) HasIndex(scope *Scope, tableName string, indexName string
}
func (commonDialect) RemoveIndex(scope *Scope, indexName string) {
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())).Error)
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())).GetError())
}
// RawScanInt scans the first column of the first row into the `scan' int pointer.

View File

@ -15,7 +15,7 @@ func TestCreate(t *testing.T) {
t.Error("User should be new record before create")
}
if count := DB.Save(&user).RowsAffected; count != 1 {
if count := DB.Save(&user).GetRowsAffected(); count != 1 {
t.Error("There should be one record be affected when create record")
}
@ -63,7 +63,7 @@ func TestCreateWithNoGORMPrimayKey(t *testing.T) {
}
jt := JoinTable{From: 1, To: 2}
err := DB.Create(&jt).Error
err := DB.Create(&jt).GetError()
if err != nil {
t.Errorf("No error should happen when create a record without a GORM primary key. But in the database this primary key exists and is the union of 2 or more fields\n But got: %s", err)
}
@ -71,7 +71,7 @@ func TestCreateWithNoGORMPrimayKey(t *testing.T) {
func TestCreateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) {
animal := Animal{Name: "Ferdinand"}
if DB.Save(&animal).Error != nil {
if DB.Save(&animal).GetError() != nil {
t.Errorf("No error should happen when create a record without std primary key")
}
@ -86,7 +86,7 @@ func TestCreateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) {
// Test create with default value not overrided
an := Animal{From: "nerdz"}
if DB.Save(&an).Error != nil {
if DB.Save(&an).GetError() != nil {
t.Errorf("No error should happen when create an record without std primary key")
}

View File

@ -38,7 +38,7 @@ func TestCustomizeColumn(t *testing.T) {
expected := "foo"
cc := CustomizeColumn{ID: 666, Name: expected, Date: time.Now()}
if count := DB.Create(&cc).RowsAffected; count != 1 {
if count := DB.Create(&cc).GetRowsAffected(); count != 1 {
t.Error("There should be one record be affected when create record")
}
@ -61,7 +61,7 @@ func TestCustomizeColumn(t *testing.T) {
func TestCustomColumnAndIgnoredFieldClash(t *testing.T) {
DB.DropTable(&CustomColumnAndIgnoredFieldClash{})
if err := DB.AutoMigrate(&CustomColumnAndIgnoredFieldClash{}).Error; err != nil {
if err := DB.AutoMigrate(&CustomColumnAndIgnoredFieldClash{}).GetError(); err != nil {
t.Errorf("Should not raise error: %s", err)
}
}
@ -86,17 +86,17 @@ func TestManyToManyWithCustomizedColumn(t *testing.T) {
Accounts: []CustomizeAccount{account},
}
if err := DB.Create(&account).Error; err != nil {
if err := DB.Create(&account).GetError(); err != nil {
t.Errorf("no error should happen, but got %v", err)
}
if err := DB.Create(&person).Error; err != nil {
if err := DB.Create(&person).GetError(); err != nil {
t.Errorf("no error should happen, but got %v", err)
}
var person1 CustomizePerson
scope := DB.NewScope(nil)
if err := DB.Preload("Accounts").First(&person1, scope.Quote("idPerson")+" = ?", person.IdPerson).Error; err != nil {
if err := DB.Preload("Accounts").First(&person1, scope.Quote("idPerson")+" = ?", person.IdPerson).GetError(); err != nil {
t.Errorf("no error should happen when preloading customized column many2many relations, but got %v", err)
}
@ -131,7 +131,7 @@ func TestOneToOneWithCustomizedColumn(t *testing.T) {
DB.Create(&invitation)
var invitation2 CustomizeInvitation
if err := DB.Preload("Person").Find(&invitation2, invitation.ID).Error; err != nil {
if err := DB.Preload("Person").Find(&invitation2, invitation.ID).GetError(); err != nil {
t.Errorf("no error should happen, but got %v", err)
}
@ -183,12 +183,12 @@ func TestOneToManyWithCustomizedColumn(t *testing.T) {
},
}
if err := DB.Create(&discount).Error; err != nil {
if err := DB.Create(&discount).GetError(); err != nil {
t.Errorf("no error should happen but got %v", err)
}
var discount1 PromotionDiscount
if err := DB.Preload("Coupons").First(&discount1, "id = ?", discount.ID).Error; err != nil {
if err := DB.Preload("Coupons").First(&discount1, "id = ?", discount.ID).GetError(); err != nil {
t.Errorf("no error should happen but got %v", err)
}
@ -197,7 +197,7 @@ func TestOneToManyWithCustomizedColumn(t *testing.T) {
}
var coupon PromotionCoupon
if err := DB.Preload("Discount").First(&coupon, "code = ?", "newyear1").Error; err != nil {
if err := DB.Preload("Discount").First(&coupon, "code = ?", "newyear1").GetError(); err != nil {
t.Errorf("no error should happen but got %v", err)
}
@ -221,12 +221,12 @@ func TestHasOneWithPartialCustomizedColumn(t *testing.T) {
},
}
if err := DB.Create(&discount).Error; err != nil {
if err := DB.Create(&discount).GetError(); err != nil {
t.Errorf("no error should happen but got %v", err)
}
var discount1 PromotionDiscount
if err := DB.Preload("Rule").First(&discount1, "id = ?", discount.ID).Error; err != nil {
if err := DB.Preload("Rule").First(&discount1, "id = ?", discount.ID).GetError(); err != nil {
t.Errorf("no error should happen but got %v", err)
}
@ -235,7 +235,7 @@ func TestHasOneWithPartialCustomizedColumn(t *testing.T) {
}
var rule PromotionRule
if err := DB.Preload("Discount").First(&rule, "name = ?", "time_limited").Error; err != nil {
if err := DB.Preload("Discount").First(&rule, "name = ?", "time_limited").GetError(); err != nil {
t.Errorf("no error should happen but got %v", err)
}
@ -256,12 +256,12 @@ func TestBelongsToWithPartialCustomizedColumn(t *testing.T) {
},
}
if err := DB.Create(&discount).Error; err != nil {
if err := DB.Create(&discount).GetError(); err != nil {
t.Errorf("no error should happen but got %v", err)
}
var discount1 PromotionDiscount
if err := DB.Preload("Benefits").First(&discount1, "id = ?", discount.ID).Error; err != nil {
if err := DB.Preload("Benefits").First(&discount1, "id = ?", discount.ID).GetError(); err != nil {
t.Errorf("no error should happen but got %v", err)
}
@ -270,7 +270,7 @@ func TestBelongsToWithPartialCustomizedColumn(t *testing.T) {
}
var benefit PromotionBenefit
if err := DB.Preload("Discount").First(&benefit, "name = ?", "free cod").Error; err != nil {
if err := DB.Preload("Discount").First(&benefit, "name = ?", "free cod").GetError(); err != nil {
t.Errorf("no error should happen but got %v", err)
}

View File

@ -18,7 +18,7 @@ func TestDdlErrors(t *testing.T) {
}()
DB.HasTable("foobarbaz")
if DB.Error == nil {
if DB.GetError() == nil {
t.Errorf("Expected operation on closed db to produce an error, but err was nil")
}
}

View File

@ -10,7 +10,7 @@ func TestDelete(t *testing.T) {
DB.Save(&user1)
DB.Save(&user2)
if err := DB.Delete(&user1).Error; err != nil {
if err := DB.Delete(&user1).GetError(); err != nil {
t.Errorf("No error should happen when delete a record, err=%s", err)
}
@ -28,13 +28,13 @@ func TestInlineDelete(t *testing.T) {
DB.Save(&user1)
DB.Save(&user2)
if DB.Delete(&User{}, user1.Id).Error != nil {
if DB.Delete(&User{}, user1.Id).GetError() != nil {
t.Errorf("No error should happen when delete a record")
} else if !DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() {
t.Errorf("User can't be found after delete")
}
if err := DB.Delete(&User{}, "name = ?", user2.Name).Error; err != nil {
if err := DB.Delete(&User{}, "name = ?", user2.Name).GetError(); err != nil {
t.Errorf("No error should happen when delete a record, err=%s", err)
} else if !DB.Where("name = ?", user2.Name).First(&User{}).RecordNotFound() {
t.Errorf("User can't be found after delete")
@ -53,11 +53,11 @@ func TestSoftDelete(t *testing.T) {
DB.Save(&user)
DB.Delete(&user)
if DB.First(&User{}, "name = ?", user.Name).Error == nil {
if DB.First(&User{}, "name = ?", user.Name).GetError() == nil {
t.Errorf("Can't find a soft deleted record")
}
if err := DB.Unscoped().First(&User{}, "name = ?", user.Name).Error; err != nil {
if err := DB.Unscoped().First(&User{}, "name = ?", user.Name).GetError(); err != nil {
t.Errorf("Should be able to find soft deleted record with Unscoped, but err=%s", err)
}

View File

@ -22,7 +22,7 @@ func TestSaveAndQueryEmbeddedStruct(t *testing.T) {
DB.Save(&HNPost{BasePost: BasePost{Title: "news"}})
DB.Save(&HNPost{BasePost: BasePost{Title: "hn_news"}})
var news HNPost
if err := DB.First(&news, "title = ?", "hn_news").Error; err != nil {
if err := DB.First(&news, "title = ?", "hn_news").GetError(); err != nil {
t.Errorf("no error should happen when query with embedded struct, but got %v", err)
} else if news.Title != "hn_news" {
t.Errorf("embedded struct's value should be scanned correctly")
@ -30,7 +30,7 @@ func TestSaveAndQueryEmbeddedStruct(t *testing.T) {
DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}})
var egNews EngadgetPost
if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil {
if err := DB.First(&egNews, "title = ?", "engadget_news").GetError(); err != nil {
t.Errorf("no error should happen when query with embedded struct, but got %v", err)
} else if egNews.BasePost.Title != "engadget_news" {
t.Errorf("embedded struct's value should be scanned correctly")

View File

@ -1,6 +1,8 @@
package gorm
import "database/sql"
import (
"database/sql"
)
type sqlCommon interface {
Exec(query string, args ...interface{}) (sql.Result, error)
@ -17,3 +19,95 @@ type sqlTx interface {
Commit() error
Rollback() error
}
type Database interface {
Close() error
DB() *sql.DB
New() Database
NewScope(value interface{}) *Scope
CommonDB() sqlCommon
Callback() *callback
SetLogger(l logger)
LogMode(enable bool) Database
SingularTable(enable bool)
Where(query interface{}, args ...interface{}) Database
Or(query interface{}, args ...interface{}) Database
Not(query interface{}, args ...interface{}) Database
Limit(value interface{}) Database
Offset(value interface{}) Database
Order(value string, reorder ...bool) Database
Select(query interface{}, args ...interface{}) Database
Omit(columns ...string) Database
Group(query string) Database
Having(query string, values ...interface{}) Database
Joins(query string) Database
//Scopes(funcs ...func(Database) Database) Database
Scopes(funcs ...func(*DB) *DB) *DB
Unscoped() Database
Attrs(attrs ...interface{}) Database
Assign(attrs ...interface{}) Database
First(out interface{}, where ...interface{}) Database
Last(out interface{}, where ...interface{}) Database
Find(out interface{}, where ...interface{}) Database
Scan(dest interface{}) Database
Row() *sql.Row
Rows() (*sql.Rows, error)
Pluck(column string, value interface{}) Database
Count(value interface{}) Database
Related(value interface{}, foreignKeys ...string) Database
FirstOrInit(out interface{}, where ...interface{}) Database
FirstOrCreate(out interface{}, where ...interface{}) Database
Update(attrs ...interface{}) Database
Updates(values interface{}, ignoreProtectedAttrs ...bool) Database
UpdateColumn(attrs ...interface{}) Database
UpdateColumns(values interface{}) Database
Save(value interface{}) Database
Create(value interface{}) Database
Delete(value interface{}, where ...interface{}) Database
Raw(sql string, values ...interface{}) Database
Exec(sql string, values ...interface{}) Database
Model(value interface{}) Database
Table(name string) Database
Debug() Database
Begin() Database
Commit() Database
Rollback() Database
NewRecord(value interface{}) bool
RecordNotFound() bool
CreateTable(values ...interface{}) Database
DropTable(values ...interface{}) Database
DropTableIfExists(values ...interface{}) Database
HasTable(value interface{}) bool
AutoMigrate(values ...interface{}) Database
ModifyColumn(column string, typ string) Database
DropColumn(column string) Database
AddIndex(indexName string, column ...string) Database
AddUniqueIndex(indexName string, column ...string) Database
RemoveIndex(indexName string) Database
CurrentDatabase() string
AddForeignKey(field string, dest string, onDelete string, onUpdate string) Database
Association(column string) *Association
Preload(column string, conditions ...interface{}) Database
Set(name string, value interface{}) Database
InstantSet(name string, value interface{}) Database
Get(name string) (value interface{}, ok bool)
SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface)
AddError(err error) error
GetError() error
GetErrors() (errors []error)
GetRowsAffected() int64
SetRowsAffected(num int64)
}

View File

@ -9,10 +9,10 @@ import (
type JoinTableHandlerInterface interface {
Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type)
Table(db *DB) string
Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error
Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error
JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB
Table(db Database) string
Add(handler JoinTableHandlerInterface, db Database, source interface{}, destination interface{}) error
Delete(handler JoinTableHandlerInterface, db Database, sources ...interface{}) error
JoinWith(handler JoinTableHandlerInterface, db Database, source interface{}) Database
SourceForeignKeys() []JoinTableForeignKey
DestinationForeignKeys() []JoinTableForeignKey
}
@ -61,11 +61,11 @@ func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, s
}
}
func (s JoinTableHandler) Table(db *DB) string {
func (s JoinTableHandler) Table(db Database) string {
return s.TableName
}
func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[string]interface{} {
func (s JoinTableHandler) GetSearchMap(db Database, sources ...interface{}) map[string]interface{} {
values := map[string]interface{}{}
for _, source := range sources {
@ -85,7 +85,7 @@ func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[strin
return values
}
func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1 interface{}, source2 interface{}) error {
func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db Database, source1 interface{}, source2 interface{}) error {
scope := db.NewScope("")
searchMap := s.GetSearchMap(db, source1, source2)
@ -113,10 +113,10 @@ func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1
strings.Join(conditions, " AND "),
)
return db.Exec(sql, values...).Error
return db.Exec(sql, values...).GetError()
}
func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error {
func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db Database, sources ...interface{}) error {
var (
scope = db.NewScope(nil)
conditions []string
@ -128,10 +128,10 @@ func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sour
values = append(values, value)
}
return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error
return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").GetError()
}
func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB {
func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db Database, source interface{}) Database {
var (
scope = db.NewScope(source)
tableName = handler.Table(db)
@ -174,7 +174,7 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so
return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTableName, strings.Join(joinConditions, " AND "))).
Where(condString, toQueryValues(foreignFieldValues)...)
} else {
db.Error = errors.New("wrong source type for join table handler")
db.AddError(errors.New("wrong source type for join table handler"))
return db
}
}

View File

@ -22,7 +22,7 @@ type PersonAddress struct {
CreatedAt time.Time
}
func (*PersonAddress) Add(handler gorm.JoinTableHandlerInterface, db *gorm.DB, foreignValue interface{}, associationValue interface{}) error {
func (*PersonAddress) Add(handler gorm.JoinTableHandlerInterface, db gorm.Database, foreignValue interface{}, associationValue interface{}) error {
return db.Where(map[string]interface{}{
"person_id": db.NewScope(foreignValue).PrimaryKeyValue(),
"address_id": db.NewScope(associationValue).PrimaryKeyValue(),
@ -30,14 +30,14 @@ func (*PersonAddress) Add(handler gorm.JoinTableHandlerInterface, db *gorm.DB, f
"person_id": foreignValue,
"address_id": associationValue,
"deleted_at": gorm.Expr("NULL"),
}).FirstOrCreate(&PersonAddress{}).Error
}).FirstOrCreate(&PersonAddress{}).GetError()
}
func (*PersonAddress) Delete(handler gorm.JoinTableHandlerInterface, db *gorm.DB, sources ...interface{}) error {
return db.Delete(&PersonAddress{}).Error
func (*PersonAddress) Delete(handler gorm.JoinTableHandlerInterface, db gorm.Database, sources ...interface{}) error {
return db.Delete(&PersonAddress{}).GetError()
}
func (pa *PersonAddress) JoinWith(handler gorm.JoinTableHandlerInterface, db *gorm.DB, source interface{}) *gorm.DB {
func (pa *PersonAddress) JoinWith(handler gorm.JoinTableHandlerInterface, db gorm.Database, source interface{}) gorm.Database {
table := pa.Table(db)
return db.Joins("INNER JOIN person_addresses ON person_addresses.address_id = addresses.id").Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table))
}
@ -54,7 +54,7 @@ func TestJoinTable(t *testing.T) {
DB.Model(person).Association("Addresses").Delete(address1)
if DB.Find(&[]PersonAddress{}, "person_id = ?", person.Id).RowsAffected != 1 {
if DB.Find(&[]PersonAddress{}, "person_id = ?", person.Id).GetRowsAffected() != 1 {
t.Errorf("Should found one address")
}
@ -62,7 +62,7 @@ func TestJoinTable(t *testing.T) {
t.Errorf("Should found one address")
}
if DB.Unscoped().Find(&[]PersonAddress{}, "person_id = ?", person.Id).RowsAffected != 2 {
if DB.Unscoped().Find(&[]PersonAddress{}, "person_id = ?", person.Id).GetRowsAffected() != 2 {
t.Errorf("Found two addresses with Unscoped")
}

140
main.go
View File

@ -90,7 +90,7 @@ func (s *DB) DB() *sql.DB {
return s.db.(*sql.DB)
}
func (s *DB) New() *DB {
func (s *DB) New() Database {
clone := s.clone()
clone.search = nil
clone.Value = nil
@ -120,7 +120,7 @@ func (s *DB) SetLogger(l logger) {
s.logger = l
}
func (s *DB) LogMode(enable bool) *DB {
func (s *DB) LogMode(enable bool) Database {
if enable {
s.logMode = 2
} else {
@ -134,47 +134,47 @@ func (s *DB) SingularTable(enable bool) {
s.parent.singularTable = enable
}
func (s *DB) Where(query interface{}, args ...interface{}) *DB {
func (s *DB) Where(query interface{}, args ...interface{}) Database {
return s.clone().search.Where(query, args...).db
}
func (s *DB) Or(query interface{}, args ...interface{}) *DB {
func (s *DB) Or(query interface{}, args ...interface{}) Database {
return s.clone().search.Or(query, args...).db
}
func (s *DB) Not(query interface{}, args ...interface{}) *DB {
func (s *DB) Not(query interface{}, args ...interface{}) Database {
return s.clone().search.Not(query, args...).db
}
func (s *DB) Limit(value interface{}) *DB {
func (s *DB) Limit(value interface{}) Database {
return s.clone().search.Limit(value).db
}
func (s *DB) Offset(value interface{}) *DB {
func (s *DB) Offset(value interface{}) Database {
return s.clone().search.Offset(value).db
}
func (s *DB) Order(value string, reorder ...bool) *DB {
func (s *DB) Order(value string, reorder ...bool) Database {
return s.clone().search.Order(value, reorder...).db
}
func (s *DB) Select(query interface{}, args ...interface{}) *DB {
func (s *DB) Select(query interface{}, args ...interface{}) Database {
return s.clone().search.Select(query, args...).db
}
func (s *DB) Omit(columns ...string) *DB {
func (s *DB) Omit(columns ...string) Database {
return s.clone().search.Omit(columns...).db
}
func (s *DB) Group(query string) *DB {
func (s *DB) Group(query string) Database {
return s.clone().search.Group(query).db
}
func (s *DB) Having(query string, values ...interface{}) *DB {
func (s *DB) Having(query string, values ...interface{}) Database {
return s.clone().search.Having(query, values...).db
}
func (s *DB) Joins(query string) *DB {
func (s *DB) Joins(query string) Database {
return s.clone().search.Joins(query).db
}
@ -185,37 +185,37 @@ func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB {
return s
}
func (s *DB) Unscoped() *DB {
func (s *DB) Unscoped() Database {
return s.clone().search.unscoped().db
}
func (s *DB) Attrs(attrs ...interface{}) *DB {
func (s *DB) Attrs(attrs ...interface{}) Database {
return s.clone().search.Attrs(attrs...).db
}
func (s *DB) Assign(attrs ...interface{}) *DB {
func (s *DB) Assign(attrs ...interface{}) Database {
return s.clone().search.Assign(attrs...).db
}
func (s *DB) First(out interface{}, where ...interface{}) *DB {
func (s *DB) First(out interface{}, where ...interface{}) Database {
newScope := s.clone().NewScope(out)
newScope.Search.Limit(1)
return newScope.Set("gorm:order_by_primary_key", "ASC").
inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
}
func (s *DB) Last(out interface{}, where ...interface{}) *DB {
func (s *DB) Last(out interface{}, where ...interface{}) Database {
newScope := s.clone().NewScope(out)
newScope.Search.Limit(1)
return newScope.Set("gorm:order_by_primary_key", "DESC").
inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
}
func (s *DB) Find(out interface{}, where ...interface{}) *DB {
func (s *DB) Find(out interface{}, where ...interface{}) Database {
return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
}
func (s *DB) Scan(dest interface{}) *DB {
func (s *DB) Scan(dest interface{}) Database {
return s.clone().NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callback.queries).db
}
@ -227,21 +227,21 @@ func (s *DB) Rows() (*sql.Rows, error) {
return s.NewScope(s.Value).rows()
}
func (s *DB) Pluck(column string, value interface{}) *DB {
func (s *DB) Pluck(column string, value interface{}) Database {
return s.NewScope(s.Value).pluck(column, value).db
}
func (s *DB) Count(value interface{}) *DB {
func (s *DB) Count(value interface{}) Database {
return s.NewScope(s.Value).count(value).db
}
func (s *DB) Related(value interface{}, foreignKeys ...string) *DB {
func (s *DB) Related(value interface{}, foreignKeys ...string) Database {
return s.clone().NewScope(s.Value).related(value, foreignKeys...).db
}
func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
func (s *DB) FirstOrInit(out interface{}, where ...interface{}) Database {
c := s.clone()
if result := c.First(out, where...); result.Error != nil {
if result := c.First(out, where...); result.GetError() != nil {
if !result.RecordNotFound() {
return result
}
@ -252,35 +252,35 @@ func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
return c
}
func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) Database {
c := s.clone()
if result := c.First(out, where...); result.Error != nil {
if result := c.First(out, where...); result.GetError() != nil {
if !result.RecordNotFound() {
return result
}
c.AddError(c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callback.creates).db.Error)
c.AddError(c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callback.creates).db.GetError())
} else if len(c.search.assignAttrs) > 0 {
c.AddError(c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callback.updates).db.Error)
c.AddError(c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callback.updates).db.GetError())
}
return c
}
func (s *DB) Update(attrs ...interface{}) *DB {
func (s *DB) Update(attrs ...interface{}) Database {
return s.Updates(toSearchableMap(attrs...), true)
}
func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB {
func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) Database {
return s.clone().NewScope(s.Value).
Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0).
InstanceSet("gorm:update_interface", values).
callCallbacks(s.parent.callback.updates).db
}
func (s *DB) UpdateColumn(attrs ...interface{}) *DB {
func (s *DB) UpdateColumn(attrs ...interface{}) Database {
return s.UpdateColumns(toSearchableMap(attrs...))
}
func (s *DB) UpdateColumns(values interface{}) *DB {
func (s *DB) UpdateColumns(values interface{}) Database {
return s.clone().NewScope(s.Value).
Set("gorm:update_column", true).
Set("gorm:save_associations", false).
@ -288,7 +288,7 @@ func (s *DB) UpdateColumns(values interface{}) *DB {
callCallbacks(s.parent.callback.updates).db
}
func (s *DB) Save(value interface{}) *DB {
func (s *DB) Save(value interface{}) Database {
scope := s.clone().NewScope(value)
if scope.PrimaryKeyZero() {
return scope.callCallbacks(s.parent.callback.creates).db
@ -296,20 +296,20 @@ func (s *DB) Save(value interface{}) *DB {
return scope.callCallbacks(s.parent.callback.updates).db
}
func (s *DB) Create(value interface{}) *DB {
func (s *DB) Create(value interface{}) Database {
scope := s.clone().NewScope(value)
return scope.callCallbacks(s.parent.callback.creates).db
}
func (s *DB) Delete(value interface{}, where ...interface{}) *DB {
func (s *DB) Delete(value interface{}, where ...interface{}) Database {
return s.clone().NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callback.deletes).db
}
func (s *DB) Raw(sql string, values ...interface{}) *DB {
func (s *DB) Raw(sql string, values ...interface{}) Database {
return s.clone().search.Raw(true).Where(sql, values...).db
}
func (s *DB) Exec(sql string, values ...interface{}) *DB {
func (s *DB) Exec(sql string, values ...interface{}) Database {
scope := s.clone().NewScope(nil)
generatedSql := scope.buildWhereCondition(map[string]interface{}{"query": sql, "args": values})
generatedSql = strings.TrimSuffix(strings.TrimPrefix(generatedSql, "("), ")")
@ -317,24 +317,24 @@ func (s *DB) Exec(sql string, values ...interface{}) *DB {
return scope.Exec().db
}
func (s *DB) Model(value interface{}) *DB {
func (s *DB) Model(value interface{}) Database {
c := s.clone()
c.Value = value
return c
}
func (s *DB) Table(name string) *DB {
func (s *DB) Table(name string) Database {
clone := s.clone()
clone.search.Table(name)
clone.Value = nil
return clone
}
func (s *DB) Debug() *DB {
func (s *DB) Debug() Database {
return s.clone().LogMode(true)
}
func (s *DB) Begin() *DB {
func (s *DB) Begin() Database {
c := s.clone()
if db, ok := c.db.(sqlDb); ok {
tx, err := db.Begin()
@ -346,7 +346,7 @@ func (s *DB) Begin() *DB {
return c
}
func (s *DB) Commit() *DB {
func (s *DB) Commit() Database {
if db, ok := s.db.(sqlTx); ok {
s.AddError(db.Commit())
} else {
@ -355,7 +355,7 @@ func (s *DB) Commit() *DB {
return s
}
func (s *DB) Rollback() *DB {
func (s *DB) Rollback() Database {
if db, ok := s.db.(sqlTx); ok {
s.AddError(db.Rollback())
} else {
@ -373,16 +373,18 @@ func (s *DB) RecordNotFound() bool {
}
// Migrations
func (s *DB) CreateTable(values ...interface{}) *DB {
db := s.clone()
func (s *DB) CreateTable(values ...interface{}) Database {
var db Database
db = s.clone()
for _, value := range values {
db = db.NewScope(value).createTable().db
}
return db
}
func (s *DB) DropTable(values ...interface{}) *DB {
db := s.clone()
func (s *DB) DropTable(values ...interface{}) Database {
var db Database
db = s.clone()
for _, value := range values {
if tableName, ok := value.(string); ok {
db = db.Table(tableName)
@ -393,8 +395,9 @@ func (s *DB) DropTable(values ...interface{}) *DB {
return db
}
func (s *DB) DropTableIfExists(values ...interface{}) *DB {
db := s.clone()
func (s *DB) DropTableIfExists(values ...interface{}) Database {
var db Database
db = s.clone()
for _, value := range values {
if tableName, ok := value.(string); ok {
db = db.Table(tableName)
@ -409,43 +412,44 @@ func (s *DB) HasTable(value interface{}) bool {
scope := s.clone().NewScope(value)
tableName := scope.TableName()
has := scope.Dialect().HasTable(scope, tableName)
s.AddError(scope.db.Error)
s.AddError(scope.db.GetError())
return has
}
func (s *DB) AutoMigrate(values ...interface{}) *DB {
db := s.clone()
func (s *DB) AutoMigrate(values ...interface{}) Database {
var db Database
db = s.clone()
for _, value := range values {
db = db.NewScope(value).NeedPtr().autoMigrate().db
}
return db
}
func (s *DB) ModifyColumn(column string, typ string) *DB {
func (s *DB) ModifyColumn(column string, typ string) Database {
scope := s.clone().NewScope(s.Value)
scope.modifyColumn(column, typ)
return scope.db
}
func (s *DB) DropColumn(column string) *DB {
func (s *DB) DropColumn(column string) Database {
scope := s.clone().NewScope(s.Value)
scope.dropColumn(column)
return scope.db
}
func (s *DB) AddIndex(indexName string, column ...string) *DB {
func (s *DB) AddIndex(indexName string, column ...string) Database {
scope := s.Unscoped().NewScope(s.Value)
scope.addIndex(false, indexName, column...)
return scope.db
}
func (s *DB) AddUniqueIndex(indexName string, column ...string) *DB {
func (s *DB) AddUniqueIndex(indexName string, column ...string) Database {
scope := s.clone().NewScope(s.Value)
scope.addIndex(true, indexName, column...)
return scope.db
}
func (s *DB) RemoveIndex(indexName string) *DB {
func (s *DB) RemoveIndex(indexName string) Database {
scope := s.clone().NewScope(s.Value)
scope.removeIndex(indexName)
return scope.db
@ -465,7 +469,7 @@ Add foreign key to the given scope
Example:
db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT")
*/
func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB {
func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) Database {
scope := s.clone().NewScope(s.Value)
scope.addForeignKey(field, dest, onDelete, onUpdate)
return scope.db
@ -492,16 +496,16 @@ func (s *DB) Association(column string) *Association {
return &Association{Error: err}
}
func (s *DB) Preload(column string, conditions ...interface{}) *DB {
func (s *DB) Preload(column string, conditions ...interface{}) Database {
return s.clone().search.Preload(column, conditions...).db
}
// Set set value by name
func (s *DB) Set(name string, value interface{}) *DB {
func (s *DB) Set(name string, value interface{}) Database {
return s.clone().InstantSet(name, value)
}
func (s *DB) InstantSet(name string, value interface{}) *DB {
func (s *DB) InstantSet(name string, value interface{}) Database {
s.values[name] = value
return s
}
@ -550,6 +554,18 @@ func (s *DB) AddError(err error) error {
return err
}
func (s *DB) GetError() error {
return s.Error
}
func (s *DB) SetRowsAffected(num int64) {
s.RowsAffected = num
}
func (s *DB) GetRowsAffected() int64 {
return s.RowsAffected
}
func (s *DB) GetErrors() (errors []error) {
if errs, ok := s.Error.(errorsInterface); ok {
return errs.GetErrors()

View File

@ -20,7 +20,7 @@ import (
)
var (
DB gorm.DB
DB gorm.Database
t1, t2, t3, t4, t5 time.Time
)
@ -42,7 +42,11 @@ func init() {
runMigration()
}
func OpenTestConnection() (db gorm.DB, err error) {
func OpenTestConnection() (*gorm.DB, error) {
var (
db gorm.DB
err error
)
switch os.Getenv("GORM_DIALECT") {
case "mysql":
// CREATE USER 'gorm'@'localhost' IDENTIFIED BY 'gorm';
@ -63,7 +67,8 @@ func OpenTestConnection() (db gorm.DB, err error) {
fmt.Println("testing sqlite3...")
db, err = gorm.Open("sqlite3", "/tmp/gorm.db")
}
return
return &db, err
}
func TestStringPrimaryKey(t *testing.T) {
@ -74,22 +79,22 @@ func TestStringPrimaryKey(t *testing.T) {
DB.AutoMigrate(&UUIDStruct{})
data := UUIDStruct{ID: "uuid", Name: "hello"}
if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" {
if err := DB.Save(&data).GetError(); err != nil || data.ID != "uuid" {
t.Errorf("string primary key should not be populated")
}
}
func TestExceptionsWithInvalidSql(t *testing.T) {
var columns []string
if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil {
if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).GetError() == nil {
t.Errorf("Should got error with invalid SQL")
}
if DB.Model(&User{}).Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil {
if DB.Model(&User{}).Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).GetError() == nil {
t.Errorf("Should got error with invalid SQL")
}
if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Find(&User{}).Error == nil {
if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Find(&User{}).GetError() == nil {
t.Errorf("Should got error with invalid SQL")
}
@ -99,7 +104,7 @@ func TestExceptionsWithInvalidSql(t *testing.T) {
t.Errorf("Should find some users")
}
if DB.Where("name = ?", "jinzhu; delete * from users").First(&User{}).Error == nil {
if DB.Where("name = ?", "jinzhu; delete * from users").First(&User{}).GetError() == nil {
t.Errorf("Should got error with invalid SQL")
}
@ -114,21 +119,21 @@ func TestSetTable(t *testing.T) {
DB.Create(getPreparedUser("pluck_user2", "pluck_user"))
DB.Create(getPreparedUser("pluck_user3", "pluck_user"))
if err := DB.Table("users").Where("role = ?", "pluck_user").Pluck("age", &[]int{}).Error; err != nil {
if err := DB.Table("users").Where("role = ?", "pluck_user").Pluck("age", &[]int{}).GetError(); err != nil {
t.Errorf("No errors should happen if set table for pluck", err.Error())
}
var users []User
if DB.Table("users").Find(&[]User{}).Error != nil {
if DB.Table("users").Find(&[]User{}).GetError() != nil {
t.Errorf("No errors should happen if set table for find")
}
if DB.Table("invalid_table").Find(&users).Error == nil {
if DB.Table("invalid_table").Find(&users).GetError() == nil {
t.Errorf("Should got error when table is set to an invalid table")
}
DB.Exec("drop table deleted_users;")
if DB.Table("deleted_users").CreateTable(&User{}).Error != nil {
if DB.Table("deleted_users").CreateTable(&User{}).GetError() != nil {
t.Errorf("Create table with specified table")
}
@ -168,7 +173,7 @@ func TestHasTable(t *testing.T) {
if ok := DB.HasTable(&Foo{}); ok {
t.Errorf("Table should not exist, but does")
}
if err := DB.CreateTable(&Foo{}).Error; err != nil {
if err := DB.CreateTable(&Foo{}).GetError(); err != nil {
t.Errorf("Table should be created")
}
if ok := DB.HasTable(&Foo{}); !ok {
@ -240,7 +245,7 @@ func TestNullValues(t *testing.T) {
Male: sql.NullBool{Bool: true, Valid: true},
Height: sql.NullFloat64{Float64: 100.11, Valid: true},
AddedAt: NullTime{Time: time.Now(), Valid: true},
}).Error; err != nil {
}).GetError(); err != nil {
t.Errorf("Not error should raise when test null value")
}
@ -258,7 +263,7 @@ func TestNullValues(t *testing.T) {
Male: sql.NullBool{Bool: true, Valid: true},
Height: sql.NullFloat64{Float64: 100.11, Valid: true},
AddedAt: NullTime{Time: time.Now(), Valid: false},
}).Error; err != nil {
}).GetError(); err != nil {
t.Errorf("Not error should raise when test null value")
}
@ -275,7 +280,7 @@ func TestNullValues(t *testing.T) {
Male: sql.NullBool{Bool: true, Valid: true},
Height: sql.NullFloat64{Float64: 100.11, Valid: true},
AddedAt: NullTime{Time: time.Now(), Valid: false},
}).Error; err == nil {
}).GetError(); err == nil {
t.Errorf("Can't save because of name can't be null")
}
}
@ -287,7 +292,7 @@ func TestNullValuesWithFirstOrCreate(t *testing.T) {
}
var nv2 NullValue
if err := DB.Where(nv1).FirstOrCreate(&nv2).Error; err != nil {
if err := DB.Where(nv1).FirstOrCreate(&nv2).GetError(); err != nil {
t.Errorf("Should not raise any error, but got %v", err)
}
@ -295,7 +300,7 @@ func TestNullValuesWithFirstOrCreate(t *testing.T) {
t.Errorf("first or create with nullvalues")
}
if err := DB.Where(nv1).Assign(NullValue{Age: sql.NullInt64{Int64: 18, Valid: true}}).FirstOrCreate(&nv2).Error; err != nil {
if err := DB.Where(nv1).Assign(NullValue{Age: sql.NullInt64{Int64: 18, Valid: true}}).FirstOrCreate(&nv2).GetError(); err != nil {
t.Errorf("Should not raise any error, but got %v", err)
}
@ -307,11 +312,11 @@ func TestNullValuesWithFirstOrCreate(t *testing.T) {
func TestTransaction(t *testing.T) {
tx := DB.Begin()
u := User{Name: "transcation"}
if err := tx.Save(&u).Error; err != nil {
if err := tx.Save(&u).GetError(); err != nil {
t.Errorf("No error should raise")
}
if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil {
if err := tx.First(&User{}, "name = ?", "transcation").GetError(); err != nil {
t.Errorf("Should find saved record")
}
@ -321,23 +326,23 @@ func TestTransaction(t *testing.T) {
tx.Rollback()
if err := tx.First(&User{}, "name = ?", "transcation").Error; err == nil {
if err := tx.First(&User{}, "name = ?", "transcation").GetError(); err == nil {
t.Errorf("Should not find record after rollback")
}
tx2 := DB.Begin()
u2 := User{Name: "transcation-2"}
if err := tx2.Save(&u2).Error; err != nil {
if err := tx2.Save(&u2).GetError(); err != nil {
t.Errorf("No error should raise")
}
if err := tx2.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
if err := tx2.First(&User{}, "name = ?", "transcation-2").GetError(); err != nil {
t.Errorf("Should find saved record")
}
tx2.Commit()
if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
if err := DB.First(&User{}, "name = ?", "transcation-2").GetError(); err != nil {
t.Errorf("Should be able to find committed record")
}
}
@ -436,7 +441,7 @@ func TestRaw(t *testing.T) {
}
DB.Exec("update users set name=? where name in (?)", "jinzhu", []string{user1.Name, user2.Name, user3.Name})
if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.RecordNotFound {
if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).GetError() != gorm.RecordNotFound {
t.Error("Raw sql to update records")
}
}
@ -568,14 +573,14 @@ func TestHstore(t *testing.T) {
t.Skip()
}
if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS hstore").Error; err != nil {
if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS hstore").GetError(); err != nil {
fmt.Println("\033[31mHINT: Must be superuser to create hstore extension (ALTER USER gorm WITH SUPERUSER;)\033[0m")
panic(fmt.Sprintf("No error should happen when create hstore extension, but got %+v", err))
}
DB.Exec("drop table details")
if err := DB.CreateTable(&Details{}).Error; err != nil {
if err := DB.CreateTable(&Details{}).GetError(); err != nil {
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
}
@ -589,7 +594,7 @@ func TestHstore(t *testing.T) {
DB.Save(&d)
var d2 Details
if err := DB.First(&d2).Error; err != nil {
if err := DB.First(&d2).GetError(); err != nil {
t.Errorf("Got error when tried to fetch details: %+v", err)
}
@ -647,7 +652,7 @@ func TestOpenExistingDB(t *testing.T) {
}
var user User
if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.RecordNotFound {
if db.Where("name = ?", "jnfeinstein").First(&user).GetError() == gorm.RecordNotFound {
t.Errorf("Should have found existing record")
}
}

View File

@ -7,7 +7,7 @@ import (
)
func runMigration() {
if err := DB.DropTableIfExists(&User{}).Error; err != nil {
if err := DB.DropTableIfExists(&User{}).GetError(); err != nil {
fmt.Printf("Got error when try to delete table users, %+v\n", err)
}
@ -20,13 +20,13 @@ func runMigration() {
DB.DropTable(value)
}
if err := DB.AutoMigrate(values...).Error; err != nil {
if err := DB.AutoMigrate(values...).GetError(); err != nil {
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
}
}
func TestIndexes(t *testing.T) {
if err := DB.Model(&Email{}).AddIndex("idx_email_email", "email").Error; err != nil {
if err := DB.Model(&Email{}).AddIndex("idx_email_email", "email").GetError(); err != nil {
t.Errorf("Got error when tried to create index: %+v", err)
}
@ -35,7 +35,7 @@ func TestIndexes(t *testing.T) {
t.Errorf("Email should have index idx_email_email")
}
if err := DB.Model(&Email{}).RemoveIndex("idx_email_email").Error; err != nil {
if err := DB.Model(&Email{}).RemoveIndex("idx_email_email").GetError(); err != nil {
t.Errorf("Got error when tried to remove index: %+v", err)
}
@ -43,7 +43,7 @@ func TestIndexes(t *testing.T) {
t.Errorf("Email's index idx_email_email should be deleted")
}
if err := DB.Model(&Email{}).AddIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil {
if err := DB.Model(&Email{}).AddIndex("idx_email_email_and_user_id", "user_id", "email").GetError(); err != nil {
t.Errorf("Got error when tried to create index: %+v", err)
}
@ -51,7 +51,7 @@ func TestIndexes(t *testing.T) {
t.Errorf("Email should have index idx_email_email_and_user_id")
}
if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil {
if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").GetError(); err != nil {
t.Errorf("Got error when tried to remove index: %+v", err)
}
@ -59,7 +59,7 @@ func TestIndexes(t *testing.T) {
t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
}
if err := DB.Model(&Email{}).AddUniqueIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil {
if err := DB.Model(&Email{}).AddUniqueIndex("idx_email_email_and_user_id", "user_id", "email").GetError(); err != nil {
t.Errorf("Got error when tried to create index: %+v", err)
}
@ -67,7 +67,7 @@ func TestIndexes(t *testing.T) {
t.Errorf("Email should have index idx_email_email_and_user_id")
}
if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.comiii"}, {Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error == nil {
if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.comiii"}, {Email: "user1@example.com"}, {Email: "user1@example.com"}}}).GetError() == nil {
t.Errorf("Should get to create duplicate record when having unique index")
}
@ -81,7 +81,7 @@ func TestIndexes(t *testing.T) {
t.Errorf("Should get no duplicated email error when insert duplicated emails for a user")
}
if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil {
if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").GetError(); err != nil {
t.Errorf("Got error when tried to remove index: %+v", err)
}
@ -89,7 +89,7 @@ func TestIndexes(t *testing.T) {
t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
}
if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error != nil {
if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.com"}, {Email: "user1@example.com"}}}).GetError() != nil {
t.Errorf("Should be able to create duplicated emails after remove unique index")
}
}
@ -110,7 +110,7 @@ func (b BigEmail) TableName() string {
func TestAutoMigration(t *testing.T) {
DB.AutoMigrate(&Address{})
if err := DB.Table("emails").AutoMigrate(&BigEmail{}).Error; err != nil {
if err := DB.Table("emails").AutoMigrate(&BigEmail{}).GetError(); err != nil {
t.Errorf("Auto Migrate should not raise any error")
}

View File

@ -14,7 +14,7 @@ import (
"github.com/jinzhu/inflection"
)
var DefaultTableNameHandler = func(db *DB, defaultTableName string) string {
var DefaultTableNameHandler = func(db Database, defaultTableName string) string {
return defaultTableName
}
@ -48,7 +48,7 @@ type ModelStruct struct {
defaultTableName string
}
func (s *ModelStruct) TableName(db *DB) string {
func (s *ModelStruct) TableName(db Database) string {
return DefaultTableNameHandler(db, s.defaultTableName)
}

View File

@ -20,12 +20,12 @@ func TestPointerFields(t *testing.T) {
var name = "pointer struct 1"
var num = 100
pointerStruct := PointerStruct{Name: &name, Num: &num}
if DB.Create(&pointerStruct).Error != nil {
if DB.Create(&pointerStruct).GetError() != nil {
t.Errorf("Failed to save pointer struct")
}
var pointerStructResult PointerStruct
if err := DB.First(&pointerStructResult, "id = ?", pointerStruct.ID).Error; err != nil || *pointerStructResult.Name != name || *pointerStructResult.Num != num {
if err := DB.First(&pointerStructResult, "id = ?", pointerStruct.ID).GetError(); err != nil || *pointerStructResult.Name != name || *pointerStructResult.Num != num {
t.Errorf("Failed to query saved pointer struct")
}
@ -38,47 +38,47 @@ func TestPointerFields(t *testing.T) {
}
var nilPointerStruct = PointerStruct{}
if err := DB.Create(&nilPointerStruct).Error; err != nil {
if err := DB.Create(&nilPointerStruct).GetError(); err != nil {
t.Errorf("Failed to save nil pointer struct", err)
}
var pointerStruct2 PointerStruct
if err := DB.First(&pointerStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil {
if err := DB.First(&pointerStruct2, "id = ?", nilPointerStruct.ID).GetError(); err != nil {
t.Errorf("Failed to query saved nil pointer struct", err)
}
var normalStruct2 NormalStruct
if err := DB.Table(tableName).First(&normalStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil {
if err := DB.Table(tableName).First(&normalStruct2, "id = ?", nilPointerStruct.ID).GetError(); err != nil {
t.Errorf("Failed to query saved nil pointer struct", err)
}
var partialNilPointerStruct1 = PointerStruct{Num: &num}
if err := DB.Create(&partialNilPointerStruct1).Error; err != nil {
if err := DB.Create(&partialNilPointerStruct1).GetError(); err != nil {
t.Errorf("Failed to save partial nil pointer struct", err)
}
var pointerStruct3 PointerStruct
if err := DB.First(&pointerStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || *pointerStruct3.Num != num {
if err := DB.First(&pointerStruct3, "id = ?", partialNilPointerStruct1.ID).GetError(); err != nil || *pointerStruct3.Num != num {
t.Errorf("Failed to query saved partial nil pointer struct", err)
}
var normalStruct3 NormalStruct
if err := DB.Table(tableName).First(&normalStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || normalStruct3.Num != num {
if err := DB.Table(tableName).First(&normalStruct3, "id = ?", partialNilPointerStruct1.ID).GetError(); err != nil || normalStruct3.Num != num {
t.Errorf("Failed to query saved partial pointer struct", err)
}
var partialNilPointerStruct2 = PointerStruct{Name: &name}
if err := DB.Create(&partialNilPointerStruct2).Error; err != nil {
if err := DB.Create(&partialNilPointerStruct2).GetError(); err != nil {
t.Errorf("Failed to save partial nil pointer struct", err)
}
var pointerStruct4 PointerStruct
if err := DB.First(&pointerStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || *pointerStruct4.Name != name {
if err := DB.First(&pointerStruct4, "id = ?", partialNilPointerStruct2.ID).GetError(); err != nil || *pointerStruct4.Name != name {
t.Errorf("Failed to query saved partial nil pointer struct", err)
}
var normalStruct4 NormalStruct
if err := DB.Table(tableName).First(&normalStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || normalStruct4.Name != name {
if err := DB.Table(tableName).First(&normalStruct4, "id = ?", partialNilPointerStruct2.ID).GetError(); err != nil || normalStruct4.Name != name {
t.Errorf("Failed to query saved partial pointer struct", err)
}
}

View File

@ -94,7 +94,7 @@ func (s postgres) HasColumn(scope *Scope, tableName string, columnName string) b
}
func (postgres) RemoveIndex(scope *Scope, indexName string) {
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error)
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).GetError())
}
func (s postgres) HasIndex(scope *Scope, tableName string, indexName string) bool {

View File

@ -115,7 +115,7 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{})
}
results := makeSlice(field.Struct.Type)
scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).GetError())
resultValues := reflect.Indirect(reflect.ValueOf(results))
for i := 0; i < resultValues.Len(); i++ {
@ -146,7 +146,7 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{})
}
results := makeSlice(field.Struct.Type)
scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).GetError())
resultValues := reflect.Indirect(reflect.ValueOf(results))
if scope.IndirectValue().Kind() == reflect.Slice {
@ -176,7 +176,7 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{
}
results := makeSlice(field.Struct.Type)
scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).GetError())
resultValues := reflect.Indirect(reflect.ValueOf(results))
for i := 0; i < resultValues.Len(); i++ {

View File

@ -115,17 +115,17 @@ func TestNestedPreload1(t *testing.T) {
DB.DropTableIfExists(&Level3{})
DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{})
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).GetError(); err != nil {
t.Error(err)
}
want := Level3{Level2: Level2{Level1: Level1{Value: "value"}}}
if err := DB.Create(&want).Error; err != nil {
if err := DB.Create(&want).GetError(); err != nil {
t.Error(err)
}
var got Level3
if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil {
if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).GetError(); err != nil {
t.Error(err)
}
@ -133,7 +133,7 @@ func TestNestedPreload1(t *testing.T) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
}
if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").Error; err != gorm.RecordNotFound {
if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").GetError(); err != gorm.RecordNotFound {
t.Error(err)
}
}
@ -159,7 +159,7 @@ func TestNestedPreload2(t *testing.T) {
DB.DropTableIfExists(&Level3{})
DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{})
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).GetError(); err != nil {
t.Error(err)
}
@ -178,12 +178,12 @@ func TestNestedPreload2(t *testing.T) {
},
},
}
if err := DB.Create(&want).Error; err != nil {
if err := DB.Create(&want).GetError(); err != nil {
t.Error(err)
}
var got Level3
if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil {
if err := DB.Preload("Level2s.Level1s").Find(&got).GetError(); err != nil {
t.Error(err)
}
@ -213,7 +213,7 @@ func TestNestedPreload3(t *testing.T) {
DB.DropTableIfExists(&Level3{})
DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{})
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).GetError(); err != nil {
t.Error(err)
}
@ -223,12 +223,12 @@ func TestNestedPreload3(t *testing.T) {
{Level1: Level1{Value: "value2"}},
},
}
if err := DB.Create(&want).Error; err != nil {
if err := DB.Create(&want).GetError(); err != nil {
t.Error(err)
}
var got Level3
if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil {
if err := DB.Preload("Level2s.Level1").Find(&got).GetError(); err != nil {
t.Error(err)
}
@ -258,7 +258,7 @@ func TestNestedPreload4(t *testing.T) {
DB.DropTableIfExists(&Level3{})
DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{})
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).GetError(); err != nil {
t.Error(err)
}
@ -270,12 +270,12 @@ func TestNestedPreload4(t *testing.T) {
},
},
}
if err := DB.Create(&want).Error; err != nil {
if err := DB.Create(&want).GetError(); err != nil {
t.Error(err)
}
var got Level3
if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil {
if err := DB.Preload("Level2.Level1s").Find(&got).GetError(); err != nil {
t.Error(err)
}
@ -306,22 +306,22 @@ func TestNestedPreload5(t *testing.T) {
DB.DropTableIfExists(&Level3{})
DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{})
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).GetError(); err != nil {
t.Error(err)
}
want := make([]Level3, 2)
want[0] = Level3{Level2: Level2{Level1: Level1{Value: "value"}}}
if err := DB.Create(&want[0]).Error; err != nil {
if err := DB.Create(&want[0]).GetError(); err != nil {
t.Error(err)
}
want[1] = Level3{Level2: Level2{Level1: Level1{Value: "value2"}}}
if err := DB.Create(&want[1]).Error; err != nil {
if err := DB.Create(&want[1]).GetError(); err != nil {
t.Error(err)
}
var got []Level3
if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil {
if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).GetError(); err != nil {
t.Error(err)
}
@ -351,7 +351,7 @@ func TestNestedPreload6(t *testing.T) {
DB.DropTableIfExists(&Level3{})
DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{})
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).GetError(); err != nil {
t.Error(err)
}
@ -371,7 +371,7 @@ func TestNestedPreload6(t *testing.T) {
},
},
}
if err := DB.Create(&want[0]).Error; err != nil {
if err := DB.Create(&want[0]).GetError(); err != nil {
t.Error(err)
}
@ -390,12 +390,12 @@ func TestNestedPreload6(t *testing.T) {
},
},
}
if err := DB.Create(&want[1]).Error; err != nil {
if err := DB.Create(&want[1]).GetError(); err != nil {
t.Error(err)
}
var got []Level3
if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil {
if err := DB.Preload("Level2s.Level1s").Find(&got).GetError(); err != nil {
t.Error(err)
}
@ -425,7 +425,7 @@ func TestNestedPreload7(t *testing.T) {
DB.DropTableIfExists(&Level3{})
DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{})
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).GetError(); err != nil {
t.Error(err)
}
@ -436,7 +436,7 @@ func TestNestedPreload7(t *testing.T) {
{Level1: Level1{Value: "value2"}},
},
}
if err := DB.Create(&want[0]).Error; err != nil {
if err := DB.Create(&want[0]).GetError(); err != nil {
t.Error(err)
}
@ -446,12 +446,12 @@ func TestNestedPreload7(t *testing.T) {
{Level1: Level1{Value: "value4"}},
},
}
if err := DB.Create(&want[1]).Error; err != nil {
if err := DB.Create(&want[1]).GetError(); err != nil {
t.Error(err)
}
var got []Level3
if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil {
if err := DB.Preload("Level2s.Level1").Find(&got).GetError(); err != nil {
t.Error(err)
}
@ -481,7 +481,7 @@ func TestNestedPreload8(t *testing.T) {
DB.DropTableIfExists(&Level3{})
DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{})
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).GetError(); err != nil {
t.Error(err)
}
@ -494,7 +494,7 @@ func TestNestedPreload8(t *testing.T) {
},
},
}
if err := DB.Create(&want[0]).Error; err != nil {
if err := DB.Create(&want[0]).GetError(); err != nil {
t.Error(err)
}
want[1] = Level3{
@ -505,12 +505,12 @@ func TestNestedPreload8(t *testing.T) {
},
},
}
if err := DB.Create(&want[1]).Error; err != nil {
if err := DB.Create(&want[1]).GetError(); err != nil {
t.Error(err)
}
var got []Level3
if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil {
if err := DB.Preload("Level2.Level1s").Find(&got).GetError(); err != nil {
t.Error(err)
}
@ -555,7 +555,7 @@ func TestNestedPreload9(t *testing.T) {
DB.DropTableIfExists(&Level2_1{})
DB.DropTableIfExists(&Level1{})
DB.DropTableIfExists(&Level0{})
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}, &Level2_1{}, &Level0{}).Error; err != nil {
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}, &Level2_1{}, &Level0{}).GetError(); err != nil {
t.Error(err)
}
@ -580,7 +580,7 @@ func TestNestedPreload9(t *testing.T) {
},
},
}
if err := DB.Create(&want[0]).Error; err != nil {
if err := DB.Create(&want[0]).GetError(); err != nil {
t.Error(err)
}
want[1] = Level3{
@ -597,12 +597,12 @@ func TestNestedPreload9(t *testing.T) {
},
},
}
if err := DB.Create(&want[1]).Error; err != nil {
if err := DB.Create(&want[1]).GetError(); err != nil {
t.Error(err)
}
var got []Level3
if err := DB.Preload("Level2").Preload("Level2.Level1s").Preload("Level2_1").Preload("Level2_1.Level1s").Preload("Level2_1.Level1s.Level0s").Find(&got).Error; err != nil {
if err := DB.Preload("Level2").Preload("Level2.Level1s").Preload("Level2_1").Preload("Level2_1.Level1s").Preload("Level2_1.Level1s.Level0s").Find(&got).GetError(); err != nil {
t.Error(err)
}
@ -634,7 +634,7 @@ func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) {
DB.DropTableIfExists(&Level1{})
DB.DropTableIfExists("levels")
if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil {
if err := DB.AutoMigrate(&Level2{}, &Level1{}).GetError(); err != nil {
t.Error(err)
}
@ -642,7 +642,7 @@ func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) {
{Value: "ru", LanguageCode: "ru"},
{Value: "en", LanguageCode: "en"},
}}
if err := DB.Save(&want).Error; err != nil {
if err := DB.Save(&want).GetError(); err != nil {
t.Error(err)
}
@ -650,12 +650,12 @@ func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) {
{Value: "zh", LanguageCode: "zh"},
{Value: "de", LanguageCode: "de"},
}}
if err := DB.Save(&want2).Error; err != nil {
if err := DB.Save(&want2).GetError(); err != nil {
t.Error(err)
}
var got Level2
if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil {
if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").GetError(); err != nil {
t.Error(err)
}
@ -664,7 +664,7 @@ func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) {
}
var got2 Level2
if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil {
if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").GetError(); err != nil {
t.Error(err)
}
@ -673,7 +673,7 @@ func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) {
}
var got3 []Level2
if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil {
if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).GetError(); err != nil {
t.Error(err)
}
@ -682,7 +682,7 @@ func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) {
}
var got4 []Level2
if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil {
if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).GetError(); err != nil {
t.Error(err)
}
@ -697,7 +697,7 @@ func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) {
t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2}))
}
if err := DB.Preload("Level1s").Find(&got4, "value IN (?)", []string{"non-existing"}).Error; err != nil {
if err := DB.Preload("Level1s").Find(&got4, "value IN (?)", []string{"non-existing"}).GetError(); err != nil {
t.Error(err)
}
}
@ -719,7 +719,7 @@ func TestManyToManyPreloadForPointer(t *testing.T) {
DB.DropTableIfExists(&Level1{})
DB.DropTableIfExists("levels")
if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil {
if err := DB.AutoMigrate(&Level2{}, &Level1{}).GetError(); err != nil {
t.Error(err)
}
@ -727,7 +727,7 @@ func TestManyToManyPreloadForPointer(t *testing.T) {
{Value: "ru"},
{Value: "en"},
}}
if err := DB.Save(&want).Error; err != nil {
if err := DB.Save(&want).GetError(); err != nil {
t.Error(err)
}
@ -735,12 +735,12 @@ func TestManyToManyPreloadForPointer(t *testing.T) {
{Value: "zh"},
{Value: "de"},
}}
if err := DB.Save(&want2).Error; err != nil {
if err := DB.Save(&want2).GetError(); err != nil {
t.Error(err)
}
var got Level2
if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil {
if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").GetError(); err != nil {
t.Error(err)
}
@ -749,7 +749,7 @@ func TestManyToManyPreloadForPointer(t *testing.T) {
}
var got2 Level2
if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil {
if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").GetError(); err != nil {
t.Error(err)
}
@ -758,7 +758,7 @@ func TestManyToManyPreloadForPointer(t *testing.T) {
}
var got3 []Level2
if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil {
if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).GetError(); err != nil {
t.Error(err)
}
@ -767,7 +767,7 @@ func TestManyToManyPreloadForPointer(t *testing.T) {
}
var got4 []Level2
if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil {
if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).GetError(); err != nil {
t.Error(err)
}
@ -810,7 +810,7 @@ func TestManyToManyPreloadForNestedPointer(t *testing.T) {
DB.DropTableIfExists(&Level1{})
DB.DropTableIfExists("levels")
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).GetError(); err != nil {
t.Error(err)
}
@ -824,7 +824,7 @@ func TestManyToManyPreloadForNestedPointer(t *testing.T) {
},
},
}
if err := DB.Save(&want).Error; err != nil {
if err := DB.Save(&want).GetError(); err != nil {
t.Error(err)
}
@ -838,12 +838,12 @@ func TestManyToManyPreloadForNestedPointer(t *testing.T) {
},
},
}
if err := DB.Save(&want2).Error; err != nil {
if err := DB.Save(&want2).GetError(); err != nil {
t.Error(err)
}
var got Level3
if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "Bob").Error; err != nil {
if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "Bob").GetError(); err != nil {
t.Error(err)
}
@ -852,7 +852,7 @@ func TestManyToManyPreloadForNestedPointer(t *testing.T) {
}
var got2 Level3
if err := DB.Preload("Level2.Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil {
if err := DB.Preload("Level2.Level1s").Find(&got2, "value = ?", "Tom").GetError(); err != nil {
t.Error(err)
}
@ -861,7 +861,7 @@ func TestManyToManyPreloadForNestedPointer(t *testing.T) {
}
var got3 []Level3
if err := DB.Preload("Level2.Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil {
if err := DB.Preload("Level2.Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).GetError(); err != nil {
t.Error(err)
}
@ -870,7 +870,7 @@ func TestManyToManyPreloadForNestedPointer(t *testing.T) {
}
var got4 []Level3
if err := DB.Preload("Level2.Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil {
if err := DB.Preload("Level2.Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).GetError(); err != nil {
t.Error(err)
}
@ -913,7 +913,7 @@ func TestNestedManyToManyPreload(t *testing.T) {
DB.DropTableIfExists("level1_level2")
DB.DropTableIfExists("level2_level3")
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).GetError(); err != nil {
t.Error(err)
}
@ -936,12 +936,12 @@ func TestNestedManyToManyPreload(t *testing.T) {
},
}
if err := DB.Save(&want).Error; err != nil {
if err := DB.Save(&want).GetError(); err != nil {
t.Error(err)
}
var got Level3
if err := DB.Preload("Level2s").Preload("Level2s.Level1s").Find(&got, "value = ?", "Level3").Error; err != nil {
if err := DB.Preload("Level2s").Preload("Level2s.Level1s").Find(&got, "value = ?", "Level3").GetError(); err != nil {
t.Error(err)
}
@ -949,7 +949,7 @@ func TestNestedManyToManyPreload(t *testing.T) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
}
if err := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.RecordNotFound {
if err := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").GetError(); err != gorm.RecordNotFound {
t.Error(err)
}
}
@ -978,7 +978,7 @@ func TestNestedManyToManyPreload2(t *testing.T) {
DB.DropTableIfExists(&Level3{})
DB.DropTableIfExists("level1_level2")
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).GetError(); err != nil {
t.Error(err)
}
@ -993,12 +993,12 @@ func TestNestedManyToManyPreload2(t *testing.T) {
},
}
if err := DB.Save(&want).Error; err != nil {
if err := DB.Save(&want).GetError(); err != nil {
t.Error(err)
}
var got Level3
if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "Level3").Error; err != nil {
if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "Level3").GetError(); err != nil {
t.Error(err)
}
@ -1006,7 +1006,7 @@ func TestNestedManyToManyPreload2(t *testing.T) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
}
if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.RecordNotFound {
if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").GetError(); err != gorm.RecordNotFound {
t.Error(err)
}
}
@ -1035,7 +1035,7 @@ func TestNilPointerSlice(t *testing.T) {
DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{})
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).GetError(); err != nil {
t.Error(err)
}
@ -1045,17 +1045,17 @@ func TestNilPointerSlice(t *testing.T) {
Value: "native",
},
}}
if err := DB.Save(&want).Error; err != nil {
if err := DB.Save(&want).GetError(); err != nil {
t.Error(err)
}
want2 := Level1{Value: "Tom", Level2: nil}
if err := DB.Save(&want2).Error; err != nil {
if err := DB.Save(&want2).GetError(); err != nil {
t.Error(err)
}
var got []Level1
if err := DB.Preload("Level2").Preload("Level2.Level3").Find(&got).Error; err != nil {
if err := DB.Preload("Level2").Preload("Level2.Level3").Find(&got).GetError(); err != nil {
t.Error(err)
}

View File

@ -31,7 +31,7 @@ func TestFirstAndLast(t *testing.T) {
t.Errorf("Find first record as slice")
}
if DB.Joins("left join emails on emails.user_id = users.id").First(&User{}).Error != nil {
if DB.Joins("left join emails on emails.user_id = users.id").First(&User{}).GetError() != nil {
t.Errorf("Should not raise any error when order with Join table")
}
}
@ -242,15 +242,15 @@ func TestSearchWithEmptyChain(t *testing.T) {
user3 := User{Name: "ChainearchUser3", Age: 20, Birthday: now.MustParse("2020-1-1")}
DB.Save(&user1).Save(&user2).Save(&user3)
if DB.Where("").Where("").First(&User{}).Error != nil {
if DB.Where("").Where("").First(&User{}).GetError() != nil {
t.Errorf("Should not raise any error if searching with empty strings")
}
if DB.Where(&User{}).Where("name = ?", user1.Name).First(&User{}).Error != nil {
if DB.Where(&User{}).Where("name = ?", user1.Name).First(&User{}).GetError() != nil {
t.Errorf("Should not raise any error if searching with empty struct")
}
if DB.Where(map[string]interface{}{}).Where("name = ?", user1.Name).First(&User{}).Error != nil {
if DB.Where(map[string]interface{}{}).Where("name = ?", user1.Name).First(&User{}).GetError() != nil {
t.Errorf("Should not raise any error if searching with empty map")
}
}
@ -359,7 +359,7 @@ func TestCount(t *testing.T) {
var count, count1, count2 int64
var users []User
if err := DB.Where("name = ?", user1.Name).Or("name = ?", user3.Name).Find(&users).Count(&count).Error; err != nil {
if err := DB.Where("name = ?", user1.Name).Or("name = ?", user3.Name).Find(&users).Count(&count).GetError(); err != nil {
t.Errorf(fmt.Sprintf("Count should work, but got err %v", err))
}
@ -381,7 +381,7 @@ func TestNot(t *testing.T) {
DB := DB.Where("role = ?", "not")
var users1, users2, users3, users4, users5, users6, users7, users8 []User
if DB.Find(&users1).RowsAffected != 4 {
if DB.Find(&users1).GetRowsAffected() != 4 {
t.Errorf("should find 4 not users")
}
DB.Not(users1[0].Id).Find(&users2)
@ -598,7 +598,7 @@ func TestSelectWithArrayInput(t *testing.T) {
func TestCurrentDatabase(t *testing.T) {
databaseName := DB.CurrentDatabase()
if err := DB.Error; err != nil {
if err := DB.GetError(); err != nil {
t.Errorf("Problem getting current db name: %s", err)
}
if databaseName == "" {

View File

@ -207,7 +207,7 @@ func (scope *Scope) CallMethod(name string, checkError bool) {
case func(s *DB):
newDB := scope.NewDB()
f(newDB)
scope.Err(newDB.Error)
scope.Err(newDB.GetError())
case func() error:
scope.Err(f())
case func(s *Scope) error:
@ -215,7 +215,7 @@ func (scope *Scope) CallMethod(name string, checkError bool) {
case func(s *DB) error:
newDB := scope.NewDB()
scope.Err(f(newDB))
scope.Err(newDB.Error)
scope.Err(newDB.GetError())
default:
scope.Err(fmt.Errorf("unsupported function %v", name))
}
@ -262,7 +262,7 @@ type tabler interface {
}
type dbTabler interface {
TableName(*DB) string
TableName(Database) string
}
// TableName get table name

View File

@ -444,17 +444,17 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
if relationship := fromField.Relationship; relationship != nil {
if relationship.Kind == "many_to_many" {
joinTableHandler := relationship.JoinTableHandler
scope.Err(joinTableHandler.JoinWith(joinTableHandler, toScope.db, scope.Value).Find(value).Error)
scope.Err(joinTableHandler.JoinWith(joinTableHandler, toScope.db, scope.Value).Find(value).GetError())
} else if relationship.Kind == "belongs_to" {
query := toScope.db
var query Database = toScope.db
for idx, foreignKey := range relationship.ForeignDBNames {
if field, ok := scope.FieldByName(foreignKey); ok {
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.AssociationForeignDBNames[idx])), field.Field.Interface())
}
}
scope.Err(query.Find(value).Error)
scope.Err(query.Find(value).GetError())
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
query := toScope.db
var query Database = toScope.db
for idx, foreignKey := range relationship.ForeignDBNames {
if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok {
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
@ -464,16 +464,16 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
if relationship.PolymorphicType != "" {
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName())
}
scope.Err(query.Find(value).Error)
scope.Err(query.Find(value).GetError())
}
} else {
sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey()))
scope.Err(toScope.db.Where(sql, fromField.Field.Interface()).Find(value).Error)
scope.Err(toScope.db.Where(sql, fromField.Field.Interface()).Find(value).GetError())
}
return scope
} else if toField != nil {
sql := fmt.Sprintf("%v = ?", scope.Quote(toField.DBName))
scope.Err(toScope.db.Where(sql, scope.PrimaryKeyValue()).Find(value).Error)
scope.Err(toScope.db.Where(sql, scope.PrimaryKeyValue()).Find(value).GetError())
return scope
}
}
@ -525,7 +525,7 @@ func (scope *Scope) createJoinTable(field *StructField) {
}
}
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v)) %s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).Error)
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v)) %s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).GetError())
}
scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler)
}

View File

@ -6,16 +6,16 @@ import (
)
func NameIn1And2(d *gorm.DB) *gorm.DB {
return d.Where("name in (?)", []string{"ScopeUser1", "ScopeUser2"})
return d.Where("name in (?)", []string{"ScopeUser1", "ScopeUser2"}).(*gorm.DB)
}
func NameIn2And3(d *gorm.DB) *gorm.DB {
return d.Where("name in (?)", []string{"ScopeUser2", "ScopeUser3"})
return d.Where("name in (?)", []string{"ScopeUser2", "ScopeUser3"}).(*gorm.DB)
}
func NameIn(names []string) func(d *gorm.DB) *gorm.DB{
return func(d *gorm.DB) *gorm.DB {
return d.Where("name in (?)", names)
return d.Where("name in (?)", names).(*gorm.DB)
}
}

View File

@ -7,7 +7,7 @@ import (
)
func TestScannableSlices(t *testing.T) {
if err := DB.AutoMigrate(&RecordWithSlice{}).Error; err != nil {
if err := DB.AutoMigrate(&RecordWithSlice{}).GetError(); err != nil {
t.Errorf("Should create table with slice values correctly: %s", err)
}
@ -19,13 +19,13 @@ func TestScannableSlices(t *testing.T) {
},
}
if err := DB.Save(&r1).Error; err != nil {
if err := DB.Save(&r1).GetError(); err != nil {
t.Errorf("Should save record with slice values")
}
var r2 RecordWithSlice
if err := DB.Find(&r2).Error; err != nil {
if err := DB.Find(&r2).GetError(); err != nil {
t.Errorf("Should fetch record with slice values")
}

View File

@ -62,7 +62,7 @@ func (s sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool
}
func (sqlite3) RemoveIndex(scope *Scope, indexName string) {
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error)
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).GetError())
}
func (sqlite3) CurrentDatabase(scope *Scope) (name string) {

View File

@ -56,17 +56,17 @@ func TestUpdate(t *testing.T) {
t.Errorf("Product should not be changed to 789")
}
if DB.Model(product2).Update("CreatedAt", time.Now().Add(time.Hour)).Error != nil {
if DB.Model(product2).Update("CreatedAt", time.Now().Add(time.Hour)).GetError() != nil {
t.Error("No error should raise when update with CamelCase")
}
if DB.Model(&product2).UpdateColumn("CreatedAt", time.Now().Add(time.Hour)).Error != nil {
if DB.Model(&product2).UpdateColumn("CreatedAt", time.Now().Add(time.Hour)).GetError() != nil {
t.Error("No error should raise when update_column with CamelCase")
}
var products []Product
DB.Find(&products)
if count := DB.Model(Product{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(products)) {
if count := DB.Model(Product{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).GetRowsAffected(); count != int64(len(products)) {
t.Error("RowsAffected should be correct when do batch update")
}
@ -95,7 +95,7 @@ func TestUpdateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) {
var animals []Animal
DB.Find(&animals)
if count := DB.Model(Animal{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(animals)) {
if count := DB.Model(Animal{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).GetRowsAffected(); count != int64(len(animals)) {
t.Error("RowsAffected should be correct when do batch update")
}
@ -402,8 +402,8 @@ func TestUpdateColumnsSkipsAssociations(t *testing.T) {
newAge := int64(100)
user.BillingAddress.Address1 = "second street"
db := DB.Model(user).UpdateColumns(User{Age: newAge})
if db.RowsAffected != 1 {
t.Errorf("Expected RowsAffected=1 but instead RowsAffected=%v", DB.RowsAffected)
if db.GetRowsAffected() != 1 {
t.Errorf("Expected RowsAffected=1 but instead RowsAffected=%v", DB.GetRowsAffected())
}
// Verify that Age now=`newAge`.