Refact field struct

This commit is contained in:
Jinzhu 2013-11-14 21:38:35 +08:00
parent e4612bde9c
commit c354b0fb06
3 changed files with 79 additions and 71 deletions

8
do.go
View File

@ -640,8 +640,8 @@ func (s *Do) combinedSql() string {
func (s *Do) createTable() *Do { func (s *Do) createTable() *Do {
var sqls []string var sqls []string
for _, field := range s.model.fields("migration") { for _, field := range s.model.fields("migration") {
if len(field.SqlType()) > 0 { if len(field.sqlTag()) > 0 {
sqls = append(sqls, field.DbName+" "+field.SqlType()) sqls = append(sqls, field.DbName+" "+field.sqlTag())
} }
} }
@ -701,8 +701,8 @@ func (s *Do) autoMigrate() *Do {
s.sqlVars = []interface{}{} s.sqlVars = []interface{}{}
// If column doesn't exist // If column doesn't exist
if len(column_name) == 0 && len(field.SqlType()) > 0 { if len(column_name) == 0 && len(field.sqlTag()) > 0 {
s.sql = fmt.Sprintf("ALTER TABLE %v ADD %v %v;", s.tableName(), field.DbName, field.SqlType()) s.sql = fmt.Sprintf("ALTER TABLE %v ADD %v %v;", s.tableName(), field.DbName, field.sqlTag())
s.exec() s.exec()
} }
} }

View File

@ -3,6 +3,7 @@ package gorm
import ( import (
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"time" "time"
"strconv" "strconv"
@ -12,22 +13,50 @@ import (
) )
type Field struct { type Field struct {
Name string Name string
Value interface{} Value interface{}
DbName string DbName string
AutoCreateTime bool AutoCreateTime bool
AutoUpdateTime bool AutoUpdateTime bool
IsPrimaryKey bool IsPrimaryKey bool
IsBlank bool structField reflect.StructField
structField reflect.StructField modelValue reflect.Value
beforeAssociation bool beforeAssociation bool
afterAssociation bool afterAssociation bool
foreignKey string foreignKey string
model *Model model *Model
} }
func (f *Field) SqlType() string { func (f *Field) isBlank() bool {
value := reflect.ValueOf(f.Value)
switch value.Kind() {
case reflect.Int, reflect.Int64, reflect.Int32:
return value.Int() == 0
case reflect.String:
return value.String() == ""
case reflect.Slice:
return value.Len() == 0
case reflect.Struct:
time_value, is_time := f.Value.(time.Time)
if is_time {
return time_value.IsZero()
} else {
_, is_scanner := reflect.New(value.Type()).Interface().(sql.Scanner)
if is_scanner {
return !value.FieldByName("Valid").Interface().(bool)
} else {
m := &Model{data: value.Interface(), do: f.model.do}
fields := m.columnsHasValue("other")
if len(fields) == 0 {
return true
}
}
}
}
return false
}
func (f *Field) sqlTag() string {
column := getInterfaceValue(f.Value) column := getInterfaceValue(f.Value)
field_value := reflect.ValueOf(f.Value) field_value := reflect.ValueOf(f.Value)
switch field_value.Kind() { switch field_value.Kind() {
@ -61,6 +90,35 @@ func (f *Field) SqlType() string {
return typ return typ
} }
func (f *Field) parseAssociation() {
field_value := reflect.ValueOf(f.Value)
switch field_value.Kind() {
case reflect.Slice:
foreign_key := f.model.typeName() + "Id"
if reflect.New(field_value.Type().Elem()).Elem().FieldByName(foreign_key).IsValid() {
f.foreignKey = foreign_key
}
f.afterAssociation = true
case reflect.Struct:
_, is_time := f.Value.(time.Time)
_, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner)
if !is_scanner && !is_time {
if f.modelValue.FieldByName(f.Name + "Id").IsValid() {
f.foreignKey = f.Name + "Id"
f.beforeAssociation = true
} else {
foreign_key := f.model.typeName() + "Id"
if reflect.New(field_value.Type()).Elem().FieldByName(foreign_key).IsValid() {
f.foreignKey = foreign_key
}
f.afterAssociation = true
}
}
}
}
func parseSqlTag(str string) (typ string, addational_typ string, size int) { func parseSqlTag(str string) (typ string, addational_typ string, size int) {
if str == "-" { if str == "-" {
typ = str typ = str

View File

@ -1,7 +1,6 @@
package gorm package gorm
import ( import (
"database/sql"
"errors" "errors"
"go/ast" "go/ast"
"reflect" "reflect"
@ -73,31 +72,7 @@ func (m *Model) fields(operation string) (fields []*Field) {
value := indirect_value.FieldByName(p.Name) value := indirect_value.FieldByName(p.Name)
time_value, is_time := value.Interface().(time.Time) time_value, is_time := value.Interface().(time.Time)
field.model = m field.model = m
field.modelValue = indirect_value
switch value.Kind() {
case reflect.Int, reflect.Int64, reflect.Int32:
field.IsBlank = value.Int() == 0
case reflect.String:
field.IsBlank = value.String() == ""
case reflect.Slice:
field.IsBlank = value.Len() == 0
case reflect.Struct:
if is_time {
field.IsBlank = time_value.IsZero()
} else {
_, is_scanner := reflect.New(value.Type()).Interface().(sql.Scanner)
if is_scanner {
field.IsBlank = !value.FieldByName("Valid").Interface().(bool)
} else {
m := &Model{data: value.Interface(), do: m.do}
fields := m.columnsHasValue("other")
if len(fields) == 0 {
field.IsBlank = true
}
}
}
}
if is_time { if is_time {
field.AutoCreateTime = "created_at" == field.DbName field.AutoCreateTime = "created_at" == field.DbName
@ -113,37 +88,10 @@ func (m *Model) fields(operation string) (fields []*Field) {
value.Set(reflect.ValueOf(time.Now())) value.Set(reflect.ValueOf(time.Now()))
} }
} }
} else {
field_value := reflect.Indirect(value)
switch field_value.Kind() {
case reflect.Slice:
foreign_key := typ.Name() + "Id"
if reflect.New(field_value.Type().Elem()).Elem().FieldByName(foreign_key).IsValid() {
field.foreignKey = foreign_key
}
field.afterAssociation = true
case reflect.Struct:
_, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner)
if !is_scanner {
if indirect_value.FieldByName(p.Name + "Id").IsValid() {
field.foreignKey = p.Name + "Id"
field.beforeAssociation = true
} else {
foreign_key := typ.Name() + "Id"
if reflect.New(field_value.Type()).Elem().FieldByName(foreign_key).IsValid() {
field.foreignKey = foreign_key
}
field.afterAssociation = true
}
}
}
} }
field.structField = p field.structField = p
field.Value = value.Interface() field.Value = value.Interface()
fields = append(fields, &field) fields = append(fields, &field)
} }
} }
@ -157,7 +105,7 @@ func (m *Model) fields(operation string) (fields []*Field) {
func (m *Model) columnsHasValue(operation string) (fields []*Field) { func (m *Model) columnsHasValue(operation string) (fields []*Field) {
for _, field := range m.fields(operation) { for _, field := range m.fields(operation) {
if !field.IsBlank { if !field.isBlank() {
fields = append(fields, field) fields = append(fields, field)
} }
} }
@ -199,7 +147,7 @@ func (m *Model) columnsAndValues(operation string) map[string]interface{} {
if m.data != nil { if m.data != nil {
for _, field := range m.fields(operation) { for _, field := range m.fields(operation) {
if !field.IsPrimaryKey && (len(field.SqlType()) > 0) { if !field.IsPrimaryKey && (len(field.sqlTag()) > 0) {
results[field.DbName] = field.Value results[field.DbName] = field.Value
} }
} }
@ -297,7 +245,8 @@ func (m *Model) setValueByColumn(name string, value interface{}, out interface{}
func (m *Model) beforeAssociations() (fields []*Field) { func (m *Model) beforeAssociations() (fields []*Field) {
for _, field := range m.fields("null") { for _, field := range m.fields("null") {
if field.beforeAssociation && !field.IsBlank { field.parseAssociation()
if field.beforeAssociation && !field.isBlank() {
fields = append(fields, field) fields = append(fields, field)
} }
} }
@ -306,7 +255,8 @@ func (m *Model) beforeAssociations() (fields []*Field) {
func (m *Model) afterAssociations() (fields []*Field) { func (m *Model) afterAssociations() (fields []*Field) {
for _, field := range m.fields("null") { for _, field := range m.fields("null") {
if field.afterAssociation && !field.IsBlank { field.parseAssociation()
if field.afterAssociation && !field.isBlank() {
fields = append(fields, field) fields = append(fields, field)
} }
} }