
use golangci/golangci-lint-action instead of reviewdog/action-golangci-lint as the second was not reporting any failures even if there was some. Report code coverage with codecov/codecov-action I have set some flags per dialect and go version Several linters has been fixed, some disabled so the build can pass
333 lines
8.0 KiB
Go
333 lines
8.0 KiB
Go
package gorm
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
"sort"
|
|
"time"
|
|
|
|
"gorm.io/gorm/schema"
|
|
"gorm.io/gorm/utils"
|
|
)
|
|
|
|
func initializeCallbacks(db *DB) *callbacks {
|
|
return &callbacks{
|
|
processors: map[string]*processor{
|
|
"create": {db: db},
|
|
"query": {db: db},
|
|
"update": {db: db},
|
|
"delete": {db: db},
|
|
"row": {db: db},
|
|
"raw": {db: db},
|
|
},
|
|
}
|
|
}
|
|
|
|
// callbacks gorm callbacks manager
|
|
type callbacks struct {
|
|
processors map[string]*processor
|
|
}
|
|
|
|
type processor struct {
|
|
db *DB
|
|
Clauses []string
|
|
fns []func(*DB)
|
|
callbacks []*callback
|
|
}
|
|
|
|
type callback struct {
|
|
name string
|
|
before string
|
|
after string
|
|
remove bool
|
|
replace bool
|
|
match func(*DB) bool
|
|
handler func(*DB)
|
|
processor *processor
|
|
}
|
|
|
|
func (cs *callbacks) Create() *processor {
|
|
return cs.processors["create"]
|
|
}
|
|
|
|
func (cs *callbacks) Query() *processor {
|
|
return cs.processors["query"]
|
|
}
|
|
|
|
func (cs *callbacks) Update() *processor {
|
|
return cs.processors["update"]
|
|
}
|
|
|
|
func (cs *callbacks) Delete() *processor {
|
|
return cs.processors["delete"]
|
|
}
|
|
|
|
func (cs *callbacks) Row() *processor {
|
|
return cs.processors["row"]
|
|
}
|
|
|
|
func (cs *callbacks) Raw() *processor {
|
|
return cs.processors["raw"]
|
|
}
|
|
|
|
func (p *processor) Execute(db *DB) *DB {
|
|
// call scopes
|
|
for len(db.Statement.scopes) > 0 {
|
|
scopes := db.Statement.scopes
|
|
db.Statement.scopes = nil
|
|
for _, scope := range scopes {
|
|
db = scope(db)
|
|
}
|
|
}
|
|
|
|
var (
|
|
curTime = time.Now()
|
|
stmt = db.Statement
|
|
resetBuildClauses bool
|
|
)
|
|
|
|
if len(stmt.BuildClauses) == 0 {
|
|
stmt.BuildClauses = p.Clauses
|
|
resetBuildClauses = true
|
|
}
|
|
|
|
// assign model values
|
|
if stmt.Model == nil {
|
|
stmt.Model = stmt.Dest
|
|
} else if stmt.Dest == nil {
|
|
stmt.Dest = stmt.Model
|
|
}
|
|
|
|
// parse model values
|
|
if stmt.Model != nil {
|
|
if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.TableExpr == nil && stmt.SQL.Len() == 0)) {
|
|
if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" && stmt.TableExpr == nil {
|
|
db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err))
|
|
} else {
|
|
db.AddError(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// assign stmt.ReflectValue
|
|
if stmt.Dest != nil {
|
|
stmt.ReflectValue = reflect.ValueOf(stmt.Dest)
|
|
for stmt.ReflectValue.Kind() == reflect.Ptr {
|
|
if stmt.ReflectValue.IsNil() && stmt.ReflectValue.CanAddr() {
|
|
stmt.ReflectValue.Set(reflect.New(stmt.ReflectValue.Type().Elem()))
|
|
}
|
|
|
|
stmt.ReflectValue = stmt.ReflectValue.Elem()
|
|
}
|
|
if !stmt.ReflectValue.IsValid() {
|
|
db.AddError(ErrInvalidValue)
|
|
}
|
|
}
|
|
|
|
for _, f := range p.fns {
|
|
f(db)
|
|
}
|
|
|
|
if stmt.SQL.Len() > 0 {
|
|
db.Logger.Trace(stmt.Context, curTime, func() (string, int64) {
|
|
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected
|
|
}, db.Error)
|
|
}
|
|
|
|
if !stmt.DB.DryRun {
|
|
stmt.SQL.Reset()
|
|
stmt.Vars = nil
|
|
}
|
|
|
|
if resetBuildClauses {
|
|
stmt.BuildClauses = nil
|
|
}
|
|
|
|
return db
|
|
}
|
|
|
|
func (p *processor) Get(name string) func(*DB) {
|
|
for i := len(p.callbacks) - 1; i >= 0; i-- {
|
|
if v := p.callbacks[i]; v.name == name && !v.remove {
|
|
return v.handler
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (p *processor) Before(name string) *callback {
|
|
return &callback{before: name, processor: p}
|
|
}
|
|
|
|
func (p *processor) After(name string) *callback {
|
|
return &callback{after: name, processor: p}
|
|
}
|
|
|
|
func (p *processor) Match(fc func(*DB) bool) *callback {
|
|
return &callback{match: fc, processor: p}
|
|
}
|
|
|
|
func (p *processor) Register(name string, fn func(*DB)) error {
|
|
return (&callback{processor: p}).Register(name, fn)
|
|
}
|
|
|
|
func (p *processor) Remove(name string) error {
|
|
return (&callback{processor: p}).Remove(name)
|
|
}
|
|
|
|
func (p *processor) Replace(name string, fn func(*DB)) error {
|
|
return (&callback{processor: p}).Replace(name, fn)
|
|
}
|
|
|
|
func (p *processor) compile() (err error) {
|
|
var callbacks []*callback
|
|
for _, callback := range p.callbacks {
|
|
if callback.match == nil || callback.match(p.db) {
|
|
callbacks = append(callbacks, callback)
|
|
}
|
|
}
|
|
p.callbacks = callbacks
|
|
|
|
if p.fns, err = sortCallbacks(p.callbacks); err != nil {
|
|
p.db.Logger.Error(context.Background(), "Got error when compile callbacks, got %v", err)
|
|
}
|
|
return
|
|
}
|
|
|
|
func (c *callback) Before(name string) *callback {
|
|
c.before = name
|
|
return c
|
|
}
|
|
|
|
func (c *callback) After(name string) *callback {
|
|
c.after = name
|
|
return c
|
|
}
|
|
|
|
func (c *callback) Register(name string, fn func(*DB)) error {
|
|
c.name = name
|
|
c.handler = fn
|
|
c.processor.callbacks = append(c.processor.callbacks, c)
|
|
return c.processor.compile()
|
|
}
|
|
|
|
func (c *callback) Remove(name string) error {
|
|
c.processor.db.Logger.Warn(context.Background(), "removing callback `%s` from %s\n", name, utils.FileWithLineNum())
|
|
c.name = name
|
|
c.remove = true
|
|
c.processor.callbacks = append(c.processor.callbacks, c)
|
|
return c.processor.compile()
|
|
}
|
|
|
|
func (c *callback) Replace(name string, fn func(*DB)) error {
|
|
c.processor.db.Logger.Info(context.Background(), "replacing callback `%s` from %s\n", name, utils.FileWithLineNum())
|
|
c.name = name
|
|
c.handler = fn
|
|
c.replace = true
|
|
c.processor.callbacks = append(c.processor.callbacks, c)
|
|
return c.processor.compile()
|
|
}
|
|
|
|
// getRIndex get right index from string slice
|
|
func getRIndex(strs []string, str string) int {
|
|
for i := len(strs) - 1; i >= 0; i-- {
|
|
if strs[i] == str {
|
|
return i
|
|
}
|
|
}
|
|
return -1
|
|
}
|
|
|
|
func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
|
|
names := make([]string, 0)
|
|
var (
|
|
sorted []string
|
|
sortCallback func(*callback) error
|
|
)
|
|
sort.Slice(cs, func(i, j int) bool {
|
|
return cs[j].before == "*" || cs[j].after == "*"
|
|
})
|
|
|
|
for _, c := range cs {
|
|
// show warning message the callback name already exists
|
|
if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove {
|
|
c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%s` from %s\n", c.name, utils.FileWithLineNum())
|
|
}
|
|
names = append(names, c.name)
|
|
}
|
|
|
|
sortCallback = func(c *callback) error {
|
|
if c.before != "" { // if defined before callback
|
|
if c.before == "*" && len(sorted) > 0 {
|
|
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
|
|
sorted = append([]string{c.name}, sorted...)
|
|
}
|
|
} else if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 {
|
|
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
|
|
// if before callback already sorted, append current callback just after it
|
|
sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...)
|
|
} else if curIdx > sortedIdx {
|
|
return fmt.Errorf("conflicting callback %s with before %s", c.name, c.before)
|
|
}
|
|
} else if idx := getRIndex(names, c.before); idx != -1 {
|
|
// if before callback exists
|
|
cs[idx].after = c.name
|
|
}
|
|
}
|
|
|
|
if c.after != "" { // if defined after callback
|
|
if c.after == "*" && len(sorted) > 0 {
|
|
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
|
|
sorted = append(sorted, c.name)
|
|
}
|
|
} else if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 {
|
|
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
|
|
// if after callback sorted, append current callback to last
|
|
sorted = append(sorted, c.name)
|
|
} else if curIdx < sortedIdx {
|
|
return fmt.Errorf("conflicting callback %s with before %s", c.name, c.after)
|
|
}
|
|
} else if idx := getRIndex(names, c.after); idx != -1 {
|
|
// if after callback exists but haven't sorted
|
|
// set after callback's before callback to current callback
|
|
after := cs[idx]
|
|
|
|
if after.before == "" {
|
|
after.before = c.name
|
|
}
|
|
|
|
if err := sortCallback(after); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := sortCallback(c); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
// if current callback haven't been sorted, append it to last
|
|
if getRIndex(sorted, c.name) == -1 {
|
|
sorted = append(sorted, c.name)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
for _, c := range cs {
|
|
if err = sortCallback(c); err != nil {
|
|
return
|
|
}
|
|
}
|
|
|
|
for _, name := range sorted {
|
|
if idx := getRIndex(names, name); !cs[idx].remove {
|
|
fns = append(fns, cs[idx].handler)
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|