Batch insert through a new callback
This commit is contained in:
		
							parent
							
								
									9acaa33324
								
							
						
					
					
						commit
						1c01606add
					
				
							
								
								
									
										36
									
								
								callback.go
									
									
									
									
									
								
							
							
						
						
									
										36
									
								
								callback.go
									
									
									
									
									
								
							@ -15,12 +15,13 @@ var DefaultCallback = &Callback{}
 | 
			
		||||
//   Field `rowQueries` contains callbacks will be call when querying object with Row, Rows...
 | 
			
		||||
//   Field `processors` contains all callback processors, will be used to generate above callbacks in order
 | 
			
		||||
type Callback struct {
 | 
			
		||||
	creates    []*func(scope *Scope)
 | 
			
		||||
	updates    []*func(scope *Scope)
 | 
			
		||||
	deletes    []*func(scope *Scope)
 | 
			
		||||
	queries    []*func(scope *Scope)
 | 
			
		||||
	rowQueries []*func(scope *Scope)
 | 
			
		||||
	processors []*CallbackProcessor
 | 
			
		||||
	creates      []*func(scope *Scope)
 | 
			
		||||
	createsBatch []*func(scope *Scope)
 | 
			
		||||
	updates      []*func(scope *Scope)
 | 
			
		||||
	deletes      []*func(scope *Scope)
 | 
			
		||||
	queries      []*func(scope *Scope)
 | 
			
		||||
	rowQueries   []*func(scope *Scope)
 | 
			
		||||
	processors   []*CallbackProcessor
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CallbackProcessor contains callback informations
 | 
			
		||||
@ -37,12 +38,13 @@ type CallbackProcessor struct {
 | 
			
		||||
 | 
			
		||||
func (c *Callback) clone() *Callback {
 | 
			
		||||
	return &Callback{
 | 
			
		||||
		creates:    c.creates,
 | 
			
		||||
		updates:    c.updates,
 | 
			
		||||
		deletes:    c.deletes,
 | 
			
		||||
		queries:    c.queries,
 | 
			
		||||
		rowQueries: c.rowQueries,
 | 
			
		||||
		processors: c.processors,
 | 
			
		||||
		creates:      c.creates,
 | 
			
		||||
		createsBatch: c.createsBatch,
 | 
			
		||||
		updates:      c.updates,
 | 
			
		||||
		deletes:      c.deletes,
 | 
			
		||||
		queries:      c.queries,
 | 
			
		||||
		rowQueries:   c.rowQueries,
 | 
			
		||||
		processors:   c.processors,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -58,6 +60,11 @@ func (c *Callback) Create() *CallbackProcessor {
 | 
			
		||||
	return &CallbackProcessor{kind: "create", parent: c}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CreateBatch could be used to register callbacks for creating objects in bulk operations
 | 
			
		||||
func (c *Callback) CreateBatch() *CallbackProcessor {
 | 
			
		||||
	return &CallbackProcessor{kind: "creates_batch", parent: c}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Update could be used to register callbacks for updating object, refer `Create` for usage
 | 
			
		||||
func (c *Callback) Update() *CallbackProcessor {
 | 
			
		||||
	return &CallbackProcessor{kind: "update", parent: c}
 | 
			
		||||
@ -217,13 +224,15 @@ func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) {
 | 
			
		||||
 | 
			
		||||
// reorder all registered processors, and reset CRUD callbacks
 | 
			
		||||
func (c *Callback) reorder() {
 | 
			
		||||
	var creates, updates, deletes, queries, rowQueries []*CallbackProcessor
 | 
			
		||||
	var creates, createsBatch, updates, deletes, queries, rowQueries []*CallbackProcessor
 | 
			
		||||
 | 
			
		||||
	for _, processor := range c.processors {
 | 
			
		||||
		if processor.name != "" {
 | 
			
		||||
			switch processor.kind {
 | 
			
		||||
			case "create":
 | 
			
		||||
				creates = append(creates, processor)
 | 
			
		||||
			case "creates_batch":
 | 
			
		||||
				createsBatch = append(createsBatch, processor)
 | 
			
		||||
			case "update":
 | 
			
		||||
				updates = append(updates, processor)
 | 
			
		||||
			case "delete":
 | 
			
		||||
@ -237,6 +246,7 @@ func (c *Callback) reorder() {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	c.creates = sortProcessors(creates)
 | 
			
		||||
	c.createsBatch = sortProcessors(createsBatch)
 | 
			
		||||
	c.updates = sortProcessors(updates)
 | 
			
		||||
	c.deletes = sortProcessors(deletes)
 | 
			
		||||
	c.queries = sortProcessors(queries)
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										144
									
								
								callback_create_batch.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										144
									
								
								callback_create_batch.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,144 @@
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"reflect"
 | 
			
		||||
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Define callbacks for batch creating
 | 
			
		||||
func init() {
 | 
			
		||||
	DefaultCallback.CreateBatch().Register("gorm:begin_transaction", beginTransactionCallback)
 | 
			
		||||
	DefaultCallback.CreateBatch().Register("gorm:create_batch", createBatchCallback)
 | 
			
		||||
	DefaultCallback.CreateBatch().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// createCallback the callback used to insert data into database
 | 
			
		||||
func createBatchCallback(scope *Scope) {
 | 
			
		||||
	value := scope.IndirectValue()
 | 
			
		||||
 | 
			
		||||
	if value.Kind() != reflect.Slice {
 | 
			
		||||
		scope.Err(fmt.Errorf("createBatchCallback cannot be called for non-slice value, %+v given", value.Interface()))
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var (
 | 
			
		||||
		columns                      []string                        // one-dimensional array of strings containing columns
 | 
			
		||||
		blankColumnsWithDefaultValue []string                        // one-dimensional array of strings containing columns
 | 
			
		||||
		placeholders                 = make([][]string, value.Len()) // two-dimensional array of strings containing value placeholders
 | 
			
		||||
		structFields                 = scope.GetModelStruct().StructFields
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	// Filling up the columns
 | 
			
		||||
	for _, field := range fields(scope) {
 | 
			
		||||
		// We don't treat non-normal fields on batch operations (relationships, etc)
 | 
			
		||||
		if !field.IsNormal {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if field.IsBlank && field.HasDefaultValue {
 | 
			
		||||
			blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName))
 | 
			
		||||
			scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue)
 | 
			
		||||
		} else if !field.IsPrimaryKey || !field.IsBlank {
 | 
			
		||||
			columns = append(columns, scope.Quote(field.DBName))
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Filling up the placeholders
 | 
			
		||||
	for elementIndex := 0; elementIndex < value.Len(); elementIndex++ {
 | 
			
		||||
		valuePlaceholders := []string{}
 | 
			
		||||
 | 
			
		||||
		for _, structField := range structFields {
 | 
			
		||||
			// When inserting, the primary key is usually auto-increment
 | 
			
		||||
			if !structField.IsPrimaryKey {
 | 
			
		||||
				fieldValue := reflect.Indirect(value.Index(elementIndex)).FieldByName(structField.Names[0]).Interface()
 | 
			
		||||
				valuePlaceholders = append(valuePlaceholders, scope.AddToVars(fieldValue))
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		placeholders[elementIndex] = valuePlaceholders
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var (
 | 
			
		||||
		returningColumn = "*"
 | 
			
		||||
		quotedTableName = scope.QuotedTableName()
 | 
			
		||||
		primaryField    = scope.PrimaryField()
 | 
			
		||||
		extraOption     string
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	if str, ok := scope.Get("gorm:insert_option"); ok {
 | 
			
		||||
		extraOption = fmt.Sprint(str)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if primaryField != nil {
 | 
			
		||||
		returningColumn = scope.Quote(primaryField.DBName)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn)
 | 
			
		||||
 | 
			
		||||
	scope.Raw(fmt.Sprintf(
 | 
			
		||||
		"INSERT INTO %v (%v) VALUES %v%v%v",
 | 
			
		||||
		scope.QuotedTableName(),
 | 
			
		||||
		strings.Join(columns, ","),
 | 
			
		||||
		strings.Join(joinValuePlaceholders(placeholders), ","),
 | 
			
		||||
		addExtraSpaceIfExist(extraOption),
 | 
			
		||||
		addExtraSpaceIfExist(lastInsertIDReturningSuffix),
 | 
			
		||||
	))
 | 
			
		||||
 | 
			
		||||
	// Executing the query
 | 
			
		||||
	// TODO(drgomesp): Do we really need this check?
 | 
			
		||||
	if lastInsertIDReturningSuffix == "" || primaryField == nil {
 | 
			
		||||
		if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
 | 
			
		||||
			// set rows affected count
 | 
			
		||||
			scope.db.RowsAffected, _ = result.RowsAffected()
 | 
			
		||||
 | 
			
		||||
			if firstInsertedID, err := result.LastInsertId(); scope.Err(err) == nil {
 | 
			
		||||
				fillPrimaryKeys(structFields, firstInsertedID, &value)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func fillPrimaryKeys(structFields []*StructField, firstInsertedID int64, values *reflect.Value) {
 | 
			
		||||
	for _, structField := range structFields {
 | 
			
		||||
		for i := 0; i < values.Len(); i++ {
 | 
			
		||||
			field := reflect.Indirect(values.Index(i)).FieldByName(structField.Names[0])
 | 
			
		||||
 | 
			
		||||
			if field.IsValid() && field.CanSet() {
 | 
			
		||||
				if field.Kind() == reflect.Int64 || field.Kind() == reflect.Int32 || field.Kind() == reflect.Int8 || field.Kind() == reflect.Int {
 | 
			
		||||
					id := firstInsertedID + int64(i)
 | 
			
		||||
 | 
			
		||||
					if !field.OverflowInt(id) {
 | 
			
		||||
						field.SetInt(id)
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func joinValuePlaceholders(placeholders [][]string) []string {
 | 
			
		||||
	var valuePlaceholders []string
 | 
			
		||||
 | 
			
		||||
	for _, placeholder := range placeholders {
 | 
			
		||||
		valuePlaceholders = append(valuePlaceholders, fmt.Sprintf("(%s)", strings.Join(placeholder, ",")))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return valuePlaceholders
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func fields(scope *Scope) []*Field {
 | 
			
		||||
	var (
 | 
			
		||||
		indirectScopeValue = scope.IndirectValue()
 | 
			
		||||
		structFields       = scope.GetModelStruct().StructFields
 | 
			
		||||
		fields             = make([]*Field, len(structFields))
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	for i, structField := range structFields {
 | 
			
		||||
		fieldValue := reflect.Indirect(indirectScopeValue.Index(0)).FieldByName(structField.Names[0])
 | 
			
		||||
		fields[i] = &Field{StructField: structField, Field: fieldValue, IsBlank: isBlank(fieldValue)}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return fields
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										6
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								main.go
									
									
									
									
									
								
							@ -407,6 +407,12 @@ func (s *DB) Create(value interface{}) *DB {
 | 
			
		||||
	return scope.callCallbacks(s.parent.callbacks.creates).db
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CreateBatch inserts multiple values into the database via a bulk operation
 | 
			
		||||
func (s *DB) CreateBatch(value interface{}) *DB {
 | 
			
		||||
	scope := s.clone().NewScope(value)
 | 
			
		||||
	return scope.callCallbacks(s.parent.callbacks.createsBatch).db
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
 | 
			
		||||
func (s *DB) Delete(value interface{}, where ...interface{}) *DB {
 | 
			
		||||
	return s.clone().NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user