diff --git a/callbacks/callback.go b/callbacks/callback.go index a4382147..6c3e2c9c 100644 --- a/callbacks/callback.go +++ b/callbacks/callback.go @@ -1,9 +1,10 @@ package gorm -import "log" +import ( + "log" -// DefaultCallback default callbacks defined by gorm -var DefaultCallback = &Callback{} + "github.com/jinzhu/gorm" +) // Callback is a struct that contains all CRUD callbacks // Field `creates` contains callbacks will be call when creating object @@ -13,23 +14,23 @@ var DefaultCallback = &Callback{} // Field `rowQueries` contains callbacks will be call when querying object with Row, Rows... // Field `processors` contains all callback processors, will be used to generate above callbacks in order type Callback struct { - creates []*func(scope *Scope) - updates []*func(scope *Scope) - deletes []*func(scope *Scope) - queries []*func(scope *Scope) - rowQueries []*func(scope *Scope) + creates []*func(*gorm.DB) + updates []*func(*gorm.DB) + deletes []*func(*gorm.DB) + queries []*func(*gorm.DB) + rowQueries []*func(*gorm.DB) processors []*CallbackProcessor } // CallbackProcessor contains callback informations type CallbackProcessor struct { - name string // current callback's name - before string // register current callback before a callback - after string // register current callback after a callback - replace bool // replace callbacks with same name - remove bool // delete callbacks with same name - kind string // callback type: create, update, delete, query, row_query - processor *func(scope *Scope) // callback handler + name string // current callback's name + before string // register current callback before a callback + after string // register current callback after a callback + replace bool // replace callbacks with same name + remove bool // delete callbacks with same name + kind string // callback type: create, update, delete, query, row_query + processor *func(*gorm.DB) // callback handler parent *Callback } @@ -45,7 +46,7 @@ func (c *Callback) clone() *Callback { } // Create could be used to register callbacks for creating object -// db.Callback().Create().After("gorm:create").Register("plugin:run_after_create", func(*Scope) { +// db.Callback().Create().After("gorm:create").Register("plugin:run_after_create", func(*gorm.DB) { // // business logic // ... // @@ -90,7 +91,7 @@ func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor { } // Register a new callback, refer `Callbacks.Create` -func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) { +func (cp *CallbackProcessor) Register(callbackName string, callback func(*gorm.DB)) { if cp.kind == "row_query" { if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" { log.Printf("Registing RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...\n", callbackName) @@ -107,7 +108,7 @@ func (cp *CallbackProcessor) Register(callbackName string, callback func(scope * // Remove a registered callback // db.Callback().Create().Remove("gorm:update_time_stamp_when_create") func (cp *CallbackProcessor) Remove(callbackName string) { - log.Printf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum()) + log.Printf("[info] removing callback `%v` from %v\n", callbackName, utils.FileWithLineNum()) cp.name = callbackName cp.remove = true cp.parent.processors = append(cp.parent.processors, cp) @@ -115,12 +116,12 @@ func (cp *CallbackProcessor) Remove(callbackName string) { } // Replace a registered callback with new callback -// db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) { +// db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*gorm.DB) { // scope.SetColumn("Created", now) // scope.SetColumn("Updated", now) // }) -func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) { - log.Printf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum()) +func (cp *CallbackProcessor) Replace(callbackName string, callback func(*gorm.DB)) { + log.Printf("[info] replacing callback `%v` from %v\n", callbackName, utils.FileWithLineNum()) cp.name = callbackName cp.processor = &callback cp.replace = true @@ -130,7 +131,7 @@ func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *S // Get registered callback // db.Callback().Create().Get("gorm:create") -func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) { +func (cp *CallbackProcessor) Get(callbackName string) (callback func(*gorm.DB)) { for _, p := range cp.parent.processors { if p.name == callbackName && p.kind == cp.kind && !cp.remove { return *p.processor @@ -150,7 +151,7 @@ func getRIndex(strs []string, str string) int { } // sortProcessors sort callback processors based on its before, after, remove, replace -func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) { +func sortProcessors(cps []*CallbackProcessor) []*func(*gorm.DB) { var ( allNames, sortedNames []string sortCallbackProcessor func(c *CallbackProcessor) @@ -159,7 +160,7 @@ func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) { for _, cp := range cps { // show warning message the callback name already exists if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove { - log.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum()) + log.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, utils.FileWithLineNum()) } allNames = append(allNames, cp.name) } @@ -203,7 +204,7 @@ func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) { sortCallbackProcessor(cp) } - var sortedFuncs []*func(scope *Scope) + var sortedFuncs []*func(*gorm.DB) for _, name := range sortedNames { if index := getRIndex(allNames, name); !cps[index].remove { sortedFuncs = append(sortedFuncs, cps[index].processor) diff --git a/callbacks/callback_create.go b/callbacks/callback_create.go deleted file mode 100644 index e7fe6f86..00000000 --- a/callbacks/callback_create.go +++ /dev/null @@ -1,164 +0,0 @@ -package gorm - -import ( - "fmt" - "strings" -) - -// Define callbacks for creating -func init() { - DefaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback) - DefaultCallback.Create().Register("gorm:before_create", beforeCreateCallback) - DefaultCallback.Create().Register("gorm:save_before_associations", saveBeforeAssociationsCallback) - DefaultCallback.Create().Register("gorm:update_time_stamp", updateTimeStampForCreateCallback) - DefaultCallback.Create().Register("gorm:create", createCallback) - DefaultCallback.Create().Register("gorm:force_reload_after_create", forceReloadAfterCreateCallback) - DefaultCallback.Create().Register("gorm:save_after_associations", saveAfterAssociationsCallback) - DefaultCallback.Create().Register("gorm:after_create", afterCreateCallback) - DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) -} - -// beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating -func beforeCreateCallback(scope *Scope) { - if !scope.HasError() { - scope.CallMethod("BeforeSave") - } - if !scope.HasError() { - scope.CallMethod("BeforeCreate") - } -} - -// updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating -func updateTimeStampForCreateCallback(scope *Scope) { - if !scope.HasError() { - now := NowFunc() - - if createdAtField, ok := scope.FieldByName("CreatedAt"); ok { - if createdAtField.IsBlank { - createdAtField.Set(now) - } - } - - if updatedAtField, ok := scope.FieldByName("UpdatedAt"); ok { - if updatedAtField.IsBlank { - updatedAtField.Set(now) - } - } - } -} - -// createCallback the callback used to insert data into database -func createCallback(scope *Scope) { - if !scope.HasError() { - defer scope.trace(NowFunc()) - - var ( - columns, placeholders []string - blankColumnsWithDefaultValue []string - ) - - for _, field := range scope.Fields() { - if scope.changeableField(field) { - if field.IsNormal { - if field.IsBlank && field.HasDefaultValue { - blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName)) - scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue) - } else if !field.IsPrimaryKey || !field.IsBlank { - columns = append(columns, scope.Quote(field.DBName)) - placeholders = append(placeholders, scope.AddToVars(field.Field.Interface())) - } - } else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" { - for _, foreignKey := range field.Relationship.ForeignDBNames { - if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) { - columns = append(columns, scope.Quote(foreignField.DBName)) - placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface())) - } - } - } - } - } - - var ( - returningColumn = "*" - quotedTableName = scope.QuotedTableName() - primaryField = scope.PrimaryField() - extraOption string - ) - - if str, ok := scope.Get("gorm:insert_option"); ok { - extraOption = fmt.Sprint(str) - } - - if primaryField != nil { - returningColumn = scope.Quote(primaryField.DBName) - } - - lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn) - - if len(columns) == 0 { - scope.Raw(fmt.Sprintf( - "INSERT INTO %v %v%v%v", - quotedTableName, - scope.Dialect().DefaultValueStr(), - addExtraSpaceIfExist(extraOption), - addExtraSpaceIfExist(lastInsertIDReturningSuffix), - )) - } else { - scope.Raw(fmt.Sprintf( - "INSERT INTO %v (%v) VALUES (%v)%v%v", - scope.QuotedTableName(), - strings.Join(columns, ","), - strings.Join(placeholders, ","), - addExtraSpaceIfExist(extraOption), - addExtraSpaceIfExist(lastInsertIDReturningSuffix), - )) - } - - // execute create sql - if lastInsertIDReturningSuffix == "" || primaryField == nil { - if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { - // set rows affected count - scope.db.RowsAffected, _ = result.RowsAffected() - - // set primary value to primary field - if primaryField != nil && primaryField.IsBlank { - if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil { - scope.Err(primaryField.Set(primaryValue)) - } - } - } - } else { - if primaryField.Field.CanAddr() { - if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil { - primaryField.IsBlank = false - scope.db.RowsAffected = 1 - } - } else { - scope.Err(ErrUnaddressable) - } - } - } -} - -// forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object -func forceReloadAfterCreateCallback(scope *Scope) { - if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok { - db := scope.DB().New().Table(scope.TableName()).Select(blankColumnsWithDefaultValue.([]string)) - for _, field := range scope.Fields() { - if field.IsPrimaryKey && !field.IsBlank { - db = db.Where(fmt.Sprintf("%v = ?", field.DBName), field.Field.Interface()) - } - } - db.Scan(scope.Value) - } -} - -// afterCreateCallback will invoke `AfterCreate`, `AfterSave` method after creating -func afterCreateCallback(scope *Scope) { - if !scope.HasError() { - scope.CallMethod("AfterCreate") - } - if !scope.HasError() { - scope.CallMethod("AfterSave") - } -} diff --git a/callbacks/callback_delete.go b/callbacks/callback_delete.go deleted file mode 100644 index 73d90880..00000000 --- a/callbacks/callback_delete.go +++ /dev/null @@ -1,63 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" -) - -// Define callbacks for deleting -func init() { - DefaultCallback.Delete().Register("gorm:begin_transaction", beginTransactionCallback) - DefaultCallback.Delete().Register("gorm:before_delete", beforeDeleteCallback) - DefaultCallback.Delete().Register("gorm:delete", deleteCallback) - DefaultCallback.Delete().Register("gorm:after_delete", afterDeleteCallback) - DefaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) -} - -// beforeDeleteCallback will invoke `BeforeDelete` method before deleting -func beforeDeleteCallback(scope *Scope) { - if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() { - scope.Err(errors.New("Missing WHERE clause while deleting")) - return - } - if !scope.HasError() { - scope.CallMethod("BeforeDelete") - } -} - -// deleteCallback used to delete data from database or set deleted_at to current time (when using with soft delete) -func deleteCallback(scope *Scope) { - if !scope.HasError() { - var extraOption string - if str, ok := scope.Get("gorm:delete_option"); ok { - extraOption = fmt.Sprint(str) - } - - deletedAtField, hasDeletedAtField := scope.FieldByName("DeletedAt") - - if !scope.Search.Unscoped && hasDeletedAtField { - scope.Raw(fmt.Sprintf( - "UPDATE %v SET %v=%v%v%v", - scope.QuotedTableName(), - scope.Quote(deletedAtField.DBName), - scope.AddToVars(NowFunc()), - addExtraSpaceIfExist(scope.CombinedConditionSql()), - addExtraSpaceIfExist(extraOption), - )).Exec() - } else { - scope.Raw(fmt.Sprintf( - "DELETE FROM %v%v%v", - scope.QuotedTableName(), - addExtraSpaceIfExist(scope.CombinedConditionSql()), - addExtraSpaceIfExist(extraOption), - )).Exec() - } - } -} - -// afterDeleteCallback will invoke `AfterDelete` method after deleting -func afterDeleteCallback(scope *Scope) { - if !scope.HasError() { - scope.CallMethod("AfterDelete") - } -} diff --git a/callbacks/callback_query.go b/callbacks/callback_query.go deleted file mode 100644 index 538693d3..00000000 --- a/callbacks/callback_query.go +++ /dev/null @@ -1,479 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" - "reflect" - "strconv" - "strings" -) - -// Define callbacks for querying -func init() { - DefaultCallback.Query().Register("gorm:query", queryCallback) - DefaultCallback.Query().Register("gorm:preload", preloadCallback) - DefaultCallback.Query().Register("gorm:after_query", afterQueryCallback) -} - -// queryCallback used to query data from database -func queryCallback(scope *Scope) { - if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { - return - } - - defer scope.trace(NowFunc()) - - var ( - isSlice, isPtr bool - resultType reflect.Type - results = scope.IndirectValue() - ) - - if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok { - if primaryField := scope.PrimaryField(); primaryField != nil { - scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryField.DBName), orderBy)) - } - } - - if value, ok := scope.Get("gorm:query_destination"); ok { - results = indirect(reflect.ValueOf(value)) - } - - if kind := results.Kind(); kind == reflect.Slice { - isSlice = true - resultType = results.Type().Elem() - results.Set(reflect.MakeSlice(results.Type(), 0, 0)) - - if resultType.Kind() == reflect.Ptr { - isPtr = true - resultType = resultType.Elem() - } - } else if kind != reflect.Struct { - scope.Err(errors.New("unsupported destination, should be slice or struct")) - return - } - - scope.prepareQuerySQL() - - if !scope.HasError() { - scope.db.RowsAffected = 0 - if str, ok := scope.Get("gorm:query_option"); ok { - scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str)) - } - - if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { - defer rows.Close() - - columns, _ := rows.Columns() - for rows.Next() { - scope.db.RowsAffected++ - - elem := results - if isSlice { - elem = reflect.New(resultType).Elem() - } - - scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields()) - - if isSlice { - if isPtr { - results.Set(reflect.Append(results, elem.Addr())) - } else { - results.Set(reflect.Append(results, elem)) - } - } - } - - if err := rows.Err(); err != nil { - scope.Err(err) - } else if scope.db.RowsAffected == 0 && !isSlice { - scope.Err(ErrRecordNotFound) - } - } - } -} - -// afterQueryCallback will invoke `AfterFind` method after querying -func afterQueryCallback(scope *Scope) { - if !scope.HasError() { - scope.CallMethod("AfterFind") - } -} - -// preloadCallback used to preload associations -func preloadCallback(scope *Scope) { - if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { - return - } - - if _, ok := scope.Get("gorm:auto_preload"); ok { - autoPreload(scope) - } - - if scope.Search.preload == nil || scope.HasError() { - return - } - - var ( - preloadedMap = map[string]bool{} - fields = scope.Fields() - ) - - for _, preload := range scope.Search.preload { - var ( - preloadFields = strings.Split(preload.schema, ".") - currentScope = scope - currentFields = fields - ) - - for idx, preloadField := range preloadFields { - var currentPreloadConditions []interface{} - - if currentScope == nil { - continue - } - - // if not preloaded - if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] { - - // assign search conditions to last preload - if idx == len(preloadFields)-1 { - currentPreloadConditions = preload.conditions - } - - for _, field := range currentFields { - if field.Name != preloadField || field.Relationship == nil { - continue - } - - switch field.Relationship.Kind { - case "has_one": - currentScope.handleHasOnePreload(field, currentPreloadConditions) - case "has_many": - currentScope.handleHasManyPreload(field, currentPreloadConditions) - case "belongs_to": - currentScope.handleBelongsToPreload(field, currentPreloadConditions) - case "many_to_many": - currentScope.handleManyToManyPreload(field, currentPreloadConditions) - default: - scope.Err(errors.New("unsupported relation")) - } - - preloadedMap[preloadKey] = true - break - } - - if !preloadedMap[preloadKey] { - scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType)) - return - } - } - - // preload next level - if idx < len(preloadFields)-1 { - currentScope = currentScope.getColumnAsScope(preloadField) - if currentScope != nil { - currentFields = currentScope.Fields() - } - } - } - } -} - -func autoPreload(scope *Scope) { - for _, field := range scope.Fields() { - if field.Relationship == nil { - continue - } - - if val, ok := field.TagSettings["PRELOAD"]; ok { - if preload, err := strconv.ParseBool(val); err != nil { - scope.Err(errors.New("invalid preload option")) - return - } else if !preload { - continue - } - } - - scope.Search.Preload(field.Name) - } -} - -func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) { - var ( - preloadDB = scope.NewDB() - preloadConditions []interface{} - ) - - for _, condition := range conditions { - if scopes, ok := condition.(func(*DB) *DB); ok { - preloadDB = scopes(preloadDB) - } else { - preloadConditions = append(preloadConditions, condition) - } - } - - return preloadDB, preloadConditions -} - -// handleHasOnePreload used to preload has one associations -func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) { - relation := field.Relationship - - // get relations's primary keys - primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) - if len(primaryKeys) == 0 { - return - } - - // preload conditions - preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) - - // find relations - query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)) - values := toQueryValues(primaryKeys) - if relation.PolymorphicType != "" { - query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName)) - values = append(values, relation.PolymorphicValue) - } - - results := makeSlice(field.Struct.Type) - scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error) - - // assign find results - var ( - resultsValue = indirect(reflect.ValueOf(results)) - indirectScopeValue = scope.IndirectValue() - ) - - if indirectScopeValue.Kind() == reflect.Slice { - for j := 0; j < indirectScopeValue.Len(); j++ { - for i := 0; i < resultsValue.Len(); i++ { - result := resultsValue.Index(i) - foreignValues := getValueFromFields(result, relation.ForeignFieldNames) - if indirectValue := indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) { - indirectValue.FieldByName(field.Name).Set(result) - break - } - } - } - } else { - for i := 0; i < resultsValue.Len(); i++ { - result := resultsValue.Index(i) - scope.Err(field.Set(result)) - } - } -} - -// handleHasManyPreload used to preload has many associations -func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) { - relation := field.Relationship - - // get relations's primary keys - primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) - if len(primaryKeys) == 0 { - return - } - - // preload conditions - preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) - - // find relations - query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)) - values := toQueryValues(primaryKeys) - if relation.PolymorphicType != "" { - query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName)) - values = append(values, relation.PolymorphicValue) - } - - results := makeSlice(field.Struct.Type) - scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error) - - // assign find results - var ( - resultsValue = indirect(reflect.ValueOf(results)) - indirectScopeValue = scope.IndirectValue() - ) - - if indirectScopeValue.Kind() == reflect.Slice { - preloadMap := make(map[string][]reflect.Value) - for i := 0; i < resultsValue.Len(); i++ { - result := resultsValue.Index(i) - foreignValues := getValueFromFields(result, relation.ForeignFieldNames) - preloadMap[toString(foreignValues)] = append(preloadMap[toString(foreignValues)], result) - } - - for j := 0; j < indirectScopeValue.Len(); j++ { - object := indirect(indirectScopeValue.Index(j)) - objectRealValue := getValueFromFields(object, relation.AssociationForeignFieldNames) - f := object.FieldByName(field.Name) - if results, ok := preloadMap[toString(objectRealValue)]; ok { - f.Set(reflect.Append(f, results...)) - } else { - f.Set(reflect.MakeSlice(f.Type(), 0, 0)) - } - } - } else { - scope.Err(field.Set(resultsValue)) - } -} - -// handleBelongsToPreload used to preload belongs to associations -func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) { - relation := field.Relationship - - // preload conditions - preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) - - // get relations's primary keys - primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value) - if len(primaryKeys) == 0 { - return - } - - // find relations - results := makeSlice(field.Struct.Type) - scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error) - - // assign find results - var ( - resultsValue = indirect(reflect.ValueOf(results)) - indirectScopeValue = scope.IndirectValue() - ) - - for i := 0; i < resultsValue.Len(); i++ { - result := resultsValue.Index(i) - if indirectScopeValue.Kind() == reflect.Slice { - value := getValueFromFields(result, relation.AssociationForeignFieldNames) - for j := 0; j < indirectScopeValue.Len(); j++ { - object := indirect(indirectScopeValue.Index(j)) - if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) { - object.FieldByName(field.Name).Set(result) - } - } - } else { - scope.Err(field.Set(result)) - } - } -} - -// handleManyToManyPreload used to preload many to many associations -func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) { - var ( - relation = field.Relationship - joinTableHandler = relation.JoinTableHandler - fieldType = field.Struct.Type.Elem() - foreignKeyValue interface{} - foreignKeyType = reflect.ValueOf(&foreignKeyValue).Type() - linkHash = map[string][]reflect.Value{} - isPtr bool - ) - - if fieldType.Kind() == reflect.Ptr { - isPtr = true - fieldType = fieldType.Elem() - } - - var sourceKeys = []string{} - for _, key := range joinTableHandler.SourceForeignKeys() { - sourceKeys = append(sourceKeys, key.DBName) - } - - // preload conditions - preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) - - // generate query with join table - newScope := scope.New(reflect.New(fieldType).Interface()) - preloadDB = preloadDB.Table(newScope.TableName()).Model(newScope.Value) - - if len(preloadDB.search.selects) == 0 { - preloadDB = preloadDB.Select("*") - } - - preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value) - - // preload inline conditions - if len(preloadConditions) > 0 { - preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...) - } - - rows, err := preloadDB.Rows() - - if scope.Err(err) != nil { - return - } - defer rows.Close() - - columns, _ := rows.Columns() - for rows.Next() { - var ( - elem = reflect.New(fieldType).Elem() - fields = scope.New(elem.Addr().Interface()).Fields() - ) - - // register foreign keys in join tables - var joinTableFields []*Field - for _, sourceKey := range sourceKeys { - joinTableFields = append(joinTableFields, &Field{StructField: &StructField{DBName: sourceKey, IsNormal: true}, Field: reflect.New(foreignKeyType).Elem()}) - } - - scope.scan(rows, columns, append(fields, joinTableFields...)) - - scope.New(elem.Addr().Interface()). - InstanceSet("gorm:skip_query_callback", true). - callCallbacks(scope.db.parent.callbacks.queries) - - var foreignKeys = make([]interface{}, len(sourceKeys)) - // generate hashed forkey keys in join table - for idx, joinTableField := range joinTableFields { - if !joinTableField.Field.IsNil() { - foreignKeys[idx] = joinTableField.Field.Elem().Interface() - } - } - hashedSourceKeys := toString(foreignKeys) - - if isPtr { - linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem.Addr()) - } else { - linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem) - } - } - - if err := rows.Err(); err != nil { - scope.Err(err) - } - - // assign find results - var ( - indirectScopeValue = scope.IndirectValue() - fieldsSourceMap = map[string][]reflect.Value{} - foreignFieldNames = []string{} - ) - - for _, dbName := range relation.ForeignFieldNames { - if field, ok := scope.FieldByName(dbName); ok { - foreignFieldNames = append(foreignFieldNames, field.Name) - } - } - - if indirectScopeValue.Kind() == reflect.Slice { - for j := 0; j < indirectScopeValue.Len(); j++ { - object := indirect(indirectScopeValue.Index(j)) - key := toString(getValueFromFields(object, foreignFieldNames)) - fieldsSourceMap[key] = append(fieldsSourceMap[key], object.FieldByName(field.Name)) - } - } else if indirectScopeValue.IsValid() { - key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames)) - fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name)) - } - for source, link := range linkHash { - for i, field := range fieldsSourceMap[source] { - //If not 0 this means Value is a pointer and we already added preloaded models to it - if fieldsSourceMap[source][i].Len() != 0 { - continue - } - field.Set(reflect.Append(fieldsSourceMap[source][i], link...)) - } - - } -} diff --git a/callbacks/callback_row_query.go b/callbacks/callback_row_query.go deleted file mode 100644 index c2ff4a08..00000000 --- a/callbacks/callback_row_query.go +++ /dev/null @@ -1,30 +0,0 @@ -package gorm - -import "database/sql" - -// Define callbacks for row query -func init() { - DefaultCallback.RowQuery().Register("gorm:row_query", rowQueryCallback) -} - -type RowQueryResult struct { - Row *sql.Row -} - -type RowsQueryResult struct { - Rows *sql.Rows - Error error -} - -// queryCallback used to query data from database -func rowQueryCallback(scope *Scope) { - if result, ok := scope.InstanceGet("row_query_result"); ok { - scope.prepareQuerySQL() - - if rowResult, ok := result.(*RowQueryResult); ok { - rowResult.Row = scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...) - } else if rowsResult, ok := result.(*RowsQueryResult); ok { - rowsResult.Rows, rowsResult.Error = scope.SQLDB().Query(scope.SQL, scope.SQLVars...) - } - } -} diff --git a/callbacks/callback_save.go b/callbacks/callback_save.go deleted file mode 100644 index ef267141..00000000 --- a/callbacks/callback_save.go +++ /dev/null @@ -1,170 +0,0 @@ -package gorm - -import ( - "reflect" - "strings" -) - -func beginTransactionCallback(scope *Scope) { - scope.Begin() -} - -func commitOrRollbackTransactionCallback(scope *Scope) { - scope.CommitOrRollback() -} - -func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCreate bool, saveReference bool, r *Relationship) { - checkTruth := func(value interface{}) bool { - if v, ok := value.(bool); ok && !v { - return false - } - - if v, ok := value.(string); ok { - v = strings.ToLower(v) - if v == "false" || v != "skip" { - return false - } - } - - return true - } - - if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { - if r = field.Relationship; r != nil { - autoUpdate, autoCreate, saveReference = true, true, true - - if value, ok := scope.Get("gorm:save_associations"); ok { - autoUpdate = checkTruth(value) - autoCreate = autoUpdate - } else if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; ok { - autoUpdate = checkTruth(value) - autoCreate = autoUpdate - } - - if value, ok := scope.Get("gorm:association_autoupdate"); ok { - autoUpdate = checkTruth(value) - } else if value, ok := field.TagSettings["ASSOCIATION_AUTOUPDATE"]; ok { - autoUpdate = checkTruth(value) - } - - if value, ok := scope.Get("gorm:association_autocreate"); ok { - autoCreate = checkTruth(value) - } else if value, ok := field.TagSettings["ASSOCIATION_AUTOCREATE"]; ok { - autoCreate = checkTruth(value) - } - - if value, ok := scope.Get("gorm:association_save_reference"); ok { - saveReference = checkTruth(value) - } else if value, ok := field.TagSettings["ASSOCIATION_SAVE_REFERENCE"]; ok { - saveReference = checkTruth(value) - } - } - } - - return -} - -func saveBeforeAssociationsCallback(scope *Scope) { - for _, field := range scope.Fields() { - autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field) - - if relationship != nil && relationship.Kind == "belongs_to" { - fieldValue := field.Field.Addr().Interface() - newScope := scope.New(fieldValue) - - if newScope.PrimaryKeyZero() { - if autoCreate { - scope.Err(scope.NewDB().Save(fieldValue).Error) - } - } else if autoUpdate { - scope.Err(scope.NewDB().Save(fieldValue).Error) - } - - if saveReference { - if len(relationship.ForeignFieldNames) != 0 { - // set value's foreign key - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok { - scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface())) - } - } - } - } - } - } -} - -func saveAfterAssociationsCallback(scope *Scope) { - for _, field := range scope.Fields() { - autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field) - - if relationship != nil && (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") { - value := field.Field - - switch value.Kind() { - case reflect.Slice: - for i := 0; i < value.Len(); i++ { - newDB := scope.NewDB() - elem := value.Index(i).Addr().Interface() - newScope := newDB.NewScope(elem) - - if saveReference { - if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 { - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if f, ok := scope.FieldByName(associationForeignName); ok { - scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) - } - } - } - - if relationship.PolymorphicType != "" { - scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) - } - } - - if newScope.PrimaryKeyZero() { - if autoCreate { - scope.Err(newDB.Save(elem).Error) - } - } else if autoUpdate { - scope.Err(newDB.Save(elem).Error) - } - - if !scope.New(newScope.Value).PrimaryKeyZero() && saveReference { - if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil { - scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value)) - } - } - } - default: - elem := value.Addr().Interface() - newScope := scope.New(elem) - - if saveReference { - if len(relationship.ForeignFieldNames) != 0 { - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if f, ok := scope.FieldByName(associationForeignName); ok { - scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) - } - } - } - - if relationship.PolymorphicType != "" { - scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) - } - } - - if newScope.PrimaryKeyZero() { - if autoCreate { - scope.Err(scope.NewDB().Save(elem).Error) - } - } else if autoUpdate { - scope.Err(scope.NewDB().Save(elem).Error) - } - } - } - } -} diff --git a/callbacks/callback_update.go b/callbacks/callback_update.go deleted file mode 100644 index 373bd726..00000000 --- a/callbacks/callback_update.go +++ /dev/null @@ -1,119 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" - "sort" - "strings" -) - -// Define callbacks for updating -func init() { - DefaultCallback.Update().Register("gorm:assign_updating_attributes", assignUpdatingAttributesCallback) - DefaultCallback.Update().Register("gorm:begin_transaction", beginTransactionCallback) - DefaultCallback.Update().Register("gorm:before_update", beforeUpdateCallback) - DefaultCallback.Update().Register("gorm:save_before_associations", saveBeforeAssociationsCallback) - DefaultCallback.Update().Register("gorm:update_time_stamp", updateTimeStampForUpdateCallback) - DefaultCallback.Update().Register("gorm:update", updateCallback) - DefaultCallback.Update().Register("gorm:save_after_associations", saveAfterAssociationsCallback) - DefaultCallback.Update().Register("gorm:after_update", afterUpdateCallback) - DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) -} - -// assignUpdatingAttributesCallback assign updating attributes to model -func assignUpdatingAttributesCallback(scope *Scope) { - if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok { - if updateMaps, hasUpdate := scope.updatedAttrsWithValues(attrs); hasUpdate { - scope.InstanceSet("gorm:update_attrs", updateMaps) - } else { - scope.SkipLeft() - } - } -} - -// beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating -func beforeUpdateCallback(scope *Scope) { - if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() { - scope.Err(errors.New("Missing WHERE clause while updating")) - return - } - if _, ok := scope.Get("gorm:update_column"); !ok { - if !scope.HasError() { - scope.CallMethod("BeforeSave") - } - if !scope.HasError() { - scope.CallMethod("BeforeUpdate") - } - } -} - -// updateTimeStampForUpdateCallback will set `UpdatedAt` when updating -func updateTimeStampForUpdateCallback(scope *Scope) { - if _, ok := scope.Get("gorm:update_column"); !ok { - scope.SetColumn("UpdatedAt", NowFunc()) - } -} - -// updateCallback the callback used to update data to database -func updateCallback(scope *Scope) { - if !scope.HasError() { - var sqls []string - - if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { - // Sort the column names so that the generated SQL is the same every time. - updateMap := updateAttrs.(map[string]interface{}) - var columns []string - for c := range updateMap { - columns = append(columns, c) - } - sort.Strings(columns) - - for _, column := range columns { - value := updateMap[column] - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value))) - } - } else { - for _, field := range scope.Fields() { - if scope.changeableField(field) { - if !field.IsPrimaryKey && field.IsNormal { - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) - } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { - for _, foreignKey := range relationship.ForeignDBNames { - if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) { - sqls = append(sqls, - fmt.Sprintf("%v = %v", scope.Quote(foreignField.DBName), scope.AddToVars(foreignField.Field.Interface()))) - } - } - } - } - } - } - - var extraOption string - if str, ok := scope.Get("gorm:update_option"); ok { - extraOption = fmt.Sprint(str) - } - - if len(sqls) > 0 { - scope.Raw(fmt.Sprintf( - "UPDATE %v SET %v%v%v", - scope.QuotedTableName(), - strings.Join(sqls, ", "), - addExtraSpaceIfExist(scope.CombinedConditionSql()), - addExtraSpaceIfExist(extraOption), - )).Exec() - } - } -} - -// afterUpdateCallback will invoke `AfterUpdate`, `AfterSave` method after updating -func afterUpdateCallback(scope *Scope) { - if _, ok := scope.Get("gorm:update_column"); !ok { - if !scope.HasError() { - scope.CallMethod("AfterUpdate") - } - if !scope.HasError() { - scope.CallMethod("AfterSave") - } - } -} diff --git a/model/errors.go b/dialects/common/sqlbuilder/errors.go similarity index 50% rename from model/errors.go rename to dialects/common/sqlbuilder/errors.go index 425f7e75..329ceb86 100644 --- a/model/errors.go +++ b/dialects/common/sqlbuilder/errors.go @@ -1,10 +1,8 @@ -package model +package sqlbuilder import "errors" var ( // ErrInvalidTable invalid table name ErrInvalidTable = errors.New("invalid table name") - // ErrUnaddressable unaddressable value - ErrUnaddressable = errors.New("using unaddressable value") ) diff --git a/model/field.go b/dialects/common/sqlbuilder/fields.go similarity index 52% rename from model/field.go rename to dialects/common/sqlbuilder/fields.go index e3d3a0d2..35f47e8f 100644 --- a/model/field.go +++ b/dialects/common/sqlbuilder/fields.go @@ -1,80 +1,28 @@ -package model +package sqlbuilder import ( - "database/sql" - "errors" "fmt" "reflect" "sort" "strings" "github.com/jinzhu/gorm" - "github.com/jinzhu/gorm/builder" + "github.com/jinzhu/gorm/model" "github.com/jinzhu/gorm/schema" ) -// Field GORM model field -type Field struct { - *schema.Field - IsBlank bool - Value reflect.Value -} - -// Set set a value to the field -func (field *Field) Set(value interface{}) (err error) { - if !field.Value.IsValid() { - return errors.New("field value not valid") - } - - if !field.Value.CanAddr() { - return ErrUnaddressable - } - - reflectValue, ok := value.(reflect.Value) - if !ok { - reflectValue = reflect.ValueOf(value) - } - - fieldValue := field.Value - if reflectValue.IsValid() { - if reflectValue.Type().ConvertibleTo(fieldValue.Type()) { - fieldValue.Set(reflectValue.Convert(fieldValue.Type())) - } else { - if fieldValue.Kind() == reflect.Ptr { - if fieldValue.IsNil() { - fieldValue.Set(reflect.New(field.StructField.Type.Elem())) - } - fieldValue = fieldValue.Elem() - } - - if reflectValue.Type().ConvertibleTo(fieldValue.Type()) { - fieldValue.Set(reflectValue.Convert(fieldValue.Type())) - } else if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { - err = scanner.Scan(reflectValue.Interface()) - } else { - err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type()) - } - } - } else { - field.Value.Set(reflect.Zero(fieldValue.Type())) - } - - field.IsBlank = isBlank(field.Value) - return err -} - -// GetAssignments get assignments -func GetAssignments(tx *gorm.DB) chan [][]*Field { - fieldChan := make(chan [][]*Field) +// GetAssignmentFields get assignment fields +func GetAssignmentFields(tx *gorm.DB) chan [][]*model.Field { + fieldChan := make(chan [][]*model.Field) go func() { assignableChecker := generateAssignableChecker(selectAttrs(tx.Statement), omitAttrs(tx.Statement)) switch dest := tx.Statement.Dest.(type) { case map[string]interface{}: - fieldChan <- [][]*Field{mapToFields(dest, schema.Parse(tx.Statement.Table), assignableChecker)} + fieldChan <- [][]*model.Field{mapToFields(dest, schema.Parse(tx.Statement.Table), assignableChecker)} case []map[string]interface{}: - fields := [][]*Field{} + fields := [][]*model.Field{} tableSchema := schema.Parse(tx.Statement.Table) for _, v := range dest { @@ -87,13 +35,13 @@ func GetAssignments(tx *gorm.DB) chan [][]*Field { switch results.Kind() { case reflect.Slice: - fields := [][]*Field{} + fields := [][]*model.Field{} for i := 0; i < results.Len(); i++ { fields = append(fields, structToField(indirect(results.Index(i)), s, assignableChecker)) } fieldChan <- fields case reflect.Struct: - fieldChan <- [][]*Field{structToField(results, s, assignableChecker)} + fieldChan <- [][]*model.Field{structToField(results, s, assignableChecker)} } } } @@ -102,12 +50,12 @@ func GetAssignments(tx *gorm.DB) chan [][]*Field { return fieldChan } -func mapToFields(value map[string]interface{}, s *schema.Schema, assignableChecker func(*Field) bool) (fields []*Field) { +func mapToFields(value map[string]interface{}, s *schema.Schema, assignableChecker func(*model.Field) bool) (fields []*model.Field) { // TODO assign those value to dest for k, v := range value { if s != nil { if f := s.FieldByName(k); f != nil { - field := &Field{Field: f, Value: reflect.ValueOf(v)} + field := &model.Field{Field: f, Value: reflect.ValueOf(v)} if assignableChecker(field) { fields = append(fields, field) } @@ -115,7 +63,7 @@ func mapToFields(value map[string]interface{}, s *schema.Schema, assignableCheck } } - field := &Field{Field: &schema.Field{DBName: k}, Value: reflect.ValueOf(v)} + field := &model.Field{Field: &schema.Field{DBName: k}, Value: reflect.ValueOf(v)} if assignableChecker(field) { fields = append(fields, field) } @@ -127,14 +75,14 @@ func mapToFields(value map[string]interface{}, s *schema.Schema, assignableCheck return } -func structToField(value reflect.Value, s *schema.Schema, assignableChecker func(*Field) bool) (fields []*Field) { +func structToField(value reflect.Value, s *schema.Schema, assignableChecker func(*model.Field) bool) (fields []*model.Field) { // TODO use Offset to replace FieldByName? for _, sf := range s.Fields { obj := value for _, bn := range sf.BindNames { obj = value.FieldByName(bn) } - field := &Field{Field: sf, Value: obj, IsBlank: isBlank(obj)} + field := &model.Field{Field: sf, Value: obj, IsBlank: model.IsBlank(obj)} if assignableChecker(field) { fields = append(fields, field) } @@ -143,8 +91,8 @@ func structToField(value reflect.Value, s *schema.Schema, assignableChecker func } // generateAssignableChecker generate checker to check if field is assignable or not -func generateAssignableChecker(selectAttrs []string, omitAttrs []string) func(*Field) bool { - return func(field *Field) bool { +func generateAssignableChecker(selectAttrs []string, omitAttrs []string) func(*model.Field) bool { + return func(field *model.Field) bool { if len(selectAttrs) > 0 { for _, attr := range selectAttrs { if field.Name == attr || field.DBName == attr { @@ -164,7 +112,7 @@ func generateAssignableChecker(selectAttrs []string, omitAttrs []string) func(*F } // omitAttrs return selected attributes of stmt -func selectAttrs(stmt *builder.Statement) []string { +func selectAttrs(stmt *gorm.Statement) []string { columns := stmt.Select.Columns for _, arg := range stmt.Select.Args { columns = append(columns, fmt.Sprint(arg)) @@ -173,6 +121,6 @@ func selectAttrs(stmt *builder.Statement) []string { } // omitAttrs return omitted attributes of stmt -func omitAttrs(stmt *builder.Statement) []string { +func omitAttrs(stmt *gorm.Statement) []string { return stmt.Omit } diff --git a/dialects/common/sqlbuilder/sqlbuilder.go b/dialects/common/sqlbuilder/sqlbuilder.go new file mode 100644 index 00000000..5debeaef --- /dev/null +++ b/dialects/common/sqlbuilder/sqlbuilder.go @@ -0,0 +1,63 @@ +package sqlbuilder + +import ( + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/model" + "github.com/jinzhu/gorm/schema" + "github.com/jinzhu/inflection" +) + +// GetTable get table name for current db operation +func GetTable(tx *gorm.DB) chan string { + tableChan := make(chan string) + + go func() { + var tableName string + if name, ok := tx.Statement.Table.(string); ok { + tableName = name + } else { + for _, v := range []interface{}{tx.Statement.Table, tx.Statement.Dest} { + if v != nil { + if t, ok := v.(tabler); ok { + tableName = t.TableName() + } else if t, ok := v.(dbTabler); ok { + tableName = t.TableName(tx) + } else if s := schema.Parse(v); s != nil { + if s.TableName != "" { + tableName = s.TableName + } else { + tableName = schema.ToDBName(s.ModelType.Name()) + if !tx.Config.SingularTable { + tableName = inflection.Plural(tableName) + } + } + } + + if tableName != "" { + break + } + } + } + } + + if tableName != "" { + if model.DefaultTableNameHandler != nil { + tableChan <- model.DefaultTableNameHandler(tx, tableName) + } else { + tableChan <- tableName + } + } else { + tx.AddError(ErrInvalidTable) + } + }() + + return tableChan +} + +type tabler interface { + TableName() string +} + +type dbTabler interface { + TableName(*gorm.DB) string +} diff --git a/dialects/common/sqlbuilder/utils.go b/dialects/common/sqlbuilder/utils.go new file mode 100644 index 00000000..ce71c421 --- /dev/null +++ b/dialects/common/sqlbuilder/utils.go @@ -0,0 +1,10 @@ +package sqlbuilder + +import "reflect" + +func indirect(reflectValue reflect.Value) reflect.Value { + for reflectValue.Kind() == reflect.Ptr { + reflectValue = reflectValue.Elem() + } + return reflectValue +} diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go index e0ff63bc..7c5e9e7a 100644 --- a/dialects/sqlite/sqlite.go +++ b/dialects/sqlite/sqlite.go @@ -6,7 +6,7 @@ import ( "fmt" "github.com/jinzhu/gorm" - "github.com/jinzhu/gorm/model" + "github.com/jinzhu/gorm/dialects/common/destination" ) // Dialect Sqlite3 Dialect for GORM @@ -23,9 +23,9 @@ func (dialect Dialect) Quote(name string) string { func (dialect *Dialect) Insert(tx *gorm.DB) (err error) { var ( args []interface{} - assignmentsChan = model.GetAssignments(tx) - tableNameChan = model.GetTable(tx) - primaryFields []*model.Field + assignmentsChan = destination.GetAssignments(tx) + tableNameChan = destination.GetTable(tx) + primaryFields []*destination.Field ) s := bytes.NewBufferString("INSERT INTO ") @@ -41,7 +41,7 @@ func (dialect *Dialect) Insert(tx *gorm.DB) (err error) { valueBuffer := bytes.NewBufferString("VALUES ") for idx, fields := range assignments { - var primaryField *model.Field + var primaryField *destination.Field if idx != 0 { valueBuffer.WriteString(",") } diff --git a/logger/utils.go b/logger/utils.go new file mode 100644 index 00000000..d34f2cb3 --- /dev/null +++ b/logger/utils.go @@ -0,0 +1,21 @@ +package logger + +import ( + "fmt" + "regexp" + "runtime" +) + +var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm/.*.go`) +var goTestRegexp = regexp.MustCompile(`jinzhu/gorm/.*test.go`) + +// FileWithLineNum get filename with line num for logging +func FileWithLineNum() string { + for i := 2; i < 15; i++ { + _, file, line, ok := runtime.Caller(i) + if ok && (!goSrcRegexp.MatchString(file) || goTestRegexp.MatchString(file)) { + return fmt.Sprintf("%v:%v", file, line) + } + } + return "" +} diff --git a/model/model.go b/model/model.go index dd2d7eb4..f8637058 100644 --- a/model/model.go +++ b/model/model.go @@ -1,9 +1,13 @@ package model import ( + "database/sql" + "errors" + "fmt" + "reflect" + "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/schema" - "github.com/jinzhu/inflection" ) // DefaultTableNameHandler default table name handler @@ -12,57 +16,52 @@ import ( // } var DefaultTableNameHandler func(tx *gorm.DB, tableName string) string -// GetTable get table name for current db operation -func GetTable(tx *gorm.DB) chan string { - tableChan := make(chan string) +// Field GORM model field +type Field struct { + *schema.Field + IsBlank bool + Value reflect.Value +} - go func() { - var tableName string - if name, ok := tx.Statement.Table.(string); ok { - tableName = name +// Set set a value to the field +func (field *Field) Set(value interface{}) (err error) { + if !field.Value.IsValid() { + return errors.New("field value not valid") + } + + if !field.Value.CanAddr() { + return gorm.ErrUnaddressable + } + + reflectValue, ok := value.(reflect.Value) + if !ok { + reflectValue = reflect.ValueOf(value) + } + + fieldValue := field.Value + if reflectValue.IsValid() { + if reflectValue.Type().ConvertibleTo(fieldValue.Type()) { + fieldValue.Set(reflectValue.Convert(fieldValue.Type())) } else { - for _, v := range []interface{}{tx.Statement.Table, tx.Statement.Dest} { - if v != nil { - if t, ok := v.(tabler); ok { - tableName = t.TableName() - } else if t, ok := v.(dbTabler); ok { - tableName = t.TableName(tx) - } else if s := schema.Parse(v); s != nil { - if s.TableName != "" { - tableName = s.TableName - } else { - tableName = schema.ToDBName(s.ModelType.Name()) - if !tx.Config.SingularTable { - tableName = inflection.Plural(tableName) - } - } - } - - if tableName != "" { - break - } + if fieldValue.Kind() == reflect.Ptr { + if fieldValue.IsNil() { + fieldValue.Set(reflect.New(field.StructField.Type.Elem())) } + fieldValue = fieldValue.Elem() } - } - if tableName != "" { - if DefaultTableNameHandler != nil { - tableChan <- DefaultTableNameHandler(tx, tableName) + if reflectValue.Type().ConvertibleTo(fieldValue.Type()) { + fieldValue.Set(reflectValue.Convert(fieldValue.Type())) + } else if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { + err = scanner.Scan(reflectValue.Interface()) } else { - tableChan <- tableName + err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type()) } - } else { - tx.AddError(ErrInvalidTable) } - }() + } else { + field.Value.Set(reflect.Zero(fieldValue.Type())) + } - return tableChan -} - -type tabler interface { - TableName() string -} - -type dbTabler interface { - TableName(*gorm.DB) string + field.IsBlank = IsBlank(field.Value) + return err } diff --git a/model/utils.go b/model/utils.go index da3d3dac..700f7087 100644 --- a/model/utils.go +++ b/model/utils.go @@ -2,32 +2,7 @@ package model import "reflect" -// ToSearchableMap convert attrs to searchable map -func ToSearchableMap(attrs ...interface{}) (result interface{}) { - if len(attrs) > 1 { - if str, ok := attrs[0].(string); ok { - result = map[string]interface{}{str: attrs[1]} - } - } else if len(attrs) == 1 { - if attr, ok := attrs[0].(map[string]interface{}); ok { - result = attr - } - - if attr, ok := attrs[0].(interface{}); ok { - result = attr - } - } - return -} - -func indirect(reflectValue reflect.Value) reflect.Value { - for reflectValue.Kind() == reflect.Ptr { - reflectValue = reflectValue.Elem() - } - return reflectValue -} - -func isBlank(value reflect.Value) bool { +func IsBlank(value reflect.Value) bool { switch value.Kind() { case reflect.String: return value.Len() == 0