Make Preload works with has_one, has_many, belongs_to relations
This commit is contained in:
		
							parent
							
								
									dd635a17c2
								
							
						
					
					
						commit
						f97e2c088e
					
				@ -1093,8 +1093,6 @@ db.Where("email = ?", "x@example.org").Attrs(User{RegisteredIp: "111.111.111.111
 | 
				
			|||||||
  DefaultTimeZone, R/W Splitting, Validation
 | 
					  DefaultTimeZone, R/W Splitting, Validation
 | 
				
			||||||
* Github Pages
 | 
					* Github Pages
 | 
				
			||||||
* AlertColumn, DropColumn
 | 
					* AlertColumn, DropColumn
 | 
				
			||||||
* db.Preload("Addresses.Map", "active = ?", true).Preload("Profile").Find(&users)
 | 
					 | 
				
			||||||
* db.Find(&users).Related(&users)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Author
 | 
					# Author
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										9
									
								
								field.go
									
									
									
									
									
								
							
							
						
						
									
										9
									
								
								field.go
									
									
									
									
									
								
							@ -15,6 +15,15 @@ type relationship struct {
 | 
				
			|||||||
	Kind                  string
 | 
						Kind                  string
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// FIXME
 | 
				
			||||||
 | 
					func (r relationship) ForeignDBName() string {
 | 
				
			||||||
 | 
						return ToSnake(r.ForeignKey)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (r relationship) AssociationForeignDBName(name string) string {
 | 
				
			||||||
 | 
						return ToSnake(r.AssociationForeignKey)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type Field struct {
 | 
					type Field struct {
 | 
				
			||||||
	Name         string
 | 
						Name         string
 | 
				
			||||||
	DBName       string
 | 
						DBName       string
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										76
									
								
								preload.go
									
									
									
									
									
								
							
							
						
						
									
										76
									
								
								preload.go
									
									
									
									
									
								
							@ -1,15 +1,30 @@
 | 
				
			|||||||
package gorm
 | 
					package gorm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"database/sql/driver"
 | 
				
			||||||
	"errors"
 | 
						"errors"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"reflect"
 | 
						"reflect"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func getFieldValue(value reflect.Value, field string) interface{} {
 | 
				
			||||||
 | 
						result := value.FieldByName(field).Interface()
 | 
				
			||||||
 | 
						if r, ok := result.(driver.Valuer); ok {
 | 
				
			||||||
 | 
							result, _ = r.Value()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return result
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func equalAsString(a interface{}, b interface{}) bool {
 | 
				
			||||||
 | 
						return fmt.Sprintf("%v", a) == fmt.Sprintf("%v", b)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func Preload(scope *Scope) {
 | 
					func Preload(scope *Scope) {
 | 
				
			||||||
	// Get Fields
 | 
						// Get Fields
 | 
				
			||||||
	var fields map[string]*Field
 | 
						var fields map[string]*Field
 | 
				
			||||||
 | 
						var isSlice bool
 | 
				
			||||||
	if scope.IndirectValue().Kind() == reflect.Slice {
 | 
						if scope.IndirectValue().Kind() == reflect.Slice {
 | 
				
			||||||
 | 
							isSlice = true
 | 
				
			||||||
		elem := reflect.New(scope.IndirectValue().Type().Elem()).Elem()
 | 
							elem := reflect.New(scope.IndirectValue().Type().Elem()).Elem()
 | 
				
			||||||
		fields = scope.New(elem.Addr().Interface()).Fields()
 | 
							fields = scope.New(elem.Addr().Interface()).Fields()
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
@ -23,16 +38,69 @@ func Preload(scope *Scope) {
 | 
				
			|||||||
					results := makeSlice(field.Field)
 | 
										results := makeSlice(field.Field)
 | 
				
			||||||
					relation := field.Relationship
 | 
										relation := field.Relationship
 | 
				
			||||||
					primaryName := scope.PrimaryKeyField().Name
 | 
										primaryName := scope.PrimaryKeyField().Name
 | 
				
			||||||
 | 
										associationPrimaryKey := scope.New(results).PrimaryKeyField().Name
 | 
				
			||||||
 | 
					
 | 
				
			||||||
					switch relation.Kind {
 | 
										switch relation.Kind {
 | 
				
			||||||
					case "has_one":
 | 
										case "has_one":
 | 
				
			||||||
						sql := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignKey))
 | 
											condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName()))
 | 
				
			||||||
						scope.NewDB().Find(results, sql, scope.getColumnAsArray(primaryName))
 | 
											scope.NewDB().Find(results, condition, scope.getColumnAsArray(primaryName))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
											resultValues := reflect.Indirect(reflect.ValueOf(results))
 | 
				
			||||||
 | 
											for i := 0; i < resultValues.Len(); i++ {
 | 
				
			||||||
 | 
												result := resultValues.Index(i)
 | 
				
			||||||
 | 
												if isSlice {
 | 
				
			||||||
 | 
													value := getFieldValue(result, relation.ForeignKey)
 | 
				
			||||||
 | 
													objects := scope.IndirectValue()
 | 
				
			||||||
 | 
													for j := 0; j < objects.Len(); j++ {
 | 
				
			||||||
 | 
														if equalAsString(getFieldValue(objects.Index(j), primaryName), value) {
 | 
				
			||||||
 | 
															objects.Index(j).FieldByName(field.Name).Set(result)
 | 
				
			||||||
 | 
															break
 | 
				
			||||||
 | 
														}
 | 
				
			||||||
 | 
													}
 | 
				
			||||||
 | 
												} else {
 | 
				
			||||||
 | 
													scope.SetColumn(field, result)
 | 
				
			||||||
 | 
												}
 | 
				
			||||||
 | 
											}
 | 
				
			||||||
					case "has_many":
 | 
										case "has_many":
 | 
				
			||||||
						sql := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignKey))
 | 
											condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName()))
 | 
				
			||||||
						scope.NewDB().Find(results, sql, scope.getColumnAsArray(primaryName))
 | 
											scope.NewDB().Find(results, condition, scope.getColumnAsArray(primaryName))
 | 
				
			||||||
 | 
											resultValues := reflect.Indirect(reflect.ValueOf(results))
 | 
				
			||||||
 | 
											if isSlice {
 | 
				
			||||||
 | 
												for i := 0; i < resultValues.Len(); i++ {
 | 
				
			||||||
 | 
													result := resultValues.Index(i)
 | 
				
			||||||
 | 
													value := getFieldValue(result, relation.ForeignKey)
 | 
				
			||||||
 | 
													objects := scope.IndirectValue()
 | 
				
			||||||
 | 
													for j := 0; j < objects.Len(); j++ {
 | 
				
			||||||
 | 
														object := objects.Index(j)
 | 
				
			||||||
 | 
														if equalAsString(getFieldValue(object, primaryName), value) {
 | 
				
			||||||
 | 
															f := object.FieldByName(field.Name)
 | 
				
			||||||
 | 
															f.Set(reflect.Append(f, result))
 | 
				
			||||||
 | 
															break
 | 
				
			||||||
 | 
														}
 | 
				
			||||||
 | 
													}
 | 
				
			||||||
 | 
												}
 | 
				
			||||||
 | 
											} else {
 | 
				
			||||||
 | 
												scope.SetColumn(field, resultValues)
 | 
				
			||||||
 | 
											}
 | 
				
			||||||
					case "belongs_to":
 | 
										case "belongs_to":
 | 
				
			||||||
						scope.NewDB().Find(results, scope.getColumnAsArray(relation.ForeignKey))
 | 
											scope.NewDB().Find(results, scope.getColumnAsArray(relation.ForeignKey))
 | 
				
			||||||
 | 
											resultValues := reflect.Indirect(reflect.ValueOf(results))
 | 
				
			||||||
 | 
											for i := 0; i < resultValues.Len(); i++ {
 | 
				
			||||||
 | 
												result := resultValues.Index(i)
 | 
				
			||||||
 | 
												if isSlice {
 | 
				
			||||||
 | 
													value := getFieldValue(result, associationPrimaryKey)
 | 
				
			||||||
 | 
													objects := scope.IndirectValue()
 | 
				
			||||||
 | 
													for j := 0; j < objects.Len(); j++ {
 | 
				
			||||||
 | 
														object := objects.Index(j)
 | 
				
			||||||
 | 
														if equalAsString(getFieldValue(object, relation.ForeignKey), value) {
 | 
				
			||||||
 | 
															object.FieldByName(field.Name).Set(result)
 | 
				
			||||||
 | 
															break
 | 
				
			||||||
 | 
														}
 | 
				
			||||||
 | 
													}
 | 
				
			||||||
 | 
												} else {
 | 
				
			||||||
 | 
													scope.SetColumn(field, result)
 | 
				
			||||||
 | 
												}
 | 
				
			||||||
 | 
											}
 | 
				
			||||||
					case "many_to_many":
 | 
										case "many_to_many":
 | 
				
			||||||
						scope.Err(errors.New("not supported relation"))
 | 
											scope.Err(errors.New("not supported relation"))
 | 
				
			||||||
					default:
 | 
										default:
 | 
				
			||||||
 | 
				
			|||||||
@ -1,11 +1,79 @@
 | 
				
			|||||||
package gorm_test
 | 
					package gorm_test
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import "testing"
 | 
					import (
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func getPreloadUser(name string) User {
 | 
				
			||||||
 | 
						return User{
 | 
				
			||||||
 | 
							Name:            name,
 | 
				
			||||||
 | 
							Role:            Role{"Preload"},
 | 
				
			||||||
 | 
							BillingAddress:  Address{Address1: fmt.Sprintf("Billing Address %v", name)},
 | 
				
			||||||
 | 
							ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)},
 | 
				
			||||||
 | 
							CreditCard:      CreditCard{Number: fmt.Sprintf("123456%v", name)},
 | 
				
			||||||
 | 
							Emails: []Email{
 | 
				
			||||||
 | 
								{Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							Languages: []Language{
 | 
				
			||||||
 | 
								{Name: fmt.Sprintf("lang_1_%v", name)},
 | 
				
			||||||
 | 
								{Name: fmt.Sprintf("lang_2_%v", name)},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func checkUserHasPreloadData(user User, t *testing.T) {
 | 
				
			||||||
 | 
						u := getPreloadUser(user.Name)
 | 
				
			||||||
 | 
						if user.BillingAddress.Address1 != u.BillingAddress.Address1 {
 | 
				
			||||||
 | 
							t.Error("Failed to preload user's BillingAddress")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if user.ShippingAddress.Address1 != u.ShippingAddress.Address1 {
 | 
				
			||||||
 | 
							t.Error("Failed to preload user's ShippingAddress")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if user.CreditCard.Number != u.CreditCard.Number {
 | 
				
			||||||
 | 
							t.Error("Failed to preload user's CreditCard")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if len(user.Emails) != len(u.Emails) {
 | 
				
			||||||
 | 
							t.Error("Failed to preload user's Emails")
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							var found int
 | 
				
			||||||
 | 
							for _, e1 := range u.Emails {
 | 
				
			||||||
 | 
								for _, e2 := range user.Emails {
 | 
				
			||||||
 | 
									if e1.Email == e2.Email {
 | 
				
			||||||
 | 
										found++
 | 
				
			||||||
 | 
										break
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if found != len(u.Emails) {
 | 
				
			||||||
 | 
								t.Error("Failed to preload user's email details")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestPreload(t *testing.T) {
 | 
					func TestPreload(t *testing.T) {
 | 
				
			||||||
	user := User{Name: "PreloadUser", BillingAddress: Address{Address1: "Billing Address"}, ShippingAddress: Address{Address1: "Shipping Address"}, Languages: []Language{{Name: "Preload L1"}, {Name: "Preload L2"}}}
 | 
						user1 := getPreloadUser("user1")
 | 
				
			||||||
	DB.Save(&user)
 | 
						DB.Save(&user1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var user User
 | 
				
			||||||
 | 
						DB.Where("role = ?", "Preload").Preload("BillingAddress").Preload("ShippingAddress").
 | 
				
			||||||
 | 
							Preload("CreditCard").Preload("Emails").Find(&user)
 | 
				
			||||||
 | 
						checkUserHasPreloadData(user, t)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						user2 := getPreloadUser("user2")
 | 
				
			||||||
 | 
						DB.Save(&user2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						user3 := getPreloadUser("user3")
 | 
				
			||||||
 | 
						DB.Save(&user3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var users []User
 | 
						var users []User
 | 
				
			||||||
	DB.Preload("BillingAddress").Preload("ShippingAddress").Preload("Languages").Find(&users)
 | 
						DB.Where("role = ?", "Preload").Preload("BillingAddress").Preload("ShippingAddress").
 | 
				
			||||||
 | 
							Preload("CreditCard").Preload("Emails").Find(&users)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, user := range users {
 | 
				
			||||||
 | 
							checkUserHasPreloadData(user, t)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user