avoid adding attributes

This commit is contained in:
black 2023-03-13 17:33:57 +08:00
parent 92a360708d
commit 672c48b74c
4 changed files with 23 additions and 12 deletions

View File

@ -11,7 +11,7 @@ func RowQuery(db *gorm.DB) {
return return
} }
if isRows := db.Statement.QueryTypes.Pop(); isRows { if types, ok := db.Statement.Settings.Load("rows"); ok && types.(*gorm.QueryTypes).Pop() {
db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
} else { } else {
db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)

View File

@ -500,7 +500,10 @@ func (db *DB) Count(count *int64) (tx *DB) {
func (db *DB) Row() *sql.Row { func (db *DB) Row() *sql.Row {
tx := db.getInstance() tx := db.getInstance()
tx.Statement.QueryTypes.Push(false)
value, _ := tx.Statement.Settings.LoadOrStore("rows", &QueryTypes{})
value.(*QueryTypes).Push(false)
tx = tx.callbacks.Row().Execute(tx) tx = tx.callbacks.Row().Execute(tx)
row, ok := tx.Statement.Dest.(*sql.Row) row, ok := tx.Statement.Dest.(*sql.Row)
if !ok && tx.DryRun { if !ok && tx.DryRun {
@ -511,7 +514,10 @@ func (db *DB) Row() *sql.Row {
func (db *DB) Rows() (*sql.Rows, error) { func (db *DB) Rows() (*sql.Rows, error) {
tx := db.getInstance() tx := db.getInstance()
tx.Statement.QueryTypes.Push(true)
value, _ := tx.Statement.Settings.LoadOrStore("rows", &QueryTypes{})
value.(*QueryTypes).Push(true)
tx = tx.callbacks.Row().Execute(tx) tx = tx.callbacks.Row().Execute(tx)
rows, ok := tx.Statement.Dest.(*sql.Rows) rows, ok := tx.Statement.Dest.(*sql.Rows)
if !ok && tx.DryRun && tx.Error == nil { if !ok && tx.DryRun && tx.Error == nil {

View File

@ -35,7 +35,6 @@ type Statement struct {
Omits []string // omit columns Omits []string // omit columns
Joins []join Joins []join
Preloads map[string][]interface{} Preloads map[string][]interface{}
QueryTypes QueryTypes
Settings sync.Map Settings sync.Map
ConnPool ConnPool ConnPool ConnPool
Schema *schema.Schema Schema *schema.Schema
@ -88,19 +87,19 @@ func (q *QueryTypes) Pop() bool {
return element.Value.(bool) return element.Value.(bool)
} }
func (q *QueryTypes) clone() QueryTypes { func (q *QueryTypes) Clone() interface{} {
q.mux.Lock() q.mux.Lock()
defer q.mux.Unlock() defer q.mux.Unlock()
if q.list == nil { if q.list == nil {
return QueryTypes{} return &QueryTypes{}
} }
cloneList := list.New() cloneList := list.New()
for e := q.list.Front(); e != nil; e = e.Next() { for e := q.list.Front(); e != nil; e = e.Next() {
cloneList.PushFront(e.Value) cloneList.PushFront(e.Value)
} }
return QueryTypes{list: cloneList} return &QueryTypes{list: cloneList}
} }
// StatementModifier statement modifier interface // StatementModifier statement modifier interface
@ -589,16 +588,22 @@ func (stmt *Statement) clone() *Statement {
copy(newStmt.scopes, stmt.scopes) copy(newStmt.scopes, stmt.scopes)
} }
newStmt.QueryTypes = stmt.QueryTypes.clone()
stmt.Settings.Range(func(k, v interface{}) bool { stmt.Settings.Range(func(k, v interface{}) bool {
if cloneable, ok := v.(Cloneable); ok {
newStmt.Settings.Store(k, cloneable.Clone())
} else {
newStmt.Settings.Store(k, v) newStmt.Settings.Store(k, v)
}
return true return true
}) })
return newStmt return newStmt
} }
type Cloneable interface {
Clone() interface{}
}
// SetColumn set column's value // SetColumn set column's value
// //
// stmt.SetColumn("Name", "jinzhu") // Hooks Method // stmt.SetColumn("Name", "jinzhu") // Hooks Method

View File

@ -64,13 +64,13 @@ func TestNameMatcher(t *testing.T) {
} }
func TestQueryTypes(t *testing.T) { func TestQueryTypes(t *testing.T) {
types := QueryTypes{} types := &QueryTypes{}
values := []bool{true, false, false, true} values := []bool{true, false, false, true}
for _, value := range values { for _, value := range values {
types.Push(value) types.Push(value)
} }
clone := types.clone() clone := types.Clone().(*QueryTypes)
for _, value := range values { for _, value := range values {
actual := clone.Pop() actual := clone.Pop()
if actual != value { if actual != value {