simplify func
This commit is contained in:
		
							parent
							
								
									1e13fd7543
								
							
						
					
					
						commit
						d4e73a0980
					
				| @ -4,6 +4,7 @@ import ( | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"gorm.io/gorm" | ||||
| 	"gorm.io/gorm/clause" | ||||
| @ -239,7 +240,6 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { | ||||
| 		var ( | ||||
| 			selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) | ||||
| 			_, updateTrackTime        = stmt.Get("gorm:update_track_time") | ||||
| 			isZero                    bool | ||||
| 		) | ||||
| 		stmt.Settings.Delete("gorm:update_track_time") | ||||
| 
 | ||||
| @ -255,91 +255,9 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { | ||||
| 
 | ||||
| 		switch stmt.ReflectValue.Kind() { | ||||
| 		case reflect.Slice, reflect.Array: | ||||
| 			rValLen := stmt.ReflectValue.Len() | ||||
| 			if rValLen == 0 { | ||||
| 				stmt.AddError(gorm.ErrEmptySlice) | ||||
| 				return | ||||
| 			} | ||||
| 
 | ||||
| 			stmt.SQL.Grow(rValLen * 18) | ||||
| 			stmt.Vars = make([]interface{}, 0, rValLen*len(values.Columns)) | ||||
| 			values.Values = make([][]interface{}, rValLen) | ||||
| 
 | ||||
| 			defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{} | ||||
| 			for i := 0; i < rValLen; i++ { | ||||
| 				rv := reflect.Indirect(stmt.ReflectValue.Index(i)) | ||||
| 				if !rv.IsValid() { | ||||
| 					stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData)) | ||||
| 					return | ||||
| 				} | ||||
| 
 | ||||
| 				values.Values[i] = make([]interface{}, len(values.Columns)) | ||||
| 				for idx, column := range values.Columns { | ||||
| 					field := stmt.Schema.FieldsByDBName[column.Name] | ||||
| 					if values.Values[i][idx], isZero = field.ValueOf(stmt.Context, rv); isZero { | ||||
| 						if field.DefaultValueInterface != nil { | ||||
| 							values.Values[i][idx] = field.DefaultValueInterface | ||||
| 							stmt.AddError(field.Set(stmt.Context, rv, field.DefaultValueInterface)) | ||||
| 						} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { | ||||
| 							stmt.AddError(field.Set(stmt.Context, rv, curTime)) | ||||
| 							values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv) | ||||
| 						} | ||||
| 					} else if field.AutoUpdateTime > 0 && updateTrackTime { | ||||
| 						stmt.AddError(field.Set(stmt.Context, rv, curTime)) | ||||
| 						values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv) | ||||
| 					} | ||||
| 				} | ||||
| 
 | ||||
| 				for _, field := range stmt.Schema.FieldsWithDefaultDBValue { | ||||
| 					if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { | ||||
| 						if rvOfvalue, isZero := field.ValueOf(stmt.Context, rv); !isZero { | ||||
| 							if len(defaultValueFieldsHavingValue[field]) == 0 { | ||||
| 								defaultValueFieldsHavingValue[field] = make([]interface{}, rValLen) | ||||
| 							} | ||||
| 							defaultValueFieldsHavingValue[field][i] = rvOfvalue | ||||
| 						} | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			for _, field := range stmt.Schema.FieldsWithDefaultDBValue { | ||||
| 				if vs, ok := defaultValueFieldsHavingValue[field]; ok { | ||||
| 					values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) | ||||
| 					for idx := range values.Values { | ||||
| 						if vs[idx] == nil { | ||||
| 							values.Values[idx] = append(values.Values[idx], stmt.DefaultValueOf(field)) | ||||
| 						} else { | ||||
| 							values.Values[idx] = append(values.Values[idx], vs[idx]) | ||||
| 						} | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 			processSliceOrArray(&values, stmt, selectColumns, restricted, updateTrackTime, curTime) | ||||
| 		case reflect.Struct: | ||||
| 			values.Values = [][]interface{}{make([]interface{}, len(values.Columns))} | ||||
| 			for idx, column := range values.Columns { | ||||
| 				field := stmt.Schema.FieldsByDBName[column.Name] | ||||
| 				if values.Values[0][idx], isZero = field.ValueOf(stmt.Context, stmt.ReflectValue); isZero { | ||||
| 					if field.DefaultValueInterface != nil { | ||||
| 						values.Values[0][idx] = field.DefaultValueInterface | ||||
| 						stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface)) | ||||
| 					} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { | ||||
| 						stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime)) | ||||
| 						values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue) | ||||
| 					} | ||||
| 				} else if field.AutoUpdateTime > 0 && updateTrackTime { | ||||
| 					stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime)) | ||||
| 					values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue) | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			for _, field := range stmt.Schema.FieldsWithDefaultDBValue { | ||||
| 				if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) && field.DefaultValueInterface == nil { | ||||
| 					if rvOfvalue, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero { | ||||
| 						values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) | ||||
| 						values.Values[0] = append(values.Values[0], rvOfvalue) | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 			processStruct(&values, stmt, selectColumns, restricted, updateTrackTime, curTime) | ||||
| 		default: | ||||
| 			stmt.AddError(gorm.ErrInvalidData) | ||||
| 		} | ||||
| @ -394,3 +312,80 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { | ||||
| 
 | ||||
| 	return values | ||||
| } | ||||
| func processSliceOrArray(values *clause.Values, stmt *gorm.Statement, selectColumns map[string]bool, restricted bool, updateTrackTime bool, curTime time.Time) { | ||||
| 	rValLen := stmt.ReflectValue.Len() | ||||
| 	if rValLen == 0 { | ||||
| 		stmt.AddError(gorm.ErrEmptySlice) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	stmt.SQL.Grow(rValLen * 18) | ||||
| 	stmt.Vars = make([]interface{}, 0, rValLen*len(values.Columns)) | ||||
| 	values.Values = make([][]interface{}, rValLen) | ||||
| 
 | ||||
| 	defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{} | ||||
| 	for i := 0; i < rValLen; i++ { | ||||
| 		rv := reflect.Indirect(stmt.ReflectValue.Index(i)) | ||||
| 		if !rv.IsValid() { | ||||
| 			stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData)) | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		processRowValues(values, stmt, i, rv, selectColumns, restricted, updateTrackTime, curTime) | ||||
| 	} | ||||
| 
 | ||||
| 	processDefaultValues(values, stmt, defaultValueFieldsHavingValue) | ||||
| } | ||||
| 
 | ||||
