Insert SQL builder

This commit is contained in:
Jinzhu 2018-03-08 22:57:11 +08:00
parent f0a88b68a9
commit 7bb9e6f00b
2 changed files with 93 additions and 71 deletions

View File

@ -0,0 +1,78 @@
package sqlbuilder
import (
"bytes"
"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/model"
)
// BuildInsertSQL build insert SQL
func BuildInsertSQL(tx *gorm.DB) (s *bytes.Buffer, args []interface{}, defaultFieldsSlice [][]*model.Field, err error) {
var (
dialect = tx.Dialect()
assignmentsChan = GetAssignmentFields(tx)
tableNameChan = GetTable(tx)
)
defer close(tableNameChan)
s = bytes.NewBufferString("INSERT INTO ")
s.WriteString(dialect.Quote(<-tableNameChan))
if assignments := <-assignmentsChan; len(assignments) > 0 {
columns := []string{}
defaultFields := []*model.Field{}
// Write columns (column1, column2, column3)
s.WriteString(" (")
// Write values (v1, v2, v3), (v2-1, v2-2, v2-3)
valueBuffer := bytes.NewBufferString("VALUES ")
for idx, fields := range assignments {
var primaryField *model.Field
if idx != 0 {
valueBuffer.WriteString(",")
}
valueBuffer.WriteString(" (")
for j, field := range fields {
if field.Field.IsPrimaryKey && primaryField == nil || field.Field.DBName == "id" {
primaryField = field
}
if idx == 0 {
columns = append(columns, field.Field.DBName)
if j != 0 {
s.WriteString(", ")
}
s.WriteString(dialect.Quote(field.Field.DBName))
}
if j != 0 {
valueBuffer.WriteString(", ")
}
valueBuffer.WriteString("?")
if field.IsBlank {
args = append(args, nil)
if field.HasDefaultValue {
defaultFields = append(defaultFields, field)
}
} else {
args = append(args, field.Value.Interface())
}
}
defaultFieldsSlice = append(defaultFieldsSlice, append([]*model.Field{primaryField}, defaultFields...))
valueBuffer.WriteString(")")
}
s.WriteString(") ")
_, err = valueBuffer.WriteTo(s)
} else {
s.WriteString(" DEFAULT VALUES")
}
return
}

View File

@ -1,13 +1,11 @@
package sqlite
import (
"bytes"
"database/sql"
"fmt"
"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/dialects/common/sqlbuilder"
"github.com/jinzhu/gorm/model"
)
// Dialect Sqlite3 Dialect for GORM
@ -22,80 +20,26 @@ func (dialect Dialect) Quote(name string) string {
// Insert insert
func (dialect *Dialect) Insert(tx *gorm.DB) (err error) {
var (
args []interface{}
assignmentsChan = sqlbuilder.GetAssignmentFields(tx)
tableNameChan = sqlbuilder.GetTable(tx)
primaryFields []*model.Field
)
s := bytes.NewBufferString("INSERT INTO ")
s.WriteString(dialect.Quote(<-tableNameChan))
if assignments := <-assignmentsChan; len(assignments) > 0 {
columns := []string{}
// Write columns (column1, column2, column3)
s.WriteString(" (")
// Write values (v1, v2, v3), (v2-1, v2-2, v2-3)
valueBuffer := bytes.NewBufferString("VALUES ")
for idx, fields := range assignments {
var primaryField *model.Field
if idx != 0 {
valueBuffer.WriteString(",")
}
valueBuffer.WriteString(" (")
for j, field := range fields {
if field.Field.IsPrimaryKey && primaryField == nil || field.Field.DBName == "id" {
primaryField = field
}
if idx == 0 {
columns = append(columns, field.Field.DBName)
if j != 0 {
s.WriteString(", ")
}
s.WriteString(dialect.Quote(field.Field.DBName))
}
if j != 0 {
valueBuffer.WriteString(", ")
}
valueBuffer.WriteString("?")
if field.IsBlank {
args = append(args, nil)
} else {
args = append(args, field.Value.Interface())
}
}
primaryFields = append(primaryFields, primaryField)
valueBuffer.WriteString(")")
}
s.WriteString(") ")
_, err = valueBuffer.WriteTo(s)
} else {
s.WriteString(" DEFAULT VALUES")
}
result, err := dialect.DB.Exec(s.String(), args...)
s, args, defaultFieldsSlice, err := sqlbuilder.BuildInsertSQL(tx)
if err == nil {
var lastInsertID int64
tx.RowsAffected, _ = result.RowsAffected()
lastInsertID, err = result.LastInsertId()
if len(primaryFields) == int(tx.RowsAffected) {
startID := lastInsertID - tx.RowsAffected + 1
for i, primaryField := range primaryFields {
tx.AddError(primaryField.Set(startID + int64(i)))
result, err := dialect.DB.Exec(s.String(), args...)
if err == nil {
var lastInsertID int64
tx.RowsAffected, _ = result.RowsAffected()
lastInsertID, err = result.LastInsertId()
if len(defaultFieldsSlice) == int(tx.RowsAffected) {
startID := lastInsertID - tx.RowsAffected + 1
for i, defaultFields := range defaultFieldsSlice {
if len(defaultFields) > 0 {
tx.AddError(defaultFields[0].Set(startID + int64(i)))
}
}
}
}
}
return
}