diff --git a/finisher_api.go b/finisher_api.go index b443f4b5..6d961811 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -284,8 +284,8 @@ func (db *DB) Count(count *int64) (tx *DB) { tx.Statement.Dest = count tx.callbacks.Query().Execute(tx) - if db.RowsAffected != 1 { - *count = db.RowsAffected + if tx.RowsAffected != 1 { + *count = tx.RowsAffected } return } diff --git a/tests/count_test.go b/tests/count_test.go index 63238089..0662ae5c 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -4,6 +4,7 @@ import ( "fmt" "testing" + "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -31,6 +32,13 @@ func TestCount(t *testing.T) { t.Errorf("multiple count in chain should works") } + tx := DB.Model(&User{}).Where("name = ?", user1.Name).Session(&gorm.Session{WithConditions: true}) + tx.Count(&count1) + tx.Or("name in ?", []string{user2.Name, user3.Name}).Count(&count2) + if count1 != 1 || count2 != 3 { + t.Errorf("count after new session should works") + } + var count3 int64 if err := DB.Model(&User{}).Where("name in ?", []string{user2.Name, user2.Name, user3.Name}).Group("id").Count(&count3).Error; err != nil { t.Errorf("Error happened when count with group, but got %v", err)