diff --git a/gorm.go b/gorm.go index 775cd3de..f07d9daa 100644 --- a/gorm.go +++ b/gorm.go @@ -50,6 +50,8 @@ type Config struct { CreateBatchSize int // TranslateError enabling error translation TranslateError bool + // NotFoundAsError set result is nil when no record found and result is ptr + NotFoundAsNilWhenPtr bool // ClauseBuilders clause builder ClauseBuilders map[string]clause.ClauseBuilder diff --git a/scan.go b/scan.go index 415b9f0d..3a2a212c 100644 --- a/scan.go +++ b/scan.go @@ -342,5 +342,9 @@ func Scan(rows Rows, db *DB, mode ScanMode) { if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound && db.Error == nil { db.AddError(ErrRecordNotFound) + if db.NotFoundAsNilWhenPtr && db.Statement.Dest != nil && reflect.ValueOf(db.Statement.Dest).Kind() == reflect.Ptr { + // reset dest to nil + reflect.ValueOf(db.Statement.Dest).Elem().Set(reflect.Zero(reflect.ValueOf(db.Statement.Dest).Elem().Type())) + } } } diff --git a/tests/query_test.go b/tests/query_test.go index c0259a14..bb673f94 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -206,6 +206,20 @@ func TestFind(t *testing.T) { } }) + t.Run("NotFoundAsNil", func(t *testing.T) { + var first *User + if err := DB.Where("name = ?", "find-not-found").First(&first).Error; err != nil { + AssertEqual(t, err, gorm.ErrRecordNotFound) + AssertEqual(t, first == nil, false) + } + + DB.Config.NotFoundAsNilWhenPtr = true + if err := DB.Where("name = ?", "find-not-found").First(&first).Error; err != nil { + AssertEqual(t, err, gorm.ErrRecordNotFound) + AssertEqual(t, first, nil) + } + }) + var models []User if err := DB.Where("name in (?)", []string{"find"}).Find(&models).Error; err != nil || len(models) != 3 { t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models))