add mutex
This commit is contained in:
parent
62230f6c84
commit
aec9023a10
@ -11,7 +11,7 @@ func RowQuery(db *gorm.DB) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if isRows := db.PopQueryType(); isRows {
|
if isRows := db.Statement.QueryTypes.Pop(); isRows {
|
||||||
db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||||
} else {
|
} else {
|
||||||
db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||||
|
@ -499,7 +499,8 @@ func (db *DB) Count(count *int64) (tx *DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) Row() *sql.Row {
|
func (db *DB) Row() *sql.Row {
|
||||||
tx := db.getInstance().PushQueryType(false)
|
tx := db.getInstance()
|
||||||
|
tx.Statement.QueryTypes.Push(false)
|
||||||
tx = tx.callbacks.Row().Execute(tx)
|
tx = tx.callbacks.Row().Execute(tx)
|
||||||
row, ok := tx.Statement.Dest.(*sql.Row)
|
row, ok := tx.Statement.Dest.(*sql.Row)
|
||||||
if !ok && tx.DryRun {
|
if !ok && tx.DryRun {
|
||||||
@ -509,7 +510,8 @@ func (db *DB) Row() *sql.Row {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) Rows() (*sql.Rows, error) {
|
func (db *DB) Rows() (*sql.Rows, error) {
|
||||||
tx := db.getInstance().PushQueryType(true)
|
tx := db.getInstance()
|
||||||
|
tx.Statement.QueryTypes.Push(true)
|
||||||
tx = tx.callbacks.Row().Execute(tx)
|
tx = tx.callbacks.Row().Execute(tx)
|
||||||
rows, ok := tx.Statement.Dest.(*sql.Rows)
|
rows, ok := tx.Statement.Dest.(*sql.Rows)
|
||||||
if !ok && tx.DryRun && tx.Error == nil {
|
if !ok && tx.DryRun && tx.Error == nil {
|
||||||
|
16
gorm.go
16
gorm.go
@ -340,22 +340,6 @@ func (db *DB) InstanceGet(key string) (interface{}, bool) {
|
|||||||
return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key)
|
return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) PushQueryType(rows bool) *DB {
|
|
||||||
tx := db.getInstance()
|
|
||||||
tx.Statement.queryTypes = append(tx.Statement.queryTypes, rows)
|
|
||||||
return tx
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *DB) PopQueryType() bool {
|
|
||||||
length := len(db.Statement.queryTypes)
|
|
||||||
if length == 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
value := db.Statement.queryTypes[length-1]
|
|
||||||
db.Statement.queryTypes = db.Statement.queryTypes[:length-1]
|
|
||||||
return value
|
|
||||||
}
|
|
||||||
|
|
||||||
// Callback returns callback manager
|
// Callback returns callback manager
|
||||||
func (db *DB) Callback() *callbacks {
|
func (db *DB) Callback() *callbacks {
|
||||||
return db.callbacks
|
return db.callbacks
|
||||||
|
44
statement.go
44
statement.go
@ -34,6 +34,7 @@ type Statement struct {
|
|||||||
Omits []string // omit columns
|
Omits []string // omit columns
|
||||||
Joins []join
|
Joins []join
|
||||||
Preloads map[string][]interface{}
|
Preloads map[string][]interface{}
|
||||||
|
QueryTypes QueryTypes
|
||||||
Settings sync.Map
|
Settings sync.Map
|
||||||
ConnPool ConnPool
|
ConnPool ConnPool
|
||||||
Schema *schema.Schema
|
Schema *schema.Schema
|
||||||
@ -46,7 +47,6 @@ type Statement struct {
|
|||||||
attrs []interface{}
|
attrs []interface{}
|
||||||
assigns []interface{}
|
assigns []interface{}
|
||||||
scopes []func(*DB) *DB
|
scopes []func(*DB) *DB
|
||||||
queryTypes []bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type join struct {
|
type join struct {
|
||||||
@ -58,6 +58,43 @@ type join struct {
|
|||||||
JoinType clause.JoinType
|
JoinType clause.JoinType
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type QueryTypes struct {
|
||||||
|
mux sync.Mutex
|
||||||
|
values []bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *QueryTypes) Push(isRows bool) {
|
||||||
|
q.mux.Lock()
|
||||||
|
defer q.mux.Unlock()
|
||||||
|
q.values = append(q.values, isRows)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *QueryTypes) Pop() bool {
|
||||||
|
q.mux.Lock()
|
||||||
|
defer q.mux.Unlock()
|
||||||
|
|
||||||
|
if len(q.values) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
value := q.values[len(q.values)-1]
|
||||||
|
q.values = q.values[:len(q.values)-1]
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *QueryTypes) clone() QueryTypes {
|
||||||
|
q.mux.Lock()
|
||||||
|
defer q.mux.Unlock()
|
||||||
|
|
||||||
|
if len(q.values) == 0 {
|
||||||
|
return QueryTypes{}
|
||||||
|
}
|
||||||
|
|
||||||
|
values := make([]bool, len(q.values))
|
||||||
|
copy(values, q.values)
|
||||||
|
return QueryTypes{values: values}
|
||||||
|
}
|
||||||
|
|
||||||
// StatementModifier statement modifier interface
|
// StatementModifier statement modifier interface
|
||||||
type StatementModifier interface {
|
type StatementModifier interface {
|
||||||
ModifyStatement(*Statement)
|
ModifyStatement(*Statement)
|
||||||
@ -544,10 +581,7 @@ func (stmt *Statement) clone() *Statement {
|
|||||||
copy(newStmt.scopes, stmt.scopes)
|
copy(newStmt.scopes, stmt.scopes)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(stmt.queryTypes) > 0 {
|
newStmt.QueryTypes = stmt.QueryTypes.clone()
|
||||||
newStmt.queryTypes = make([]bool, len(stmt.queryTypes))
|
|
||||||
copy(newStmt.queryTypes, stmt.queryTypes)
|
|
||||||
}
|
|
||||||
|
|
||||||
stmt.Settings.Range(func(k, v interface{}) bool {
|
stmt.Settings.Range(func(k, v interface{}) bool {
|
||||||
newStmt.Settings.Store(k, v)
|
newStmt.Settings.Store(k, v)
|
||||||
|
@ -82,4 +82,14 @@ func TestScopes(t *testing.T) {
|
|||||||
}).Scan(&user).Error; err != nil {
|
}).Scan(&user).Error; err != nil {
|
||||||
t.Errorf("failed to find user, got err: %v", err)
|
t.Errorf("failed to find user, got err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := DB.Scopes(func(db *gorm.DB) *gorm.DB {
|
||||||
|
var maxID int64
|
||||||
|
if err := db.Model(&User{}).Select("max(id)").Scan(&maxID).Error; err != nil {
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
return db.Where("id = ?", maxID)
|
||||||
|
}).Scan(&user).Error; err != nil {
|
||||||
|
t.Errorf("failed to find user, got err: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user