diff --git a/internal/logging/custom_handler.go b/internal/logging/custom_handler.go new file mode 100644 index 0000000..02eb730 --- /dev/null +++ b/internal/logging/custom_handler.go @@ -0,0 +1,207 @@ +package logging + +import ( + "bytes" + "context" + "io" + "log/slog" + "runtime" + "slices" + "strings" + "sync" + "text/template" + "time" +) + +type FormattedHandler struct { + mu *sync.Mutex + out io.Writer + opts Options + attrs map[string]slog.Value + groups []string + groupLvl int +} + +type Options struct { + Level slog.Leveler + Format string + ReplaceAttr func(groups []string, attr slog.Attr) slog.Attr +} +type locData struct { + FileName string + Function string + Line int +} + +func NewFormattedHandler(out io.Writer, options Options) *FormattedHandler { + h := &FormattedHandler{ + opts: options, + out: out, + mu: &sync.Mutex{}, + groups: make([]string, 0), + } + if h.opts.Format == "" { + h.opts.Format = "{{.Time}} [{{.Level}}]" + } + if h.opts.Level == nil { + h.opts.Level = slog.LevelInfo + } + return h +} + +func (f *FormattedHandler) Enabled(ctx context.Context, level slog.Level) bool { + return level >= f.opts.Level.Level() +} + +func (f *FormattedHandler) Handle(ctx context.Context, r slog.Record) error { + bufp := allocBuf() + buf := *bufp + defer func() { + *bufp = buf + freeBuf(bufp) + }() + rep := f.opts.ReplaceAttr + key := slog.LevelKey + val := r.Level + if rep == nil { + r.AddAttrs(slog.String(key, val.String())) + } else { + nattr := slog.Any(key, val) + nattr.Value = rep(f.groups, nattr).Value + r.AddAttrs(nattr) + } + + f.mu.Lock() + defer f.mu.Unlock() + tctx, tmpl := f.newFmtCtx(r) + wr := bytes.NewBuffer(buf) + parsed, err := tmpl.Parse(f.opts.Format) + if err != nil { + return err + } + err = parsed.Execute(wr, tctx) + if err != nil { + return err + } + + wr.WriteByte('\n') + _, err = f.out.Write(wr.Bytes()) + return err +} + +func (f *FormattedHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + if len(attrs) == 0 { + return f + } + nf := f.clone() + bufp := allocBuf() + buf := *bufp + defer func() { + *bufp = buf + freeBuf(bufp) + }() + s := f.newState(bytes.NewBuffer(buf)) + defer s.free() + pos := s.buf.Len() + s.startGroups() + if !s.appendAttrs(attrs) { + s.buf.Truncate(pos) + } else { + nf.groupLvl = len(nf.groups) + } + return nf +} + +func (f *FormattedHandler) WithGroup(name string) slog.Handler { + if name == "" { + return f + } + f2 := f.clone() + f2.groups = append(f2.groups, name) + return f2 +} + +func (f *FormattedHandler) clone() *FormattedHandler { + return &FormattedHandler{ + opts: f.opts, + groups: slices.Clip(f.groups), + out: f.out, + mu: f.mu, + groupLvl: f.groupLvl, + } +} + +type tmplData struct { + Level string + Message string + RawTime time.Time + Time string + PC uintptr + Location locData + Record slog.Record +} + +func hasBuiltInKey(a slog.Attr) bool { + return a.Key == slog.MessageKey || + a.Key == slog.TimeKey || + a.Key == slog.SourceKey +} + +func (f *FormattedHandler) newFmtCtx(r slog.Record) (ctx *tmplData, tmpl *template.Template) { + tmpl = template.New("log") + ctx = &tmplData{ + Message: r.Message, + RawTime: r.Time, + PC: r.PC, + Location: locData{}, + } + if !r.Time.IsZero() { + ctx.Time = r.Time.Format(time.RFC3339Nano) + } + r.Attrs(func(a slog.Attr) bool { + if a.Key == slog.LevelKey { + str := strings.ToUpper(a.Value.String()) + if rep := f.opts.ReplaceAttr; rep != nil { + str = strings.ToUpper(a.Value.String()) + } + ctx.Level = str + } + return true + }) + if r.PC != 0 { + frames := runtime.CallersFrames([]uintptr{r.PC}) + frame, _ := frames.Next() + ctx.Location.FileName = frame.File + ctx.Location.Function = frame.Function + ctx.Location.Line = frame.Line + } + fm := make(map[string]any) + fm["rest"] = func() string { + bb := new(bytes.Buffer) + s := f.newState(bb) + defer s.free() + s.begin(r) + return s.buf.String() + } + tmpl = tmpl.Funcs(fm) + return +} + +var bufPool = sync.Pool{ + New: func() any { + b := make([]byte, 0, 4096) + return &b + }, +} + +func allocBuf() *[]byte { + return bufPool.Get().(*[]byte) +} + +func freeBuf(b *[]byte) { + const maxBufferSize = 16 << 10 + if cap(*b) <= maxBufferSize { + *b = (*b)[:0] + bufPool.Put(b) + } +} diff --git a/internal/logging/custom_handler_state.go b/internal/logging/custom_handler_state.go new file mode 100644 index 0000000..e312cf1 --- /dev/null +++ b/internal/logging/custom_handler_state.go @@ -0,0 +1,145 @@ +package logging + +import ( + "bytes" + "fmt" + "log/slog" + "strings" + "sync" + "time" +) + +func (f *FormattedHandler) newState(sb *bytes.Buffer) state { + s := state{ + fh: f, + buf: sb, + } + if f.opts.ReplaceAttr != nil { + s.groups = groupPool.Get().(*[]string) + *s.groups = append(*s.groups, f.groups[:f.groupLvl]...) + } + return s +} + +type state struct { + buf *bytes.Buffer + fh *FormattedHandler + groups *[]string +} + +func (s *state) startGroups() { + for _, n := range s.fh.groups[s.fh.groupLvl:] { + s.startGroup(n) + } +} + +func (s *state) startGroup(name string) { + s.buf.WriteByte('\n') + if s.groups != nil { + *s.groups = append(*s.groups, name) + } +} + +func (s *state) endGroup() { + if s.groups != nil { + *s.groups = (*s.groups)[:len(*s.groups)-1] + } +} + +func (s *state) appendAttr(a slog.Attr) bool { + a.Value = a.Value.Resolve() + if rep := s.fh.opts.ReplaceAttr; rep != nil && a.Value.Kind() != slog.KindGroup { + var gs []string + if s.groups != nil { + gs = *s.groups + } + a = rep(gs, a) + a.Value = a.Value.Resolve() + } + if a.Equal(slog.Attr{}) || + hasBuiltInKey(a) || + a.Key == slog.LevelKey { + return false + } + if a.Value.Kind() == slog.KindGroup { + pos := s.buf.Len() + attrs := a.Value.Group() + if len(attrs) > 0 { + if a.Key != "" { + s.startGroup(a.Key) + } + if !s.appendAttrs(attrs) { + s.buf.Truncate(pos) + return false + } + if a.Key != "" { + s.endGroup() + } + } + } else { + s.writeAttr(a) + } + return true +} + +func (s *state) appendAttrs(as []slog.Attr) bool { + nonEmpty := false + for _, a := range as { + if s.appendAttr(a) { + nonEmpty = true + } + } + return nonEmpty +} + +func (s *state) writeAttr(a slog.Attr) { + if s.buf.Len() > 0 { + s.buf.WriteString(";") + } + if len(*s.groups) > 0 { + s.buf.WriteString(fmt.Sprintf("%*s", len(*s.groups)*2, "")) + s.buf.WriteString(strings.Join(*s.groups, ".")) + s.buf.WriteString(".") + } + s.buf.WriteString(a.Key) + s.buf.WriteString("=") + switch a.Value.Kind() { + case slog.KindDuration: + s.buf.WriteString(a.Value.Duration().String()) + case slog.KindTime: + s.buf.WriteString(a.Value.Time().Format(time.RFC3339Nano)) + default: + s.buf.WriteString(fmt.Sprintf("%+v", a.Value.Any())) + } +} + +func (s *state) begin(r slog.Record) { + if r.NumAttrs() > 0 { + pos := s.buf.Len() + s.startGroups() + empty := true + r.Attrs(func(a slog.Attr) bool { + isBuiltIn := hasBuiltInKey(a) || a.Key == slog.LevelKey + if !isBuiltIn && s.appendAttr(a) { + empty = false + } + return true + }) + if empty { + s.buf.Truncate(pos) + } + } +} + +func (s *state) free() { + if gs := s.groups; gs != nil { + *gs = (*gs)[:0] + groupPool.Put(gs) + } + s.buf.Reset() +} + +var groupPool = sync.Pool{New: func() any { + s := make([]string, 0, 10) + return &s +}} diff --git a/internal/logging/custom_handler_test.go b/internal/logging/custom_handler_test.go new file mode 100644 index 0000000..9dd690d --- /dev/null +++ b/internal/logging/custom_handler_test.go @@ -0,0 +1,34 @@ +package logging + +import ( + "context" + "log/slog" + "os" + "testing" +) + +const LevelQ = slog.Level(-6) + +func TestDoAFlip(t *testing.T) { + t.Name() + replacer := func(groups []string, a slog.Attr) slog.Attr { + if a.Key == slog.LevelKey { + level := a.Value.Any().(slog.Level) + switch level { + case LevelQ: + a.Value = slog.StringValue("q") + } + } + return a + } + h := NewFormattedHandler(os.Stderr, Options{ + Format: "{{.Time}} [{{.Level}}] {{.Message}} | {{ rest }}", + Level: LevelQ, + ReplaceAttr: replacer, + }) + logger := slog.New(h) + slog.SetDefault(logger) + + logger.Debug("hello", "btfash", true) + logger.Log(context.TODO(), LevelQ, "hi") +} diff --git a/model_migration.go b/model_migration.go index 0e31bc3..db18fff 100644 --- a/model_migration.go +++ b/model_migration.go @@ -142,7 +142,8 @@ ON UPDATE CASCADE;`, rel.RelatedModel.TableName, field.ColumnName, ) dq := fmt.Sprintf(`ALTER TABLE %s DROP CONSTRAINT IF EXISTS %s;`, m.TableName, fk) - fmt.Printf("%s\n%s\n", dq, q) + engine.logSql("drop constraint", dq) + engine.logSql("alter table", q) if _, err := engine.conn.Exec(engine.ctx, dq); err != nil { return err } @@ -156,7 +157,7 @@ ON UPDATE CASCADE;`, func (m *Model) migrate(engine *Engine) error { sql := m.createTableSql() - fmt.Println(sql) + engine.logSql("create table", sql) if !engine.dryRun { _, err := engine.conn.Exec(engine.ctx, sql) if err != nil { @@ -173,7 +174,7 @@ func (m *Model) migrate(engine *Engine) error { engine.m2mSeen[rel.RelatedModel.Name] = true } jtsql := m.createJoinTableSql(relName) - fmt.Println(jtsql) + engine.logSql("crate join table", jtsql) if !engine.dryRun { _, err := engine.conn.Exec(engine.ctx, jtsql) if err != nil { diff --git a/query_populate.go b/query_populate.go index fe80327..c7a8f67 100644 --- a/query_populate.go +++ b/query_populate.go @@ -132,7 +132,7 @@ func (q *Query) populateHas(rel *Relationship, parent reflect.Value, parentIds [ aq, aa := sb.Select(ccols...). From(rel.RelatedModel.TableName). Where(fmt.Sprintf("%s IN (%s)", fk, MakePlaceholders(len(parentIds))), parentIds...).MustSQL() - fmt.Printf("[POPULATE] %s %+v\n", aq, aa) + q.engine.logQuery("populate", aq, aa) rows, err := q.engine.conn.Query(q.ctx, aq, aa...) if err != nil { return reflect.Value{}, err @@ -245,7 +245,7 @@ func (q *Query) populateManyToMany(rel *Relationship, parent reflect.Value, pare rel.relatedID().ColumnName, rel.RelatedModel.TableName)). Where(fmt.Sprintf("jt.%s_id IN (%s)", rel.Model.TableName, inPlaceholders), parentIds...).MustSQL() - fmt.Printf("[POPULATE/JOIN] %s %+v\n", mq, ma) + q.engine.logQuery("populate/join", mq, ma) rows, err := q.engine.conn.Query(q.ctx, mq, ma...) if err != nil { return reflect.Value{}, err @@ -306,7 +306,7 @@ func (q *Query) populateBelongsTo(rel *Relationship, childrenSlice reflect.Value Where(fmt.Sprintf("%s IN (%s)", childIdField.ColumnName, MakePlaceholders(len(childIDs)), ), childIDs...).MustSQL() - fmt.Printf("[POPULATE/BELONGS-TO] %s %+v\n", qs, qa) + q.engine.logQuery("populate/belongs-to", qs, qa) rows, err := q.engine.conn.Query(q.ctx, qs, qa...) if err != nil { return reflect.Value{}, err @@ -350,7 +350,7 @@ func (q *Query) populateBelongsTo(rel *Relationship, childrenSlice reflect.Value parentIdField.ColumnName, MakePlaceholders(len(parentKeyValues))), parentKeyValues...). MustSQL() - fmt.Printf("[POPULATE/BELONGS-TO->PARENT] %s %+v\n", pquery, pqargs) + q.engine.logQuery("populate/belongs-to->parent", pquery, pqargs) parentRows, err := q.engine.conn.Query(q.ctx, pquery, pqargs...) if err != nil { return reflect.Value{}, err diff --git a/query_tail.go b/query_tail.go index ea496e4..4198200 100644 --- a/query_tail.go +++ b/query_tail.go @@ -19,7 +19,7 @@ func (q *Query) Find(dest any) error { return err } qq, qa := sqlb.MustSQL() - fmt.Printf("[FIND] %s %+v\n", qq, qa) + q.engine.logQuery("find", qq, qa) if maybeSlice.Kind() == reflect.Struct { row := q.engine.conn.QueryRow(q.ctx, qq, qa...) @@ -102,7 +102,7 @@ func (q *Query) UpdateRaw(values map[string]any) (int64, error) { } } sql, args := stmt.MustSQL() - fmt.Printf("[UPDATE/RAW] %s %+v\n", sql, args) + q.engine.logQuery("update/raw", sql, args) q.tx, err = q.engine.conn.Begin(q.ctx) if err != nil { return 0, err @@ -134,7 +134,7 @@ func (q *Query) Delete() (int64, error) { sqlb := sb.Delete(q.model.TableName).Where(subQuery) sql, sqla := sqlb.MustSQL() - fmt.Printf("[DELETE] %s %+v\n", sql, sqla) + q.engine.logQuery("delete", sql, sqla) cmdTag, err := q.tx.Exec(q.ctx, sql, sqla...) if err != nil { return 0, fmt.Errorf("failed to delete: %w", err) @@ -284,7 +284,7 @@ func (q *Query) doSave(val reflect.Value, model *Model, parentFks map[string]any } if doInsert { var nid any - fmt.Printf("[INSERT] %s %+v\n", qq, qa) + q.engine.logQuery("insert", qq, qa) row := q.tx.QueryRow(q.ctx, qq, qa...) err := row.Scan(&nid) if err != nil { @@ -292,7 +292,7 @@ func (q *Query) doSave(val reflect.Value, model *Model, parentFks map[string]any } pkField.Set(reflect.ValueOf(nid)) } else { - fmt.Printf("[UPDATE] %s %+v\n", qq, qa) + q.engine.logQuery("update", qq, qa) _, err := q.tx.Exec(q.ctx, qq, qa...) if err != nil { return nil, fmt.Errorf("update failed for model %s: %w", model.Name, err) diff --git a/relationship.go b/relationship.go index a4a3970..4b2b1df 100644 --- a/relationship.go +++ b/relationship.go @@ -113,7 +113,7 @@ func (r *Relationship) joinDelete(pk, fk any, q *Query) error { dq = dq.Where(fmt.Sprintf("%s_id = ?", r.RelatedModel.TableName), fk) } ds, aa := dq.MustSQL() - fmt.Printf("[DELETE/JOIN] %s %+v \n", ds, logTrunc(200, aa)) + q.engine.logQuery("delete/join", ds, aa) if !q.engine.dryRun { _, err := q.tx.Exec(q.ctx, ds, aa...) return err