From d4e73a0980275f86950997be852ba9e8c9606ac9 Mon Sep 17 00:00:00 2001 From: hanfulin Date: Wed, 10 Apr 2024 19:54:07 +0800 Subject: [PATCH] simplify func --- callbacks/create.go | 165 +++++++++++++++++++++----------------------- 1 file changed, 80 insertions(+), 85 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index 8b7846b6..7feb560e 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -4,6 +4,7 @@ import ( "fmt" "reflect" "strings" + "time" "gorm.io/gorm" "gorm.io/gorm/clause" @@ -239,7 +240,6 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { var ( selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) _, updateTrackTime = stmt.Get("gorm:update_track_time") - isZero bool ) stmt.Settings.Delete("gorm:update_track_time") @@ -255,91 +255,9 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { switch stmt.ReflectValue.Kind() { case reflect.Slice, reflect.Array: - rValLen := stmt.ReflectValue.Len() - if rValLen == 0 { - stmt.AddError(gorm.ErrEmptySlice) - return - } - - stmt.SQL.Grow(rValLen * 18) - stmt.Vars = make([]interface{}, 0, rValLen*len(values.Columns)) - values.Values = make([][]interface{}, rValLen) - - defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{} - for i := 0; i < rValLen; i++ { - rv := reflect.Indirect(stmt.ReflectValue.Index(i)) - if !rv.IsValid() { - stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData)) - return - } - - values.Values[i] = make([]interface{}, len(values.Columns)) - for idx, column := range values.Columns { - field := stmt.Schema.FieldsByDBName[column.Name] - if values.Values[i][idx], isZero = field.ValueOf(stmt.Context, rv); isZero { - if field.DefaultValueInterface != nil { - values.Values[i][idx] = field.DefaultValueInterface - stmt.AddError(field.Set(stmt.Context, rv, field.DefaultValueInterface)) - } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { - stmt.AddError(field.Set(stmt.Context, rv, curTime)) - values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv) - } - } else if field.AutoUpdateTime > 0 && updateTrackTime { - stmt.AddError(field.Set(stmt.Context, rv, curTime)) - values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv) - } - } - - for _, field := range stmt.Schema.FieldsWithDefaultDBValue { - if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { - if rvOfvalue, isZero := field.ValueOf(stmt.Context, rv); !isZero { - if len(defaultValueFieldsHavingValue[field]) == 0 { - defaultValueFieldsHavingValue[field] = make([]interface{}, rValLen) - } - defaultValueFieldsHavingValue[field][i] = rvOfvalue - } - } - } - } - - for _, field := range stmt.Schema.FieldsWithDefaultDBValue { - if vs, ok := defaultValueFieldsHavingValue[field]; ok { - values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) - for idx := range values.Values { - if vs[idx] == nil { - values.Values[idx] = append(values.Values[idx], stmt.DefaultValueOf(field)) - } else { - values.Values[idx] = append(values.Values[idx], vs[idx]) - } - } - } - } + processSliceOrArray(&values, stmt, selectColumns, restricted, updateTrackTime, curTime) case reflect.Struct: - values.Values = [][]interface{}{make([]interface{}, len(values.Columns))} - for idx, column := range values.Columns { - field := stmt.Schema.FieldsByDBName[column.Name] - if values.Values[0][idx], isZero = field.ValueOf(stmt.Context, stmt.ReflectValue); isZero { - if field.DefaultValueInterface != nil { - values.Values[0][idx] = field.DefaultValueInterface - stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface)) - } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { - stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime)) - values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue) - } - } else if field.AutoUpdateTime > 0 && updateTrackTime { - stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime)) - values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue) - } - } - - for _, field := range stmt.Schema.FieldsWithDefaultDBValue { - if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) && field.DefaultValueInterface == nil { - if rvOfvalue, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero { - values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) - values.Values[0] = append(values.Values[0], rvOfvalue) - } - } - } + processStruct(&values, stmt, selectColumns, restricted, updateTrackTime, curTime) default: stmt.AddError(gorm.ErrInvalidData) } @@ -394,3 +312,80 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { return values } +func processSliceOrArray(values *clause.Values, stmt *gorm.Statement, selectColumns map[string]bool, restricted bool, updateTrackTime bool, curTime time.Time) { + rValLen := stmt.ReflectValue.Len() + if rValLen == 0 { + stmt.AddError(gorm.ErrEmptySlice) + return + } + + stmt.SQL.Grow(rValLen * 18) + stmt.Vars = make([]interface{}, 0, rValLen*len(values.Columns)) + values.Values = make([][]interface{}, rValLen) + + defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{} + for i := 0; i < rValLen; i++ { + rv := reflect.Indirect(stmt.ReflectValue.Index(i)) + if !rv.IsValid() { + stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData)) + return + } + + processRowValues(values, stmt, i, rv, selectColumns, restricted, updateTrackTime, curTime) + } + + processDefaultValues(values, stmt, defaultValueFieldsHavingValue) +} + +func processStruct(values *clause.Values, stmt *gorm.Statement, selectColumns map[string]bool, restricted bool, updateTrackTime bool, curTime time.Time) { + values.Values = [][]interface{}{make([]interface{}, len(values.Columns))} + rv := reflect.Indirect(stmt.ReflectValue) + processRowValues(values, stmt, 0, rv, selectColumns, restricted, updateTrackTime, curTime) + processDefaultValues(values, stmt, nil) +} + +func processRowValues(values *clause.Values, stmt *gorm.Statement, rowIndex int, rv reflect.Value, selectColumns map[string]bool, restricted bool, updateTrackTime bool, curTime time.Time) { + var isZero bool + for idx, column := range values.Columns { + field := stmt.Schema.FieldsByDBName[column.Name] + if values.Values[rowIndex][idx], isZero = field.ValueOf(stmt.Context, rv); isZero { + setDefaultValueOrAutoTime(values, stmt, field, rowIndex, idx, rv, curTime) + } else if field.AutoUpdateTime > 0 && updateTrackTime { + setDefaultValueOrAutoTime(values, stmt, field, rowIndex, idx, rv, curTime) + } + } + + for _, field := range stmt.Schema.FieldsWithDefaultDBValue { + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) && field.DefaultValueInterface == nil { + if rvOfvalue, isZero := field.ValueOf(stmt.Context, rv); !isZero { + values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) + values.Values[rowIndex] = append(values.Values[rowIndex], rvOfvalue) + } + } + } +} + +func processDefaultValues(values *clause.Values, stmt *gorm.Statement, defaultValueFieldsHavingValue map[*schema.Field][]interface{}) { + for _, field := range stmt.Schema.FieldsWithDefaultDBValue { + if vs, ok := defaultValueFieldsHavingValue[field]; ok { + values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) + for idx := range values.Values { + if vs[idx] == nil { + values.Values[idx] = append(values.Values[idx], stmt.DefaultValueOf(field)) + } else { + values.Values[idx] = append(values.Values[idx], vs[idx]) + } + } + } + } +} + +func setDefaultValueOrAutoTime(values *clause.Values, stmt *gorm.Statement, field *schema.Field, rowIndex, idx int, rv reflect.Value, curTime time.Time) { + if field.DefaultValueInterface != nil { + values.Values[rowIndex][idx] = field.DefaultValueInterface + stmt.AddError(field.Set(stmt.Context, rv, field.DefaultValueInterface)) + } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { + stmt.AddError(field.Set(stmt.Context, rv, curTime)) + values.Values[rowIndex][idx], _ = field.ValueOf(stmt.Context, rv) + } +}