feat: go code style adjust and optimize code for callbacks package
This commit is contained in:
		
							parent
							
								
									5e64ac7de9
								
							
						
					
					
						commit
						7f6ca2ee20
					
				@ -39,7 +39,8 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
 | 
				
			|||||||
				switch db.Statement.ReflectValue.Kind() {
 | 
									switch db.Statement.ReflectValue.Kind() {
 | 
				
			||||||
				case reflect.Slice, reflect.Array:
 | 
									case reflect.Slice, reflect.Array:
 | 
				
			||||||
					var (
 | 
										var (
 | 
				
			||||||
						objs      = make([]reflect.Value, 0, db.Statement.ReflectValue.Len())
 | 
											rValLen   = db.Statement.ReflectValue.Len()
 | 
				
			||||||
 | 
											objs      = make([]reflect.Value, 0, rValLen)
 | 
				
			||||||
						fieldType = rel.Field.FieldType
 | 
											fieldType = rel.Field.FieldType
 | 
				
			||||||
						isPtr     = fieldType.Kind() == reflect.Ptr
 | 
											isPtr     = fieldType.Kind() == reflect.Ptr
 | 
				
			||||||
					)
 | 
										)
 | 
				
			||||||
@ -49,22 +50,21 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
 | 
				
			|||||||
					}
 | 
										}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
					elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
 | 
										elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
 | 
				
			||||||
					for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
 | 
										for i := 0; i < rValLen; i++ {
 | 
				
			||||||
						obj := db.Statement.ReflectValue.Index(i)
 | 
											obj := db.Statement.ReflectValue.Index(i)
 | 
				
			||||||
 | 
											if reflect.Indirect(obj).Kind() != reflect.Struct {
 | 
				
			||||||
						if reflect.Indirect(obj).Kind() == reflect.Struct {
 | 
					 | 
				
			||||||
							if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value
 | 
					 | 
				
			||||||
								rv := rel.Field.ReflectValueOf(obj) // relation reflect value
 | 
					 | 
				
			||||||
								objs = append(objs, obj)
 | 
					 | 
				
			||||||
								if isPtr {
 | 
					 | 
				
			||||||
									elems = reflect.Append(elems, rv)
 | 
					 | 
				
			||||||
								} else {
 | 
					 | 
				
			||||||
									elems = reflect.Append(elems, rv.Addr())
 | 
					 | 
				
			||||||
								}
 | 
					 | 
				
			||||||
							}
 | 
					 | 
				
			||||||
						} else {
 | 
					 | 
				
			||||||
							break
 | 
												break
 | 
				
			||||||
						}
 | 
											}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
											if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value
 | 
				
			||||||
 | 
												rv := rel.Field.ReflectValueOf(obj) // relation reflect value
 | 
				
			||||||
 | 
												objs = append(objs, obj)
 | 
				
			||||||
 | 
												if isPtr {
 | 
				
			||||||
 | 
													elems = reflect.Append(elems, rv)
 | 
				
			||||||
 | 
												} else {
 | 
				
			||||||
 | 
													elems = reflect.Append(elems, rv.Addr())
 | 
				
			||||||
 | 
												}
 | 
				
			||||||
 | 
											}
 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
					if elems.Len() > 0 {
 | 
										if elems.Len() > 0 {
 | 
				
			||||||
 | 
				
			|||||||
@ -200,15 +200,16 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
		switch stmt.ReflectValue.Kind() {
 | 
							switch stmt.ReflectValue.Kind() {
 | 
				
			||||||
		case reflect.Slice, reflect.Array:
 | 
							case reflect.Slice, reflect.Array:
 | 
				
			||||||
			stmt.SQL.Grow(stmt.ReflectValue.Len() * 18)
 | 
								rValLen := stmt.ReflectValue.Len()
 | 
				
			||||||
			values.Values = make([][]interface{}, stmt.ReflectValue.Len())
 | 
								stmt.SQL.Grow(rValLen * 18)
 | 
				
			||||||
			defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{}
 | 
								values.Values = make([][]interface{}, rValLen)
 | 
				
			||||||
			if stmt.ReflectValue.Len() == 0 {
 | 
								if rValLen == 0 {
 | 
				
			||||||
				stmt.AddError(gorm.ErrEmptySlice)
 | 
									stmt.AddError(gorm.ErrEmptySlice)
 | 
				
			||||||
				return
 | 
									return
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			for i := 0; i < stmt.ReflectValue.Len(); i++ {
 | 
								defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{}
 | 
				
			||||||
 | 
								for i := 0; i < rValLen; i++ {
 | 
				
			||||||
				rv := reflect.Indirect(stmt.ReflectValue.Index(i))
 | 
									rv := reflect.Indirect(stmt.ReflectValue.Index(i))
 | 
				
			||||||
				if !rv.IsValid() {
 | 
									if !rv.IsValid() {
 | 
				
			||||||
					stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData))
 | 
										stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData))
 | 
				
			||||||
@ -234,11 +235,11 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
				for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
 | 
									for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
 | 
				
			||||||
					if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
 | 
										if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
 | 
				
			||||||
						if v, isZero := field.ValueOf(rv); !isZero {
 | 
											if rvOfvalue, isZero := field.ValueOf(rv); !isZero {
 | 
				
			||||||
							if len(defaultValueFieldsHavingValue[field]) == 0 {
 | 
												if len(defaultValueFieldsHavingValue[field]) == 0 {
 | 
				
			||||||
								defaultValueFieldsHavingValue[field] = make([]interface{}, stmt.ReflectValue.Len())
 | 
													defaultValueFieldsHavingValue[field] = make([]interface{}, rValLen)
 | 
				
			||||||
							}
 | 
												}
 | 
				
			||||||
							defaultValueFieldsHavingValue[field][i] = v
 | 
												defaultValueFieldsHavingValue[field][i] = rvOfvalue
 | 
				
			||||||
						}
 | 
											}
 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
@ -274,9 +275,9 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
			for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
 | 
								for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
 | 
				
			||||||
				if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
 | 
									if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
 | 
				
			||||||
					if v, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
 | 
										if rvOfvalue, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
 | 
				
			||||||
						values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
 | 
											values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
 | 
				
			||||||
						values.Values[0] = append(values.Values[0], v)
 | 
											values.Values[0] = append(values.Values[0], rvOfvalue)
 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
				
			|||||||
@ -156,16 +156,19 @@ func Delete(config *Config) func(db *gorm.DB) {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if !db.DryRun && db.Error == nil {
 | 
							if !db.DryRun && db.Error == nil {
 | 
				
			||||||
			if ok, mode := hasReturning(db, supportReturning); ok {
 | 
								ok, mode := hasReturning(db, supportReturning)
 | 
				
			||||||
				if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
 | 
								if !ok {
 | 
				
			||||||
					gorm.Scan(rows, db, mode)
 | 
					 | 
				
			||||||
					rows.Close()
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
 | 
									result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
 | 
				
			||||||
				if db.AddError(err) == nil {
 | 
									if db.AddError(err) == nil {
 | 
				
			||||||
					db.RowsAffected, _ = result.RowsAffected()
 | 
										db.RowsAffected, _ = result.RowsAffected()
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									return
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
 | 
				
			||||||
 | 
									gorm.Scan(rows, db, mode)
 | 
				
			||||||
 | 
									rows.Close()
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
				
			|||||||
@ -61,12 +61,13 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
 | 
				
			|||||||
		fieldValues := make([]interface{}, len(joinForeignFields))
 | 
							fieldValues := make([]interface{}, len(joinForeignFields))
 | 
				
			||||||
		joinFieldValues := make([]interface{}, len(joinRelForeignFields))
 | 
							joinFieldValues := make([]interface{}, len(joinRelForeignFields))
 | 
				
			||||||
		for i := 0; i < joinResults.Len(); i++ {
 | 
							for i := 0; i < joinResults.Len(); i++ {
 | 
				
			||||||
 | 
								joinIndexValue := joinResults.Index(i)
 | 
				
			||||||
			for idx, field := range joinForeignFields {
 | 
								for idx, field := range joinForeignFields {
 | 
				
			||||||
				fieldValues[idx], _ = field.ValueOf(joinResults.Index(i))
 | 
									fieldValues[idx], _ = field.ValueOf(joinIndexValue)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			for idx, field := range joinRelForeignFields {
 | 
								for idx, field := range joinRelForeignFields {
 | 
				
			||||||
				joinFieldValues[idx], _ = field.ValueOf(joinResults.Index(i))
 | 
									joinFieldValues[idx], _ = field.ValueOf(joinIndexValue)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok {
 | 
								if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok {
 | 
				
			||||||
 | 
				
			|||||||
@ -9,8 +9,9 @@ func RawExec(db *gorm.DB) {
 | 
				
			|||||||
		result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
 | 
							result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			db.AddError(err)
 | 
								db.AddError(err)
 | 
				
			||||||
		} else {
 | 
								return
 | 
				
			||||||
			db.RowsAffected, _ = result.RowsAffected()
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							db.RowsAffected, _ = result.RowsAffected()
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -7,16 +7,17 @@ import (
 | 
				
			|||||||
func RowQuery(db *gorm.DB) {
 | 
					func RowQuery(db *gorm.DB) {
 | 
				
			||||||
	if db.Error == nil {
 | 
						if db.Error == nil {
 | 
				
			||||||
		BuildQuerySQL(db)
 | 
							BuildQuerySQL(db)
 | 
				
			||||||
 | 
							if db.DryRun {
 | 
				
			||||||
		if !db.DryRun {
 | 
								return
 | 
				
			||||||
			if isRows, ok := db.Get("rows"); ok && isRows.(bool) {
 | 
					 | 
				
			||||||
				db.Statement.Settings.Delete("rows")
 | 
					 | 
				
			||||||
				db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			db.RowsAffected = -1
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if isRows, ok := db.Get("rows"); ok && isRows.(bool) {
 | 
				
			||||||
 | 
								db.Statement.Settings.Delete("rows")
 | 
				
			||||||
 | 
								db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							db.RowsAffected = -1
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -20,11 +20,12 @@ func BeginTransaction(db *gorm.DB) {
 | 
				
			|||||||
func CommitOrRollbackTransaction(db *gorm.DB) {
 | 
					func CommitOrRollbackTransaction(db *gorm.DB) {
 | 
				
			||||||
	if !db.Config.SkipDefaultTransaction {
 | 
						if !db.Config.SkipDefaultTransaction {
 | 
				
			||||||
		if _, ok := db.InstanceGet("gorm:started_transaction"); ok {
 | 
							if _, ok := db.InstanceGet("gorm:started_transaction"); ok {
 | 
				
			||||||
			if db.Error == nil {
 | 
								if db.Error != nil {
 | 
				
			||||||
				db.Commit()
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				db.Rollback()
 | 
									db.Rollback()
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									db.Commit()
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			db.Statement.ConnPool = db.ConnPool
 | 
								db.Statement.ConnPool = db.ConnPool
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
				
			|||||||
@ -157,7 +157,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
 | 
				
			|||||||
		case reflect.Slice, reflect.Array:
 | 
							case reflect.Slice, reflect.Array:
 | 
				
			||||||
			if size := stmt.ReflectValue.Len(); size > 0 {
 | 
								if size := stmt.ReflectValue.Len(); size > 0 {
 | 
				
			||||||
				var primaryKeyExprs []clause.Expression
 | 
									var primaryKeyExprs []clause.Expression
 | 
				
			||||||
				for i := 0; i < stmt.ReflectValue.Len(); i++ {
 | 
									for i := 0; i < size; i++ {
 | 
				
			||||||
					var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields))
 | 
										var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields))
 | 
				
			||||||
					var notZero bool
 | 
										var notZero bool
 | 
				
			||||||
					for idx, field := range stmt.Schema.PrimaryFields {
 | 
										for idx, field := range stmt.Schema.PrimaryFields {
 | 
				
			||||||
 | 
				
			|||||||
@ -156,7 +156,9 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (m Migrator) GetTables() (tableList []string, err error) {
 | 
					func (m Migrator) GetTables() (tableList []string, err error) {
 | 
				
			||||||
	return tableList, m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?", m.CurrentDatabase()).Scan(&tableList).Error
 | 
						err = m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?", m.CurrentDatabase()).
 | 
				
			||||||
 | 
							Scan(&tableList).Error
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (m Migrator) CreateTable(values ...interface{}) error {
 | 
					func (m Migrator) CreateTable(values ...interface{}) error {
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										6
									
								
								scan.go
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								scan.go
									
									
									
									
									
								
							@ -102,9 +102,9 @@ func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue re
 | 
				
			|||||||
type ScanMode uint8
 | 
					type ScanMode uint8
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const (
 | 
					const (
 | 
				
			||||||
	ScanInitialized         ScanMode = 1 << 0
 | 
						ScanInitialized         ScanMode = 1 << iota // 1
 | 
				
			||||||
	ScanUpdate                       = 1 << 1
 | 
						ScanUpdate                                   // 2
 | 
				
			||||||
	ScanOnConflictDoNothing          = 1 << 2
 | 
						ScanOnConflictDoNothing                      // 4
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
 | 
					func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user