gorm/callback.go

268 lines
7.1 KiB
Go

package gorm
import (
"fmt"
"reflect"
"strconv"
"strings"
)
type callback struct {
creates []*func(scope *Scope)
updates []*func(scope *Scope)
deletes []*func(scope *Scope)
queries []*func(scope *Scope)
rowQueries []*func(scope *Scope)
processors []*callbackProcessor
}
type callbackProcessor struct {
name string
before string
after string
replace bool
remove bool
typ string
processor *func(scope *Scope)
callback *callback
}
func (c *callback) addProcessor(typ string) *callbackProcessor {
cp := &callbackProcessor{typ: typ, callback: c}
c.processors = append(c.processors, cp)
return cp
}
func (c *callback) clone() *callback {
return &callback{
creates: c.creates,
updates: c.updates,
deletes: c.deletes,
queries: c.queries,
processors: c.processors,
}
}
func (c *callback) Create() *callbackProcessor {
return c.addProcessor("create")
}
func (c *callback) Update() *callbackProcessor {
return c.addProcessor("update")
}
func (c *callback) Delete() *callbackProcessor {
return c.addProcessor("delete")
}
func (c *callback) Query() *callbackProcessor {
return c.addProcessor("query")
}
func (c *callback) RowQuery() *callbackProcessor {
return c.addProcessor("row_query")
}
func (cp *callbackProcessor) Before(name string) *callbackProcessor {
cp.before = name
return cp
}
func (cp *callbackProcessor) After(name string) *callbackProcessor {
cp.after = name
return cp
}
func (cp *callbackProcessor) Register(name string, fc func(scope *Scope)) {
cp.name = name
cp.processor = &fc
cp.callback.sort()
}
func (cp *callbackProcessor) Remove(name string) {
fmt.Printf("[info] removing callback `%v` from %v\n", name, fileWithLineNum())
cp.name = name
cp.remove = true
cp.callback.sort()
}
func (cp *callbackProcessor) Replace(name string, fc func(scope *Scope)) {
fmt.Printf("[info] replacing callback `%v` from %v\n", name, fileWithLineNum())
cp.name = name
cp.processor = &fc
cp.replace = true
cp.callback.sort()
}
func getRIndex(strs []string, str string) int {
for i := len(strs) - 1; i >= 0; i-- {
if strs[i] == str {
return i
}
}
return -1
}
func sortProcessors(cps []*callbackProcessor) []*func(scope *Scope) {
var sortCallbackProcessor func(c *callbackProcessor)
var names, sortedNames = []string{}, []string{}
for _, cp := range cps {
if index := getRIndex(names, cp.name); index > -1 {
if !cp.replace && !cp.remove {
fmt.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum())
}
}
names = append(names, cp.name)
}
sortCallbackProcessor = func(c *callbackProcessor) {
if getRIndex(sortedNames, c.name) > -1 {
return
}
if len(c.before) > 0 {
if index := getRIndex(sortedNames, c.before); index > -1 {
sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...)
} else if index := getRIndex(names, c.before); index > -1 {
sortedNames = append(sortedNames, c.name)
sortCallbackProcessor(cps[index])
} else {
sortedNames = append(sortedNames, c.name)
}
}
if len(c.after) > 0 {
if index := getRIndex(sortedNames, c.after); index > -1 {
sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...)
} else if index := getRIndex(names, c.after); index > -1 {
cp := cps[index]
if len(cp.before) == 0 {
cp.before = c.name
}
sortCallbackProcessor(cp)
} else {
sortedNames = append(sortedNames, c.name)
}
}
if getRIndex(sortedNames, c.name) == -1 {
sortedNames = append(sortedNames, c.name)
}
}
for _, cp := range cps {
sortCallbackProcessor(cp)
}
var funcs = []*func(scope *Scope){}
var sortedFuncs = []*func(scope *Scope){}
for _, name := range sortedNames {
index := getRIndex(names, name)
if !cps[index].remove {
sortedFuncs = append(sortedFuncs, cps[index].processor)
}
}
for _, cp := range cps {
if sindex := getRIndex(sortedNames, cp.name); sindex == -1 {
if !cp.remove {
funcs = append(funcs, cp.processor)
}
}
}
return append(sortedFuncs, funcs...)
}
func (c *callback) sort() {
var creates, updates, deletes, queries, rowQueries []*callbackProcessor
for _, processor := range c.processors {
switch processor.typ {
case "create":
creates = append(creates, processor)
case "update":
updates = append(updates, processor)
case "delete":
deletes = append(deletes, processor)
case "query":
queries = append(queries, processor)
case "row_query":
rowQueries = append(rowQueries, processor)
}
}
c.creates = sortProcessors(creates)
c.updates = sortProcessors(updates)
c.deletes = sortProcessors(deletes)
c.queries = sortProcessors(queries)
c.rowQueries = sortProcessors(rowQueries)
}
func ForceReload(scope *Scope) {
if _, ok := scope.InstanceGet("gorm:force_reload"); ok {
scope.DB().New().First(scope.Value)
}
}
func escapeIfNeeded(scope *Scope, value string) string {
trimmed := strings.TrimSpace(value)
// default:'string value' OR
if (strings.HasPrefix(trimmed, "'") && strings.HasSuffix(trimmed, "'")) ||
strings.HasSuffix(trimmed, ")") { //sql expression, like: default:"(now() at timezone 'utc') or now() or user_defined_function(parameters.. )
return trimmed
}
lowered := strings.ToLower(trimmed)
if lowered == "null" || strings.HasPrefix(lowered, "current_") { // null and other sql reserved keyworks (used a default values) can't be placed between apices
return lowered
}
return scope.AddToVars(trimmed) // default:'something' like:default:'false' should be between quotes (what AddToVars do)
}
func handleDefaultValue(scope *Scope, field *Field) string {
if field.IsBlank {
defaultValue := strings.TrimSpace(parseTagSetting(field.Tag.Get("sql"))["DEFAULT"])
switch field.Field.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if numericValue, err := strconv.ParseInt(defaultValue, 10, 64); err == nil {
if numericValue != field.Field.Int() {
return escapeIfNeeded(scope, fmt.Sprintf("%d", field.Field.Int()))
} else {
return escapeIfNeeded(scope, defaultValue)
}
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if numericValue, err := strconv.ParseUint(defaultValue, 10, 64); err == nil {
if numericValue != field.Field.Uint() {
return escapeIfNeeded(scope, escapeIfNeeded(scope, fmt.Sprintf("%d", field.Field.Int())))
} else {
return escapeIfNeeded(scope, defaultValue)
}
}
case reflect.Bool:
if boolValue, err := strconv.ParseBool(defaultValue); err == nil {
if boolValue != field.Field.Bool() {
return escapeIfNeeded(scope, fmt.Sprintf("%t", field.Field.Bool()))
} else {
return escapeIfNeeded(scope, defaultValue)
}
}
/*
case reflect.String:
if defaultValue != field.Field.String() {
return escapeIfNeeded(scope, fmt.Sprintf("%s", field.Field.String()))
} else {
return escapeIfNeeded(scope, defaultValue)
}
*/
default:
return escapeIfNeeded(scope, defaultValue)
}
}
return scope.AddToVars(field.Field.Interface())
}
var DefaultCallback = &callback{processors: []*callbackProcessor{}}