Finish implement association support
This commit is contained in:
		
							parent
							
								
									20cb57b1ac
								
							
						
					
					
						commit
						0f21272c7f
					
				
							
								
								
									
										198
									
								
								association.go
									
									
									
									
									
								
							
							
						
						
									
										198
									
								
								association.go
									
									
									
									
									
								
							| @ -1,6 +1,7 @@ | ||||
| package gorm | ||||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 
 | ||||
| @ -34,16 +35,119 @@ func (db *DB) Association(column string) *Association { | ||||
| 
 | ||||
| func (association *Association) Find(out interface{}, conds ...interface{}) error { | ||||
| 	if association.Error == nil { | ||||
| 		var ( | ||||
| 			tx         = association.DB | ||||
| 			queryConds = association.Relationship.ToQueryConditions(tx.Statement.ReflectValue) | ||||
| 		) | ||||
| 
 | ||||
| 		if association.Relationship.JoinTable != nil { | ||||
| 			for _, queryClause := range association.Relationship.JoinTable.QueryClauses { | ||||
| 				tx.Clauses(queryClause) | ||||
| 			} | ||||
| 
 | ||||
| 			tx.Clauses(clause.From{Joins: []clause.Join{{ | ||||
| 				Table: clause.Table{Name: association.Relationship.JoinTable.Table}, | ||||
| 				ON:    clause.Where{Exprs: queryConds}, | ||||
| 			}}}) | ||||
| 		} else { | ||||
| 			tx.Clauses(clause.Where{Exprs: queryConds}) | ||||
| 		} | ||||
| 
 | ||||
| 		association.Error = tx.Find(out, conds...).Error | ||||
| 	} | ||||
| 
 | ||||
| 	return association.Error | ||||
| } | ||||
| 
 | ||||
| func (association *Association) Append(values ...interface{}) error { | ||||
| 	if association.Error == nil { | ||||
| 		switch association.Relationship.Type { | ||||
| 		case schema.HasOne, schema.BelongsTo: | ||||
| 			if len(values) > 0 { | ||||
| 				association.Error = association.Replace(values...) | ||||
| 			} | ||||
| 		default: | ||||
| 			association.saveAssociation(false, values...) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return association.Error | ||||
| } | ||||
| 
 | ||||
| func (association *Association) Replace(values ...interface{}) error { | ||||
| 	if association.Error == nil { | ||||
| 		association.saveAssociation(true, values...) | ||||
| 		reflectValue := association.DB.Statement.ReflectValue | ||||
| 		rel := association.Relationship | ||||
| 
 | ||||
| 		switch rel.Type { | ||||
| 		case schema.HasOne, schema.HasMany: | ||||
| 			var ( | ||||
| 				primaryFields []*schema.Field | ||||
| 				foreignKeys   []string | ||||
| 				updateMap     = map[string]interface{}{} | ||||
| 				modelValue    = reflect.New(rel.FieldSchema.ModelType).Interface() | ||||
| 			) | ||||
| 
 | ||||
| 			for _, ref := range rel.References { | ||||
| 				if ref.OwnPrimaryKey { | ||||
| 					primaryFields = append(primaryFields, ref.PrimaryKey) | ||||
| 				} else { | ||||
| 					foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) | ||||
| 					updateMap[ref.ForeignKey.DBName] = nil | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			_, values := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) | ||||
| 			column, queryValues := schema.ToQueryValues(foreignKeys, values) | ||||
| 			association.DB.Model(modelValue).Where(clause.IN{Column: column, Values: queryValues}).UpdateColumns(updateMap) | ||||
| 		case schema.Many2Many: | ||||
| 			var primaryFields, relPrimaryFields []*schema.Field | ||||
| 			var foreignKeys, relForeignKeys []string | ||||
| 			modelValue := reflect.New(rel.JoinTable.ModelType).Interface() | ||||
| 			conds := []clause.Expression{} | ||||
| 
 | ||||
| 			for _, ref := range rel.References { | ||||
| 				if ref.OwnPrimaryKey { | ||||
| 					primaryFields = append(primaryFields, ref.PrimaryKey) | ||||
| 					foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) | ||||
| 				} else if ref.PrimaryValue != "" { | ||||
| 					conds = append(conds, clause.Eq{ | ||||
| 						Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, | ||||
| 						Value:  ref.PrimaryValue, | ||||
| 					}) | ||||
| 				} else { | ||||
| 					relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey) | ||||
| 					relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			generateConds := func(rv reflect.Value) { | ||||
| 				_, values := schema.GetIdentityFieldValuesMap(rv, primaryFields) | ||||
| 				column, queryValues := schema.ToQueryValues(foreignKeys, values) | ||||
| 
 | ||||
| 				relValue := rel.Field.ReflectValueOf(rv) | ||||
| 				_, relValues := schema.GetIdentityFieldValuesMap(relValue, relPrimaryFields) | ||||
| 				relColumn, relQueryValues := schema.ToQueryValues(relForeignKeys, relValues) | ||||
| 
 | ||||
| 				conds = append(conds, clause.And( | ||||
| 					clause.IN{Column: column, Values: queryValues}, | ||||
| 					clause.Not(clause.IN{Column: relColumn, Values: relQueryValues}), | ||||
| 				)) | ||||
| 			} | ||||
| 
 | ||||
| 			switch reflectValue.Kind() { | ||||
| 			case reflect.Struct: | ||||
| 				generateConds(reflectValue) | ||||
| 			case reflect.Slice, reflect.Array: | ||||
| 				for i := 0; i < reflectValue.Len(); i++ { | ||||
| 					generateConds(reflectValue.Index(i)) | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			association.DB.Where(conds).Delete(modelValue) | ||||
| 		} | ||||
| 	} | ||||
| 	return association.Error | ||||
| } | ||||
| 
 | ||||
| @ -78,7 +182,7 @@ func (association *Association) Delete(values ...interface{}) error { | ||||
| 		column, values := schema.ToQueryValues(foreignKeys, relQueryValues) | ||||
| 		tx.Where(clause.IN{Column: column, Values: values}) | ||||
| 
 | ||||
| 		switch association.Relationship.Type { | ||||
| 		switch rel.Type { | ||||
| 		case schema.HasOne, schema.HasMany: | ||||
| 			modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() | ||||
| 			tx.Model(modelValue).Clauses(clause.Where{Exprs: conds}).UpdateColumns(updateAttrs) | ||||
| @ -164,3 +268,95 @@ func (association *Association) Count() (count int) { | ||||
| 
 | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (association *Association) saveAssociation(clear bool, values ...interface{}) { | ||||
| 	reflectValue := association.DB.Statement.ReflectValue | ||||
| 
 | ||||
| 	appendToRelations := func(source, rv reflect.Value, clear bool) { | ||||
| 		switch association.Relationship.Type { | ||||
| 		case schema.HasOne, schema.BelongsTo: | ||||
| 			switch rv.Kind() { | ||||
| 			case reflect.Slice, reflect.Array: | ||||
| 				if rv.Len() > 0 { | ||||
| 					association.Error = association.Relationship.Field.Set(source, rv.Index(0)) | ||||
| 				} | ||||
| 			case reflect.Struct: | ||||
| 				association.Error = association.Relationship.Field.Set(source, rv) | ||||
| 			} | ||||
| 		case schema.HasMany, schema.Many2Many: | ||||
| 			elemType := association.Relationship.Field.IndirectFieldType.Elem() | ||||
| 			fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(reflectValue)) | ||||
| 			if clear { | ||||
| 				fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType) | ||||
| 			} | ||||
| 
 | ||||
| 			appendToFieldValues := func(ev reflect.Value) { | ||||
| 				if ev.Type().AssignableTo(elemType) { | ||||
| 					fieldValue = reflect.Append(fieldValue, ev) | ||||
| 				} else if ev.Type().Elem().AssignableTo(elemType) { | ||||
| 					fieldValue = reflect.Append(fieldValue, ev.Elem()) | ||||
| 				} else { | ||||
| 					association.Error = fmt.Errorf("unsupported data type: %v for relation %v", ev.Type(), association.Relationship.Name) | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			switch rv.Kind() { | ||||
| 			case reflect.Slice, reflect.Array: | ||||
| 				for i := 0; i < rv.Len(); i++ { | ||||
| 					appendToFieldValues(reflect.Indirect(rv.Index(i))) | ||||
| 				} | ||||
| 			case reflect.Struct: | ||||
| 				appendToFieldValues(rv) | ||||
| 			} | ||||
| 
 | ||||
| 			if association.Error == nil { | ||||
| 				association.Error = association.Relationship.Field.Set(source, fieldValue) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	selectedColumns := []string{association.Relationship.Name} | ||||
| 	hasZero := false | ||||
| 	for _, ref := range association.Relationship.References { | ||||
| 		if !ref.OwnPrimaryKey { | ||||
| 			selectedColumns = append(selectedColumns, ref.ForeignKey.Name) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	switch reflectValue.Kind() { | ||||
| 	case reflect.Slice, reflect.Array: | ||||
| 		if len(values) != reflectValue.Len() { | ||||
| 			if clear && len(values) == 0 { | ||||
| 				for i := 0; i < reflectValue.Len(); i++ { | ||||
| 					association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType)) | ||||
| 				} | ||||
| 				break | ||||
| 			} | ||||
| 			association.Error = errors.New("invalid association values, length doesn't match") | ||||
| 		} | ||||
| 
 | ||||
| 		for i := 0; i < reflectValue.Len(); i++ { | ||||
| 			appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear) | ||||
| 
 | ||||
| 			if !hasZero { | ||||
| 				_, hasZero = association.DB.Statement.Schema.PrioritizedPrimaryField.ValueOf(reflectValue.Index(i)) | ||||
| 			} | ||||
| 		} | ||||
| 	case reflect.Struct: | ||||
| 		if clear && len(values) == 0 { | ||||
| 			association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType)) | ||||
| 		} | ||||
| 
 | ||||
| 		for idx, value := range values { | ||||
| 			appendToRelations(reflectValue, reflect.Indirect(reflect.ValueOf(value)), clear && idx == 0) | ||||
| 		} | ||||
| 
 | ||||
| 		_, hasZero = association.DB.Statement.Schema.PrioritizedPrimaryField.ValueOf(reflectValue) | ||||
| 	} | ||||
| 
 | ||||
| 	if hasZero { | ||||
| 		association.DB.Save(reflectValue.Interface()) | ||||
| 	} else { | ||||
| 		association.DB.Select(selectedColumns).Save(reflectValue.Interface()) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @ -28,7 +28,7 @@ func SaveBeforeAssociations(db *gorm.DB) { | ||||
| 			} | ||||
| 
 | ||||
| 			switch db.Statement.ReflectValue.Kind() { | ||||
| 			case reflect.Slice: | ||||
| 			case reflect.Slice, reflect.Array: | ||||
| 				var ( | ||||
| 					objs      []reflect.Value | ||||
| 					fieldType = rel.Field.FieldType | ||||
| @ -92,7 +92,7 @@ func SaveAfterAssociations(db *gorm.DB) { | ||||
| 			} | ||||
| 
 | ||||
| 			switch db.Statement.ReflectValue.Kind() { | ||||
| 			case reflect.Slice: | ||||
| 			case reflect.Slice, reflect.Array: | ||||
| 				var ( | ||||
| 					fieldType = rel.Field.FieldType | ||||
| 					isPtr     = fieldType.Kind() == reflect.Ptr | ||||
| @ -193,7 +193,7 @@ func SaveAfterAssociations(db *gorm.DB) { | ||||
| 			} | ||||
| 
 | ||||
| 			switch db.Statement.ReflectValue.Kind() { | ||||
| 			case reflect.Slice: | ||||
| 			case reflect.Slice, reflect.Array: | ||||
| 				for i := 0; i < db.Statement.ReflectValue.Len(); i++ { | ||||
| 					appendToElems(db.Statement.ReflectValue.Index(i)) | ||||
| 				} | ||||
| @ -260,7 +260,7 @@ func SaveAfterAssociations(db *gorm.DB) { | ||||
| 			} | ||||
| 
 | ||||
| 			switch db.Statement.ReflectValue.Kind() { | ||||
| 			case reflect.Slice: | ||||
| 			case reflect.Slice, reflect.Array: | ||||
| 				for i := 0; i < db.Statement.ReflectValue.Len(); i++ { | ||||
| 					appendToElems(db.Statement.ReflectValue.Index(i)) | ||||
| 				} | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu