Use callbacks to record sql instead

This commit is contained in:
Ian Tan 2017-11-23 17:25:40 +08:00
parent c2a28c63c3
commit 1a384b3c0c

View File

@ -1,21 +1,72 @@
package gorm package gorm
import ( import (
"regexp" "fmt"
) )
// Recorder satisfies the logger interface // Recorder satisfies the logger interface
type Recorder struct { type Recorder struct {
stmts []Stmt stmts []Stmt
preload []searchPreload // store it on Recorder
} }
// Stmt represents a sql statement. It can be an Exec or Query // Stmt represents a sql statement. It can be an Exec, Query, or QueryRow
type Stmt struct { type Stmt struct {
stmtType string kind string // can be Query, Exec, QueryRow
preload string // contains schema if it is a preload query
sql string sql string
args []interface{} args []interface{}
} }
func recordQueryCallback(scope *Scope) {
r, ok := scope.Get("gorm:recorder")
if !ok {
panic(fmt.Errorf("Expected a recorder to be set, but got none"))
}
stmt := Stmt{
kind: "query",
sql: scope.SQL,
args: scope.SQLVars,
}
recorder := r.(*Recorder)
if len(recorder.preload) > 0 {
// this will cause the scope.SQL to mutate to the preload query
scope.prepareQuerySQL()
stmt.preload = recorder.preload[0].schema
// spew.Printf("_____PRELOADING_____\r\n%s\r\n", stmt.preload)
// spew.Printf("_____SQL_____\r\n%s\r\n", scope.SQL)
// we just want to pop the first element off
recorder.preload = recorder.preload[1:]
}
recorder.Record(stmt)
}
func recordPreloadCallback(scope *Scope) {
// this callback runs _before_ gorm:preload
// it should record the next thing to be preloaded
recorder, ok := scope.Get("gorm:recorder")
if !ok {
panic(fmt.Errorf("Expected a recorder to be set, but got none"))
}
if len(scope.Search.preload) > 0 {
// spew.Printf("callback:preload\r\n%s\r\n", spew.Sdump(scope.Search.preload))
recorder.(*Recorder).preload = scope.Search.preload
}
}
// Record records a Stmt for use when SQL is finally executed
func (r *Recorder) Record(stmt Stmt) {
r.stmts = append(r.stmts, stmt)
}
func getStmtFromLog(values ...interface{}) Stmt { func getStmtFromLog(values ...interface{}) Stmt {
var statement Stmt var statement Stmt
@ -72,7 +123,6 @@ type Expecter struct {
adapter Adapter adapter Adapter
gorm *DB gorm *DB
recorder *Recorder recorder *Recorder
preload []string // records fields to be preloaded
} }
// NewDefaultExpecter returns a Expecter powered by go-sqlmock // NewDefaultExpecter returns a Expecter powered by go-sqlmock
@ -87,7 +137,7 @@ func NewDefaultExpecter() (*DB, *Expecter, error) {
noop, _ := NewNoopDB() noop, _ := NewNoopDB()
gorm := &DB{ gorm := &DB{
db: noop, db: noop,
logger: recorder, logger: defaultLogger,
logMode: 2, logMode: 2,
values: map[string]interface{}{}, values: map[string]interface{}{},
callbacks: DefaultCallback, callbacks: DefaultCallback,
@ -95,6 +145,11 @@ func NewDefaultExpecter() (*DB, *Expecter, error) {
} }
gorm.parent = gorm gorm.parent = gorm
gorm = gorm.Set("gorm:recorder", recorder)
gorm = gorm.Set("gorm:preload_idx", 0)
gorm.Callback().Query().Before("gorm:preload").Register("gorm:record_preload", recordPreloadCallback)
gorm.Callback().Query().After("gorm:query").Register("gorm:record_query", recordQueryCallback)
gorm.Callback().RowQuery().Before("gorm:row_query").Register("gorm:record_query", recordQueryCallback)
return gormDb, &Expecter{adapter: adapter, gorm: gorm, recorder: recorder}, nil return gormDb, &Expecter{adapter: adapter, gorm: gorm, recorder: recorder}, nil
} }
@ -119,36 +174,14 @@ func (h *Expecter) AssertExpectations() error {
// First triggers a Query // First triggers a Query
func (h *Expecter) First(out interface{}, where ...interface{}) ExpectedQuery { func (h *Expecter) First(out interface{}, where ...interface{}) ExpectedQuery {
var q ExpectedQuery
h.gorm.First(out, where...) h.gorm.First(out, where...)
return h.adapter.ExpectQuery(h.recorder.stmts...)
if empty := h.recorder.IsEmpty(); empty {
panic("No recorded statements")
}
for _, stmt := range h.recorder.stmts {
q = h.adapter.ExpectQuery(regexp.QuoteMeta(stmt.sql))
}
return q
} }
// Find triggers a Query // Find triggers a Query
func (h *Expecter) Find(out interface{}, where ...interface{}) ExpectedQuery { func (h *Expecter) Find(out interface{}, where ...interface{}) ExpectedQuery {
var (
stmts []string
)
h.gorm.Find(out, where...) h.gorm.Find(out, where...)
return h.adapter.ExpectQuery(h.recorder.stmts...)
if empty := h.recorder.IsEmpty(); empty {
panic("No recorded statements")
}
for _, stmt := range h.recorder.stmts {
stmts = append(stmts, regexp.QuoteMeta(stmt.sql))
}
return h.adapter.ExpectQuery(stmts...)
} }
// Preload clones the expecter and sets a preload condition on gorm.DB // Preload clones the expecter and sets a preload condition on gorm.DB