diff --git a/association.go b/association.go index abcae47d..bd2a7cdd 100644 --- a/association.go +++ b/association.go @@ -247,11 +247,12 @@ func (association *Association) Clear() error { return association.Replace() } -func (association *Association) Count() (count int) { +func (association *Association) Count() (count int64) { if association.Error == nil { var ( - tx = association.DB - conds = association.Relationship.ToQueryConditions(tx.Statement.ReflectValue) + conds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue) + modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface() + tx = association.DB.Model(modelValue) ) if association.Relationship.JoinTable != nil { diff --git a/callbacks.go b/callbacks.go index 61cebc81..629b90aa 100644 --- a/callbacks.go +++ b/callbacks.go @@ -73,6 +73,7 @@ func (cs *callbacks) Raw() *processor { func (p *processor) Execute(db *DB) { curTime := time.Now() + db.RowsAffected = 0 if stmt := db.Statement; stmt != nil { if stmt.Model == nil { stmt.Model = stmt.Dest @@ -102,7 +103,7 @@ func (p *processor) Execute(db *DB) { }, db.Error) stmt.reinit() - db.Config.statementPool.Put(stmt) + // db.Config.statementPool.Put(stmt) } } diff --git a/callbacks/query.go b/callbacks/query.go index 4a89c575..95b5ead3 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -21,6 +21,11 @@ func Query(db *gorm.DB) { clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ Name: f.DBName, }) + } else { + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ + Name: name, + Raw: true, + }) } } } @@ -85,7 +90,7 @@ func Query(db *gorm.DB) { db.Statement.AddClauseIfNotExists(clause.From{}) } - db.Statement.AddClauseIfNotExists(clauseSelect) + db.Statement.AddClause(clauseSelect) db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") } diff --git a/callbacks/scan.go b/callbacks/scan.go index 6ea8bf23..9ffcab4a 100644 --- a/callbacks/scan.go +++ b/callbacks/scan.go @@ -49,6 +49,11 @@ func Scan(rows *sql.Rows, db *gorm.DB) { } *dest = append(*dest, v) } + case *int, *int64, *uint, *uint64: + for rows.Next() { + db.RowsAffected++ + rows.Scan(dest) + } default: switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: diff --git a/clause/values.go b/clause/values.go index a997fc26..b2f5421b 100644 --- a/clause/values.go +++ b/clause/values.go @@ -41,8 +41,5 @@ func (values Values) Build(builder Builder) { // MergeClause merge values clauses func (values Values) MergeClause(clause *Clause) { clause.Name = "" - if v, ok := clause.Expression.(Values); ok { - values.Values = append(v.Values, values.Values...) - } clause.Expression = values } diff --git a/finisher_api.go b/finisher_api.go index 1b2a7e29..6a787576 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -145,8 +145,19 @@ func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) { return } -func (db *DB) Count(value interface{}) (tx *DB) { +func (db *DB) Count(count *int64) (tx *DB) { tx = db.getInstance() + if len(tx.Statement.Selects) == 0 { + tx.Statement.Selects = []string{"count(1)"} + } + if tx.Statement.Model == nil { + tx.Statement.Model = tx.Statement.Dest + } + tx.Statement.Dest = count + tx.callbacks.Query().Execute(tx) + if db.RowsAffected != 1 { + *count = db.RowsAffected + } return } diff --git a/statement.go b/statement.go index 1ea5a56c..0abf7a7e 100644 --- a/statement.go +++ b/statement.go @@ -63,6 +63,8 @@ func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) { case clause.Table: if v.Name == clause.CurrentTable { stmt.DB.Dialector.QuoteTo(writer, stmt.Table) + } else if v.Raw { + writer.WriteString(v.Name) } else { stmt.DB.Dialector.QuoteTo(writer, v.Name) } @@ -85,6 +87,8 @@ func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) { if stmt.Schema != nil && stmt.Schema.PrioritizedPrimaryField != nil { stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.PrioritizedPrimaryField.DBName) } + } else if v.Raw { + writer.WriteString(v.Name) } else { stmt.DB.Dialector.QuoteTo(writer, v.Name) } @@ -275,33 +279,33 @@ func (stmt *Statement) Parse(value interface{}) (err error) { } func (stmt *Statement) reinit() { - stmt.Table = "" - stmt.Model = nil - stmt.Selects = nil - stmt.Omits = nil - stmt.ConnPool = stmt.DB.Config.ConnPool - stmt.Schema = nil - stmt.Context = context.Background() - stmt.RaiseErrorOnNotFound = false + // stmt.Table = "" + // stmt.Model = nil + // stmt.Selects = nil + // stmt.Omits = nil + // stmt.ConnPool = stmt.DB.Config.ConnPool + // stmt.Context = context.Background() + // stmt.RaiseErrorOnNotFound = false + // for k := range stmt.Clauses { + // delete(stmt.Clauses, k) + // } + + // for k := range stmt.Joins { + // delete(stmt.Joins, k) + // } + + // for k := range stmt.Preloads { + // delete(stmt.Preloads, k) + // } + + // stmt.Settings.Range(func(k, _ interface{}) bool { + // stmt.Settings.Delete(k) + // return true + // }) + + stmt.Schema = nil stmt.SQL.Reset() stmt.Vars = nil stmt.NamedVars = nil - - for k := range stmt.Clauses { - delete(stmt.Clauses, k) - } - - for k := range stmt.Joins { - delete(stmt.Joins, k) - } - - for k := range stmt.Preloads { - delete(stmt.Preloads, k) - } - - stmt.Settings.Range(func(k, _ interface{}) bool { - stmt.Settings.Delete(k) - return true - }) } diff --git a/tests/associations_test.go b/tests/associations_test.go index dc88ee03..845ee65e 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -21,4 +21,12 @@ func TestAssociationForBelongsTo(t *testing.T) { user2.Manager = &User{} DB.Model(&user2).Association("Manager").Find(user2.Manager) CheckUser(t, user2, user) + + if count := DB.Model(&user).Association("Company").Count(); count != 1 { + t.Errorf("invalid company count, got %v", count) + } + + if count := DB.Model(&user).Association("Manager").Count(); count != 1 { + t.Errorf("invalid manager count, got %v", count) + } } diff --git a/tests/count_test.go b/tests/count_test.go new file mode 100644 index 00000000..960db167 --- /dev/null +++ b/tests/count_test.go @@ -0,0 +1,42 @@ +package tests_test + +import ( + "fmt" + "testing" + + . "github.com/jinzhu/gorm/tests" +) + +func TestCount(t *testing.T) { + var ( + user1 = *GetUser("count-1", Config{}) + user2 = *GetUser("count-2", Config{}) + user3 = *GetUser("count-3", Config{}) + users []User + count, count1, count2 int64 + ) + + DB.Save(&user1).Save(&user2).Save(&user3) + + if err := DB.Where("name = ?", user1.Name).Or("name = ?", user3.Name).Find(&users).Count(&count).Error; err != nil { + t.Errorf(fmt.Sprintf("Count should work, but got err %v", err)) + } + + if count != int64(len(users)) { + t.Errorf("Count() method should get correct value, expect: %v, got %v", count, len(users)) + } + + DB.Model(&User{}).Where("name = ?", user1.Name).Count(&count1).Or("name in ?", []string{user2.Name, user3.Name}).Count(&count2) + if count1 != 1 || count2 != 3 { + t.Errorf("multiple count in chain should works") + } + + var count3 int64 + if err := DB.Model(&User{}).Where("name in ?", []string{user2.Name, user2.Name, user3.Name}).Group("id").Count(&count3).Error; err != nil { + t.Errorf("No error should happen when count with group, but got %v", err) + } + + if count3 != 2 { + t.Errorf("Should get correct count for count with group, but got %v", count3) + } +}