diff --git a/dialects/common/sqlbuilder/insert.go b/dialects/common/sqlbuilder/insert.go new file mode 100644 index 00000000..6b70e44b --- /dev/null +++ b/dialects/common/sqlbuilder/insert.go @@ -0,0 +1,78 @@ +package sqlbuilder + +import ( + "bytes" + + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/model" +) + +// BuildInsertSQL build insert SQL +func BuildInsertSQL(tx *gorm.DB) (s *bytes.Buffer, args []interface{}, defaultFieldsSlice [][]*model.Field, err error) { + var ( + dialect = tx.Dialect() + assignmentsChan = GetAssignmentFields(tx) + tableNameChan = GetTable(tx) + ) + defer close(tableNameChan) + + s = bytes.NewBufferString("INSERT INTO ") + s.WriteString(dialect.Quote(<-tableNameChan)) + + if assignments := <-assignmentsChan; len(assignments) > 0 { + columns := []string{} + defaultFields := []*model.Field{} + + // Write columns (column1, column2, column3) + s.WriteString(" (") + + // Write values (v1, v2, v3), (v2-1, v2-2, v2-3) + valueBuffer := bytes.NewBufferString("VALUES ") + + for idx, fields := range assignments { + var primaryField *model.Field + if idx != 0 { + valueBuffer.WriteString(",") + } + valueBuffer.WriteString(" (") + + for j, field := range fields { + if field.Field.IsPrimaryKey && primaryField == nil || field.Field.DBName == "id" { + primaryField = field + } + + if idx == 0 { + columns = append(columns, field.Field.DBName) + if j != 0 { + s.WriteString(", ") + } + s.WriteString(dialect.Quote(field.Field.DBName)) + } + + if j != 0 { + valueBuffer.WriteString(", ") + } + valueBuffer.WriteString("?") + + if field.IsBlank { + args = append(args, nil) + if field.HasDefaultValue { + defaultFields = append(defaultFields, field) + } + } else { + args = append(args, field.Value.Interface()) + } + } + + defaultFieldsSlice = append(defaultFieldsSlice, append([]*model.Field{primaryField}, defaultFields...)) + valueBuffer.WriteString(")") + } + s.WriteString(") ") + + _, err = valueBuffer.WriteTo(s) + } else { + s.WriteString(" DEFAULT VALUES") + } + + return +} diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go index dd3d2149..8214cb0e 100644 --- a/dialects/sqlite/sqlite.go +++ b/dialects/sqlite/sqlite.go @@ -1,13 +1,11 @@ package sqlite import ( - "bytes" "database/sql" "fmt" "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/dialects/common/sqlbuilder" - "github.com/jinzhu/gorm/model" ) // Dialect Sqlite3 Dialect for GORM @@ -22,80 +20,26 @@ func (dialect Dialect) Quote(name string) string { // Insert insert func (dialect *Dialect) Insert(tx *gorm.DB) (err error) { - var ( - args []interface{} - assignmentsChan = sqlbuilder.GetAssignmentFields(tx) - tableNameChan = sqlbuilder.GetTable(tx) - primaryFields []*model.Field - ) - - s := bytes.NewBufferString("INSERT INTO ") - s.WriteString(dialect.Quote(<-tableNameChan)) - - if assignments := <-assignmentsChan; len(assignments) > 0 { - columns := []string{} - - // Write columns (column1, column2, column3) - s.WriteString(" (") - - // Write values (v1, v2, v3), (v2-1, v2-2, v2-3) - valueBuffer := bytes.NewBufferString("VALUES ") - - for idx, fields := range assignments { - var primaryField *model.Field - if idx != 0 { - valueBuffer.WriteString(",") - } - valueBuffer.WriteString(" (") - - for j, field := range fields { - if field.Field.IsPrimaryKey && primaryField == nil || field.Field.DBName == "id" { - primaryField = field - } - - if idx == 0 { - columns = append(columns, field.Field.DBName) - if j != 0 { - s.WriteString(", ") - } - s.WriteString(dialect.Quote(field.Field.DBName)) - } - - if j != 0 { - valueBuffer.WriteString(", ") - } - valueBuffer.WriteString("?") - - if field.IsBlank { - args = append(args, nil) - } else { - args = append(args, field.Value.Interface()) - } - } - - primaryFields = append(primaryFields, primaryField) - valueBuffer.WriteString(")") - } - s.WriteString(") ") - - _, err = valueBuffer.WriteTo(s) - } else { - s.WriteString(" DEFAULT VALUES") - } - - result, err := dialect.DB.Exec(s.String(), args...) + s, args, defaultFieldsSlice, err := sqlbuilder.BuildInsertSQL(tx) if err == nil { - var lastInsertID int64 - tx.RowsAffected, _ = result.RowsAffected() - lastInsertID, err = result.LastInsertId() - if len(primaryFields) == int(tx.RowsAffected) { - startID := lastInsertID - tx.RowsAffected + 1 - for i, primaryField := range primaryFields { - tx.AddError(primaryField.Set(startID + int64(i))) + result, err := dialect.DB.Exec(s.String(), args...) + + if err == nil { + var lastInsertID int64 + tx.RowsAffected, _ = result.RowsAffected() + lastInsertID, err = result.LastInsertId() + if len(defaultFieldsSlice) == int(tx.RowsAffected) { + startID := lastInsertID - tx.RowsAffected + 1 + for i, defaultFields := range defaultFieldsSlice { + if len(defaultFields) > 0 { + tx.AddError(defaultFields[0].Set(startID + int64(i))) + } + } } } } + return }