Add ScanRows support
This commit is contained in:
		
							parent
							
								
									51c5be0503
								
							
						
					
					
						commit
						5be642a435
					
				| @ -105,7 +105,7 @@ func Query(db *gorm.DB) { | |||||||
| 	} | 	} | ||||||
| 	defer rows.Close() | 	defer rows.Close() | ||||||
| 
 | 
 | ||||||
| 	Scan(rows, db) | 	gorm.Scan(rows, db, false) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func Preload(db *gorm.DB) { | func Preload(db *gorm.DB) { | ||||||
|  | |||||||
| @ -186,8 +186,13 @@ func (db *DB) Scan(dest interface{}) (tx *DB) { | |||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (db *DB) ScanRows(rows *sql.Rows, result interface{}) error { | func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { | ||||||
| 	return nil | 	tx := db.getInstance() | ||||||
|  | 	tx.Error = tx.Statement.Parse(dest) | ||||||
|  | 	tx.Statement.Dest = dest | ||||||
|  | 	tx.Statement.ReflectValue = reflect.Indirect(reflect.ValueOf(dest)) | ||||||
|  | 	Scan(rows, tx, true) | ||||||
|  | 	return tx.Error | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Transaction start a transaction as a block, return error will rollback, otherwise to commit.
 | // Transaction start a transaction as a block, return error will rollback, otherwise to commit.
 | ||||||
|  | |||||||
| @ -1,15 +1,14 @@ | |||||||
| package callbacks | package gorm | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"database/sql" | 	"database/sql" | ||||||
| 	"reflect" | 	"reflect" | ||||||
| 	"strings" | 	"strings" | ||||||
| 
 | 
 | ||||||
| 	"github.com/jinzhu/gorm" |  | ||||||
| 	"github.com/jinzhu/gorm/schema" | 	"github.com/jinzhu/gorm/schema" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func Scan(rows *sql.Rows, db *gorm.DB) { | func Scan(rows *sql.Rows, db *DB, initialized bool) { | ||||||
| 	columns, _ := rows.Columns() | 	columns, _ := rows.Columns() | ||||||
| 	values := make([]interface{}, len(columns)) | 	values := make([]interface{}, len(columns)) | ||||||
| 
 | 
 | ||||||
| @ -19,7 +18,7 @@ func Scan(rows *sql.Rows, db *gorm.DB) { | |||||||
| 			values[idx] = new(interface{}) | 			values[idx] = new(interface{}) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		if rows.Next() { | 		if initialized || rows.Next() { | ||||||
| 			db.RowsAffected++ | 			db.RowsAffected++ | ||||||
| 			rows.Scan(values...) | 			rows.Scan(values...) | ||||||
| 		} | 		} | ||||||
| @ -39,7 +38,8 @@ func Scan(rows *sql.Rows, db *gorm.DB) { | |||||||
| 			values[idx] = new(interface{}) | 			values[idx] = new(interface{}) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		for rows.Next() { | 		for initialized || rows.Next() { | ||||||
|  | 			initialized = false | ||||||
| 			db.RowsAffected++ | 			db.RowsAffected++ | ||||||
| 			rows.Scan(values...) | 			rows.Scan(values...) | ||||||
| 
 | 
 | ||||||
| @ -50,7 +50,8 @@ func Scan(rows *sql.Rows, db *gorm.DB) { | |||||||
| 			*dest = append(*dest, v) | 			*dest = append(*dest, v) | ||||||
| 		} | 		} | ||||||
| 	case *int, *int64, *uint, *uint64: | 	case *int, *int64, *uint, *uint64: | ||||||
| 		for rows.Next() { | 		for initialized || rows.Next() { | ||||||
|  | 			initialized = false | ||||||
| 			db.RowsAffected++ | 			db.RowsAffected++ | ||||||
| 			rows.Scan(dest) | 			rows.Scan(dest) | ||||||
| 		} | 		} | ||||||
| @ -78,7 +79,8 @@ func Scan(rows *sql.Rows, db *gorm.DB) { | |||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			for rows.Next() { | 			for initialized || rows.Next() { | ||||||
|  | 				initialized = false | ||||||
| 				elem := reflect.New(db.Statement.Schema.ModelType).Elem() | 				elem := reflect.New(db.Statement.Schema.ModelType).Elem() | ||||||
| 				for idx, field := range fields { | 				for idx, field := range fields { | ||||||
| 					if field != nil { | 					if field != nil { | ||||||
| @ -118,7 +120,7 @@ func Scan(rows *sql.Rows, db *gorm.DB) { | |||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			if rows.Next() { | 			if initialized || rows.Next() { | ||||||
| 				db.RowsAffected++ | 				db.RowsAffected++ | ||||||
| 				if err := rows.Scan(values...); err != nil { | 				if err := rows.Scan(values...); err != nil { | ||||||
| 					db.AddError(err) | 					db.AddError(err) | ||||||
| @ -128,6 +130,6 @@ func Scan(rows *sql.Rows, db *gorm.DB) { | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound { | 	if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound { | ||||||
| 		db.AddError(gorm.ErrRecordNotFound) | 		db.AddError(ErrRecordNotFound) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| @ -1,6 +1,9 @@ | |||||||
| package tests_test | package tests_test | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"reflect" | ||||||
|  | 	"sort" | ||||||
|  | 	"strings" | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	. "github.com/jinzhu/gorm/tests" | 	. "github.com/jinzhu/gorm/tests" | ||||||
| @ -24,7 +27,7 @@ func TestScan(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	var doubleAgeRes = &result{} | 	var doubleAgeRes = &result{} | ||||||
| 	if err := DB.Debug().Table("users").Select("age + age as age").Where("id = ?", user3.ID).Scan(&doubleAgeRes).Error; err != nil { | 	if err := DB.Table("users").Select("age + age as age").Where("id = ?", user3.ID).Scan(&doubleAgeRes).Error; err != nil { | ||||||
| 		t.Errorf("Scan to pointer of pointer") | 		t.Errorf("Scan to pointer of pointer") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| @ -32,9 +35,44 @@ func TestScan(t *testing.T) { | |||||||
| 		t.Errorf("Scan double age as age, expect: %v, got %v", res.Age*2, doubleAgeRes.Age) | 		t.Errorf("Scan double age as age, expect: %v, got %v", res.Age*2, doubleAgeRes.Age) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	var ress []result | 	var results []result | ||||||
| 	DB.Table("users").Select("name, age").Where("id in ?", []uint{user2.ID, user3.ID}).Scan(&ress) | 	DB.Table("users").Select("name, age").Where("id in ?", []uint{user2.ID, user3.ID}).Scan(&results) | ||||||
| 	if len(ress) != 2 || ress[0].Name != user2.Name || ress[1].Name != user3.Name { | 
 | ||||||
|  | 	sort.Slice(results, func(i, j int) bool { | ||||||
|  | 		return strings.Compare(results[i].Name, results[j].Name) < -1 | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	if len(results) != 2 || results[0].Name != user2.Name || results[1].Name != user3.Name { | ||||||
| 		t.Errorf("Scan into struct map") | 		t.Errorf("Scan into struct map") | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func TestScanRows(t *testing.T) { | ||||||
|  | 	user1 := User{Name: "ScanRowsUser1", Age: 1} | ||||||
|  | 	user2 := User{Name: "ScanRowsUser2", Age: 10} | ||||||
|  | 	user3 := User{Name: "ScanRowsUser3", Age: 20} | ||||||
|  | 	DB.Save(&user1).Save(&user2).Save(&user3) | ||||||
|  | 
 | ||||||
|  | 	rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows() | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Errorf("Not error should happen, got %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	type Result struct { | ||||||
|  | 		Name string | ||||||
|  | 		Age  int | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var results []Result | ||||||
|  | 	for rows.Next() { | ||||||
|  | 		var result Result | ||||||
|  | 		if err := DB.ScanRows(rows, &result); err != nil { | ||||||
|  | 			t.Errorf("should get no error, but got %v", err) | ||||||
|  | 		} | ||||||
|  | 		results = append(results, result) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) { | ||||||
|  | 		t.Errorf("Should find expected results") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu