diff --git a/schema/relationship.go b/schema/relationship.go index a351c3d9..ce122194 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -24,7 +24,7 @@ func buildToOneRel(field *Field, sourceSchema *Schema) { // user belongs to profile, associationType is Profile, user use ProfileID as foreign key relationship = &Relationship{} associationType = sourceSchema.ModelType.Name() - destSchema = ParseSchema(reflect.New(field.StructField.Type).Interface()) + destSchema = Parse(reflect.New(field.StructField.Type).Interface()) tagForeignKeys, tagAssociationForeignKeys []string ) @@ -184,7 +184,7 @@ func buildToManyRel(field *Field, sourceSchema *Schema) { var ( relationship = &Relationship{} elemType = field.StructField.Type - destSchema = ParseSchema(reflect.New(elemType).Interface()) + destSchema = Parse(reflect.New(elemType).Interface()) foreignKeys, associationForeignKeys []string ) diff --git a/schema/relationship_test.go b/schema/relationship_test.go new file mode 100644 index 00000000..5b37a784 --- /dev/null +++ b/schema/relationship_test.go @@ -0,0 +1,65 @@ +package schema + +import "testing" + +type BelongsTo struct { + ID int + Name string +} + +type HasOne struct { + ID int + MyStructID uint +} + +type HasMany struct { + ID int + MyStructID uint + Name string +} + +type Many2Many struct { + ID int + Name string +} + +func TestBelongsToRel(t *testing.T) { + type MyStruct struct { + ID int + Name string + BelongsTo BelongsTo + } + + Parse(&MyStruct{}) +} + +func TestHasOneRel(t *testing.T) { + type MyStruct struct { + ID int + Name string + HasOne HasOne + } + + Parse(&MyStruct{}) +} + +func TestHasManyRel(t *testing.T) { + type MyStruct struct { + ID int + Name string + HasMany []HasMany + } + + Parse(&MyStruct{}) +} + +func TestManyToManyRel(t *testing.T) { + type MyStruct struct { + ID int + Name string + HasMany []HasMany + } + + Parse(&MyStruct{}) +} + diff --git a/schema/schema.go b/schema/schema.go index 075e360a..93dda6dc 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -33,8 +33,8 @@ type Field struct { Relationship *Relationship } -// ParseSchema parse struct and generate schema based on struct and tag definition -func ParseSchema(dest interface{}) *Schema { +// Parse parse struct and generate schema based on struct and tag definition +func Parse(dest interface{}) *Schema { schema := Schema{} // Get dest type @@ -104,7 +104,7 @@ func ParseSchema(dest interface{}) *Schema { field.IsNormal = true } else if _, ok := field.TagSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { // embedded struct - if subSchema := ParseSchema(fieldValue); subSchema != nil { + if subSchema := Parse(fieldValue); subSchema != nil { for _, subField := range subSchema.Fields { subField = subField.clone() subField.BindNames = append([]string{fieldStruct.Name}, subField.BindNames...) diff --git a/schema/schema_test.go b/schema/schema_test.go index 8f57a4a5..f788c1da 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -2,89 +2,102 @@ package schema import ( "database/sql" + "fmt" + "reflect" "testing" "time" ) -type MyStruct struct { - ID int - Int uint - IntPointer *uint - String string - StringPointer *string - Time time.Time - TimePointer *time.Time - NullInt64 sql.NullInt64 -} - -type BelongsTo struct { - ID int - Name string -} - -type HasOne struct { - ID int - MyStructID uint -} - -type HasMany struct { - ID int - MyStructID uint - Name string -} - -type Many2Many struct { - ID int - Name string -} - -func TestParseSchema(t *testing.T) { - ParseSchema(&MyStruct{}) -} - -func TestParseBelongsToRel(t *testing.T) { - type MyStruct struct { - ID int - Name string - BelongsTo BelongsTo +func fieldEqual(got, expected *Field) error { + if expected.DBName != got.DBName { + return fmt.Errorf("field DBName should be %v, got %v", expected.DBName, got.DBName) } - ParseSchema(&MyStruct{}) -} - -func TestParseHasOneRel(t *testing.T) { - type MyStruct struct { - ID int - Name string - HasOne HasOne + if expected.Name != got.Name { + return fmt.Errorf("field Name should be %v, got %v", expected.Name, got.Name) } - ParseSchema(&MyStruct{}) -} - -func TestParseHasManyRel(t *testing.T) { - type MyStruct struct { - ID int - Name string - HasMany []HasMany + if !reflect.DeepEqual(expected.BindNames, got.BindNames) { + return fmt.Errorf("field BindNames should be %#v, got %#v", expected.BindNames, got.BindNames) } - ParseSchema(&MyStruct{}) -} - -func TestParseManyToManyRel(t *testing.T) { - type MyStruct struct { - ID int - Name string - HasMany []HasMany + if (expected.TagSettings == nil && len(got.TagSettings) != 0) && !reflect.DeepEqual(expected.TagSettings, got.TagSettings) { + return fmt.Errorf("field TagSettings should be %#v, got %#v", expected.TagSettings, got.TagSettings) } - ParseSchema(&MyStruct{}) + if expected.IsNormal != got.IsNormal { + return fmt.Errorf("field IsNormal should be %v, got %v", expected.IsNormal, got.IsNormal) + } + + if expected.IsPrimaryKey != got.IsPrimaryKey { + return fmt.Errorf("field IsPrimaryKey should be %v, got %v", expected.IsPrimaryKey, got.IsPrimaryKey) + } + + if expected.IsIgnored != got.IsIgnored { + return fmt.Errorf("field IsIgnored should be %v, got %v", expected.IsIgnored, got.IsIgnored) + } + + if expected.IsForeignKey != got.IsForeignKey { + return fmt.Errorf("field IsForeignKey should be %v, got %v", expected.IsForeignKey, got.IsForeignKey) + } + + if expected.DefaultValue != got.DefaultValue { + return fmt.Errorf("field DefaultValue should be %v, got %v", expected.DefaultValue, got.DefaultValue) + } + + if expected.HasDefaultValue != got.HasDefaultValue { + return fmt.Errorf("field HasDefaultValue should be %v, got %v", expected.HasDefaultValue, got.HasDefaultValue) + } + return nil +} + +func compareFields(fields []*Field, expectedFields []*Field, t *testing.T) { + if len(fields) != len(expectedFields) { + t.Errorf("expected has %v fields, but got %v", len(expectedFields), len(fields)) + } + + for _, expectedField := range expectedFields { + field := getSchemaField(expectedField.DBName, fields) + if field == nil { + t.Errorf("Field %#v is not found", expectedField.Name) + } else if err := fieldEqual(field, expectedField); err != nil { + t.Error(err) + } + } +} + +func TestParse(t *testing.T) { + type MyStruct struct { + ID int + Int uint + IntPointer *uint `gorm:"default:10"` + String string + StringPointer *string `gorm:"column:strp"` + Time time.Time + TimePointer *time.Time + NullInt64 sql.NullInt64 + } + + schema := Parse(&MyStruct{}) + + compareFields(schema.Fields, []*Field{ + {DBName: "id", Name: "ID", BindNames: []string{"ID"}, IsNormal: true, IsPrimaryKey: true}, + {DBName: "int", Name: "Int", BindNames: []string{"Int"}, IsNormal: true}, + {DBName: "int_pointer", Name: "IntPointer", BindNames: []string{"IntPointer"}, TagSettings: map[string]string{"DEFAULT": "10"}, IsNormal: true, HasDefaultValue: true, DefaultValue: "10"}, + {DBName: "string", Name: "String", BindNames: []string{"String"}, IsNormal: true}, + {DBName: "strp", Name: "StringPointer", BindNames: []string{"StringPointer"}, TagSettings: map[string]string{"COLUMN": "strp"}, IsNormal: true}, + {DBName: "time", Name: "Time", BindNames: []string{"Time"}, IsNormal: true}, + {DBName: "time_pointer", Name: "TimePointer", BindNames: []string{"TimePointer"}, IsNormal: true}, + {DBName: "null_int64", Name: "NullInt64", BindNames: []string{"NullInt64"}, IsNormal: true}, + }, t) } func TestEmbeddedStruct(t *testing.T) { } +func TestOverwriteEmbeddedStructFields(t *testing.T) { +} + func TestCustomizePrimaryKey(t *testing.T) { } diff --git a/schema/utils.go b/schema/utils.go index d8446383..b75f5ad6 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -86,14 +86,16 @@ func getPrimaryPrimaryField(fields []*Field) *Field { func parseTagSetting(tags reflect.StructTag) map[string]string { setting := map[string]string{} for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} { - tags := strings.Split(str, ";") - for _, value := range tags { - v := strings.Split(value, ":") - k := strings.TrimSpace(strings.ToUpper(v[0])) - if len(v) >= 2 { - setting[k] = strings.Join(v[1:], ":") - } else { - setting[k] = k + if str != "" { + tags := strings.Split(str, ";") + for _, value := range tags { + v := strings.Split(value, ":") + k := strings.TrimSpace(strings.ToUpper(v[0])) + if len(v) >= 2 { + setting[k] = strings.Join(v[1:], ":") + } else { + setting[k] = k + } } } }