From aec9023a104189539996c23521cfffc66200f367 Mon Sep 17 00:00:00 2001 From: black Date: Mon, 13 Mar 2023 16:34:50 +0800 Subject: [PATCH] add mutex --- callbacks/row.go | 2 +- finisher_api.go | 6 ++++-- gorm.go | 16 ---------------- statement.go | 44 +++++++++++++++++++++++++++++++++++++++----- tests/scopes_test.go | 10 ++++++++++ 5 files changed, 54 insertions(+), 24 deletions(-) diff --git a/callbacks/row.go b/callbacks/row.go index 19510716..77c93e78 100644 --- a/callbacks/row.go +++ b/callbacks/row.go @@ -11,7 +11,7 @@ func RowQuery(db *gorm.DB) { return } - if isRows := db.PopQueryType(); isRows { + if isRows := db.Statement.QueryTypes.Pop(); isRows { db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) } else { db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) diff --git a/finisher_api.go b/finisher_api.go index 0fa7d79f..935a0268 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -499,7 +499,8 @@ func (db *DB) Count(count *int64) (tx *DB) { } func (db *DB) Row() *sql.Row { - tx := db.getInstance().PushQueryType(false) + tx := db.getInstance() + tx.Statement.QueryTypes.Push(false) tx = tx.callbacks.Row().Execute(tx) row, ok := tx.Statement.Dest.(*sql.Row) if !ok && tx.DryRun { @@ -509,7 +510,8 @@ func (db *DB) Row() *sql.Row { } func (db *DB) Rows() (*sql.Rows, error) { - tx := db.getInstance().PushQueryType(true) + tx := db.getInstance() + tx.Statement.QueryTypes.Push(true) tx = tx.callbacks.Row().Execute(tx) rows, ok := tx.Statement.Dest.(*sql.Rows) if !ok && tx.DryRun && tx.Error == nil { diff --git a/gorm.go b/gorm.go index 25502b66..9a70c3d2 100644 --- a/gorm.go +++ b/gorm.go @@ -340,22 +340,6 @@ func (db *DB) InstanceGet(key string) (interface{}, bool) { return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key) } -func (db *DB) PushQueryType(rows bool) *DB { - tx := db.getInstance() - tx.Statement.queryTypes = append(tx.Statement.queryTypes, rows) - return tx -} - -func (db *DB) PopQueryType() bool { - length := len(db.Statement.queryTypes) - if length == 0 { - return false - } - value := db.Statement.queryTypes[length-1] - db.Statement.queryTypes = db.Statement.queryTypes[:length-1] - return value -} - // Callback returns callback manager func (db *DB) Callback() *callbacks { return db.callbacks diff --git a/statement.go b/statement.go index 095a3c13..497cfedc 100644 --- a/statement.go +++ b/statement.go @@ -34,6 +34,7 @@ type Statement struct { Omits []string // omit columns Joins []join Preloads map[string][]interface{} + QueryTypes QueryTypes Settings sync.Map ConnPool ConnPool Schema *schema.Schema @@ -46,7 +47,6 @@ type Statement struct { attrs []interface{} assigns []interface{} scopes []func(*DB) *DB - queryTypes []bool } type join struct { @@ -58,6 +58,43 @@ type join struct { JoinType clause.JoinType } +type QueryTypes struct { + mux sync.Mutex + values []bool +} + +func (q *QueryTypes) Push(isRows bool) { + q.mux.Lock() + defer q.mux.Unlock() + q.values = append(q.values, isRows) +} + +func (q *QueryTypes) Pop() bool { + q.mux.Lock() + defer q.mux.Unlock() + + if len(q.values) == 0 { + return false + } + + value := q.values[len(q.values)-1] + q.values = q.values[:len(q.values)-1] + return value +} + +func (q *QueryTypes) clone() QueryTypes { + q.mux.Lock() + defer q.mux.Unlock() + + if len(q.values) == 0 { + return QueryTypes{} + } + + values := make([]bool, len(q.values)) + copy(values, q.values) + return QueryTypes{values: values} +} + // StatementModifier statement modifier interface type StatementModifier interface { ModifyStatement(*Statement) @@ -544,10 +581,7 @@ func (stmt *Statement) clone() *Statement { copy(newStmt.scopes, stmt.scopes) } - if len(stmt.queryTypes) > 0 { - newStmt.queryTypes = make([]bool, len(stmt.queryTypes)) - copy(newStmt.queryTypes, stmt.queryTypes) - } + newStmt.QueryTypes = stmt.QueryTypes.clone() stmt.Settings.Range(func(k, v interface{}) bool { newStmt.Settings.Store(k, v) diff --git a/tests/scopes_test.go b/tests/scopes_test.go index 61f4ef3c..25257918 100644 --- a/tests/scopes_test.go +++ b/tests/scopes_test.go @@ -82,4 +82,14 @@ func TestScopes(t *testing.T) { }).Scan(&user).Error; err != nil { t.Errorf("failed to find user, got err: %v", err) } + + if err := DB.Scopes(func(db *gorm.DB) *gorm.DB { + var maxID int64 + if err := db.Model(&User{}).Select("max(id)").Scan(&maxID).Error; err != nil { + return db + } + return db.Where("id = ?", maxID) + }).Scan(&user).Error; err != nil { + t.Errorf("failed to find user, got err: %v", err) + } }