feat(Scopes): make Schema and Table accessible from scopes

This commit is contained in:
SystemGlitch 2022-10-25 16:18:57 +02:00
parent f82e9cfdbe
commit 3df271759b
No known key found for this signature in database
GPG Key ID: E09ED82897AEE53B
2 changed files with 69 additions and 15 deletions

View File

@ -73,15 +73,6 @@ func (cs *callbacks) Raw() *processor {
}
func (p *processor) Execute(db *DB) *DB {
// call scopes
for len(db.Statement.scopes) > 0 {
scopes := db.Statement.scopes
db.Statement.scopes = nil
for _, scope := range scopes {
db = scope(db)
}
}
var (
curTime = time.Now()
stmt = db.Statement
@ -100,13 +91,12 @@ func (p *processor) Execute(db *DB) *DB {
stmt.Dest = stmt.Model
}
var modelParseError error
// parse model values
if stmt.Model != nil {
if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.TableExpr == nil && stmt.SQL.Len() == 0)) {
if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" && stmt.TableExpr == nil {
db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err))
} else {
db.AddError(err)
if modelParseError = stmt.Parse(stmt.Model); modelParseError != nil && (!errors.Is(modelParseError, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.TableExpr == nil && stmt.SQL.Len() == 0)) {
if !errors.Is(modelParseError, schema.ErrUnsupportedDataType) || stmt.Table != "" || stmt.TableExpr != nil {
_ = db.AddError(modelParseError)
}
}
}
@ -122,10 +112,23 @@ func (p *processor) Execute(db *DB) *DB {
stmt.ReflectValue = stmt.ReflectValue.Elem()
}
if !stmt.ReflectValue.IsValid() {
db.AddError(ErrInvalidValue)
_ = db.AddError(ErrInvalidValue)
}
}
// call scopes
for len(db.Statement.scopes) > 0 {
scopes := db.Statement.scopes
db.Statement.scopes = nil
for _, scope := range scopes {
db = scope(db)
}
}
if stmt.Table == "" && stmt.TableExpr == nil && stmt.SQL.Len() == 0 {
_ = db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", modelParseError))
}
for _, f := range p.fns {
f(db)
}

View File

@ -0,0 +1,51 @@
package tests_test
import (
"testing"
"gorm.io/gorm"
"gorm.io/gorm/schema"
. "gorm.io/gorm/utils/tests"
)
func TestSchemaAccessibleFromScopes(t *testing.T) {
users := []User{
*GetUser("schema-scope-1", Config{}),
*GetUser("schema-scope-2", Config{}),
}
if err := DB.Create(&users).Error; err != nil {
t.Fatalf("errors happened when create users: %v", err)
}
var schema *schema.Schema
var tableName string
scope := func(db *gorm.DB) *gorm.DB {
schema = db.Statement.Schema
tableName = db.Statement.Table
return db
}
var results []User
if err := DB.Scopes(scope).Select("name", "age").Where("name like ?", "schema-scope-%").Find(&results).Error; err != nil {
t.Errorf("failed to query users, got error: %v", err)
}
expects := []User{
{Name: "schema-scope-1", Age: 18},
{Name: "schema-scope-2", Age: 18},
}
if len(results) != 2 {
t.Fatalf("invalid results length found, expects: %v, got %v", len(expects), len(results))
}
expectedTableName := "users"
if tableName != expectedTableName {
t.Errorf("invalid table name found, expects: %v, got %v", expectedTableName, tableName)
}
if schema == nil {
t.Errorf("invalid schema found, expected non-nil schema")
}
}