Code optimize (#4415)
* optimize gormSourceDir replace * fmt.Errorf adjust and Optimize for-break * strings trim * feat: avoid using the same name field and if..else optimization adjustment * optimization callbacks/create.go Create func if...else logic * fix: callbacks/create.go Create func * fix FileWithLineNum func and add gormSourceDir unit test * remove debug print and utils_filenum_test.go
This commit is contained in:
		
							parent
							
								
									00b252559f
								
							
						
					
					
						commit
						50e85e14d4
					
				| @ -26,7 +26,7 @@ func (db *DB) Association(column string) *Association { | |||||||
| 		association.Relationship = db.Statement.Schema.Relationships.Relations[column] | 		association.Relationship = db.Statement.Schema.Relationships.Relations[column] | ||||||
| 
 | 
 | ||||||
| 		if association.Relationship == nil { | 		if association.Relationship == nil { | ||||||
| 			association.Error = fmt.Errorf("%w: %v", ErrUnsupportedRelation, column) | 			association.Error = fmt.Errorf("%w: %s", ErrUnsupportedRelation, column) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model) | 		db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model) | ||||||
| @ -355,7 +355,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ | |||||||
| 				} else if ev.Type().Elem().AssignableTo(elemType) { | 				} else if ev.Type().Elem().AssignableTo(elemType) { | ||||||
| 					fieldValue = reflect.Append(fieldValue, ev.Elem()) | 					fieldValue = reflect.Append(fieldValue, ev.Elem()) | ||||||
| 				} else { | 				} else { | ||||||
| 					association.Error = fmt.Errorf("unsupported data type: %v for relation %v", ev.Type(), association.Relationship.Name) | 					association.Error = fmt.Errorf("unsupported data type: %v for relation %s", ev.Type(), association.Relationship.Name) | ||||||
| 				} | 				} | ||||||
| 
 | 
 | ||||||
