Finish implement association support
This commit is contained in:
		
							parent
							
								
									20cb57b1ac
								
							
						
					
					
						commit
						0f21272c7f
					
				
							
								
								
									
										198
									
								
								association.go
									
									
									
									
									
								
							
							
						
						
									
										198
									
								
								association.go
									
									
									
									
									
								
							| @ -1,6 +1,7 @@ | |||||||
| package gorm | package gorm | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"reflect" | 	"reflect" | ||||||
| 
 | 
 | ||||||
| @ -34,16 +35,119 @@ func (db *DB) Association(column string) *Association { | |||||||
| 
 | 
 | ||||||
| func (association *Association) Find(out interface{}, conds ...interface{}) error { | func (association *Association) Find(out interface{}, conds ...interface{}) error { | ||||||
| 	if association.Error == nil { | 	if association.Error == nil { | ||||||
|  | 		var ( | ||||||
|  | 			tx         = association.DB | ||||||
|  | 			queryConds = association.Relationship.ToQueryConditions(tx.Statement.ReflectValue) | ||||||
|  | 		) | ||||||
|  | 
 | ||||||
|  | 		if association.Relationship.JoinTable != nil { | ||||||
|  | 			for _, queryClause := range association.Relationship.JoinTable.QueryClauses { | ||||||
|  | 				tx.Clauses(queryClause) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			tx.Clauses(clause.From{Joins: []clause.Join{{ | ||||||
|  | 				Table: clause.Table{Name: association.Relationship.JoinTable.Table}, | ||||||
|  | 				ON:    clause.Where{Exprs: queryConds}, | ||||||
|  | 			}}}) | ||||||
|  | 		} else { | ||||||
|  | 			tx.Clauses(clause.Where{Exprs: queryConds}) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		association.Error = tx.Find(out, conds...).Error | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return association.Error | 	return association.Error | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (association *Association) Append(values ...interface{}) error { | func (association *Association) Append(values ...interface{}) error { | ||||||
|  | 	if association.Error == nil { | ||||||
|  | 		switch association.Relationship.Type { | ||||||
|  | 		case schema.HasOne, schema.BelongsTo: | ||||||
|  | 			if len(values) > 0 { | ||||||
|  | 				association.Error = association.Replace(values...) | ||||||
|  | 			} | ||||||
|  | 		default: | ||||||
|  | 			association.saveAssociation(false, values...) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	return association.Error | 	return association.Error | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (association *Association) Replace(values ...interface{}) error { | func (association *Association) Replace(values ...interface{}) error { | ||||||
|  | 	if association.Error == nil { | ||||||
|  | 		association.saveAssociation(true, values...) | ||||||
|  | 		reflectValue := association.DB.Statement.ReflectValue | ||||||
|  | 		rel := association.Relationship | ||||||
|  | 
 | ||||||
|  | 		switch rel.Type { | ||||||
|  | 		case schema.HasOne, schema.HasMany: | ||||||
|  | 			var ( | ||||||
|  | 				primaryFields []*schema.Field | ||||||
|  | 				foreignKeys   []string | ||||||
|  | 				updateMap     = map[string]interface{}{} | ||||||
|  | 				modelValue    = reflect.New(rel.FieldSchema.ModelType).Interface() | ||||||
|  | 			) | ||||||
|  | 
 | ||||||
|  | 			for _, ref := range rel.References { | ||||||
|  | 				if ref.OwnPrimaryKey { | ||||||
|  | 					primaryFields = append(primaryFields, ref.PrimaryKey) | ||||||
|  | 				} else { | ||||||
|  | 					foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) | ||||||
|  | 					updateMap[ref.ForeignKey.DBName] = nil | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			_, values := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) | ||||||
|  | 			column, queryValues := schema.ToQueryValues(foreignKeys, values) | ||||||
|  | 			association.DB.Model(modelValue).Where(clause.IN{Column: column, Values: queryValues}).UpdateColumns(updateMap) | ||||||
|  | 		case schema.Many2Many: | ||||||
|  | 			var primaryFields, relPrimaryFields []*schema.Field | ||||||
|  | 			var foreignKeys, relForeignKeys []string | ||||||
|  | 			modelValue := reflect.New(rel.JoinTable.ModelType).Interface() | ||||||
|  | 			conds := []clause.Expression{} | ||||||
|  | 
 | ||||||
|  | 			for _, ref := range rel.References { | ||||||
|  | 				if ref.OwnPrimaryKey { | ||||||
|  | 					primaryFields = append(primaryFields, ref.PrimaryKey) | ||||||
|  | 					foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) | ||||||
|  | 				} else if ref.PrimaryValue != "" { | ||||||
|  | 					conds = append(conds, clause.Eq{ | ||||||
|  | 						Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, | ||||||
|  | 						Value:  ref.PrimaryValue, | ||||||
|  | 					}) | ||||||
|  | 				} else { | ||||||
|  | 					relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey) | ||||||
|  | 					relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			generateConds := func(rv reflect.Value) { | ||||||
|  | 				_, values := schema.GetIdentityFieldValuesMap(rv, primaryFields) | ||||||
|  | 				column, queryValues := schema.ToQueryValues(foreignKeys, values) | ||||||
|  | 
 | ||||||
|  | 				relValue := rel.Field.ReflectValueOf(rv) | ||||||
|  | 				_, relValues := schema.GetIdentityFieldValuesMap(relValue, relPrimaryFields) | ||||||
|  | 				relColumn, relQueryValues := schema.ToQueryValues(relForeignKeys, relValues) | ||||||
|  | 
 | ||||||
|  | 				conds = append(conds, clause.And( | ||||||
|  | 					clause.IN{Column: column, Values: queryValues}, | ||||||
|  | 					clause.Not(clause.IN{Column: relColumn, Values: relQueryValues}), | ||||||
|  | 				)) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			switch reflectValue.Kind() { | ||||||
|  | 			case reflect.Struct: | ||||||
|  | 				generateConds(reflectValue) | ||||||
|  | 			case reflect.Slice, reflect.Array: | ||||||
|  | 				for i := 0; i < reflectValue.Len(); i++ { | ||||||
|  | 					generateConds(reflectValue.Index(i)) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			association.DB.Where(conds).Delete(modelValue) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
| 	return association.Error | 	return association.Error | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| @ -78,7 +182,7 @@ func (association *Association) Delete(values ...interface{}) error { | |||||||
| 		column, values := schema.ToQueryValues(foreignKeys, relQueryValues) | 		column, values := schema.ToQueryValues(foreignKeys, relQueryValues) | ||||||
| 		tx.Where(clause.IN{Column: column, Values: values}) | 		tx.Where(clause.IN{Column: column, Values: values}) | ||||||
| 
 | 
 | ||||||
| 		switch association.Relationship.Type { | 		switch rel.Type { | ||||||
| 		case schema.HasOne, schema.HasMany: | 		case schema.HasOne, schema.HasMany: | ||||||
| 			modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() | 			modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() | ||||||
| 			tx.Model(modelValue).Clauses(clause.Where{Exprs: conds}).UpdateColumns(updateAttrs) | 			tx.Model(modelValue).Clauses(clause.Where{Exprs: conds}).UpdateColumns(updateAttrs) | ||||||
| @ -164,3 +268,95 @@ func (association *Association) Count() (count int) { | |||||||
| 
 | 
 | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func (association *Association) saveAssociation(clear bool, values ...interface{}) { | ||||||
|  | 	reflectValue := association.DB.Statement.ReflectValue | ||||||
|  | 
 | ||||||
|  | 	appendToRelations := func(source, rv reflect.Value, clear bool) { | ||||||
|  | 		switch association.Relationship.Type { | ||||||
|  | 		case schema.HasOne, schema.BelongsTo: | ||||||
|  | 			switch rv.Kind() { | ||||||
|  | 			case reflect.Slice, reflect.Array: | ||||||
|  | 				if rv.Len() > 0 { | ||||||
|  | 					association.Error = association.Relationship.Field.Set(source, rv.Index(0)) | ||||||
|  | 				} | ||||||
|  | 			case reflect.Struct: | ||||||
|  | 				association.Error = association.Relationship.Field.Set(source, rv) | ||||||
|  | 			} | ||||||
|  | 		case schema.HasMany, schema.Many2Many: | ||||||
|  | 			elemType := association.Relationship.Field.IndirectFieldType.Elem() | ||||||
|  | 			fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(reflectValue)) | ||||||
|  | 			if clear { | ||||||
|  | 				fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			appendToFieldValues := func(ev reflect.Value) { | ||||||
|  | 				if ev.Type().AssignableTo(elemType) { | ||||||
|  | 					fieldValue = reflect.Append(fieldValue, ev) | ||||||
|  | 				} else if ev.Type().Elem().AssignableTo(elemType) { | ||||||
|  | 					fieldValue = reflect.Append(fieldValue, ev.Elem()) | ||||||
|  | 				} else { | ||||||
|  | 					association.Error = fmt.Errorf("unsupported data type: %v for relation %v", ev.Type(), association.Relationship.Name) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			switch rv.Kind() { | ||||||
|  | 			case reflect.Slice, reflect.Array: | ||||||
|  | 				for i := 0; i < rv.Len(); i++ { | ||||||
|  | 					appendToFieldValues(reflect.Indirect(rv.Index(i))) | ||||||
|  | 				} | ||||||
|  | 			case reflect.Struct: | ||||||
|  | 				appendToFieldValues(rv) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if association.Error == nil { | ||||||
|  | 				association.Error = association.Relationship.Field.Set(source, fieldValue) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	selectedColumns := []string{association.Relationship.Name} | ||||||
|  | 	hasZero := false | ||||||
|  | 	for _, ref := range association.Relationship.References { | ||||||
|  | 		if !ref.OwnPrimaryKey { | ||||||
|  | 			selectedColumns = append(selectedColumns, ref.ForeignKey.Name) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	switch reflectValue.Kind() { | ||||||
|  | 	case reflect.Slice, reflect.Array: | ||||||
|  | 		if len(values) != reflectValue.Len() { | ||||||
|  | 			if clear && len(values) == 0 { | ||||||
|  | 				for i := 0; i < reflectValue.Len(); i++ { | ||||||
|  | 					association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType)) | ||||||
|  | 				} | ||||||
|  | 				break | ||||||
|  | 			} | ||||||
|  | 			association.Error = errors.New("invalid association values, length doesn't match") | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		for i := 0; i < reflectValue.Len(); i++ { | ||||||
|  | 			appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear) | ||||||
|  | 
 | ||||||
|  | 			if !hasZero { | ||||||
|  | 				_, hasZero = association.DB.Statement.Schema.PrioritizedPrimaryField.ValueOf(reflectValue.Index(i)) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	case reflect.Struct: | ||||||
|  | 		if clear && len(values) == 0 { | ||||||
|  | 			association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType)) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		for idx, value := range values { | ||||||
|  | 			appendToRelations(reflectValue, reflect.Indirect(reflect.ValueOf(value)), clear && idx == 0) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		_, hasZero = association.DB.Statement.Schema.PrioritizedPrimaryField.ValueOf(reflectValue) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if hasZero { | ||||||
|  | 		association.DB.Save(reflectValue.Interface()) | ||||||
|  | 	} else { | ||||||
|  | 		association.DB.Select(selectedColumns).Save(reflectValue.Interface()) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | |||||||
| @ -28,7 +28,7 @@ func SaveBeforeAssociations(db *gorm.DB) { | |||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			switch db.Statement.ReflectValue.Kind() { | 			switch db.Statement.ReflectValue.Kind() { | ||||||
| 			case reflect.Slice: | 			case reflect.Slice, reflect.Array: | ||||||
| 				var ( | 				var ( | ||||||
| 					objs      []reflect.Value | 					objs      []reflect.Value | ||||||
| 					fieldType = rel.Field.FieldType | 					fieldType = rel.Field.FieldType | ||||||
| @ -92,7 +92,7 @@ func SaveAfterAssociations(db *gorm.DB) { | |||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			switch db.Statement.ReflectValue.Kind() { | 			switch db.Statement.ReflectValue.Kind() { | ||||||
| 			case reflect.Slice: | 			case reflect.Slice, reflect.Array: | ||||||
| 				var ( | 				var ( | ||||||
| 					fieldType = rel.Field.FieldType | 					fieldType = rel.Field.FieldType | ||||||
| 					isPtr     = fieldType.Kind() == reflect.Ptr | 					isPtr     = fieldType.Kind() == reflect.Ptr | ||||||
| @ -193,7 +193,7 @@ func SaveAfterAssociations(db *gorm.DB) { | |||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			switch db.Statement.ReflectValue.Kind() { | 			switch db.Statement.ReflectValue.Kind() { | ||||||
| 			case reflect.Slice: | 			case reflect.Slice, reflect.Array: | ||||||
| 				for i := 0; i < db.Statement.ReflectValue.Len(); i++ { | 				for i := 0; i < db.Statement.ReflectValue.Len(); i++ { | ||||||
| 					appendToElems(db.Statement.ReflectValue.Index(i)) | 					appendToElems(db.Statement.ReflectValue.Index(i)) | ||||||
| 				} | 				} | ||||||
| @ -260,7 +260,7 @@ func SaveAfterAssociations(db *gorm.DB) { | |||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			switch db.Statement.ReflectValue.Kind() { | 			switch db.Statement.ReflectValue.Kind() { | ||||||
| 			case reflect.Slice: | 			case reflect.Slice, reflect.Array: | ||||||
| 				for i := 0; i < db.Statement.ReflectValue.Len(); i++ { | 				for i := 0; i < db.Statement.ReflectValue.Len(); i++ { | ||||||
| 					appendToElems(db.Statement.ReflectValue.Index(i)) | 					appendToElems(db.Statement.ReflectValue.Index(i)) | ||||||
| 				} | 				} | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu