From a3264c8fdb9fed1483ebb332ba9c9d4f195bd116 Mon Sep 17 00:00:00 2001 From: Sukharev Maxim Date: Wed, 4 Jan 2017 14:23:43 +0700 Subject: [PATCH] Separate callback for preSQL and postSQL --- callback.go | 31 +++++++++++++++++++++---------- callback_create.go | 3 ++- callback_query.go | 3 ++- callback_trace.go | 29 +++++++++++++++++++++++++++++ scope.go | 24 +++++++++--------------- 5 files changed, 63 insertions(+), 27 deletions(-) create mode 100644 callback_trace.go diff --git a/callback.go b/callback.go index 7b7cb604..46ce6f62 100644 --- a/callback.go +++ b/callback.go @@ -13,7 +13,8 @@ var DefaultCallback = &Callback{} // Field `deletes` contains callbacks will be call when deleting object // Field `queries` contains callbacks will be call when querying object with query methods like Find, First, Related, Association... // Field `rowQueries` contains callbacks will be call when querying object with Row, Rows... -// Field `trace` contains callbacks will be call after any sql query was executed +// Field `beforeSQL` contains callbacks will be call before any sql query was executed +// Field `afterSQL` contains callbacks will be call after any sql query was executed // Field `processors` contains all callback processors, will be used to generate above callbacks in order type Callback struct { creates []*func(scope *Scope) @@ -21,7 +22,8 @@ type Callback struct { deletes []*func(scope *Scope) queries []*func(scope *Scope) rowQueries []*func(scope *Scope) - trace []*func(scope *Scope) + beforeSQL []*func(scope *Scope) + afterSQL []*func(scope *Scope) processors []*CallbackProcessor } @@ -45,7 +47,8 @@ func (c *Callback) clone() *Callback { queries: c.queries, rowQueries: c.rowQueries, processors: c.processors, - trace: c.trace, + beforeSQL: c.beforeSQL, + afterSQL: c.afterSQL, } } @@ -82,9 +85,14 @@ func (c *Callback) RowQuery() *CallbackProcessor { return &CallbackProcessor{kind: "row_query", parent: c} } -// Trace could be used to register callbacks for tracing sql queries, refer `Create` for usage -func (c *Callback) Trace() *CallbackProcessor { - return &CallbackProcessor{kind: "trace", parent: c} +// Trace could be used to register callbacks before any sql queries, refer `Create` for usage +func (c *Callback) BeforeSQL() *CallbackProcessor { + return &CallbackProcessor{kind: "beforeSQL", parent: c} +} + +// Trace could be used to register callbacks after any sql queries, refer `Create` for usage +func (c *Callback) AfterSQL() *CallbackProcessor { + return &CallbackProcessor{kind: "afterSQL", parent: c} } // After insert a new callback after callback `callbackName`, refer `Callbacks.Create` @@ -218,7 +226,7 @@ func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) { // reorder all registered processors, and reset CURD callbacks func (c *Callback) reorder() { - var creates, updates, deletes, queries, rowQueries, trace []*CallbackProcessor + var creates, updates, deletes, queries, rowQueries, beforeSQL, afterSQL []*CallbackProcessor for _, processor := range c.processors { if processor.name != "" { @@ -233,8 +241,10 @@ func (c *Callback) reorder() { queries = append(queries, processor) case "row_query": rowQueries = append(rowQueries, processor) - case "trace": - trace = append(trace, processor) + case "beforeSQL": + beforeSQL = append(beforeSQL, processor) + case "afterSQL": + afterSQL = append(afterSQL, processor) } } } @@ -244,5 +254,6 @@ func (c *Callback) reorder() { c.deletes = sortProcessors(deletes) c.queries = sortProcessors(queries) c.rowQueries = sortProcessors(rowQueries) - c.trace = sortProcessors(trace) + c.beforeSQL = sortProcessors(beforeSQL) + c.afterSQL = sortProcessors(afterSQL) } diff --git a/callback_create.go b/callback_create.go index f0709880..f1e5d181 100644 --- a/callback_create.go +++ b/callback_create.go @@ -40,7 +40,8 @@ func updateTimeStampForCreateCallback(scope *Scope) { // createCallback the callback used to insert data into database func createCallback(scope *Scope) { if !scope.HasError() { - defer scope.trace(NowFunc()) + scope.callCallbacks(scope.db.parent.callbacks.beforeSQL) + defer scope.callCallbacks(scope.db.parent.callbacks.afterSQL) var ( columns, placeholders []string diff --git a/callback_query.go b/callback_query.go index 93782b1d..720a0722 100644 --- a/callback_query.go +++ b/callback_query.go @@ -15,7 +15,8 @@ func init() { // queryCallback used to query data from database func queryCallback(scope *Scope) { - defer scope.trace(NowFunc()) + scope.callCallbacks(scope.db.parent.callbacks.beforeSQL) + defer scope.callCallbacks(scope.db.parent.callbacks.afterSQL) var ( isSlice, isPtr bool diff --git a/callback_trace.go b/callback_trace.go new file mode 100644 index 00000000..3697e60c --- /dev/null +++ b/callback_trace.go @@ -0,0 +1,29 @@ +package gorm + +import ( + "time" +) + +// Define callbacks for tracing +func init() { + DefaultCallback.BeforeSQL().Register("gorm:start-time", startTimeCallback) + DefaultCallback.AfterSQL().Register("gorm:log", logCallback) +} + +// startTimeCallback puts time when sql started in scope +func startTimeCallback(scope *Scope) { + scope.Set("gorm:trace-start-time", NowFunc()) +} + +// logCallback prints sql log +func logCallback(scope *Scope) { + if len(scope.SQL) <= 0 { + return + } + + t, ok := scope.Get("gorm:trace-start-time") + if !ok { + return + } + scope.db.slog(scope.SQL, t.(time.Time), scope.SQLVars...) +} diff --git a/scope.go b/scope.go index 75c33397..e140deee 100644 --- a/scope.go +++ b/scope.go @@ -8,7 +8,6 @@ import ( "regexp" "strconv" "strings" - "time" "reflect" ) @@ -346,9 +345,10 @@ func (scope *Scope) Raw(sql string) *Scope { // Exec perform generated SQL func (scope *Scope) Exec() *Scope { - defer scope.trace(NowFunc()) - if !scope.HasError() { + scope.callCallbacks(scope.db.parent.callbacks.beforeSQL) + defer scope.callCallbacks(scope.db.parent.callbacks.afterSQL) + if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { if count, err := result.RowsAffected(); scope.Err(err) == nil { scope.db.RowsAffected = count @@ -884,14 +884,18 @@ func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[strin } func (scope *Scope) row() *sql.Row { - defer scope.trace(NowFunc()) + scope.callCallbacks(scope.db.parent.callbacks.beforeSQL) + defer scope.callCallbacks(scope.db.parent.callbacks.afterSQL) + scope.callCallbacks(scope.db.parent.callbacks.rowQueries) scope.prepareQuerySQL() return scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...) } func (scope *Scope) rows() (*sql.Rows, error) { - defer scope.trace(NowFunc()) + scope.callCallbacks(scope.db.parent.callbacks.beforeSQL) + defer scope.callCallbacks(scope.db.parent.callbacks.afterSQL) + scope.callCallbacks(scope.db.parent.callbacks.rowQueries) scope.prepareQuerySQL() return scope.SQLDB().Query(scope.SQL, scope.SQLVars...) @@ -945,16 +949,6 @@ func (scope *Scope) typeName() string { return typ.Name() } -// trace print sql log -func (scope *Scope) trace(t time.Time) { - scope.Set("gorm:trace-time", t) - scope.callCallbacks(scope.db.parent.callbacks.trace) - - if len(scope.SQL) > 0 { - scope.db.slog(scope.SQL, t, scope.SQLVars...) - } -} - func (scope *Scope) changeableField(field *Field) bool { if selectAttrs := scope.SelectAttrs(); len(selectAttrs) > 0 { for _, attr := range selectAttrs {