diff --git a/callback_create.go b/callback_create.go index e7ec40bb..0aca4eeb 100644 --- a/callback_create.go +++ b/callback_create.go @@ -70,8 +70,8 @@ func Create(scope *Scope) { id, err := result.LastInsertId() if scope.Err(err) == nil { scope.db.RowsAffected, _ = result.RowsAffected() - if primaryField != nil && primaryField.IsBlank { - scope.Err(scope.SetColumn(primaryField, id)) + if autoIncrementField := scope.AutoIncrementField(); autoIncrementField != nil { + scope.Err(scope.SetColumn(autoIncrementField, id)) } } } diff --git a/scope.go b/scope.go index 8cf1da3e..a161d33b 100644 --- a/scope.go +++ b/scope.go @@ -109,16 +109,27 @@ func (scope *Scope) Log(v ...interface{}) { func (scope *Scope) HasError() bool { return scope.db.Error != nil } - func (scope *Scope) PrimaryField() *Field { if primaryFields := scope.GetModelStruct().PrimaryFields; len(primaryFields) > 0 { - for i := 0; i < len(primaryFields); i++ { - if primaryFields[i].IsAutoIncrement { - return scope.Fields()[primaryFields[i].DBName] + if len(primaryFields) > 1 { + if field, ok := scope.Fields()["id"]; ok { + return field + } + } + return scope.Fields()[primaryFields[0].DBName] + } + return nil +} + +func (scope *Scope) AutoIncrementField() *Field { + if structFields := scope.GetModelStruct().StructFields; len(structFields) > 0 { + for i := 0; i < len(structFields); i++ { + if structFields[i].IsAutoIncrement { + return scope.Fields()[structFields[i].DBName] } } } - return nil + return scope.PrimaryField() } // PrimaryKey get the primary key's column name diff --git a/structs_test.go b/structs_test.go index 9a9b23d1..45cd5f33 100644 --- a/structs_test.go +++ b/structs_test.go @@ -11,7 +11,7 @@ import ( ) type User struct { - Id int64 + Id int64 `sql:"auto_increment"` Age int64 UserNum Num Name string `sql:"size:255"`