package orm import ( "context" "fmt" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgxpool" "io" "log/slog" "os" "rockfic.com/orm/internal/logging" "time" ) // LevelQuery enables logging of SQL queries if passed to Config.LogLevel const LevelQuery = slog.Level(-6) const defaultKey = "default" type Config struct { DryRun bool // when true, queries will not run on the underlying database LogLevel slog.Level // controls the level of information logged; defaults to slog.LevelInfo if not set LogTo io.Writer // where to write log output to; defaults to os.Stdout } type Engine struct { modelMap *internalModelMap conn *pgxpool.Pool m2mSeen map[string]bool dryRun bool pgCfg *pgxpool.Config ctx context.Context logger *slog.Logger cfg *Config levelVar *slog.LevelVar connStr string } // Models - parse and register one or more types as persistable models func (e *Engine) Models(v ...any) { emm := makeModelMap(v...) for k := range emm.Map { if _, ok := e.modelMap.Map[k]; !ok { e.modelMap.Mux.Lock() e.modelMap.Map[k] = emm.Map[k] e.modelMap.Mux.Unlock() } } } // Model - createes a Query and sets its model to // the one corresponding to the type of `val` func (e *Engine) Model(val any) *Query { qq := &Query{ engine: e, ctx: context.Background(), wheres: make(map[string][]any), orders: make([]string, 0), populationTree: make(map[string]any), joins: make([]string, 0), } return qq.setModel(val) } // QueryRaw - wrapper for the Query method of pgxpool.Pool func (e *Engine) QueryRaw(sql string, args ...any) (pgx.Rows, error) { return e.conn.Query(e.ctx, sql, args...) } // Migrate - non-destructive; run migrations to update the underlying schema, WITHOUT dropping tables beforehand func (e *Engine) Migrate() error { failedMigrations := make(map[string]*Model) var err error for mk, m := range e.modelMap.Map { err = m.migrate(e) if err != nil { failedMigrations[mk] = m } } for len(failedMigrations) > 0 { e.m2mSeen = make(map[string]bool) for mk, m := range failedMigrations { err = m.migrate(e) if err == nil { delete(failedMigrations, mk) } } } return err } // MigrateDropping - destructive migration; DROP the necessary tables if they exist, // then recreate them to match your models' schema func (e *Engine) MigrateDropping() error { for _, m := range e.modelMap.Map { sql := fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE;", m.TableName) if _, err := e.conn.Exec(e.ctx, sql); err != nil { return err } for _, r := range m.Relationships { if r.m2mIsh() || r.Type == ManyToMany { jsql := fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE;", r.ComputeJoinTable()) if _, err := e.conn.Exec(e.ctx, jsql); err != nil { return err } } } } return e.Migrate() } func (e *Engine) logQuery(msg, sql string, args []any) { e.logger.Log(e.ctx, LevelQuery, msg, "sql", sql, "args", logTrunc(200, args)) } func (e *Engine) logSql(msg, sql string) { e.logger.Log(e.ctx, LevelQuery, msg, "sql", sql) } // Disconnect - closes and disposes of this Engine's connection pool. func (e *Engine) Disconnect() { e.conn.Close() if asFile, ok := e.cfg.LogTo.(*os.File); ok { _ = asFile.Close() } } // Open - creates a new connection according to `connString` // and returns a brand new Engine to run FUCK operations on. func Open(connString string, cfg *Config) (*Engine, error) { if cfg == nil { cfg = &Config{ LogLevel: slog.LevelInfo, LogTo: os.Stdout, DryRun: connString == "", } } else { if cfg.LogTo == nil { cfg.LogTo = os.Stdout } } e := &Engine{ modelMap: &internalModelMap{ Map: make(map[string]*Model), }, m2mSeen: make(map[string]bool), dryRun: connString == "", ctx: context.Background(), levelVar: new(slog.LevelVar), cfg: cfg, } e.levelVar.Set(cfg.LogLevel) replacer := func(groups []string, a slog.Attr) slog.Attr { if a.Key == slog.LevelKey { level := a.Value.Any().(slog.Level) switch level { case LevelQuery: a.Value = slog.StringValue("query") } } return a } e.logger = slog.New(logging.NewFormattedHandler(cfg.LogTo, logging.Options{ Level: e.levelVar, ReplaceAttr: replacer, Format: "{{.Time}} [{{.Level}}] {{.Message}} | {{ rest }}", })) slog.SetDefault(e.logger) if connString != "" { engines.Mux.Lock() if len(engines.Engines) == 0 || engines.Engines[defaultKey] == nil { engines.Engines[defaultKey] = e } else { engines.Engines[connString] = e } e.connStr = "" engines.Mux.Unlock() var err error e.pgCfg, err = pgxpool.ParseConfig(connString) e.pgCfg.MinConns = 5 e.pgCfg.MaxConns = 10 e.pgCfg.MaxConnIdleTime = time.Minute * 2 e.pgCfg.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error { oldHandler := conn.Config().OnPgError conn.Config().OnPgError = func(conn *pgconn.PgConn, pgError *pgconn.PgError) bool { e.logger.Error("ERROR ->", "err", pgError.Error()) return oldHandler(conn, pgError) } return nil } if err != nil { return nil, err } e.conn, err = pgxpool.NewWithConfig(e.ctx, e.pgCfg) if err != nil { return nil, err } } return e, nil }