diff --git a/association.go b/association.go index 572f1526..62c25b71 100644 --- a/association.go +++ b/association.go @@ -26,7 +26,7 @@ func (db *DB) Association(column string) *Association { association.Relationship = db.Statement.Schema.Relationships.Relations[column] if association.Relationship == nil { - association.Error = fmt.Errorf("%w: %v", ErrUnsupportedRelation, column) + association.Error = fmt.Errorf("%w: %s", ErrUnsupportedRelation, column) } db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model) @@ -355,7 +355,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } else if ev.Type().Elem().AssignableTo(elemType) { fieldValue = reflect.Append(fieldValue, ev.Elem()) } else { - association.Error = fmt.Errorf("unsupported data type: %v for relation %v", ev.Type(), association.Relationship.Name) + association.Error = fmt.Errorf("unsupported data type: %v for relation %s", ev.Type(), association.Relationship.Name) } if elemType.Kind() == reflect.Struct { diff --git a/callbacks.go b/callbacks.go index 26e9c40d..02e741e7 100644 --- a/callbacks.go +++ b/callbacks.go @@ -212,7 +212,7 @@ func (c *callback) Register(name string, fn func(*DB)) error { } func (c *callback) Remove(name string) error { - c.processor.db.Logger.Warn(context.Background(), "removing callback `%v` from %v\n", name, utils.FileWithLineNum()) + c.processor.db.Logger.Warn(context.Background(), "removing callback `%s` from %s\n", name, utils.FileWithLineNum()) c.name = name c.remove = true c.processor.callbacks = append(c.processor.callbacks, c) @@ -220,7 +220,7 @@ func (c *callback) Remove(name string) error { } func (c *callback) Replace(name string, fn func(*DB)) error { - c.processor.db.Logger.Info(context.Background(), "replacing callback `%v` from %v\n", name, utils.FileWithLineNum()) + c.processor.db.Logger.Info(context.Background(), "replacing callback `%s` from %s\n", name, utils.FileWithLineNum()) c.name = name c.handler = fn c.replace = true @@ -250,7 +250,7 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { for _, c := range cs { // show warning message the callback name already exists if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove { - c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum()) + c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%s` from %s\n", c.name, utils.FileWithLineNum()) } names = append(names, c.name) } @@ -266,7 +266,7 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { // if before callback already sorted, append current callback just after it sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...) } else if curIdx > sortedIdx { - return fmt.Errorf("conflicting callback %v with before %v", c.name, c.before) + return fmt.Errorf("conflicting callback %s with before %s", c.name, c.before) } } else if idx := getRIndex(names, c.before); idx != -1 { // if before callback exists @@ -284,7 +284,7 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { // if after callback sorted, append current callback to last sorted = append(sorted, c.name) } else if curIdx < sortedIdx { - return fmt.Errorf("conflicting callback %v with before %v", c.name, c.after) + return fmt.Errorf("conflicting callback %s with before %s", c.name, c.after) } } else if idx := getRIndex(names, c.after); idx != -1 { // if after callback exists but haven't sorted diff --git a/callbacks/create.go b/callbacks/create.go index 727bd380..e46d3d05 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -33,75 +33,81 @@ func BeforeCreate(db *gorm.DB) { func Create(config *Config) func(db *gorm.DB) { if config.WithReturning { return CreateWithReturning - } else { - return func(db *gorm.DB) { - if db.Error == nil { - if db.Statement.Schema != nil && !db.Statement.Unscoped { - for _, c := range db.Statement.Schema.CreateClauses { - db.Statement.AddClause(c) - } - } + } - if db.Statement.SQL.String() == "" { - db.Statement.SQL.Grow(180) - db.Statement.AddClauseIfNotExists(clause.Insert{}) - db.Statement.AddClause(ConvertToCreateValues(db.Statement)) + return func(db *gorm.DB) { + if db.Error != nil { + // maybe record logger TODO + return + } - db.Statement.Build(db.Statement.BuildClauses...) - } + if db.Statement.Schema != nil && !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.CreateClauses { + db.Statement.AddClause(c) + } + } - if !db.DryRun && db.Error == nil { - result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + if db.Statement.SQL.String() == "" { + db.Statement.SQL.Grow(180) + db.Statement.AddClauseIfNotExists(clause.Insert{}) + db.Statement.AddClause(ConvertToCreateValues(db.Statement)) - if err == nil { - db.RowsAffected, _ = result.RowsAffected() + db.Statement.Build(db.Statement.BuildClauses...) + } - if db.RowsAffected > 0 { - if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { - if insertID, err := result.LastInsertId(); err == nil && insertID > 0 { - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - if config.LastInsertIDReversed { - for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { - rv := db.Statement.ReflectValue.Index(i) - if reflect.Indirect(rv).Kind() != reflect.Struct { - break - } + if !db.DryRun && db.Error == nil { + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv) - if isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) - insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement - } - } - } else { - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - rv := db.Statement.ReflectValue.Index(i) - if reflect.Indirect(rv).Kind() != reflect.Struct { - break - } + if err != nil { + db.AddError(err) + return + } - if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) - insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement - } - } - } - case reflect.Struct: - if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) - } - } - } else { - db.AddError(err) + db.RowsAffected, _ = result.RowsAffected() + if !(db.RowsAffected > 0) { + return + } + + if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { + if insertID, err := result.LastInsertId(); err == nil && insertID > 0 { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + if config.LastInsertIDReversed { + for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { + rv := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(rv).Kind() != reflect.Struct { + break + } + + _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv) + if isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) + insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement + } + } + } else { + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + rv := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(rv).Kind() != reflect.Struct { + break + } + + if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) + insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement } } } - } else { - db.AddError(err) + case reflect.Struct: + if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) + } } + } else { + db.AddError(err) } } + } } } diff --git a/finisher_api.go b/finisher_api.go index 7b8afabd..f4fa5c76 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -190,16 +190,17 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat if tx.Error != nil || int(result.RowsAffected) < batchSize { break - } else { - resultsValue := reflect.Indirect(reflect.ValueOf(dest)) - if result.Statement.Schema.PrioritizedPrimaryField == nil { - tx.AddError(ErrPrimaryKeyRequired) - break - } else { - primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(resultsValue.Index(resultsValue.Len() - 1)) - queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) - } } + + // Optimize for-break + resultsValue := reflect.Indirect(reflect.ValueOf(dest)) + if result.Statement.Schema.PrioritizedPrimaryField == nil { + tx.AddError(ErrPrimaryKeyRequired) + break + } + + primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(resultsValue.Index(resultsValue.Len() - 1)) + queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) } tx.RowsAffected = rowsAffected diff --git a/gorm.go b/gorm.go index e105a933..7f7bad26 100644 --- a/gorm.go +++ b/gorm.go @@ -409,7 +409,7 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac } ref.ForeignKey = f } else { - return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName) + return fmt.Errorf("missing field %s for join table", ref.ForeignKey.DBName) } } @@ -422,7 +422,7 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac relation.JoinTable = joinSchema } else { - return fmt.Errorf("failed to found relation: %v", field) + return fmt.Errorf("failed to found relation: %s", field) } return nil diff --git a/migrator/migrator.go b/migrator/migrator.go index 1800ab54..03ffdd02 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -119,13 +119,10 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { for _, rel := range stmt.Schema.Relationships.Relations { if !m.DB.Config.DisableForeignKeyConstraintWhenMigrating { - if constraint := rel.ParseConstraint(); constraint != nil { - if constraint.Schema == stmt.Schema { - if !tx.Migrator().HasConstraint(value, constraint.Name) { - if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil { - return err - } - } + if constraint := rel.ParseConstraint(); constraint != nil && + constraint.Schema == stmt.Schema && !tx.Migrator().HasConstraint(value, constraint.Name) { + if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil { + return err } } } @@ -294,16 +291,20 @@ func (m Migrator) RenameTable(oldName, newName interface{}) error { func (m Migrator) AddColumn(value interface{}, field string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if field := stmt.Schema.LookUpField(field); field != nil { - if !field.IgnoreMigration { - return m.DB.Exec( - "ALTER TABLE ? ADD ? ?", - m.CurrentTable(stmt), clause.Column{Name: field.DBName}, m.DB.Migrator().FullDataTypeOf(field), - ).Error - } - return nil + // avoid using the same name field + f := stmt.Schema.LookUpField(field) + if f == nil { + return fmt.Errorf("failed to look up field with name: %s", field) } - return fmt.Errorf("failed to look up field with name: %s", field) + + if !f.IgnoreMigration { + return m.DB.Exec( + "ALTER TABLE ? ADD ? ?", + m.CurrentTable(stmt), clause.Column{Name: f.DBName}, m.DB.Migrator().FullDataTypeOf(f), + ).Error + } + + return nil }) } diff --git a/schema/field.go b/schema/field.go index 5dbc96f1..9efaa44a 100644 --- a/schema/field.go +++ b/schema/field.go @@ -198,28 +198,28 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.DataType = Bool if field.HasDefaultValue && !skipParseDefaultValue { if field.DefaultValueInterface, err = strconv.ParseBool(field.DefaultValue); err != nil { - schema.err = fmt.Errorf("failed to parse %v as default value for bool, got error: %v", field.DefaultValue, err) + schema.err = fmt.Errorf("failed to parse %s as default value for bool, got error: %v", field.DefaultValue, err) } } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: field.DataType = Int if field.HasDefaultValue && !skipParseDefaultValue { if field.DefaultValueInterface, err = strconv.ParseInt(field.DefaultValue, 0, 64); err != nil { - schema.err = fmt.Errorf("failed to parse %v as default value for int, got error: %v", field.DefaultValue, err) + schema.err = fmt.Errorf("failed to parse %s as default value for int, got error: %v", field.DefaultValue, err) } } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: field.DataType = Uint if field.HasDefaultValue && !skipParseDefaultValue { if field.DefaultValueInterface, err = strconv.ParseUint(field.DefaultValue, 0, 64); err != nil { - schema.err = fmt.Errorf("failed to parse %v as default value for uint, got error: %v", field.DefaultValue, err) + schema.err = fmt.Errorf("failed to parse %s as default value for uint, got error: %v", field.DefaultValue, err) } } case reflect.Float32, reflect.Float64: field.DataType = Float if field.HasDefaultValue && !skipParseDefaultValue { if field.DefaultValueInterface, err = strconv.ParseFloat(field.DefaultValue, 64); err != nil { - schema.err = fmt.Errorf("failed to parse %v as default value for float, got error: %v", field.DefaultValue, err) + schema.err = fmt.Errorf("failed to parse %s as default value for float, got error: %v", field.DefaultValue, err) } } case reflect.String: @@ -227,7 +227,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { if field.HasDefaultValue && !skipParseDefaultValue { field.DefaultValue = strings.Trim(field.DefaultValue, "'") - field.DefaultValue = strings.Trim(field.DefaultValue, "\"") + field.DefaultValue = strings.Trim(field.DefaultValue, `"`) field.DefaultValueInterface = field.DefaultValue } case reflect.Struct: @@ -392,7 +392,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } } else { - schema.err = fmt.Errorf("invalid embedded struct for %v's field %v, should be struct, but got %v", field.Schema.Name, field.Name, field.FieldType) + schema.err = fmt.Errorf("invalid embedded struct for %s's field %s, should be struct, but got %v", field.Schema.Name, field.Name, field.FieldType) } } @@ -423,12 +423,12 @@ func (field *Field) setupValuerAndSetter() { } else { v = v.Field(-idx - 1) - if v.Type().Elem().Kind() == reflect.Struct { - if !v.IsNil() { - v = v.Elem() - } else { - return nil, true - } + if v.Type().Elem().Kind() != reflect.Struct { + return nil, true + } + + if !v.IsNil() { + v = v.Elem() } else { return nil, true } @@ -736,7 +736,7 @@ func (field *Field) setupValuerAndSetter() { if t, err := now.Parse(data); err == nil { field.ReflectValueOf(value).Set(reflect.ValueOf(t)) } else { - return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) + return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err) } default: return fallbackSetter(value, v, field.Set) @@ -765,7 +765,7 @@ func (field *Field) setupValuerAndSetter() { } fieldValue.Elem().Set(reflect.ValueOf(t)) } else { - return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) + return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err) } default: return fallbackSetter(value, v, field.Set) diff --git a/schema/naming.go b/schema/naming.go index d53942e4..47e313a7 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -74,7 +74,9 @@ func (ns NamingStrategy) IndexName(table, column string) string { } func (ns NamingStrategy) formatName(prefix, table, name string) string { - formattedName := strings.Replace(fmt.Sprintf("%v_%v_%v", prefix, table, name), ".", "_", -1) + formattedName := strings.Replace(strings.Join([]string{ + prefix, table, name, + }, "_"), ".", "_", -1) if utf8.RuneCountInString(formattedName) > 64 { h := sha1.New() diff --git a/schema/relationship.go b/schema/relationship.go index c7abc234..db496e30 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -85,7 +85,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { case reflect.Slice: schema.guessRelation(relation, field, guessHas) default: - schema.err = fmt.Errorf("unsupported data type %v for %v on field %v", relation.FieldSchema, schema, field.Name) + schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema, field.Name) } } @@ -143,11 +143,11 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi } if relation.Polymorphic.PolymorphicType == nil { - schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"Type") + schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"Type") } if relation.Polymorphic.PolymorphicID == nil { - schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"ID") + schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"ID") } if schema.err == nil { @@ -159,7 +159,7 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi primaryKeyField := schema.PrioritizedPrimaryField if len(relation.foreignKeys) > 0 { if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 { - schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %v", relation.foreignKeys, schema, field.Name) + schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys, schema, field.Name) } } @@ -203,7 +203,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel if field := schema.LookUpField(foreignKey); field != nil { ownForeignFields = append(ownForeignFields, field) } else { - schema.err = fmt.Errorf("invalid foreign key: %v", foreignKey) + schema.err = fmt.Errorf("invalid foreign key: %s", foreignKey) return } } @@ -215,7 +215,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel if field := relation.FieldSchema.LookUpField(foreignKey); field != nil { refForeignFields = append(refForeignFields, field) } else { - schema.err = fmt.Errorf("invalid foreign key: %v", foreignKey) + schema.err = fmt.Errorf("invalid foreign key: %s", foreignKey) return } } @@ -379,7 +379,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu schema.guessRelation(relation, field, guessEmbeddedHas) // case guessEmbeddedHas: default: - schema.err = fmt.Errorf("invalid field found for struct %v's field %v: define a valid foreign key for relations or implement the Valuer/Scanner interface", schema, field.Name) + schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface", schema, field.Name) } } diff --git a/schema/schema.go b/schema/schema.go index 1ce88fa5..8ade2ed7 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -45,9 +45,9 @@ type Schema struct { func (schema Schema) String() string { if schema.ModelType.Name() == "" { - return fmt.Sprintf("%v(%v)", schema.Name, schema.Table) + return fmt.Sprintf("%s(%s)", schema.Name, schema.Table) } - return fmt.Sprintf("%v.%v", schema.ModelType.PkgPath(), schema.ModelType.Name()) + return fmt.Sprintf("%s.%s", schema.ModelType.PkgPath(), schema.ModelType.Name()) } func (schema Schema) MakeSlice() reflect.Value { @@ -86,7 +86,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) if modelType.PkgPath() == "" { return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) } - return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) + return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } if v, ok := cacheStore.Load(modelType); ok { @@ -275,7 +275,7 @@ func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, e if modelType.PkgPath() == "" { return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) } - return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) + return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } if v, ok := cacheStore.Load(modelType); ok { diff --git a/schema/utils.go b/schema/utils.go index add22047..e005cc74 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -178,17 +178,18 @@ func ToQueryValues(table string, foreignKeys []string, foreignValues [][]interfa } return clause.Column{Table: table, Name: foreignKeys[0]}, queryValues - } else { - columns := make([]clause.Column, len(foreignKeys)) - for idx, key := range foreignKeys { - columns[idx] = clause.Column{Table: table, Name: key} - } - - for idx, r := range foreignValues { - queryValues[idx] = r - } - return columns, queryValues } + + columns := make([]clause.Column, len(foreignKeys)) + for idx, key := range foreignKeys { + columns[idx] = clause.Column{Table: table, Name: key} + } + + for idx, r := range foreignValues { + queryValues[idx] = r + } + + return columns, queryValues } type embeddedNamer struct { diff --git a/utils/utils.go b/utils/utils.go index ce6f35df..3261138f 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -3,8 +3,8 @@ package utils import ( "database/sql/driver" "fmt" + "path/filepath" "reflect" - "regexp" "runtime" "strconv" "strings" @@ -15,17 +15,20 @@ var gormSourceDir string func init() { _, file, _, _ := runtime.Caller(0) - gormSourceDir = regexp.MustCompile(`utils.utils\.go`).ReplaceAllString(file, "") + // Here is the directory to get the gorm source code. Here, the filepath.Dir mode is enough, + // and the filepath is compatible with various operating systems + gormSourceDir = filepath.Dir(filepath.Dir(file)) } +// FileWithLineNum return the file name and line number of the current file func FileWithLineNum() string { - for i := 2; i < 15; i++ { + for i := 1; i < 15; i++ { _, file, line, ok := runtime.Caller(i) - if ok && (!strings.HasPrefix(file, gormSourceDir) || strings.HasSuffix(file, "_test.go")) { return file + ":" + strconv.FormatInt(int64(line), 10) } } + return "" }