From c354b0fb067fcc8714df7e9f79278ad3f6b463c2 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 14 Nov 2013 21:38:35 +0800 Subject: [PATCH] Refact field struct --- do.go | 8 +++--- field.go | 78 ++++++++++++++++++++++++++++++++++++++++++++++++-------- model.go | 64 +++++----------------------------------------- 3 files changed, 79 insertions(+), 71 deletions(-) diff --git a/do.go b/do.go index b84c5f0c..81f0e87f 100644 --- a/do.go +++ b/do.go @@ -640,8 +640,8 @@ func (s *Do) combinedSql() string { func (s *Do) createTable() *Do { var sqls []string for _, field := range s.model.fields("migration") { - if len(field.SqlType()) > 0 { - sqls = append(sqls, field.DbName+" "+field.SqlType()) + if len(field.sqlTag()) > 0 { + sqls = append(sqls, field.DbName+" "+field.sqlTag()) } } @@ -701,8 +701,8 @@ func (s *Do) autoMigrate() *Do { s.sqlVars = []interface{}{} // If column doesn't exist - if len(column_name) == 0 && len(field.SqlType()) > 0 { - s.sql = fmt.Sprintf("ALTER TABLE %v ADD %v %v;", s.tableName(), field.DbName, field.SqlType()) + if len(column_name) == 0 && len(field.sqlTag()) > 0 { + s.sql = fmt.Sprintf("ALTER TABLE %v ADD %v %v;", s.tableName(), field.DbName, field.sqlTag()) s.exec() } } diff --git a/field.go b/field.go index 01e9d22e..1c5d399c 100644 --- a/field.go +++ b/field.go @@ -3,6 +3,7 @@ package gorm import ( "database/sql" "database/sql/driver" + "time" "strconv" @@ -12,22 +13,50 @@ import ( ) type Field struct { - Name string - Value interface{} - DbName string - AutoCreateTime bool - AutoUpdateTime bool - IsPrimaryKey bool - IsBlank bool - structField reflect.StructField - + Name string + Value interface{} + DbName string + AutoCreateTime bool + AutoUpdateTime bool + IsPrimaryKey bool + structField reflect.StructField + modelValue reflect.Value beforeAssociation bool afterAssociation bool foreignKey string model *Model } -func (f *Field) SqlType() string { +func (f *Field) isBlank() bool { + value := reflect.ValueOf(f.Value) + switch value.Kind() { + case reflect.Int, reflect.Int64, reflect.Int32: + return value.Int() == 0 + case reflect.String: + return value.String() == "" + case reflect.Slice: + return value.Len() == 0 + case reflect.Struct: + time_value, is_time := f.Value.(time.Time) + if is_time { + return time_value.IsZero() + } else { + _, is_scanner := reflect.New(value.Type()).Interface().(sql.Scanner) + if is_scanner { + return !value.FieldByName("Valid").Interface().(bool) + } else { + m := &Model{data: value.Interface(), do: f.model.do} + fields := m.columnsHasValue("other") + if len(fields) == 0 { + return true + } + } + } + } + return false +} + +func (f *Field) sqlTag() string { column := getInterfaceValue(f.Value) field_value := reflect.ValueOf(f.Value) switch field_value.Kind() { @@ -61,6 +90,35 @@ func (f *Field) SqlType() string { return typ } +func (f *Field) parseAssociation() { + field_value := reflect.ValueOf(f.Value) + + switch field_value.Kind() { + case reflect.Slice: + foreign_key := f.model.typeName() + "Id" + if reflect.New(field_value.Type().Elem()).Elem().FieldByName(foreign_key).IsValid() { + f.foreignKey = foreign_key + } + f.afterAssociation = true + case reflect.Struct: + _, is_time := f.Value.(time.Time) + _, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner) + + if !is_scanner && !is_time { + if f.modelValue.FieldByName(f.Name + "Id").IsValid() { + f.foreignKey = f.Name + "Id" + f.beforeAssociation = true + } else { + foreign_key := f.model.typeName() + "Id" + if reflect.New(field_value.Type()).Elem().FieldByName(foreign_key).IsValid() { + f.foreignKey = foreign_key + } + f.afterAssociation = true + } + } + } +} + func parseSqlTag(str string) (typ string, addational_typ string, size int) { if str == "-" { typ = str diff --git a/model.go b/model.go index 7fa58950..d4484fcc 100644 --- a/model.go +++ b/model.go @@ -1,7 +1,6 @@ package gorm import ( - "database/sql" "errors" "go/ast" "reflect" @@ -73,31 +72,7 @@ func (m *Model) fields(operation string) (fields []*Field) { value := indirect_value.FieldByName(p.Name) time_value, is_time := value.Interface().(time.Time) field.model = m - - switch value.Kind() { - case reflect.Int, reflect.Int64, reflect.Int32: - field.IsBlank = value.Int() == 0 - case reflect.String: - field.IsBlank = value.String() == "" - case reflect.Slice: - field.IsBlank = value.Len() == 0 - case reflect.Struct: - if is_time { - field.IsBlank = time_value.IsZero() - } else { - _, is_scanner := reflect.New(value.Type()).Interface().(sql.Scanner) - - if is_scanner { - field.IsBlank = !value.FieldByName("Valid").Interface().(bool) - } else { - m := &Model{data: value.Interface(), do: m.do} - fields := m.columnsHasValue("other") - if len(fields) == 0 { - field.IsBlank = true - } - } - } - } + field.modelValue = indirect_value if is_time { field.AutoCreateTime = "created_at" == field.DbName @@ -113,37 +88,10 @@ func (m *Model) fields(operation string) (fields []*Field) { value.Set(reflect.ValueOf(time.Now())) } } - } else { - field_value := reflect.Indirect(value) - - switch field_value.Kind() { - case reflect.Slice: - foreign_key := typ.Name() + "Id" - if reflect.New(field_value.Type().Elem()).Elem().FieldByName(foreign_key).IsValid() { - field.foreignKey = foreign_key - } - field.afterAssociation = true - case reflect.Struct: - _, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner) - - if !is_scanner { - if indirect_value.FieldByName(p.Name + "Id").IsValid() { - field.foreignKey = p.Name + "Id" - field.beforeAssociation = true - } else { - foreign_key := typ.Name() + "Id" - if reflect.New(field_value.Type()).Elem().FieldByName(foreign_key).IsValid() { - field.foreignKey = foreign_key - } - field.afterAssociation = true - } - } - } } field.structField = p field.Value = value.Interface() - fields = append(fields, &field) } } @@ -157,7 +105,7 @@ func (m *Model) fields(operation string) (fields []*Field) { func (m *Model) columnsHasValue(operation string) (fields []*Field) { for _, field := range m.fields(operation) { - if !field.IsBlank { + if !field.isBlank() { fields = append(fields, field) } } @@ -199,7 +147,7 @@ func (m *Model) columnsAndValues(operation string) map[string]interface{} { if m.data != nil { for _, field := range m.fields(operation) { - if !field.IsPrimaryKey && (len(field.SqlType()) > 0) { + if !field.IsPrimaryKey && (len(field.sqlTag()) > 0) { results[field.DbName] = field.Value } } @@ -297,7 +245,8 @@ func (m *Model) setValueByColumn(name string, value interface{}, out interface{} func (m *Model) beforeAssociations() (fields []*Field) { for _, field := range m.fields("null") { - if field.beforeAssociation && !field.IsBlank { + field.parseAssociation() + if field.beforeAssociation && !field.isBlank() { fields = append(fields, field) } } @@ -306,7 +255,8 @@ func (m *Model) beforeAssociations() (fields []*Field) { func (m *Model) afterAssociations() (fields []*Field) { for _, field := range m.fields("null") { - if field.afterAssociation && !field.IsBlank { + field.parseAssociation() + if field.afterAssociation && !field.isBlank() { fields = append(fields, field) } }