Add batch id tests

This commit is contained in:
Jinzhu 2018-02-28 00:30:03 +08:00
parent 2e2fd54278
commit 94e06eb2f8
2 changed files with 33 additions and 12 deletions

View File

@ -6,6 +6,7 @@ import (
"path/filepath" "path/filepath"
"testing" "testing"
"github.com/davecgh/go-spew/spew"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
) )
@ -19,12 +20,22 @@ func init() {
} }
} }
func TestInsert(t *testing.T) { func TestBatchInsert(t *testing.T) {
type User struct { type User struct {
gorm.Model gorm.Model
Name string Name string
Age int Age int
} }
DB.Create([]*User{{Name: "name1", Age: 10}, {Name: "name2", Age: 20}}) users := []*User{{Name: "name1", Age: 10}, {Name: "name2", Age: 20}, {Name: "name3", Age: 30}}
DB.Create(users)
spew.Dump(users)
for _, user := range users {
if user.ID == 0 {
t.Errorf("User should have primary key")
}
}
} }

View File

@ -25,6 +25,7 @@ func (dialect *Dialect) Insert(tx *gorm.DB) (err error) {
args []interface{} args []interface{}
assignmentsChan = model.GetAssignments(tx) assignmentsChan = model.GetAssignments(tx)
tableNameChan = model.GetTable(tx) tableNameChan = model.GetTable(tx)
primaryFields []*model.Field
) )
s := bytes.NewBufferString("INSERT INTO ") s := bytes.NewBufferString("INSERT INTO ")
@ -40,12 +41,17 @@ func (dialect *Dialect) Insert(tx *gorm.DB) (err error) {
valueBuffer := bytes.NewBufferString("VALUES ") valueBuffer := bytes.NewBufferString("VALUES ")
for idx, fields := range assignments { for idx, fields := range assignments {
var primaryField *model.Field
if idx != 0 { if idx != 0 {
valueBuffer.WriteString(",") valueBuffer.WriteString(",")
} }
valueBuffer.WriteString(" (") valueBuffer.WriteString(" (")
for j, field := range fields { for j, field := range fields {
if field.Field.IsPrimaryKey && primaryField == nil || field.Field.DBName == "id" {
primaryField = field
}
if idx == 0 { if idx == 0 {
columns = append(columns, field.Field.DBName) columns = append(columns, field.Field.DBName)
if j != 0 { if j != 0 {
@ -60,15 +66,13 @@ func (dialect *Dialect) Insert(tx *gorm.DB) (err error) {
valueBuffer.WriteString("?") valueBuffer.WriteString("?")
if field.IsBlank { if field.IsBlank {
if field.Field.HasDefaultValue { args = append(args, nil)
args = append(args, field.Field.DefaultValue)
} else {
args = append(args, nil)
}
} else { } else {
args = append(args, field.Value.Interface()) args = append(args, field.Value.Interface())
} }
} }
primaryFields = append(primaryFields, primaryField)
valueBuffer.WriteString(")") valueBuffer.WriteString(")")
} }
s.WriteString(") ") s.WriteString(") ")
@ -78,12 +82,18 @@ func (dialect *Dialect) Insert(tx *gorm.DB) (err error) {
s.WriteString(" DEFAULT VALUES") s.WriteString(" DEFAULT VALUES")
} }
fmt.Println(s.String()) result, err := dialect.DB.Exec(s.String(), args...)
fmt.Printf("%#v \n", args)
if result, err := dialect.DB.Exec(s.String(), args...); err == nil { if err == nil {
var lastInsertID int64
tx.RowsAffected, _ = result.RowsAffected() tx.RowsAffected, _ = result.RowsAffected()
} else { lastInsertID, err = result.LastInsertId()
fmt.Println(err) if len(primaryFields) == int(tx.RowsAffected) {
startID := lastInsertID - tx.RowsAffected + 1
for i, primaryField := range primaryFields {
tx.AddError(primaryField.Set(startID + int64(i)))
}
}
} }
return return
} }