From 84c6b46011b5b146782affd77dcf5ff95e255c50 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 3 Dec 2015 15:18:42 +0800 Subject: [PATCH 1/6] Update inflection address --- model_struct.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model_struct.go b/model_struct.go index 4ea821d6..c4eb5827 100644 --- a/model_struct.go +++ b/model_struct.go @@ -10,7 +10,7 @@ import ( "sync" "time" - "github.com/qor/inflection" + "github.com/jinzhu/inflection" ) var DefaultTableNameHandler = func(db *DB, defaultTableName string) string { From 4c1a78bab7496c6db3c40ebb5d73cdd549bef35b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 4 Dec 2015 18:41:28 +0800 Subject: [PATCH 2/6] Don't query all columns out from database after create, but only those has default value --- callback_create.go | 11 ++++++++--- callback_query.go | 4 +++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/callback_create.go b/callback_create.go index 71db4ef0..d13a71be 100644 --- a/callback_create.go +++ b/callback_create.go @@ -33,7 +33,12 @@ func Create(scope *Scope) { columns = append(columns, scope.Quote(field.DBName)) sqls = append(sqls, scope.AddToVars(field.Field.Interface())) } else if field.HasDefaultValue { - scope.InstanceSet("gorm:force_reload_after_create", true) + var hasDefaultValueColumns []string + if oldHasDefaultValueColumns, ok := scope.InstanceGet("gorm:force_reload_after_create_attrs"); ok { + hasDefaultValueColumns = oldHasDefaultValueColumns.([]string) + } + hasDefaultValueColumns = append(hasDefaultValueColumns, field.DBName) + scope.InstanceSet("gorm:force_reload_after_create_attrs", hasDefaultValueColumns) } } } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { @@ -98,8 +103,8 @@ func Create(scope *Scope) { } func ForceReloadAfterCreate(scope *Scope) { - if _, ok := scope.InstanceGet("gorm:force_reload_after_create"); ok { - scope.DB().New().First(scope.Value) + if columns, ok := scope.InstanceGet("gorm:force_reload_after_create_attrs"); ok { + scope.DB().New().Select(columns.([]string)).First(scope.Value) } } diff --git a/callback_query.go b/callback_query.go index 387e813d..5473f232 100644 --- a/callback_query.go +++ b/callback_query.go @@ -71,7 +71,9 @@ func Query(scope *Scope) { if field.Field.Kind() == reflect.Ptr { values[index] = field.Field.Addr().Interface() } else { - values[index] = reflect.New(reflect.PtrTo(field.Field.Type())).Interface() + reflectValue := reflect.New(reflect.PtrTo(field.Struct.Type)) + reflectValue.Elem().Set(field.Field.Addr()) + values[index] = reflectValue.Interface() } } else { var value interface{} From fc42a1bbf35c9fee8900d7d0fa9f61ab9eff8b26 Mon Sep 17 00:00:00 2001 From: Luke Cowell Date: Fri, 4 Dec 2015 14:56:21 -0800 Subject: [PATCH 3/6] provide user with more descriptive error message --- field.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/field.go b/field.go index a56dbe51..9e6891c8 100644 --- a/field.go +++ b/field.go @@ -3,6 +3,7 @@ package gorm import ( "database/sql" "errors" + "fmt" "reflect" ) @@ -44,7 +45,7 @@ func (field *Field) Set(value interface{}) error { if reflectValue.Type().ConvertibleTo(field.Field.Type()) { field.Field.Set(reflectValue.Convert(field.Field.Type())) } else { - return errors.New("could not convert argument") + return fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), field.Field.Type()) } } From 807ed63cfe6ce71368fd9b1c47ad3c79af0ff76f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 9 Dec 2015 10:40:12 +0800 Subject: [PATCH 4/6] Fix pollute model's fields with join table's values --- preload.go | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/preload.go b/preload.go index 2d1aed2f..417a0d8c 100644 --- a/preload.go +++ b/preload.go @@ -209,6 +209,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface } db := scope.NewDB().Table(scope.New(reflect.New(destType).Interface()).TableName()).Select("*") + preloadJoinDB := joinTableHandler.JoinWith(joinTableHandler, db, scope.Value) if len(conditions) > 0 { @@ -228,13 +229,15 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface fields := scope.New(elem.Addr().Interface()).Fields() + var foundFields = map[string]bool{} for index, column := range columns { - if field, ok := fields[column]; ok { + if field, ok := fields[column]; ok && !foundFields[column] { if field.Field.Kind() == reflect.Ptr { values[index] = field.Field.Addr().Interface() } else { values[index] = reflect.New(reflect.PtrTo(field.Field.Type())).Interface() } + foundFields[column] = true } else { var i interface{} values[index] = &i @@ -245,14 +248,16 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface var sourceKey []interface{} + var scannedFields = map[string]bool{} for index, column := range columns { value := values[index] - if field, ok := fields[column]; ok { + if field, ok := fields[column]; ok && !scannedFields[column] { if field.Field.Kind() == reflect.Ptr { field.Field.Set(reflect.ValueOf(value).Elem()) } else if v := reflect.ValueOf(value).Elem().Elem(); v.IsValid() { field.Field.Set(v) } + scannedFields[column] = true } else if strInSlice(column, sourceKeys) { sourceKey = append(sourceKey, *(value.(*interface{}))) } From 341703ed5d437e38e80e3ceaf7a6588766ae1fcd Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 11 Dec 2015 11:45:22 +0800 Subject: [PATCH 5/6] Scan value into ignored fields if there is no ambiguity --- field.go | 10 ++++++---- model_struct.go | 35 ++++++++++++++++++----------------- 2 files changed, 24 insertions(+), 21 deletions(-) diff --git a/field.go b/field.go index 9e6891c8..db1fdd8f 100644 --- a/field.go +++ b/field.go @@ -62,10 +62,12 @@ func (scope *Scope) Fields() map[string]*Field { indirectValue := scope.IndirectValue() isStruct := indirectValue.Kind() == reflect.Struct for _, structField := range modelStruct.StructFields { - if isStruct { - fields[structField.DBName] = getField(indirectValue, structField) - } else { - fields[structField.DBName] = &Field{StructField: structField, IsBlank: true} + if field, ok := fields[structField.DBName]; !ok || field.IsIgnored { + if isStruct { + fields[structField.DBName] = getField(indirectValue, structField) + } else { + fields[structField.DBName] = &Field{StructField: structField, IsBlank: true} + } } } diff --git a/model_struct.go b/model_struct.go index c4eb5827..97d145b6 100644 --- a/model_struct.go +++ b/model_struct.go @@ -149,24 +149,25 @@ func (scope *Scope) GetModelStruct() *ModelStruct { if fieldStruct.Tag.Get("sql") == "-" { field.IsIgnored = true - } else { - sqlSettings := parseTagSetting(field.Tag.Get("sql")) - gormSettings := parseTagSetting(field.Tag.Get("gorm")) - if _, ok := gormSettings["PRIMARY_KEY"]; ok { - field.IsPrimaryKey = true - modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) - } - - if _, ok := sqlSettings["DEFAULT"]; ok { - field.HasDefaultValue = true - } - - if value, ok := gormSettings["COLUMN"]; ok { - field.DBName = value - } else { - field.DBName = ToDBName(fieldStruct.Name) - } } + + sqlSettings := parseTagSetting(field.Tag.Get("sql")) + gormSettings := parseTagSetting(field.Tag.Get("gorm")) + if _, ok := gormSettings["PRIMARY_KEY"]; ok { + field.IsPrimaryKey = true + modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) + } + + if _, ok := sqlSettings["DEFAULT"]; ok { + field.HasDefaultValue = true + } + + if value, ok := gormSettings["COLUMN"]; ok { + field.DBName = value + } else { + field.DBName = ToDBName(fieldStruct.Name) + } + fields = append(fields, field) } } From ba694926d032afd3dd1e649896a62bf802464069 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 11 Dec 2015 12:22:09 +0800 Subject: [PATCH 6/6] Create composite primary key for join table --- scope_private.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/scope_private.go b/scope_private.go index d5aacfdb..4fd7149d 100644 --- a/scope_private.go +++ b/scope_private.go @@ -492,12 +492,13 @@ func (scope *Scope) createJoinTable(field *StructField) { if !scope.Dialect().HasTable(scope, joinTable) { toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()} - var sqlTypes []string + var sqlTypes, primaryKeys []string for idx, fieldName := range relationship.ForeignFieldNames { if field, ok := scope.Fields()[fieldName]; ok { value := reflect.Indirect(reflect.New(field.Struct.Type)) primaryKeySqlType := scope.Dialect().SqlTag(value, 255, false) sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+primaryKeySqlType) + primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx])) } } @@ -506,9 +507,11 @@ func (scope *Scope) createJoinTable(field *StructField) { value := reflect.Indirect(reflect.New(field.Struct.Type)) primaryKeySqlType := scope.Dialect().SqlTag(value, 255, false) sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+primaryKeySqlType) + primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx])) } } - scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v) %s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), scope.getTableOptions())).Error) + + scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v)) %s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).Error) } scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler) }