Test on_embedded_conflict mode with primary key
This commit is contained in:
		
							parent
							
								
									45df6277c4
								
							
						
					
					
						commit
						3c9dee889f
					
				@ -3,7 +3,7 @@ package utils
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/jinzhu/gorm"
 | 
			
		||||
	"github.com/jinzhu/gorm/builder"
 | 
			
		||||
	"github.com/jinzhu/gorm/model"
 | 
			
		||||
	"github.com/jinzhu/gorm/schema"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// DefaultTableNameHandler default table name handler
 | 
			
		||||
@ -12,7 +12,7 @@ var DefaultTableNameHandler = func(stmt *builder.Statement, tableName string) st
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetCreatingAssignments get creating assignments
 | 
			
		||||
func GetCreatingAssignments(stmt *builder.Statement, errs *gorm.Errors) chan []model.Field {
 | 
			
		||||
func GetCreatingAssignments(stmt *builder.Statement, errs *gorm.Errors) chan []schema.Field {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -160,6 +160,20 @@ func Parse(dest interface{}) *Schema {
 | 
			
		||||
	if len(onConflictFields) > 0 {
 | 
			
		||||
		removeIdx := []int{}
 | 
			
		||||
 | 
			
		||||
		updatePrimaryKey := func(field, conflictField *Field) {
 | 
			
		||||
			if field != nil && field.IsPrimaryKey {
 | 
			
		||||
				for i, p := range schema.PrimaryFields {
 | 
			
		||||
					if p == field {
 | 
			
		||||
						schema.PrimaryFields = append(schema.PrimaryFields[0:i], schema.PrimaryFields[i+1:]...)
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if conflictField != nil && conflictField.IsPrimaryKey {
 | 
			
		||||
				schema.PrimaryFields = append(schema.PrimaryFields, conflictField)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		for _, idx := range onConflictFields {
 | 
			
		||||
			conflictField := schema.Fields[idx]
 | 
			
		||||
 | 
			
		||||
@ -167,13 +181,33 @@ func Parse(dest interface{}) *Schema {
 | 
			
		||||
				if i != idx && conflictField.Name == field.Name {
 | 
			
		||||
					switch conflictField.TagSettings["ON_EMBEDDED_CONFLICT"] {
 | 
			
		||||
					case "replace":
 | 
			
		||||
						// if original field is primary key, delete origianl one
 | 
			
		||||
						// add conflicated one if it is primary key
 | 
			
		||||
						if field.IsPrimaryKey {
 | 
			
		||||
							updatePrimaryKey(field, conflictField)
 | 
			
		||||
						}
 | 
			
		||||
						removeIdx = append(removeIdx, i)
 | 
			
		||||
					case "ignore":
 | 
			
		||||
						// skip ignored field
 | 
			
		||||
						updatePrimaryKey(conflictField, nil)
 | 
			
		||||
						removeIdx = append(removeIdx, idx)
 | 
			
		||||
					case "update":
 | 
			
		||||
						for key, value := range conflictField.TagSettings {
 | 
			
		||||
							field.TagSettings[key] = value
 | 
			
		||||
						// if original field is primary key, delete origianl one
 | 
			
		||||
						// add conflicated one if it is primary key
 | 
			
		||||
						if field.IsPrimaryKey {
 | 
			
		||||
							updatePrimaryKey(field, conflictField)
 | 
			
		||||
						}
 | 
			
		||||
						for key, value := range field.TagSettings {
 | 
			
		||||
							if _, ok := conflictField.TagSettings[key]; !ok {
 | 
			
		||||
								conflictField.TagSettings[key] = value
 | 
			
		||||
							}
 | 
			
		||||
						}
 | 
			
		||||
 | 
			
		||||
						conflictField.BindNames = field.BindNames
 | 
			
		||||
						if column, ok := conflictField.TagSettings["COLUMN"]; ok {
 | 
			
		||||
							conflictField.DBName = column
 | 
			
		||||
						}
 | 
			
		||||
						*field = *conflictField
 | 
			
		||||
						removeIdx = append(removeIdx, idx)
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
@ -34,6 +34,32 @@ func TestParse(t *testing.T) {
 | 
			
		||||
	}, t)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestCustomizePrimaryKey(t *testing.T) {
 | 
			
		||||
	// on_embedded_conflict replace, ignore mode
 | 
			
		||||
	type MyStruct struct {
 | 
			
		||||
		ID    string
 | 
			
		||||
		Name  string
 | 
			
		||||
		Email string
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	schema := Parse(&MyStruct{})
 | 
			
		||||
	expectedFields := []*Field{
 | 
			
		||||
		{DBName: "id", Name: "ID", BindNames: []string{"ID"}, IsNormal: true, IsPrimaryKey: true, TagSettings: map[string]string{"PRIMARY_KEY": "PRIMARY_KEY"}},
 | 
			
		||||
	}
 | 
			
		||||
	compareFields(schema.PrimaryFields, expectedFields, t)
 | 
			
		||||
 | 
			
		||||
	type MyStruct2 struct {
 | 
			
		||||
		ID    string
 | 
			
		||||
		Name  string
 | 
			
		||||
		Email string `gorm:"primary_key;"`
 | 
			
		||||
	}
 | 
			
		||||
	schema2 := Parse(&MyStruct2{})
 | 
			
		||||
	expectedFields2 := []*Field{
 | 
			
		||||
		{DBName: "email", Name: "Email", BindNames: []string{"Email"}, IsNormal: true, IsPrimaryKey: true, TagSettings: map[string]string{"PRIMARY_KEY": "PRIMARY_KEY"}},
 | 
			
		||||
	}
 | 
			
		||||
	compareFields(schema2.PrimaryFields, expectedFields2, t)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestEmbeddedStruct(t *testing.T) {
 | 
			
		||||
	// Anonymous Embedded
 | 
			
		||||
	type EmbedStruct struct {
 | 
			
		||||
@ -146,16 +172,60 @@ func TestOverwriteEmbeddedStructFields(t *testing.T) {
 | 
			
		||||
		{DBName: "id", Name: "ID", BindNames: []string{"ID"}, IsNormal: true, IsPrimaryKey: true},
 | 
			
		||||
		{DBName: "name", Name: "Name", BindNames: []string{"EmbedStruct", "Name"}, IsNormal: true},
 | 
			
		||||
		{DBName: "my_name2", Name: "Name2", BindNames: []string{"Name2"}, IsNormal: true, TagSettings: map[string]string{"ON_EMBEDDED_CONFLICT": "ignore", "COLUMN": "my_name2"}},
 | 
			
		||||
		{DBName: "my_age", Name: "Age", BindNames: []string{"EmbedStruct", "Age"}, IsNormal: true, TagSettings: map[string]string{"ON_EMBEDDED_CONFLICT": "update", "COLUMN": "my_age"}},
 | 
			
		||||
		{DBName: "my_age2", Name: "Age", BindNames: []string{"EmbedStruct", "Age"}, IsNormal: true, TagSettings: map[string]string{"ON_EMBEDDED_CONFLICT": "update", "COLUMN": "my_age2"}},
 | 
			
		||||
		{DBName: "role", Name: "Role", BindNames: []string{"EmbedStruct", "Role"}, IsNormal: true, HasDefaultValue: true, DefaultValue: "guest", TagSettings: map[string]string{"COLUMN": "Role"}},
 | 
			
		||||
	}
 | 
			
		||||
	compareFields(schema2.Fields, expectedFields2, t)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestCustomizePrimaryKey(t *testing.T) {
 | 
			
		||||
}
 | 
			
		||||
func TestOverwriteEmbeddedStructPrimaryFields(t *testing.T) {
 | 
			
		||||
	type EmbedStruct struct {
 | 
			
		||||
		Name  string `gorm:"primary_key"`
 | 
			
		||||
		Email string
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
func TestCompositePrimaryKeys(t *testing.T) {
 | 
			
		||||
	// on_embedded_conflict replace, ignore mode
 | 
			
		||||
	type MyStruct struct {
 | 
			
		||||
		ID string
 | 
			
		||||
		EmbedStruct
 | 
			
		||||
		Name  string `gorm:"on_embedded_conflict:update;column:my_name"`
 | 
			
		||||
		Email string `gorm:"primary_key;on_embedded_conflict:ignore;column:my_email"`
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	schema := Parse(&MyStruct{})
 | 
			
		||||
	expectedFields := []*Field{
 | 
			
		||||
		{DBName: "id", Name: "ID", BindNames: []string{"ID"}, IsNormal: true, IsPrimaryKey: true, TagSettings: map[string]string{"PRIMARY_KEY": "PRIMARY_KEY"}},
 | 
			
		||||
	}
 | 
			
		||||
	compareFields(schema.PrimaryFields, expectedFields, t)
 | 
			
		||||
 | 
			
		||||
	// on_embedded_conflict update mode, ignore mode w/o corresponding field
 | 
			
		||||
	type MyStruct2 struct {
 | 
			
		||||
		ID string
 | 
			
		||||
		EmbedStruct
 | 
			
		||||
		Name  string `gorm:"on_embedded_conflict:ignore;column:my_name2"`
 | 
			
		||||
		Email string `gorm:"primary_key;on_embedded_conflict:update"`
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	schema2 := Parse(&MyStruct2{})
 | 
			
		||||
	expectedFields2 := []*Field{
 | 
			
		||||
		{DBName: "name", Name: "Name", BindNames: []string{"EmbedStruct", "Name"}, IsNormal: true, IsPrimaryKey: true, TagSettings: map[string]string{"PRIMARY_KEY": "PRIMARY_KEY"}},
 | 
			
		||||
		{DBName: "email", Name: "Email", BindNames: []string{"EmbedStruct", "Email"}, IsNormal: true, IsPrimaryKey: true, TagSettings: map[string]string{"PRIMARY_KEY": "PRIMARY_KEY", "ON_EMBEDDED_CONFLICT": "update"}},
 | 
			
		||||
	}
 | 
			
		||||
	compareFields(schema2.PrimaryFields, expectedFields2, t)
 | 
			
		||||
 | 
			
		||||
	// on_embedded_conflict update mode, ignore mode w/o corresponding field
 | 
			
		||||
	type MyStruct3 struct {
 | 
			
		||||
		ID string
 | 
			
		||||
		EmbedStruct
 | 
			
		||||
		Name  string `gorm:"on_embedded_conflict:replace;column:my_name2"`
 | 
			
		||||
		Email string `gorm:"primary_key;on_embedded_conflict:replace"`
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	schema3 := Parse(&MyStruct3{})
 | 
			
		||||
	expectedFields3 := []*Field{
 | 
			
		||||
		{DBName: "email", Name: "Email", BindNames: []string{"Email"}, IsNormal: true, IsPrimaryKey: true, TagSettings: map[string]string{"PRIMARY_KEY": "PRIMARY_KEY", "ON_EMBEDDED_CONFLICT": "update"}},
 | 
			
		||||
	}
 | 
			
		||||
	compareFields(schema3.PrimaryFields, expectedFields3, t)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
////////////////////////////////////////////////////////////////////////////////
 | 
			
		||||
@ -163,7 +233,14 @@ func TestCompositePrimaryKeys(t *testing.T) {
 | 
			
		||||
////////////////////////////////////////////////////////////////////////////////
 | 
			
		||||
func compareFields(fields []*Field, expectedFields []*Field, t *testing.T) {
 | 
			
		||||
	if len(fields) != len(expectedFields) {
 | 
			
		||||
		t.Errorf("expected has %v fields, but got %v", len(expectedFields), len(fields))
 | 
			
		||||
		var exptectedNames, gotNames []string
 | 
			
		||||
		for _, field := range fields {
 | 
			
		||||
			gotNames = append(gotNames, field.Name)
 | 
			
		||||
		}
 | 
			
		||||
		for _, field := range expectedFields {
 | 
			
		||||
			exptectedNames = append(exptectedNames, field.Name)
 | 
			
		||||
		}
 | 
			
		||||
		t.Errorf("expected has %v (%#v) fields, but got %v (%#v)", len(expectedFields), exptectedNames, len(fields), gotNames)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, expectedField := range expectedFields {
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user