refactor write back pk value for create

This commit is contained in:
方圣卿 2023-11-15 16:54:06 +08:00
parent e7301c10a8
commit a26ce789b2

View File

@ -107,11 +107,6 @@ func Create(config *Config) func(db *gorm.DB) {
return return
} }
if db.Statement.Schema != nil {
if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
return
}
insertID, err := result.LastInsertId() insertID, err := result.LastInsertId()
insertOk := err == nil && insertID > 0 insertOk := err == nil && insertID > 0
if !insertOk { if !insertOk {
@ -119,56 +114,18 @@ func Create(config *Config) func(db *gorm.DB) {
return return
} }
switch db.Statement.ReflectValue.Kind() { var (
case reflect.Slice, reflect.Array: pkField *schema.Field
if config.LastInsertIDReversed { pkFieldName = "@id"
for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { )
rv := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(rv).Kind() != reflect.Struct {
break
}
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv) if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
if isZero { pkField = db.Statement.Schema.PrioritizedPrimaryField
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID))
insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
}
}
} else {
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
rv := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(rv).Kind() != reflect.Struct {
break
}
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv); isZero {
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID))
insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
}
}
}
case reflect.Struct:
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
if isZero {
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
}
}
}
if db.Statement.Dest != nil {
// append @id column with value for auto-increment primary key
// the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1
insertID, err := result.LastInsertId()
insertOk := err == nil && insertID > 0
if !insertOk {
db.AddError(err)
return
}
pkFieldName := "@id"
if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil {
pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName
} }
// append @id column with value for auto-increment primary key
// the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1
switch values := db.Statement.Dest.(type) { switch values := db.Statement.Dest.(type) {
case map[string]interface{}: case map[string]interface{}:
values[pkFieldName] = insertID values[pkFieldName] = insertID
@ -189,6 +146,44 @@ func Create(config *Config) func(db *gorm.DB) {
} }
insertID += schema.DefaultAutoIncrementIncrement insertID += schema.DefaultAutoIncrementIncrement
} }
default:
if pkField == nil {
return
}
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
if config.LastInsertIDReversed {
for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
rv := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(rv).Kind() != reflect.Struct {
break
}
_, isZero := pkField.ValueOf(db.Statement.Context, rv)
if isZero {
db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
insertID -= pkField.AutoIncrementIncrement
}
}
} else {
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
rv := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(rv).Kind() != reflect.Struct {
break
}
if _, isZero := pkField.ValueOf(db.Statement.Context, rv); isZero {
db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
insertID += pkField.AutoIncrementIncrement
}
}
}
case reflect.Struct:
_, isZero := pkField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
if isZero {
db.AddError(pkField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
}
} }
} }
} }