diff --git a/logger/logger.go b/logger/logger.go index 2ffd28d5..347b01c3 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -47,7 +47,7 @@ const ( // Writer log writer interface type Writer interface { - Printf(string, ...interface{}) + Printf(context.Context, string, ...interface{}) } // Config logger config @@ -69,9 +69,9 @@ type Interface interface { var ( // Discard Discard logger will print any log to ioutil.Discard - Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{}) + Discard = New(loggerWrap{logger: log.New(ioutil.Discard, "", log.LstdFlags)}, Config{}) // Default Default logger - Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ + Default = New(loggerWrap{logger: log.New(os.Stdout, "\r\n", log.LstdFlags)}, Config{ SlowThreshold: 200 * time.Millisecond, LogLevel: Warn, IgnoreRecordNotFoundError: false, @@ -81,6 +81,14 @@ var ( Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()} ) +type loggerWrap struct { + logger *log.Logger +} + +func (l loggerWrap) Printf(_ context.Context, msg string, data ...interface{}) { + l.logger.Printf(msg, data...) +} + // New initialize logger func New(writer Writer, config Config) Interface { var ( @@ -130,21 +138,21 @@ func (l *logger) LogMode(level LogLevel) Interface { // Info print info func (l logger) Info(ctx context.Context, msg string, data ...interface{}) { if l.LogLevel >= Info { - l.Printf(l.infoStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) + l.Printf(ctx, l.infoStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) } } // Warn print warn messages func (l logger) Warn(ctx context.Context, msg string, data ...interface{}) { if l.LogLevel >= Warn { - l.Printf(l.warnStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) + l.Printf(ctx, l.warnStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) } } // Error print error messages func (l logger) Error(ctx context.Context, msg string, data ...interface{}) { if l.LogLevel >= Error { - l.Printf(l.errStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) + l.Printf(ctx, l.errStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) } } @@ -159,24 +167,24 @@ func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, i case err != nil && l.LogLevel >= Error && (!errors.Is(err, ErrRecordNotFound) || !l.IgnoreRecordNotFoundError): sql, rows := fc() if rows == -1 { - l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, "-", sql) + l.Printf(ctx, l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, "-", sql) } else { - l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql) + l.Printf(ctx, l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql) } case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= Warn: sql, rows := fc() slowLog := fmt.Sprintf("SLOW SQL >= %v", l.SlowThreshold) if rows == -1 { - l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, "-", sql) + l.Printf(ctx, l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, "-", sql) } else { - l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, rows, sql) + l.Printf(ctx, l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, rows, sql) } case l.LogLevel == Info: sql, rows := fc() if rows == -1 { - l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql) + l.Printf(ctx, l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql) } else { - l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) + l.Printf(ctx, l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) } } }