simplify func
This commit is contained in:
parent
1e13fd7543
commit
d4e73a0980
@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
@ -239,7 +240,6 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
||||
var (
|
||||
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
|
||||
_, updateTrackTime = stmt.Get("gorm:update_track_time")
|
||||
isZero bool
|
||||
)
|
||||
stmt.Settings.Delete("gorm:update_track_time")
|
||||
|
||||
@ -255,91 +255,9 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
||||
|
||||
switch stmt.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
rValLen := stmt.ReflectValue.Len()
|
||||
if rValLen == 0 {
|
||||
stmt.AddError(gorm.ErrEmptySlice)
|
||||
return
|
||||
}
|
||||
|
||||
stmt.SQL.Grow(rValLen * 18)
|
||||
stmt.Vars = make([]interface{}, 0, rValLen*len(values.Columns))
|
||||
values.Values = make([][]interface{}, rValLen)
|
||||
|
||||
defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{}
|
||||
for i := 0; i < rValLen; i++ {
|
||||
rv := reflect.Indirect(stmt.ReflectValue.Index(i))
|
||||
if !rv.IsValid() {
|
||||
stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData))
|
||||
return
|
||||
}
|
||||
|
||||
values.Values[i] = make([]interface{}, len(values.Columns))
|
||||
for idx, column := range values.Columns {
|
||||
field := stmt.Schema.FieldsByDBName[column.Name]
|
||||
if values.Values[i][idx], isZero = field.ValueOf(stmt.Context, rv); isZero {
|
||||
if field.DefaultValueInterface != nil {
|
||||
values.Values[i][idx] = field.DefaultValueInterface
|
||||
stmt.AddError(field.Set(stmt.Context, rv, field.DefaultValueInterface))
|
||||
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
|
||||
stmt.AddError(field.Set(stmt.Context, rv, curTime))
|
||||
values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv)
|
||||
}
|
||||
} else if field.AutoUpdateTime > 0 && updateTrackTime {
|
||||
stmt.AddError(field.Set(stmt.Context, rv, curTime))
|
||||
values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv)
|
||||
}
|
||||
}
|
||||
|
||||
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||
if rvOfvalue, isZero := field.ValueOf(stmt.Context, rv); !isZero {
|
||||
if len(defaultValueFieldsHavingValue[field]) == 0 {
|
||||
defaultValueFieldsHavingValue[field] = make([]interface{}, rValLen)
|
||||
}
|
||||
defaultValueFieldsHavingValue[field][i] = rvOfvalue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
|
||||
if vs, ok := defaultValueFieldsHavingValue[field]; ok {
|
||||
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
|
||||
for idx := range values.Values {
|
||||
if vs[idx] == nil {
|
||||
values.Values[idx] = append(values.Values[idx], stmt.DefaultValueOf(field))
|
||||
} else {
|
||||
values.Values[idx] = append(values.Values[idx], vs[idx])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
processSliceOrArray(&values, stmt, selectColumns, restricted, updateTrackTime, curTime)
|
||||
case reflect.Struct:
|
||||
values.Values = [][]interface{}{make([]interface{}, len(values.Columns))}
|
||||
for idx, column := range values.Columns {
|
||||
field := stmt.Schema.FieldsByDBName[column.Name]
|
||||
if values.Values[0][idx], isZero = field.ValueOf(stmt.Context, stmt.ReflectValue); isZero {
|
||||
if field.DefaultValueInterface != nil {
|
||||
values.Values[0][idx] = field.DefaultValueInterface
|
||||
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface))
|
||||
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
|
||||
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime))
|
||||
values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)
|
||||
}
|
||||
} else if field.AutoUpdateTime > 0 && updateTrackTime {
|
||||
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime))
|
||||
values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)
|
||||
}
|
||||
}
|
||||
|
||||
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) && field.DefaultValueInterface == nil {
|
||||
if rvOfvalue, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero {
|
||||
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
|
||||
values.Values[0] = append(values.Values[0], rvOfvalue)
|
||||
}
|
||||
}
|
||||
}
|
||||
processStruct(&values, stmt, selectColumns, restricted, updateTrackTime, curTime)
|
||||
default:
|
||||
stmt.AddError(gorm.ErrInvalidData)
|
||||
}
|
||||
@ -394,3 +312,80 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
||||
|
||||
return values
|
||||
}
|
||||
func processSliceOrArray(values *clause.Values, stmt *gorm.Statement, selectColumns map[string]bool, restricted bool, updateTrackTime bool, curTime time.Time) {
|
||||
rValLen := stmt.ReflectValue.Len()
|
||||
if rValLen == 0 {
|
||||
stmt.AddError(gorm.ErrEmptySlice)
|
||||
return
|
||||
}
|
||||
|
||||
stmt.SQL.Grow(rValLen * 18)
|
||||
stmt.Vars = make([]interface{}, 0, rValLen*len(values.Columns))
|
||||
values.Values = make([][]interface{}, rValLen)
|
||||
|
||||
defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{}
|
||||
for i := 0; i < rValLen; i++ {
|
||||
rv := reflect.Indirect(stmt.ReflectValue.Index(i))
|
||||
if !rv.IsValid() {
|
||||
stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData))
|
||||
return
|
||||
}
|
||||
|
||||
processRowValues(values, stmt, i, rv, selectColumns, restricted, updateTrackTime, curTime)
|
||||
}
|
||||
|
||||
processDefaultValues(values, stmt, defaultValueFieldsHavingValue)
|
||||
}
|
||||
|
||||
func processStruct(values *clause.Values, stmt *gorm.Statement, selectColumns map[string]bool, restricted bool, updateTrackTime bool, curTime time.Time) {
|
||||
values.Values = [][]interface{}{make([]interface{}, len(values.Columns))}
|
||||
rv := reflect.Indirect(stmt.ReflectValue)
|
||||
processRowValues(values, stmt, 0, rv, selectColumns, restricted, updateTrackTime, curTime)
|
||||
processDefaultValues(values, stmt, nil)
|
||||
}
|
||||
|
||||
func processRowValues(values *clause.Values, stmt *gorm.Statement, rowIndex int, rv reflect.Value, selectColumns map[string]bool, restricted bool, updateTrackTime bool, curTime time.Time) {
|
||||
var isZero bool
|
||||
for idx, column := range values.Columns {
|
||||
field := stmt.Schema.FieldsByDBName[column.Name]
|
||||
if values.Values[rowIndex][idx], isZero = field.ValueOf(stmt.Context, rv); isZero {
|
||||
setDefaultValueOrAutoTime(values, stmt, field, rowIndex, idx, rv, curTime)
|
||||
} else if field.AutoUpdateTime > 0 && updateTrackTime {
|
||||
setDefaultValueOrAutoTime(values, stmt, field, rowIndex, idx, rv, curTime)
|
||||
}
|
||||
}
|
||||
|
||||
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) && field.DefaultValueInterface == nil {
|
||||
if rvOfvalue, isZero := field.ValueOf(stmt.Context, rv); !isZero {
|
||||
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
|
||||
values.Values[rowIndex] = append(values.Values[rowIndex], rvOfvalue)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func processDefaultValues(values *clause.Values, stmt *gorm.Statement, defaultValueFieldsHavingValue map[*schema.Field][]interface{}) {
|
||||
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
|
||||
if vs, ok := defaultValueFieldsHavingValue[field]; ok {
|
||||
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
|
||||
for idx := range values.Values {
|
||||
if vs[idx] == nil {
|
||||
values.Values[idx] = append(values.Values[idx], stmt.DefaultValueOf(field))
|
||||
} else {
|
||||
values.Values[idx] = append(values.Values[idx], vs[idx])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func setDefaultValueOrAutoTime(values *clause.Values, stmt *gorm.Statement, field *schema.Field, rowIndex, idx int, rv reflect.Value, curTime time.Time) {
|
||||
if field.DefaultValueInterface != nil {
|
||||
values.Values[rowIndex][idx] = field.DefaultValueInterface
|
||||
stmt.AddError(field.Set(stmt.Context, rv, field.DefaultValueInterface))
|
||||
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
|
||||
stmt.AddError(field.Set(stmt.Context, rv, curTime))
|
||||
values.Values[rowIndex][idx], _ = field.ValueOf(stmt.Context, rv)
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user