From 1a384b3c0c766342006cda84e653e14c78db689c Mon Sep 17 00:00:00 2001 From: Ian Tan Date: Thu, 23 Nov 2017 17:25:40 +0800 Subject: [PATCH] Use callbacks to record sql instead --- expecter.go | 97 +++++++++++++++++++++++++++++++++++------------------ 1 file changed, 65 insertions(+), 32 deletions(-) diff --git a/expecter.go b/expecter.go index 41d23ce3..947d9afb 100644 --- a/expecter.go +++ b/expecter.go @@ -1,19 +1,70 @@ package gorm import ( - "regexp" + "fmt" ) // Recorder satisfies the logger interface type Recorder struct { - stmts []Stmt + stmts []Stmt + preload []searchPreload // store it on Recorder } -// Stmt represents a sql statement. It can be an Exec or Query +// Stmt represents a sql statement. It can be an Exec, Query, or QueryRow type Stmt struct { - stmtType string - sql string - args []interface{} + kind string // can be Query, Exec, QueryRow + preload string // contains schema if it is a preload query + sql string + args []interface{} +} + +func recordQueryCallback(scope *Scope) { + r, ok := scope.Get("gorm:recorder") + + if !ok { + panic(fmt.Errorf("Expected a recorder to be set, but got none")) + } + + stmt := Stmt{ + kind: "query", + sql: scope.SQL, + args: scope.SQLVars, + } + + recorder := r.(*Recorder) + + if len(recorder.preload) > 0 { + // this will cause the scope.SQL to mutate to the preload query + scope.prepareQuerySQL() + stmt.preload = recorder.preload[0].schema + + // spew.Printf("_____PRELOADING_____\r\n%s\r\n", stmt.preload) + // spew.Printf("_____SQL_____\r\n%s\r\n", scope.SQL) + + // we just want to pop the first element off + recorder.preload = recorder.preload[1:] + } + + recorder.Record(stmt) +} + +func recordPreloadCallback(scope *Scope) { + // this callback runs _before_ gorm:preload + // it should record the next thing to be preloaded + recorder, ok := scope.Get("gorm:recorder") + + if !ok { + panic(fmt.Errorf("Expected a recorder to be set, but got none")) + } + if len(scope.Search.preload) > 0 { + // spew.Printf("callback:preload\r\n%s\r\n", spew.Sdump(scope.Search.preload)) + recorder.(*Recorder).preload = scope.Search.preload + } +} + +// Record records a Stmt for use when SQL is finally executed +func (r *Recorder) Record(stmt Stmt) { + r.stmts = append(r.stmts, stmt) } func getStmtFromLog(values ...interface{}) Stmt { @@ -72,7 +123,6 @@ type Expecter struct { adapter Adapter gorm *DB recorder *Recorder - preload []string // records fields to be preloaded } // NewDefaultExpecter returns a Expecter powered by go-sqlmock @@ -87,7 +137,7 @@ func NewDefaultExpecter() (*DB, *Expecter, error) { noop, _ := NewNoopDB() gorm := &DB{ db: noop, - logger: recorder, + logger: defaultLogger, logMode: 2, values: map[string]interface{}{}, callbacks: DefaultCallback, @@ -95,6 +145,11 @@ func NewDefaultExpecter() (*DB, *Expecter, error) { } gorm.parent = gorm + gorm = gorm.Set("gorm:recorder", recorder) + gorm = gorm.Set("gorm:preload_idx", 0) + gorm.Callback().Query().Before("gorm:preload").Register("gorm:record_preload", recordPreloadCallback) + gorm.Callback().Query().After("gorm:query").Register("gorm:record_query", recordQueryCallback) + gorm.Callback().RowQuery().Before("gorm:row_query").Register("gorm:record_query", recordQueryCallback) return gormDb, &Expecter{adapter: adapter, gorm: gorm, recorder: recorder}, nil } @@ -119,36 +174,14 @@ func (h *Expecter) AssertExpectations() error { // First triggers a Query func (h *Expecter) First(out interface{}, where ...interface{}) ExpectedQuery { - var q ExpectedQuery h.gorm.First(out, where...) - - if empty := h.recorder.IsEmpty(); empty { - panic("No recorded statements") - } - - for _, stmt := range h.recorder.stmts { - q = h.adapter.ExpectQuery(regexp.QuoteMeta(stmt.sql)) - } - - return q + return h.adapter.ExpectQuery(h.recorder.stmts...) } // Find triggers a Query func (h *Expecter) Find(out interface{}, where ...interface{}) ExpectedQuery { - var ( - stmts []string - ) h.gorm.Find(out, where...) - - if empty := h.recorder.IsEmpty(); empty { - panic("No recorded statements") - } - - for _, stmt := range h.recorder.stmts { - stmts = append(stmts, regexp.QuoteMeta(stmt.sql)) - } - - return h.adapter.ExpectQuery(stmts...) + return h.adapter.ExpectQuery(h.recorder.stmts...) } // Preload clones the expecter and sets a preload condition on gorm.DB