From 2dd6140e7ab86304f460a41e8868bdf7972e1b38 Mon Sep 17 00:00:00 2001 From: SystemGlitch Date: Fri, 4 Nov 2022 14:31:17 +0100 Subject: [PATCH] feat(Scopes): restored the ability to set the model from Scopes --- callbacks.go | 45 ++++++++++++++++++++++------------ tests/schema_in_scopes_test.go | 40 ++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 16 deletions(-) diff --git a/callbacks.go b/callbacks.go index 875150b8..8720b0fd 100644 --- a/callbacks.go +++ b/callbacks.go @@ -84,22 +84,7 @@ func (p *processor) Execute(db *DB) *DB { resetBuildClauses = true } - // assign model values - if stmt.Model == nil { - stmt.Model = stmt.Dest - } else if stmt.Dest == nil { - stmt.Dest = stmt.Model - } - - var modelParseError error - // parse model values - if stmt.Model != nil { - if modelParseError = stmt.Parse(stmt.Model); modelParseError != nil && (!errors.Is(modelParseError, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.TableExpr == nil && stmt.SQL.Len() == 0)) { - if !errors.Is(modelParseError, schema.ErrUnsupportedDataType) || stmt.Table != "" || stmt.TableExpr != nil { - _ = db.AddError(modelParseError) - } - } - } + modelParseError := p.parseModelValue(db, stmt) // assign stmt.ReflectValue if stmt.Dest != nil { @@ -125,6 +110,11 @@ func (p *processor) Execute(db *DB) *DB { } } + // parse model again if the model has been set in scopes + if stmt.Schema == nil && (modelParseError == nil || errors.Is(modelParseError, schema.ErrUnsupportedDataType)) { + modelParseError = p.parseModelValue(db, stmt) + } + if stmt.Table == "" && stmt.TableExpr == nil && stmt.SQL.Len() == 0 { _ = db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", modelParseError)) } @@ -151,6 +141,29 @@ func (p *processor) Execute(db *DB) *DB { return db } +func (p *processor) parseModelValue(db *DB, stmt *Statement) error { + p.assignModelValues(stmt) + + var err error + if stmt.Model != nil { + err = stmt.Parse(stmt.Model) + if err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.TableExpr == nil && stmt.SQL.Len() == 0)) { + if !errors.Is(err, schema.ErrUnsupportedDataType) || stmt.Table != "" || stmt.TableExpr != nil { + _ = db.AddError(err) + } + } + } + return err +} + +func (p *processor) assignModelValues(stmt *Statement) { + if stmt.Model == nil { + stmt.Model = stmt.Dest + } else if stmt.Dest == nil { + stmt.Dest = stmt.Model + } +} + func (p *processor) Get(name string) func(*DB) { for i := len(p.callbacks) - 1; i >= 0; i-- { if v := p.callbacks[i]; v.name == name && !v.remove { diff --git a/tests/schema_in_scopes_test.go b/tests/schema_in_scopes_test.go index 54048e78..48de8481 100644 --- a/tests/schema_in_scopes_test.go +++ b/tests/schema_in_scopes_test.go @@ -49,3 +49,43 @@ func TestSchemaAccessibleFromScopes(t *testing.T) { t.Errorf("invalid schema found, expected non-nil schema") } } + +func TestSetModelInScope(t *testing.T) { + users := []User{ + *GetUser("model-scope-1", Config{}), + *GetUser("model-scope-2", Config{}), + } + + if err := DB.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create users: %v", err) + } + + scope := func(db *gorm.DB) *gorm.DB { + return db.Model(&User{}) + } + + var results []map[string]interface{} + tx := DB.Scopes(scope) + tx = tx.Select("name", "age").Where("name like ?", "model-scope-%").Find(&results) + if err := tx.Error; err != nil { + t.Errorf("failed to query users, got error: %v", err) + } + + expects := []User{ + {Name: "model-scope-1", Age: 18}, + {Name: "model-scope-2", Age: 18}, + } + + if len(results) != 2 { + t.Fatalf("invalid results length found, expects: %v, got %v", len(expects), len(results)) + } + + expectedTableName := "users" + if tx.Statement.Table != expectedTableName { + t.Errorf("invalid table name found, expects: %v, got %v", expectedTableName, tx.Statement.Table) + } + + if tx.Statement.Schema == nil { + t.Errorf("invalid schema found, expected non-nil schema") + } +}