
use golangci/golangci-lint-action instead of reviewdog/action-golangci-lint as the second was not reporting any failures even if there was some. Report code coverage with codecov/codecov-action I have set some flags per dialect and go version Several linters has been fixed, some disabled so the build can pass
712 lines
19 KiB
Go
712 lines
19 KiB
Go
package gorm
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"database/sql/driver"
|
|
"fmt"
|
|
"reflect"
|
|
"regexp"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
|
|
"gorm.io/gorm/clause"
|
|
"gorm.io/gorm/logger"
|
|
"gorm.io/gorm/schema"
|
|
"gorm.io/gorm/utils"
|
|
)
|
|
|
|
// Statement statement
|
|
type Statement struct {
|
|
*DB
|
|
TableExpr *clause.Expr
|
|
Table string
|
|
Model interface{}
|
|
Unscoped bool
|
|
Dest interface{}
|
|
ReflectValue reflect.Value
|
|
Clauses map[string]clause.Clause
|
|
BuildClauses []string
|
|
Distinct bool
|
|
Selects []string // selected columns
|
|
Omits []string // omit columns
|
|
Joins []join
|
|
Preloads map[string][]interface{}
|
|
Settings sync.Map
|
|
ConnPool ConnPool
|
|
Schema *schema.Schema
|
|
Context context.Context
|
|
RaiseErrorOnNotFound bool
|
|
SkipHooks bool
|
|
SQL strings.Builder
|
|
Vars []interface{}
|
|
CurDestIndex int
|
|
attrs []interface{}
|
|
assigns []interface{}
|
|
scopes []func(*DB) *DB
|
|
}
|
|
|
|
type join struct {
|
|
Name string
|
|
Conds []interface{}
|
|
On *clause.Where
|
|
}
|
|
|
|
// StatementModifier statement modifier interface
|
|
type StatementModifier interface {
|
|
ModifyStatement(*Statement)
|
|
}
|
|
|
|
// WriteString write string
|
|
func (stmt *Statement) WriteString(str string) (int, error) {
|
|
return stmt.SQL.WriteString(str)
|
|
}
|
|
|
|
// WriteByte write byte
|
|
func (stmt *Statement) WriteByte(c byte) error {
|
|
return stmt.SQL.WriteByte(c)
|
|
}
|
|
|
|
// WriteQuoted write quoted value
|
|
func (stmt *Statement) WriteQuoted(value interface{}) {
|
|
stmt.QuoteTo(&stmt.SQL, value)
|
|
}
|
|
|
|
// QuoteTo write quoted value to writer
|
|
func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
|
|
write := func(raw bool, str string) {
|
|
if raw {
|
|
writer.WriteString(str)
|
|
} else {
|
|
stmt.DB.Dialector.QuoteTo(writer, str)
|
|
}
|
|
}
|
|
|
|
switch v := field.(type) {
|
|
case clause.Table:
|
|
if v.Name == clause.CurrentTable {
|
|
if stmt.TableExpr != nil {
|
|
stmt.TableExpr.Build(stmt)
|
|
} else {
|
|
write(v.Raw, stmt.Table)
|
|
}
|
|
} else {
|
|
write(v.Raw, v.Name)
|
|
}
|
|
|
|
if v.Alias != "" {
|
|
writer.WriteByte(' ')
|
|
write(v.Raw, v.Alias)
|
|
}
|
|
case clause.Column:
|
|
if v.Table != "" {
|
|
if v.Table == clause.CurrentTable {
|
|
write(v.Raw, stmt.Table)
|
|
} else {
|
|
write(v.Raw, v.Table)
|
|
}
|
|
writer.WriteByte('.')
|
|
}
|
|
|
|
if v.Name == clause.PrimaryKey {
|
|
if stmt.Schema == nil {
|
|
stmt.DB.AddError(ErrModelValueRequired)
|
|
} else if stmt.Schema.PrioritizedPrimaryField != nil {
|
|
write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName)
|
|
} else if len(stmt.Schema.DBNames) > 0 {
|
|
write(v.Raw, stmt.Schema.DBNames[0])
|
|
}
|
|
} else {
|
|
write(v.Raw, v.Name)
|
|
}
|
|
|
|
if v.Alias != "" {
|
|
writer.WriteString(" AS ")
|
|
write(v.Raw, v.Alias)
|
|
}
|
|
case []clause.Column:
|
|
writer.WriteByte('(')
|
|
for idx, d := range v {
|
|
if idx > 0 {
|
|
writer.WriteString(",")
|
|
}
|
|
stmt.QuoteTo(writer, d)
|
|
}
|
|
writer.WriteByte(')')
|
|
case clause.Expr:
|
|
v.Build(stmt)
|
|
case string:
|
|
stmt.DB.Dialector.QuoteTo(writer, v)
|
|
case []string:
|
|
writer.WriteByte('(')
|
|
for idx, d := range v {
|
|
if idx > 0 {
|
|
writer.WriteString(",")
|
|
}
|
|
stmt.DB.Dialector.QuoteTo(writer, d)
|
|
}
|
|
writer.WriteByte(')')
|
|
default:
|
|
stmt.DB.Dialector.QuoteTo(writer, fmt.Sprint(field))
|
|
}
|
|
}
|
|
|
|
// Quote returns quoted value
|
|
func (stmt *Statement) Quote(field interface{}) string {
|
|
var builder strings.Builder
|
|
stmt.QuoteTo(&builder, field)
|
|
return builder.String()
|
|
}
|
|
|
|
// AddVar add var
|
|
func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
|
|
for idx, v := range vars {
|
|
if idx > 0 {
|
|
writer.WriteByte(',')
|
|
}
|
|
|
|
switch v := v.(type) {
|
|
case sql.NamedArg:
|
|
stmt.Vars = append(stmt.Vars, v.Value)
|
|
case clause.Column, clause.Table:
|
|
stmt.QuoteTo(writer, v)
|
|
case Valuer:
|
|
reflectValue := reflect.ValueOf(v)
|
|
if reflectValue.Kind() == reflect.Ptr && reflectValue.IsNil() {
|
|
stmt.AddVar(writer, nil)
|
|
} else {
|
|
stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB))
|
|
}
|
|
case clause.Expr:
|
|
v.Build(stmt)
|
|
case *clause.Expr:
|
|
v.Build(stmt)
|
|
case driver.Valuer:
|
|
stmt.Vars = append(stmt.Vars, v)
|
|
stmt.DB.Dialector.BindVarTo(writer, stmt, v)
|
|
case []byte:
|
|
stmt.Vars = append(stmt.Vars, v)
|
|
stmt.DB.Dialector.BindVarTo(writer, stmt, v)
|
|
case []interface{}:
|
|
if len(v) > 0 {
|
|
writer.WriteByte('(')
|
|
stmt.AddVar(writer, v...)
|
|
writer.WriteByte(')')
|
|
} else {
|
|
writer.WriteString("(NULL)")
|
|
}
|
|
case *DB:
|
|
subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance()
|
|
if v.Statement.SQL.Len() > 0 {
|
|
var (
|
|
vars = subdb.Statement.Vars
|
|
sql = v.Statement.SQL.String()
|
|
)
|
|
|
|
subdb.Statement.Vars = make([]interface{}, 0, len(vars))
|
|
for _, vv := range vars {
|
|
subdb.Statement.Vars = append(subdb.Statement.Vars, vv)
|
|
bindvar := strings.Builder{}
|
|
v.Dialector.BindVarTo(&bindvar, subdb.Statement, vv)
|
|
sql = strings.Replace(sql, bindvar.String(), "?", 1)
|
|
}
|
|
|
|
subdb.Statement.SQL.Reset()
|
|
subdb.Statement.Vars = stmt.Vars
|
|
if strings.Contains(sql, "@") {
|
|
clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement)
|
|
} else {
|
|
clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement)
|
|
}
|
|
} else {
|
|
subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...)
|
|
subdb.callbacks.Query().Execute(subdb)
|
|
}
|
|
|
|
writer.WriteString(subdb.Statement.SQL.String())
|
|
stmt.Vars = subdb.Statement.Vars
|
|
default:
|
|
switch rv := reflect.ValueOf(v); rv.Kind() {
|
|
case reflect.Slice, reflect.Array:
|
|
if rv.Len() == 0 {
|
|
writer.WriteString("(NULL)")
|
|
} else {
|
|
writer.WriteByte('(')
|
|
for i := 0; i < rv.Len(); i++ {
|
|
if i > 0 {
|
|
writer.WriteByte(',')
|
|
}
|
|
stmt.AddVar(writer, rv.Index(i).Interface())
|
|
}
|
|
writer.WriteByte(')')
|
|
}
|
|
default:
|
|
stmt.Vars = append(stmt.Vars, v)
|
|
stmt.DB.Dialector.BindVarTo(writer, stmt, v)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// AddClause add clause
|
|
func (stmt *Statement) AddClause(v clause.Interface) {
|
|
if optimizer, ok := v.(StatementModifier); ok {
|
|
optimizer.ModifyStatement(stmt)
|
|
} else {
|
|
name := v.Name()
|
|
c := stmt.Clauses[name]
|
|
c.Name = name
|
|
v.MergeClause(&c)
|
|
stmt.Clauses[name] = c
|
|
}
|
|
}
|
|
|
|
// AddClauseIfNotExists add clause if not exists
|
|
func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) {
|
|
if c, ok := stmt.Clauses[v.Name()]; !ok || c.Expression == nil {
|
|
stmt.AddClause(v)
|
|
}
|
|
}
|
|
|
|
// BuildCondition build condition
|
|
func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []clause.Expression {
|
|
if s, ok := query.(string); ok {
|
|
// if it is a number, then treats it as primary key
|
|
if _, err := strconv.Atoi(s); err != nil {
|
|
if s == "" && len(args) == 0 {
|
|
return nil
|
|
}
|
|
|
|
if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) {
|
|
// looks like a where condition
|
|
return []clause.Expression{clause.Expr{SQL: s, Vars: args}}
|
|
}
|
|
|
|
if len(args) > 0 && strings.Contains(s, "@") {
|
|
// looks like a named query
|
|
return []clause.Expression{clause.NamedExpr{SQL: s, Vars: args}}
|
|
}
|
|
|
|
if strings.Contains(strings.TrimSpace(s), " ") {
|
|
// looks like a where condition
|
|
return []clause.Expression{clause.Expr{SQL: s, Vars: args}}
|
|
}
|
|
|
|
if len(args) == 1 {
|
|
return []clause.Expression{clause.Eq{Column: s, Value: args[0]}}
|
|
}
|
|
}
|
|
}
|
|
|
|
conds := make([]clause.Expression, 0, 4)
|
|
args = append([]interface{}{query}, args...)
|
|
for idx, arg := range args {
|
|
if valuer, ok := arg.(driver.Valuer); ok {
|
|
arg, _ = valuer.Value()
|
|
}
|
|
|
|
switch v := arg.(type) {
|
|
case clause.Expression:
|
|
conds = append(conds, v)
|
|
case *DB:
|
|
if cs, ok := v.Statement.Clauses["WHERE"]; ok {
|
|
if where, ok := cs.Expression.(clause.Where); ok {
|
|
if len(where.Exprs) == 1 {
|
|
if orConds, ok := where.Exprs[0].(clause.OrConditions); ok {
|
|
where.Exprs[0] = clause.AndConditions(orConds)
|
|
}
|
|
}
|
|
conds = append(conds, clause.And(where.Exprs...))
|
|
} else if cs.Expression != nil {
|
|
conds = append(conds, cs.Expression)
|
|
}
|
|
}
|
|
case map[interface{}]interface{}:
|
|
for i, j := range v {
|
|
conds = append(conds, clause.Eq{Column: i, Value: j})
|
|
}
|
|
case map[string]string:
|
|
keys := make([]string, 0, len(v))
|
|
for i := range v {
|
|
keys = append(keys, i)
|
|
}
|
|
sort.Strings(keys)
|
|
|
|
for _, key := range keys {
|
|
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
|
|
}
|
|
case map[string]interface{}:
|
|
keys := make([]string, 0, len(v))
|
|
for i := range v {
|
|
keys = append(keys, i)
|
|
}
|
|
sort.Strings(keys)
|
|
|
|
for _, key := range keys {
|
|
reflectValue := reflect.Indirect(reflect.ValueOf(v[key]))
|
|
switch reflectValue.Kind() {
|
|
case reflect.Slice, reflect.Array:
|
|
if _, ok := v[key].(driver.Valuer); ok {
|
|
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
|
|
} else if _, ok := v[key].(Valuer); ok {
|
|
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
|
|
} else {
|
|
// optimize reflect value length
|
|
valueLen := reflectValue.Len()
|
|
values := make([]interface{}, valueLen)
|
|
for i := 0; i < valueLen; i++ {
|
|
values[i] = reflectValue.Index(i).Interface()
|
|
}
|
|
|
|
conds = append(conds, clause.IN{Column: key, Values: values})
|
|
}
|
|
default:
|
|
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
|
|
}
|
|
}
|
|
default:
|
|
reflectValue := reflect.Indirect(reflect.ValueOf(arg))
|
|
for reflectValue.Kind() == reflect.Ptr {
|
|
reflectValue = reflectValue.Elem()
|
|
}
|
|
|
|
if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil {
|
|
selectedColumns := map[string]bool{}
|
|
if idx == 0 {
|
|
for _, v := range args[1:] {
|
|
if vs, ok := v.(string); ok {
|
|
selectedColumns[vs] = true
|
|
}
|
|
}
|
|
}
|
|
restricted := len(selectedColumns) != 0
|
|
|
|
switch reflectValue.Kind() {
|
|
case reflect.Struct:
|
|
for _, field := range s.Fields {
|
|
selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
|
|
if selected || (!restricted && field.Readable) {
|
|
if v, isZero := field.ValueOf(reflectValue); !isZero || selected {
|
|
if field.DBName != "" {
|
|
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
|
|
} else if field.DataType != "" {
|
|
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
|
|
}
|
|
}
|
|
}
|
|
}
|
|
case reflect.Slice, reflect.Array:
|
|
for i := 0; i < reflectValue.Len(); i++ {
|
|
for _, field := range s.Fields {
|
|
selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
|
|
if selected || (!restricted && field.Readable) {
|
|
if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero || selected {
|
|
if field.DBName != "" {
|
|
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
|
|
} else if field.DataType != "" {
|
|
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if restricted {
|
|
break
|
|
}
|
|
} else if !reflectValue.IsValid() {
|
|
stmt.AddError(ErrInvalidData)
|
|
} else if len(conds) == 0 {
|
|
if len(args) == 1 {
|
|
switch reflectValue.Kind() {
|
|
case reflect.Slice, reflect.Array:
|
|
// optimize reflect value length
|
|
valueLen := reflectValue.Len()
|
|
values := make([]interface{}, valueLen)
|
|
for i := 0; i < valueLen; i++ {
|
|
values[i] = reflectValue.Index(i).Interface()
|
|
}
|
|
|
|
if len(values) > 0 {
|
|
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values})
|
|
}
|
|
return conds
|
|
}
|
|
}
|
|
|
|
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args})
|
|
}
|
|
}
|
|
}
|
|
|
|
return conds
|
|
}
|
|
|
|
// Build build sql with clauses names
|
|
func (stmt *Statement) Build(clauses ...string) {
|
|
var firstClauseWritten bool
|
|
|
|
for _, name := range clauses {
|
|
if c, ok := stmt.Clauses[name]; ok {
|
|
if firstClauseWritten {
|
|
stmt.WriteByte(' ')
|
|
}
|
|
|
|
firstClauseWritten = true
|
|
if b, ok := stmt.DB.ClauseBuilders[name]; ok {
|
|
b(c, stmt)
|
|
} else {
|
|
c.Build(stmt)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (stmt *Statement) Parse(value interface{}) (err error) {
|
|
return stmt.ParseWithSpecialTableName(value, "")
|
|
}
|
|
|
|
func (stmt *Statement) ParseWithSpecialTableName(value interface{}, specialTableName string) (err error) {
|
|
if stmt.Schema, err = schema.ParseWithSpecialTableName(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, specialTableName); err == nil && stmt.Table == "" {
|
|
if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 {
|
|
stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)}
|
|
stmt.Table = tables[1]
|
|
return
|
|
}
|
|
|
|
stmt.Table = stmt.Schema.Table
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (stmt *Statement) clone() *Statement {
|
|
newStmt := &Statement{
|
|
TableExpr: stmt.TableExpr,
|
|
Table: stmt.Table,
|
|
Model: stmt.Model,
|
|
Unscoped: stmt.Unscoped,
|
|
Dest: stmt.Dest,
|
|
ReflectValue: stmt.ReflectValue,
|
|
Clauses: map[string]clause.Clause{},
|
|
Distinct: stmt.Distinct,
|
|
Selects: stmt.Selects,
|
|
Omits: stmt.Omits,
|
|
Preloads: map[string][]interface{}{},
|
|
ConnPool: stmt.ConnPool,
|
|
Schema: stmt.Schema,
|
|
Context: stmt.Context,
|
|
RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound,
|
|
SkipHooks: stmt.SkipHooks,
|
|
}
|
|
|
|
if stmt.SQL.Len() > 0 {
|
|
newStmt.SQL.WriteString(stmt.SQL.String())
|
|
newStmt.Vars = make([]interface{}, 0, len(stmt.Vars))
|
|
newStmt.Vars = append(newStmt.Vars, stmt.Vars...)
|
|
}
|
|
|
|
for k, c := range stmt.Clauses {
|
|
newStmt.Clauses[k] = c
|
|
}
|
|
|
|
for k, p := range stmt.Preloads {
|
|
newStmt.Preloads[k] = p
|
|
}
|
|
|
|
if len(stmt.Joins) > 0 {
|
|
newStmt.Joins = make([]join, len(stmt.Joins))
|
|
copy(newStmt.Joins, stmt.Joins)
|
|
}
|
|
|
|
if len(stmt.scopes) > 0 {
|
|
newStmt.scopes = make([]func(*DB) *DB, len(stmt.scopes))
|
|
copy(newStmt.scopes, stmt.scopes)
|
|
}
|
|
|
|
stmt.Settings.Range(func(k, v interface{}) bool {
|
|
newStmt.Settings.Store(k, v)
|
|
return true
|
|
})
|
|
|
|
return newStmt
|
|
}
|
|
|
|
// SetColumn set column's value
|
|
// stmt.SetColumn("Name", "jinzhu") // Hooks Method
|
|
// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method
|
|
func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) {
|
|
if v, ok := stmt.Dest.(map[string]interface{}); ok {
|
|
v[name] = value
|
|
} else if v, ok := stmt.Dest.([]map[string]interface{}); ok {
|
|
for _, m := range v {
|
|
m[name] = value
|
|
}
|
|
} else if stmt.Schema != nil {
|
|
if field := stmt.Schema.LookUpField(name); field != nil {
|
|
destValue := reflect.ValueOf(stmt.Dest)
|
|
for destValue.Kind() == reflect.Ptr {
|
|
destValue = destValue.Elem()
|
|
}
|
|
|
|
if stmt.ReflectValue != destValue {
|
|
if !destValue.CanAddr() {
|
|
destValueCanAddr := reflect.New(destValue.Type())
|
|
destValueCanAddr.Elem().Set(destValue)
|
|
stmt.Dest = destValueCanAddr.Interface()
|
|
destValue = destValueCanAddr.Elem()
|
|
}
|
|
|
|
switch destValue.Kind() {
|
|
case reflect.Struct:
|
|
field.Set(destValue, value)
|
|
default:
|
|
stmt.AddError(ErrInvalidData)
|
|
}
|
|
}
|
|
|
|
switch stmt.ReflectValue.Kind() {
|
|
case reflect.Slice, reflect.Array:
|
|
if len(fromCallbacks) > 0 {
|
|
for i := 0; i < stmt.ReflectValue.Len(); i++ {
|
|
field.Set(stmt.ReflectValue.Index(i), value)
|
|
}
|
|
} else {
|
|
field.Set(stmt.ReflectValue.Index(stmt.CurDestIndex), value)
|
|
}
|
|
case reflect.Struct:
|
|
if !stmt.ReflectValue.CanAddr() {
|
|
stmt.AddError(ErrInvalidValue)
|
|
return
|
|
}
|
|
|
|
field.Set(stmt.ReflectValue, value)
|
|
}
|
|
} else {
|
|
stmt.AddError(ErrInvalidField)
|
|
}
|
|
} else {
|
|
stmt.AddError(ErrInvalidField)
|
|
}
|
|
}
|
|
|
|
// Changed check model changed or not when updating
|
|
func (stmt *Statement) Changed(fields ...string) bool {
|
|
modelValue := stmt.ReflectValue
|
|
switch modelValue.Kind() {
|
|
case reflect.Slice, reflect.Array:
|
|
modelValue = stmt.ReflectValue.Index(stmt.CurDestIndex)
|
|
}
|
|
|
|
selectColumns, restricted := stmt.SelectAndOmitColumns(false, true)
|
|
changed := func(field *schema.Field) bool {
|
|
fieldValue, _ := field.ValueOf(modelValue)
|
|
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
|
if v, ok := stmt.Dest.(map[string]interface{}); ok {
|
|
if fv, ok := v[field.Name]; ok {
|
|
return !utils.AssertEqual(fv, fieldValue)
|
|
} else if fv, ok := v[field.DBName]; ok {
|
|
return !utils.AssertEqual(fv, fieldValue)
|
|
}
|
|
} else {
|
|
destValue := reflect.ValueOf(stmt.Dest)
|
|
for destValue.Kind() == reflect.Ptr {
|
|
destValue = destValue.Elem()
|
|
}
|
|
|
|
changedValue, zero := field.ValueOf(destValue)
|
|
return !zero && !utils.AssertEqual(changedValue, fieldValue)
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
if len(fields) == 0 {
|
|
for _, field := range stmt.Schema.FieldsByDBName {
|
|
if changed(field) {
|
|
return true
|
|
}
|
|
}
|
|
} else {
|
|
for _, name := range fields {
|
|
if field := stmt.Schema.LookUpField(name); field != nil {
|
|
if changed(field) {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
var nameMatcher = regexp.MustCompile(`^[\W]?(?:[a-z_]+?)[\W]?\.[\W]?([a-z_]+?)[\W]?$`)
|
|
|
|
// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
|
|
func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) {
|
|
results := map[string]bool{}
|
|
notRestricted := false
|
|
|
|
// select columns
|
|
for _, column := range stmt.Selects {
|
|
if stmt.Schema == nil {
|
|
results[column] = true
|
|
} else if column == "*" {
|
|
notRestricted = true
|
|
for _, dbName := range stmt.Schema.DBNames {
|
|
results[dbName] = true
|
|
}
|
|
} else if column == clause.Associations {
|
|
for _, rel := range stmt.Schema.Relationships.Relations {
|
|
results[rel.Name] = true
|
|
}
|
|
} else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
|
|
results[field.DBName] = true
|
|
} else if matches := nameMatcher.FindStringSubmatch(column); len(matches) == 2 {
|
|
results[matches[1]] = true
|
|
} else {
|
|
results[column] = true
|
|
}
|
|
}
|
|
|
|
// omit columns
|
|
for _, omit := range stmt.Omits {
|
|
if stmt.Schema == nil {
|
|
results[omit] = false
|
|
} else if omit == "*" {
|
|
for _, dbName := range stmt.Schema.DBNames {
|
|
results[dbName] = false
|
|
}
|
|
} else if omit == clause.Associations {
|
|
for _, rel := range stmt.Schema.Relationships.Relations {
|
|
results[rel.Name] = false
|
|
}
|
|
} else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" {
|
|
results[field.DBName] = false
|
|
} else if matches := nameMatcher.FindStringSubmatch(omit); len(matches) == 2 {
|
|
results[matches[1]] = false
|
|
} else {
|
|
results[omit] = false
|
|
}
|
|
}
|
|
|
|
if stmt.Schema != nil {
|
|
for _, field := range stmt.Schema.FieldsByName {
|
|
name := field.DBName
|
|
if name == "" {
|
|
name = field.Name
|
|
}
|
|
|
|
if requireCreate && !field.Creatable {
|
|
results[name] = false
|
|
} else if requireUpdate && !field.Updatable {
|
|
results[name] = false
|
|
}
|
|
}
|
|
}
|
|
|
|
return results, !notRestricted && len(stmt.Selects) > 0
|
|
}
|