diff --git a/schema/schema.go b/schema/schema.go index 93dda6dc..6df54961 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -4,6 +4,7 @@ import ( "database/sql" "go/ast" "reflect" + "sort" "time" ) @@ -48,6 +49,7 @@ func Parse(dest interface{}) *Schema { } schema.ModelType = reflectType + onConflictFields := map[string]int{} for i := 0; i < reflectType.NumField(); i++ { fieldStruct := reflectType.Field(i) @@ -148,9 +150,42 @@ func Parse(dest interface{}) *Schema { field.DBName = ToDBName(fieldStruct.Name) } + if _, ok := field.TagSettings["ON_EMBEDDED_CONFLICT"]; ok { + onConflictFields[field.Name] = len(schema.Fields) + } + schema.Fields = append(schema.Fields, field) } + if len(onConflictFields) > 0 { + removeIdx := []int{} + + for _, idx := range onConflictFields { + conflictField := schema.Fields[idx] + + for i, field := range schema.Fields { + if i != idx && conflictField.Name == field.Name { + switch conflictField.TagSettings["ON_EMBEDDED_CONFLICT"] { + case "replace": + removeIdx = append(removeIdx, i) + case "ignore": + removeIdx = append(removeIdx, idx) + case "update": + for key, value := range conflictField.TagSettings { + field.TagSettings[key] = value + } + removeIdx = append(removeIdx, idx) + } + } + } + } + + sort.Ints(removeIdx) + for i := len(removeIdx) - 1; i >= 0; i-- { + schema.Fields = append(schema.Fields[0:removeIdx[i]], schema.Fields[removeIdx[i]+1:]...) + } + } + if len(schema.PrimaryFields) == 0 { if field := getSchemaField("id", schema.Fields); field != nil { field.IsPrimaryKey = true diff --git a/schema/schema_test.go b/schema/schema_test.go index f788c1da..e8247334 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -8,6 +8,174 @@ import ( "time" ) +func TestParse(t *testing.T) { + type MyStruct struct { + ID int + Int uint + IntPointer *uint `gorm:"default:10"` + String string + StringPointer *string `gorm:"column:strp"` + Time time.Time + TimePointer *time.Time + NullInt64 sql.NullInt64 + } + + schema := Parse(&MyStruct{}) + + compareFields(schema.Fields, []*Field{ + {DBName: "id", Name: "ID", BindNames: []string{"ID"}, IsNormal: true, IsPrimaryKey: true}, + {DBName: "int", Name: "Int", BindNames: []string{"Int"}, IsNormal: true}, + {DBName: "int_pointer", Name: "IntPointer", BindNames: []string{"IntPointer"}, IsNormal: true, HasDefaultValue: true, DefaultValue: "10", TagSettings: map[string]string{"DEFAULT": "10"}}, + {DBName: "string", Name: "String", BindNames: []string{"String"}, IsNormal: true}, + {DBName: "strp", Name: "StringPointer", BindNames: []string{"StringPointer"}, IsNormal: true, TagSettings: map[string]string{"COLUMN": "strp"}}, + {DBName: "time", Name: "Time", BindNames: []string{"Time"}, IsNormal: true}, + {DBName: "time_pointer", Name: "TimePointer", BindNames: []string{"TimePointer"}, IsNormal: true}, + {DBName: "null_int64", Name: "NullInt64", BindNames: []string{"NullInt64"}, IsNormal: true}, + }, t) +} + +func TestEmbeddedStruct(t *testing.T) { + // Anonymous Embedded + type EmbedStruct struct { + Name string + Age string `gorm:"column:my_age"` + Role string `gorm:"default:guest"` + } + + type MyStruct struct { + ID string + EmbedStruct + } + + schema := Parse(&MyStruct{}) + expectedFields := []*Field{ + {DBName: "id", Name: "ID", BindNames: []string{"ID"}, IsNormal: true, IsPrimaryKey: true}, + {DBName: "name", Name: "Name", BindNames: []string{"EmbedStruct", "Name"}, IsNormal: true}, + {DBName: "my_age", Name: "Age", BindNames: []string{"EmbedStruct", "Age"}, IsNormal: true, TagSettings: map[string]string{"COLUMN": "Age"}}, + {DBName: "role", Name: "Role", BindNames: []string{"EmbedStruct", "Role"}, IsNormal: true, HasDefaultValue: true, DefaultValue: "guest", TagSettings: map[string]string{"COLUMN": "Role"}}, + } + compareFields(schema.Fields, expectedFields, t) + + // Embedded with Tag + type MyStruct2 struct { + ID string + EmbedStruct EmbedStruct `gorm:"embedded"` + } + + schema2 := Parse(&MyStruct2{}) + expectedFields2 := []*Field{ + {DBName: "id", Name: "ID", BindNames: []string{"ID"}, IsNormal: true, IsPrimaryKey: true, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED"}}, + {DBName: "name", Name: "Name", BindNames: []string{"EmbedStruct", "Name"}, IsNormal: true, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED"}}, + {DBName: "my_age", Name: "Age", BindNames: []string{"EmbedStruct", "Age"}, IsNormal: true, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "COLUMN": "Age"}}, + {DBName: "role", Name: "Role", BindNames: []string{"EmbedStruct", "Role"}, IsNormal: true, HasDefaultValue: true, DefaultValue: "guest", TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "COLUMN": "Role"}}, + } + compareFields(schema2.Fields, expectedFields2, t) + + // Embedded with prefix + type MyStruct3 struct { + ID string + EmbedStruct `gorm:"EMBEDDED_PREFIX:my_"` + } + + schema3 := Parse(&MyStruct3{}) + expectedFields3 := []*Field{ + {DBName: "id", Name: "ID", BindNames: []string{"ID"}, IsNormal: true, IsPrimaryKey: true, TagSettings: map[string]string{"EMBEDDED_PREFIX": "my_"}}, + {DBName: "my_name", Name: "Name", BindNames: []string{"EmbedStruct", "Name"}, IsNormal: true, TagSettings: map[string]string{"EMBEDDED_PREFIX": "my_"}}, + {DBName: "my_my_age", Name: "Age", BindNames: []string{"EmbedStruct", "Age"}, IsNormal: true, TagSettings: map[string]string{"EMBEDDED_PREFIX": "my_", "COLUMN": "Age"}}, + {DBName: "my_role", Name: "Role", BindNames: []string{"EmbedStruct", "Role"}, IsNormal: true, HasDefaultValue: true, DefaultValue: "guest", TagSettings: map[string]string{"EMBEDDED_PREFIX": "my_", "COLUMN": "Role"}}, + } + compareFields(schema3.Fields, expectedFields3, t) +} + +func TestEmbeddedStructWithPrimaryKey(t *testing.T) { + type EmbedStruct struct { + ID string + Age string `gorm:"column:my_age"` + Role string `gorm:"default:guest"` + } + + type MyStruct struct { + Name string + EmbedStruct + } + + schema := Parse(&MyStruct{}) + expectedFields := []*Field{ + {DBName: "id", Name: "ID", BindNames: []string{"EmbedStruct", "ID"}, IsNormal: true, IsPrimaryKey: true}, + {DBName: "name", Name: "Name", BindNames: []string{"Name"}, IsNormal: true}, + {DBName: "my_age", Name: "Age", BindNames: []string{"EmbedStruct", "Age"}, IsNormal: true, TagSettings: map[string]string{"COLUMN": "Age"}}, + {DBName: "role", Name: "Role", BindNames: []string{"EmbedStruct", "Role"}, IsNormal: true, HasDefaultValue: true, DefaultValue: "guest", TagSettings: map[string]string{"COLUMN": "Role"}}, + } + compareFields(schema.Fields, expectedFields, t) +} + +func TestOverwriteEmbeddedStructFields(t *testing.T) { + type EmbedStruct struct { + Name string + Age string `gorm:"column:my_age"` + Role string `gorm:"default:guest"` + } + + // on_embedded_conflict replace, ignore mode + type MyStruct struct { + ID string + EmbedStruct + Age string `gorm:"on_embedded_conflict:replace;column:my_age2"` + Name string `gorm:"on_embedded_conflict:ignore;column:my_name"` + } + + schema := Parse(&MyStruct{}) + expectedFields := []*Field{ + {DBName: "id", Name: "ID", BindNames: []string{"ID"}, IsNormal: true, IsPrimaryKey: true}, + {DBName: "name", Name: "Name", BindNames: []string{"EmbedStruct", "Name"}, IsNormal: true}, + {DBName: "my_age2", Name: "Age", BindNames: []string{"Age"}, IsNormal: true, TagSettings: map[string]string{"ON_EMBEDDED_CONFLICT": "replace", "COLUMN": "my_age2"}}, + {DBName: "role", Name: "Role", BindNames: []string{"EmbedStruct", "Role"}, IsNormal: true, HasDefaultValue: true, DefaultValue: "guest", TagSettings: map[string]string{"COLUMN": "Role"}}, + } + compareFields(schema.Fields, expectedFields, t) + + // on_embedded_conflict update mode, ignore mode w/o corresponding field + type MyStruct2 struct { + ID string + EmbedStruct + Age string `gorm:"on_embedded_conflict:update;column:my_age2"` + Name2 string `gorm:"on_embedded_conflict:ignore;column:my_name2"` + } + + schema2 := Parse(&MyStruct2{}) + expectedFields2 := []*Field{ + {DBName: "id", Name: "ID", BindNames: []string{"ID"}, IsNormal: true, IsPrimaryKey: true}, + {DBName: "name", Name: "Name", BindNames: []string{"EmbedStruct", "Name"}, IsNormal: true}, + {DBName: "my_name2", Name: "Name2", BindNames: []string{"Name2"}, IsNormal: true, TagSettings: map[string]string{"ON_EMBEDDED_CONFLICT": "ignore", "COLUMN": "my_name2"}}, + {DBName: "my_age", Name: "Age", BindNames: []string{"EmbedStruct", "Age"}, IsNormal: true, TagSettings: map[string]string{"ON_EMBEDDED_CONFLICT": "update", "COLUMN": "my_age"}}, + {DBName: "role", Name: "Role", BindNames: []string{"EmbedStruct", "Role"}, IsNormal: true, HasDefaultValue: true, DefaultValue: "guest", TagSettings: map[string]string{"COLUMN": "Role"}}, + } + compareFields(schema2.Fields, expectedFields2, t) +} + +func TestCustomizePrimaryKey(t *testing.T) { +} + +func TestCompositePrimaryKeys(t *testing.T) { +} + +//////////////////////////////////////////////////////////////////////////////// +// Test Helpers +//////////////////////////////////////////////////////////////////////////////// +func compareFields(fields []*Field, expectedFields []*Field, t *testing.T) { + if len(fields) != len(expectedFields) { + t.Errorf("expected has %v fields, but got %v", len(expectedFields), len(fields)) + } + + for _, expectedField := range expectedFields { + field := getSchemaField(expectedField.DBName, fields) + if field == nil { + t.Errorf("Field %#v is not found", expectedField.Name) + } else if err := fieldEqual(field, expectedField); err != nil { + t.Error(err) + } + } +} + func fieldEqual(got, expected *Field) error { if expected.DBName != got.DBName { return fmt.Errorf("field DBName should be %v, got %v", expected.DBName, got.DBName) @@ -50,56 +218,3 @@ func fieldEqual(got, expected *Field) error { } return nil } - -func compareFields(fields []*Field, expectedFields []*Field, t *testing.T) { - if len(fields) != len(expectedFields) { - t.Errorf("expected has %v fields, but got %v", len(expectedFields), len(fields)) - } - - for _, expectedField := range expectedFields { - field := getSchemaField(expectedField.DBName, fields) - if field == nil { - t.Errorf("Field %#v is not found", expectedField.Name) - } else if err := fieldEqual(field, expectedField); err != nil { - t.Error(err) - } - } -} - -func TestParse(t *testing.T) { - type MyStruct struct { - ID int - Int uint - IntPointer *uint `gorm:"default:10"` - String string - StringPointer *string `gorm:"column:strp"` - Time time.Time - TimePointer *time.Time - NullInt64 sql.NullInt64 - } - - schema := Parse(&MyStruct{}) - - compareFields(schema.Fields, []*Field{ - {DBName: "id", Name: "ID", BindNames: []string{"ID"}, IsNormal: true, IsPrimaryKey: true}, - {DBName: "int", Name: "Int", BindNames: []string{"Int"}, IsNormal: true}, - {DBName: "int_pointer", Name: "IntPointer", BindNames: []string{"IntPointer"}, TagSettings: map[string]string{"DEFAULT": "10"}, IsNormal: true, HasDefaultValue: true, DefaultValue: "10"}, - {DBName: "string", Name: "String", BindNames: []string{"String"}, IsNormal: true}, - {DBName: "strp", Name: "StringPointer", BindNames: []string{"StringPointer"}, TagSettings: map[string]string{"COLUMN": "strp"}, IsNormal: true}, - {DBName: "time", Name: "Time", BindNames: []string{"Time"}, IsNormal: true}, - {DBName: "time_pointer", Name: "TimePointer", BindNames: []string{"TimePointer"}, IsNormal: true}, - {DBName: "null_int64", Name: "NullInt64", BindNames: []string{"NullInt64"}, IsNormal: true}, - }, t) -} - -func TestEmbeddedStruct(t *testing.T) { -} - -func TestOverwriteEmbeddedStructFields(t *testing.T) { -} - -func TestCustomizePrimaryKey(t *testing.T) { -} - -func TestCompositePrimaryKeys(t *testing.T) { -}