Use callbacks to record sql instead
This commit is contained in:
		
							parent
							
								
									c2a28c63c3
								
							
						
					
					
						commit
						1a384b3c0c
					
				
							
								
								
									
										97
									
								
								expecter.go
									
									
									
									
									
								
							
							
						
						
									
										97
									
								
								expecter.go
									
									
									
									
									
								
							@ -1,19 +1,70 @@
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"regexp"
 | 
			
		||||
	"fmt"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Recorder satisfies the logger interface
 | 
			
		||||
type Recorder struct {
 | 
			
		||||
	stmts []Stmt
 | 
			
		||||
	stmts   []Stmt
 | 
			
		||||
	preload []searchPreload // store it on Recorder
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Stmt represents a sql statement. It can be an Exec or Query
 | 
			
		||||
// Stmt represents a sql statement. It can be an Exec, Query, or QueryRow
 | 
			
		||||
type Stmt struct {
 | 
			
		||||
	stmtType string
 | 
			
		||||
	sql      string
 | 
			
		||||
	args     []interface{}
 | 
			
		||||
	kind    string // can be Query, Exec, QueryRow
 | 
			
		||||
	preload string // contains schema if it is a preload query
 | 
			
		||||
	sql     string
 | 
			
		||||
	args    []interface{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func recordQueryCallback(scope *Scope) {
 | 
			
		||||
	r, ok := scope.Get("gorm:recorder")
 | 
			
		||||
 | 
			
		||||
	if !ok {
 | 
			
		||||
		panic(fmt.Errorf("Expected a recorder to be set, but got none"))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	stmt := Stmt{
 | 
			
		||||
		kind: "query",
 | 
			
		||||
		sql:  scope.SQL,
 | 
			
		||||
		args: scope.SQLVars,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	recorder := r.(*Recorder)
 | 
			
		||||
 | 
			
		||||
	if len(recorder.preload) > 0 {
 | 
			
		||||
		// this will cause the scope.SQL to mutate to the preload query
 | 
			
		||||
		scope.prepareQuerySQL()
 | 
			
		||||
		stmt.preload = recorder.preload[0].schema
 | 
			
		||||
 | 
			
		||||
		// spew.Printf("_____PRELOADING_____\r\n%s\r\n", stmt.preload)
 | 
			
		||||
		// spew.Printf("_____SQL_____\r\n%s\r\n", scope.SQL)
 | 
			
		||||
 | 
			
		||||
		// we just want to pop the first element off
 | 
			
		||||
		recorder.preload = recorder.preload[1:]
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	recorder.Record(stmt)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func recordPreloadCallback(scope *Scope) {
 | 
			
		||||
	// this callback runs _before_ gorm:preload
 | 
			
		||||
	// it should record the next thing to be preloaded
 | 
			
		||||
	recorder, ok := scope.Get("gorm:recorder")
 | 
			
		||||
 | 
			
		||||
	if !ok {
 | 
			
		||||
		panic(fmt.Errorf("Expected a recorder to be set, but got none"))
 | 
			
		||||
	}
 | 
			
		||||
	if len(scope.Search.preload) > 0 {
 | 
			
		||||
		// spew.Printf("callback:preload\r\n%s\r\n", spew.Sdump(scope.Search.preload))
 | 
			
		||||
		recorder.(*Recorder).preload = scope.Search.preload
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Record records a Stmt for use when SQL is finally executed
 | 
			
		||||
func (r *Recorder) Record(stmt Stmt) {
 | 
			
		||||
	r.stmts = append(r.stmts, stmt)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getStmtFromLog(values ...interface{}) Stmt {
 | 
			
		||||
@ -72,7 +123,6 @@ type Expecter struct {
 | 
			
		||||
	adapter  Adapter
 | 
			
		||||
	gorm     *DB
 | 
			
		||||
	recorder *Recorder
 | 
			
		||||
	preload  []string // records fields to be preloaded
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewDefaultExpecter returns a Expecter powered by go-sqlmock
 | 
			
		||||
@ -87,7 +137,7 @@ func NewDefaultExpecter() (*DB, *Expecter, error) {
 | 
			
		||||
	noop, _ := NewNoopDB()
 | 
			
		||||
	gorm := &DB{
 | 
			
		||||
		db:        noop,
 | 
			
		||||
		logger:    recorder,
 | 
			
		||||
		logger:    defaultLogger,
 | 
			
		||||
		logMode:   2,
 | 
			
		||||
		values:    map[string]interface{}{},
 | 
			
		||||
		callbacks: DefaultCallback,
 | 
			
		||||
@ -95,6 +145,11 @@ func NewDefaultExpecter() (*DB, *Expecter, error) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	gorm.parent = gorm
 | 
			
		||||
	gorm = gorm.Set("gorm:recorder", recorder)
 | 
			
		||||
	gorm = gorm.Set("gorm:preload_idx", 0)
 | 
			
		||||
	gorm.Callback().Query().Before("gorm:preload").Register("gorm:record_preload", recordPreloadCallback)
 | 
			
		||||
	gorm.Callback().Query().After("gorm:query").Register("gorm:record_query", recordQueryCallback)
 | 
			
		||||
	gorm.Callback().RowQuery().Before("gorm:row_query").Register("gorm:record_query", recordQueryCallback)
 | 
			
		||||
 | 
			
		||||
	return gormDb, &Expecter{adapter: adapter, gorm: gorm, recorder: recorder}, nil
 | 
			
		||||
}
 | 
			
		||||
@ -119,36 +174,14 @@ func (h *Expecter) AssertExpectations() error {
 | 
			
		||||
 | 
			
		||||
// First triggers a Query
 | 
			
		||||
func (h *Expecter) First(out interface{}, where ...interface{}) ExpectedQuery {
 | 
			
		||||
	var q ExpectedQuery
 | 
			
		||||
	h.gorm.First(out, where...)
 | 
			
		||||
 | 
			
		||||
	if empty := h.recorder.IsEmpty(); empty {
 | 
			
		||||
		panic("No recorded statements")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, stmt := range h.recorder.stmts {
 | 
			
		||||
		q = h.adapter.ExpectQuery(regexp.QuoteMeta(stmt.sql))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return q
 | 
			
		||||
	return h.adapter.ExpectQuery(h.recorder.stmts...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Find triggers a Query
 | 
			
		||||
func (h *Expecter) Find(out interface{}, where ...interface{}) ExpectedQuery {
 | 
			
		||||
	var (
 | 
			
		||||
		stmts []string
 | 
			
		||||
	)
 | 
			
		||||
	h.gorm.Find(out, where...)
 | 
			
		||||
 | 
			
		||||
	if empty := h.recorder.IsEmpty(); empty {
 | 
			
		||||
		panic("No recorded statements")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, stmt := range h.recorder.stmts {
 | 
			
		||||
		stmts = append(stmts, regexp.QuoteMeta(stmt.sql))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return h.adapter.ExpectQuery(stmts...)
 | 
			
		||||
	return h.adapter.ExpectQuery(h.recorder.stmts...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Preload clones the expecter and sets a preload condition on gorm.DB
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user