diff --git a/callbacks.go b/callbacks.go index 78f1192e..6c70b392 100644 --- a/callbacks.go +++ b/callbacks.go @@ -1,6 +1,7 @@ package gorm import ( + "context" "errors" "fmt" "reflect" @@ -90,7 +91,7 @@ func (p *processor) Execute(db *DB) { } if stmt := db.Statement; stmt != nil { - db.Logger.Trace(curTime, func() (string, int64) { + db.Logger.Trace(stmt.Context, curTime, func() (string, int64) { return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected }, db.Error) @@ -141,7 +142,7 @@ func (p *processor) compile() (err error) { } if p.fns, err = sortCallbacks(p.callbacks); err != nil { - logger.Default.Error("Got error when compile callbacks, got %v", err) + logger.Default.Error(context.Background(), "Got error when compile callbacks, got %v", err) } return } @@ -164,7 +165,7 @@ func (c *callback) Register(name string, fn func(*DB)) error { } func (c *callback) Remove(name string) error { - logger.Default.Warn("removing callback `%v` from %v\n", name, utils.FileWithLineNum()) + logger.Default.Warn(context.Background(), "removing callback `%v` from %v\n", name, utils.FileWithLineNum()) c.name = name c.remove = true c.processor.callbacks = append(c.processor.callbacks, c) @@ -172,7 +173,7 @@ func (c *callback) Remove(name string) error { } func (c *callback) Replace(name string, fn func(*DB)) error { - logger.Default.Info("replacing callback `%v` from %v\n", name, utils.FileWithLineNum()) + logger.Default.Info(context.Background(), "replacing callback `%v` from %v\n", name, utils.FileWithLineNum()) c.name = name c.handler = fn c.replace = true @@ -199,7 +200,7 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { for _, c := range cs { // show warning message the callback name already exists if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove { - logger.Default.Warn("duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum()) + logger.Default.Warn(context.Background(), "duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum()) } names = append(names, c.name) } diff --git a/logger/logger.go b/logger/logger.go index ee6c0da1..24cee821 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -1,6 +1,7 @@ package logger import ( + "context" "log" "os" "time" @@ -46,10 +47,10 @@ type Config struct { // Interface logger interface type Interface interface { LogMode(LogLevel) Interface - Info(string, ...interface{}) - Warn(string, ...interface{}) - Error(string, ...interface{}) - Trace(begin time.Time, fc func() (string, int64), err error) + Info(context.Context, string, ...interface{}) + Warn(context.Context, string, ...interface{}) + Error(context.Context, string, ...interface{}) + Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) } var Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ @@ -103,35 +104,35 @@ func (l logger) LogMode(level LogLevel) Interface { } // Info print info -func (l logger) Info(msg string, data ...interface{}) { +func (l logger) Info(ctx context.Context, msg string, data ...interface{}) { if l.LogLevel >= Info { l.Printf(l.infoStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) } } // Warn print warn messages -func (l logger) Warn(msg string, data ...interface{}) { +func (l logger) Warn(ctx context.Context, msg string, data ...interface{}) { if l.LogLevel >= Warn { l.Printf(l.warnStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) } } // Error print error messages -func (l logger) Error(msg string, data ...interface{}) { +func (l logger) Error(ctx context.Context, msg string, data ...interface{}) { if l.LogLevel >= Error { l.Printf(l.errStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) } } // Trace print sql message -func (l logger) Trace(begin time.Time, fc func() (string, int64), err error) { +func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { if l.LogLevel > 0 { elapsed := time.Now().Sub(begin) switch { case err != nil: sql, rows := fc() l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql) - case elapsed > l.SlowThreshold && l.SlowThreshold != 0: + case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= Warn: sql, rows := fc() l.Printf(l.traceWarnStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) case l.LogLevel >= Info: diff --git a/schema/schema.go b/schema/schema.go index 2ac6d312..3abac2ba 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -1,6 +1,7 @@ package schema import ( + "context" "errors" "fmt" "go/ast" @@ -83,7 +84,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) defer func() { if schema.err != nil { - logger.Default.Error(schema.err.Error()) + logger.Default.Error(context.Background(), schema.err.Error()) cacheStore.Delete(modelType) } }() @@ -174,7 +175,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) case "func(*gorm.DB)": // TODO hack reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true) default: - logger.Default.Warn("Model %v don't match %vInterface, should be %v(*gorm.DB)", schema, name, name) + logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be %v(*gorm.DB)", schema, name, name) } } }