From 1c01606adde886808625307bcdad568a7333084a Mon Sep 17 00:00:00 2001 From: drgomesp Date: Wed, 10 May 2017 16:37:02 +0400 Subject: [PATCH] Batch insert through a new callback --- callback.go | 36 ++++++---- callback_create_batch.go | 144 +++++++++++++++++++++++++++++++++++++++ main.go | 6 ++ 3 files changed, 173 insertions(+), 13 deletions(-) create mode 100644 callback_create_batch.go diff --git a/callback.go b/callback.go index 17f75451..05580e81 100644 --- a/callback.go +++ b/callback.go @@ -15,12 +15,13 @@ var DefaultCallback = &Callback{} // Field `rowQueries` contains callbacks will be call when querying object with Row, Rows... // Field `processors` contains all callback processors, will be used to generate above callbacks in order type Callback struct { - creates []*func(scope *Scope) - updates []*func(scope *Scope) - deletes []*func(scope *Scope) - queries []*func(scope *Scope) - rowQueries []*func(scope *Scope) - processors []*CallbackProcessor + creates []*func(scope *Scope) + createsBatch []*func(scope *Scope) + updates []*func(scope *Scope) + deletes []*func(scope *Scope) + queries []*func(scope *Scope) + rowQueries []*func(scope *Scope) + processors []*CallbackProcessor } // CallbackProcessor contains callback informations @@ -37,12 +38,13 @@ type CallbackProcessor struct { func (c *Callback) clone() *Callback { return &Callback{ - creates: c.creates, - updates: c.updates, - deletes: c.deletes, - queries: c.queries, - rowQueries: c.rowQueries, - processors: c.processors, + creates: c.creates, + createsBatch: c.createsBatch, + updates: c.updates, + deletes: c.deletes, + queries: c.queries, + rowQueries: c.rowQueries, + processors: c.processors, } } @@ -58,6 +60,11 @@ func (c *Callback) Create() *CallbackProcessor { return &CallbackProcessor{kind: "create", parent: c} } +// CreateBatch could be used to register callbacks for creating objects in bulk operations +func (c *Callback) CreateBatch() *CallbackProcessor { + return &CallbackProcessor{kind: "creates_batch", parent: c} +} + // Update could be used to register callbacks for updating object, refer `Create` for usage func (c *Callback) Update() *CallbackProcessor { return &CallbackProcessor{kind: "update", parent: c} @@ -217,13 +224,15 @@ func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) { // reorder all registered processors, and reset CRUD callbacks func (c *Callback) reorder() { - var creates, updates, deletes, queries, rowQueries []*CallbackProcessor + var creates, createsBatch, updates, deletes, queries, rowQueries []*CallbackProcessor for _, processor := range c.processors { if processor.name != "" { switch processor.kind { case "create": creates = append(creates, processor) + case "creates_batch": + createsBatch = append(createsBatch, processor) case "update": updates = append(updates, processor) case "delete": @@ -237,6 +246,7 @@ func (c *Callback) reorder() { } c.creates = sortProcessors(creates) + c.createsBatch = sortProcessors(createsBatch) c.updates = sortProcessors(updates) c.deletes = sortProcessors(deletes) c.queries = sortProcessors(queries) diff --git a/callback_create_batch.go b/callback_create_batch.go new file mode 100644 index 00000000..2a452144 --- /dev/null +++ b/callback_create_batch.go @@ -0,0 +1,144 @@ +package gorm + +import ( + "fmt" + "reflect" + + "strings" +) + +// Define callbacks for batch creating +func init() { + DefaultCallback.CreateBatch().Register("gorm:begin_transaction", beginTransactionCallback) + DefaultCallback.CreateBatch().Register("gorm:create_batch", createBatchCallback) + DefaultCallback.CreateBatch().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) +} + +// createCallback the callback used to insert data into database +func createBatchCallback(scope *Scope) { + value := scope.IndirectValue() + + if value.Kind() != reflect.Slice { + scope.Err(fmt.Errorf("createBatchCallback cannot be called for non-slice value, %+v given", value.Interface())) + return + } + + var ( + columns []string // one-dimensional array of strings containing columns + blankColumnsWithDefaultValue []string // one-dimensional array of strings containing columns + placeholders = make([][]string, value.Len()) // two-dimensional array of strings containing value placeholders + structFields = scope.GetModelStruct().StructFields + ) + + // Filling up the columns + for _, field := range fields(scope) { + // We don't treat non-normal fields on batch operations (relationships, etc) + if !field.IsNormal { + continue + } + + if field.IsBlank && field.HasDefaultValue { + blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName)) + scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue) + } else if !field.IsPrimaryKey || !field.IsBlank { + columns = append(columns, scope.Quote(field.DBName)) + } + } + + // Filling up the placeholders + for elementIndex := 0; elementIndex < value.Len(); elementIndex++ { + valuePlaceholders := []string{} + + for _, structField := range structFields { + // When inserting, the primary key is usually auto-increment + if !structField.IsPrimaryKey { + fieldValue := reflect.Indirect(value.Index(elementIndex)).FieldByName(structField.Names[0]).Interface() + valuePlaceholders = append(valuePlaceholders, scope.AddToVars(fieldValue)) + } + } + + placeholders[elementIndex] = valuePlaceholders + } + + var ( + returningColumn = "*" + quotedTableName = scope.QuotedTableName() + primaryField = scope.PrimaryField() + extraOption string + ) + + if str, ok := scope.Get("gorm:insert_option"); ok { + extraOption = fmt.Sprint(str) + } + + if primaryField != nil { + returningColumn = scope.Quote(primaryField.DBName) + } + + lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn) + + scope.Raw(fmt.Sprintf( + "INSERT INTO %v (%v) VALUES %v%v%v", + scope.QuotedTableName(), + strings.Join(columns, ","), + strings.Join(joinValuePlaceholders(placeholders), ","), + addExtraSpaceIfExist(extraOption), + addExtraSpaceIfExist(lastInsertIDReturningSuffix), + )) + + // Executing the query + // TODO(drgomesp): Do we really need this check? + if lastInsertIDReturningSuffix == "" || primaryField == nil { + if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { + // set rows affected count + scope.db.RowsAffected, _ = result.RowsAffected() + + if firstInsertedID, err := result.LastInsertId(); scope.Err(err) == nil { + fillPrimaryKeys(structFields, firstInsertedID, &value) + } + } + } +} + +func fillPrimaryKeys(structFields []*StructField, firstInsertedID int64, values *reflect.Value) { + for _, structField := range structFields { + for i := 0; i < values.Len(); i++ { + field := reflect.Indirect(values.Index(i)).FieldByName(structField.Names[0]) + + if field.IsValid() && field.CanSet() { + if field.Kind() == reflect.Int64 || field.Kind() == reflect.Int32 || field.Kind() == reflect.Int8 || field.Kind() == reflect.Int { + id := firstInsertedID + int64(i) + + if !field.OverflowInt(id) { + field.SetInt(id) + } + } + } + } + } +} + +func joinValuePlaceholders(placeholders [][]string) []string { + var valuePlaceholders []string + + for _, placeholder := range placeholders { + valuePlaceholders = append(valuePlaceholders, fmt.Sprintf("(%s)", strings.Join(placeholder, ","))) + } + + return valuePlaceholders +} + +func fields(scope *Scope) []*Field { + var ( + indirectScopeValue = scope.IndirectValue() + structFields = scope.GetModelStruct().StructFields + fields = make([]*Field, len(structFields)) + ) + + for i, structField := range structFields { + fieldValue := reflect.Indirect(indirectScopeValue.Index(0)).FieldByName(structField.Names[0]) + fields[i] = &Field{StructField: structField, Field: fieldValue, IsBlank: isBlank(fieldValue)} + } + + return fields +} diff --git a/main.go b/main.go index 97cff7db..7a754a3a 100644 --- a/main.go +++ b/main.go @@ -407,6 +407,12 @@ func (s *DB) Create(value interface{}) *DB { return scope.callCallbacks(s.parent.callbacks.creates).db } +// CreateBatch inserts multiple values into the database via a bulk operation +func (s *DB) CreateBatch(value interface{}) *DB { + scope := s.clone().NewScope(value) + return scope.callCallbacks(s.parent.callbacks.createsBatch).db +} + // Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition func (s *DB) Delete(value interface{}, where ...interface{}) *DB { return s.clone().NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db