diff --git a/callbacks/query.go b/callbacks/query.go index c9fa160f..84b9ed98 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -105,7 +105,7 @@ func Query(db *gorm.DB) { } defer rows.Close() - Scan(rows, db) + gorm.Scan(rows, db, false) } func Preload(db *gorm.DB) { diff --git a/finisher_api.go b/finisher_api.go index 84168e23..04b25ed2 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -186,8 +186,13 @@ func (db *DB) Scan(dest interface{}) (tx *DB) { return } -func (db *DB) ScanRows(rows *sql.Rows, result interface{}) error { - return nil +func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { + tx := db.getInstance() + tx.Error = tx.Statement.Parse(dest) + tx.Statement.Dest = dest + tx.Statement.ReflectValue = reflect.Indirect(reflect.ValueOf(dest)) + Scan(rows, tx, true) + return tx.Error } // Transaction start a transaction as a block, return error will rollback, otherwise to commit. diff --git a/callbacks/scan.go b/scan.go similarity index 91% rename from callbacks/scan.go rename to scan.go index 9ffcab4a..d2169f87 100644 --- a/callbacks/scan.go +++ b/scan.go @@ -1,15 +1,14 @@ -package callbacks +package gorm import ( "database/sql" "reflect" "strings" - "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/schema" ) -func Scan(rows *sql.Rows, db *gorm.DB) { +func Scan(rows *sql.Rows, db *DB, initialized bool) { columns, _ := rows.Columns() values := make([]interface{}, len(columns)) @@ -19,7 +18,7 @@ func Scan(rows *sql.Rows, db *gorm.DB) { values[idx] = new(interface{}) } - if rows.Next() { + if initialized || rows.Next() { db.RowsAffected++ rows.Scan(values...) } @@ -39,7 +38,8 @@ func Scan(rows *sql.Rows, db *gorm.DB) { values[idx] = new(interface{}) } - for rows.Next() { + for initialized || rows.Next() { + initialized = false db.RowsAffected++ rows.Scan(values...) @@ -50,7 +50,8 @@ func Scan(rows *sql.Rows, db *gorm.DB) { *dest = append(*dest, v) } case *int, *int64, *uint, *uint64: - for rows.Next() { + for initialized || rows.Next() { + initialized = false db.RowsAffected++ rows.Scan(dest) } @@ -78,7 +79,8 @@ func Scan(rows *sql.Rows, db *gorm.DB) { } } - for rows.Next() { + for initialized || rows.Next() { + initialized = false elem := reflect.New(db.Statement.Schema.ModelType).Elem() for idx, field := range fields { if field != nil { @@ -118,7 +120,7 @@ func Scan(rows *sql.Rows, db *gorm.DB) { } } - if rows.Next() { + if initialized || rows.Next() { db.RowsAffected++ if err := rows.Scan(values...); err != nil { db.AddError(err) @@ -128,6 +130,6 @@ func Scan(rows *sql.Rows, db *gorm.DB) { } if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound { - db.AddError(gorm.ErrRecordNotFound) + db.AddError(ErrRecordNotFound) } } diff --git a/tests/scan_test.go b/tests/scan_test.go index f7a14636..fc6c1721 100644 --- a/tests/scan_test.go +++ b/tests/scan_test.go @@ -1,6 +1,9 @@ package tests_test import ( + "reflect" + "sort" + "strings" "testing" . "github.com/jinzhu/gorm/tests" @@ -24,7 +27,7 @@ func TestScan(t *testing.T) { } var doubleAgeRes = &result{} - if err := DB.Debug().Table("users").Select("age + age as age").Where("id = ?", user3.ID).Scan(&doubleAgeRes).Error; err != nil { + if err := DB.Table("users").Select("age + age as age").Where("id = ?", user3.ID).Scan(&doubleAgeRes).Error; err != nil { t.Errorf("Scan to pointer of pointer") } @@ -32,9 +35,44 @@ func TestScan(t *testing.T) { t.Errorf("Scan double age as age, expect: %v, got %v", res.Age*2, doubleAgeRes.Age) } - var ress []result - DB.Table("users").Select("name, age").Where("id in ?", []uint{user2.ID, user3.ID}).Scan(&ress) - if len(ress) != 2 || ress[0].Name != user2.Name || ress[1].Name != user3.Name { + var results []result + DB.Table("users").Select("name, age").Where("id in ?", []uint{user2.ID, user3.ID}).Scan(&results) + + sort.Slice(results, func(i, j int) bool { + return strings.Compare(results[i].Name, results[j].Name) < -1 + }) + + if len(results) != 2 || results[0].Name != user2.Name || results[1].Name != user3.Name { t.Errorf("Scan into struct map") } } + +func TestScanRows(t *testing.T) { + user1 := User{Name: "ScanRowsUser1", Age: 1} + user2 := User{Name: "ScanRowsUser2", Age: 10} + user3 := User{Name: "ScanRowsUser3", Age: 20} + DB.Save(&user1).Save(&user2).Save(&user3) + + rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows() + if err != nil { + t.Errorf("Not error should happen, got %v", err) + } + + type Result struct { + Name string + Age int + } + + var results []Result + for rows.Next() { + var result Result + if err := DB.ScanRows(rows, &result); err != nil { + t.Errorf("should get no error, but got %v", err) + } + results = append(results, result) + } + + if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) { + t.Errorf("Should find expected results") + } +}