From 5883490aa773ad8dbc13c901bb4ffec502417477 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 20 Jun 2020 17:21:01 +0800 Subject: [PATCH] Select, Omit, Preload supports clause.Associations --- callbacks/helper.go | 15 ++++++++++----- callbacks/query.go | 14 +++++++++++--- tests/create_test.go | 24 +++++++++++++++++++++--- tests/preload_test.go | 23 +++++++++++++++++++++++ 4 files changed, 65 insertions(+), 11 deletions(-) diff --git a/callbacks/helper.go b/callbacks/helper.go index 97c8ad35..3b0cca16 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -19,10 +19,11 @@ func SelectAndOmitColumns(stmt *gorm.Statement, requireCreate, requireUpdate boo for _, dbName := range stmt.Schema.DBNames { results[dbName] = true } - break - } - - if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" { + } else if column == clause.Associations { + for _, rel := range stmt.Schema.Relationships.Relations { + results[rel.Name] = true + } + } else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" { results[field.DBName] = true } else { results[column] = true @@ -31,7 +32,11 @@ func SelectAndOmitColumns(stmt *gorm.Statement, requireCreate, requireUpdate boo // omit columns for _, omit := range stmt.Omits { - if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" { + if omit == clause.Associations { + for _, rel := range stmt.Schema.Relationships.Relations { + results[rel.Name] = false + } + } else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" { results[field.DBName] = false } else { results[omit] = false diff --git a/callbacks/query.go b/callbacks/query.go index e5557d4a..27d53a4d 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -140,9 +140,17 @@ func Preload(db *gorm.DB) { if db.Error == nil && len(db.Statement.Preloads) > 0 { preloadMap := map[string][]string{} for name := range db.Statement.Preloads { - preloadFields := strings.Split(name, ".") - for idx := range preloadFields { - preloadMap[strings.Join(preloadFields[:idx+1], ".")] = preloadFields[:idx+1] + if name == clause.Associations { + for _, rel := range db.Statement.Schema.Relationships.Relations { + if rel.Schema == db.Statement.Schema { + preloadMap[rel.Name] = []string{rel.Name} + } + } + } else { + preloadFields := strings.Split(name, ".") + for idx := range preloadFields { + preloadMap[strings.Join(preloadFields[:idx+1], ".")] = preloadFields[:idx+1] + } } } diff --git a/tests/create_test.go b/tests/create_test.go index 351f02a3..4bf623b3 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -6,6 +6,7 @@ import ( "github.com/jinzhu/now" "gorm.io/gorm" + "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" ) @@ -282,13 +283,30 @@ func TestOmitWithCreate(t *testing.T) { user := *GetUser("omit_create", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) DB.Omit("Account", "Toys", "Manager", "Birthday").Create(&user) - var user2 User - DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&user2, user.ID) + var result User + DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&result, user.ID) user.Birthday = nil user.Account = Account{} user.Toys = nil user.Manager = nil - CheckUser(t, user2, user) + CheckUser(t, result, user) + + user2 := *GetUser("omit_create", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + DB.Omit(clause.Associations).Create(&user2) + + var result2 User + DB.Preload(clause.Associations).First(&result2, user2.ID) + + user2.Account = Account{} + user2.Toys = nil + user2.Manager = nil + user2.Company = Company{} + user2.Pets = nil + user2.Team = nil + user2.Languages = nil + user2.Friends = nil + + CheckUser(t, result2, user2) } diff --git a/tests/preload_test.go b/tests/preload_test.go index 06e38f09..3caa17b4 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -9,6 +9,29 @@ import ( . "gorm.io/gorm/utils/tests" ) +func TestPreloadWithAssociations(t *testing.T) { + var user = *GetUser("preload_with_associations", Config{ + Account: true, + Pets: 2, + Toys: 3, + Company: true, + Manager: true, + Team: 4, + Languages: 3, + Friends: 1, + }) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckUser(t, user, user) + + var user2 User + DB.Preload(clause.Associations).Find(&user2, "id = ?", user.ID) + CheckUser(t, user2, user) +} + func TestNestedPreload(t *testing.T) { var user = *GetUser("nested_preload", Config{Pets: 2})