feat: go code style adjust and optimize code for callbacks package (#4861)
* feat: go code style adjust and optimize code for callbacks package * Update scan.go
This commit is contained in:
		
							parent
							
								
									b8f33a42a4
								
							
						
					
					
						commit
						9d5f315b6d
					
				| @ -39,7 +39,8 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) { | |||||||
| 				switch db.Statement.ReflectValue.Kind() { | 				switch db.Statement.ReflectValue.Kind() { | ||||||
| 				case reflect.Slice, reflect.Array: | 				case reflect.Slice, reflect.Array: | ||||||
| 					var ( | 					var ( | ||||||
| 						objs      = make([]reflect.Value, 0, db.Statement.ReflectValue.Len()) | 						rValLen   = db.Statement.ReflectValue.Len() | ||||||
|  | 						objs      = make([]reflect.Value, 0, rValLen) | ||||||
| 						fieldType = rel.Field.FieldType | 						fieldType = rel.Field.FieldType | ||||||
| 						isPtr     = fieldType.Kind() == reflect.Ptr | 						isPtr     = fieldType.Kind() == reflect.Ptr | ||||||
| 					) | 					) | ||||||
| @ -49,22 +50,21 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) { | |||||||
| 					} | 					} | ||||||
| 
 | 
 | ||||||
| 					elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) | 					elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) | ||||||
| 					for i := 0; i < db.Statement.ReflectValue.Len(); i++ { | 					for i := 0; i < rValLen; i++ { | ||||||
| 						obj := db.Statement.ReflectValue.Index(i) | 						obj := db.Statement.ReflectValue.Index(i) | ||||||
| 
 | 						if reflect.Indirect(obj).Kind() != reflect.Struct { | ||||||
| 						if reflect.Indirect(obj).Kind() == reflect.Struct { |  | ||||||
| 							if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value
 |  | ||||||
| 								rv := rel.Field.ReflectValueOf(obj) // relation reflect value
 |  | ||||||
| 								objs = append(objs, obj) |  | ||||||
| 								if isPtr { |  | ||||||
| 									elems = reflect.Append(elems, rv) |  | ||||||
| 								} else { |  | ||||||
| 									elems = reflect.Append(elems, rv.Addr()) |  | ||||||
| 								} |  | ||||||
| 							} |  | ||||||
| 						} else { |  | ||||||
| 							break | 							break | ||||||
| 						} | 						} | ||||||
|  | 
 | ||||||
|  | 						if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value
 | ||||||
|  | 							rv := rel.Field.ReflectValueOf(obj) // relation reflect value
 | ||||||
|  | 							objs = append(objs, obj) | ||||||
|  | 							if isPtr { | ||||||
|  | 								elems = reflect.Append(elems, rv) | ||||||
|  | 							} else { | ||||||
|  | 								elems = reflect.Append(elems, rv.Addr()) | ||||||
|  | 							} | ||||||
|  | 						} | ||||||
| 					} | 					} | ||||||
| 
 | 
 | ||||||
