Refactor Scope

This commit is contained in:
Jinzhu 2014-09-02 19:03:01 +08:00
parent 9c7ff3d211
commit 953c347ba7
10 changed files with 87 additions and 84 deletions

View File

@ -28,27 +28,29 @@ func (association *Association) Find(value interface{}) *Association {
func (association *Association) Append(values ...interface{}) *Association { func (association *Association) Append(values ...interface{}) *Association {
scope := association.Scope scope := association.Scope
field := scope.IndirectValue().FieldByName(association.Column) field := association.Field
fieldType := field.Field.Type()
for _, value := range values { for _, value := range values {
reflectvalue := reflect.ValueOf(value) reflectvalue := reflect.ValueOf(value)
if reflectvalue.Kind() == reflect.Ptr { if reflectvalue.Kind() == reflect.Ptr {
if reflectvalue.Elem().Kind() == reflect.Struct { if reflectvalue.Elem().Kind() == reflect.Struct {
if field.Type().Elem().Kind() == reflect.Ptr { if fieldType.Elem().Kind() == reflect.Ptr {
field.Set(reflect.Append(field, reflectvalue)) field.Set(reflect.Append(field.Field, reflectvalue))
} else if field.Type().Elem().Kind() == reflect.Struct { } else if fieldType.Elem().Kind() == reflect.Struct {
field.Set(reflect.Append(field, reflectvalue.Elem())) field.Set(reflect.Append(field.Field, reflectvalue.Elem()))
} }
} else if reflectvalue.Elem().Kind() == reflect.Slice { } else if reflectvalue.Elem().Kind() == reflect.Slice {
if field.Type().Elem().Kind() == reflect.Ptr { if fieldType.Elem().Kind() == reflect.Ptr {
field.Set(reflect.AppendSlice(field, reflectvalue)) field.Set(reflect.AppendSlice(field.Field, reflectvalue))
} else if field.Type().Elem().Kind() == reflect.Struct { } else if fieldType.Elem().Kind() == reflect.Struct {
field.Set(reflect.AppendSlice(field, reflectvalue.Elem())) field.Set(reflect.AppendSlice(field.Field, reflectvalue.Elem()))
} }
} }
} else if reflectvalue.Kind() == reflect.Struct && field.Type().Elem().Kind() == reflect.Struct { } else if reflectvalue.Kind() == reflect.Struct && fieldType.Elem().Kind() == reflect.Struct {
field.Set(reflect.Append(field, reflectvalue)) field.Set(reflect.Append(field.Field, reflectvalue))
} else if reflectvalue.Kind() == reflect.Slice && field.Type().Elem() == reflectvalue.Type().Elem() { } else if reflectvalue.Kind() == reflect.Slice && fieldType.Elem() == reflectvalue.Type().Elem() {
field.Set(reflect.AppendSlice(field, reflectvalue)) field.Set(reflect.AppendSlice(field.Field, reflectvalue))
} else { } else {
association.err(errors.New("invalid association type")) association.err(errors.New("invalid association type"))
} }
@ -107,7 +109,7 @@ func (association *Association) Replace(values ...interface{}) *Association {
relationship := association.Field.Relationship relationship := association.Field.Relationship
scope := association.Scope scope := association.Scope
if relationship.Kind == "many_to_many" { if relationship.Kind == "many_to_many" {
field := scope.IndirectValue().FieldByName(association.Column) field := association.Field.Field
oldPrimaryKeys := association.getPrimaryKeys(field.Interface()) oldPrimaryKeys := association.getPrimaryKeys(field.Interface())
association.Append(values...) association.Append(values...)
@ -154,7 +156,7 @@ func (association *Association) Count() int {
count := -1 count := -1
relationship := association.Field.Relationship relationship := association.Field.Relationship
scope := association.Scope scope := association.Scope
field := scope.IndirectValue().FieldByName(association.Column) field := association.Field.Field
fieldValue := field.Interface() fieldValue := field.Interface()
newScope := scope.New(fieldValue) newScope := scope.New(fieldValue)

View File

@ -158,10 +158,12 @@ func TestManyToMany(t *testing.T) {
languageA := Language{Name: "AA"} languageA := Language{Name: "AA"}
DB.Save(&languageA) DB.Save(&languageA)
DB.Model(&User{Id: user.Id}).Association("Languages").Append(languageA) DB.Model(&User{Id: user.Id}).Association("Languages").Append(languageA)
languageC := Language{Name: "CC"} languageC := Language{Name: "CC"}
DB.Save(&languageC) DB.Save(&languageC)
DB.Model(&user).Association("Languages").Append(&[]Language{{Name: "BB"}, languageC}) DB.Model(&user).Association("Languages").Append(&[]Language{{Name: "BB"}, languageC})
DB.Model(&User{Id: user.Id}).Association("Languages").Append([]Language{{Name: "DD"}, {Name: "EE"}})
DB.Model(&User{Id: user.Id}).Association("Languages").Append(&[]Language{{Name: "DD"}, {Name: "EE"}})
totalLanguages := []string{"ZH", "EN", "DE", "AA", "BB", "CC", "DD", "EE"} totalLanguages := []string{"ZH", "EN", "DE", "AA", "BB", "CC", "DD", "EE"}

View File

@ -28,7 +28,7 @@ func Create(scope *Scope) {
for _, field := range scope.Fields() { for _, field := range scope.Fields() {
if field.IsNormal && (!field.IsPrimaryKey || !scope.PrimaryKeyZero()) { if field.IsNormal && (!field.IsPrimaryKey || !scope.PrimaryKeyZero()) {
columns = append(columns, scope.Quote(field.DBName)) columns = append(columns, scope.Quote(field.DBName))
sqls = append(sqls, scope.AddToVars(field.Value)) sqls = append(sqls, scope.AddToVars(field.Field.Interface()))
} }
} }

View File

@ -19,20 +19,18 @@ func SaveBeforeAssociations(scope *Scope) {
if !field.IsBlank && !field.IsIgnored { if !field.IsBlank && !field.IsIgnored {
relationship := field.Relationship relationship := field.Relationship
if relationship != nil && relationship.Kind == "belongs_to" { if relationship != nil && relationship.Kind == "belongs_to" {
value := reflect.ValueOf(field.Value) value := field.Field
newDB := scope.NewDB() newDB := scope.NewDB()
if value.CanAddr() { if !value.CanAddr() {
scope.Err(newDB.Save(value.Addr().Interface()).Error)
} else {
// If can't take address, then clone the value and set it back // If can't take address, then clone the value and set it back
value = reflect.New(reflect.ValueOf(field.Value).Type()).Elem() value = reflect.New(value.Type()).Elem()
for _, f := range newDB.NewScope(field.Value).Fields() { for _, f := range newDB.NewScope(field.Field.Interface()).Fields() {
value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value)) value.FieldByName(f.Name).Set(reflect.ValueOf(f.Field.Interface()))
} }
scope.Err(newDB.Save(value.Addr().Interface()).Error)
scope.SetColumn(field.Name, value.Interface()) scope.SetColumn(field.Name, value.Interface())
} }
scope.Err(newDB.Save(value.Addr().Interface()).Error)
if relationship.ForeignKey != "" { if relationship.ForeignKey != "" {
scope.SetColumn(relationship.ForeignKey, newDB.NewScope(value.Interface()).PrimaryKeyValue()) scope.SetColumn(relationship.ForeignKey, newDB.NewScope(value.Interface()).PrimaryKeyValue())
@ -48,7 +46,7 @@ func SaveAfterAssociations(scope *Scope) {
relationship := field.Relationship relationship := field.Relationship
if relationship != nil && if relationship != nil &&
(relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") { (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") {
value := reflect.ValueOf(field.Value) value := field.Field
switch value.Kind() { switch value.Kind() {
case reflect.Slice: case reflect.Slice:
@ -89,14 +87,14 @@ func SaveAfterAssociations(scope *Scope) {
newDB := scope.NewDB() newDB := scope.NewDB()
if value.CanAddr() { if value.CanAddr() {
if relationship.ForeignKey != "" { if relationship.ForeignKey != "" {
newDB.NewScope(field.Value).SetColumn(relationship.ForeignKey, scope.PrimaryKeyValue()) newDB.NewScope(value.Addr().Interface()).SetColumn(relationship.ForeignKey, scope.PrimaryKeyValue())
} }
scope.Err(newDB.Save(field.Value).Error) scope.Err(newDB.Save(value.Addr().Interface()).Error)
} else { } else {
destValue := reflect.New(reflect.TypeOf(field.Value)).Elem() destValue := reflect.New(field.Field.Type()).Elem()
for _, f := range newDB.NewScope(field.Value).Fields() { for _, f := range newDB.NewScope(field.Field.Interface()).Fields() {
destValue.FieldByName(f.Name).Set(reflect.ValueOf(f.Value)) destValue.FieldByName(f.Name).Set(f.Field)
} }
elem := destValue.Addr().Interface() elem := destValue.Addr().Interface()

View File

@ -49,7 +49,7 @@ func Update(scope *Scope) {
} else { } else {
for _, field := range scope.Fields() { for _, field := range scope.Fields() {
if !field.IsPrimaryKey && field.IsNormal && !field.IsIgnored { if !field.IsPrimaryKey && field.IsNormal && !field.IsIgnored {
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Value))) sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
} }
} }
} }

View File

@ -17,7 +17,6 @@ type Field struct {
Name string Name string
DBName string DBName string
Field reflect.Value Field reflect.Value
Value interface{}
Tag reflect.StructTag Tag reflect.StructTag
Relationship *relationship Relationship *relationship
IsNormal bool IsNormal bool
@ -26,12 +25,29 @@ type Field struct {
IsPrimaryKey bool IsPrimaryKey bool
} }
func (f *Field) IsScanner() bool { func (field *Field) IsScanner() bool {
_, isScanner := reflect.New(reflect.ValueOf(f.Value).Type()).Interface().(sql.Scanner) _, isScanner := reflect.New(field.Field.Type()).Interface().(sql.Scanner)
return isScanner return isScanner
} }
func (f *Field) IsTime() bool { func (field *Field) IsTime() bool {
_, isTime := f.Value.(time.Time) _, isTime := field.Field.Interface().(time.Time)
return isTime return isTime
} }
func (field *Field) Set(value interface{}) (result bool) {
if field.Field.IsValid() && field.Field.CanAddr() {
result = true
if scanner, ok := field.Field.Addr().Interface().(sql.Scanner); ok {
scanner.Scan(value)
} else if reflect.TypeOf(value).ConvertibleTo(field.Field.Type()) {
field.Field.Set(reflect.ValueOf(value).Convert(field.Field.Type()))
} else {
result = false
}
}
if result {
field.IsBlank = isBlank(field.Field)
}
return
}

View File

@ -150,13 +150,17 @@ func (scope *Scope) FieldValueByName(name string) (interface{}, bool) {
} }
// SetColumn to set the column's value // SetColumn to set the column's value
func (scope *Scope) SetColumn(column string, value interface{}) bool { func (scope *Scope) SetColumn(column interface{}, value interface{}) bool {
if scope.Value == nil { if field, ok := column.(*Field); ok {
return false return field.Set(value)
} } else if str, ok := column.(string); ok {
for _, field := range scope.Fields() { if scope.Value == nil {
if field.Name == column || field.DBName == column { return false
return setFieldValue(field.Field, value) }
for _, field := range scope.Fields() {
if field.Name == str || field.DBName == str {
return field.Set(value)
}
} }
} }
return false return false
@ -267,11 +271,9 @@ func (scope *Scope) CombinedConditionSql() string {
} }
func (scope *Scope) FieldByName(name string) (field *Field, ok bool) { func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
if scope.Value != nil { for _, field := range scope.Fields() {
if scope.IndirectValue().Kind() == reflect.Struct { if field.Name == name {
if f, ok := scope.IndirectValue().Type().FieldByName(SnakeToUpperCamel(name)); ok { return field, true
return scope.fieldFromStruct(f, true)[0], true
}
} }
} }
return nil, false return nil, false
@ -285,7 +287,6 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField, withRelatio
value := scope.IndirectValue().FieldByName(fieldStruct.Name) value := scope.IndirectValue().FieldByName(fieldStruct.Name)
indirectValue := reflect.Indirect(value) indirectValue := reflect.Indirect(value)
field.Field = value field.Field = value
field.Value = value.Interface()
field.IsBlank = isBlank(value) field.IsBlank = isBlank(value)
// Search for primary key tag identifier // Search for primary key tag identifier
@ -416,9 +417,9 @@ func (scope *Scope) Fields(noRelations ...bool) map[string]*Field {
} }
} }
// if withRelation { if withRelation {
// scope.fields = fields scope.fields = fields
// } }
return fields return fields
} }

View File

@ -42,7 +42,7 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
var sqls []string var sqls []string
for _, field := range scope.New(value).Fields() { for _, field := range scope.New(value).Fields() {
if !field.IsBlank { if !field.IsBlank {
sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(field.DBName), scope.AddToVars(field.Value))) sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
} }
} }
return strings.Join(sqls, " AND ") return strings.Join(sqls, " AND ")
@ -103,7 +103,7 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string
var sqls []string var sqls []string
for _, field := range scope.New(value).Fields() { for _, field := range scope.New(value).Fields() {
if !field.IsBlank { if !field.IsBlank {
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(field.DBName), scope.AddToVars(field.Value))) sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
} }
} }
return strings.Join(sqls, " AND ") return strings.Join(sqls, " AND ")
@ -264,17 +264,17 @@ func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignore
} }
for key, value := range values { for key, value := range values {
if field := data.FieldByName(SnakeToUpperCamel(key)); field.IsValid() { if field, ok := scope.FieldByName(SnakeToUpperCamel(key)); ok && field.Field.IsValid() {
func() { func() {
defer func() { defer func() {
if err := recover(); err != nil { if err := recover(); err != nil {
hasUpdate = true hasUpdate = true
setFieldValue(field, value) field.Set(value)
} }
}() }()
if field.Interface() != value { if field.Field.Interface() != value {
switch field.Kind() { switch field.Field.Kind() {
case reflect.Int, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int32, reflect.Int64:
if s, ok := value.(string); ok { if s, ok := value.(string); ok {
i, err := strconv.Atoi(s) i, err := strconv.Atoi(s)
@ -283,13 +283,13 @@ func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignore
} }
} }
if field.Int() != reflect.ValueOf(value).Int() { if field.Field.Int() != reflect.ValueOf(value).Int() {
hasUpdate = true hasUpdate = true
setFieldValue(field, value) field.Set(value)
} }
default: default:
hasUpdate = true hasUpdate = true
setFieldValue(field, value) field.Set(value)
} }
} }
}() }()
@ -324,8 +324,8 @@ func (scope *Scope) sqlTagForField(field *Field) (typ string) {
additionalType = additionalType + "DEFAULT " + value additionalType = additionalType + "DEFAULT " + value
} }
value := field.Value value := field.Field.Interface()
reflectValue := reflect.ValueOf(value) reflectValue := field.Field
switch reflectValue.Kind() { switch reflectValue.Kind() {
case reflect.Slice: case reflect.Slice:

View File

@ -68,7 +68,7 @@ func ToSnake(u string) string {
} }
s := strings.ToLower(buf.String()) s := strings.ToLower(buf.String())
go smap.Set(u, s) smap.Set(u, s)
return s return s
} }
@ -86,7 +86,7 @@ func SnakeToUpperCamel(s string) string {
} }
u := buf.String() u := buf.String()
go umap.Set(s, u) umap.Set(s, u)
return u return u
} }

View File

@ -1,7 +1,6 @@
package gorm package gorm
import ( import (
"database/sql"
"fmt" "fmt"
"os" "os"
"reflect" "reflect"
@ -11,7 +10,7 @@ import (
) )
func fileWithLineNum() string { func fileWithLineNum() string {
for i := 1; i < 15; i++ { for i := 2; i < 15; i++ {
_, file, line, ok := runtime.Caller(i) _, file, line, ok := runtime.Caller(i)
if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) { if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) {
return fmt.Sprintf("%v:%v", strings.TrimPrefix(file, os.Getenv("GOPATH")+"src/"), line) return fmt.Sprintf("%v:%v", strings.TrimPrefix(file, os.Getenv("GOPATH")+"src/"), line)
@ -20,21 +19,6 @@ func fileWithLineNum() string {
return "" return ""
} }
func setFieldValue(field reflect.Value, value interface{}) (result bool) {
result = false
if field.IsValid() && field.CanAddr() {
result = true
if scanner, ok := field.Addr().Interface().(sql.Scanner); ok {
scanner.Scan(value)
} else if reflect.TypeOf(value).ConvertibleTo(field.Type()) {
field.Set(reflect.ValueOf(value).Convert(field.Type()))
} else {
result = false
}
}
return
}
func isBlank(value reflect.Value) bool { func isBlank(value reflect.Value) bool {
return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface()) return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface())
} }
@ -82,7 +66,7 @@ func convertInterfaceToMap(values interface{}) map[string]interface{} {
scope := Scope{Value: values} scope := Scope{Value: values}
for _, field := range scope.Fields() { for _, field := range scope.Fields() {
if !field.IsBlank { if !field.IsBlank {
attrs[field.DBName] = field.Value attrs[field.DBName] = field.Field.Interface()
} }
} }
} }