From 1181263f8abab177f8a6df941ad6d45c1c810110 Mon Sep 17 00:00:00 2001 From: zaneli Date: Sun, 8 Dec 2019 23:55:19 +0900 Subject: [PATCH] Fix rows affected log for Pluck, and remove rows affected log uncountable case and getting single row case --- callback_create.go | 2 +- callback_query.go | 2 +- logger.go | 4 +++- main.go | 16 ++++++++++++---- scope.go | 16 +++++++++------- 5 files changed, 26 insertions(+), 14 deletions(-) diff --git a/callback_create.go b/callback_create.go index c4d25f37..d3528439 100644 --- a/callback_create.go +++ b/callback_create.go @@ -50,7 +50,7 @@ func updateTimeStampForCreateCallback(scope *Scope) { // createCallback the callback used to insert data into database func createCallback(scope *Scope) { if !scope.HasError() { - defer scope.trace(NowFunc()) + defer scope.trace(NowFunc(), true) var ( columns, placeholders []string diff --git a/callback_query.go b/callback_query.go index 544afd63..224a6fa4 100644 --- a/callback_query.go +++ b/callback_query.go @@ -24,7 +24,7 @@ func queryCallback(scope *Scope) { return } - defer scope.trace(NowFunc()) + defer scope.trace(NowFunc(), true) var ( isSlice, isPtr bool diff --git a/logger.go b/logger.go index 88e167dd..ee0b0d3e 100644 --- a/logger.go +++ b/logger.go @@ -106,7 +106,9 @@ var LogFormatter = func(values ...interface{}) (messages []interface{}) { } messages = append(messages, sql) - messages = append(messages, fmt.Sprintf(" \n\033[36;31m[%v]\033[0m ", strconv.FormatInt(values[5].(int64), 10)+" rows affected or returned ")) + if len(values) > 5 { + messages = append(messages, fmt.Sprintf(" \n\033[36;31m[%v]\033[0m ", strconv.FormatInt(values[5].(int64), 10)+" rows affected or returned ")) + } } else { messages = append(messages, "\033[31;1m") messages = append(messages, values[2:]...) diff --git a/main.go b/main.go index 3db87870..5294b330 100644 --- a/main.go +++ b/main.go @@ -366,12 +366,16 @@ func (s *DB) Scan(dest interface{}) *DB { // Row return `*sql.Row` with given conditions func (s *DB) Row() *sql.Row { - return s.NewScope(s.Value).row() + scope := s.NewScope(s.Value) + defer scope.trace(NowFunc(), false) + return scope.row() } // Rows return `*sql.Rows` with given conditions func (s *DB) Rows() (*sql.Rows, error) { - return s.NewScope(s.Value).rows() + scope := s.NewScope(s.Value) + defer scope.trace(NowFunc(), false) + return scope.rows() } // ScanRows scan `*sql.Rows` to give struct @@ -874,8 +878,12 @@ func (s *DB) log(v ...interface{}) { } } -func (s *DB) slog(sql string, t time.Time, vars ...interface{}) { +func (s *DB) slog(sql string, t time.Time, showRowsAffected bool, vars ...interface{}) { if s.logMode == detailedLogMode { - s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars, s.RowsAffected) + if showRowsAffected { + s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars, s.RowsAffected) + } else { + s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars) + } } } diff --git a/scope.go b/scope.go index d82cadbc..5a614324 100644 --- a/scope.go +++ b/scope.go @@ -358,7 +358,7 @@ func (scope *Scope) Raw(sql string) *Scope { // Exec perform generated SQL func (scope *Scope) Exec() *Scope { - defer scope.trace(NowFunc()) + defer scope.trace(NowFunc(), true) if !scope.HasError() { if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { @@ -934,8 +934,6 @@ func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[strin } func (scope *Scope) row() *sql.Row { - defer scope.trace(NowFunc()) - result := &RowQueryResult{} scope.InstanceSet("row_query_result", result) scope.callCallbacks(scope.db.parent.callbacks.rowQueries) @@ -944,8 +942,6 @@ func (scope *Scope) row() *sql.Row { } func (scope *Scope) rows() (*sql.Rows, error) { - defer scope.trace(NowFunc()) - result := &RowsQueryResult{} scope.InstanceSet("row_query_result", result) scope.callCallbacks(scope.db.parent.callbacks.rowQueries) @@ -980,6 +976,8 @@ func (scope *Scope) isQueryForColumn(query interface{}, column string) bool { } func (scope *Scope) pluck(column string, value interface{}) *Scope { + defer scope.trace(NowFunc(), true) + dest := reflect.Indirect(reflect.ValueOf(value)) if dest.Kind() != reflect.Slice { scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind())) @@ -998,6 +996,8 @@ func (scope *Scope) pluck(column string, value interface{}) *Scope { if scope.Err(err) == nil { defer rows.Close() for rows.Next() { + scope.db.RowsAffected++ + elem := reflect.New(dest.Type().Elem()).Interface() scope.Err(rows.Scan(elem)) dest.Set(reflect.Append(dest, reflect.ValueOf(elem).Elem())) @@ -1011,6 +1011,8 @@ func (scope *Scope) pluck(column string, value interface{}) *Scope { } func (scope *Scope) count(value interface{}) *Scope { + defer scope.trace(NowFunc(), false) + if query, ok := scope.Search.selects["query"]; !ok || !countingQueryRegexp.MatchString(fmt.Sprint(query)) { if len(scope.Search.group) != 0 { if len(scope.Search.havingConditions) != 0 { @@ -1042,9 +1044,9 @@ func (scope *Scope) typeName() string { } // trace print sql log -func (scope *Scope) trace(t time.Time) { +func (scope *Scope) trace(t time.Time, showRowsAffected bool) { if len(scope.SQL) > 0 { - scope.db.slog(scope.SQL, t, scope.SQLVars...) + scope.db.slog(scope.SQL, t, showRowsAffected, scope.SQLVars...) } }