From 4a301464df54108a973e186b8c548239a585209a Mon Sep 17 00:00:00 2001 From: Roy Reznik Date: Wed, 6 May 2020 08:48:46 +0300 Subject: [PATCH] Fixed implementation & tests --- callback_save.go | 2 +- callback_update.go | 2 +- callbacks_test.go | 4 ++-- main_test.go | 20 +++++++++++++++++++- 4 files changed, 23 insertions(+), 5 deletions(-) diff --git a/callback_save.go b/callback_save.go index 3b4e0589..892a2541 100644 --- a/callback_save.go +++ b/callback_save.go @@ -129,7 +129,7 @@ func saveAfterAssociationsCallback(scope *Scope) { scope.Err(newDB.Save(elem).Error) } } else if autoUpdate { - scope.Err(newDB.Save(elem).Error) + scope.Err(newScope.DB().Updates(elem).Error) } if !scope.New(newScope.Value).PrimaryKeyZero() && saveReference { diff --git a/callback_update.go b/callback_update.go index 4a2b6682..699e534b 100644 --- a/callback_update.go +++ b/callback_update.go @@ -75,7 +75,7 @@ func updateCallback(scope *Scope) { } else { for _, field := range scope.Fields() { if scope.changeableField(field) { - if !field.IsPrimaryKey && field.IsNormal && (!field.IsBlank || field.HasDefaultValue || field.IsForeignKey) { + if !field.IsPrimaryKey && field.IsNormal && (field.Name != "CreatedAt" || !field.IsBlank) { if !field.IsForeignKey || !field.IsBlank || !field.HasDefaultValue { sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) } diff --git a/callbacks_test.go b/callbacks_test.go index 62cb1660..bebd0e38 100644 --- a/callbacks_test.go +++ b/callbacks_test.go @@ -98,12 +98,12 @@ func TestRunCallbacks(t *testing.T) { } DB.Where("Code = ?", "unique_code").First(&p) - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 0, 1, 0, 0, 0, 0, 2}) { + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2}) { t.Errorf("After update callbacks values are not saved, %v", p.GetCallTimes()) } DB.Delete(&p) - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 0, 1, 0, 0, 1, 1, 2}) { + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2}) { t.Errorf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes()) } diff --git a/main_test.go b/main_test.go index c5e19c59..6585fc3f 100644 --- a/main_test.go +++ b/main_test.go @@ -841,9 +841,27 @@ func TestJoinsWithSelect(t *testing.T) { } DB.Save(&user) + validateEmails := func(results []result, emails []string) bool { + if len(results) != len(emails) { + return false + } + for _, r := range results { + containsEmail := false + for _, email := range emails { + if email == r.Email { + containsEmail = true + } + } + if !containsEmail { + return false + } + } + return true + } + var results []result DB.Table("users").Select("name, emails.email").Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins_with_select").Scan(&results) - if len(results) != 2 || results[0].Email != "join1@example.com" || results[1].Email != "join2@example.com" { + if len(results) != 2 || !validateEmails(results, []string{"join1@example.com", "join2@example.com"}) { t.Errorf("Should find all two emails with Join select") } }