From 8ce25a5dd11272527a455a69e8a0c89408f3a5c3 Mon Sep 17 00:00:00 2001 From: mr-chenguang lcgash Date: Mon, 15 Apr 2024 03:08:37 +0000 Subject: [PATCH] fix: keep nil when dest is ptr & dest is nil[notfound] --- callbacks.go | 3 ++- scan.go | 2 +- statement.go | 1 + tests/query_test.go | 8 +------- 4 files changed, 5 insertions(+), 9 deletions(-) diff --git a/callbacks.go b/callbacks.go index 50b5b0e9..e97254cf 100644 --- a/callbacks.go +++ b/callbacks.go @@ -115,7 +115,8 @@ func (p *processor) Execute(db *DB) *DB { if stmt.Dest != nil { stmt.ReflectValue = reflect.ValueOf(stmt.Dest) for stmt.ReflectValue.Kind() == reflect.Ptr { - if stmt.ReflectValue.IsNil() && stmt.ReflectValue.CanAddr() { + stmt.DestIsNil = stmt.ReflectValue.IsNil() + if stmt.DestIsNil && stmt.ReflectValue.CanAddr() { stmt.ReflectValue.Set(reflect.New(stmt.ReflectValue.Type().Elem())) } diff --git a/scan.go b/scan.go index 3a2a212c..236f405a 100644 --- a/scan.go +++ b/scan.go @@ -342,7 +342,7 @@ func Scan(rows Rows, db *DB, mode ScanMode) { if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound && db.Error == nil { db.AddError(ErrRecordNotFound) - if db.NotFoundAsNilWhenPtr && db.Statement.Dest != nil && reflect.ValueOf(db.Statement.Dest).Kind() == reflect.Ptr { + if db.Statement.DestIsNil { // reset dest to nil reflect.ValueOf(db.Statement.Dest).Elem().Set(reflect.Zero(reflect.ValueOf(db.Statement.Dest).Elem().Type())) } diff --git a/statement.go b/statement.go index ae79aa32..b2c187b9 100644 --- a/statement.go +++ b/statement.go @@ -26,6 +26,7 @@ type Statement struct { Model interface{} Unscoped bool Dest interface{} + DestIsNil bool ReflectValue reflect.Value Clauses map[string]clause.Clause BuildClauses []string diff --git a/tests/query_test.go b/tests/query_test.go index bb673f94..f3b00f8b 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -208,13 +208,7 @@ func TestFind(t *testing.T) { t.Run("NotFoundAsNil", func(t *testing.T) { var first *User - if err := DB.Where("name = ?", "find-not-found").First(&first).Error; err != nil { - AssertEqual(t, err, gorm.ErrRecordNotFound) - AssertEqual(t, first == nil, false) - } - - DB.Config.NotFoundAsNilWhenPtr = true - if err := DB.Where("name = ?", "find-not-found").First(&first).Error; err != nil { + if err := DB.Where("name = ?", "find not found").First(&first).Error; err != nil { AssertEqual(t, err, gorm.ErrRecordNotFound) AssertEqual(t, first, nil) }