From c4ceaefa46416d85ce0b32167635228d40d46d23 Mon Sep 17 00:00:00 2001 From: SystemGlitch Date: Tue, 25 Oct 2022 16:18:57 +0200 Subject: [PATCH] feat(Scopes): make Schema and Table accessible from scopes --- callbacks.go | 33 ++++++++++++---------- tests/schema_in_scopes_test.go | 51 ++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 15 deletions(-) create mode 100644 tests/schema_in_scopes_test.go diff --git a/callbacks.go b/callbacks.go index c060ea70..875150b8 100644 --- a/callbacks.go +++ b/callbacks.go @@ -73,15 +73,6 @@ func (cs *callbacks) Raw() *processor { } func (p *processor) Execute(db *DB) *DB { - // call scopes - for len(db.Statement.scopes) > 0 { - scopes := db.Statement.scopes - db.Statement.scopes = nil - for _, scope := range scopes { - db = scope(db) - } - } - var ( curTime = time.Now() stmt = db.Statement @@ -100,13 +91,12 @@ func (p *processor) Execute(db *DB) *DB { stmt.Dest = stmt.Model } + var modelParseError error // parse model values if stmt.Model != nil { - if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.TableExpr == nil && stmt.SQL.Len() == 0)) { - if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" && stmt.TableExpr == nil { - db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err)) - } else { - db.AddError(err) + if modelParseError = stmt.Parse(stmt.Model); modelParseError != nil && (!errors.Is(modelParseError, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.TableExpr == nil && stmt.SQL.Len() == 0)) { + if !errors.Is(modelParseError, schema.ErrUnsupportedDataType) || stmt.Table != "" || stmt.TableExpr != nil { + _ = db.AddError(modelParseError) } } } @@ -122,10 +112,23 @@ func (p *processor) Execute(db *DB) *DB { stmt.ReflectValue = stmt.ReflectValue.Elem() } if !stmt.ReflectValue.IsValid() { - db.AddError(ErrInvalidValue) + _ = db.AddError(ErrInvalidValue) } } + // call scopes + for len(db.Statement.scopes) > 0 { + scopes := db.Statement.scopes + db.Statement.scopes = nil + for _, scope := range scopes { + db = scope(db) + } + } + + if stmt.Table == "" && stmt.TableExpr == nil && stmt.SQL.Len() == 0 { + _ = db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", modelParseError)) + } + for _, f := range p.fns { f(db) } diff --git a/tests/schema_in_scopes_test.go b/tests/schema_in_scopes_test.go new file mode 100644 index 00000000..54048e78 --- /dev/null +++ b/tests/schema_in_scopes_test.go @@ -0,0 +1,51 @@ +package tests_test + +import ( + "testing" + + "gorm.io/gorm" + "gorm.io/gorm/schema" + . "gorm.io/gorm/utils/tests" +) + +func TestSchemaAccessibleFromScopes(t *testing.T) { + users := []User{ + *GetUser("schema-scope-1", Config{}), + *GetUser("schema-scope-2", Config{}), + } + + if err := DB.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create users: %v", err) + } + + var schema *schema.Schema + var tableName string + scope := func(db *gorm.DB) *gorm.DB { + schema = db.Statement.Schema + tableName = db.Statement.Table + return db + } + + var results []User + if err := DB.Scopes(scope).Select("name", "age").Where("name like ?", "schema-scope-%").Find(&results).Error; err != nil { + t.Errorf("failed to query users, got error: %v", err) + } + + expects := []User{ + {Name: "schema-scope-1", Age: 18}, + {Name: "schema-scope-2", Age: 18}, + } + + if len(results) != 2 { + t.Fatalf("invalid results length found, expects: %v, got %v", len(expects), len(results)) + } + + expectedTableName := "users" + if tableName != expectedTableName { + t.Errorf("invalid table name found, expects: %v, got %v", expectedTableName, tableName) + } + + if schema == nil { + t.Errorf("invalid schema found, expected non-nil schema") + } +}