Add Count tests
This commit is contained in:
		
							parent
							
								
									1c39ac921b
								
							
						
					
					
						commit
						cbc4a81140
					
				| @ -247,11 +247,12 @@ func (association *Association) Clear() error { | |||||||
| 	return association.Replace() | 	return association.Replace() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (association *Association) Count() (count int) { | func (association *Association) Count() (count int64) { | ||||||
| 	if association.Error == nil { | 	if association.Error == nil { | ||||||
| 		var ( | 		var ( | ||||||
| 			tx    = association.DB | 			conds      = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue) | ||||||
| 			conds = association.Relationship.ToQueryConditions(tx.Statement.ReflectValue) | 			modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface() | ||||||
|  | 			tx         = association.DB.Model(modelValue) | ||||||
| 		) | 		) | ||||||
| 
 | 
 | ||||||
| 		if association.Relationship.JoinTable != nil { | 		if association.Relationship.JoinTable != nil { | ||||||
|  | |||||||
| @ -73,6 +73,7 @@ func (cs *callbacks) Raw() *processor { | |||||||
| 
 | 
 | ||||||
| func (p *processor) Execute(db *DB) { | func (p *processor) Execute(db *DB) { | ||||||
| 	curTime := time.Now() | 	curTime := time.Now() | ||||||
|  | 	db.RowsAffected = 0 | ||||||
| 	if stmt := db.Statement; stmt != nil { | 	if stmt := db.Statement; stmt != nil { | ||||||
| 		if stmt.Model == nil { | 		if stmt.Model == nil { | ||||||
| 			stmt.Model = stmt.Dest | 			stmt.Model = stmt.Dest | ||||||
| @ -102,7 +103,7 @@ func (p *processor) Execute(db *DB) { | |||||||
| 		}, db.Error) | 		}, db.Error) | ||||||
| 
 | 
 | ||||||
| 		stmt.reinit() | 		stmt.reinit() | ||||||
| 		db.Config.statementPool.Put(stmt) | 		// db.Config.statementPool.Put(stmt)
 | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -21,6 +21,11 @@ func Query(db *gorm.DB) { | |||||||
| 					clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ | 					clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ | ||||||
| 						Name: f.DBName, | 						Name: f.DBName, | ||||||
| 					}) | 					}) | ||||||
|  | 				} else { | ||||||
|  | 					clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ | ||||||
|  | 						Name: name, | ||||||
|  | 						Raw:  true, | ||||||
|  | 					}) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| @ -85,7 +90,7 @@ func Query(db *gorm.DB) { | |||||||
| 			db.Statement.AddClauseIfNotExists(clause.From{}) | 			db.Statement.AddClauseIfNotExists(clause.From{}) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		db.Statement.AddClauseIfNotExists(clauseSelect) | 		db.Statement.AddClause(clauseSelect) | ||||||
| 		db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") | 		db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -49,6 +49,11 @@ func Scan(rows *sql.Rows, db *gorm.DB) { | |||||||
| 			} | 			} | ||||||
| 			*dest = append(*dest, v) | 			*dest = append(*dest, v) | ||||||
| 		} | 		} | ||||||
|  | 	case *int, *int64, *uint, *uint64: | ||||||
|  | 		for rows.Next() { | ||||||
|  | 			db.RowsAffected++ | ||||||
|  | 			rows.Scan(dest) | ||||||
|  | 		} | ||||||
| 	default: | 	default: | ||||||
| 		switch db.Statement.ReflectValue.Kind() { | 		switch db.Statement.ReflectValue.Kind() { | ||||||
| 		case reflect.Slice, reflect.Array: | 		case reflect.Slice, reflect.Array: | ||||||
|  | |||||||
| @ -41,8 +41,5 @@ func (values Values) Build(builder Builder) { | |||||||
| // MergeClause merge values clauses
 | // MergeClause merge values clauses
 | ||||||
| func (values Values) MergeClause(clause *Clause) { | func (values Values) MergeClause(clause *Clause) { | ||||||
| 	clause.Name = "" | 	clause.Name = "" | ||||||
| 	if v, ok := clause.Expression.(Values); ok { |  | ||||||
| 		values.Values = append(v.Values, values.Values...) |  | ||||||
| 	} |  | ||||||
| 	clause.Expression = values | 	clause.Expression = values | ||||||
| } | } | ||||||
|  | |||||||
| @ -145,8 +145,19 @@ func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) { | |||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (db *DB) Count(value interface{}) (tx *DB) { | func (db *DB) Count(count *int64) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
|  | 	if len(tx.Statement.Selects) == 0 { | ||||||
|  | 		tx.Statement.Selects = []string{"count(1)"} | ||||||
|  | 	} | ||||||
|  | 	if tx.Statement.Model == nil { | ||||||
|  | 		tx.Statement.Model = tx.Statement.Dest | ||||||
|  | 	} | ||||||
|  | 	tx.Statement.Dest = count | ||||||
|  | 	tx.callbacks.Query().Execute(tx) | ||||||
|  | 	if db.RowsAffected != 1 { | ||||||
|  | 		*count = db.RowsAffected | ||||||
|  | 	} | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | |||||||
							
								
								
									
										54
									
								
								statement.go
									
									
									
									
									
								
							
							
						
						
									
										54
									
								
								statement.go
									
									
									
									
									
								
							| @ -63,6 +63,8 @@ func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) { | |||||||
| 	case clause.Table: | 	case clause.Table: | ||||||
| 		if v.Name == clause.CurrentTable { | 		if v.Name == clause.CurrentTable { | ||||||
| 			stmt.DB.Dialector.QuoteTo(writer, stmt.Table) | 			stmt.DB.Dialector.QuoteTo(writer, stmt.Table) | ||||||
|  | 		} else if v.Raw { | ||||||
|  | 			writer.WriteString(v.Name) | ||||||
| 		} else { | 		} else { | ||||||
| 			stmt.DB.Dialector.QuoteTo(writer, v.Name) | 			stmt.DB.Dialector.QuoteTo(writer, v.Name) | ||||||
| 		} | 		} | ||||||
| @ -85,6 +87,8 @@ func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) { | |||||||
| 			if stmt.Schema != nil && stmt.Schema.PrioritizedPrimaryField != nil { | 			if stmt.Schema != nil && stmt.Schema.PrioritizedPrimaryField != nil { | ||||||
| 				stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.PrioritizedPrimaryField.DBName) | 				stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.PrioritizedPrimaryField.DBName) | ||||||
| 			} | 			} | ||||||
|  | 		} else if v.Raw { | ||||||
|  | 			writer.WriteString(v.Name) | ||||||
| 		} else { | 		} else { | ||||||
| 			stmt.DB.Dialector.QuoteTo(writer, v.Name) | 			stmt.DB.Dialector.QuoteTo(writer, v.Name) | ||||||
| 		} | 		} | ||||||
| @ -275,33 +279,33 @@ func (stmt *Statement) Parse(value interface{}) (err error) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (stmt *Statement) reinit() { | func (stmt *Statement) reinit() { | ||||||
| 	stmt.Table = "" | 	// stmt.Table = ""
 | ||||||
| 	stmt.Model = nil | 	// stmt.Model = nil
 | ||||||
| 	stmt.Selects = nil | 	// stmt.Selects = nil
 | ||||||
| 	stmt.Omits = nil | 	// stmt.Omits = nil
 | ||||||
| 	stmt.ConnPool = stmt.DB.Config.ConnPool | 	// stmt.ConnPool = stmt.DB.Config.ConnPool
 | ||||||
| 	stmt.Schema = nil | 	// stmt.Context = context.Background()
 | ||||||
| 	stmt.Context = context.Background() | 	// stmt.RaiseErrorOnNotFound = false
 | ||||||
| 	stmt.RaiseErrorOnNotFound = false |  | ||||||
| 
 | 
 | ||||||
|  | 	// for k := range stmt.Clauses {
 | ||||||
|  | 	// 	delete(stmt.Clauses, k)
 | ||||||
|  | 	// }
 | ||||||
|  | 
 | ||||||
|  | 	// for k := range stmt.Joins {
 | ||||||
|  | 	// 	delete(stmt.Joins, k)
 | ||||||
|  | 	// }
 | ||||||
|  | 
 | ||||||
|  | 	// for k := range stmt.Preloads {
 | ||||||
|  | 	// 	delete(stmt.Preloads, k)
 | ||||||
|  | 	// }
 | ||||||
|  | 
 | ||||||
|  | 	// stmt.Settings.Range(func(k, _ interface{}) bool {
 | ||||||
|  | 	// 	stmt.Settings.Delete(k)
 | ||||||
|  | 	// 	return true
 | ||||||
|  | 	// })
 | ||||||
|  | 
 | ||||||
|  | 	stmt.Schema = nil | ||||||
| 	stmt.SQL.Reset() | 	stmt.SQL.Reset() | ||||||
| 	stmt.Vars = nil | 	stmt.Vars = nil | ||||||
| 	stmt.NamedVars = nil | 	stmt.NamedVars = nil | ||||||
| 
 |  | ||||||
| 	for k := range stmt.Clauses { |  | ||||||
| 		delete(stmt.Clauses, k) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	for k := range stmt.Joins { |  | ||||||
| 		delete(stmt.Joins, k) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	for k := range stmt.Preloads { |  | ||||||
| 		delete(stmt.Preloads, k) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	stmt.Settings.Range(func(k, _ interface{}) bool { |  | ||||||
| 		stmt.Settings.Delete(k) |  | ||||||
| 		return true |  | ||||||
| 	}) |  | ||||||
| } | } | ||||||
|  | |||||||
| @ -21,4 +21,12 @@ func TestAssociationForBelongsTo(t *testing.T) { | |||||||
| 	user2.Manager = &User{} | 	user2.Manager = &User{} | ||||||
| 	DB.Model(&user2).Association("Manager").Find(user2.Manager) | 	DB.Model(&user2).Association("Manager").Find(user2.Manager) | ||||||
| 	CheckUser(t, user2, user) | 	CheckUser(t, user2, user) | ||||||
|  | 
 | ||||||
|  | 	if count := DB.Model(&user).Association("Company").Count(); count != 1 { | ||||||
|  | 		t.Errorf("invalid company count, got %v", count) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if count := DB.Model(&user).Association("Manager").Count(); count != 1 { | ||||||
|  | 		t.Errorf("invalid manager count, got %v", count) | ||||||
|  | 	} | ||||||
| } | } | ||||||
|  | |||||||
							
								
								
									
										42
									
								
								tests/count_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								tests/count_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,42 @@ | |||||||
|  | package tests_test | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"testing" | ||||||
|  | 
 | ||||||
|  | 	. "github.com/jinzhu/gorm/tests" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func TestCount(t *testing.T) { | ||||||
|  | 	var ( | ||||||
|  | 		user1                 = *GetUser("count-1", Config{}) | ||||||
|  | 		user2                 = *GetUser("count-2", Config{}) | ||||||
|  | 		user3                 = *GetUser("count-3", Config{}) | ||||||
|  | 		users                 []User | ||||||
|  | 		count, count1, count2 int64 | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	DB.Save(&user1).Save(&user2).Save(&user3) | ||||||
|  | 
 | ||||||
|  | 	if err := DB.Where("name = ?", user1.Name).Or("name = ?", user3.Name).Find(&users).Count(&count).Error; err != nil { | ||||||
|  | 		t.Errorf(fmt.Sprintf("Count should work, but got err %v", err)) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if count != int64(len(users)) { | ||||||
|  | 		t.Errorf("Count() method should get correct value, expect: %v, got %v", count, len(users)) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	DB.Model(&User{}).Where("name = ?", user1.Name).Count(&count1).Or("name in ?", []string{user2.Name, user3.Name}).Count(&count2) | ||||||
|  | 	if count1 != 1 || count2 != 3 { | ||||||
|  | 		t.Errorf("multiple count in chain should works") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var count3 int64 | ||||||
|  | 	if err := DB.Model(&User{}).Where("name in ?", []string{user2.Name, user2.Name, user3.Name}).Group("id").Count(&count3).Error; err != nil { | ||||||
|  | 		t.Errorf("No error should happen when count with group, but got %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if count3 != 2 { | ||||||
|  | 		t.Errorf("Should get correct count for count with group, but got %v", count3) | ||||||
|  | 	} | ||||||
|  | } | ||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu