diff --git a/dialects/common/utils/statement.go b/dialects/common/utils/statement.go index 487a90d6..319c5b31 100644 --- a/dialects/common/utils/statement.go +++ b/dialects/common/utils/statement.go @@ -3,7 +3,7 @@ package utils import ( "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/builder" - "github.com/jinzhu/gorm/model" + "github.com/jinzhu/gorm/schema" ) // DefaultTableNameHandler default table name handler @@ -12,7 +12,7 @@ var DefaultTableNameHandler = func(stmt *builder.Statement, tableName string) st } // GetCreatingAssignments get creating assignments -func GetCreatingAssignments(stmt *builder.Statement, errs *gorm.Errors) chan []model.Field { +func GetCreatingAssignments(stmt *builder.Statement, errs *gorm.Errors) chan []schema.Field { return nil } diff --git a/schema/schema.go b/schema/schema.go index 6df54961..1420d997 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -160,6 +160,20 @@ func Parse(dest interface{}) *Schema { if len(onConflictFields) > 0 { removeIdx := []int{} + updatePrimaryKey := func(field, conflictField *Field) { + if field != nil && field.IsPrimaryKey { + for i, p := range schema.PrimaryFields { + if p == field { + schema.PrimaryFields = append(schema.PrimaryFields[0:i], schema.PrimaryFields[i+1:]...) + } + } + } + + if conflictField != nil && conflictField.IsPrimaryKey { + schema.PrimaryFields = append(schema.PrimaryFields, conflictField) + } + } + for _, idx := range onConflictFields { conflictField := schema.Fields[idx] @@ -167,13 +181,33 @@ func Parse(dest interface{}) *Schema { if i != idx && conflictField.Name == field.Name { switch conflictField.TagSettings["ON_EMBEDDED_CONFLICT"] { case "replace": + // if original field is primary key, delete origianl one + // add conflicated one if it is primary key + if field.IsPrimaryKey { + updatePrimaryKey(field, conflictField) + } removeIdx = append(removeIdx, i) case "ignore": + // skip ignored field + updatePrimaryKey(conflictField, nil) removeIdx = append(removeIdx, idx) case "update": - for key, value := range conflictField.TagSettings { - field.TagSettings[key] = value + // if original field is primary key, delete origianl one + // add conflicated one if it is primary key + if field.IsPrimaryKey { + updatePrimaryKey(field, conflictField) } + for key, value := range field.TagSettings { + if _, ok := conflictField.TagSettings[key]; !ok { + conflictField.TagSettings[key] = value + } + } + + conflictField.BindNames = field.BindNames + if column, ok := conflictField.TagSettings["COLUMN"]; ok { + conflictField.DBName = column + } + *field = *conflictField removeIdx = append(removeIdx, idx) } } diff --git a/schema/schema_test.go b/schema/schema_test.go index e8247334..a2f1d24c 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -34,6 +34,32 @@ func TestParse(t *testing.T) { }, t) } +func TestCustomizePrimaryKey(t *testing.T) { + // on_embedded_conflict replace, ignore mode + type MyStruct struct { + ID string + Name string + Email string + } + + schema := Parse(&MyStruct{}) + expectedFields := []*Field{ + {DBName: "id", Name: "ID", BindNames: []string{"ID"}, IsNormal: true, IsPrimaryKey: true, TagSettings: map[string]string{"PRIMARY_KEY": "PRIMARY_KEY"}}, + } + compareFields(schema.PrimaryFields, expectedFields, t) + + type MyStruct2 struct { + ID string + Name string + Email string `gorm:"primary_key;"` + } + schema2 := Parse(&MyStruct2{}) + expectedFields2 := []*Field{ + {DBName: "email", Name: "Email", BindNames: []string{"Email"}, IsNormal: true, IsPrimaryKey: true, TagSettings: map[string]string{"PRIMARY_KEY": "PRIMARY_KEY"}}, + } + compareFields(schema2.PrimaryFields, expectedFields2, t) +} + func TestEmbeddedStruct(t *testing.T) { // Anonymous Embedded type EmbedStruct struct { @@ -146,16 +172,60 @@ func TestOverwriteEmbeddedStructFields(t *testing.T) { {DBName: "id", Name: "ID", BindNames: []string{"ID"}, IsNormal: true, IsPrimaryKey: true}, {DBName: "name", Name: "Name", BindNames: []string{"EmbedStruct", "Name"}, IsNormal: true}, {DBName: "my_name2", Name: "Name2", BindNames: []string{"Name2"}, IsNormal: true, TagSettings: map[string]string{"ON_EMBEDDED_CONFLICT": "ignore", "COLUMN": "my_name2"}}, - {DBName: "my_age", Name: "Age", BindNames: []string{"EmbedStruct", "Age"}, IsNormal: true, TagSettings: map[string]string{"ON_EMBEDDED_CONFLICT": "update", "COLUMN": "my_age"}}, + {DBName: "my_age2", Name: "Age", BindNames: []string{"EmbedStruct", "Age"}, IsNormal: true, TagSettings: map[string]string{"ON_EMBEDDED_CONFLICT": "update", "COLUMN": "my_age2"}}, {DBName: "role", Name: "Role", BindNames: []string{"EmbedStruct", "Role"}, IsNormal: true, HasDefaultValue: true, DefaultValue: "guest", TagSettings: map[string]string{"COLUMN": "Role"}}, } compareFields(schema2.Fields, expectedFields2, t) } -func TestCustomizePrimaryKey(t *testing.T) { -} +func TestOverwriteEmbeddedStructPrimaryFields(t *testing.T) { + type EmbedStruct struct { + Name string `gorm:"primary_key"` + Email string + } -func TestCompositePrimaryKeys(t *testing.T) { + // on_embedded_conflict replace, ignore mode + type MyStruct struct { + ID string + EmbedStruct + Name string `gorm:"on_embedded_conflict:update;column:my_name"` + Email string `gorm:"primary_key;on_embedded_conflict:ignore;column:my_email"` + } + + schema := Parse(&MyStruct{}) + expectedFields := []*Field{ + {DBName: "id", Name: "ID", BindNames: []string{"ID"}, IsNormal: true, IsPrimaryKey: true, TagSettings: map[string]string{"PRIMARY_KEY": "PRIMARY_KEY"}}, + } + compareFields(schema.PrimaryFields, expectedFields, t) + + // on_embedded_conflict update mode, ignore mode w/o corresponding field + type MyStruct2 struct { + ID string + EmbedStruct + Name string `gorm:"on_embedded_conflict:ignore;column:my_name2"` + Email string `gorm:"primary_key;on_embedded_conflict:update"` + } + + schema2 := Parse(&MyStruct2{}) + expectedFields2 := []*Field{ + {DBName: "name", Name: "Name", BindNames: []string{"EmbedStruct", "Name"}, IsNormal: true, IsPrimaryKey: true, TagSettings: map[string]string{"PRIMARY_KEY": "PRIMARY_KEY"}}, + {DBName: "email", Name: "Email", BindNames: []string{"EmbedStruct", "Email"}, IsNormal: true, IsPrimaryKey: true, TagSettings: map[string]string{"PRIMARY_KEY": "PRIMARY_KEY", "ON_EMBEDDED_CONFLICT": "update"}}, + } + compareFields(schema2.PrimaryFields, expectedFields2, t) + + // on_embedded_conflict update mode, ignore mode w/o corresponding field + type MyStruct3 struct { + ID string + EmbedStruct + Name string `gorm:"on_embedded_conflict:replace;column:my_name2"` + Email string `gorm:"primary_key;on_embedded_conflict:replace"` + } + + schema3 := Parse(&MyStruct3{}) + expectedFields3 := []*Field{ + {DBName: "email", Name: "Email", BindNames: []string{"Email"}, IsNormal: true, IsPrimaryKey: true, TagSettings: map[string]string{"PRIMARY_KEY": "PRIMARY_KEY", "ON_EMBEDDED_CONFLICT": "update"}}, + } + compareFields(schema3.PrimaryFields, expectedFields3, t) } //////////////////////////////////////////////////////////////////////////////// @@ -163,7 +233,14 @@ func TestCompositePrimaryKeys(t *testing.T) { //////////////////////////////////////////////////////////////////////////////// func compareFields(fields []*Field, expectedFields []*Field, t *testing.T) { if len(fields) != len(expectedFields) { - t.Errorf("expected has %v fields, but got %v", len(expectedFields), len(fields)) + var exptectedNames, gotNames []string + for _, field := range fields { + gotNames = append(gotNames, field.Name) + } + for _, field := range expectedFields { + exptectedNames = append(exptectedNames, field.Name) + } + t.Errorf("expected has %v (%#v) fields, but got %v (%#v)", len(expectedFields), exptectedNames, len(fields), gotNames) } for _, expectedField := range expectedFields {