Add batch id tests
This commit is contained in:
parent
2e2fd54278
commit
94e06eb2f8
@ -6,6 +6,7 @@ import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
@ -19,12 +20,22 @@ func init() {
|
||||
}
|
||||
}
|
||||
|
||||
func TestInsert(t *testing.T) {
|
||||
func TestBatchInsert(t *testing.T) {
|
||||
type User struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
Age int
|
||||
}
|
||||
|
||||
DB.Create([]*User{{Name: "name1", Age: 10}, {Name: "name2", Age: 20}})
|
||||
users := []*User{{Name: "name1", Age: 10}, {Name: "name2", Age: 20}, {Name: "name3", Age: 30}}
|
||||
|
||||
DB.Create(users)
|
||||
|
||||
spew.Dump(users)
|
||||
|
||||
for _, user := range users {
|
||||
if user.ID == 0 {
|
||||
t.Errorf("User should have primary key")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -25,6 +25,7 @@ func (dialect *Dialect) Insert(tx *gorm.DB) (err error) {
|
||||
args []interface{}
|
||||
assignmentsChan = model.GetAssignments(tx)
|
||||
tableNameChan = model.GetTable(tx)
|
||||
primaryFields []*model.Field
|
||||
)
|
||||
|
||||
s := bytes.NewBufferString("INSERT INTO ")
|
||||
@ -40,12 +41,17 @@ func (dialect *Dialect) Insert(tx *gorm.DB) (err error) {
|
||||
valueBuffer := bytes.NewBufferString("VALUES ")
|
||||
|
||||
for idx, fields := range assignments {
|
||||
var primaryField *model.Field
|
||||
if idx != 0 {
|
||||
valueBuffer.WriteString(",")
|
||||
}
|
||||
valueBuffer.WriteString(" (")
|
||||
|
||||
for j, field := range fields {
|
||||
if field.Field.IsPrimaryKey && primaryField == nil || field.Field.DBName == "id" {
|
||||
primaryField = field
|
||||
}
|
||||
|
||||
if idx == 0 {
|
||||
columns = append(columns, field.Field.DBName)
|
||||
if j != 0 {
|
||||
@ -60,15 +66,13 @@ func (dialect *Dialect) Insert(tx *gorm.DB) (err error) {
|
||||
valueBuffer.WriteString("?")
|
||||
|
||||
if field.IsBlank {
|
||||
if field.Field.HasDefaultValue {
|
||||
args = append(args, field.Field.DefaultValue)
|
||||
} else {
|
||||
args = append(args, nil)
|
||||
}
|
||||
} else {
|
||||
args = append(args, field.Value.Interface())
|
||||
}
|
||||
}
|
||||
|
||||
primaryFields = append(primaryFields, primaryField)
|
||||
valueBuffer.WriteString(")")
|
||||
}
|
||||
s.WriteString(") ")
|
||||
@ -78,12 +82,18 @@ func (dialect *Dialect) Insert(tx *gorm.DB) (err error) {
|
||||
s.WriteString(" DEFAULT VALUES")
|
||||
}
|
||||
|
||||
fmt.Println(s.String())
|
||||
fmt.Printf("%#v \n", args)
|
||||
if result, err := dialect.DB.Exec(s.String(), args...); err == nil {
|
||||
result, err := dialect.DB.Exec(s.String(), args...)
|
||||
|
||||
if err == nil {
|
||||
var lastInsertID int64
|
||||
tx.RowsAffected, _ = result.RowsAffected()
|
||||
} else {
|
||||
fmt.Println(err)
|
||||
lastInsertID, err = result.LastInsertId()
|
||||
if len(primaryFields) == int(tx.RowsAffected) {
|
||||
startID := lastInsertID - tx.RowsAffected + 1
|
||||
for i, primaryField := range primaryFields {
|
||||
tx.AddError(primaryField.Set(startID + int64(i)))
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user