Insert SQL builder
This commit is contained in:
		
							parent
							
								
									f0a88b68a9
								
							
						
					
					
						commit
						7bb9e6f00b
					
				
							
								
								
									
										78
									
								
								dialects/common/sqlbuilder/insert.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										78
									
								
								dialects/common/sqlbuilder/insert.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,78 @@
 | 
			
		||||
package sqlbuilder
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
 | 
			
		||||
	"github.com/jinzhu/gorm"
 | 
			
		||||
	"github.com/jinzhu/gorm/model"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// BuildInsertSQL build insert SQL
 | 
			
		||||
func BuildInsertSQL(tx *gorm.DB) (s *bytes.Buffer, args []interface{}, defaultFieldsSlice [][]*model.Field, err error) {
 | 
			
		||||
	var (
 | 
			
		||||
		dialect         = tx.Dialect()
 | 
			
		||||
		assignmentsChan = GetAssignmentFields(tx)
 | 
			
		||||
		tableNameChan   = GetTable(tx)
 | 
			
		||||
	)
 | 
			
		||||
	defer close(tableNameChan)
 | 
			
		||||
 | 
			
		||||
	s = bytes.NewBufferString("INSERT INTO ")
 | 
			
		||||
	s.WriteString(dialect.Quote(<-tableNameChan))
 | 
			
		||||
 | 
			
		||||
	if assignments := <-assignmentsChan; len(assignments) > 0 {
 | 
			
		||||
		columns := []string{}
 | 
			
		||||
		defaultFields := []*model.Field{}
 | 
			
		||||
 | 
			
		||||
		// Write columns (column1, column2, column3)
 | 
			
		||||
		s.WriteString(" (")
 | 
			
		||||
 | 
			
		||||
		// Write values (v1, v2, v3), (v2-1, v2-2, v2-3)
 | 
			
		||||
		valueBuffer := bytes.NewBufferString("VALUES ")
 | 
			
		||||
 | 
			
		||||
		for idx, fields := range assignments {
 | 
			
		||||
			var primaryField *model.Field
 | 
			
		||||
			if idx != 0 {
 | 
			
		||||
				valueBuffer.WriteString(",")
 | 
			
		||||
			}
 | 
			
		||||
			valueBuffer.WriteString(" (")
 | 
			
		||||
 | 
			
		||||
			for j, field := range fields {
 | 
			
		||||
				if field.Field.IsPrimaryKey && primaryField == nil || field.Field.DBName == "id" {
 | 
			
		||||
					primaryField = field
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				if idx == 0 {
 | 
			
		||||
					columns = append(columns, field.Field.DBName)
 | 
			
		||||
					if j != 0 {
 | 
			
		||||
						s.WriteString(", ")
 | 
			
		||||
					}
 | 
			
		||||
					s.WriteString(dialect.Quote(field.Field.DBName))
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				if j != 0 {
 | 
			
		||||
					valueBuffer.WriteString(", ")
 | 
			
		||||
				}
 | 
			
		||||
				valueBuffer.WriteString("?")
 | 
			
		||||
 | 
			
		||||
				if field.IsBlank {
 | 
			
		||||
					args = append(args, nil)
 | 
			
		||||
					if field.HasDefaultValue {
 | 
			
		||||
						defaultFields = append(defaultFields, field)
 | 
			
		||||
					}
 | 
			
		||||
				} else {
 | 
			
		||||
					args = append(args, field.Value.Interface())
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			defaultFieldsSlice = append(defaultFieldsSlice, append([]*model.Field{primaryField}, defaultFields...))
 | 
			
		||||
			valueBuffer.WriteString(")")
 | 
			
		||||
		}
 | 
			
		||||
		s.WriteString(") ")
 | 
			
		||||
 | 
			
		||||
		_, err = valueBuffer.WriteTo(s)
 | 
			
		||||
	} else {
 | 
			
		||||
		s.WriteString(" DEFAULT VALUES")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
@ -1,13 +1,11 @@
 | 
			
		||||
package sqlite
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"fmt"
 | 
			
		||||
 | 
			
		||||
	"github.com/jinzhu/gorm"
 | 
			
		||||
	"github.com/jinzhu/gorm/dialects/common/sqlbuilder"
 | 
			
		||||
	"github.com/jinzhu/gorm/model"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Dialect Sqlite3 Dialect for GORM
 | 
			
		||||
@ -22,80 +20,26 @@ func (dialect Dialect) Quote(name string) string {
 | 
			
		||||
 | 
			
		||||
// Insert insert
 | 
			
		||||
func (dialect *Dialect) Insert(tx *gorm.DB) (err error) {
 | 
			
		||||
	var (
 | 
			
		||||
		args            []interface{}
 | 
			
		||||
		assignmentsChan = sqlbuilder.GetAssignmentFields(tx)
 | 
			
		||||
		tableNameChan   = sqlbuilder.GetTable(tx)
 | 
			
		||||
		primaryFields   []*model.Field
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	s := bytes.NewBufferString("INSERT INTO ")
 | 
			
		||||
	s.WriteString(dialect.Quote(<-tableNameChan))
 | 
			
		||||
 | 
			
		||||
	if assignments := <-assignmentsChan; len(assignments) > 0 {
 | 
			
		||||
		columns := []string{}
 | 
			
		||||
 | 
			
		||||
		// Write columns (column1, column2, column3)
 | 
			
		||||
		s.WriteString(" (")
 | 
			
		||||
 | 
			
		||||
		// Write values (v1, v2, v3), (v2-1, v2-2, v2-3)
 | 
			
		||||
		valueBuffer := bytes.NewBufferString("VALUES ")
 | 
			
		||||
 | 
			
		||||
		for idx, fields := range assignments {
 | 
			
		||||
			var primaryField *model.Field
 | 
			
		||||
			if idx != 0 {
 | 
			
		||||
				valueBuffer.WriteString(",")
 | 
			
		||||
			}
 | 
			
		||||
			valueBuffer.WriteString(" (")
 | 
			
		||||
 | 
			
		||||
			for j, field := range fields {
 | 
			
		||||
				if field.Field.IsPrimaryKey && primaryField == nil || field.Field.DBName == "id" {
 | 
			
		||||
					primaryField = field
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				if idx == 0 {
 | 
			
		||||
					columns = append(columns, field.Field.DBName)
 | 
			
		||||
					if j != 0 {
 | 
			
		||||
						s.WriteString(", ")
 | 
			
		||||
					}
 | 
			
		||||
					s.WriteString(dialect.Quote(field.Field.DBName))
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				if j != 0 {
 | 
			
		||||
					valueBuffer.WriteString(", ")
 | 
			
		||||
				}
 | 
			
		||||
				valueBuffer.WriteString("?")
 | 
			
		||||
 | 
			
		||||
				if field.IsBlank {
 | 
			
		||||
					args = append(args, nil)
 | 
			
		||||
				} else {
 | 
			
		||||
					args = append(args, field.Value.Interface())
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			primaryFields = append(primaryFields, primaryField)
 | 
			
		||||
			valueBuffer.WriteString(")")
 | 
			
		||||
		}
 | 
			
		||||
		s.WriteString(") ")
 | 
			
		||||
 | 
			
		||||
		_, err = valueBuffer.WriteTo(s)
 | 
			
		||||
	} else {
 | 
			
		||||
		s.WriteString(" DEFAULT VALUES")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	result, err := dialect.DB.Exec(s.String(), args...)
 | 
			
		||||
	s, args, defaultFieldsSlice, err := sqlbuilder.BuildInsertSQL(tx)
 | 
			
		||||
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		var lastInsertID int64
 | 
			
		||||
		tx.RowsAffected, _ = result.RowsAffected()
 | 
			
		||||
		lastInsertID, err = result.LastInsertId()
 | 
			
		||||
		if len(primaryFields) == int(tx.RowsAffected) {
 | 
			
		||||
			startID := lastInsertID - tx.RowsAffected + 1
 | 
			
		||||
			for i, primaryField := range primaryFields {
 | 
			
		||||
				tx.AddError(primaryField.Set(startID + int64(i)))
 | 
			
		||||
		result, err := dialect.DB.Exec(s.String(), args...)
 | 
			
		||||
 | 
			
		||||
		if err == nil {
 | 
			
		||||
			var lastInsertID int64
 | 
			
		||||
			tx.RowsAffected, _ = result.RowsAffected()
 | 
			
		||||
			lastInsertID, err = result.LastInsertId()
 | 
			
		||||
			if len(defaultFieldsSlice) == int(tx.RowsAffected) {
 | 
			
		||||
				startID := lastInsertID - tx.RowsAffected + 1
 | 
			
		||||
				for i, defaultFields := range defaultFieldsSlice {
 | 
			
		||||
					if len(defaultFields) > 0 {
 | 
			
		||||
						tx.AddError(defaultFields[0].Set(startID + int64(i)))
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user