simplify func

This commit is contained in:
hanfulin 2024-04-10 19:54:07 +08:00
parent 1e13fd7543
commit d4e73a0980

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
"time"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
@ -239,7 +240,6 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
var ( var (
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
_, updateTrackTime = stmt.Get("gorm:update_track_time") _, updateTrackTime = stmt.Get("gorm:update_track_time")
isZero bool
) )
stmt.Settings.Delete("gorm:update_track_time") stmt.Settings.Delete("gorm:update_track_time")
@ -255,91 +255,9 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
switch stmt.ReflectValue.Kind() { switch stmt.ReflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
rValLen := stmt.ReflectValue.Len() processSliceOrArray(&values, stmt, selectColumns, restricted, updateTrackTime, curTime)
if rValLen == 0 {
stmt.AddError(gorm.ErrEmptySlice)
return
}
stmt.SQL.Grow(rValLen * 18)
stmt.Vars = make([]interface{}, 0, rValLen*len(values.Columns))
values.Values = make([][]interface{}, rValLen)
defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{}
for i := 0; i < rValLen; i++ {
rv := reflect.Indirect(stmt.ReflectValue.Index(i))
if !rv.IsValid() {
stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData))
return
}
values.Values[i] = make([]interface{}, len(values.Columns))
for idx, column := range values.Columns {
field := stmt.Schema.FieldsByDBName[column.Name]
if values.Values[i][idx], isZero = field.ValueOf(stmt.Context, rv); isZero {
if field.DefaultValueInterface != nil {
values.Values[i][idx] = field.DefaultValueInterface
stmt.AddError(field.Set(stmt.Context, rv, field.DefaultValueInterface))
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
stmt.AddError(field.Set(stmt.Context, rv, curTime))
values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv)
}
} else if field.AutoUpdateTime > 0 && updateTrackTime {
stmt.AddError(field.Set(stmt.Context, rv, curTime))
values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv)
}
}
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if rvOfvalue, isZero := field.ValueOf(stmt.Context, rv); !isZero {
if len(defaultValueFieldsHavingValue[field]) == 0 {
defaultValueFieldsHavingValue[field] = make([]interface{}, rValLen)
}
defaultValueFieldsHavingValue[field][i] = rvOfvalue
}
}
}
}
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
if vs, ok := defaultValueFieldsHavingValue[field]; ok {
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
for idx := range values.Values {
if vs[idx] == nil {
values.Values[idx] = append(values.Values[idx], stmt.DefaultValueOf(field))
} else {
values.Values[idx] = append(values.Values[idx], vs[idx])
}
}
}
}
case reflect.Struct: case reflect.Struct:
values.Values = [][]interface{}{make([]interface{}, len(values.Columns))} processStruct(&values, stmt, selectColumns, restricted, updateTrackTime, curTime)
for idx, column := range values.Columns {
field := stmt.Schema.FieldsByDBName[column.Name]
if values.Values[0][idx], isZero = field.ValueOf(stmt.Context, stmt.ReflectValue); isZero {
if field.DefaultValueInterface != nil {
values.Values[0][idx] = field.DefaultValueInterface
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface))
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime))
values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)
}
} else if field.AutoUpdateTime > 0 && updateTrackTime {
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime))
values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)
}
}
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) && field.DefaultValueInterface == nil {
if rvOfvalue, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero {
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
values.Values[0] = append(values.Values[0], rvOfvalue)
}
}
}
default: default:
stmt.AddError(gorm.ErrInvalidData) stmt.AddError(gorm.ErrInvalidData)
} }
@ -394,3 +312,80 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
return values return values
} }
func processSliceOrArray(values *clause.Values, stmt *gorm.Statement, selectColumns map[string]bool, restricted bool, updateTrackTime bool, curTime time.Time) {
rValLen := stmt.ReflectValue.Len()
if rValLen == 0 {
stmt.AddError(gorm.ErrEmptySlice)
return
}
stmt.SQL.Grow(rValLen * 18)
stmt.Vars = make([]interface{}, 0, rValLen*len(values.Columns))
values.Values = make([][]interface{}, rValLen)
defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{}
for i := 0; i < rValLen; i++ {
rv := reflect.Indirect(stmt.ReflectValue.Index(i))
if !rv.IsValid() {
stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData))
return
}
processRowValues(values, stmt, i, rv, selectColumns, restricted, updateTrackTime, curTime)
}
processDefaultValues(values, stmt, defaultValueFieldsHavingValue)
}
func processStruct(values *clause.Values, stmt *gorm.Statement, selectColumns map[string]bool, restricted bool, updateTrackTime bool, curTime time.Time) {
values.Values = [][]interface{}{make([]interface{}, len(values.Columns))}
rv := reflect.Indirect(stmt.ReflectValue)
processRowValues(values, stmt, 0, rv, selectColumns, restricted, updateTrackTime, curTime)
processDefaultValues(values, stmt, nil)
}
func processRowValues(values *clause.Values, stmt *gorm.Statement, rowIndex int, rv reflect.Value, selectColumns map[string]bool, restricted bool, updateTrackTime bool, curTime time.Time) {
var isZero bool
for idx, column := range values.Columns {
field := stmt.Schema.FieldsByDBName[column.Name]
if values.Values[rowIndex][idx], isZero = field.ValueOf(stmt.Context, rv); isZero {
setDefaultValueOrAutoTime(values, stmt, field, rowIndex, idx, rv, curTime)
} else if field.AutoUpdateTime > 0 && updateTrackTime {
setDefaultValueOrAutoTime(values, stmt, field, rowIndex, idx, rv, curTime)
}
}
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) && field.DefaultValueInterface == nil {
if rvOfvalue, isZero := field.ValueOf(stmt.Context, rv); !isZero {
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
values.Values[rowIndex] = append(values.Values[rowIndex], rvOfvalue)
}
}
}
}
func processDefaultValues(values *clause.Values, stmt *gorm.Statement, defaultValueFieldsHavingValue map[*schema.Field][]interface{}) {
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
if vs, ok := defaultValueFieldsHavingValue[field]; ok {
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
for idx := range values.Values {
if vs[idx] == nil {
values.Values[idx] = append(values.Values[idx], stmt.DefaultValueOf(field))
} else {
values.Values[idx] = append(values.Values[idx], vs[idx])
}
}
}
}
}
func setDefaultValueOrAutoTime(values *clause.Values, stmt *gorm.Statement, field *schema.Field, rowIndex, idx int, rv reflect.Value, curTime time.Time) {
if field.DefaultValueInterface != nil {
values.Values[rowIndex][idx] = field.DefaultValueInterface
stmt.AddError(field.Set(stmt.Context, rv, field.DefaultValueInterface))
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
stmt.AddError(field.Set(stmt.Context, rv, curTime))
values.Values[rowIndex][idx], _ = field.ValueOf(stmt.Context, rv)
}
}