From 1a7ea98ac51af189177e382a7a083b11a2b9b3c2 Mon Sep 17 00:00:00 2001 From: black-06 Date: Thu, 23 Mar 2023 11:19:53 +0800 Subject: [PATCH] fix: count with group (#6157) (#6160) * fix: count with group (#6157) * add an easy-to-understand ut --- finisher_api.go | 2 +- tests/count_test.go | 30 ++++++++++++++++++++++++++++-- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index d647cf64..0e3c2876 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -491,7 +491,7 @@ func (db *DB) Count(count *int64) (tx *DB) { tx.Statement.Dest = count tx = tx.callbacks.Query().Execute(tx) - if tx.RowsAffected != 1 { + if _, ok := db.Statement.Clauses["GROUP BY"]; ok || tx.RowsAffected != 1 { *count = tx.RowsAffected } diff --git a/tests/count_test.go b/tests/count_test.go index 2199dc6d..b0dfb0b5 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -11,6 +11,32 @@ import ( . "gorm.io/gorm/utils/tests" ) +func TestCountWithGroup(t *testing.T) { + DB.Create([]Company{ + {Name: "company_count_group_a"}, + {Name: "company_count_group_a"}, + {Name: "company_count_group_a"}, + {Name: "company_count_group_b"}, + {Name: "company_count_group_c"}, + }) + + var count1 int64 + if err := DB.Model(&Company{}).Where("name = ?", "company_count_group_a").Group("name").Count(&count1).Error; err != nil { + t.Errorf(fmt.Sprintf("Count should work, but got err %v", err)) + } + if count1 != 1 { + t.Errorf("Count with group should be 1, but got count: %v", count1) + } + + var count2 int64 + if err := DB.Debug().Model(&Company{}).Where("name in ?", []string{"company_count_group_b", "company_count_group_c"}).Group("name").Count(&count2).Error; err != nil { + t.Errorf(fmt.Sprintf("Count should work, but got err %v", err)) + } + if count2 != 2 { + t.Errorf("Count with group should be 2, but got count: %v", count2) + } +} + func TestCount(t *testing.T) { var ( user1 = *GetUser("count-1", Config{}) @@ -141,8 +167,8 @@ func TestCount(t *testing.T) { } DB.Create(sameUsers) - if err := DB.Model(&User{}).Where("name = ?", "count-4").Group("name").Count(&count11).Error; err != nil || count11 != int64(len(sameUsers)) { - t.Fatalf("Count should be 3, but got count: %v err %v", count11, err) + if err := DB.Model(&User{}).Where("name = ?", "count-4").Group("name").Count(&count11).Error; err != nil || count11 != 1 { + t.Fatalf("Count should be 1, but got count: %v err %v", count11, err) } var count12 int64