Add MainPrimaryField method for schema

This commit is contained in:
Jinzhu 2018-02-26 21:30:30 +08:00
parent 26371b617c
commit ef35e82691
3 changed files with 16 additions and 12 deletions

View File

@ -105,7 +105,7 @@ func buildToOneRel(field *Field, sourceSchema *Schema) {
}
}
if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
associationForeignKeys = []string{getPrimaryPrimaryField(sourceSchema.PrimaryFields).DBName}
associationForeignKeys = []string{sourceSchema.MainPrimaryField().DBName}
}
} else if len(foreignKeys) != len(associationForeignKeys) {
sourceSchema.ParseErrors = append(sourceSchema.ParseErrors, errors.New("invalid foreign keys, should have same length"))
@ -165,7 +165,7 @@ func buildToOneRel(field *Field, sourceSchema *Schema) {
}
}
if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
associationForeignKeys = []string{getPrimaryPrimaryField(destSchema.PrimaryFields).DBName}
associationForeignKeys = []string{destSchema.MainPrimaryField().DBName}
}
} else if len(foreignKeys) != len(associationForeignKeys) {
sourceSchema.ParseErrors = append(sourceSchema.ParseErrors, errors.New("invalid foreign keys, should have same length"))
@ -332,7 +332,7 @@ func buildToManyRel(field *Field, sourceSchema *Schema) {
}
}
if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
associationForeignKeys = []string{getPrimaryPrimaryField(sourceSchema.PrimaryFields).DBName}
associationForeignKeys = []string{sourceSchema.MainPrimaryField().DBName}
}
} else if len(foreignKeys) != len(associationForeignKeys) {
sourceSchema.ParseErrors = append(sourceSchema.ParseErrors, errors.New("invalid foreign keys, should have same length"))

View File

@ -242,6 +242,19 @@ func Parse(dest interface{}) *Schema {
return &schema
}
// MainPrimaryField returns main primary field, usually the field with db name "id" or the first primary field
func (schema *Schema) MainPrimaryField() *Field {
for _, field := range schema.PrimaryFields {
if field.DBName == "id" {
return field
}
}
if len(schema.PrimaryFields) > 0 {
return schema.PrimaryFields[0]
}
return nil
}
func (schemaField *Field) clone() *Field {
clone := *schemaField

View File

@ -74,15 +74,6 @@ func getSchemaField(name string, fields []*Field) *Field {
return nil
}
func getPrimaryPrimaryField(fields []*Field) *Field {
for _, field := range fields {
if field.DBName == "id" {
return field
}
}
return fields[0]
}
func parseTagSetting(tags reflect.StructTag) map[string]string {
setting := map[string]string{}
for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} {