From ef35e82691d04d1332e80d488549e2bb3a6d4770 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 26 Feb 2018 21:30:30 +0800 Subject: [PATCH] Add MainPrimaryField method for schema --- schema/relationship.go | 6 +++--- schema/schema.go | 13 +++++++++++++ schema/utils.go | 9 --------- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/schema/relationship.go b/schema/relationship.go index 87aba5be..d907eba8 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -105,7 +105,7 @@ func buildToOneRel(field *Field, sourceSchema *Schema) { } } if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { - associationForeignKeys = []string{getPrimaryPrimaryField(sourceSchema.PrimaryFields).DBName} + associationForeignKeys = []string{sourceSchema.MainPrimaryField().DBName} } } else if len(foreignKeys) != len(associationForeignKeys) { sourceSchema.ParseErrors = append(sourceSchema.ParseErrors, errors.New("invalid foreign keys, should have same length")) @@ -165,7 +165,7 @@ func buildToOneRel(field *Field, sourceSchema *Schema) { } } if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { - associationForeignKeys = []string{getPrimaryPrimaryField(destSchema.PrimaryFields).DBName} + associationForeignKeys = []string{destSchema.MainPrimaryField().DBName} } } else if len(foreignKeys) != len(associationForeignKeys) { sourceSchema.ParseErrors = append(sourceSchema.ParseErrors, errors.New("invalid foreign keys, should have same length")) @@ -332,7 +332,7 @@ func buildToManyRel(field *Field, sourceSchema *Schema) { } } if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { - associationForeignKeys = []string{getPrimaryPrimaryField(sourceSchema.PrimaryFields).DBName} + associationForeignKeys = []string{sourceSchema.MainPrimaryField().DBName} } } else if len(foreignKeys) != len(associationForeignKeys) { sourceSchema.ParseErrors = append(sourceSchema.ParseErrors, errors.New("invalid foreign keys, should have same length")) diff --git a/schema/schema.go b/schema/schema.go index c01b67cc..ee5f5c54 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -242,6 +242,19 @@ func Parse(dest interface{}) *Schema { return &schema } +// MainPrimaryField returns main primary field, usually the field with db name "id" or the first primary field +func (schema *Schema) MainPrimaryField() *Field { + for _, field := range schema.PrimaryFields { + if field.DBName == "id" { + return field + } + } + if len(schema.PrimaryFields) > 0 { + return schema.PrimaryFields[0] + } + return nil +} + func (schemaField *Field) clone() *Field { clone := *schemaField diff --git a/schema/utils.go b/schema/utils.go index 1e11a516..5ae703a0 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -74,15 +74,6 @@ func getSchemaField(name string, fields []*Field) *Field { return nil } -func getPrimaryPrimaryField(fields []*Field) *Field { - for _, field := range fields { - if field.DBName == "id" { - return field - } - } - return fields[0] -} - func parseTagSetting(tags reflect.StructTag) map[string]string { setting := map[string]string{} for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} {