From 801a271d0760865d714a4e3532a273dbf2998676 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 12 Jan 2016 12:16:22 +0800 Subject: [PATCH] Fix Association Count --- association.go | 23 ++++++++++++--------- association_test.go | 50 +++++++++++++++++++++++++++++++++++---------- join_table_test.go | 2 +- main_test.go | 5 +++-- structs_test.go | 6 ++++-- 5 files changed, 60 insertions(+), 26 deletions(-) diff --git a/association.go b/association.go index 5cc32e1c..4660cb27 100644 --- a/association.go +++ b/association.go @@ -297,13 +297,16 @@ func (association *Association) Clear() *Association { } func (association *Association) Count() int { - count := -1 - relationship := association.Field.Relationship - scope := association.Scope - newScope := scope.New(association.Field.Field.Interface()) + var ( + count = 0 + relationship = association.Field.Relationship + scope = association.Scope + fieldValue = association.Field.Field.Interface() + newScope = scope.New(fieldValue) + ) if relationship.Kind == "many_to_many" { - relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.NewDB(), association.Scope.Value).Table(newScope.TableName()).Count(&count) + relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.NewDB(), association.Scope.Value).Model(fieldValue).Count(&count) } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { query := scope.DB() for idx, foreignKey := range relationship.ForeignDBNames { @@ -316,16 +319,16 @@ func (association *Association) Count() int { if relationship.PolymorphicType != "" { query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.PolymorphicDBName)), scope.TableName()) } - query.Table(newScope.TableName()).Count(&count) + query.Model(fieldValue).Count(&count) } else if relationship.Kind == "belongs_to" { query := scope.DB() - for idx, foreignKey := range relationship.ForeignDBNames { - if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { - query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(foreignKey)), + for idx, primaryKey := range relationship.AssociationForeignDBNames { + if field, ok := scope.FieldByName(relationship.ForeignDBNames[idx]); ok { + query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(primaryKey)), field.Field.Interface()) } } - query.Table(newScope.TableName()).Count(&count) + query.Model(fieldValue).Count(&count) } return count diff --git a/association_test.go b/association_test.go index 0e61d51f..29a65292 100644 --- a/association_test.go +++ b/association_test.go @@ -19,7 +19,7 @@ func TestBelongsTo(t *testing.T) { t.Errorf("Got errors when save post", err.Error()) } - if post.Category.Id == 0 || post.MainCategory.Id == 0 { + if post.Category.ID == 0 || post.MainCategory.ID == 0 { t.Errorf("Category's primary key should be updated") } @@ -46,11 +46,11 @@ func TestBelongsTo(t *testing.T) { t.Errorf("Query belongs to relations with Related") } - if DB.Model(&post).Association("Category").Count() == 1 { + if DB.Model(&post).Association("Category").Count() != 1 { t.Errorf("Post's category count should be 1") } - if DB.Model(&post).Association("MainCategory").Count() == 1 { + if DB.Model(&post).Association("MainCategory").Count() != 1 { t.Errorf("Post's main category count should be 1") } @@ -60,7 +60,7 @@ func TestBelongsTo(t *testing.T) { } DB.Model(&post).Association("Category").Append(&category2) - if category2.Id == 0 { + if category2.ID == 0 { t.Errorf("Category should has ID when created with Append") } @@ -71,7 +71,7 @@ func TestBelongsTo(t *testing.T) { t.Errorf("Category should be updated with Append") } - if DB.Model(&post).Association("Category").Count() == 1 { + if DB.Model(&post).Association("Category").Count() != 1 { t.Errorf("Post's category count should be 1") } @@ -81,7 +81,7 @@ func TestBelongsTo(t *testing.T) { } DB.Model(&post).Association("Category").Replace(&category3) - if category3.Id == 0 { + if category3.ID == 0 { t.Errorf("Category should has ID when created with Replace") } @@ -91,7 +91,7 @@ func TestBelongsTo(t *testing.T) { t.Errorf("Category should be updated with Replace") } - if DB.Model(&post).Association("Category").Count() == 1 { + if DB.Model(&post).Association("Category").Count() != 1 { t.Errorf("Post's category count should be 1") } @@ -117,8 +117,8 @@ func TestBelongsTo(t *testing.T) { t.Errorf("Category should be deleted with Delete") } - if DB.Model(&post).Association("Category").Count() == 0 { - t.Errorf("Post's category count should be 0 after Delete") + if count := DB.Model(&post).Association("Category").Count(); count != 0 { + t.Errorf("Post's category count should be 0 after Delete, but got %v", count) } // Clear @@ -144,8 +144,36 @@ func TestBelongsTo(t *testing.T) { t.Errorf("Should not find any category after Clear") } - if DB.Model(&post).Association("Category").Count() == 0 { - t.Errorf("Post's category count should be 0 after Clear") + if count := DB.Model(&post).Association("Category").Count(); count != 0 { + t.Errorf("Post's category count should be 0 after Clear, but got %v", count) + } + + // Check Association mode with soft delete + category6 := Category{ + Name: "Category 6", + } + DB.Model(&post).Association("Category").Append(&category6) + + if count := DB.Model(&post).Association("Category").Count(); count != 1 { + t.Errorf("Post's category count should be 1 after Append, but got %v", count) + } + + DB.Delete(&category6) + + if count := DB.Model(&post).Association("Category").Count(); count != 0 { + t.Errorf("Post's category count should be 0 after the category has been deleted, but got %v", count) + } + + if err := DB.Model(&post).Association("Category").Find(&Category{}).Error; err == nil { + t.Errorf("Post's category is not findable after Delete") + } + + if count := DB.Unscoped().Model(&post).Association("Category").Count(); count != 1 { + t.Errorf("Post's category count should be 1 when query with Unscoped, but got %v", count) + } + + if err := DB.Unscoped().Model(&post).Association("Category").Find(&Category{}).Error; err != nil { + t.Errorf("Post's category should be findable when query with Unscoped, got %v", err) } } diff --git a/join_table_test.go b/join_table_test.go index 3353aee2..70e792ed 100644 --- a/join_table_test.go +++ b/join_table_test.go @@ -39,7 +39,7 @@ func (*PersonAddress) Delete(handler gorm.JoinTableHandlerInterface, db *gorm.DB func (pa *PersonAddress) JoinWith(handler gorm.JoinTableHandlerInterface, db *gorm.DB, source interface{}) *gorm.DB { table := pa.Table(db) - return db.Table(table).Joins("INNER JOIN person_addresses ON person_addresses.address_id = addresses.id").Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table)) + return db.Joins("INNER JOIN person_addresses ON person_addresses.address_id = addresses.id").Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table)) } func TestJoinTable(t *testing.T) { diff --git a/main_test.go b/main_test.go index e6c703e4..65467d73 100644 --- a/main_test.go +++ b/main_test.go @@ -33,8 +33,9 @@ func init() { // DB.SetLogger(Logger{log.New(os.Stdout, "\r\n", 0)}) // DB.SetLogger(log.New(os.Stdout, "\r\n", 0)) - // DB.LogMode(true) - DB.LogMode(false) + if os.Getenv("DEBUG") == "true" { + DB.LogMode(true) + } DB.DB().SetMaxIdleConns(10) diff --git a/structs_test.go b/structs_test.go index a3dfa8b1..8f529952 100644 --- a/structs_test.go +++ b/structs_test.go @@ -6,6 +6,8 @@ import ( "errors" "fmt" + "github.com/jinzhu/gorm" + "reflect" "time" ) @@ -154,12 +156,12 @@ type Post struct { } type Category struct { - Id int64 + gorm.Model Name string } type Comment struct { - Id int64 + gorm.Model PostId int64 Content string Post Post