fix: nested scan (#6136)
This commit is contained in:
parent
cc2d46e5be
commit
62230f6c84
@ -11,8 +11,7 @@ func RowQuery(db *gorm.DB) {
|
||||
return
|
||||
}
|
||||
|
||||
if isRows, ok := db.Get("rows"); ok && isRows.(bool) {
|
||||
db.Statement.Settings.Delete("rows")
|
||||
if isRows := db.PopQueryType(); isRows {
|
||||
db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
} else {
|
||||
db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
|
@ -499,7 +499,7 @@ func (db *DB) Count(count *int64) (tx *DB) {
|
||||
}
|
||||
|
||||
func (db *DB) Row() *sql.Row {
|
||||
tx := db.getInstance().Set("rows", false)
|
||||
tx := db.getInstance().PushQueryType(false)
|
||||
tx = tx.callbacks.Row().Execute(tx)
|
||||
row, ok := tx.Statement.Dest.(*sql.Row)
|
||||
if !ok && tx.DryRun {
|
||||
@ -509,7 +509,7 @@ func (db *DB) Row() *sql.Row {
|
||||
}
|
||||
|
||||
func (db *DB) Rows() (*sql.Rows, error) {
|
||||
tx := db.getInstance().Set("rows", true)
|
||||
tx := db.getInstance().PushQueryType(true)
|
||||
tx = tx.callbacks.Row().Execute(tx)
|
||||
rows, ok := tx.Statement.Dest.(*sql.Rows)
|
||||
if !ok && tx.DryRun && tx.Error == nil {
|
||||
|
16
gorm.go
16
gorm.go
@ -340,6 +340,22 @@ func (db *DB) InstanceGet(key string) (interface{}, bool) {
|
||||
return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key)
|
||||
}
|
||||
|
||||
func (db *DB) PushQueryType(rows bool) *DB {
|
||||
tx := db.getInstance()
|
||||
tx.Statement.queryTypes = append(tx.Statement.queryTypes, rows)
|
||||
return tx
|
||||
}
|
||||
|
||||
func (db *DB) PopQueryType() bool {
|
||||
length := len(db.Statement.queryTypes)
|
||||
if length == 0 {
|
||||
return false
|
||||
}
|
||||
value := db.Statement.queryTypes[length-1]
|
||||
db.Statement.queryTypes = db.Statement.queryTypes[:length-1]
|
||||
return value
|
||||
}
|
||||
|
||||
// Callback returns callback manager
|
||||
func (db *DB) Callback() *callbacks {
|
||||
return db.callbacks
|
||||
|
@ -46,6 +46,7 @@ type Statement struct {
|
||||
attrs []interface{}
|
||||
assigns []interface{}
|
||||
scopes []func(*DB) *DB
|
||||
queryTypes []bool
|
||||
}
|
||||
|
||||
type join struct {
|
||||
@ -543,6 +544,11 @@ func (stmt *Statement) clone() *Statement {
|
||||
copy(newStmt.scopes, stmt.scopes)
|
||||
}
|
||||
|
||||
if len(stmt.queryTypes) > 0 {
|
||||
newStmt.queryTypes = make([]bool, len(stmt.queryTypes))
|
||||
copy(newStmt.queryTypes, stmt.queryTypes)
|
||||
}
|
||||
|
||||
stmt.Settings.Range(func(k, v interface{}) bool {
|
||||
newStmt.Settings.Store(k, v)
|
||||
return true
|
||||
|
@ -71,4 +71,15 @@ func TestScopes(t *testing.T) {
|
||||
if err := DB.Scopes(userTable).Select("max(id)").Scan(&maxId).Error; err != nil {
|
||||
t.Errorf("select max(id)")
|
||||
}
|
||||
|
||||
var user User
|
||||
if err := DB.Scopes(func(db *gorm.DB) *gorm.DB {
|
||||
var maxID int64
|
||||
if err := db.Raw("select max(id) from users").Scan(&maxID).Error; err != nil {
|
||||
return db
|
||||
}
|
||||
return db.Raw("select * from users where id = ?", maxID)
|
||||
}).Scan(&user).Error; err != nil {
|
||||
t.Errorf("failed to find user, got err: %v", err)
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user