Add Count tests
This commit is contained in:
		
							parent
							
								
									1c39ac921b
								
							
						
					
					
						commit
						cbc4a81140
					
				| @ -247,11 +247,12 @@ func (association *Association) Clear() error { | ||||
| 	return association.Replace() | ||||
| } | ||||
| 
 | ||||
| func (association *Association) Count() (count int) { | ||||
| func (association *Association) Count() (count int64) { | ||||
| 	if association.Error == nil { | ||||
| 		var ( | ||||
| 			tx    = association.DB | ||||
| 			conds = association.Relationship.ToQueryConditions(tx.Statement.ReflectValue) | ||||
| 			conds      = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue) | ||||
| 			modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface() | ||||
| 			tx         = association.DB.Model(modelValue) | ||||
| 		) | ||||
| 
 | ||||
| 		if association.Relationship.JoinTable != nil { | ||||
|  | ||||
| @ -73,6 +73,7 @@ func (cs *callbacks) Raw() *processor { | ||||
| 
 | ||||
| func (p *processor) Execute(db *DB) { | ||||
| 	curTime := time.Now() | ||||
| 	db.RowsAffected = 0 | ||||
| 	if stmt := db.Statement; stmt != nil { | ||||
| 		if stmt.Model == nil { | ||||
| 			stmt.Model = stmt.Dest | ||||
| @ -102,7 +103,7 @@ func (p *processor) Execute(db *DB) { | ||||
| 		}, db.Error) | ||||
| 
 | ||||
| 		stmt.reinit() | ||||
| 		db.Config.statementPool.Put(stmt) | ||||
| 		// db.Config.statementPool.Put(stmt)
 | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
|  | ||||
| @ -21,6 +21,11 @@ func Query(db *gorm.DB) { | ||||
| 					clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ | ||||
| 						Name: f.DBName, | ||||
| 					}) | ||||
| 				} else { | ||||
| 					clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ | ||||
| 						Name: name, | ||||
| 						Raw:  true, | ||||
| 					}) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| @ -85,7 +90,7 @@ func Query(db *gorm.DB) { | ||||
| 			db.Statement.AddClauseIfNotExists(clause.From{}) | ||||
| 		} | ||||
| 
 | ||||
| 		db.Statement.AddClauseIfNotExists(clauseSelect) | ||||
| 		db.Statement.AddClause(clauseSelect) | ||||
| 		db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") | ||||
| 	} | ||||
| 
 | ||||
|  | ||||
| @ -49,6 +49,11 @@ func Scan(rows *sql.Rows, db *gorm.DB) { | ||||
| 			} | ||||
| 			*dest = append(*dest, v) | ||||
| 		} | ||||
| 	case *int, *int64, *uint, *uint64: | ||||
| 		for rows.Next() { | ||||
| 			db.RowsAffected++ | ||||
| 			rows.Scan(dest) | ||||
| 		} | ||||
| 	default: | ||||
| 		switch db.Statement.ReflectValue.Kind() { | ||||
| 		case reflect.Slice, reflect.Array: | ||||
|  | ||||
| @ -41,8 +41,5 @@ func (values Values) Build(builder Builder) { | ||||
| // MergeClause merge values clauses
 | ||||
| func (values Values) MergeClause(clause *Clause) { | ||||
| 	clause.Name = "" | ||||
| 	if v, ok := clause.Expression.(Values); ok { | ||||
| 		values.Values = append(v.Values, values.Values...) | ||||
| 	} | ||||
| 	clause.Expression = values | ||||
| } | ||||
|  | ||||
| @ -145,8 +145,19 @@ func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) { | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (db *DB) Count(value interface{}) (tx *DB) { | ||||
| func (db *DB) Count(count *int64) (tx *DB) { | ||||
| 	tx = db.getInstance() | ||||
| 	if len(tx.Statement.Selects) == 0 { | ||||
| 		tx.Statement.Selects = []string{"count(1)"} | ||||
| 	} | ||||
| 	if tx.Statement.Model == nil { | ||||
| 		tx.Statement.Model = tx.Statement.Dest | ||||
| 	} | ||||
| 	tx.Statement.Dest = count | ||||
| 	tx.callbacks.Query().Execute(tx) | ||||
| 	if db.RowsAffected != 1 { | ||||
| 		*count = db.RowsAffected | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
|  | ||||
							
								
								
									
										54
									
								
								statement.go
									
									
									
									
									
								
							
							
						
						
									
										54
									
								
								statement.go
									
									
									
									
									
								
							| @ -63,6 +63,8 @@ func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) { | ||||
| 	case clause.Table: | ||||
| 		if v.Name == clause.CurrentTable { | ||||
| 			stmt.DB.Dialector.QuoteTo(writer, stmt.Table) | ||||
| 		} else if v.Raw { | ||||
| 			writer.WriteString(v.Name) | ||||
| 		} else { | ||||
| 			stmt.DB.Dialector.QuoteTo(writer, v.Name) | ||||
| 		} | ||||
| @ -85,6 +87,8 @@ func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) { | ||||
| 			if stmt.Schema != nil && stmt.Schema.PrioritizedPrimaryField != nil { | ||||
| 				stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.PrioritizedPrimaryField.DBName) | ||||
| 			} | ||||
| 		} else if v.Raw { | ||||
| 			writer.WriteString(v.Name) | ||||
| 		} else { | ||||
| 			stmt.DB.Dialector.QuoteTo(writer, v.Name) | ||||
| 		} | ||||
| @ -275,33 +279,33 @@ func (stmt *Statement) Parse(value interface{}) (err error) { | ||||
| } | ||||
| 
 | ||||
| func (stmt *Statement) reinit() { | ||||
| 	stmt.Table = "" | ||||
| 	stmt.Model = nil | ||||
| 	stmt.Selects = nil | ||||
| 	stmt.Omits = nil | ||||
| 	stmt.ConnPool = stmt.DB.Config.ConnPool | ||||
| 	stmt.Schema = nil | ||||
| 	stmt.Context = context.Background() | ||||
| 	stmt.RaiseErrorOnNotFound = false | ||||
| 	// stmt.Table = ""
 | ||||
| 	// stmt.Model = nil
 | ||||
| 	// stmt.Selects = nil
 | ||||
| 	// stmt.Omits = nil
 | ||||
| 	// stmt.ConnPool = stmt.DB.Config.ConnPool
 | ||||
| 	// stmt.Context = context.Background()
 | ||||
| 	// stmt.RaiseErrorOnNotFound = false
 | ||||
| 
 | ||||
| 	// for k := range stmt.Clauses {
 | ||||
| 	// 	delete(stmt.Clauses, k)
 | ||||
| 	// }
 | ||||
| 
 | ||||
| 	// for k := range stmt.Joins {
 | ||||
| 	// 	delete(stmt.Joins, k)
 | ||||
| 	// }
 | ||||
| 
 | ||||
| 	// for k := range stmt.Preloads {
 | ||||
| 	// 	delete(stmt.Preloads, k)
 | ||||
| 	// }
 | ||||
| 
 | ||||
| 	// stmt.Settings.Range(func(k, _ interface{}) bool {
 | ||||
| 	// 	stmt.Settings.Delete(k)
 | ||||
| 	// 	return true
 | ||||
| 	// })
 | ||||
| 
 | ||||
| 	stmt.Schema = nil | ||||
| 	stmt.SQL.Reset() | ||||
| 	stmt.Vars = nil | ||||
| 	stmt.NamedVars = nil | ||||
| 
 | ||||
| 	for k := range stmt.Clauses { | ||||
| 		delete(stmt.Clauses, k) | ||||
| 	} | ||||
| 
 | ||||
| 	for k := range stmt.Joins { | ||||
| 		delete(stmt.Joins, k) | ||||
| 	} | ||||
| 
 | ||||
| 	for k := range stmt.Preloads { | ||||
| 		delete(stmt.Preloads, k) | ||||
| 	} | ||||
| 
 | ||||
| 	stmt.Settings.Range(func(k, _ interface{}) bool { | ||||
| 		stmt.Settings.Delete(k) | ||||
| 		return true | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| @ -21,4 +21,12 @@ func TestAssociationForBelongsTo(t *testing.T) { | ||||
| 	user2.Manager = &User{} | ||||
| 	DB.Model(&user2).Association("Manager").Find(user2.Manager) | ||||
| 	CheckUser(t, user2, user) | ||||
| 
 | ||||
| 	if count := DB.Model(&user).Association("Company").Count(); count != 1 { | ||||
| 		t.Errorf("invalid company count, got %v", count) | ||||
| 	} | ||||
| 
 | ||||
| 	if count := DB.Model(&user).Association("Manager").Count(); count != 1 { | ||||
| 		t.Errorf("invalid manager count, got %v", count) | ||||
| 	} | ||||
| } | ||||
|  | ||||
							
								
								
									
										42
									
								
								tests/count_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								tests/count_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,42 @@ | ||||
| package tests_test | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	. "github.com/jinzhu/gorm/tests" | ||||
| ) | ||||
| 
 | ||||
| func TestCount(t *testing.T) { | ||||
| 	var ( | ||||
| 		user1                 = *GetUser("count-1", Config{}) | ||||
| 		user2                 = *GetUser("count-2", Config{}) | ||||
| 		user3                 = *GetUser("count-3", Config{}) | ||||
| 		users                 []User | ||||
| 		count, count1, count2 int64 | ||||
| 	) | ||||
| 
 | ||||
| 	DB.Save(&user1).Save(&user2).Save(&user3) | ||||
| 
 | ||||
| 	if err := DB.Where("name = ?", user1.Name).Or("name = ?", user3.Name).Find(&users).Count(&count).Error; err != nil { | ||||
| 		t.Errorf(fmt.Sprintf("Count should work, but got err %v", err)) | ||||
| 	} | ||||
| 
 | ||||
| 	if count != int64(len(users)) { | ||||
| 		t.Errorf("Count() method should get correct value, expect: %v, got %v", count, len(users)) | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Model(&User{}).Where("name = ?", user1.Name).Count(&count1).Or("name in ?", []string{user2.Name, user3.Name}).Count(&count2) | ||||
| 	if count1 != 1 || count2 != 3 { | ||||
| 		t.Errorf("multiple count in chain should works") | ||||
| 	} | ||||
| 
 | ||||
| 	var count3 int64 | ||||
| 	if err := DB.Model(&User{}).Where("name in ?", []string{user2.Name, user2.Name, user3.Name}).Group("id").Count(&count3).Error; err != nil { | ||||
| 		t.Errorf("No error should happen when count with group, but got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if count3 != 2 { | ||||
| 		t.Errorf("Should get correct count for count with group, but got %v", count3) | ||||
| 	} | ||||
| } | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu