diff --git a/callbacks/create.go b/callbacks/create.go index 00bbea8b..04438232 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -107,15 +107,47 @@ func Create(config *Config) func(db *gorm.DB) { return } - if db.Statement.Schema != nil { - if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { - return - } + insertID, err := result.LastInsertId() + insertOk := err == nil && insertID > 0 + if !insertOk { + db.AddError(err) + return + } - insertID, err := result.LastInsertId() - insertOk := err == nil && insertID > 0 - if !insertOk { - db.AddError(err) + var ( + pkField *schema.Field + pkFieldName = "@id" + ) + + if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { + pkField = db.Statement.Schema.PrioritizedPrimaryField + pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName + } + + // append @id column with value for auto-increment primary key + // the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1 + switch values := db.Statement.Dest.(type) { + case map[string]interface{}: + values[pkFieldName] = insertID + case *map[string]interface{}: + (*values)[pkFieldName] = insertID + case []map[string]interface{}, *[]map[string]interface{}: + mapValues, ok := values.([]map[string]interface{}) + if !ok { + if v, ok := values.(*[]map[string]interface{}); ok { + if *v != nil { + mapValues = *v + } + } + } + for _, mapValue := range mapValues { + if mapValue != nil { + mapValue[pkFieldName] = insertID + } + insertID += schema.DefaultAutoIncrementIncrement + } + default: + if pkField == nil { return } @@ -128,10 +160,10 @@ func Create(config *Config) func(db *gorm.DB) { break } - _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv) + _, isZero := pkField.ValueOf(db.Statement.Context, rv) if isZero { - db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID)) - insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement + db.AddError(pkField.Set(db.Statement.Context, rv, insertID)) + insertID -= pkField.AutoIncrementIncrement } } } else { @@ -141,53 +173,16 @@ func Create(config *Config) func(db *gorm.DB) { break } - if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv); isZero { - db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID)) - insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement + if _, isZero := pkField.ValueOf(db.Statement.Context, rv); isZero { + db.AddError(pkField.Set(db.Statement.Context, rv, insertID)) + insertID += pkField.AutoIncrementIncrement } } } case reflect.Struct: - _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue) + _, isZero := pkField.ValueOf(db.Statement.Context, db.Statement.ReflectValue) if isZero { - db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID)) - } - } - } - if db.Statement.Dest != nil { - // append @id column with value for auto-increment primary key - // the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1 - insertID, err := result.LastInsertId() - insertOk := err == nil && insertID > 0 - if !insertOk { - db.AddError(err) - return - } - - pkFieldName := "@id" - if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil { - pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName - } - - switch values := db.Statement.Dest.(type) { - case map[string]interface{}: - values[pkFieldName] = insertID - case *map[string]interface{}: - (*values)[pkFieldName] = insertID - case []map[string]interface{}, *[]map[string]interface{}: - mapValues, ok := values.([]map[string]interface{}) - if !ok { - if v, ok := values.(*[]map[string]interface{}); ok { - if *v != nil { - mapValues = *v - } - } - } - for _, mapValue := range mapValues { - if mapValue != nil { - mapValue[pkFieldName] = insertID - } - insertID += schema.DefaultAutoIncrementIncrement + db.AddError(pkField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID)) } } }