diff --git a/dialects/sqlite/main_test.go b/dialects/sqlite/main_test.go index 49b7a5cf..9039f5fc 100644 --- a/dialects/sqlite/main_test.go +++ b/dialects/sqlite/main_test.go @@ -6,6 +6,7 @@ import ( "path/filepath" "testing" + "github.com/davecgh/go-spew/spew" "github.com/jinzhu/gorm" ) @@ -19,12 +20,22 @@ func init() { } } -func TestInsert(t *testing.T) { +func TestBatchInsert(t *testing.T) { type User struct { gorm.Model Name string Age int } - DB.Create([]*User{{Name: "name1", Age: 10}, {Name: "name2", Age: 20}}) + users := []*User{{Name: "name1", Age: 10}, {Name: "name2", Age: 20}, {Name: "name3", Age: 30}} + + DB.Create(users) + + spew.Dump(users) + + for _, user := range users { + if user.ID == 0 { + t.Errorf("User should have primary key") + } + } } diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go index afe844a3..e0ff63bc 100644 --- a/dialects/sqlite/sqlite.go +++ b/dialects/sqlite/sqlite.go @@ -25,6 +25,7 @@ func (dialect *Dialect) Insert(tx *gorm.DB) (err error) { args []interface{} assignmentsChan = model.GetAssignments(tx) tableNameChan = model.GetTable(tx) + primaryFields []*model.Field ) s := bytes.NewBufferString("INSERT INTO ") @@ -40,12 +41,17 @@ func (dialect *Dialect) Insert(tx *gorm.DB) (err error) { valueBuffer := bytes.NewBufferString("VALUES ") for idx, fields := range assignments { + var primaryField *model.Field if idx != 0 { valueBuffer.WriteString(",") } valueBuffer.WriteString(" (") for j, field := range fields { + if field.Field.IsPrimaryKey && primaryField == nil || field.Field.DBName == "id" { + primaryField = field + } + if idx == 0 { columns = append(columns, field.Field.DBName) if j != 0 { @@ -60,15 +66,13 @@ func (dialect *Dialect) Insert(tx *gorm.DB) (err error) { valueBuffer.WriteString("?") if field.IsBlank { - if field.Field.HasDefaultValue { - args = append(args, field.Field.DefaultValue) - } else { - args = append(args, nil) - } + args = append(args, nil) } else { args = append(args, field.Value.Interface()) } } + + primaryFields = append(primaryFields, primaryField) valueBuffer.WriteString(")") } s.WriteString(") ") @@ -78,12 +82,18 @@ func (dialect *Dialect) Insert(tx *gorm.DB) (err error) { s.WriteString(" DEFAULT VALUES") } - fmt.Println(s.String()) - fmt.Printf("%#v \n", args) - if result, err := dialect.DB.Exec(s.String(), args...); err == nil { + result, err := dialect.DB.Exec(s.String(), args...) + + if err == nil { + var lastInsertID int64 tx.RowsAffected, _ = result.RowsAffected() - } else { - fmt.Println(err) + lastInsertID, err = result.LastInsertId() + if len(primaryFields) == int(tx.RowsAffected) { + startID := lastInsertID - tx.RowsAffected + 1 + for i, primaryField := range primaryFields { + tx.AddError(primaryField.Set(startID + int64(i))) + } + } } return }