feat: make errors more detailed when setting value to fields

This commit is contained in:
iseki 2025-07-09 13:07:00 +08:00
parent 2f4925e017
commit d770e5b1c5
No known key found for this signature in database
GPG Key ID: B5F8E8F1CE406872
3 changed files with 47 additions and 11 deletions

View File

@ -183,7 +183,7 @@ func Create(config *Config) func(db *gorm.DB) {
_, isZero := pkField.ValueOf(db.Statement.Context, rv) _, isZero := pkField.ValueOf(db.Statement.Context, rv)
if isZero { if isZero {
db.AddError(pkField.Set(db.Statement.Context, rv, insertID)) db.AddError(newSetFieldValueError(pkField, pkField.Set(db.Statement.Context, rv, insertID)))
insertID -= pkField.AutoIncrementIncrement insertID -= pkField.AutoIncrementIncrement
} }
} }
@ -195,7 +195,7 @@ func Create(config *Config) func(db *gorm.DB) {
} }
if _, isZero := pkField.ValueOf(db.Statement.Context, rv); isZero { if _, isZero := pkField.ValueOf(db.Statement.Context, rv); isZero {
db.AddError(pkField.Set(db.Statement.Context, rv, insertID)) db.AddError(newSetFieldValueError(pkField, pkField.Set(db.Statement.Context, rv, insertID)))
insertID += pkField.AutoIncrementIncrement insertID += pkField.AutoIncrementIncrement
} }
} }
@ -203,7 +203,7 @@ func Create(config *Config) func(db *gorm.DB) {
case reflect.Struct: case reflect.Struct:
_, isZero := pkField.ValueOf(db.Statement.Context, db.Statement.ReflectValue) _, isZero := pkField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
if isZero { if isZero {
db.AddError(pkField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID)) db.AddError(newSetFieldValueError(pkField, pkField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID)))
} }
} }
} }
@ -289,13 +289,13 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
if values.Values[i][idx], isZero = field.ValueOf(stmt.Context, rv); isZero { if values.Values[i][idx], isZero = field.ValueOf(stmt.Context, rv); isZero {
if field.DefaultValueInterface != nil { if field.DefaultValueInterface != nil {
values.Values[i][idx] = field.DefaultValueInterface values.Values[i][idx] = field.DefaultValueInterface
stmt.AddError(field.Set(stmt.Context, rv, field.DefaultValueInterface)) stmt.AddError(newSetFieldValueError(field, field.Set(stmt.Context, rv, field.DefaultValueInterface)))
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
stmt.AddError(field.Set(stmt.Context, rv, curTime)) stmt.AddError(newSetFieldValueError(field, field.Set(stmt.Context, rv, curTime)))
values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv) values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv)
} }
} else if field.AutoUpdateTime > 0 && updateTrackTime { } else if field.AutoUpdateTime > 0 && updateTrackTime {
stmt.AddError(field.Set(stmt.Context, rv, curTime)) stmt.AddError(newSetFieldValueError(field, field.Set(stmt.Context, rv, curTime)))
values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv) values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv)
} }
} }
@ -331,13 +331,13 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
if values.Values[0][idx], isZero = field.ValueOf(stmt.Context, stmt.ReflectValue); isZero { if values.Values[0][idx], isZero = field.ValueOf(stmt.Context, stmt.ReflectValue); isZero {
if field.DefaultValueInterface != nil { if field.DefaultValueInterface != nil {
values.Values[0][idx] = field.DefaultValueInterface values.Values[0][idx] = field.DefaultValueInterface
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface)) stmt.AddError(newSetFieldValueError(field, field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface)))
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime)) stmt.AddError(newSetFieldValueError(field, field.Set(stmt.Context, stmt.ReflectValue, curTime)))
values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue) values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)
} }
} else if field.AutoUpdateTime > 0 && updateTrackTime { } else if field.AutoUpdateTime > 0 && updateTrackTime {
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime)) stmt.AddError(newSetFieldValueError(field, field.Set(stmt.Context, stmt.ReflectValue, curTime)))
values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue) values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)
} }
} }
@ -351,7 +351,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
} }
} }
default: default:
stmt.AddError(gorm.ErrInvalidData) stmt.AddError(fmt.Errorf("%w: expected slice, array or struct, but got %v", gorm.ErrInvalidData, stmt.ReflectValue.Kind()))
} }
} }

36
callbacks/errors.go Normal file
View File

@ -0,0 +1,36 @@
package callbacks
import (
"fmt"
"gorm.io/gorm/schema"
)
type _SetFieldValueError struct {
Field *schema.Field
Err error
}
func (e _SetFieldValueError) Error() string {
return fmt.Sprintf("error when set value for field %s: %v", e.Field.Name, e.Err)
}
func (e _SetFieldValueError) Unwrap() error {
return e.Err
}
func newSetFieldValueError(field *schema.Field, e error) error {
if e == nil {
return nil
}
//goland:noinspection GoTypeAssertionOnErrors
if we, ok := e.(*_SetFieldValueError); ok && we.Field == field {
return e
}
if field == nil {
panic("field is nil")
}
return &_SetFieldValueError{
Field: field,
Err: e,
}
}

View File

@ -21,7 +21,7 @@ func SetupUpdateReflectValue(db *gorm.DB) {
if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
for _, rel := range db.Statement.Schema.Relationships.BelongsTo { for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
if _, ok := dest[rel.Name]; ok { if _, ok := dest[rel.Name]; ok {
db.AddError(rel.Field.Set(db.Statement.Context, db.Statement.ReflectValue, dest[rel.Name])) db.AddError(newSetFieldValueError(rel.Field, rel.Field.Set(db.Statement.Context, db.Statement.ReflectValue, dest[rel.Name])))
} }
} }
} }