Fix Scan struct with primary key, close #3357
This commit is contained in:
		
							parent
							
								
									9a101c8a08
								
							
						
					
					
						commit
						dbaa6b0ec3
					
				| @ -79,6 +79,8 @@ func (p *processor) Execute(db *DB) { | |||||||
| 
 | 
 | ||||||
| 	if stmt.Model == nil { | 	if stmt.Model == nil { | ||||||
| 		stmt.Model = stmt.Dest | 		stmt.Model = stmt.Dest | ||||||
|  | 	} else if stmt.Dest == nil { | ||||||
|  | 		stmt.Dest = stmt.Model | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if stmt.Model != nil { | 	if stmt.Model != nil { | ||||||
|  | |||||||
| @ -11,7 +11,7 @@ func RowQuery(db *gorm.DB) { | |||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		if !db.DryRun { | 		if !db.DryRun { | ||||||
| 			if _, ok := db.Get("rows"); ok { | 			if isRows, ok := db.InstanceGet("rows"); ok && isRows.(bool) { | ||||||
| 				db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | 				db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||||
| 			} else { | 			} else { | ||||||
| 				db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | 				db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||||
|  | |||||||
| @ -331,13 +331,13 @@ func (db *DB) Count(count *int64) (tx *DB) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (db *DB) Row() *sql.Row { | func (db *DB) Row() *sql.Row { | ||||||
| 	tx := db.getInstance() | 	tx := db.getInstance().InstanceSet("rows", false) | ||||||
| 	tx.callbacks.Row().Execute(tx) | 	tx.callbacks.Row().Execute(tx) | ||||||
| 	return tx.Statement.Dest.(*sql.Row) | 	return tx.Statement.Dest.(*sql.Row) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (db *DB) Rows() (*sql.Rows, error) { | func (db *DB) Rows() (*sql.Rows, error) { | ||||||
| 	tx := db.Set("rows", true) | 	tx := db.getInstance().InstanceSet("rows", true) | ||||||
| 	tx.callbacks.Row().Execute(tx) | 	tx.callbacks.Row().Execute(tx) | ||||||
| 	return tx.Statement.Dest.(*sql.Rows), tx.Error | 	return tx.Statement.Dest.(*sql.Rows), tx.Error | ||||||
| } | } | ||||||
| @ -345,8 +345,14 @@ func (db *DB) Rows() (*sql.Rows, error) { | |||||||
| // Scan scan value to a struct
 | // Scan scan value to a struct
 | ||||||
| func (db *DB) Scan(dest interface{}) (tx *DB) { | func (db *DB) Scan(dest interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
| 	tx.Statement.Dest = dest | 	if rows, err := tx.Rows(); err != nil { | ||||||
| 	tx.callbacks.Query().Execute(tx) | 		tx.AddError(err) | ||||||
|  | 	} else { | ||||||
|  | 		defer rows.Close() | ||||||
|  | 		if rows.Next() { | ||||||
|  | 			tx.ScanRows(rows, dest) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| @ -379,7 +385,10 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { | |||||||
| 	tx := db.getInstance() | 	tx := db.getInstance() | ||||||
| 	tx.Error = tx.Statement.Parse(dest) | 	tx.Error = tx.Statement.Parse(dest) | ||||||
| 	tx.Statement.Dest = dest | 	tx.Statement.Dest = dest | ||||||
| 	tx.Statement.ReflectValue = reflect.Indirect(reflect.ValueOf(dest)) | 	tx.Statement.ReflectValue = reflect.ValueOf(dest) | ||||||
|  | 	for tx.Statement.ReflectValue.Kind() == reflect.Ptr { | ||||||
|  | 		tx.Statement.ReflectValue = tx.Statement.ReflectValue.Elem() | ||||||
|  | 	} | ||||||
| 	Scan(rows, tx, true) | 	Scan(rows, tx, true) | ||||||
| 	return tx.Error | 	return tx.Error | ||||||
| } | } | ||||||
|  | |||||||
| @ -3,13 +3,14 @@ package logger | |||||||
| import ( | import ( | ||||||
| 	"database/sql/driver" | 	"database/sql/driver" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"gorm.io/gorm/utils" |  | ||||||
| 	"reflect" | 	"reflect" | ||||||
| 	"regexp" | 	"regexp" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
| 	"unicode" | 	"unicode" | ||||||
|  | 
 | ||||||
|  | 	"gorm.io/gorm/utils" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func isPrintable(s []byte) bool { | func isPrintable(s []byte) bool { | ||||||
|  | |||||||
| @ -9,7 +9,7 @@ import ( | |||||||
| 
 | 
 | ||||||
| // Migrator returns migrator
 | // Migrator returns migrator
 | ||||||
| func (db *DB) Migrator() Migrator { | func (db *DB) Migrator() Migrator { | ||||||
| 	return db.Dialector.Migrator(db) | 	return db.Dialector.Migrator(db.Session(&Session{WithConditions: true})) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // AutoMigrate run auto migration for given models
 | // AutoMigrate run auto migration for given models
 | ||||||
|  | |||||||
| @ -6,6 +6,7 @@ import ( | |||||||
| 	"strings" | 	"strings" | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
|  | 	"gorm.io/gorm" | ||||||
| 	. "gorm.io/gorm/utils/tests" | 	. "gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| @ -16,14 +17,25 @@ func TestScan(t *testing.T) { | |||||||
| 	DB.Save(&user1).Save(&user2).Save(&user3) | 	DB.Save(&user1).Save(&user2).Save(&user3) | ||||||
| 
 | 
 | ||||||
| 	type result struct { | 	type result struct { | ||||||
|  | 		ID   uint | ||||||
| 		Name string | 		Name string | ||||||
| 		Age  int | 		Age  int | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	var res result | 	var res result | ||||||
| 	DB.Table("users").Select("name, age").Where("id = ?", user3.ID).Scan(&res) | 	DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&res) | ||||||
| 	if res.Name != user3.Name || res.Age != int(user3.Age) { | 	if res.ID != user3.ID || res.Name != user3.Name || res.Age != int(user3.Age) { | ||||||
| 		t.Errorf("Scan into struct should work") | 		t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	DB.Table("users").Select("id, name, age").Where("id = ?", user2.ID).Scan(&res) | ||||||
|  | 	if res.ID != user2.ID || res.Name != user2.Name || res.Age != int(user2.Age) { | ||||||
|  | 		t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user2) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	DB.Model(&User{Model: gorm.Model{ID: user3.ID}}).Select("id, name, age").Scan(&res) | ||||||
|  | 	if res.ID != user3.ID || res.Name != user3.Name || res.Age != int(user3.Age) { | ||||||
|  | 		t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	var doubleAgeRes = &result{} | 	var doubleAgeRes = &result{} | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu