This commit is contained in:
Paolo Galeone 2015-12-13 19:19:32 +01:00
commit fa2f068cbb
6 changed files with 50 additions and 28 deletions

View File

@ -249,6 +249,14 @@ func handleDefaultValue(scope *Scope, field *Field) string {
return escapeIfNeeded(scope, defaultValue) return escapeIfNeeded(scope, defaultValue)
} }
} }
/*
case reflect.String:
if defaultValue != field.Field.String() {
return escapeIfNeeded(scope, fmt.Sprintf("%s", field.Field.String()))
} else {
return escapeIfNeeded(scope, defaultValue)
}
*/
default: default:
return escapeIfNeeded(scope, defaultValue) return escapeIfNeeded(scope, defaultValue)
} }

View File

@ -71,7 +71,9 @@ func Query(scope *Scope) {
if field.Field.Kind() == reflect.Ptr { if field.Field.Kind() == reflect.Ptr {
values[index] = field.Field.Addr().Interface() values[index] = field.Field.Addr().Interface()
} else { } else {
values[index] = reflect.New(reflect.PtrTo(field.Field.Type())).Interface() reflectValue := reflect.New(reflect.PtrTo(field.Struct.Type))
reflectValue.Elem().Set(field.Field.Addr())
values[index] = reflectValue.Interface()
} }
} else { } else {
var value interface{} var value interface{}

View File

@ -3,6 +3,7 @@ package gorm
import ( import (
"database/sql" "database/sql"
"errors" "errors"
"fmt"
"reflect" "reflect"
) )
@ -44,7 +45,7 @@ func (field *Field) Set(value interface{}) error {
if reflectValue.Type().ConvertibleTo(field.Field.Type()) { if reflectValue.Type().ConvertibleTo(field.Field.Type()) {
field.Field.Set(reflectValue.Convert(field.Field.Type())) field.Field.Set(reflectValue.Convert(field.Field.Type()))
} else { } else {
return errors.New("could not convert argument") return fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), field.Field.Type())
} }
} }
@ -61,10 +62,12 @@ func (scope *Scope) Fields() map[string]*Field {
indirectValue := scope.IndirectValue() indirectValue := scope.IndirectValue()
isStruct := indirectValue.Kind() == reflect.Struct isStruct := indirectValue.Kind() == reflect.Struct
for _, structField := range modelStruct.StructFields { for _, structField := range modelStruct.StructFields {
if isStruct { if field, ok := fields[structField.DBName]; !ok || field.IsIgnored {
fields[structField.DBName] = getField(indirectValue, structField) if isStruct {
} else { fields[structField.DBName] = getField(indirectValue, structField)
fields[structField.DBName] = &Field{StructField: structField, IsBlank: true} } else {
fields[structField.DBName] = &Field{StructField: structField, IsBlank: true}
}
} }
} }

View File

@ -10,7 +10,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/qor/inflection" "github.com/jinzhu/inflection"
) )
var DefaultTableNameHandler = func(db *DB, defaultTableName string) string { var DefaultTableNameHandler = func(db *DB, defaultTableName string) string {
@ -149,24 +149,25 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
if fieldStruct.Tag.Get("sql") == "-" { if fieldStruct.Tag.Get("sql") == "-" {
field.IsIgnored = true field.IsIgnored = true
} else {
sqlSettings := parseTagSetting(field.Tag.Get("sql"))
gormSettings := parseTagSetting(field.Tag.Get("gorm"))
if _, ok := gormSettings["PRIMARY_KEY"]; ok {
field.IsPrimaryKey = true
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
}
if _, ok := sqlSettings["DEFAULT"]; ok {
field.HasDefaultValue = true
}
if value, ok := gormSettings["COLUMN"]; ok {
field.DBName = value
} else {
field.DBName = ToDBName(fieldStruct.Name)
}
} }
sqlSettings := parseTagSetting(field.Tag.Get("sql"))
gormSettings := parseTagSetting(field.Tag.Get("gorm"))
if _, ok := gormSettings["PRIMARY_KEY"]; ok {
field.IsPrimaryKey = true
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
}
if _, ok := sqlSettings["DEFAULT"]; ok {
field.HasDefaultValue = true
}
if value, ok := gormSettings["COLUMN"]; ok {
field.DBName = value
} else {
field.DBName = ToDBName(fieldStruct.Name)
}
fields = append(fields, field) fields = append(fields, field)
} }
} }

