diff --git a/association.go b/association.go index 30ea36b2..89e63134 100644 --- a/association.go +++ b/association.go @@ -23,7 +23,7 @@ func (association *Association) setErr(err error) *Association { func (association *Association) Find(value interface{}) *Association { association.Scope.related(value, association.Column) - return association.setErr(association.Scope.db.Error) + return association.setErr(association.Scope.db.GetError()) } func (association *Association) saveAssociations(values ...interface{}) *Association { @@ -42,7 +42,7 @@ func (association *Association) saveAssociations(values ...interface{}) *Associa // value has to been saved for many2many if relationship.Kind == "many_to_many" { if scope.New(reflectValue.Interface()).PrimaryKeyZero() { - association.setErr(scope.NewDB().Save(reflectValue.Interface()).Error) + association.setErr(scope.NewDB().Save(reflectValue.Interface()).GetError()) } } @@ -68,7 +68,7 @@ func (association *Association) saveAssociations(values ...interface{}) *Associa if relationship.Kind == "many_to_many" { association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, reflectValue.Interface())) } else { - association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).Error) + association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).GetError()) if setFieldBackToValue { reflectValue.Elem().Set(field.Field) @@ -104,9 +104,9 @@ func (association *Association) Append(values ...interface{}) *Association { func (association *Association) Replace(values ...interface{}) *Association { var ( relationship = association.Field.Relationship - scope = association.Scope - field = association.Field.Field - newDB = scope.NewDB() + scope = association.Scope + field = association.Field.Field + newDB Database = scope.NewDB() ) // Append new values @@ -122,7 +122,7 @@ func (association *Association) Replace(values ...interface{}) *Association { for _, foreignKey := range relationship.ForeignDBNames { foreignKeyMap[foreignKey] = nil } - association.setErr(newDB.Model(scope.Value).UpdateColumn(foreignKeyMap).Error) + association.setErr(newDB.Model(scope.Value).UpdateColumn(foreignKeyMap).GetError()) } } else { // Relations @@ -173,7 +173,7 @@ func (association *Association) Replace(values ...interface{}) *Association { } fieldValue := reflect.New(association.Field.Field.Type()).Interface() - association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error) + association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).GetError()) } } return association @@ -182,9 +182,9 @@ func (association *Association) Replace(values ...interface{}) *Association { func (association *Association) Delete(values ...interface{}) *Association { var ( relationship = association.Field.Relationship - scope = association.Scope - field = association.Field.Field - newDB = scope.NewDB() + scope = association.Scope + field = association.Field.Field + newDB Database = scope.NewDB() ) if len(values) == 0 { @@ -231,12 +231,12 @@ func (association *Association) Delete(values ...interface{}) *Association { // set foreign key to be null modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface() - if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.Error == nil { - if results.RowsAffected > 0 { + if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.GetError() == nil { + if results.GetRowsAffected() > 0 { scope.updatedAttrsWithValues(foreignKeyMap, false) } } else { - association.setErr(results.Error) + association.setErr(results.GetError()) } } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { // find all relations @@ -254,7 +254,7 @@ func (association *Association) Delete(values ...interface{}) *Association { // set matched relation's foreign key to be null fieldValue := reflect.New(association.Field.Field.Type()).Interface() - association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error) + association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).GetError()) } } @@ -298,17 +298,17 @@ func (association *Association) Clear() *Association { func (association *Association) Count() int { var ( - count = 0 + count = 0 relationship = association.Field.Relationship - scope = association.Scope - fieldValue = association.Field.Field.Interface() - newScope = scope.New(fieldValue) + scope = association.Scope + fieldValue = association.Field.Field.Interface() + newScope = scope.New(fieldValue) ) if relationship.Kind == "many_to_many" { relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.DB(), association.Scope.Value).Model(fieldValue).Count(&count) } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { - query := scope.DB() + var query Database = scope.DB() for idx, foreignKey := range relationship.ForeignDBNames { if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(foreignKey)), @@ -321,7 +321,7 @@ func (association *Association) Count() int { } query.Model(fieldValue).Count(&count) } else if relationship.Kind == "belongs_to" { - query := scope.DB() + var query Database = scope.DB() for idx, primaryKey := range relationship.AssociationForeignDBNames { if field, ok := scope.FieldByName(relationship.ForeignDBNames[idx]); ok { query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(primaryKey)), diff --git a/association_test.go b/association_test.go index ab3abd91..b7d8e7f3 100644 --- a/association_test.go +++ b/association_test.go @@ -15,7 +15,7 @@ func TestBelongsTo(t *testing.T) { MainCategory: Category{Name: "Main Category 1"}, } - if err := DB.Save(&post).Error; err != nil { + if err := DB.Save(&post).GetError(); err != nil { t.Errorf("Got errors when save post", err.Error()) } @@ -183,7 +183,7 @@ func TestHasOne(t *testing.T) { CreditCard: CreditCard{Number: "411111111111"}, } - if err := DB.Save(&user).Error; err != nil { + if err := DB.Save(&user).GetError(); err != nil { t.Errorf("Got errors when save user", err.Error()) } @@ -330,7 +330,7 @@ func TestHasMany(t *testing.T) { Comments: []*Comment{{Content: "Comment 1"}, {Content: "Comment 2"}}, } - if err := DB.Save(&post).Error; err != nil { + if err := DB.Save(&post).GetError(); err != nil { t.Errorf("Got errors when save post", err.Error()) } @@ -351,7 +351,7 @@ func TestHasMany(t *testing.T) { } // Query - if DB.First(&Comment{}, "content = ?", "Comment 1").Error != nil { + if DB.First(&Comment{}, "content = ?", "Comment 1").GetError() != nil { t.Errorf("Comment 1 should be saved") } diff --git a/callback_create.go b/callback_create.go index d13a71be..9d939d72 100644 --- a/callback_create.go +++ b/callback_create.go @@ -78,7 +78,8 @@ func Create(scope *Scope) { if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil { id, err := result.LastInsertId() if scope.Err(err) == nil { - scope.db.RowsAffected, _ = result.RowsAffected() + rowsAffected, _ := result.RowsAffected() + scope.db.SetRowsAffected(rowsAffected) if primaryField != nil && primaryField.IsBlank { scope.Err(scope.SetColumn(primaryField, id)) } @@ -87,13 +88,14 @@ func Create(scope *Scope) { } else { if primaryField == nil { if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == nil { - scope.db.RowsAffected, _ = results.RowsAffected() + rowsAffected, _ := results.RowsAffected() + scope.db.SetRowsAffected(rowsAffected) } else { scope.Err(err) } } else { if err := scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())); err == nil { - scope.db.RowsAffected = 1 + scope.db.SetRowsAffected(1) } else { scope.Err(err) } diff --git a/callback_query.go b/callback_query.go index 5473f232..ec42a912 100644 --- a/callback_query.go +++ b/callback_query.go @@ -45,7 +45,7 @@ func Query(scope *Scope) { if !scope.HasError() { rows, err := scope.SqlDB().Query(scope.Sql, scope.SqlVars...) - scope.db.RowsAffected = 0 + scope.db.SetRowsAffected(0) if scope.Err(err) != nil { return @@ -54,7 +54,8 @@ func Query(scope *Scope) { columns, _ := rows.Columns() for rows.Next() { - scope.db.RowsAffected++ + rowsAffected := scope.db.GetRowsAffected()+1 + scope.db.SetRowsAffected(rowsAffected) anyRecordFound = true elem := dest diff --git a/callback_shared.go b/callback_shared.go index 547059e3..a4dfb1bc 100644 --- a/callback_shared.go +++ b/callback_shared.go @@ -18,7 +18,7 @@ func SaveBeforeAssociations(scope *Scope) { if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { value := field.Field - scope.Err(scope.NewDB().Save(value.Addr().Interface()).Error) + scope.Err(scope.NewDB().Save(value.Addr().Interface()).GetError()) if len(relationship.ForeignFieldNames) != 0 { for idx, fieldName := range relationship.ForeignFieldNames { associationForeignName := relationship.AssociationForeignDBNames[idx] @@ -62,7 +62,7 @@ func SaveAfterAssociations(scope *Scope) { scope.Err(newScope.SetColumn(relationship.PolymorphicType, scope.TableName())) } - scope.Err(newDB.Save(elem).Error) + scope.Err(newDB.Save(elem).GetError()) if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil { scope.Err(joinTableHandler.Add(joinTableHandler, scope.NewDB(), scope.Value, newScope.Value)) @@ -83,7 +83,7 @@ func SaveAfterAssociations(scope *Scope) { if relationship.PolymorphicType != "" { scope.Err(newScope.SetColumn(relationship.PolymorphicType, scope.TableName())) } - scope.Err(scope.NewDB().Save(elem).Error) + scope.Err(scope.NewDB().Save(elem).GetError()) } } } diff --git a/callbacks_test.go b/callbacks_test.go index a58913d7..0ad91b82 100644 --- a/callbacks_test.go +++ b/callbacks_test.go @@ -108,22 +108,22 @@ func TestRunCallbacks(t *testing.T) { t.Errorf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes()) } - if DB.Where("Code = ?", "unique_code").First(&p).Error == nil { + if DB.Where("Code = ?", "unique_code").First(&p).GetError() == nil { t.Errorf("Can't find a deleted record") } } func TestCallbacksWithErrors(t *testing.T) { p := Product{Code: "Invalid", Price: 100} - if DB.Save(&p).Error == nil { + if DB.Save(&p).GetError() == nil { t.Errorf("An error from before create callbacks happened when create with invalid value") } - if DB.Where("code = ?", "Invalid").First(&Product{}).Error == nil { + if DB.Where("code = ?", "Invalid").First(&Product{}).GetError() == nil { t.Errorf("Should not save record that have errors") } - if DB.Save(&Product{Code: "dont_save", Price: 100}).Error == nil { + if DB.Save(&Product{Code: "dont_save", Price: 100}).GetError() == nil { t.Errorf("An error from after create callbacks happened when create with invalid value") } @@ -131,47 +131,47 @@ func TestCallbacksWithErrors(t *testing.T) { DB.Save(&p2) p2.Code = "dont_update" - if DB.Save(&p2).Error == nil { + if DB.Save(&p2).GetError() == nil { t.Errorf("An error from before update callbacks happened when update with invalid value") } - if DB.Where("code = ?", "update_callback").First(&Product{}).Error != nil { + if DB.Where("code = ?", "update_callback").First(&Product{}).GetError() != nil { t.Errorf("Record Should not be updated due to errors happened in before update callback") } - if DB.Where("code = ?", "dont_update").First(&Product{}).Error == nil { + if DB.Where("code = ?", "dont_update").First(&Product{}).GetError() == nil { t.Errorf("Record Should not be updated due to errors happened in before update callback") } p2.Code = "dont_save" - if DB.Save(&p2).Error == nil { + if DB.Save(&p2).GetError() == nil { t.Errorf("An error from before save callbacks happened when update with invalid value") } p3 := Product{Code: "dont_delete", Price: 100} DB.Save(&p3) - if DB.Delete(&p3).Error == nil { + if DB.Delete(&p3).GetError() == nil { t.Errorf("An error from before delete callbacks happened when delete") } - if DB.Where("Code = ?", "dont_delete").First(&p3).Error != nil { + if DB.Where("Code = ?", "dont_delete").First(&p3).GetError() != nil { t.Errorf("An error from before delete callbacks happened") } p4 := Product{Code: "after_save_error", Price: 100} DB.Save(&p4) - if err := DB.First(&Product{}, "code = ?", "after_save_error").Error; err == nil { + if err := DB.First(&Product{}, "code = ?", "after_save_error").GetError(); err == nil { t.Errorf("Record should be reverted if get an error in after save callback") } p5 := Product{Code: "after_delete_error", Price: 100} DB.Save(&p5) - if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { + if err := DB.First(&Product{}, "code = ?", "after_delete_error").GetError(); err != nil { t.Errorf("Record should be found") } DB.Delete(&p5) - if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { + if err := DB.First(&Product{}, "code = ?", "after_delete_error").GetError(); err != nil { t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback") } } diff --git a/common_dialect.go b/common_dialect.go index 7f08b04f..cfce362d 100644 --- a/common_dialect.go +++ b/common_dialect.go @@ -96,7 +96,7 @@ func (c commonDialect) HasIndex(scope *Scope, tableName string, indexName string } func (commonDialect) RemoveIndex(scope *Scope, indexName string) { - scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())).Error) + scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())).GetError()) } // RawScanInt scans the first column of the first row into the `scan' int pointer. diff --git a/create_test.go b/create_test.go index 4d1f623d..7a312f9c 100644 --- a/create_test.go +++ b/create_test.go @@ -15,7 +15,7 @@ func TestCreate(t *testing.T) { t.Error("User should be new record before create") } - if count := DB.Save(&user).RowsAffected; count != 1 { + if count := DB.Save(&user).GetRowsAffected(); count != 1 { t.Error("There should be one record be affected when create record") } @@ -63,7 +63,7 @@ func TestCreateWithNoGORMPrimayKey(t *testing.T) { } jt := JoinTable{From: 1, To: 2} - err := DB.Create(&jt).Error + err := DB.Create(&jt).GetError() if err != nil { t.Errorf("No error should happen when create a record without a GORM primary key. But in the database this primary key exists and is the union of 2 or more fields\n But got: %s", err) } @@ -71,7 +71,7 @@ func TestCreateWithNoGORMPrimayKey(t *testing.T) { func TestCreateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) { animal := Animal{Name: "Ferdinand"} - if DB.Save(&animal).Error != nil { + if DB.Save(&animal).GetError() != nil { t.Errorf("No error should happen when create a record without std primary key") } @@ -86,7 +86,7 @@ func TestCreateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) { // Test create with default value not overrided an := Animal{From: "nerdz"} - if DB.Save(&an).Error != nil { + if DB.Save(&an).GetError() != nil { t.Errorf("No error should happen when create an record without std primary key") } diff --git a/customize_column_test.go b/customize_column_test.go index 93bab2e1..1608cd66 100644 --- a/customize_column_test.go +++ b/customize_column_test.go @@ -38,7 +38,7 @@ func TestCustomizeColumn(t *testing.T) { expected := "foo" cc := CustomizeColumn{ID: 666, Name: expected, Date: time.Now()} - if count := DB.Create(&cc).RowsAffected; count != 1 { + if count := DB.Create(&cc).GetRowsAffected(); count != 1 { t.Error("There should be one record be affected when create record") } @@ -61,7 +61,7 @@ func TestCustomizeColumn(t *testing.T) { func TestCustomColumnAndIgnoredFieldClash(t *testing.T) { DB.DropTable(&CustomColumnAndIgnoredFieldClash{}) - if err := DB.AutoMigrate(&CustomColumnAndIgnoredFieldClash{}).Error; err != nil { + if err := DB.AutoMigrate(&CustomColumnAndIgnoredFieldClash{}).GetError(); err != nil { t.Errorf("Should not raise error: %s", err) } } @@ -86,17 +86,17 @@ func TestManyToManyWithCustomizedColumn(t *testing.T) { Accounts: []CustomizeAccount{account}, } - if err := DB.Create(&account).Error; err != nil { + if err := DB.Create(&account).GetError(); err != nil { t.Errorf("no error should happen, but got %v", err) } - if err := DB.Create(&person).Error; err != nil { + if err := DB.Create(&person).GetError(); err != nil { t.Errorf("no error should happen, but got %v", err) } var person1 CustomizePerson scope := DB.NewScope(nil) - if err := DB.Preload("Accounts").First(&person1, scope.Quote("idPerson")+" = ?", person.IdPerson).Error; err != nil { + if err := DB.Preload("Accounts").First(&person1, scope.Quote("idPerson")+" = ?", person.IdPerson).GetError(); err != nil { t.Errorf("no error should happen when preloading customized column many2many relations, but got %v", err) } @@ -131,7 +131,7 @@ func TestOneToOneWithCustomizedColumn(t *testing.T) { DB.Create(&invitation) var invitation2 CustomizeInvitation - if err := DB.Preload("Person").Find(&invitation2, invitation.ID).Error; err != nil { + if err := DB.Preload("Person").Find(&invitation2, invitation.ID).GetError(); err != nil { t.Errorf("no error should happen, but got %v", err) } @@ -183,12 +183,12 @@ func TestOneToManyWithCustomizedColumn(t *testing.T) { }, } - if err := DB.Create(&discount).Error; err != nil { + if err := DB.Create(&discount).GetError(); err != nil { t.Errorf("no error should happen but got %v", err) } var discount1 PromotionDiscount - if err := DB.Preload("Coupons").First(&discount1, "id = ?", discount.ID).Error; err != nil { + if err := DB.Preload("Coupons").First(&discount1, "id = ?", discount.ID).GetError(); err != nil { t.Errorf("no error should happen but got %v", err) } @@ -197,7 +197,7 @@ func TestOneToManyWithCustomizedColumn(t *testing.T) { } var coupon PromotionCoupon - if err := DB.Preload("Discount").First(&coupon, "code = ?", "newyear1").Error; err != nil { + if err := DB.Preload("Discount").First(&coupon, "code = ?", "newyear1").GetError(); err != nil { t.Errorf("no error should happen but got %v", err) } @@ -221,12 +221,12 @@ func TestHasOneWithPartialCustomizedColumn(t *testing.T) { }, } - if err := DB.Create(&discount).Error; err != nil { + if err := DB.Create(&discount).GetError(); err != nil { t.Errorf("no error should happen but got %v", err) } var discount1 PromotionDiscount - if err := DB.Preload("Rule").First(&discount1, "id = ?", discount.ID).Error; err != nil { + if err := DB.Preload("Rule").First(&discount1, "id = ?", discount.ID).GetError(); err != nil { t.Errorf("no error should happen but got %v", err) } @@ -235,7 +235,7 @@ func TestHasOneWithPartialCustomizedColumn(t *testing.T) { } var rule PromotionRule - if err := DB.Preload("Discount").First(&rule, "name = ?", "time_limited").Error; err != nil { + if err := DB.Preload("Discount").First(&rule, "name = ?", "time_limited").GetError(); err != nil { t.Errorf("no error should happen but got %v", err) } @@ -256,12 +256,12 @@ func TestBelongsToWithPartialCustomizedColumn(t *testing.T) { }, } - if err := DB.Create(&discount).Error; err != nil { + if err := DB.Create(&discount).GetError(); err != nil { t.Errorf("no error should happen but got %v", err) } var discount1 PromotionDiscount - if err := DB.Preload("Benefits").First(&discount1, "id = ?", discount.ID).Error; err != nil { + if err := DB.Preload("Benefits").First(&discount1, "id = ?", discount.ID).GetError(); err != nil { t.Errorf("no error should happen but got %v", err) } @@ -270,7 +270,7 @@ func TestBelongsToWithPartialCustomizedColumn(t *testing.T) { } var benefit PromotionBenefit - if err := DB.Preload("Discount").First(&benefit, "name = ?", "free cod").Error; err != nil { + if err := DB.Preload("Discount").First(&benefit, "name = ?", "free cod").GetError(); err != nil { t.Errorf("no error should happen but got %v", err) } diff --git a/ddl_errors_test.go b/ddl_errors_test.go index aca59553..083be93e 100644 --- a/ddl_errors_test.go +++ b/ddl_errors_test.go @@ -18,7 +18,7 @@ func TestDdlErrors(t *testing.T) { }() DB.HasTable("foobarbaz") - if DB.Error == nil { + if DB.GetError() == nil { t.Errorf("Expected operation on closed db to produce an error, but err was nil") } } diff --git a/delete_test.go b/delete_test.go index e0c71660..312f092d 100644 --- a/delete_test.go +++ b/delete_test.go @@ -10,7 +10,7 @@ func TestDelete(t *testing.T) { DB.Save(&user1) DB.Save(&user2) - if err := DB.Delete(&user1).Error; err != nil { + if err := DB.Delete(&user1).GetError(); err != nil { t.Errorf("No error should happen when delete a record, err=%s", err) } @@ -28,13 +28,13 @@ func TestInlineDelete(t *testing.T) { DB.Save(&user1) DB.Save(&user2) - if DB.Delete(&User{}, user1.Id).Error != nil { + if DB.Delete(&User{}, user1.Id).GetError() != nil { t.Errorf("No error should happen when delete a record") } else if !DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() { t.Errorf("User can't be found after delete") } - if err := DB.Delete(&User{}, "name = ?", user2.Name).Error; err != nil { + if err := DB.Delete(&User{}, "name = ?", user2.Name).GetError(); err != nil { t.Errorf("No error should happen when delete a record, err=%s", err) } else if !DB.Where("name = ?", user2.Name).First(&User{}).RecordNotFound() { t.Errorf("User can't be found after delete") @@ -53,11 +53,11 @@ func TestSoftDelete(t *testing.T) { DB.Save(&user) DB.Delete(&user) - if DB.First(&User{}, "name = ?", user.Name).Error == nil { + if DB.First(&User{}, "name = ?", user.Name).GetError() == nil { t.Errorf("Can't find a soft deleted record") } - if err := DB.Unscoped().First(&User{}, "name = ?", user.Name).Error; err != nil { + if err := DB.Unscoped().First(&User{}, "name = ?", user.Name).GetError(); err != nil { t.Errorf("Should be able to find soft deleted record with Unscoped, but err=%s", err) } diff --git a/embedded_struct_test.go b/embedded_struct_test.go index 7be75d99..a7fb41cf 100644 --- a/embedded_struct_test.go +++ b/embedded_struct_test.go @@ -22,7 +22,7 @@ func TestSaveAndQueryEmbeddedStruct(t *testing.T) { DB.Save(&HNPost{BasePost: BasePost{Title: "news"}}) DB.Save(&HNPost{BasePost: BasePost{Title: "hn_news"}}) var news HNPost - if err := DB.First(&news, "title = ?", "hn_news").Error; err != nil { + if err := DB.First(&news, "title = ?", "hn_news").GetError(); err != nil { t.Errorf("no error should happen when query with embedded struct, but got %v", err) } else if news.Title != "hn_news" { t.Errorf("embedded struct's value should be scanned correctly") @@ -30,7 +30,7 @@ func TestSaveAndQueryEmbeddedStruct(t *testing.T) { DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}}) var egNews EngadgetPost - if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil { + if err := DB.First(&egNews, "title = ?", "engadget_news").GetError(); err != nil { t.Errorf("no error should happen when query with embedded struct, but got %v", err) } else if egNews.BasePost.Title != "engadget_news" { t.Errorf("embedded struct's value should be scanned correctly") diff --git a/interface.go b/interface.go index 7b02aa66..26c962ac 100644 --- a/interface.go +++ b/interface.go @@ -1,6 +1,8 @@ package gorm -import "database/sql" +import ( + "database/sql" +) type sqlCommon interface { Exec(query string, args ...interface{}) (sql.Result, error) @@ -17,3 +19,95 @@ type sqlTx interface { Commit() error Rollback() error } + +type Database interface { + Close() error + DB() *sql.DB + New() Database + NewScope(value interface{}) *Scope + CommonDB() sqlCommon + Callback() *callback + SetLogger(l logger) + LogMode(enable bool) Database + SingularTable(enable bool) + + Where(query interface{}, args ...interface{}) Database + Or(query interface{}, args ...interface{}) Database + Not(query interface{}, args ...interface{}) Database + Limit(value interface{}) Database + Offset(value interface{}) Database + Order(value string, reorder ...bool) Database + Select(query interface{}, args ...interface{}) Database + Omit(columns ...string) Database + Group(query string) Database + Having(query string, values ...interface{}) Database + Joins(query string) Database + + //Scopes(funcs ...func(Database) Database) Database + Scopes(funcs ...func(*DB) *DB) *DB + Unscoped() Database + + Attrs(attrs ...interface{}) Database + Assign(attrs ...interface{}) Database + First(out interface{}, where ...interface{}) Database + Last(out interface{}, where ...interface{}) Database + Find(out interface{}, where ...interface{}) Database + Scan(dest interface{}) Database + Row() *sql.Row + Rows() (*sql.Rows, error) + Pluck(column string, value interface{}) Database + Count(value interface{}) Database + + Related(value interface{}, foreignKeys ...string) Database + + FirstOrInit(out interface{}, where ...interface{}) Database + FirstOrCreate(out interface{}, where ...interface{}) Database + Update(attrs ...interface{}) Database + Updates(values interface{}, ignoreProtectedAttrs ...bool) Database + UpdateColumn(attrs ...interface{}) Database + UpdateColumns(values interface{}) Database + Save(value interface{}) Database + Create(value interface{}) Database + Delete(value interface{}, where ...interface{}) Database + + Raw(sql string, values ...interface{}) Database + Exec(sql string, values ...interface{}) Database + Model(value interface{}) Database + Table(name string) Database + Debug() Database + + Begin() Database + Commit() Database + Rollback() Database + + NewRecord(value interface{}) bool + RecordNotFound() bool + + CreateTable(values ...interface{}) Database + DropTable(values ...interface{}) Database + DropTableIfExists(values ...interface{}) Database + HasTable(value interface{}) bool + AutoMigrate(values ...interface{}) Database + ModifyColumn(column string, typ string) Database + DropColumn(column string) Database + AddIndex(indexName string, column ...string) Database + AddUniqueIndex(indexName string, column ...string) Database + RemoveIndex(indexName string) Database + CurrentDatabase() string + AddForeignKey(field string, dest string, onDelete string, onUpdate string) Database + + Association(column string) *Association + Preload(column string, conditions ...interface{}) Database + Set(name string, value interface{}) Database + InstantSet(name string, value interface{}) Database + Get(name string) (value interface{}, ok bool) + SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) + + AddError(err error) error + GetError() error + GetErrors() (errors []error) + + GetRowsAffected() int64 + SetRowsAffected(num int64) +} + diff --git a/join_table_handler.go b/join_table_handler.go index 006701a6..cdb25cbe 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -9,10 +9,10 @@ import ( type JoinTableHandlerInterface interface { Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) - Table(db *DB) string - Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error - Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error - JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB + Table(db Database) string + Add(handler JoinTableHandlerInterface, db Database, source interface{}, destination interface{}) error + Delete(handler JoinTableHandlerInterface, db Database, sources ...interface{}) error + JoinWith(handler JoinTableHandlerInterface, db Database, source interface{}) Database SourceForeignKeys() []JoinTableForeignKey DestinationForeignKeys() []JoinTableForeignKey } @@ -61,11 +61,11 @@ func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, s } } -func (s JoinTableHandler) Table(db *DB) string { +func (s JoinTableHandler) Table(db Database) string { return s.TableName } -func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[string]interface{} { +func (s JoinTableHandler) GetSearchMap(db Database, sources ...interface{}) map[string]interface{} { values := map[string]interface{}{} for _, source := range sources { @@ -85,7 +85,7 @@ func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[strin return values } -func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1 interface{}, source2 interface{}) error { +func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db Database, source1 interface{}, source2 interface{}) error { scope := db.NewScope("") searchMap := s.GetSearchMap(db, source1, source2) @@ -113,10 +113,10 @@ func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1 strings.Join(conditions, " AND "), ) - return db.Exec(sql, values...).Error + return db.Exec(sql, values...).GetError() } -func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error { +func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db Database, sources ...interface{}) error { var ( scope = db.NewScope(nil) conditions []string @@ -128,10 +128,10 @@ func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sour values = append(values, value) } - return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error + return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").GetError() } -func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB { +func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db Database, source interface{}) Database { var ( scope = db.NewScope(source) tableName = handler.Table(db) @@ -174,7 +174,7 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTableName, strings.Join(joinConditions, " AND "))). Where(condString, toQueryValues(foreignFieldValues)...) } else { - db.Error = errors.New("wrong source type for join table handler") + db.AddError(errors.New("wrong source type for join table handler")) return db } } diff --git a/join_table_test.go b/join_table_test.go index 70e792ed..6e4ec005 100644 --- a/join_table_test.go +++ b/join_table_test.go @@ -22,7 +22,7 @@ type PersonAddress struct { CreatedAt time.Time } -func (*PersonAddress) Add(handler gorm.JoinTableHandlerInterface, db *gorm.DB, foreignValue interface{}, associationValue interface{}) error { +func (*PersonAddress) Add(handler gorm.JoinTableHandlerInterface, db gorm.Database, foreignValue interface{}, associationValue interface{}) error { return db.Where(map[string]interface{}{ "person_id": db.NewScope(foreignValue).PrimaryKeyValue(), "address_id": db.NewScope(associationValue).PrimaryKeyValue(), @@ -30,14 +30,14 @@ func (*PersonAddress) Add(handler gorm.JoinTableHandlerInterface, db *gorm.DB, f "person_id": foreignValue, "address_id": associationValue, "deleted_at": gorm.Expr("NULL"), - }).FirstOrCreate(&PersonAddress{}).Error + }).FirstOrCreate(&PersonAddress{}).GetError() } -func (*PersonAddress) Delete(handler gorm.JoinTableHandlerInterface, db *gorm.DB, sources ...interface{}) error { - return db.Delete(&PersonAddress{}).Error +func (*PersonAddress) Delete(handler gorm.JoinTableHandlerInterface, db gorm.Database, sources ...interface{}) error { + return db.Delete(&PersonAddress{}).GetError() } -func (pa *PersonAddress) JoinWith(handler gorm.JoinTableHandlerInterface, db *gorm.DB, source interface{}) *gorm.DB { +func (pa *PersonAddress) JoinWith(handler gorm.JoinTableHandlerInterface, db gorm.Database, source interface{}) gorm.Database { table := pa.Table(db) return db.Joins("INNER JOIN person_addresses ON person_addresses.address_id = addresses.id").Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table)) } @@ -54,7 +54,7 @@ func TestJoinTable(t *testing.T) { DB.Model(person).Association("Addresses").Delete(address1) - if DB.Find(&[]PersonAddress{}, "person_id = ?", person.Id).RowsAffected != 1 { + if DB.Find(&[]PersonAddress{}, "person_id = ?", person.Id).GetRowsAffected() != 1 { t.Errorf("Should found one address") } @@ -62,7 +62,7 @@ func TestJoinTable(t *testing.T) { t.Errorf("Should found one address") } - if DB.Unscoped().Find(&[]PersonAddress{}, "person_id = ?", person.Id).RowsAffected != 2 { + if DB.Unscoped().Find(&[]PersonAddress{}, "person_id = ?", person.Id).GetRowsAffected() != 2 { t.Errorf("Found two addresses with Unscoped") } diff --git a/main.go b/main.go index ff707f3f..606d6b6f 100644 --- a/main.go +++ b/main.go @@ -90,7 +90,7 @@ func (s *DB) DB() *sql.DB { return s.db.(*sql.DB) } -func (s *DB) New() *DB { +func (s *DB) New() Database { clone := s.clone() clone.search = nil clone.Value = nil @@ -120,7 +120,7 @@ func (s *DB) SetLogger(l logger) { s.logger = l } -func (s *DB) LogMode(enable bool) *DB { +func (s *DB) LogMode(enable bool) Database { if enable { s.logMode = 2 } else { @@ -134,47 +134,47 @@ func (s *DB) SingularTable(enable bool) { s.parent.singularTable = enable } -func (s *DB) Where(query interface{}, args ...interface{}) *DB { +func (s *DB) Where(query interface{}, args ...interface{}) Database { return s.clone().search.Where(query, args...).db } -func (s *DB) Or(query interface{}, args ...interface{}) *DB { +func (s *DB) Or(query interface{}, args ...interface{}) Database { return s.clone().search.Or(query, args...).db } -func (s *DB) Not(query interface{}, args ...interface{}) *DB { +func (s *DB) Not(query interface{}, args ...interface{}) Database { return s.clone().search.Not(query, args...).db } -func (s *DB) Limit(value interface{}) *DB { +func (s *DB) Limit(value interface{}) Database { return s.clone().search.Limit(value).db } -func (s *DB) Offset(value interface{}) *DB { +func (s *DB) Offset(value interface{}) Database { return s.clone().search.Offset(value).db } -func (s *DB) Order(value string, reorder ...bool) *DB { +func (s *DB) Order(value string, reorder ...bool) Database { return s.clone().search.Order(value, reorder...).db } -func (s *DB) Select(query interface{}, args ...interface{}) *DB { +func (s *DB) Select(query interface{}, args ...interface{}) Database { return s.clone().search.Select(query, args...).db } -func (s *DB) Omit(columns ...string) *DB { +func (s *DB) Omit(columns ...string) Database { return s.clone().search.Omit(columns...).db } -func (s *DB) Group(query string) *DB { +func (s *DB) Group(query string) Database { return s.clone().search.Group(query).db } -func (s *DB) Having(query string, values ...interface{}) *DB { +func (s *DB) Having(query string, values ...interface{}) Database { return s.clone().search.Having(query, values...).db } -func (s *DB) Joins(query string) *DB { +func (s *DB) Joins(query string) Database { return s.clone().search.Joins(query).db } @@ -185,37 +185,37 @@ func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB { return s } -func (s *DB) Unscoped() *DB { +func (s *DB) Unscoped() Database { return s.clone().search.unscoped().db } -func (s *DB) Attrs(attrs ...interface{}) *DB { +func (s *DB) Attrs(attrs ...interface{}) Database { return s.clone().search.Attrs(attrs...).db } -func (s *DB) Assign(attrs ...interface{}) *DB { +func (s *DB) Assign(attrs ...interface{}) Database { return s.clone().search.Assign(attrs...).db } -func (s *DB) First(out interface{}, where ...interface{}) *DB { +func (s *DB) First(out interface{}, where ...interface{}) Database { newScope := s.clone().NewScope(out) newScope.Search.Limit(1) return newScope.Set("gorm:order_by_primary_key", "ASC"). inlineCondition(where...).callCallbacks(s.parent.callback.queries).db } -func (s *DB) Last(out interface{}, where ...interface{}) *DB { +func (s *DB) Last(out interface{}, where ...interface{}) Database { newScope := s.clone().NewScope(out) newScope.Search.Limit(1) return newScope.Set("gorm:order_by_primary_key", "DESC"). inlineCondition(where...).callCallbacks(s.parent.callback.queries).db } -func (s *DB) Find(out interface{}, where ...interface{}) *DB { +func (s *DB) Find(out interface{}, where ...interface{}) Database { return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callback.queries).db } -func (s *DB) Scan(dest interface{}) *DB { +func (s *DB) Scan(dest interface{}) Database { return s.clone().NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callback.queries).db } @@ -227,21 +227,21 @@ func (s *DB) Rows() (*sql.Rows, error) { return s.NewScope(s.Value).rows() } -func (s *DB) Pluck(column string, value interface{}) *DB { +func (s *DB) Pluck(column string, value interface{}) Database { return s.NewScope(s.Value).pluck(column, value).db } -func (s *DB) Count(value interface{}) *DB { +func (s *DB) Count(value interface{}) Database { return s.NewScope(s.Value).count(value).db } -func (s *DB) Related(value interface{}, foreignKeys ...string) *DB { +func (s *DB) Related(value interface{}, foreignKeys ...string) Database { return s.clone().NewScope(s.Value).related(value, foreignKeys...).db } -func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { +func (s *DB) FirstOrInit(out interface{}, where ...interface{}) Database { c := s.clone() - if result := c.First(out, where...); result.Error != nil { + if result := c.First(out, where...); result.GetError() != nil { if !result.RecordNotFound() { return result } @@ -252,35 +252,35 @@ func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { return c } -func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB { +func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) Database { c := s.clone() - if result := c.First(out, where...); result.Error != nil { + if result := c.First(out, where...); result.GetError() != nil { if !result.RecordNotFound() { return result } - c.AddError(c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callback.creates).db.Error) + c.AddError(c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callback.creates).db.GetError()) } else if len(c.search.assignAttrs) > 0 { - c.AddError(c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callback.updates).db.Error) + c.AddError(c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callback.updates).db.GetError()) } return c } -func (s *DB) Update(attrs ...interface{}) *DB { +func (s *DB) Update(attrs ...interface{}) Database { return s.Updates(toSearchableMap(attrs...), true) } -func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB { +func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) Database { return s.clone().NewScope(s.Value). Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0). InstanceSet("gorm:update_interface", values). callCallbacks(s.parent.callback.updates).db } -func (s *DB) UpdateColumn(attrs ...interface{}) *DB { +func (s *DB) UpdateColumn(attrs ...interface{}) Database { return s.UpdateColumns(toSearchableMap(attrs...)) } -func (s *DB) UpdateColumns(values interface{}) *DB { +func (s *DB) UpdateColumns(values interface{}) Database { return s.clone().NewScope(s.Value). Set("gorm:update_column", true). Set("gorm:save_associations", false). @@ -288,7 +288,7 @@ func (s *DB) UpdateColumns(values interface{}) *DB { callCallbacks(s.parent.callback.updates).db } -func (s *DB) Save(value interface{}) *DB { +func (s *DB) Save(value interface{}) Database { scope := s.clone().NewScope(value) if scope.PrimaryKeyZero() { return scope.callCallbacks(s.parent.callback.creates).db @@ -296,20 +296,20 @@ func (s *DB) Save(value interface{}) *DB { return scope.callCallbacks(s.parent.callback.updates).db } -func (s *DB) Create(value interface{}) *DB { +func (s *DB) Create(value interface{}) Database { scope := s.clone().NewScope(value) return scope.callCallbacks(s.parent.callback.creates).db } -func (s *DB) Delete(value interface{}, where ...interface{}) *DB { +func (s *DB) Delete(value interface{}, where ...interface{}) Database { return s.clone().NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callback.deletes).db } -func (s *DB) Raw(sql string, values ...interface{}) *DB { +func (s *DB) Raw(sql string, values ...interface{}) Database { return s.clone().search.Raw(true).Where(sql, values...).db } -func (s *DB) Exec(sql string, values ...interface{}) *DB { +func (s *DB) Exec(sql string, values ...interface{}) Database { scope := s.clone().NewScope(nil) generatedSql := scope.buildWhereCondition(map[string]interface{}{"query": sql, "args": values}) generatedSql = strings.TrimSuffix(strings.TrimPrefix(generatedSql, "("), ")") @@ -317,24 +317,24 @@ func (s *DB) Exec(sql string, values ...interface{}) *DB { return scope.Exec().db } -func (s *DB) Model(value interface{}) *DB { +func (s *DB) Model(value interface{}) Database { c := s.clone() c.Value = value return c } -func (s *DB) Table(name string) *DB { +func (s *DB) Table(name string) Database { clone := s.clone() clone.search.Table(name) clone.Value = nil return clone } -func (s *DB) Debug() *DB { +func (s *DB) Debug() Database { return s.clone().LogMode(true) } -func (s *DB) Begin() *DB { +func (s *DB) Begin() Database { c := s.clone() if db, ok := c.db.(sqlDb); ok { tx, err := db.Begin() @@ -346,7 +346,7 @@ func (s *DB) Begin() *DB { return c } -func (s *DB) Commit() *DB { +func (s *DB) Commit() Database { if db, ok := s.db.(sqlTx); ok { s.AddError(db.Commit()) } else { @@ -355,7 +355,7 @@ func (s *DB) Commit() *DB { return s } -func (s *DB) Rollback() *DB { +func (s *DB) Rollback() Database { if db, ok := s.db.(sqlTx); ok { s.AddError(db.Rollback()) } else { @@ -373,16 +373,18 @@ func (s *DB) RecordNotFound() bool { } // Migrations -func (s *DB) CreateTable(values ...interface{}) *DB { - db := s.clone() +func (s *DB) CreateTable(values ...interface{}) Database { + var db Database + db = s.clone() for _, value := range values { db = db.NewScope(value).createTable().db } return db } -func (s *DB) DropTable(values ...interface{}) *DB { - db := s.clone() +func (s *DB) DropTable(values ...interface{}) Database { + var db Database + db = s.clone() for _, value := range values { if tableName, ok := value.(string); ok { db = db.Table(tableName) @@ -393,8 +395,9 @@ func (s *DB) DropTable(values ...interface{}) *DB { return db } -func (s *DB) DropTableIfExists(values ...interface{}) *DB { - db := s.clone() +func (s *DB) DropTableIfExists(values ...interface{}) Database { + var db Database + db = s.clone() for _, value := range values { if tableName, ok := value.(string); ok { db = db.Table(tableName) @@ -409,43 +412,44 @@ func (s *DB) HasTable(value interface{}) bool { scope := s.clone().NewScope(value) tableName := scope.TableName() has := scope.Dialect().HasTable(scope, tableName) - s.AddError(scope.db.Error) + s.AddError(scope.db.GetError()) return has } -func (s *DB) AutoMigrate(values ...interface{}) *DB { - db := s.clone() +func (s *DB) AutoMigrate(values ...interface{}) Database { + var db Database + db = s.clone() for _, value := range values { db = db.NewScope(value).NeedPtr().autoMigrate().db } return db } -func (s *DB) ModifyColumn(column string, typ string) *DB { +func (s *DB) ModifyColumn(column string, typ string) Database { scope := s.clone().NewScope(s.Value) scope.modifyColumn(column, typ) return scope.db } -func (s *DB) DropColumn(column string) *DB { +func (s *DB) DropColumn(column string) Database { scope := s.clone().NewScope(s.Value) scope.dropColumn(column) return scope.db } -func (s *DB) AddIndex(indexName string, column ...string) *DB { +func (s *DB) AddIndex(indexName string, column ...string) Database { scope := s.Unscoped().NewScope(s.Value) scope.addIndex(false, indexName, column...) return scope.db } -func (s *DB) AddUniqueIndex(indexName string, column ...string) *DB { +func (s *DB) AddUniqueIndex(indexName string, column ...string) Database { scope := s.clone().NewScope(s.Value) scope.addIndex(true, indexName, column...) return scope.db } -func (s *DB) RemoveIndex(indexName string) *DB { +func (s *DB) RemoveIndex(indexName string) Database { scope := s.clone().NewScope(s.Value) scope.removeIndex(indexName) return scope.db @@ -465,7 +469,7 @@ Add foreign key to the given scope Example: db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT") */ -func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB { +func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) Database { scope := s.clone().NewScope(s.Value) scope.addForeignKey(field, dest, onDelete, onUpdate) return scope.db @@ -492,16 +496,16 @@ func (s *DB) Association(column string) *Association { return &Association{Error: err} } -func (s *DB) Preload(column string, conditions ...interface{}) *DB { +func (s *DB) Preload(column string, conditions ...interface{}) Database { return s.clone().search.Preload(column, conditions...).db } // Set set value by name -func (s *DB) Set(name string, value interface{}) *DB { +func (s *DB) Set(name string, value interface{}) Database { return s.clone().InstantSet(name, value) } -func (s *DB) InstantSet(name string, value interface{}) *DB { +func (s *DB) InstantSet(name string, value interface{}) Database { s.values[name] = value return s } @@ -550,6 +554,18 @@ func (s *DB) AddError(err error) error { return err } +func (s *DB) GetError() error { + return s.Error +} + +func (s *DB) SetRowsAffected(num int64) { + s.RowsAffected = num +} + +func (s *DB) GetRowsAffected() int64 { + return s.RowsAffected +} + func (s *DB) GetErrors() (errors []error) { if errs, ok := s.Error.(errorsInterface); ok { return errs.GetErrors() diff --git a/main_test.go b/main_test.go index 65467d73..261e568a 100644 --- a/main_test.go +++ b/main_test.go @@ -20,7 +20,7 @@ import ( ) var ( - DB gorm.DB + DB gorm.Database t1, t2, t3, t4, t5 time.Time ) @@ -42,7 +42,11 @@ func init() { runMigration() } -func OpenTestConnection() (db gorm.DB, err error) { +func OpenTestConnection() (*gorm.DB, error) { + var ( + db gorm.DB + err error + ) switch os.Getenv("GORM_DIALECT") { case "mysql": // CREATE USER 'gorm'@'localhost' IDENTIFIED BY 'gorm'; @@ -63,7 +67,8 @@ func OpenTestConnection() (db gorm.DB, err error) { fmt.Println("testing sqlite3...") db, err = gorm.Open("sqlite3", "/tmp/gorm.db") } - return + + return &db, err } func TestStringPrimaryKey(t *testing.T) { @@ -74,22 +79,22 @@ func TestStringPrimaryKey(t *testing.T) { DB.AutoMigrate(&UUIDStruct{}) data := UUIDStruct{ID: "uuid", Name: "hello"} - if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" { + if err := DB.Save(&data).GetError(); err != nil || data.ID != "uuid" { t.Errorf("string primary key should not be populated") } } func TestExceptionsWithInvalidSql(t *testing.T) { var columns []string - if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { + if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).GetError() == nil { t.Errorf("Should got error with invalid SQL") } - if DB.Model(&User{}).Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { + if DB.Model(&User{}).Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).GetError() == nil { t.Errorf("Should got error with invalid SQL") } - if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Find(&User{}).Error == nil { + if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Find(&User{}).GetError() == nil { t.Errorf("Should got error with invalid SQL") } @@ -99,7 +104,7 @@ func TestExceptionsWithInvalidSql(t *testing.T) { t.Errorf("Should find some users") } - if DB.Where("name = ?", "jinzhu; delete * from users").First(&User{}).Error == nil { + if DB.Where("name = ?", "jinzhu; delete * from users").First(&User{}).GetError() == nil { t.Errorf("Should got error with invalid SQL") } @@ -114,21 +119,21 @@ func TestSetTable(t *testing.T) { DB.Create(getPreparedUser("pluck_user2", "pluck_user")) DB.Create(getPreparedUser("pluck_user3", "pluck_user")) - if err := DB.Table("users").Where("role = ?", "pluck_user").Pluck("age", &[]int{}).Error; err != nil { + if err := DB.Table("users").Where("role = ?", "pluck_user").Pluck("age", &[]int{}).GetError(); err != nil { t.Errorf("No errors should happen if set table for pluck", err.Error()) } var users []User - if DB.Table("users").Find(&[]User{}).Error != nil { + if DB.Table("users").Find(&[]User{}).GetError() != nil { t.Errorf("No errors should happen if set table for find") } - if DB.Table("invalid_table").Find(&users).Error == nil { + if DB.Table("invalid_table").Find(&users).GetError() == nil { t.Errorf("Should got error when table is set to an invalid table") } DB.Exec("drop table deleted_users;") - if DB.Table("deleted_users").CreateTable(&User{}).Error != nil { + if DB.Table("deleted_users").CreateTable(&User{}).GetError() != nil { t.Errorf("Create table with specified table") } @@ -168,7 +173,7 @@ func TestHasTable(t *testing.T) { if ok := DB.HasTable(&Foo{}); ok { t.Errorf("Table should not exist, but does") } - if err := DB.CreateTable(&Foo{}).Error; err != nil { + if err := DB.CreateTable(&Foo{}).GetError(); err != nil { t.Errorf("Table should be created") } if ok := DB.HasTable(&Foo{}); !ok { @@ -240,7 +245,7 @@ func TestNullValues(t *testing.T) { Male: sql.NullBool{Bool: true, Valid: true}, Height: sql.NullFloat64{Float64: 100.11, Valid: true}, AddedAt: NullTime{Time: time.Now(), Valid: true}, - }).Error; err != nil { + }).GetError(); err != nil { t.Errorf("Not error should raise when test null value") } @@ -258,7 +263,7 @@ func TestNullValues(t *testing.T) { Male: sql.NullBool{Bool: true, Valid: true}, Height: sql.NullFloat64{Float64: 100.11, Valid: true}, AddedAt: NullTime{Time: time.Now(), Valid: false}, - }).Error; err != nil { + }).GetError(); err != nil { t.Errorf("Not error should raise when test null value") } @@ -275,7 +280,7 @@ func TestNullValues(t *testing.T) { Male: sql.NullBool{Bool: true, Valid: true}, Height: sql.NullFloat64{Float64: 100.11, Valid: true}, AddedAt: NullTime{Time: time.Now(), Valid: false}, - }).Error; err == nil { + }).GetError(); err == nil { t.Errorf("Can't save because of name can't be null") } } @@ -287,7 +292,7 @@ func TestNullValuesWithFirstOrCreate(t *testing.T) { } var nv2 NullValue - if err := DB.Where(nv1).FirstOrCreate(&nv2).Error; err != nil { + if err := DB.Where(nv1).FirstOrCreate(&nv2).GetError(); err != nil { t.Errorf("Should not raise any error, but got %v", err) } @@ -295,7 +300,7 @@ func TestNullValuesWithFirstOrCreate(t *testing.T) { t.Errorf("first or create with nullvalues") } - if err := DB.Where(nv1).Assign(NullValue{Age: sql.NullInt64{Int64: 18, Valid: true}}).FirstOrCreate(&nv2).Error; err != nil { + if err := DB.Where(nv1).Assign(NullValue{Age: sql.NullInt64{Int64: 18, Valid: true}}).FirstOrCreate(&nv2).GetError(); err != nil { t.Errorf("Should not raise any error, but got %v", err) } @@ -307,11 +312,11 @@ func TestNullValuesWithFirstOrCreate(t *testing.T) { func TestTransaction(t *testing.T) { tx := DB.Begin() u := User{Name: "transcation"} - if err := tx.Save(&u).Error; err != nil { + if err := tx.Save(&u).GetError(); err != nil { t.Errorf("No error should raise") } - if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil { + if err := tx.First(&User{}, "name = ?", "transcation").GetError(); err != nil { t.Errorf("Should find saved record") } @@ -321,23 +326,23 @@ func TestTransaction(t *testing.T) { tx.Rollback() - if err := tx.First(&User{}, "name = ?", "transcation").Error; err == nil { + if err := tx.First(&User{}, "name = ?", "transcation").GetError(); err == nil { t.Errorf("Should not find record after rollback") } tx2 := DB.Begin() u2 := User{Name: "transcation-2"} - if err := tx2.Save(&u2).Error; err != nil { + if err := tx2.Save(&u2).GetError(); err != nil { t.Errorf("No error should raise") } - if err := tx2.First(&User{}, "name = ?", "transcation-2").Error; err != nil { + if err := tx2.First(&User{}, "name = ?", "transcation-2").GetError(); err != nil { t.Errorf("Should find saved record") } tx2.Commit() - if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil { + if err := DB.First(&User{}, "name = ?", "transcation-2").GetError(); err != nil { t.Errorf("Should be able to find committed record") } } @@ -436,7 +441,7 @@ func TestRaw(t *testing.T) { } DB.Exec("update users set name=? where name in (?)", "jinzhu", []string{user1.Name, user2.Name, user3.Name}) - if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.RecordNotFound { + if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).GetError() != gorm.RecordNotFound { t.Error("Raw sql to update records") } } @@ -568,14 +573,14 @@ func TestHstore(t *testing.T) { t.Skip() } - if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS hstore").Error; err != nil { + if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS hstore").GetError(); err != nil { fmt.Println("\033[31mHINT: Must be superuser to create hstore extension (ALTER USER gorm WITH SUPERUSER;)\033[0m") panic(fmt.Sprintf("No error should happen when create hstore extension, but got %+v", err)) } DB.Exec("drop table details") - if err := DB.CreateTable(&Details{}).Error; err != nil { + if err := DB.CreateTable(&Details{}).GetError(); err != nil { panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) } @@ -589,7 +594,7 @@ func TestHstore(t *testing.T) { DB.Save(&d) var d2 Details - if err := DB.First(&d2).Error; err != nil { + if err := DB.First(&d2).GetError(); err != nil { t.Errorf("Got error when tried to fetch details: %+v", err) } @@ -647,7 +652,7 @@ func TestOpenExistingDB(t *testing.T) { } var user User - if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.RecordNotFound { + if db.Where("name = ?", "jnfeinstein").First(&user).GetError() == gorm.RecordNotFound { t.Errorf("Should have found existing record") } } diff --git a/migration_test.go b/migration_test.go index 0411872e..284ed5f9 100644 --- a/migration_test.go +++ b/migration_test.go @@ -7,7 +7,7 @@ import ( ) func runMigration() { - if err := DB.DropTableIfExists(&User{}).Error; err != nil { + if err := DB.DropTableIfExists(&User{}).GetError(); err != nil { fmt.Printf("Got error when try to delete table users, %+v\n", err) } @@ -20,13 +20,13 @@ func runMigration() { DB.DropTable(value) } - if err := DB.AutoMigrate(values...).Error; err != nil { + if err := DB.AutoMigrate(values...).GetError(); err != nil { panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) } } func TestIndexes(t *testing.T) { - if err := DB.Model(&Email{}).AddIndex("idx_email_email", "email").Error; err != nil { + if err := DB.Model(&Email{}).AddIndex("idx_email_email", "email").GetError(); err != nil { t.Errorf("Got error when tried to create index: %+v", err) } @@ -35,7 +35,7 @@ func TestIndexes(t *testing.T) { t.Errorf("Email should have index idx_email_email") } - if err := DB.Model(&Email{}).RemoveIndex("idx_email_email").Error; err != nil { + if err := DB.Model(&Email{}).RemoveIndex("idx_email_email").GetError(); err != nil { t.Errorf("Got error when tried to remove index: %+v", err) } @@ -43,7 +43,7 @@ func TestIndexes(t *testing.T) { t.Errorf("Email's index idx_email_email should be deleted") } - if err := DB.Model(&Email{}).AddIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil { + if err := DB.Model(&Email{}).AddIndex("idx_email_email_and_user_id", "user_id", "email").GetError(); err != nil { t.Errorf("Got error when tried to create index: %+v", err) } @@ -51,7 +51,7 @@ func TestIndexes(t *testing.T) { t.Errorf("Email should have index idx_email_email_and_user_id") } - if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil { + if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").GetError(); err != nil { t.Errorf("Got error when tried to remove index: %+v", err) } @@ -59,7 +59,7 @@ func TestIndexes(t *testing.T) { t.Errorf("Email's index idx_email_email_and_user_id should be deleted") } - if err := DB.Model(&Email{}).AddUniqueIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil { + if err := DB.Model(&Email{}).AddUniqueIndex("idx_email_email_and_user_id", "user_id", "email").GetError(); err != nil { t.Errorf("Got error when tried to create index: %+v", err) } @@ -67,7 +67,7 @@ func TestIndexes(t *testing.T) { t.Errorf("Email should have index idx_email_email_and_user_id") } - if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.comiii"}, {Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error == nil { + if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.comiii"}, {Email: "user1@example.com"}, {Email: "user1@example.com"}}}).GetError() == nil { t.Errorf("Should get to create duplicate record when having unique index") } @@ -81,7 +81,7 @@ func TestIndexes(t *testing.T) { t.Errorf("Should get no duplicated email error when insert duplicated emails for a user") } - if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil { + if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").GetError(); err != nil { t.Errorf("Got error when tried to remove index: %+v", err) } @@ -89,7 +89,7 @@ func TestIndexes(t *testing.T) { t.Errorf("Email's index idx_email_email_and_user_id should be deleted") } - if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error != nil { + if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.com"}, {Email: "user1@example.com"}}}).GetError() != nil { t.Errorf("Should be able to create duplicated emails after remove unique index") } } @@ -110,7 +110,7 @@ func (b BigEmail) TableName() string { func TestAutoMigration(t *testing.T) { DB.AutoMigrate(&Address{}) - if err := DB.Table("emails").AutoMigrate(&BigEmail{}).Error; err != nil { + if err := DB.Table("emails").AutoMigrate(&BigEmail{}).GetError(); err != nil { t.Errorf("Auto Migrate should not raise any error") } diff --git a/model_struct.go b/model_struct.go index d80165c8..860f1a42 100644 --- a/model_struct.go +++ b/model_struct.go @@ -14,7 +14,7 @@ import ( "github.com/jinzhu/inflection" ) -var DefaultTableNameHandler = func(db *DB, defaultTableName string) string { +var DefaultTableNameHandler = func(db Database, defaultTableName string) string { return defaultTableName } @@ -48,7 +48,7 @@ type ModelStruct struct { defaultTableName string } -func (s *ModelStruct) TableName(db *DB) string { +func (s *ModelStruct) TableName(db Database) string { return DefaultTableNameHandler(db, s.defaultTableName) } diff --git a/pointer_test.go b/pointer_test.go index b47717f3..6db1da17 100644 --- a/pointer_test.go +++ b/pointer_test.go @@ -20,12 +20,12 @@ func TestPointerFields(t *testing.T) { var name = "pointer struct 1" var num = 100 pointerStruct := PointerStruct{Name: &name, Num: &num} - if DB.Create(&pointerStruct).Error != nil { + if DB.Create(&pointerStruct).GetError() != nil { t.Errorf("Failed to save pointer struct") } var pointerStructResult PointerStruct - if err := DB.First(&pointerStructResult, "id = ?", pointerStruct.ID).Error; err != nil || *pointerStructResult.Name != name || *pointerStructResult.Num != num { + if err := DB.First(&pointerStructResult, "id = ?", pointerStruct.ID).GetError(); err != nil || *pointerStructResult.Name != name || *pointerStructResult.Num != num { t.Errorf("Failed to query saved pointer struct") } @@ -38,47 +38,47 @@ func TestPointerFields(t *testing.T) { } var nilPointerStruct = PointerStruct{} - if err := DB.Create(&nilPointerStruct).Error; err != nil { + if err := DB.Create(&nilPointerStruct).GetError(); err != nil { t.Errorf("Failed to save nil pointer struct", err) } var pointerStruct2 PointerStruct - if err := DB.First(&pointerStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil { + if err := DB.First(&pointerStruct2, "id = ?", nilPointerStruct.ID).GetError(); err != nil { t.Errorf("Failed to query saved nil pointer struct", err) } var normalStruct2 NormalStruct - if err := DB.Table(tableName).First(&normalStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil { + if err := DB.Table(tableName).First(&normalStruct2, "id = ?", nilPointerStruct.ID).GetError(); err != nil { t.Errorf("Failed to query saved nil pointer struct", err) } var partialNilPointerStruct1 = PointerStruct{Num: &num} - if err := DB.Create(&partialNilPointerStruct1).Error; err != nil { + if err := DB.Create(&partialNilPointerStruct1).GetError(); err != nil { t.Errorf("Failed to save partial nil pointer struct", err) } var pointerStruct3 PointerStruct - if err := DB.First(&pointerStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || *pointerStruct3.Num != num { + if err := DB.First(&pointerStruct3, "id = ?", partialNilPointerStruct1.ID).GetError(); err != nil || *pointerStruct3.Num != num { t.Errorf("Failed to query saved partial nil pointer struct", err) } var normalStruct3 NormalStruct - if err := DB.Table(tableName).First(&normalStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || normalStruct3.Num != num { + if err := DB.Table(tableName).First(&normalStruct3, "id = ?", partialNilPointerStruct1.ID).GetError(); err != nil || normalStruct3.Num != num { t.Errorf("Failed to query saved partial pointer struct", err) } var partialNilPointerStruct2 = PointerStruct{Name: &name} - if err := DB.Create(&partialNilPointerStruct2).Error; err != nil { + if err := DB.Create(&partialNilPointerStruct2).GetError(); err != nil { t.Errorf("Failed to save partial nil pointer struct", err) } var pointerStruct4 PointerStruct - if err := DB.First(&pointerStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || *pointerStruct4.Name != name { + if err := DB.First(&pointerStruct4, "id = ?", partialNilPointerStruct2.ID).GetError(); err != nil || *pointerStruct4.Name != name { t.Errorf("Failed to query saved partial nil pointer struct", err) } var normalStruct4 NormalStruct - if err := DB.Table(tableName).First(&normalStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || normalStruct4.Name != name { + if err := DB.Table(tableName).First(&normalStruct4, "id = ?", partialNilPointerStruct2.ID).GetError(); err != nil || normalStruct4.Name != name { t.Errorf("Failed to query saved partial pointer struct", err) } } diff --git a/postgres.go b/postgres.go index 3b083dfa..af9fded0 100644 --- a/postgres.go +++ b/postgres.go @@ -94,7 +94,7 @@ func (s postgres) HasColumn(scope *Scope, tableName string, columnName string) b } func (postgres) RemoveIndex(scope *Scope, indexName string) { - scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error) + scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).GetError()) } func (s postgres) HasIndex(scope *Scope, tableName string, indexName string) bool { diff --git a/preload.go b/preload.go index d12995f3..cb55b153 100644 --- a/preload.go +++ b/preload.go @@ -115,7 +115,7 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) } results := makeSlice(field.Struct.Type) - scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error) + scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).GetError()) resultValues := reflect.Indirect(reflect.ValueOf(results)) for i := 0; i < resultValues.Len(); i++ { @@ -146,7 +146,7 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) } results := makeSlice(field.Struct.Type) - scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error) + scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).GetError()) resultValues := reflect.Indirect(reflect.ValueOf(results)) if scope.IndirectValue().Kind() == reflect.Slice { @@ -176,7 +176,7 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{ } results := makeSlice(field.Struct.Type) - scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error) + scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).GetError()) resultValues := reflect.Indirect(reflect.ValueOf(results)) for i := 0; i < resultValues.Len(); i++ { diff --git a/preload_test.go b/preload_test.go index 29ea39a7..5bd386bb 100644 --- a/preload_test.go +++ b/preload_test.go @@ -115,17 +115,17 @@ func TestNestedPreload1(t *testing.T) { DB.DropTableIfExists(&Level3{}) DB.DropTableIfExists(&Level2{}) DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).GetError(); err != nil { t.Error(err) } want := Level3{Level2: Level2{Level1: Level1{Value: "value"}}} - if err := DB.Create(&want).Error; err != nil { + if err := DB.Create(&want).GetError(); err != nil { t.Error(err) } var got Level3 - if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil { + if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).GetError(); err != nil { t.Error(err) } @@ -133,7 +133,7 @@ func TestNestedPreload1(t *testing.T) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } - if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").Error; err != gorm.RecordNotFound { + if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").GetError(); err != gorm.RecordNotFound { t.Error(err) } } @@ -159,7 +159,7 @@ func TestNestedPreload2(t *testing.T) { DB.DropTableIfExists(&Level3{}) DB.DropTableIfExists(&Level2{}) DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).GetError(); err != nil { t.Error(err) } @@ -178,12 +178,12 @@ func TestNestedPreload2(t *testing.T) { }, }, } - if err := DB.Create(&want).Error; err != nil { + if err := DB.Create(&want).GetError(); err != nil { t.Error(err) } var got Level3 - if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil { + if err := DB.Preload("Level2s.Level1s").Find(&got).GetError(); err != nil { t.Error(err) } @@ -213,7 +213,7 @@ func TestNestedPreload3(t *testing.T) { DB.DropTableIfExists(&Level3{}) DB.DropTableIfExists(&Level2{}) DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).GetError(); err != nil { t.Error(err) } @@ -223,12 +223,12 @@ func TestNestedPreload3(t *testing.T) { {Level1: Level1{Value: "value2"}}, }, } - if err := DB.Create(&want).Error; err != nil { + if err := DB.Create(&want).GetError(); err != nil { t.Error(err) } var got Level3 - if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil { + if err := DB.Preload("Level2s.Level1").Find(&got).GetError(); err != nil { t.Error(err) } @@ -258,7 +258,7 @@ func TestNestedPreload4(t *testing.T) { DB.DropTableIfExists(&Level3{}) DB.DropTableIfExists(&Level2{}) DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).GetError(); err != nil { t.Error(err) } @@ -270,12 +270,12 @@ func TestNestedPreload4(t *testing.T) { }, }, } - if err := DB.Create(&want).Error; err != nil { + if err := DB.Create(&want).GetError(); err != nil { t.Error(err) } var got Level3 - if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil { + if err := DB.Preload("Level2.Level1s").Find(&got).GetError(); err != nil { t.Error(err) } @@ -306,22 +306,22 @@ func TestNestedPreload5(t *testing.T) { DB.DropTableIfExists(&Level3{}) DB.DropTableIfExists(&Level2{}) DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).GetError(); err != nil { t.Error(err) } want := make([]Level3, 2) want[0] = Level3{Level2: Level2{Level1: Level1{Value: "value"}}} - if err := DB.Create(&want[0]).Error; err != nil { + if err := DB.Create(&want[0]).GetError(); err != nil { t.Error(err) } want[1] = Level3{Level2: Level2{Level1: Level1{Value: "value2"}}} - if err := DB.Create(&want[1]).Error; err != nil { + if err := DB.Create(&want[1]).GetError(); err != nil { t.Error(err) } var got []Level3 - if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil { + if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).GetError(); err != nil { t.Error(err) } @@ -351,7 +351,7 @@ func TestNestedPreload6(t *testing.T) { DB.DropTableIfExists(&Level3{}) DB.DropTableIfExists(&Level2{}) DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).GetError(); err != nil { t.Error(err) } @@ -371,7 +371,7 @@ func TestNestedPreload6(t *testing.T) { }, }, } - if err := DB.Create(&want[0]).Error; err != nil { + if err := DB.Create(&want[0]).GetError(); err != nil { t.Error(err) } @@ -390,12 +390,12 @@ func TestNestedPreload6(t *testing.T) { }, }, } - if err := DB.Create(&want[1]).Error; err != nil { + if err := DB.Create(&want[1]).GetError(); err != nil { t.Error(err) } var got []Level3 - if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil { + if err := DB.Preload("Level2s.Level1s").Find(&got).GetError(); err != nil { t.Error(err) } @@ -425,7 +425,7 @@ func TestNestedPreload7(t *testing.T) { DB.DropTableIfExists(&Level3{}) DB.DropTableIfExists(&Level2{}) DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).GetError(); err != nil { t.Error(err) } @@ -436,7 +436,7 @@ func TestNestedPreload7(t *testing.T) { {Level1: Level1{Value: "value2"}}, }, } - if err := DB.Create(&want[0]).Error; err != nil { + if err := DB.Create(&want[0]).GetError(); err != nil { t.Error(err) } @@ -446,12 +446,12 @@ func TestNestedPreload7(t *testing.T) { {Level1: Level1{Value: "value4"}}, }, } - if err := DB.Create(&want[1]).Error; err != nil { + if err := DB.Create(&want[1]).GetError(); err != nil { t.Error(err) } var got []Level3 - if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil { + if err := DB.Preload("Level2s.Level1").Find(&got).GetError(); err != nil { t.Error(err) } @@ -481,7 +481,7 @@ func TestNestedPreload8(t *testing.T) { DB.DropTableIfExists(&Level3{}) DB.DropTableIfExists(&Level2{}) DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).GetError(); err != nil { t.Error(err) } @@ -494,7 +494,7 @@ func TestNestedPreload8(t *testing.T) { }, }, } - if err := DB.Create(&want[0]).Error; err != nil { + if err := DB.Create(&want[0]).GetError(); err != nil { t.Error(err) } want[1] = Level3{ @@ -505,12 +505,12 @@ func TestNestedPreload8(t *testing.T) { }, }, } - if err := DB.Create(&want[1]).Error; err != nil { + if err := DB.Create(&want[1]).GetError(); err != nil { t.Error(err) } var got []Level3 - if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil { + if err := DB.Preload("Level2.Level1s").Find(&got).GetError(); err != nil { t.Error(err) } @@ -555,7 +555,7 @@ func TestNestedPreload9(t *testing.T) { DB.DropTableIfExists(&Level2_1{}) DB.DropTableIfExists(&Level1{}) DB.DropTableIfExists(&Level0{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}, &Level2_1{}, &Level0{}).Error; err != nil { + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}, &Level2_1{}, &Level0{}).GetError(); err != nil { t.Error(err) } @@ -580,7 +580,7 @@ func TestNestedPreload9(t *testing.T) { }, }, } - if err := DB.Create(&want[0]).Error; err != nil { + if err := DB.Create(&want[0]).GetError(); err != nil { t.Error(err) } want[1] = Level3{ @@ -597,12 +597,12 @@ func TestNestedPreload9(t *testing.T) { }, }, } - if err := DB.Create(&want[1]).Error; err != nil { + if err := DB.Create(&want[1]).GetError(); err != nil { t.Error(err) } var got []Level3 - if err := DB.Preload("Level2").Preload("Level2.Level1s").Preload("Level2_1").Preload("Level2_1.Level1s").Preload("Level2_1.Level1s.Level0s").Find(&got).Error; err != nil { + if err := DB.Preload("Level2").Preload("Level2.Level1s").Preload("Level2_1").Preload("Level2_1.Level1s").Preload("Level2_1.Level1s.Level0s").Find(&got).GetError(); err != nil { t.Error(err) } @@ -634,7 +634,7 @@ func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) { DB.DropTableIfExists(&Level1{}) DB.DropTableIfExists("levels") - if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil { + if err := DB.AutoMigrate(&Level2{}, &Level1{}).GetError(); err != nil { t.Error(err) } @@ -642,7 +642,7 @@ func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) { {Value: "ru", LanguageCode: "ru"}, {Value: "en", LanguageCode: "en"}, }} - if err := DB.Save(&want).Error; err != nil { + if err := DB.Save(&want).GetError(); err != nil { t.Error(err) } @@ -650,12 +650,12 @@ func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) { {Value: "zh", LanguageCode: "zh"}, {Value: "de", LanguageCode: "de"}, }} - if err := DB.Save(&want2).Error; err != nil { + if err := DB.Save(&want2).GetError(); err != nil { t.Error(err) } var got Level2 - if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { + if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").GetError(); err != nil { t.Error(err) } @@ -664,7 +664,7 @@ func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) { } var got2 Level2 - if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { + if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").GetError(); err != nil { t.Error(err) } @@ -673,7 +673,7 @@ func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) { } var got3 []Level2 - if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { + if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).GetError(); err != nil { t.Error(err) } @@ -682,7 +682,7 @@ func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) { } var got4 []Level2 - if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { + if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).GetError(); err != nil { t.Error(err) } @@ -697,7 +697,7 @@ func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) { t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2})) } - if err := DB.Preload("Level1s").Find(&got4, "value IN (?)", []string{"non-existing"}).Error; err != nil { + if err := DB.Preload("Level1s").Find(&got4, "value IN (?)", []string{"non-existing"}).GetError(); err != nil { t.Error(err) } } @@ -719,7 +719,7 @@ func TestManyToManyPreloadForPointer(t *testing.T) { DB.DropTableIfExists(&Level1{}) DB.DropTableIfExists("levels") - if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil { + if err := DB.AutoMigrate(&Level2{}, &Level1{}).GetError(); err != nil { t.Error(err) } @@ -727,7 +727,7 @@ func TestManyToManyPreloadForPointer(t *testing.T) { {Value: "ru"}, {Value: "en"}, }} - if err := DB.Save(&want).Error; err != nil { + if err := DB.Save(&want).GetError(); err != nil { t.Error(err) } @@ -735,12 +735,12 @@ func TestManyToManyPreloadForPointer(t *testing.T) { {Value: "zh"}, {Value: "de"}, }} - if err := DB.Save(&want2).Error; err != nil { + if err := DB.Save(&want2).GetError(); err != nil { t.Error(err) } var got Level2 - if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { + if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").GetError(); err != nil { t.Error(err) } @@ -749,7 +749,7 @@ func TestManyToManyPreloadForPointer(t *testing.T) { } var got2 Level2 - if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { + if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").GetError(); err != nil { t.Error(err) } @@ -758,7 +758,7 @@ func TestManyToManyPreloadForPointer(t *testing.T) { } var got3 []Level2 - if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { + if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).GetError(); err != nil { t.Error(err) } @@ -767,7 +767,7 @@ func TestManyToManyPreloadForPointer(t *testing.T) { } var got4 []Level2 - if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { + if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).GetError(); err != nil { t.Error(err) } @@ -810,7 +810,7 @@ func TestManyToManyPreloadForNestedPointer(t *testing.T) { DB.DropTableIfExists(&Level1{}) DB.DropTableIfExists("levels") - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).GetError(); err != nil { t.Error(err) } @@ -824,7 +824,7 @@ func TestManyToManyPreloadForNestedPointer(t *testing.T) { }, }, } - if err := DB.Save(&want).Error; err != nil { + if err := DB.Save(&want).GetError(); err != nil { t.Error(err) } @@ -838,12 +838,12 @@ func TestManyToManyPreloadForNestedPointer(t *testing.T) { }, }, } - if err := DB.Save(&want2).Error; err != nil { + if err := DB.Save(&want2).GetError(); err != nil { t.Error(err) } var got Level3 - if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { + if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "Bob").GetError(); err != nil { t.Error(err) } @@ -852,7 +852,7 @@ func TestManyToManyPreloadForNestedPointer(t *testing.T) { } var got2 Level3 - if err := DB.Preload("Level2.Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { + if err := DB.Preload("Level2.Level1s").Find(&got2, "value = ?", "Tom").GetError(); err != nil { t.Error(err) } @@ -861,7 +861,7 @@ func TestManyToManyPreloadForNestedPointer(t *testing.T) { } var got3 []Level3 - if err := DB.Preload("Level2.Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { + if err := DB.Preload("Level2.Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).GetError(); err != nil { t.Error(err) } @@ -870,7 +870,7 @@ func TestManyToManyPreloadForNestedPointer(t *testing.T) { } var got4 []Level3 - if err := DB.Preload("Level2.Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { + if err := DB.Preload("Level2.Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).GetError(); err != nil { t.Error(err) } @@ -913,7 +913,7 @@ func TestNestedManyToManyPreload(t *testing.T) { DB.DropTableIfExists("level1_level2") DB.DropTableIfExists("level2_level3") - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).GetError(); err != nil { t.Error(err) } @@ -936,12 +936,12 @@ func TestNestedManyToManyPreload(t *testing.T) { }, } - if err := DB.Save(&want).Error; err != nil { + if err := DB.Save(&want).GetError(); err != nil { t.Error(err) } var got Level3 - if err := DB.Preload("Level2s").Preload("Level2s.Level1s").Find(&got, "value = ?", "Level3").Error; err != nil { + if err := DB.Preload("Level2s").Preload("Level2s.Level1s").Find(&got, "value = ?", "Level3").GetError(); err != nil { t.Error(err) } @@ -949,7 +949,7 @@ func TestNestedManyToManyPreload(t *testing.T) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } - if err := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.RecordNotFound { + if err := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").GetError(); err != gorm.RecordNotFound { t.Error(err) } } @@ -978,7 +978,7 @@ func TestNestedManyToManyPreload2(t *testing.T) { DB.DropTableIfExists(&Level3{}) DB.DropTableIfExists("level1_level2") - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).GetError(); err != nil { t.Error(err) } @@ -993,12 +993,12 @@ func TestNestedManyToManyPreload2(t *testing.T) { }, } - if err := DB.Save(&want).Error; err != nil { + if err := DB.Save(&want).GetError(); err != nil { t.Error(err) } var got Level3 - if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "Level3").Error; err != nil { + if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "Level3").GetError(); err != nil { t.Error(err) } @@ -1006,7 +1006,7 @@ func TestNestedManyToManyPreload2(t *testing.T) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } - if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.RecordNotFound { + if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").GetError(); err != gorm.RecordNotFound { t.Error(err) } } @@ -1035,7 +1035,7 @@ func TestNilPointerSlice(t *testing.T) { DB.DropTableIfExists(&Level2{}) DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).GetError(); err != nil { t.Error(err) } @@ -1045,17 +1045,17 @@ func TestNilPointerSlice(t *testing.T) { Value: "native", }, }} - if err := DB.Save(&want).Error; err != nil { + if err := DB.Save(&want).GetError(); err != nil { t.Error(err) } want2 := Level1{Value: "Tom", Level2: nil} - if err := DB.Save(&want2).Error; err != nil { + if err := DB.Save(&want2).GetError(); err != nil { t.Error(err) } var got []Level1 - if err := DB.Preload("Level2").Preload("Level2.Level3").Find(&got).Error; err != nil { + if err := DB.Preload("Level2").Preload("Level2.Level3").Find(&got).GetError(); err != nil { t.Error(err) } diff --git a/query_test.go b/query_test.go index 274e8e9b..ac6f3937 100644 --- a/query_test.go +++ b/query_test.go @@ -31,7 +31,7 @@ func TestFirstAndLast(t *testing.T) { t.Errorf("Find first record as slice") } - if DB.Joins("left join emails on emails.user_id = users.id").First(&User{}).Error != nil { + if DB.Joins("left join emails on emails.user_id = users.id").First(&User{}).GetError() != nil { t.Errorf("Should not raise any error when order with Join table") } } @@ -242,15 +242,15 @@ func TestSearchWithEmptyChain(t *testing.T) { user3 := User{Name: "ChainearchUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} DB.Save(&user1).Save(&user2).Save(&user3) - if DB.Where("").Where("").First(&User{}).Error != nil { + if DB.Where("").Where("").First(&User{}).GetError() != nil { t.Errorf("Should not raise any error if searching with empty strings") } - if DB.Where(&User{}).Where("name = ?", user1.Name).First(&User{}).Error != nil { + if DB.Where(&User{}).Where("name = ?", user1.Name).First(&User{}).GetError() != nil { t.Errorf("Should not raise any error if searching with empty struct") } - if DB.Where(map[string]interface{}{}).Where("name = ?", user1.Name).First(&User{}).Error != nil { + if DB.Where(map[string]interface{}{}).Where("name = ?", user1.Name).First(&User{}).GetError() != nil { t.Errorf("Should not raise any error if searching with empty map") } } @@ -359,7 +359,7 @@ func TestCount(t *testing.T) { var count, count1, count2 int64 var users []User - if err := DB.Where("name = ?", user1.Name).Or("name = ?", user3.Name).Find(&users).Count(&count).Error; err != nil { + if err := DB.Where("name = ?", user1.Name).Or("name = ?", user3.Name).Find(&users).Count(&count).GetError(); err != nil { t.Errorf(fmt.Sprintf("Count should work, but got err %v", err)) } @@ -381,7 +381,7 @@ func TestNot(t *testing.T) { DB := DB.Where("role = ?", "not") var users1, users2, users3, users4, users5, users6, users7, users8 []User - if DB.Find(&users1).RowsAffected != 4 { + if DB.Find(&users1).GetRowsAffected() != 4 { t.Errorf("should find 4 not users") } DB.Not(users1[0].Id).Find(&users2) @@ -598,7 +598,7 @@ func TestSelectWithArrayInput(t *testing.T) { func TestCurrentDatabase(t *testing.T) { databaseName := DB.CurrentDatabase() - if err := DB.Error; err != nil { + if err := DB.GetError(); err != nil { t.Errorf("Problem getting current db name: %s", err) } if databaseName == "" { diff --git a/scope.go b/scope.go index a11d4ec4..ad12c5f2 100644 --- a/scope.go +++ b/scope.go @@ -207,7 +207,7 @@ func (scope *Scope) CallMethod(name string, checkError bool) { case func(s *DB): newDB := scope.NewDB() f(newDB) - scope.Err(newDB.Error) + scope.Err(newDB.GetError()) case func() error: scope.Err(f()) case func(s *Scope) error: @@ -215,7 +215,7 @@ func (scope *Scope) CallMethod(name string, checkError bool) { case func(s *DB) error: newDB := scope.NewDB() scope.Err(f(newDB)) - scope.Err(newDB.Error) + scope.Err(newDB.GetError()) default: scope.Err(fmt.Errorf("unsupported function %v", name)) } @@ -262,7 +262,7 @@ type tabler interface { } type dbTabler interface { - TableName(*DB) string + TableName(Database) string } // TableName get table name diff --git a/scope_private.go b/scope_private.go index a154c426..daf718ae 100644 --- a/scope_private.go +++ b/scope_private.go @@ -444,17 +444,17 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { if relationship := fromField.Relationship; relationship != nil { if relationship.Kind == "many_to_many" { joinTableHandler := relationship.JoinTableHandler - scope.Err(joinTableHandler.JoinWith(joinTableHandler, toScope.db, scope.Value).Find(value).Error) + scope.Err(joinTableHandler.JoinWith(joinTableHandler, toScope.db, scope.Value).Find(value).GetError()) } else if relationship.Kind == "belongs_to" { - query := toScope.db + var query Database = toScope.db for idx, foreignKey := range relationship.ForeignDBNames { if field, ok := scope.FieldByName(foreignKey); ok { query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.AssociationForeignDBNames[idx])), field.Field.Interface()) } } - scope.Err(query.Find(value).Error) + scope.Err(query.Find(value).GetError()) } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { - query := toScope.db + var query Database = toScope.db for idx, foreignKey := range relationship.ForeignDBNames { if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) @@ -464,16 +464,16 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { if relationship.PolymorphicType != "" { query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName()) } - scope.Err(query.Find(value).Error) + scope.Err(query.Find(value).GetError()) } } else { sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey())) - scope.Err(toScope.db.Where(sql, fromField.Field.Interface()).Find(value).Error) + scope.Err(toScope.db.Where(sql, fromField.Field.Interface()).Find(value).GetError()) } return scope } else if toField != nil { sql := fmt.Sprintf("%v = ?", scope.Quote(toField.DBName)) - scope.Err(toScope.db.Where(sql, scope.PrimaryKeyValue()).Find(value).Error) + scope.Err(toScope.db.Where(sql, scope.PrimaryKeyValue()).Find(value).GetError()) return scope } } @@ -525,7 +525,7 @@ func (scope *Scope) createJoinTable(field *StructField) { } } - scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v)) %s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).Error) + scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v)) %s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).GetError()) } scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler) } diff --git a/scope_test.go b/scope_test.go index 42458995..2e3d03c4 100644 --- a/scope_test.go +++ b/scope_test.go @@ -6,16 +6,16 @@ import ( ) func NameIn1And2(d *gorm.DB) *gorm.DB { - return d.Where("name in (?)", []string{"ScopeUser1", "ScopeUser2"}) + return d.Where("name in (?)", []string{"ScopeUser1", "ScopeUser2"}).(*gorm.DB) } func NameIn2And3(d *gorm.DB) *gorm.DB { - return d.Where("name in (?)", []string{"ScopeUser2", "ScopeUser3"}) + return d.Where("name in (?)", []string{"ScopeUser2", "ScopeUser3"}).(*gorm.DB) } -func NameIn(names []string) func(d *gorm.DB) *gorm.DB { +func NameIn(names []string) func(d *gorm.DB) *gorm.DB{ return func(d *gorm.DB) *gorm.DB { - return d.Where("name in (?)", names) + return d.Where("name in (?)", names).(*gorm.DB) } } diff --git a/slice_test.go b/slice_test.go index 21410548..c4eb9e29 100644 --- a/slice_test.go +++ b/slice_test.go @@ -7,7 +7,7 @@ import ( ) func TestScannableSlices(t *testing.T) { - if err := DB.AutoMigrate(&RecordWithSlice{}).Error; err != nil { + if err := DB.AutoMigrate(&RecordWithSlice{}).GetError(); err != nil { t.Errorf("Should create table with slice values correctly: %s", err) } @@ -19,13 +19,13 @@ func TestScannableSlices(t *testing.T) { }, } - if err := DB.Save(&r1).Error; err != nil { + if err := DB.Save(&r1).GetError(); err != nil { t.Errorf("Should save record with slice values") } var r2 RecordWithSlice - if err := DB.Find(&r2).Error; err != nil { + if err := DB.Find(&r2).GetError(); err != nil { t.Errorf("Should fetch record with slice values") } diff --git a/sqlite3.go b/sqlite3.go index d052d2c1..3071206f 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -62,7 +62,7 @@ func (s sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool } func (sqlite3) RemoveIndex(scope *Scope, indexName string) { - scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error) + scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).GetError()) } func (sqlite3) CurrentDatabase(scope *Scope) (name string) { diff --git a/update_test.go b/update_test.go index 75877488..1304c9f5 100644 --- a/update_test.go +++ b/update_test.go @@ -56,17 +56,17 @@ func TestUpdate(t *testing.T) { t.Errorf("Product should not be changed to 789") } - if DB.Model(product2).Update("CreatedAt", time.Now().Add(time.Hour)).Error != nil { + if DB.Model(product2).Update("CreatedAt", time.Now().Add(time.Hour)).GetError() != nil { t.Error("No error should raise when update with CamelCase") } - if DB.Model(&product2).UpdateColumn("CreatedAt", time.Now().Add(time.Hour)).Error != nil { + if DB.Model(&product2).UpdateColumn("CreatedAt", time.Now().Add(time.Hour)).GetError() != nil { t.Error("No error should raise when update_column with CamelCase") } var products []Product DB.Find(&products) - if count := DB.Model(Product{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(products)) { + if count := DB.Model(Product{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).GetRowsAffected(); count != int64(len(products)) { t.Error("RowsAffected should be correct when do batch update") } @@ -95,7 +95,7 @@ func TestUpdateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) { var animals []Animal DB.Find(&animals) - if count := DB.Model(Animal{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(animals)) { + if count := DB.Model(Animal{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).GetRowsAffected(); count != int64(len(animals)) { t.Error("RowsAffected should be correct when do batch update") } @@ -402,8 +402,8 @@ func TestUpdateColumnsSkipsAssociations(t *testing.T) { newAge := int64(100) user.BillingAddress.Address1 = "second street" db := DB.Model(user).UpdateColumns(User{Age: newAge}) - if db.RowsAffected != 1 { - t.Errorf("Expected RowsAffected=1 but instead RowsAffected=%v", DB.RowsAffected) + if db.GetRowsAffected() != 1 { + t.Errorf("Expected RowsAffected=1 but instead RowsAffected=%v", DB.GetRowsAffected()) } // Verify that Age now=`newAge`.