optimize for logic,if logic and setupValuerAndSetter func.

This commit is contained in:
daheige 2021-06-14 10:59:00 +08:00
parent a689e56433
commit 9653eaf0fe
3 changed files with 32 additions and 33 deletions

View File

@ -22,15 +22,16 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
setupReferences := func(obj reflect.Value, elem reflect.Value) { setupReferences := func(obj reflect.Value, elem reflect.Value) {
for _, ref := range rel.References { for _, ref := range rel.References {
if !ref.OwnPrimaryKey { if ref.OwnPrimaryKey {
pv, _ := ref.PrimaryKey.ValueOf(elem) continue
db.AddError(ref.ForeignKey.Set(obj, pv)) }
if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { pv, _ := ref.PrimaryKey.ValueOf(elem)
dest[ref.ForeignKey.DBName] = pv db.AddError(ref.ForeignKey.Set(obj, pv))
if _, ok := dest[rel.Name]; ok { if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
dest[rel.Name] = elem.Interface() dest[ref.ForeignKey.DBName] = pv
} if _, ok := dest[rel.Name]; ok {
dest[rel.Name] = elem.Interface()
} }
} }
} }
@ -51,27 +52,24 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
for i := 0; i < db.Statement.ReflectValue.Len(); i++ { for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
obj := db.Statement.ReflectValue.Index(i) obj := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(obj).Kind() != reflect.Struct {
if reflect.Indirect(obj).Kind() == reflect.Struct {
if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value
rv := rel.Field.ReflectValueOf(obj) // relation reflect value
objs = append(objs, obj)
if isPtr {
elems = reflect.Append(elems, rv)
} else {
elems = reflect.Append(elems, rv.Addr())
}
}
} else {
break break
} }
if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value
rv := rel.Field.ReflectValueOf(obj) // relation reflect value
objs = append(objs, obj)
if isPtr {
elems = reflect.Append(elems, rv)
} else {
elems = reflect.Append(elems, rv.Addr())
}
}
} }
if elems.Len() > 0 { if elems.Len() > 0 && saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) == nil {
if saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) == nil { for i := 0; i < elems.Len(); i++ {
for i := 0; i < elems.Len(); i++ { setupReferences(objs[i], elems.Index(i))
setupReferences(objs[i], elems.Index(i))
}
} }
} }
case reflect.Struct: case reflect.Struct:

View File

@ -209,7 +209,7 @@ func Preload(db *gorm.DB) {
if rel := db.Statement.Schema.Relationships.Relations[name]; rel != nil { if rel := db.Statement.Schema.Relationships.Relations[name]; rel != nil {
preload(db, rel, db.Statement.Preloads[name], preloadMap[name]) preload(db, rel, db.Statement.Preloads[name], preloadMap[name])
} else { } else {
db.AddError(fmt.Errorf("%v: %w for schema %v", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)) db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name))
} }
} }
} }

View File

@ -490,21 +490,22 @@ func (field *Field) setupValuerAndSetter() {
return return
} else if field.FieldType.Kind() == reflect.Ptr { } else if field.FieldType.Kind() == reflect.Ptr {
fieldValue := field.ReflectValueOf(value) fieldValue := field.ReflectValueOf(value)
fieldType := field.FieldType.Elem()
if reflectValType.AssignableTo(field.FieldType.Elem()) { if reflectValType.AssignableTo(fieldType) {
if !fieldValue.IsValid() { if !fieldValue.IsValid() {
fieldValue = reflect.New(field.FieldType.Elem()) fieldValue = reflect.New(fieldType)
} else if fieldValue.IsNil() { } else if fieldValue.IsNil() {
fieldValue.Set(reflect.New(field.FieldType.Elem())) fieldValue.Set(reflect.New(fieldType))
} }
fieldValue.Elem().Set(reflectV) fieldValue.Elem().Set(reflectV)
return return
} else if reflectValType.ConvertibleTo(field.FieldType.Elem()) { } else if reflectValType.ConvertibleTo(fieldType) {
if fieldValue.IsNil() { if fieldValue.IsNil() {
fieldValue.Set(reflect.New(field.FieldType.Elem())) fieldValue.Set(reflect.New(fieldType))
} }
fieldValue.Elem().Set(reflectV.Convert(field.FieldType.Elem())) fieldValue.Elem().Set(reflectV.Convert(fieldType))
return return
} }
} }
@ -520,7 +521,7 @@ func (field *Field) setupValuerAndSetter() {
err = setter(value, v) err = setter(value, v)
} }
} else { } else {
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) return fmt.Errorf("failed to set value %+v to field %s", v, field.Name)
} }
} }