avoid adding attributes

This commit is contained in:
black 2023-03-13 17:33:57 +08:00
parent 92a360708d
commit 672c48b74c
4 changed files with 23 additions and 12 deletions

View File

@ -11,7 +11,7 @@ func RowQuery(db *gorm.DB) {
return
}
if isRows := db.Statement.QueryTypes.Pop(); isRows {
if types, ok := db.Statement.Settings.Load("rows"); ok && types.(*gorm.QueryTypes).Pop() {
db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
} else {
db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)

View File

@ -500,7 +500,10 @@ func (db *DB) Count(count *int64) (tx *DB) {
func (db *DB) Row() *sql.Row {
tx := db.getInstance()
tx.Statement.QueryTypes.Push(false)
value, _ := tx.Statement.Settings.LoadOrStore("rows", &QueryTypes{})
value.(*QueryTypes).Push(false)
tx = tx.callbacks.Row().Execute(tx)
row, ok := tx.Statement.Dest.(*sql.Row)
if !ok && tx.DryRun {
@ -511,7 +514,10 @@ func (db *DB) Row() *sql.Row {
func (db *DB) Rows() (*sql.Rows, error) {
tx := db.getInstance()
tx.Statement.QueryTypes.Push(true)
value, _ := tx.Statement.Settings.LoadOrStore("rows", &QueryTypes{})
value.(*QueryTypes).Push(true)
tx = tx.callbacks.Row().Execute(tx)
rows, ok := tx.Statement.Dest.(*sql.Rows)
if !ok && tx.DryRun && tx.Error == nil {

View File

@ -35,7 +35,6 @@ type Statement struct {
Omits []string // omit columns
Joins []join
Preloads map[string][]interface{}
QueryTypes QueryTypes
Settings sync.Map
ConnPool ConnPool
Schema *schema.Schema
@ -88,19 +87,19 @@ func (q *QueryTypes) Pop() bool {
return element.Value.(bool)
}
func (q *QueryTypes) clone() QueryTypes {
func (q *QueryTypes) Clone() interface{} {
q.mux.Lock()
defer q.mux.Unlock()
if q.list == nil {
return QueryTypes{}
return &QueryTypes{}
}
cloneList := list.New()
for e := q.list.Front(); e != nil; e = e.Next() {
cloneList.PushFront(e.Value)
}
return QueryTypes{list: cloneList}
return &QueryTypes{list: cloneList}
}
// StatementModifier statement modifier interface
@ -589,16 +588,22 @@ func (stmt *Statement) clone() *Statement {
copy(newStmt.scopes, stmt.scopes)
}
newStmt.QueryTypes = stmt.QueryTypes.clone()
stmt.Settings.Range(func(k, v interface{}) bool {
newStmt.Settings.Store(k, v)
if cloneable, ok := v.(Cloneable); ok {
newStmt.Settings.Store(k, cloneable.Clone())
} else {
newStmt.Settings.Store(k, v)
}
return true
})
return newStmt
}
type Cloneable interface {
Clone() interface{}
}
// SetColumn set column's value
//
// stmt.SetColumn("Name", "jinzhu") // Hooks Method

View File

@ -64,13 +64,13 @@ func TestNameMatcher(t *testing.T) {
}
func TestQueryTypes(t *testing.T) {
types := QueryTypes{}
types := &QueryTypes{}
values := []bool{true, false, false, true}
for _, value := range values {
types.Push(value)
}
clone := types.clone()
clone := types.Clone().(*QueryTypes)
for _, value := range values {
actual := clone.Pop()
if actual != value {