Refactor Association Mode
This commit is contained in:
parent
8d716be896
commit
41870191b0
185
association.go
185
association.go
@ -1,12 +1,11 @@
|
|||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Association Association Mode contains some helper methods to handle relationship things easily.
|
||||||
type Association struct {
|
type Association struct {
|
||||||
Scope *Scope
|
Scope *Scope
|
||||||
Column string
|
Column string
|
||||||
@ -14,86 +13,13 @@ type Association struct {
|
|||||||
Field *Field
|
Field *Field
|
||||||
}
|
}
|
||||||
|
|
||||||
func (association *Association) setErr(err error) *Association {
|
// Find find out all related associations
|
||||||
if err != nil {
|
|
||||||
association.Error = err
|
|
||||||
}
|
|
||||||
return association
|
|
||||||
}
|
|
||||||
|
|
||||||
func (association *Association) Find(value interface{}) *Association {
|
func (association *Association) Find(value interface{}) *Association {
|
||||||
association.Scope.related(value, association.Column)
|
association.Scope.related(value, association.Column)
|
||||||
return association.setErr(association.Scope.db.Error)
|
return association.setErr(association.Scope.db.Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (association *Association) saveAssociations(values ...interface{}) *Association {
|
// Append append new associations for many2many, has_many, will replace current association for has_one, belongs_to
|
||||||
scope := association.Scope
|
|
||||||
field := association.Field
|
|
||||||
relationship := association.Field.Relationship
|
|
||||||
|
|
||||||
saveAssociation := func(reflectValue reflect.Value) {
|
|
||||||
// value has to been pointer
|
|
||||||
if reflectValue.Kind() != reflect.Ptr {
|
|
||||||
reflectPtr := reflect.New(reflectValue.Type())
|
|
||||||
reflectPtr.Elem().Set(reflectValue)
|
|
||||||
reflectValue = reflectPtr
|
|
||||||
}
|
|
||||||
|
|
||||||
// value has to been saved for many2many
|
|
||||||
if relationship.Kind == "many_to_many" {
|
|
||||||
if scope.New(reflectValue.Interface()).PrimaryKeyZero() {
|
|
||||||
association.setErr(scope.NewDB().Save(reflectValue.Interface()).Error)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Assign Fields
|
|
||||||
var fieldType = field.Field.Type()
|
|
||||||
var setFieldBackToValue, setSliceFieldBackToValue bool
|
|
||||||
if reflectValue.Type().AssignableTo(fieldType) {
|
|
||||||
field.Set(reflectValue)
|
|
||||||
} else if reflectValue.Type().Elem().AssignableTo(fieldType) {
|
|
||||||
// if field's type is struct, then need to set value back to argument after save
|
|
||||||
setFieldBackToValue = true
|
|
||||||
field.Set(reflectValue.Elem())
|
|
||||||
} else if fieldType.Kind() == reflect.Slice {
|
|
||||||
if reflectValue.Type().AssignableTo(fieldType.Elem()) {
|
|
||||||
field.Set(reflect.Append(field.Field, reflectValue))
|
|
||||||
} else if reflectValue.Type().Elem().AssignableTo(fieldType.Elem()) {
|
|
||||||
// if field's type is slice of struct, then need to set value back to argument after save
|
|
||||||
setSliceFieldBackToValue = true
|
|
||||||
field.Set(reflect.Append(field.Field, reflectValue.Elem()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if relationship.Kind == "many_to_many" {
|
|
||||||
association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, reflectValue.Interface()))
|
|
||||||
} else {
|
|
||||||
association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).Error)
|
|
||||||
|
|
||||||
if setFieldBackToValue {
|
|
||||||
reflectValue.Elem().Set(field.Field)
|
|
||||||
} else if setSliceFieldBackToValue {
|
|
||||||
reflectValue.Elem().Set(field.Field.Index(field.Field.Len() - 1))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, value := range values {
|
|
||||||
reflectValue := reflect.ValueOf(value)
|
|
||||||
indirectReflectValue := reflect.Indirect(reflectValue)
|
|
||||||
if indirectReflectValue.Kind() == reflect.Struct {
|
|
||||||
saveAssociation(reflectValue)
|
|
||||||
} else if indirectReflectValue.Kind() == reflect.Slice {
|
|
||||||
for i := 0; i < indirectReflectValue.Len(); i++ {
|
|
||||||
saveAssociation(indirectReflectValue.Index(i))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
association.setErr(errors.New("invalid value type"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return association
|
|
||||||
}
|
|
||||||
|
|
||||||
func (association *Association) Append(values ...interface{}) *Association {
|
func (association *Association) Append(values ...interface{}) *Association {
|
||||||
if relationship := association.Field.Relationship; relationship.Kind == "has_one" {
|
if relationship := association.Field.Relationship; relationship.Kind == "has_one" {
|
||||||
return association.Replace(values...)
|
return association.Replace(values...)
|
||||||
@ -101,6 +27,7 @@ func (association *Association) Append(values ...interface{}) *Association {
|
|||||||
return association.saveAssociations(values...)
|
return association.saveAssociations(values...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Replace replace current associations with new one
|
||||||
func (association *Association) Replace(values ...interface{}) *Association {
|
func (association *Association) Replace(values ...interface{}) *Association {
|
||||||
var (
|
var (
|
||||||
relationship = association.Field.Relationship
|
relationship = association.Field.Relationship
|
||||||
@ -115,7 +42,7 @@ func (association *Association) Replace(values ...interface{}) *Association {
|
|||||||
|
|
||||||
// Belongs To
|
// Belongs To
|
||||||
if relationship.Kind == "belongs_to" {
|
if relationship.Kind == "belongs_to" {
|
||||||
// Set foreign key to be null only when clearing value
|
// Set foreign key to be null when clearing value (length equals 0)
|
||||||
if len(values) == 0 {
|
if len(values) == 0 {
|
||||||
// Set foreign key to be nil
|
// Set foreign key to be nil
|
||||||
var foreignKeyMap = map[string]interface{}{}
|
var foreignKeyMap = map[string]interface{}{}
|
||||||
@ -125,29 +52,21 @@ func (association *Association) Replace(values ...interface{}) *Association {
|
|||||||
association.setErr(newDB.Model(scope.Value).UpdateColumn(foreignKeyMap).Error)
|
association.setErr(newDB.Model(scope.Value).UpdateColumn(foreignKeyMap).Error)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Relations
|
// Polymorphic Relations
|
||||||
if relationship.PolymorphicDBName != "" {
|
if relationship.PolymorphicDBName != "" {
|
||||||
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName())
|
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Relations except new created
|
// Relations except new created
|
||||||
if len(values) > 0 {
|
if len(values) > 0 {
|
||||||
var newPrimaryKeys [][]interface{}
|
|
||||||
var associationForeignFieldNames []string
|
var associationForeignFieldNames []string
|
||||||
|
|
||||||
if relationship.Kind == "many_to_many" {
|
if relationship.Kind == "many_to_many" {
|
||||||
// If many to many relations, get it from foreign key
|
|
||||||
associationForeignFieldNames = relationship.AssociationForeignFieldNames
|
associationForeignFieldNames = relationship.AssociationForeignFieldNames
|
||||||
} else {
|
} else {
|
||||||
// If other relations, get real primary keys
|
associationForeignFieldNames = relationship.AssociationForeignDBNames
|
||||||
for _, field := range scope.New(reflect.New(field.Type()).Interface()).Fields() {
|
|
||||||
if field.IsPrimaryKey {
|
|
||||||
associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
newPrimaryKeys = association.getPrimaryKeys(associationForeignFieldNames, field.Interface())
|
newPrimaryKeys := association.getPrimaryKeys(associationForeignFieldNames, field.Interface())
|
||||||
|
|
||||||
if len(newPrimaryKeys) > 0 {
|
if len(newPrimaryKeys) > 0 {
|
||||||
sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(newPrimaryKeys))
|
sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(newPrimaryKeys))
|
||||||
@ -156,13 +75,11 @@ func (association *Association) Replace(values ...interface{}) *Association {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if relationship.Kind == "many_to_many" {
|
if relationship.Kind == "many_to_many" {
|
||||||
for idx, foreignKey := range relationship.ForeignDBNames {
|
if sourcePrimaryKeys := association.getPrimaryKeys(relationship.ForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 {
|
||||||
if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok {
|
newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...)
|
||||||
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship))
|
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship))
|
||||||
|
}
|
||||||
} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
|
} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
|
||||||
var foreignKeyMap = map[string]interface{}{}
|
var foreignKeyMap = map[string]interface{}{}
|
||||||
for idx, foreignKey := range relationship.ForeignDBNames {
|
for idx, foreignKey := range relationship.ForeignDBNames {
|
||||||
@ -179,6 +96,7 @@ func (association *Association) Replace(values ...interface{}) *Association {
|
|||||||
return association
|
return association
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Delete remove relationship between source & passed arguments, but won't delete those arguments
|
||||||
func (association *Association) Delete(values ...interface{}) *Association {
|
func (association *Association) Delete(values ...interface{}) *Association {
|
||||||
var (
|
var (
|
||||||
relationship = association.Field.Relationship
|
relationship = association.Field.Relationship
|
||||||
@ -292,10 +210,12 @@ func (association *Association) Delete(values ...interface{}) *Association {
|
|||||||
return association
|
return association
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Clear remove relationship between source & current associations, won't delete those associations
|
||||||
func (association *Association) Clear() *Association {
|
func (association *Association) Clear() *Association {
|
||||||
return association.Replace()
|
return association.Replace()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Count return the count of current associations
|
||||||
func (association *Association) Count() int {
|
func (association *Association) Count() int {
|
||||||
var (
|
var (
|
||||||
count = 0
|
count = 0
|
||||||
@ -333,78 +253,3 @@ func (association *Association) Count() int {
|
|||||||
|
|
||||||
return count
|
return count
|
||||||
}
|
}
|
||||||
|
|
||||||
func (association *Association) getPrimaryKeys(columns []string, values ...interface{}) (results [][]interface{}) {
|
|
||||||
scope := association.Scope
|
|
||||||
|
|
||||||
for _, value := range values {
|
|
||||||
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
|
||||||
if reflectValue.Kind() == reflect.Slice {
|
|
||||||
for i := 0; i < reflectValue.Len(); i++ {
|
|
||||||
primaryKeys := []interface{}{}
|
|
||||||
newScope := scope.New(reflectValue.Index(i).Interface())
|
|
||||||
for _, column := range columns {
|
|
||||||
if field, ok := newScope.FieldByName(column); ok {
|
|
||||||
primaryKeys = append(primaryKeys, field.Field.Interface())
|
|
||||||
} else {
|
|
||||||
primaryKeys = append(primaryKeys, "")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
results = append(results, primaryKeys)
|
|
||||||
}
|
|
||||||
} else if reflectValue.Kind() == reflect.Struct {
|
|
||||||
newScope := scope.New(value)
|
|
||||||
var primaryKeys []interface{}
|
|
||||||
for _, column := range columns {
|
|
||||||
if field, ok := newScope.FieldByName(column); ok {
|
|
||||||
primaryKeys = append(primaryKeys, field.Field.Interface())
|
|
||||||
} else {
|
|
||||||
primaryKeys = append(primaryKeys, "")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
results = append(results, primaryKeys)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func toQueryMarks(primaryValues [][]interface{}) string {
|
|
||||||
var results []string
|
|
||||||
|
|
||||||
for _, primaryValue := range primaryValues {
|
|
||||||
var marks []string
|
|
||||||
for _ = range primaryValue {
|
|
||||||
marks = append(marks, "?")
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(marks) > 1 {
|
|
||||||
results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ",")))
|
|
||||||
} else {
|
|
||||||
results = append(results, strings.Join(marks, ""))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return strings.Join(results, ",")
|
|
||||||
}
|
|
||||||
|
|
||||||
func toQueryCondition(scope *Scope, columns []string) string {
|
|
||||||
var newColumns []string
|
|
||||||
for _, column := range columns {
|
|
||||||
newColumns = append(newColumns, scope.Quote(column))
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(columns) > 1 {
|
|
||||||
return fmt.Sprintf("(%v)", strings.Join(newColumns, ","))
|
|
||||||
}
|
|
||||||
return strings.Join(newColumns, ",")
|
|
||||||
}
|
|
||||||
|
|
||||||
func toQueryValues(primaryValues [][]interface{}) (values []interface{}) {
|
|
||||||
for _, primaryValue := range primaryValues {
|
|
||||||
for _, value := range primaryValue {
|
|
||||||
values = append(values, value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return values
|
|
||||||
}
|
|
||||||
|
158
association_utils.go
Normal file
158
association_utils.go
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
package gorm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (association *Association) setErr(err error) *Association {
|
||||||
|
if err != nil {
|
||||||
|
association.Error = err
|
||||||
|
}
|
||||||
|
return association
|
||||||
|
}
|
||||||
|
|
||||||
|
func (association *Association) saveAssociations(values ...interface{}) *Association {
|
||||||
|
scope := association.Scope
|
||||||
|
field := association.Field
|
||||||
|
relationship := association.Field.Relationship
|
||||||
|
|
||||||
|
saveAssociation := func(reflectValue reflect.Value) {
|
||||||
|
// value has to been pointer
|
||||||
|
if reflectValue.Kind() != reflect.Ptr {
|
||||||
|
reflectPtr := reflect.New(reflectValue.Type())
|
||||||
|
reflectPtr.Elem().Set(reflectValue)
|
||||||
|
reflectValue = reflectPtr
|
||||||
|
}
|
||||||
|
|
||||||
|
// value has to been saved for many2many
|
||||||
|
if relationship.Kind == "many_to_many" {
|
||||||
|
if scope.New(reflectValue.Interface()).PrimaryKeyZero() {
|
||||||
|
association.setErr(scope.NewDB().Save(reflectValue.Interface()).Error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assign Fields
|
||||||
|
var fieldType = field.Field.Type()
|
||||||
|
var setFieldBackToValue, setSliceFieldBackToValue bool
|
||||||
|
if reflectValue.Type().AssignableTo(fieldType) {
|
||||||
|
field.Set(reflectValue)
|
||||||
|
} else if reflectValue.Type().Elem().AssignableTo(fieldType) {
|
||||||
|
// if field's type is struct, then need to set value back to argument after save
|
||||||
|
setFieldBackToValue = true
|
||||||
|
field.Set(reflectValue.Elem())
|
||||||
|
} else if fieldType.Kind() == reflect.Slice {
|
||||||
|
if reflectValue.Type().AssignableTo(fieldType.Elem()) {
|
||||||
|
field.Set(reflect.Append(field.Field, reflectValue))
|
||||||
|
} else if reflectValue.Type().Elem().AssignableTo(fieldType.Elem()) {
|
||||||
|
// if field's type is slice of struct, then need to set value back to argument after save
|
||||||
|
setSliceFieldBackToValue = true
|
||||||
|
field.Set(reflect.Append(field.Field, reflectValue.Elem()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if relationship.Kind == "many_to_many" {
|
||||||
|
association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, reflectValue.Interface()))
|
||||||
|
} else {
|
||||||
|
association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).Error)
|
||||||
|
|
||||||
|
if setFieldBackToValue {
|
||||||
|
reflectValue.Elem().Set(field.Field)
|
||||||
|
} else if setSliceFieldBackToValue {
|
||||||
|
reflectValue.Elem().Set(field.Field.Index(field.Field.Len() - 1))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, value := range values {
|
||||||
|
reflectValue := reflect.ValueOf(value)
|
||||||
|
indirectReflectValue := reflect.Indirect(reflectValue)
|
||||||
|
if indirectReflectValue.Kind() == reflect.Struct {
|
||||||
|
saveAssociation(reflectValue)
|
||||||
|
} else if indirectReflectValue.Kind() == reflect.Slice {
|
||||||
|
for i := 0; i < indirectReflectValue.Len(); i++ {
|
||||||
|
saveAssociation(indirectReflectValue.Index(i))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
association.setErr(errors.New("invalid value type"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return association
|
||||||
|
}
|
||||||
|
|
||||||
|
func (association *Association) getPrimaryKeys(columns []string, values ...interface{}) (results [][]interface{}) {
|
||||||
|
scope := association.Scope
|
||||||
|
|
||||||
|
for _, value := range values {
|
||||||
|
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
||||||
|
if reflectValue.Kind() == reflect.Slice {
|
||||||
|
for i := 0; i < reflectValue.Len(); i++ {
|
||||||
|
primaryKeys := []interface{}{}
|
||||||
|
newScope := scope.New(reflectValue.Index(i).Interface())
|
||||||
|
for _, column := range columns {
|
||||||
|
if field, ok := newScope.FieldByName(column); ok {
|
||||||
|
primaryKeys = append(primaryKeys, field.Field.Interface())
|
||||||
|
} else {
|
||||||
|
primaryKeys = append(primaryKeys, "")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
results = append(results, primaryKeys)
|
||||||
|
}
|
||||||
|
} else if reflectValue.Kind() == reflect.Struct {
|
||||||
|
newScope := scope.New(value)
|
||||||
|
var primaryKeys []interface{}
|
||||||
|
for _, column := range columns {
|
||||||
|
if field, ok := newScope.FieldByName(column); ok {
|
||||||
|
primaryKeys = append(primaryKeys, field.Field.Interface())
|
||||||
|
} else {
|
||||||
|
primaryKeys = append(primaryKeys, "")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
results = append(results, primaryKeys)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func toQueryMarks(primaryValues [][]interface{}) string {
|
||||||
|
var results []string
|
||||||
|
|
||||||
|
for _, primaryValue := range primaryValues {
|
||||||
|
var marks []string
|
||||||
|
for _ = range primaryValue {
|
||||||
|
marks = append(marks, "?")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(marks) > 1 {
|
||||||
|
results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ",")))
|
||||||
|
} else {
|
||||||
|
results = append(results, strings.Join(marks, ""))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.Join(results, ",")
|
||||||
|
}
|
||||||
|
|
||||||
|
func toQueryCondition(scope *Scope, columns []string) string {
|
||||||
|
var newColumns []string
|
||||||
|
for _, column := range columns {
|
||||||
|
newColumns = append(newColumns, scope.Quote(column))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(columns) > 1 {
|
||||||
|
return fmt.Sprintf("(%v)", strings.Join(newColumns, ","))
|
||||||
|
}
|
||||||
|
return strings.Join(newColumns, ",")
|
||||||
|
}
|
||||||
|
|
||||||
|
func toQueryValues(primaryValues [][]interface{}) (values []interface{}) {
|
||||||
|
for _, primaryValue := range primaryValues {
|
||||||
|
for _, value := range primaryValue {
|
||||||
|
values = append(values, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return values
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user