Fix Count with complicated Select, close #3826
This commit is contained in:
		
							parent
							
								
									f655041908
								
							
						
					
					
						commit
						1ef1f0bfe4
					
				| @ -93,10 +93,12 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { | ||||
| 		} | ||||
| 		delete(tx.Statement.Clauses, "SELECT") | ||||
| 	case string: | ||||
| 		fields := strings.FieldsFunc(v, utils.IsValidDBNameChar) | ||||
| 
 | ||||
| 		// normal field names
 | ||||
| 		if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") { | ||||
| 		if (strings.Contains(v, " ?") || strings.Contains(v, "(?")) && len(args) > 0 { | ||||
| 			tx.Statement.AddClause(clause.Select{ | ||||
| 				Distinct:   db.Statement.Distinct, | ||||
| 				Expression: clause.Expr{SQL: v, Vars: args}, | ||||
| 			}) | ||||
| 		} else { | ||||
| 			tx.Statement.Selects = []string{v} | ||||
| 
 | ||||
| 			for _, arg := range args { | ||||
| @ -115,11 +117,6 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { | ||||
| 			} | ||||
| 
 | ||||
| 			delete(tx.Statement.Clauses, "SELECT") | ||||
| 		} else { | ||||
| 			tx.Statement.AddClause(clause.Select{ | ||||
| 				Distinct:   db.Statement.Distinct, | ||||
| 				Expression: clause.Expr{SQL: v, Vars: args}, | ||||
| 			}) | ||||
| 		} | ||||
| 	default: | ||||
| 		tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args)) | ||||
|  | ||||
| @ -355,29 +355,38 @@ func (db *DB) Count(count *int64) (tx *DB) { | ||||
| 		}() | ||||
| 	} | ||||
| 
 | ||||
| 	if selectClause, ok := db.Statement.Clauses["SELECT"]; ok { | ||||
| 		defer func() { | ||||
| 			db.Statement.Clauses["SELECT"] = selectClause | ||||
| 		}() | ||||
| 	} else { | ||||
| 		defer delete(tx.Statement.Clauses, "SELECT") | ||||
| 	} | ||||
| 
 | ||||
