diff --git a/callback_create.go b/callback_create.go index 7f21ed6a..e33a7688 100644 --- a/callback_create.go +++ b/callback_create.go @@ -70,8 +70,8 @@ func Create(scope *Scope) { id, err := result.LastInsertId() if scope.Err(err) == nil { scope.db.RowsAffected, _ = result.RowsAffected() - if primaryField != nil && primaryField.IsBlank { - scope.Err(scope.SetColumn(primaryField, id)) + if autoIncrementField := scope.AutoIncrementField(); autoIncrementField != nil { + scope.Err(scope.SetColumn(autoIncrementField, id)) } } } diff --git a/model_struct.go b/model_struct.go index 10423ae2..4494669d 100644 --- a/model_struct.go +++ b/model_struct.go @@ -40,6 +40,7 @@ type StructField struct { Tag reflect.StructTag Struct reflect.StructField IsForeignKey bool + IsAutoIncrement bool Relationship *Relationship } @@ -148,6 +149,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct { field.HasDefaultValue = true } + if _, ok := sqlSettings["AUTO_INCREMENT"]; ok { + field.IsAutoIncrement = true + } + if value, ok := gormSettings["COLUMN"]; ok { field.DBName = value } else { diff --git a/scope.go b/scope.go index 11bad777..a161d33b 100644 --- a/scope.go +++ b/scope.go @@ -109,7 +109,6 @@ func (scope *Scope) Log(v ...interface{}) { func (scope *Scope) HasError() bool { return scope.db.Error != nil } - func (scope *Scope) PrimaryField() *Field { if primaryFields := scope.GetModelStruct().PrimaryFields; len(primaryFields) > 0 { if len(primaryFields) > 1 { @@ -122,6 +121,17 @@ func (scope *Scope) PrimaryField() *Field { return nil } +func (scope *Scope) AutoIncrementField() *Field { + if structFields := scope.GetModelStruct().StructFields; len(structFields) > 0 { + for i := 0; i < len(structFields); i++ { + if structFields[i].IsAutoIncrement { + return scope.Fields()[structFields[i].DBName] + } + } + } + return scope.PrimaryField() +} + // PrimaryKey get the primary key's column name func (scope *Scope) PrimaryKey() string { if field := scope.PrimaryField(); field != nil { diff --git a/structs_test.go b/structs_test.go index 9a9b23d1..45cd5f33 100644 --- a/structs_test.go +++ b/structs_test.go @@ -11,7 +11,7 @@ import ( ) type User struct { - Id int64 + Id int64 `sql:"auto_increment"` Age int64 UserNum Num Name string `sql:"size:255"`