Insert SQL builder
This commit is contained in:
parent
f0a88b68a9
commit
7bb9e6f00b
78
dialects/common/sqlbuilder/insert.go
Normal file
78
dialects/common/sqlbuilder/insert.go
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
package sqlbuilder
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
|
"github.com/jinzhu/gorm/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BuildInsertSQL build insert SQL
|
||||||
|
func BuildInsertSQL(tx *gorm.DB) (s *bytes.Buffer, args []interface{}, defaultFieldsSlice [][]*model.Field, err error) {
|
||||||
|
var (
|
||||||
|
dialect = tx.Dialect()
|
||||||
|
assignmentsChan = GetAssignmentFields(tx)
|
||||||
|
tableNameChan = GetTable(tx)
|
||||||
|
)
|
||||||
|
defer close(tableNameChan)
|
||||||
|
|
||||||
|
s = bytes.NewBufferString("INSERT INTO ")
|
||||||
|
s.WriteString(dialect.Quote(<-tableNameChan))
|
||||||
|
|
||||||
|
if assignments := <-assignmentsChan; len(assignments) > 0 {
|
||||||
|
columns := []string{}
|
||||||
|
defaultFields := []*model.Field{}
|
||||||
|
|
||||||
|
// Write columns (column1, column2, column3)
|
||||||
|
s.WriteString(" (")
|
||||||
|
|
||||||
|
// Write values (v1, v2, v3), (v2-1, v2-2, v2-3)
|
||||||
|
valueBuffer := bytes.NewBufferString("VALUES ")
|
||||||
|
|
||||||
|
for idx, fields := range assignments {
|
||||||
|
var primaryField *model.Field
|
||||||
|
if idx != 0 {
|
||||||
|
valueBuffer.WriteString(",")
|
||||||
|
}
|
||||||
|
valueBuffer.WriteString(" (")
|
||||||
|
|
||||||
|
for j, field := range fields {
|
||||||
|
if field.Field.IsPrimaryKey && primaryField == nil || field.Field.DBName == "id" {
|
||||||
|
primaryField = field
|
||||||
|
}
|
||||||
|
|
||||||
|
if idx == 0 {
|
||||||
|
columns = append(columns, field.Field.DBName)
|
||||||
|
if j != 0 {
|
||||||
|
s.WriteString(", ")
|
||||||
|
}
|
||||||
|
s.WriteString(dialect.Quote(field.Field.DBName))
|
||||||
|
}
|
||||||
|
|
||||||
|
if j != 0 {
|
||||||
|
valueBuffer.WriteString(", ")
|
||||||
|
}
|
||||||
|
valueBuffer.WriteString("?")
|
||||||
|
|
||||||
|
if field.IsBlank {
|
||||||
|
args = append(args, nil)
|
||||||
|
if field.HasDefaultValue {
|
||||||
|
defaultFields = append(defaultFields, field)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
args = append(args, field.Value.Interface())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultFieldsSlice = append(defaultFieldsSlice, append([]*model.Field{primaryField}, defaultFields...))
|
||||||
|
valueBuffer.WriteString(")")
|
||||||
|
}
|
||||||
|
s.WriteString(") ")
|
||||||
|
|
||||||
|
_, err = valueBuffer.WriteTo(s)
|
||||||
|
} else {
|
||||||
|
s.WriteString(" DEFAULT VALUES")
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
@ -1,13 +1,11 @@
|
|||||||
package sqlite
|
package sqlite
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
"github.com/jinzhu/gorm/dialects/common/sqlbuilder"
|
"github.com/jinzhu/gorm/dialects/common/sqlbuilder"
|
||||||
"github.com/jinzhu/gorm/model"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Dialect Sqlite3 Dialect for GORM
|
// Dialect Sqlite3 Dialect for GORM
|
||||||
@ -22,80 +20,26 @@ func (dialect Dialect) Quote(name string) string {
|
|||||||
|
|
||||||
// Insert insert
|
// Insert insert
|
||||||
func (dialect *Dialect) Insert(tx *gorm.DB) (err error) {
|
func (dialect *Dialect) Insert(tx *gorm.DB) (err error) {
|
||||||
var (
|
s, args, defaultFieldsSlice, err := sqlbuilder.BuildInsertSQL(tx)
|
||||||
args []interface{}
|
|
||||||
assignmentsChan = sqlbuilder.GetAssignmentFields(tx)
|
|
||||||
tableNameChan = sqlbuilder.GetTable(tx)
|
|
||||||
primaryFields []*model.Field
|
|
||||||
)
|
|
||||||
|
|
||||||
s := bytes.NewBufferString("INSERT INTO ")
|
|
||||||
s.WriteString(dialect.Quote(<-tableNameChan))
|
|
||||||
|
|
||||||
if assignments := <-assignmentsChan; len(assignments) > 0 {
|
|
||||||
columns := []string{}
|
|
||||||
|
|
||||||
// Write columns (column1, column2, column3)
|
|
||||||
s.WriteString(" (")
|
|
||||||
|
|
||||||
// Write values (v1, v2, v3), (v2-1, v2-2, v2-3)
|
|
||||||
valueBuffer := bytes.NewBufferString("VALUES ")
|
|
||||||
|
|
||||||
for idx, fields := range assignments {
|
|
||||||
var primaryField *model.Field
|
|
||||||
if idx != 0 {
|
|
||||||
valueBuffer.WriteString(",")
|
|
||||||
}
|
|
||||||
valueBuffer.WriteString(" (")
|
|
||||||
|
|
||||||
for j, field := range fields {
|
|
||||||
if field.Field.IsPrimaryKey && primaryField == nil || field.Field.DBName == "id" {
|
|
||||||
primaryField = field
|
|
||||||
}
|
|
||||||
|
|
||||||
if idx == 0 {
|
|
||||||
columns = append(columns, field.Field.DBName)
|
|
||||||
if j != 0 {
|
|
||||||
s.WriteString(", ")
|
|
||||||
}
|
|
||||||
s.WriteString(dialect.Quote(field.Field.DBName))
|
|
||||||
}
|
|
||||||
|
|
||||||
if j != 0 {
|
|
||||||
valueBuffer.WriteString(", ")
|
|
||||||
}
|
|
||||||
valueBuffer.WriteString("?")
|
|
||||||
|
|
||||||
if field.IsBlank {
|
|
||||||
args = append(args, nil)
|
|
||||||
} else {
|
|
||||||
args = append(args, field.Value.Interface())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
primaryFields = append(primaryFields, primaryField)
|
|
||||||
valueBuffer.WriteString(")")
|
|
||||||
}
|
|
||||||
s.WriteString(") ")
|
|
||||||
|
|
||||||
_, err = valueBuffer.WriteTo(s)
|
|
||||||
} else {
|
|
||||||
s.WriteString(" DEFAULT VALUES")
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := dialect.DB.Exec(s.String(), args...)
|
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
var lastInsertID int64
|
result, err := dialect.DB.Exec(s.String(), args...)
|
||||||
tx.RowsAffected, _ = result.RowsAffected()
|
|
||||||
lastInsertID, err = result.LastInsertId()
|
if err == nil {
|
||||||
if len(primaryFields) == int(tx.RowsAffected) {
|
var lastInsertID int64
|
||||||
startID := lastInsertID - tx.RowsAffected + 1
|
tx.RowsAffected, _ = result.RowsAffected()
|
||||||
for i, primaryField := range primaryFields {
|
lastInsertID, err = result.LastInsertId()
|
||||||
tx.AddError(primaryField.Set(startID + int64(i)))
|
if len(defaultFieldsSlice) == int(tx.RowsAffected) {
|
||||||
|
startID := lastInsertID - tx.RowsAffected + 1
|
||||||
|
for i, defaultFields := range defaultFieldsSlice {
|
||||||
|
if len(defaultFields) > 0 {
|
||||||
|
tx.AddError(defaultFields[0].Set(startID + int64(i)))
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user