| 	if len(tx.Statement.Selects) == 0 { | ||||
| 		tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}}) | ||||
| 		defer delete(tx.Statement.Clauses, "SELECT") | ||||
| 	} else if !strings.Contains(strings.ToLower(tx.Statement.Selects[0]), "count(") { | ||||
| 		expr := clause.Expr{SQL: "count(1)"} | ||||
| 
 | ||||
| 		if len(tx.Statement.Selects) == 1 { | ||||
| 			dbName := tx.Statement.Selects[0] | ||||
| 			if tx.Statement.Parse(tx.Statement.Model) == nil { | ||||
| 				if f := tx.Statement.Schema.LookUpField(dbName); f != nil { | ||||
| 					dbName = f.DBName | ||||
| 			fields := strings.FieldsFunc(dbName, utils.IsValidDBNameChar) | ||||
| 			if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") { | ||||
| 				if tx.Statement.Parse(tx.Statement.Model) == nil { | ||||
| 					if f := tx.Statement.Schema.LookUpField(dbName); f != nil { | ||||
| 						dbName = f.DBName | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			if tx.Statement.Distinct { | ||||
| 				expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: dbName}}} | ||||
| 			} else { | ||||
| 				expr = clause.Expr{SQL: "COUNT(?)", Vars: []interface{}{clause.Column{Name: dbName}}} | ||||
| 				if tx.Statement.Distinct { | ||||
| 					expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: dbName}}} | ||||
| 				} else { | ||||
| 					expr = clause.Expr{SQL: "COUNT(?)", Vars: []interface{}{clause.Column{Name: dbName}}} | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		tx.Statement.AddClause(clause.Select{Expression: expr}) | ||||
| 		defer delete(tx.Statement.Clauses, "SELECT") | ||||
| 	} | ||||
| 
 | ||||
| 	if orderByClause, ok := db.Statement.Clauses["ORDER BY"]; ok { | ||||
| @ -457,11 +466,13 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { | ||||
| 		tx.AddError(ErrModelValueRequired) | ||||
| 	} | ||||
| 
 | ||||
| 	fields := strings.FieldsFunc(column, utils.IsValidDBNameChar) | ||||
| 	tx.Statement.AddClauseIfNotExists(clause.Select{ | ||||
| 		Distinct: tx.Statement.Distinct, | ||||
| 		Columns:  []clause.Column{{Name: column, Raw: len(fields) != 1}}, | ||||
| 	}) | ||||
| 	if len(tx.Statement.Selects) != 1 { | ||||
| 		fields := strings.FieldsFunc(column, utils.IsValidDBNameChar) | ||||
| 		tx.Statement.AddClauseIfNotExists(clause.Select{ | ||||
| 			Distinct: tx.Statement.Distinct, | ||||
| 			Columns:  []clause.Column{{Name: column, Raw: len(fields) != 1}}, | ||||
| 		}) | ||||
| 	} | ||||
| 	tx.Statement.Dest = dest | ||||
| 	tx.callbacks.Query().Execute(tx) | ||||
| 	return | ||||
|  | ||||
| @ -3,6 +3,8 @@ package tests_test | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"regexp" | ||||
| 	"sort" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"gorm.io/gorm" | ||||
| @ -77,4 +79,46 @@ func TestCount(t *testing.T) { | ||||
| 	if err := DB.Table("users").Where("users.name = ?", user1.Name).Order("name").Count(&count5).Error; err != nil || count5 != 1 { | ||||
| 		t.Errorf("count with join, got error: %v, count %v", err, count) | ||||
| 	} | ||||
| 
 | ||||
| 	var count6 int64 | ||||
| 	if err := DB.Model(&User{}).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Select( | ||||
| 		"(CASE WHEN name=? THEN ? ELSE ? END) as name", "count-1", "main", "other", | ||||
| 	).Count(&count6).Find(&users).Error; err != nil || count6 != 3 { | ||||
| 		t.Fatalf(fmt.Sprintf("Count should work, but got err %v", err)) | ||||
| 	} | ||||
| 
 | ||||
| 	expects := []User{User{Name: "main"}, {Name: "other"}, {Name: "other"}} | ||||
| 	sort.SliceStable(users, func(i, j int) bool { | ||||
| 		return strings.Compare(users[i].Name, users[j].Name) < 0 | ||||
| 	}) | ||||
| 
 | ||||
| 	AssertEqual(t, users, expects) | ||||
| 
 | ||||
| 	var count7 int64 | ||||
| 	if err := DB.Model(&User{}).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Select( | ||||
| 		"(CASE WHEN name=? THEN ? ELSE ? END) as name, age", "count-1", "main", "other", | ||||
| 	).Count(&count7).Find(&users).Error; err != nil || count7 != 3 { | ||||
| 		t.Fatalf(fmt.Sprintf("Count should work, but got err %v", err)) | ||||
| 	} | ||||
| 
 | ||||
| 	expects = []User{User{Name: "main", Age: 18}, {Name: "other", Age: 18}, {Name: "other", Age: 18}} | ||||
| 	sort.SliceStable(users, func(i, j int) bool { | ||||
| 		return strings.Compare(users[i].Name, users[j].Name) < 0 | ||||
| 	}) | ||||
| 
 | ||||
| 	AssertEqual(t, users, expects) | ||||
| 
 | ||||
| 	var count8 int64 | ||||
| 	if err := DB.Model(&User{}).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Select( | ||||
| 		"(CASE WHEN age=18 THEN 1 ELSE 2 END) as age", "name", | ||||
| 	).Count(&count8).Find(&users).Error; err != nil || count8 != 3 { | ||||
| 		t.Fatalf(fmt.Sprintf("Count should work, but got err %v", err)) | ||||
| 	} | ||||
| 
 | ||||
| 	expects = []User{User{Name: "count-1", Age: 1}, {Name: "count-2", Age: 1}, {Name: "count-3", Age: 1}} | ||||
| 	sort.SliceStable(users, func(i, j int) bool { | ||||
| 		return strings.Compare(users[i].Name, users[j].Name) < 0 | ||||
| 	}) | ||||
| 
 | ||||
| 	AssertEqual(t, users, expects) | ||||
| } | ||||
|  | ||||
| @ -677,7 +677,7 @@ func TestPluckWithSelect(t *testing.T) { | ||||
| 	DB.Create(&users) | ||||
| 
 | ||||
| 	var userAges []int | ||||
| 	err := DB.Model(&User{}).Where("name like ?", "pluck_with_select%").Select("age + 1  as user_age").Pluck("user_age", &userAges).Error | ||||
| 	err := DB.Model(&User{}).Where("name like ?", "pluck_with_select%").Select("age + 1 as user_age").Pluck("user_age", &userAges).Error | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("got error when pluck user_age: %v", err) | ||||
| 	} | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu