diff --git a/chainable_api.go b/chainable_api.go index c3a02d20..dca12b08 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -93,10 +93,12 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { } delete(tx.Statement.Clauses, "SELECT") case string: - fields := strings.FieldsFunc(v, utils.IsValidDBNameChar) - - // normal field names - if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") { + if (strings.Contains(v, " ?") || strings.Contains(v, "(?")) && len(args) > 0 { + tx.Statement.AddClause(clause.Select{ + Distinct: db.Statement.Distinct, + Expression: clause.Expr{SQL: v, Vars: args}, + }) + } else { tx.Statement.Selects = []string{v} for _, arg := range args { @@ -115,11 +117,6 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { } delete(tx.Statement.Clauses, "SELECT") - } else { - tx.Statement.AddClause(clause.Select{ - Distinct: db.Statement.Distinct, - Expression: clause.Expr{SQL: v, Vars: args}, - }) } default: tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args)) diff --git a/finisher_api.go b/finisher_api.go index d36dc754..98a877f2 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -355,29 +355,38 @@ func (db *DB) Count(count *int64) (tx *DB) { }() } + if selectClause, ok := db.Statement.Clauses["SELECT"]; ok { + defer func() { + db.Statement.Clauses["SELECT"] = selectClause + }() + } else { + defer delete(tx.Statement.Clauses, "SELECT") + } + if len(tx.Statement.Selects) == 0 { tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}}) - defer delete(tx.Statement.Clauses, "SELECT") } else if !strings.Contains(strings.ToLower(tx.Statement.Selects[0]), "count(") { expr := clause.Expr{SQL: "count(1)"} if len(tx.Statement.Selects) == 1 { dbName := tx.Statement.Selects[0] - if tx.Statement.Parse(tx.Statement.Model) == nil { - if f := tx.Statement.Schema.LookUpField(dbName); f != nil { - dbName = f.DBName + fields := strings.FieldsFunc(dbName, utils.IsValidDBNameChar) + if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") { + if tx.Statement.Parse(tx.Statement.Model) == nil { + if f := tx.Statement.Schema.LookUpField(dbName); f != nil { + dbName = f.DBName + } } - } - if tx.Statement.Distinct { - expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: dbName}}} - } else { - expr = clause.Expr{SQL: "COUNT(?)", Vars: []interface{}{clause.Column{Name: dbName}}} + if tx.Statement.Distinct { + expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: dbName}}} + } else { + expr = clause.Expr{SQL: "COUNT(?)", Vars: []interface{}{clause.Column{Name: dbName}}} + } } } tx.Statement.AddClause(clause.Select{Expression: expr}) - defer delete(tx.Statement.Clauses, "SELECT") } if orderByClause, ok := db.Statement.Clauses["ORDER BY"]; ok { @@ -457,11 +466,13 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { tx.AddError(ErrModelValueRequired) } - fields := strings.FieldsFunc(column, utils.IsValidDBNameChar) - tx.Statement.AddClauseIfNotExists(clause.Select{ - Distinct: tx.Statement.Distinct, - Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}}, - }) + if len(tx.Statement.Selects) != 1 { + fields := strings.FieldsFunc(column, utils.IsValidDBNameChar) + tx.Statement.AddClauseIfNotExists(clause.Select{ + Distinct: tx.Statement.Distinct, + Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}}, + }) + } tx.Statement.Dest = dest tx.callbacks.Query().Execute(tx) return diff --git a/tests/count_test.go b/tests/count_test.go index 55fb71e2..ffe675d9 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -3,6 +3,8 @@ package tests_test import ( "fmt" "regexp" + "sort" + "strings" "testing" "gorm.io/gorm" @@ -77,4 +79,46 @@ func TestCount(t *testing.T) { if err := DB.Table("users").Where("users.name = ?", user1.Name).Order("name").Count(&count5).Error; err != nil || count5 != 1 { t.Errorf("count with join, got error: %v, count %v", err, count) } + + var count6 int64 + if err := DB.Model(&User{}).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Select( + "(CASE WHEN name=? THEN ? ELSE ? END) as name", "count-1", "main", "other", + ).Count(&count6).Find(&users).Error; err != nil || count6 != 3 { + t.Fatalf(fmt.Sprintf("Count should work, but got err %v", err)) + } + + expects := []User{User{Name: "main"}, {Name: "other"}, {Name: "other"}} + sort.SliceStable(users, func(i, j int) bool { + return strings.Compare(users[i].Name, users[j].Name) < 0 + }) + + AssertEqual(t, users, expects) + + var count7 int64 + if err := DB.Model(&User{}).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Select( + "(CASE WHEN name=? THEN ? ELSE ? END) as name, age", "count-1", "main", "other", + ).Count(&count7).Find(&users).Error; err != nil || count7 != 3 { + t.Fatalf(fmt.Sprintf("Count should work, but got err %v", err)) + } + + expects = []User{User{Name: "main", Age: 18}, {Name: "other", Age: 18}, {Name: "other", Age: 18}} + sort.SliceStable(users, func(i, j int) bool { + return strings.Compare(users[i].Name, users[j].Name) < 0 + }) + + AssertEqual(t, users, expects) + + var count8 int64 + if err := DB.Model(&User{}).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Select( + "(CASE WHEN age=18 THEN 1 ELSE 2 END) as age", "name", + ).Count(&count8).Find(&users).Error; err != nil || count8 != 3 { + t.Fatalf(fmt.Sprintf("Count should work, but got err %v", err)) + } + + expects = []User{User{Name: "count-1", Age: 1}, {Name: "count-2", Age: 1}, {Name: "count-3", Age: 1}} + sort.SliceStable(users, func(i, j int) bool { + return strings.Compare(users[i].Name, users[j].Name) < 0 + }) + + AssertEqual(t, users, expects) } diff --git a/tests/query_test.go b/tests/query_test.go index c4162bdc..af8bbf07 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -677,7 +677,7 @@ func TestPluckWithSelect(t *testing.T) { DB.Create(&users) var userAges []int - err := DB.Model(&User{}).Where("name like ?", "pluck_with_select%").Select("age + 1 as user_age").Pluck("user_age", &userAges).Error + err := DB.Model(&User{}).Where("name like ?", "pluck_with_select%").Select("age + 1 as user_age").Pluck("user_age", &userAges).Error if err != nil { t.Fatalf("got error when pluck user_age: %v", err) }