From 672c48b74c5099c5a5f80ea7ffed2289d5508ec6 Mon Sep 17 00:00:00 2001 From: black Date: Mon, 13 Mar 2023 17:33:57 +0800 Subject: [PATCH] avoid adding attributes --- callbacks/row.go | 2 +- finisher_api.go | 10 ++++++++-- statement.go | 19 ++++++++++++------- statement_test.go | 4 ++-- 4 files changed, 23 insertions(+), 12 deletions(-) diff --git a/callbacks/row.go b/callbacks/row.go index 77c93e78..5893ee2a 100644 --- a/callbacks/row.go +++ b/callbacks/row.go @@ -11,7 +11,7 @@ func RowQuery(db *gorm.DB) { return } - if isRows := db.Statement.QueryTypes.Pop(); isRows { + if types, ok := db.Statement.Settings.Load("rows"); ok && types.(*gorm.QueryTypes).Pop() { db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) } else { db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) diff --git a/finisher_api.go b/finisher_api.go index 935a0268..62e6523f 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -500,7 +500,10 @@ func (db *DB) Count(count *int64) (tx *DB) { func (db *DB) Row() *sql.Row { tx := db.getInstance() - tx.Statement.QueryTypes.Push(false) + + value, _ := tx.Statement.Settings.LoadOrStore("rows", &QueryTypes{}) + value.(*QueryTypes).Push(false) + tx = tx.callbacks.Row().Execute(tx) row, ok := tx.Statement.Dest.(*sql.Row) if !ok && tx.DryRun { @@ -511,7 +514,10 @@ func (db *DB) Row() *sql.Row { func (db *DB) Rows() (*sql.Rows, error) { tx := db.getInstance() - tx.Statement.QueryTypes.Push(true) + + value, _ := tx.Statement.Settings.LoadOrStore("rows", &QueryTypes{}) + value.(*QueryTypes).Push(true) + tx = tx.callbacks.Row().Execute(tx) rows, ok := tx.Statement.Dest.(*sql.Rows) if !ok && tx.DryRun && tx.Error == nil { diff --git a/statement.go b/statement.go index 25029c48..7f1954b9 100644 --- a/statement.go +++ b/statement.go @@ -35,7 +35,6 @@ type Statement struct { Omits []string // omit columns Joins []join Preloads map[string][]interface{} - QueryTypes QueryTypes Settings sync.Map ConnPool ConnPool Schema *schema.Schema @@ -88,19 +87,19 @@ func (q *QueryTypes) Pop() bool { return element.Value.(bool) } -func (q *QueryTypes) clone() QueryTypes { +func (q *QueryTypes) Clone() interface{} { q.mux.Lock() defer q.mux.Unlock() if q.list == nil { - return QueryTypes{} + return &QueryTypes{} } cloneList := list.New() for e := q.list.Front(); e != nil; e = e.Next() { cloneList.PushFront(e.Value) } - return QueryTypes{list: cloneList} + return &QueryTypes{list: cloneList} } // StatementModifier statement modifier interface @@ -589,16 +588,22 @@ func (stmt *Statement) clone() *Statement { copy(newStmt.scopes, stmt.scopes) } - newStmt.QueryTypes = stmt.QueryTypes.clone() - stmt.Settings.Range(func(k, v interface{}) bool { - newStmt.Settings.Store(k, v) + if cloneable, ok := v.(Cloneable); ok { + newStmt.Settings.Store(k, cloneable.Clone()) + } else { + newStmt.Settings.Store(k, v) + } return true }) return newStmt } +type Cloneable interface { + Clone() interface{} +} + // SetColumn set column's value // // stmt.SetColumn("Name", "jinzhu") // Hooks Method diff --git a/statement_test.go b/statement_test.go index a6b5f1c5..84ccc2a5 100644 --- a/statement_test.go +++ b/statement_test.go @@ -64,13 +64,13 @@ func TestNameMatcher(t *testing.T) { } func TestQueryTypes(t *testing.T) { - types := QueryTypes{} + types := &QueryTypes{} values := []bool{true, false, false, true} for _, value := range values { types.Push(value) } - clone := types.clone() + clone := types.Clone().(*QueryTypes) for _, value := range values { actual := clone.Pop() if actual != value {