| 				if elemType.Kind() == reflect.Struct { | 				if elemType.Kind() == reflect.Struct { | ||||||
|  | |||||||
							
								
								
									
										10
									
								
								callbacks.go
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								callbacks.go
									
									
									
									
									
								
							| @ -212,7 +212,7 @@ func (c *callback) Register(name string, fn func(*DB)) error { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (c *callback) Remove(name string) error { | func (c *callback) Remove(name string) error { | ||||||
| 	c.processor.db.Logger.Warn(context.Background(), "removing callback `%v` from %v\n", name, utils.FileWithLineNum()) | 	c.processor.db.Logger.Warn(context.Background(), "removing callback `%s` from %s\n", name, utils.FileWithLineNum()) | ||||||
| 	c.name = name | 	c.name = name | ||||||
| 	c.remove = true | 	c.remove = true | ||||||
| 	c.processor.callbacks = append(c.processor.callbacks, c) | 	c.processor.callbacks = append(c.processor.callbacks, c) | ||||||
| @ -220,7 +220,7 @@ func (c *callback) Remove(name string) error { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (c *callback) Replace(name string, fn func(*DB)) error { | func (c *callback) Replace(name string, fn func(*DB)) error { | ||||||
| 	c.processor.db.Logger.Info(context.Background(), "replacing callback `%v` from %v\n", name, utils.FileWithLineNum()) | 	c.processor.db.Logger.Info(context.Background(), "replacing callback `%s` from %s\n", name, utils.FileWithLineNum()) | ||||||
| 	c.name = name | 	c.name = name | ||||||
| 	c.handler = fn | 	c.handler = fn | ||||||
| 	c.replace = true | 	c.replace = true | ||||||
| @ -250,7 +250,7 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { | |||||||
| 	for _, c := range cs { | 	for _, c := range cs { | ||||||
| 		// show warning message the callback name already exists
 | 		// show warning message the callback name already exists
 | ||||||
| 		if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove { | 		if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove { | ||||||
| 			c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum()) | 			c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%s` from %s\n", c.name, utils.FileWithLineNum()) | ||||||
| 		} | 		} | ||||||
| 		names = append(names, c.name) | 		names = append(names, c.name) | ||||||
| 	} | 	} | ||||||
| @ -266,7 +266,7 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { | |||||||
| 					// if before callback already sorted, append current callback just after it
 | 					// if before callback already sorted, append current callback just after it
 | ||||||
| 					sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...) | 					sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...) | ||||||
| 				} else if curIdx > sortedIdx { | 				} else if curIdx > sortedIdx { | ||||||
| 					return fmt.Errorf("conflicting callback %v with before %v", c.name, c.before) | 					return fmt.Errorf("conflicting callback %s with before %s", c.name, c.before) | ||||||
| 				} | 				} | ||||||
| 			} else if idx := getRIndex(names, c.before); idx != -1 { | 			} else if idx := getRIndex(names, c.before); idx != -1 { | ||||||
| 				// if before callback exists
 | 				// if before callback exists
 | ||||||
| @ -284,7 +284,7 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { | |||||||
| 					// if after callback sorted, append current callback to last
 | 					// if after callback sorted, append current callback to last
 | ||||||
| 					sorted = append(sorted, c.name) | 					sorted = append(sorted, c.name) | ||||||
| 				} else if curIdx < sortedIdx { | 				} else if curIdx < sortedIdx { | ||||||
| 					return fmt.Errorf("conflicting callback %v with before %v", c.name, c.after) | 					return fmt.Errorf("conflicting callback %s with before %s", c.name, c.after) | ||||||
| 				} | 				} | ||||||
| 			} else if idx := getRIndex(names, c.after); idx != -1 { | 			} else if idx := getRIndex(names, c.after); idx != -1 { | ||||||
| 				// if after callback exists but haven't sorted
 | 				// if after callback exists but haven't sorted
 | ||||||
|  | |||||||
| @ -33,9 +33,14 @@ func BeforeCreate(db *gorm.DB) { | |||||||
| func Create(config *Config) func(db *gorm.DB) { | func Create(config *Config) func(db *gorm.DB) { | ||||||
| 	if config.WithReturning { | 	if config.WithReturning { | ||||||
| 		return CreateWithReturning | 		return CreateWithReturning | ||||||
| 	} else { | 	} | ||||||
|  | 
 | ||||||
| 	return func(db *gorm.DB) { | 	return func(db *gorm.DB) { | ||||||
| 			if db.Error == nil { | 		if db.Error != nil { | ||||||
|  | 			// maybe record logger TODO
 | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
| 		if db.Statement.Schema != nil && !db.Statement.Unscoped { | 		if db.Statement.Schema != nil && !db.Statement.Unscoped { | ||||||
| 			for _, c := range db.Statement.Schema.CreateClauses { | 			for _, c := range db.Statement.Schema.CreateClauses { | ||||||
| 				db.Statement.AddClause(c) | 				db.Statement.AddClause(c) | ||||||
| @ -53,10 +58,16 @@ func Create(config *Config) func(db *gorm.DB) { | |||||||
| 		if !db.DryRun && db.Error == nil { | 		if !db.DryRun && db.Error == nil { | ||||||
| 			result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | 			result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||||
| 
 | 
 | ||||||
| 					if err == nil { | 			if err != nil { | ||||||
| 						db.RowsAffected, _ = result.RowsAffected() | 				db.AddError(err) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			db.RowsAffected, _ = result.RowsAffected() | ||||||
|  | 			if !(db.RowsAffected > 0) { | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
| 
 | 
 | ||||||
| 						if db.RowsAffected > 0 { |  | ||||||
| 			if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { | 			if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { | ||||||
| 				if insertID, err := result.LastInsertId(); err == nil && insertID > 0 { | 				if insertID, err := result.LastInsertId(); err == nil && insertID > 0 { | ||||||
| 					switch db.Statement.ReflectValue.Kind() { | 					switch db.Statement.ReflectValue.Kind() { | ||||||
| @ -96,12 +107,7 @@ func Create(config *Config) func(db *gorm.DB) { | |||||||
| 					db.AddError(err) | 					db.AddError(err) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 						} | 
 | ||||||
| 					} else { |  | ||||||
| 						db.AddError(err) |  | ||||||
| 					} |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | |||||||
| @ -190,17 +190,18 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat | |||||||
| 
 | 
 | ||||||
| 		if tx.Error != nil || int(result.RowsAffected) < batchSize { | 		if tx.Error != nil || int(result.RowsAffected) < batchSize { | ||||||
| 			break | 			break | ||||||
| 		} else { | 		} | ||||||
|  | 
 | ||||||
|  | 		// Optimize for-break
 | ||||||
| 		resultsValue := reflect.Indirect(reflect.ValueOf(dest)) | 		resultsValue := reflect.Indirect(reflect.ValueOf(dest)) | ||||||
| 		if result.Statement.Schema.PrioritizedPrimaryField == nil { | 		if result.Statement.Schema.PrioritizedPrimaryField == nil { | ||||||
| 			tx.AddError(ErrPrimaryKeyRequired) | 			tx.AddError(ErrPrimaryKeyRequired) | ||||||
| 			break | 			break | ||||||
| 			} else { | 		} | ||||||
|  | 
 | ||||||
| 		primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(resultsValue.Index(resultsValue.Len() - 1)) | 		primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(resultsValue.Index(resultsValue.Len() - 1)) | ||||||
| 		queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) | 		queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) | ||||||
| 	} | 	} | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 | 
 | ||||||
| 	tx.RowsAffected = rowsAffected | 	tx.RowsAffected = rowsAffected | ||||||
| 	return tx | 	return tx | ||||||
|  | |||||||
							
								
								
									
										4
									
								
								gorm.go
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								gorm.go
									
									
									
									
									
								
							| @ -409,7 +409,7 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac | |||||||
| 				} | 				} | ||||||
| 				ref.ForeignKey = f | 				ref.ForeignKey = f | ||||||
| 			} else { | 			} else { | ||||||
| 				return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName) | 				return fmt.Errorf("missing field %s for join table", ref.ForeignKey.DBName) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| @ -422,7 +422,7 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac | |||||||
| 
 | 
 | ||||||
| 		relation.JoinTable = joinSchema | 		relation.JoinTable = joinSchema | ||||||
| 	} else { | 	} else { | ||||||
| 		return fmt.Errorf("failed to found relation: %v", field) | 		return fmt.Errorf("failed to found relation: %s", field) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return nil | ||||||
|  | |||||||
| @ -119,16 +119,13 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { | |||||||
| 
 | 
 | ||||||
| 				for _, rel := range stmt.Schema.Relationships.Relations { | 				for _, rel := range stmt.Schema.Relationships.Relations { | ||||||
| 					if !m.DB.Config.DisableForeignKeyConstraintWhenMigrating { | 					if !m.DB.Config.DisableForeignKeyConstraintWhenMigrating { | ||||||
| 						if constraint := rel.ParseConstraint(); constraint != nil { | 						if constraint := rel.ParseConstraint(); constraint != nil && | ||||||
| 							if constraint.Schema == stmt.Schema { | 							constraint.Schema == stmt.Schema && !tx.Migrator().HasConstraint(value, constraint.Name) { | ||||||
| 								if !tx.Migrator().HasConstraint(value, constraint.Name) { |  | ||||||
| 							if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil { | 							if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil { | ||||||
| 								return err | 								return err | ||||||
| 							} | 							} | ||||||
| 						} | 						} | ||||||
| 					} | 					} | ||||||
| 						} |  | ||||||
| 					} |  | ||||||
| 
 | 
 | ||||||
| 					for _, chk := range stmt.Schema.ParseCheckConstraints() { | 					for _, chk := range stmt.Schema.ParseCheckConstraints() { | ||||||
| 						if !tx.Migrator().HasConstraint(value, chk.Name) { | 						if !tx.Migrator().HasConstraint(value, chk.Name) { | ||||||
| @ -294,16 +291,20 @@ func (m Migrator) RenameTable(oldName, newName interface{}) error { | |||||||
| 
 | 
 | ||||||
| func (m Migrator) AddColumn(value interface{}, field string) error { | func (m Migrator) AddColumn(value interface{}, field string) error { | ||||||
| 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
| 		if field := stmt.Schema.LookUpField(field); field != nil { | 		// avoid using the same name field
 | ||||||
| 			if !field.IgnoreMigration { | 		f := stmt.Schema.LookUpField(field) | ||||||
|  | 		if f == nil { | ||||||
|  | 			return fmt.Errorf("failed to look up field with name: %s", field) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if !f.IgnoreMigration { | ||||||
| 			return m.DB.Exec( | 			return m.DB.Exec( | ||||||
| 				"ALTER TABLE ? ADD ? ?", | 				"ALTER TABLE ? ADD ? ?", | ||||||
| 					m.CurrentTable(stmt), clause.Column{Name: field.DBName}, m.DB.Migrator().FullDataTypeOf(field), | 				m.CurrentTable(stmt), clause.Column{Name: f.DBName}, m.DB.Migrator().FullDataTypeOf(f), | ||||||
| 			).Error | 			).Error | ||||||
| 		} | 		} | ||||||
|  | 
 | ||||||
| 		return nil | 		return nil | ||||||
| 		} |  | ||||||
| 		return fmt.Errorf("failed to look up field with name: %s", field) |  | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -198,28 +198,28 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { | |||||||
| 		field.DataType = Bool | 		field.DataType = Bool | ||||||
| 		if field.HasDefaultValue && !skipParseDefaultValue { | 		if field.HasDefaultValue && !skipParseDefaultValue { | ||||||
| 			if field.DefaultValueInterface, err = strconv.ParseBool(field.DefaultValue); err != nil { | 			if field.DefaultValueInterface, err = strconv.ParseBool(field.DefaultValue); err != nil { | ||||||
| 				schema.err = fmt.Errorf("failed to parse %v as default value for bool, got error: %v", field.DefaultValue, err) | 				schema.err = fmt.Errorf("failed to parse %s as default value for bool, got error: %v", field.DefaultValue, err) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | 	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | ||||||
| 		field.DataType = Int | 		field.DataType = Int | ||||||
| 		if field.HasDefaultValue && !skipParseDefaultValue { | 		if field.HasDefaultValue && !skipParseDefaultValue { | ||||||
| 			if field.DefaultValueInterface, err = strconv.ParseInt(field.DefaultValue, 0, 64); err != nil { | 			if field.DefaultValueInterface, err = strconv.ParseInt(field.DefaultValue, 0, 64); err != nil { | ||||||
| 				schema.err = fmt.Errorf("failed to parse %v as default value for int, got error: %v", field.DefaultValue, err) | 				schema.err = fmt.Errorf("failed to parse %s as default value for int, got error: %v", field.DefaultValue, err) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: | 	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: | ||||||
| 		field.DataType = Uint | 		field.DataType = Uint | ||||||
| 		if field.HasDefaultValue && !skipParseDefaultValue { | 		if field.HasDefaultValue && !skipParseDefaultValue { | ||||||
| 			if field.DefaultValueInterface, err = strconv.ParseUint(field.DefaultValue, 0, 64); err != nil { | 			if field.DefaultValueInterface, err = strconv.ParseUint(field.DefaultValue, 0, 64); err != nil { | ||||||
| 				schema.err = fmt.Errorf("failed to parse %v as default value for uint, got error: %v", field.DefaultValue, err) | 				schema.err = fmt.Errorf("failed to parse %s as default value for uint, got error: %v", field.DefaultValue, err) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	case reflect.Float32, reflect.Float64: | 	case reflect.Float32, reflect.Float64: | ||||||
| 		field.DataType = Float | 		field.DataType = Float | ||||||
| 		if field.HasDefaultValue && !skipParseDefaultValue { | 		if field.HasDefaultValue && !skipParseDefaultValue { | ||||||
| 			if field.DefaultValueInterface, err = strconv.ParseFloat(field.DefaultValue, 64); err != nil { | 			if field.DefaultValueInterface, err = strconv.ParseFloat(field.DefaultValue, 64); err != nil { | ||||||
| 				schema.err = fmt.Errorf("failed to parse %v as default value for float, got error: %v", field.DefaultValue, err) | 				schema.err = fmt.Errorf("failed to parse %s as default value for float, got error: %v", field.DefaultValue, err) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	case reflect.String: | 	case reflect.String: | ||||||
| @ -227,7 +227,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { | |||||||
| 
 | 
 | ||||||
| 		if field.HasDefaultValue && !skipParseDefaultValue { | 		if field.HasDefaultValue && !skipParseDefaultValue { | ||||||
| 			field.DefaultValue = strings.Trim(field.DefaultValue, "'") | 			field.DefaultValue = strings.Trim(field.DefaultValue, "'") | ||||||
| 			field.DefaultValue = strings.Trim(field.DefaultValue, "\"") | 			field.DefaultValue = strings.Trim(field.DefaultValue, `"`) | ||||||
| 			field.DefaultValueInterface = field.DefaultValue | 			field.DefaultValueInterface = field.DefaultValue | ||||||
| 		} | 		} | ||||||
| 	case reflect.Struct: | 	case reflect.Struct: | ||||||
| @ -392,7 +392,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { | |||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 		} else { | 		} else { | ||||||
| 			schema.err = fmt.Errorf("invalid embedded struct for %v's field %v, should be struct, but got %v", field.Schema.Name, field.Name, field.FieldType) | 			schema.err = fmt.Errorf("invalid embedded struct for %s's field %s, should be struct, but got %v", field.Schema.Name, field.Name, field.FieldType) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| @ -423,12 +423,12 @@ func (field *Field) setupValuerAndSetter() { | |||||||
| 				} else { | 				} else { | ||||||
| 					v = v.Field(-idx - 1) | 					v = v.Field(-idx - 1) | ||||||
| 
 | 
 | ||||||
| 					if v.Type().Elem().Kind() == reflect.Struct { | 					if v.Type().Elem().Kind() != reflect.Struct { | ||||||
| 						if !v.IsNil() { |  | ||||||
| 							v = v.Elem() |  | ||||||
| 						} else { |  | ||||||
| 						return nil, true | 						return nil, true | ||||||
| 					} | 					} | ||||||
|  | 
 | ||||||
|  | 					if !v.IsNil() { | ||||||
|  | 						v = v.Elem() | ||||||
| 					} else { | 					} else { | ||||||
| 						return nil, true | 						return nil, true | ||||||
| 					} | 					} | ||||||
| @ -736,7 +736,7 @@ func (field *Field) setupValuerAndSetter() { | |||||||
| 					if t, err := now.Parse(data); err == nil { | 					if t, err := now.Parse(data); err == nil { | ||||||
| 						field.ReflectValueOf(value).Set(reflect.ValueOf(t)) | 						field.ReflectValueOf(value).Set(reflect.ValueOf(t)) | ||||||
| 					} else { | 					} else { | ||||||
| 						return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) | 						return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err) | ||||||
| 					} | 					} | ||||||
| 				default: | 				default: | ||||||
| 					return fallbackSetter(value, v, field.Set) | 					return fallbackSetter(value, v, field.Set) | ||||||
| @ -765,7 +765,7 @@ func (field *Field) setupValuerAndSetter() { | |||||||
| 						} | 						} | ||||||
| 						fieldValue.Elem().Set(reflect.ValueOf(t)) | 						fieldValue.Elem().Set(reflect.ValueOf(t)) | ||||||
| 					} else { | 					} else { | ||||||
| 						return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) | 						return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err) | ||||||
| 					} | 					} | ||||||
| 				default: | 				default: | ||||||
| 					return fallbackSetter(value, v, field.Set) | 					return fallbackSetter(value, v, field.Set) | ||||||
|  | |||||||
| @ -74,7 +74,9 @@ func (ns NamingStrategy) IndexName(table, column string) string { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (ns NamingStrategy) formatName(prefix, table, name string) string { | func (ns NamingStrategy) formatName(prefix, table, name string) string { | ||||||
| 	formattedName := strings.Replace(fmt.Sprintf("%v_%v_%v", prefix, table, name), ".", "_", -1) | 	formattedName := strings.Replace(strings.Join([]string{ | ||||||
|  | 		prefix, table, name, | ||||||
|  | 	}, "_"), ".", "_", -1) | ||||||
| 
 | 
 | ||||||
| 	if utf8.RuneCountInString(formattedName) > 64 { | 	if utf8.RuneCountInString(formattedName) > 64 { | ||||||
| 		h := sha1.New() | 		h := sha1.New() | ||||||
|  | |||||||
| @ -85,7 +85,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { | |||||||
| 		case reflect.Slice: | 		case reflect.Slice: | ||||||
| 			schema.guessRelation(relation, field, guessHas) | 			schema.guessRelation(relation, field, guessHas) | ||||||
| 		default: | 		default: | ||||||
| 			schema.err = fmt.Errorf("unsupported data type %v for %v on field %v", relation.FieldSchema, schema, field.Name) | 			schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema, field.Name) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| @ -143,11 +143,11 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if relation.Polymorphic.PolymorphicType == nil { | 	if relation.Polymorphic.PolymorphicType == nil { | ||||||
| 		schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"Type") | 		schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"Type") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if relation.Polymorphic.PolymorphicID == nil { | 	if relation.Polymorphic.PolymorphicID == nil { | ||||||
| 		schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"ID") | 		schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"ID") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if schema.err == nil { | 	if schema.err == nil { | ||||||
| @ -159,7 +159,7 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi | |||||||
| 		primaryKeyField := schema.PrioritizedPrimaryField | 		primaryKeyField := schema.PrioritizedPrimaryField | ||||||
| 		if len(relation.foreignKeys) > 0 { | 		if len(relation.foreignKeys) > 0 { | ||||||
| 			if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 { | 			if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 { | ||||||
| 				schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %v", relation.foreignKeys, schema, field.Name) | 				schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys, schema, field.Name) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| @ -203,7 +203,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel | |||||||
| 			if field := schema.LookUpField(foreignKey); field != nil { | 			if field := schema.LookUpField(foreignKey); field != nil { | ||||||
| 				ownForeignFields = append(ownForeignFields, field) | 				ownForeignFields = append(ownForeignFields, field) | ||||||
| 			} else { | 			} else { | ||||||
| 				schema.err = fmt.Errorf("invalid foreign key: %v", foreignKey) | 				schema.err = fmt.Errorf("invalid foreign key: %s", foreignKey) | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| @ -215,7 +215,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel | |||||||
| 			if field := relation.FieldSchema.LookUpField(foreignKey); field != nil { | 			if field := relation.FieldSchema.LookUpField(foreignKey); field != nil { | ||||||
| 				refForeignFields = append(refForeignFields, field) | 				refForeignFields = append(refForeignFields, field) | ||||||
| 			} else { | 			} else { | ||||||
| 				schema.err = fmt.Errorf("invalid foreign key: %v", foreignKey) | 				schema.err = fmt.Errorf("invalid foreign key: %s", foreignKey) | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| @ -379,7 +379,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu | |||||||
| 			schema.guessRelation(relation, field, guessEmbeddedHas) | 			schema.guessRelation(relation, field, guessEmbeddedHas) | ||||||
| 		// case guessEmbeddedHas:
 | 		// case guessEmbeddedHas:
 | ||||||
| 		default: | 		default: | ||||||
| 			schema.err = fmt.Errorf("invalid field found for struct %v's field %v: define a valid foreign key for relations or implement the Valuer/Scanner interface", schema, field.Name) | 			schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface", schema, field.Name) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -45,9 +45,9 @@ type Schema struct { | |||||||
| 
 | 
 | ||||||
| func (schema Schema) String() string { | func (schema Schema) String() string { | ||||||
| 	if schema.ModelType.Name() == "" { | 	if schema.ModelType.Name() == "" { | ||||||
| 		return fmt.Sprintf("%v(%v)", schema.Name, schema.Table) | 		return fmt.Sprintf("%s(%s)", schema.Name, schema.Table) | ||||||
| 	} | 	} | ||||||
| 	return fmt.Sprintf("%v.%v", schema.ModelType.PkgPath(), schema.ModelType.Name()) | 	return fmt.Sprintf("%s.%s", schema.ModelType.PkgPath(), schema.ModelType.Name()) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (schema Schema) MakeSlice() reflect.Value { | func (schema Schema) MakeSlice() reflect.Value { | ||||||
| @ -86,7 +86,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) | |||||||
| 		if modelType.PkgPath() == "" { | 		if modelType.PkgPath() == "" { | ||||||
| 			return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) | 			return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) | ||||||
| 		} | 		} | ||||||
| 		return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) | 		return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if v, ok := cacheStore.Load(modelType); ok { | 	if v, ok := cacheStore.Load(modelType); ok { | ||||||
| @ -275,7 +275,7 @@ func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, e | |||||||
| 		if modelType.PkgPath() == "" { | 		if modelType.PkgPath() == "" { | ||||||
| 			return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) | 			return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) | ||||||
| 		} | 		} | ||||||
| 		return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) | 		return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if v, ok := cacheStore.Load(modelType); ok { | 	if v, ok := cacheStore.Load(modelType); ok { | ||||||
|  | |||||||
| @ -178,7 +178,8 @@ func ToQueryValues(table string, foreignKeys []string, foreignValues [][]interfa | |||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		return clause.Column{Table: table, Name: foreignKeys[0]}, queryValues | 		return clause.Column{Table: table, Name: foreignKeys[0]}, queryValues | ||||||
| 	} else { | 	} | ||||||
|  | 
 | ||||||
| 	columns := make([]clause.Column, len(foreignKeys)) | 	columns := make([]clause.Column, len(foreignKeys)) | ||||||
| 	for idx, key := range foreignKeys { | 	for idx, key := range foreignKeys { | ||||||
| 		columns[idx] = clause.Column{Table: table, Name: key} | 		columns[idx] = clause.Column{Table: table, Name: key} | ||||||
| @ -187,8 +188,8 @@ func ToQueryValues(table string, foreignKeys []string, foreignValues [][]interfa | |||||||
| 	for idx, r := range foreignValues { | 	for idx, r := range foreignValues { | ||||||
| 		queryValues[idx] = r | 		queryValues[idx] = r | ||||||
| 	} | 	} | ||||||
|  | 
 | ||||||
| 	return columns, queryValues | 	return columns, queryValues | ||||||
| 	} |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type embeddedNamer struct { | type embeddedNamer struct { | ||||||
|  | |||||||
| @ -3,8 +3,8 @@ package utils | |||||||
| import ( | import ( | ||||||
| 	"database/sql/driver" | 	"database/sql/driver" | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"path/filepath" | ||||||
| 	"reflect" | 	"reflect" | ||||||
| 	"regexp" |  | ||||||
| 	"runtime" | 	"runtime" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"strings" | 	"strings" | ||||||
| @ -15,17 +15,20 @@ var gormSourceDir string | |||||||
| 
 | 
 | ||||||
| func init() { | func init() { | ||||||
| 	_, file, _, _ := runtime.Caller(0) | 	_, file, _, _ := runtime.Caller(0) | ||||||
| 	gormSourceDir = regexp.MustCompile(`utils.utils\.go`).ReplaceAllString(file, "") | 	// Here is the directory to get the gorm source code. Here, the filepath.Dir mode is enough,
 | ||||||
|  | 	// and the filepath is compatible with various operating systems
 | ||||||
|  | 	gormSourceDir = filepath.Dir(filepath.Dir(file)) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // FileWithLineNum return the file name and line number of the current file
 | ||||||
| func FileWithLineNum() string { | func FileWithLineNum() string { | ||||||
| 	for i := 2; i < 15; i++ { | 	for i := 1; i < 15; i++ { | ||||||
| 		_, file, line, ok := runtime.Caller(i) | 		_, file, line, ok := runtime.Caller(i) | ||||||
| 
 |  | ||||||
| 		if ok && (!strings.HasPrefix(file, gormSourceDir) || strings.HasSuffix(file, "_test.go")) { | 		if ok && (!strings.HasPrefix(file, gormSourceDir) || strings.HasSuffix(file, "_test.go")) { | ||||||
| 			return file + ":" + strconv.FormatInt(int64(line), 10) | 			return file + ":" + strconv.FormatInt(int64(line), 10) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  | 
 | ||||||
| 	return "" | 	return "" | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user