View File

@ -209,6 +209,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
} }
db := scope.NewDB().Table(scope.New(reflect.New(destType).Interface()).TableName()).Select("*") db := scope.NewDB().Table(scope.New(reflect.New(destType).Interface()).TableName()).Select("*")
preloadJoinDB := joinTableHandler.JoinWith(joinTableHandler, db, scope.Value) preloadJoinDB := joinTableHandler.JoinWith(joinTableHandler, db, scope.Value)
if len(conditions) > 0 { if len(conditions) > 0 {
@ -228,13 +229,15 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
fields := scope.New(elem.Addr().Interface()).Fields() fields := scope.New(elem.Addr().Interface()).Fields()
var foundFields = map[string]bool{}
for index, column := range columns { for index, column := range columns {
if field, ok := fields[column]; ok { if field, ok := fields[column]; ok && !foundFields[column] {
if field.Field.Kind() == reflect.Ptr { if field.Field.Kind() == reflect.Ptr {
values[index] = field.Field.Addr().Interface() values[index] = field.Field.Addr().Interface()
} else { } else {
values[index] = reflect.New(reflect.PtrTo(field.Field.Type())).Interface() values[index] = reflect.New(reflect.PtrTo(field.Field.Type())).Interface()
} }
foundFields[column] = true
} else { } else {
var i interface{} var i interface{}
values[index] = &i values[index] = &i
@ -245,14 +248,16 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
var sourceKey []interface{} var sourceKey []interface{}
var scannedFields = map[string]bool{}
for index, column := range columns { for index, column := range columns {
value := values[index] value := values[index]
if field, ok := fields[column]; ok { if field, ok := fields[column]; ok && !scannedFields[column] {
if field.Field.Kind() == reflect.Ptr { if field.Field.Kind() == reflect.Ptr {
field.Field.Set(reflect.ValueOf(value).Elem()) field.Field.Set(reflect.ValueOf(value).Elem())
} else if v := reflect.ValueOf(value).Elem().Elem(); v.IsValid() { } else if v := reflect.ValueOf(value).Elem().Elem(); v.IsValid() {
field.Field.Set(v) field.Field.Set(v)
} }
scannedFields[column] = true
} else if strInSlice(column, sourceKeys) { } else if strInSlice(column, sourceKeys) {
sourceKey = append(sourceKey, *(value.(*interface{}))) sourceKey = append(sourceKey, *(value.(*interface{})))
} }

View File

@ -492,12 +492,13 @@ func (scope *Scope) createJoinTable(field *StructField) {
if !scope.Dialect().HasTable(scope, joinTable) { if !scope.Dialect().HasTable(scope, joinTable) {
toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()} toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()}
var sqlTypes []string var sqlTypes, primaryKeys []string
for idx, fieldName := range relationship.ForeignFieldNames { for idx, fieldName := range relationship.ForeignFieldNames {
if field, ok := scope.Fields()[fieldName]; ok { if field, ok := scope.Fields()[fieldName]; ok {
value := reflect.Indirect(reflect.New(field.Struct.Type)) value := reflect.Indirect(reflect.New(field.Struct.Type))
primaryKeySqlType := scope.Dialect().SqlTag(value, 255, false) primaryKeySqlType := scope.Dialect().SqlTag(value, 255, false)
sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+primaryKeySqlType) sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+primaryKeySqlType)
primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx]))
} }
} }
@ -506,9 +507,11 @@ func (scope *Scope) createJoinTable(field *StructField) {
value := reflect.Indirect(reflect.New(field.Struct.Type)) value := reflect.Indirect(reflect.New(field.Struct.Type))
primaryKeySqlType := scope.Dialect().SqlTag(value, 255, false) primaryKeySqlType := scope.Dialect().SqlTag(value, 255, false)
sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+primaryKeySqlType) sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+primaryKeySqlType)
primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx]))
} }
} }
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v) %s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), scope.getTableOptions())).Error)
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v)) %s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).Error)
} }
scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler) scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler)
} }