diff --git a/callbacks/row.go b/callbacks/row.go index beaa189e..19510716 100644 --- a/callbacks/row.go +++ b/callbacks/row.go @@ -11,8 +11,7 @@ func RowQuery(db *gorm.DB) { return } - if isRows, ok := db.Get("rows"); ok && isRows.(bool) { - db.Statement.Settings.Delete("rows") + if isRows := db.PopQueryType(); isRows { db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) } else { db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) diff --git a/finisher_api.go b/finisher_api.go index e6fe4666..0fa7d79f 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -499,7 +499,7 @@ func (db *DB) Count(count *int64) (tx *DB) { } func (db *DB) Row() *sql.Row { - tx := db.getInstance().Set("rows", false) + tx := db.getInstance().PushQueryType(false) tx = tx.callbacks.Row().Execute(tx) row, ok := tx.Statement.Dest.(*sql.Row) if !ok && tx.DryRun { @@ -509,7 +509,7 @@ func (db *DB) Row() *sql.Row { } func (db *DB) Rows() (*sql.Rows, error) { - tx := db.getInstance().Set("rows", true) + tx := db.getInstance().PushQueryType(true) tx = tx.callbacks.Row().Execute(tx) rows, ok := tx.Statement.Dest.(*sql.Rows) if !ok && tx.DryRun && tx.Error == nil { diff --git a/gorm.go b/gorm.go index 9a70c3d2..25502b66 100644 --- a/gorm.go +++ b/gorm.go @@ -340,6 +340,22 @@ func (db *DB) InstanceGet(key string) (interface{}, bool) { return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key) } +func (db *DB) PushQueryType(rows bool) *DB { + tx := db.getInstance() + tx.Statement.queryTypes = append(tx.Statement.queryTypes, rows) + return tx +} + +func (db *DB) PopQueryType() bool { + length := len(db.Statement.queryTypes) + if length == 0 { + return false + } + value := db.Statement.queryTypes[length-1] + db.Statement.queryTypes = db.Statement.queryTypes[:length-1] + return value +} + // Callback returns callback manager func (db *DB) Callback() *callbacks { return db.callbacks diff --git a/statement.go b/statement.go index bc959f0b..095a3c13 100644 --- a/statement.go +++ b/statement.go @@ -46,6 +46,7 @@ type Statement struct { attrs []interface{} assigns []interface{} scopes []func(*DB) *DB + queryTypes []bool } type join struct { @@ -543,6 +544,11 @@ func (stmt *Statement) clone() *Statement { copy(newStmt.scopes, stmt.scopes) } + if len(stmt.queryTypes) > 0 { + newStmt.queryTypes = make([]bool, len(stmt.queryTypes)) + copy(newStmt.queryTypes, stmt.queryTypes) + } + stmt.Settings.Range(func(k, v interface{}) bool { newStmt.Settings.Store(k, v) return true diff --git a/tests/scopes_test.go b/tests/scopes_test.go index ab3807ea..61f4ef3c 100644 --- a/tests/scopes_test.go +++ b/tests/scopes_test.go @@ -71,4 +71,15 @@ func TestScopes(t *testing.T) { if err := DB.Scopes(userTable).Select("max(id)").Scan(&maxId).Error; err != nil { t.Errorf("select max(id)") } + + var user User + if err := DB.Scopes(func(db *gorm.DB) *gorm.DB { + var maxID int64 + if err := db.Raw("select max(id) from users").Scan(&maxID).Error; err != nil { + return db + } + return db.Raw("select * from users where id = ?", maxID) + }).Scan(&user).Error; err != nil { + t.Errorf("failed to find user, got err: %v", err) + } }