| func processStruct(values *clause.Values, stmt *gorm.Statement, selectColumns map[string]bool, restricted bool, updateTrackTime bool, curTime time.Time) { | ||||
| 	values.Values = [][]interface{}{make([]interface{}, len(values.Columns))} | ||||
| 	rv := reflect.Indirect(stmt.ReflectValue) | ||||
| 	processRowValues(values, stmt, 0, rv, selectColumns, restricted, updateTrackTime, curTime) | ||||
| 	processDefaultValues(values, stmt, nil) | ||||
| } | ||||
| 
 | ||||
| func processRowValues(values *clause.Values, stmt *gorm.Statement, rowIndex int, rv reflect.Value, selectColumns map[string]bool, restricted bool, updateTrackTime bool, curTime time.Time) { | ||||
| 	var isZero bool | ||||
| 	for idx, column := range values.Columns { | ||||
| 		field := stmt.Schema.FieldsByDBName[column.Name] | ||||
| 		if values.Values[rowIndex][idx], isZero = field.ValueOf(stmt.Context, rv); isZero { | ||||
| 			setDefaultValueOrAutoTime(values, stmt, field, rowIndex, idx, rv, curTime) | ||||
| 		} else if field.AutoUpdateTime > 0 && updateTrackTime { | ||||
| 			setDefaultValueOrAutoTime(values, stmt, field, rowIndex, idx, rv, curTime) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	for _, field := range stmt.Schema.FieldsWithDefaultDBValue { | ||||
| 		if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) && field.DefaultValueInterface == nil { | ||||
| 			if rvOfvalue, isZero := field.ValueOf(stmt.Context, rv); !isZero { | ||||
| 				values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) | ||||
| 				values.Values[rowIndex] = append(values.Values[rowIndex], rvOfvalue) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func processDefaultValues(values *clause.Values, stmt *gorm.Statement, defaultValueFieldsHavingValue map[*schema.Field][]interface{}) { | ||||
| 	for _, field := range stmt.Schema.FieldsWithDefaultDBValue { | ||||
| 		if vs, ok := defaultValueFieldsHavingValue[field]; ok { | ||||
| 			values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) | ||||
| 			for idx := range values.Values { | ||||
| 				if vs[idx] == nil { | ||||
| 					values.Values[idx] = append(values.Values[idx], stmt.DefaultValueOf(field)) | ||||
| 				} else { | ||||
| 					values.Values[idx] = append(values.Values[idx], vs[idx]) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func setDefaultValueOrAutoTime(values *clause.Values, stmt *gorm.Statement, field *schema.Field, rowIndex, idx int, rv reflect.Value, curTime time.Time) { | ||||
| 	if field.DefaultValueInterface != nil { | ||||
| 		values.Values[rowIndex][idx] = field.DefaultValueInterface | ||||
| 		stmt.AddError(field.Set(stmt.Context, rv, field.DefaultValueInterface)) | ||||
| 	} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { | ||||
| 		stmt.AddError(field.Set(stmt.Context, rv, curTime)) | ||||
| 		values.Values[rowIndex][idx], _ = field.ValueOf(stmt.Context, rv) | ||||
| 	} | ||||
| } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 hanfulin
						hanfulin