| 					if elems.Len() > 0 { | 					if elems.Len() > 0 { | ||||||
|  | |||||||
| @ -200,15 +200,16 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { | |||||||
| 
 | 
 | ||||||
| 		switch stmt.ReflectValue.Kind() { | 		switch stmt.ReflectValue.Kind() { | ||||||
| 		case reflect.Slice, reflect.Array: | 		case reflect.Slice, reflect.Array: | ||||||
| 			stmt.SQL.Grow(stmt.ReflectValue.Len() * 18) | 			rValLen := stmt.ReflectValue.Len() | ||||||
| 			values.Values = make([][]interface{}, stmt.ReflectValue.Len()) | 			stmt.SQL.Grow(rValLen * 18) | ||||||
| 			defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{} | 			values.Values = make([][]interface{}, rValLen) | ||||||
| 			if stmt.ReflectValue.Len() == 0 { | 			if rValLen == 0 { | ||||||
| 				stmt.AddError(gorm.ErrEmptySlice) | 				stmt.AddError(gorm.ErrEmptySlice) | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			for i := 0; i < stmt.ReflectValue.Len(); i++ { | 			defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{} | ||||||
|  | 			for i := 0; i < rValLen; i++ { | ||||||
| 				rv := reflect.Indirect(stmt.ReflectValue.Index(i)) | 				rv := reflect.Indirect(stmt.ReflectValue.Index(i)) | ||||||
| 				if !rv.IsValid() { | 				if !rv.IsValid() { | ||||||
| 					stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData)) | 					stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData)) | ||||||
| @ -234,11 +235,11 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { | |||||||
| 
 | 
 | ||||||
| 				for _, field := range stmt.Schema.FieldsWithDefaultDBValue { | 				for _, field := range stmt.Schema.FieldsWithDefaultDBValue { | ||||||
| 					if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { | 					if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { | ||||||
| 						if v, isZero := field.ValueOf(rv); !isZero { | 						if rvOfvalue, isZero := field.ValueOf(rv); !isZero { | ||||||
| 							if len(defaultValueFieldsHavingValue[field]) == 0 { | 							if len(defaultValueFieldsHavingValue[field]) == 0 { | ||||||
| 								defaultValueFieldsHavingValue[field] = make([]interface{}, stmt.ReflectValue.Len()) | 								defaultValueFieldsHavingValue[field] = make([]interface{}, rValLen) | ||||||
| 							} | 							} | ||||||
| 							defaultValueFieldsHavingValue[field][i] = v | 							defaultValueFieldsHavingValue[field][i] = rvOfvalue | ||||||
| 						} | 						} | ||||||
| 					} | 					} | ||||||
| 				} | 				} | ||||||
| @ -274,9 +275,9 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { | |||||||
| 
 | 
 | ||||||
| 			for _, field := range stmt.Schema.FieldsWithDefaultDBValue { | 			for _, field := range stmt.Schema.FieldsWithDefaultDBValue { | ||||||
| 				if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { | 				if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { | ||||||
| 					if v, isZero := field.ValueOf(stmt.ReflectValue); !isZero { | 					if rvOfvalue, isZero := field.ValueOf(stmt.ReflectValue); !isZero { | ||||||
| 						values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) | 						values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) | ||||||
| 						values.Values[0] = append(values.Values[0], v) | 						values.Values[0] = append(values.Values[0], rvOfvalue) | ||||||
| 					} | 					} | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
|  | |||||||
| @ -156,16 +156,19 @@ func Delete(config *Config) func(db *gorm.DB) { | |||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		if !db.DryRun && db.Error == nil { | 		if !db.DryRun && db.Error == nil { | ||||||
| 			if ok, mode := hasReturning(db, supportReturning); ok { | 			ok, mode := hasReturning(db, supportReturning) | ||||||
| 				if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { | 			if !ok { | ||||||
| 					gorm.Scan(rows, db, mode) |  | ||||||
| 					rows.Close() |  | ||||||
| 				} |  | ||||||
| 			} else { |  | ||||||
| 				result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | 				result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||||
| 				if db.AddError(err) == nil { | 				if db.AddError(err) == nil { | ||||||
| 					db.RowsAffected, _ = result.RowsAffected() | 					db.RowsAffected, _ = result.RowsAffected() | ||||||
| 				} | 				} | ||||||
|  | 
 | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { | ||||||
|  | 				gorm.Scan(rows, db, mode) | ||||||
|  | 				rows.Close() | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  | |||||||
| @ -61,12 +61,13 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload | |||||||
| 		fieldValues := make([]interface{}, len(joinForeignFields)) | 		fieldValues := make([]interface{}, len(joinForeignFields)) | ||||||
| 		joinFieldValues := make([]interface{}, len(joinRelForeignFields)) | 		joinFieldValues := make([]interface{}, len(joinRelForeignFields)) | ||||||
| 		for i := 0; i < joinResults.Len(); i++ { | 		for i := 0; i < joinResults.Len(); i++ { | ||||||
|  | 			joinIndexValue := joinResults.Index(i) | ||||||
| 			for idx, field := range joinForeignFields { | 			for idx, field := range joinForeignFields { | ||||||
| 				fieldValues[idx], _ = field.ValueOf(joinResults.Index(i)) | 				fieldValues[idx], _ = field.ValueOf(joinIndexValue) | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			for idx, field := range joinRelForeignFields { | 			for idx, field := range joinRelForeignFields { | ||||||
| 				joinFieldValues[idx], _ = field.ValueOf(joinResults.Index(i)) | 				joinFieldValues[idx], _ = field.ValueOf(joinIndexValue) | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok { | 			if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok { | ||||||
|  | |||||||
| @ -9,8 +9,9 @@ func RawExec(db *gorm.DB) { | |||||||
| 		result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | 		result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			db.AddError(err) | 			db.AddError(err) | ||||||
| 		} else { | 			return | ||||||
| 			db.RowsAffected, _ = result.RowsAffected() |  | ||||||
| 		} | 		} | ||||||
|  | 
 | ||||||
|  | 		db.RowsAffected, _ = result.RowsAffected() | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | |||||||
| @ -7,16 +7,17 @@ import ( | |||||||
| func RowQuery(db *gorm.DB) { | func RowQuery(db *gorm.DB) { | ||||||
| 	if db.Error == nil { | 	if db.Error == nil { | ||||||
| 		BuildQuerySQL(db) | 		BuildQuerySQL(db) | ||||||
| 
 | 		if db.DryRun { | ||||||
| 		if !db.DryRun { | 			return | ||||||
| 			if isRows, ok := db.Get("rows"); ok && isRows.(bool) { |  | ||||||
| 				db.Statement.Settings.Delete("rows") |  | ||||||
| 				db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) |  | ||||||
| 			} else { |  | ||||||
| 				db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) |  | ||||||
| 			} |  | ||||||
| 
 |  | ||||||
| 			db.RowsAffected = -1 |  | ||||||
| 		} | 		} | ||||||
|  | 
 | ||||||
|  | 		if isRows, ok := db.Get("rows"); ok && isRows.(bool) { | ||||||
|  | 			db.Statement.Settings.Delete("rows") | ||||||
|  | 			db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||||
|  | 		} else { | ||||||
|  | 			db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		db.RowsAffected = -1 | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | |||||||
| @ -20,11 +20,12 @@ func BeginTransaction(db *gorm.DB) { | |||||||
| func CommitOrRollbackTransaction(db *gorm.DB) { | func CommitOrRollbackTransaction(db *gorm.DB) { | ||||||
| 	if !db.Config.SkipDefaultTransaction { | 	if !db.Config.SkipDefaultTransaction { | ||||||
| 		if _, ok := db.InstanceGet("gorm:started_transaction"); ok { | 		if _, ok := db.InstanceGet("gorm:started_transaction"); ok { | ||||||
| 			if db.Error == nil { | 			if db.Error != nil { | ||||||
| 				db.Commit() |  | ||||||
| 			} else { |  | ||||||
| 				db.Rollback() | 				db.Rollback() | ||||||
|  | 			} else { | ||||||
|  | 				db.Commit() | ||||||
| 			} | 			} | ||||||
|  | 
 | ||||||
| 			db.Statement.ConnPool = db.ConnPool | 			db.Statement.ConnPool = db.ConnPool | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  | |||||||
| @ -157,7 +157,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { | |||||||
| 		case reflect.Slice, reflect.Array: | 		case reflect.Slice, reflect.Array: | ||||||
| 			if size := stmt.ReflectValue.Len(); size > 0 { | 			if size := stmt.ReflectValue.Len(); size > 0 { | ||||||
| 				var primaryKeyExprs []clause.Expression | 				var primaryKeyExprs []clause.Expression | ||||||
| 				for i := 0; i < stmt.ReflectValue.Len(); i++ { | 				for i := 0; i < size; i++ { | ||||||
| 					var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields)) | 					var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields)) | ||||||
| 					var notZero bool | 					var notZero bool | ||||||
| 					for idx, field := range stmt.Schema.PrimaryFields { | 					for idx, field := range stmt.Schema.PrimaryFields { | ||||||
|  | |||||||
| @ -156,7 +156,9 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (m Migrator) GetTables() (tableList []string, err error) { | func (m Migrator) GetTables() (tableList []string, err error) { | ||||||
| 	return tableList, m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?", m.CurrentDatabase()).Scan(&tableList).Error | 	err = m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?", m.CurrentDatabase()). | ||||||
|  | 		Scan(&tableList).Error | ||||||
|  | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (m Migrator) CreateTable(values ...interface{}) error { | func (m Migrator) CreateTable(values ...interface{}) error { | ||||||
|  | |||||||
							
								
								
									
										6
									
								
								scan.go
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								scan.go
									
									
									
									
									
								
							| @ -102,9 +102,9 @@ func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue re | |||||||
| type ScanMode uint8 | type ScanMode uint8 | ||||||
| 
 | 
 | ||||||
| const ( | const ( | ||||||
| 	ScanInitialized         ScanMode = 1 << 0 | 	ScanInitialized         ScanMode = 1 << 0    // 1
 | ||||||
| 	ScanUpdate                       = 1 << 1 | 	ScanUpdate              ScanMode = 1 << 1    // 2
 | ||||||
| 	ScanOnConflictDoNothing          = 1 << 2 | 	ScanOnConflictDoNothing ScanMode = 1 << 2    // 4
 | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func Scan(rows *sql.Rows, db *DB, mode ScanMode) { | func Scan(rows *sql.Rows, db *DB, mode ScanMode) { | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 heige
						heige