Merge remote-tracking branch 'upstream/master' into f-table-comment-0411

This commit is contained in:
John Mai 2023-04-11 23:10:41 +08:00
commit b067c10de6
24 changed files with 664 additions and 92 deletions

View File

@ -16,7 +16,7 @@ jobs:
ACTIONS_STEP_DEBUG: true ACTIONS_STEP_DEBUG: true
steps: steps:
- name: Close Stale Issues - name: Close Stale Issues
uses: actions/stale@v7 uses: actions/stale@v8
with: with:
repo-token: ${{ secrets.GITHUB_TOKEN }} repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 30 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 30 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"

View File

@ -16,7 +16,7 @@ jobs:
ACTIONS_STEP_DEBUG: true ACTIONS_STEP_DEBUG: true
steps: steps:
- name: Close Stale Issues - name: Close Stale Issues
uses: actions/stale@v7 uses: actions/stale@v8
with: with:
repo-token: ${{ secrets.GITHUB_TOKEN }} repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 30 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 30 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"

View File

@ -16,7 +16,7 @@ jobs:
ACTIONS_STEP_DEBUG: true ACTIONS_STEP_DEBUG: true
steps: steps:
- name: Close Stale Issues - name: Close Stale Issues
uses: actions/stale@v7 uses: actions/stale@v8
with: with:
repo-token: ${{ secrets.GITHUB_TOKEN }} repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "This issue has been automatically marked as stale because it has been open 360 days with no activity. Remove stale label or comment or this will be closed in 180 days" stale-issue-message: "This issue has been automatically marked as stale because it has been open 360 days with no activity. Remove stale label or comment or this will be closed in 180 days"

View File

@ -35,9 +35,7 @@ The fantastic ORM library for Golang, aims to be developer friendly.
## Contributors ## Contributors
Thank you for contributing to the GORM framework! [Thank you](https://github.com/go-gorm/gorm/graphs/contributors) for contributing to the GORM framework!
[![Contributors](https://contrib.rocks/image?repo=go-gorm/gorm)](https://github.com/go-gorm/gorm/graphs/contributors)
## License ## License

View File

@ -75,11 +75,7 @@ func (cs *callbacks) Raw() *processor {
func (p *processor) Execute(db *DB) *DB { func (p *processor) Execute(db *DB) *DB {
// call scopes // call scopes
for len(db.Statement.scopes) > 0 { for len(db.Statement.scopes) > 0 {
scopes := db.Statement.scopes db = db.executeScopes()
db.Statement.scopes = nil
for _, scope := range scopes {
db = scope(db)
}
} }
var ( var (

View File

@ -51,25 +51,40 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
} }
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
identityMap := map[string]bool{}
for i := 0; i < rValLen; i++ { for i := 0; i < rValLen; i++ {
obj := db.Statement.ReflectValue.Index(i) obj := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(obj).Kind() != reflect.Struct { if reflect.Indirect(obj).Kind() != reflect.Struct {
break break
} }
if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { // check belongs to relation value if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { // check belongs to relation value
rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) // relation reflect value rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) // relation reflect value
if !isPtr {
rv = rv.Addr()
}
objs = append(objs, obj) objs = append(objs, obj)
if isPtr { elems = reflect.Append(elems, rv)
elems = reflect.Append(elems, rv)
} else { relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
elems = reflect.Append(elems, rv.Addr()) for _, pf := range rel.FieldSchema.PrimaryFields {
if pfv, ok := pf.ValueOf(db.Statement.Context, rv); !ok {
relPrimaryValues = append(relPrimaryValues, pfv)
}
}
cacheKey := utils.ToStringKey(relPrimaryValues...)
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
if cacheKey != "" { // has primary fields
identityMap[cacheKey] = true
}
distinctElems = reflect.Append(distinctElems, rv)
} }
} }
} }
if elems.Len() > 0 { if elems.Len() > 0 {
if saveAssociations(db, rel, elems, selectColumns, restricted, nil) == nil { if saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil) == nil {
for i := 0; i < elems.Len(); i++ { for i := 0; i < elems.Len(); i++ {
setupReferences(objs[i], elems.Index(i)) setupReferences(objs[i], elems.Index(i))
} }

View File

@ -3,6 +3,7 @@ package callbacks
import ( import (
"fmt" "fmt"
"reflect" "reflect"
"strings"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
@ -10,6 +11,98 @@ import (
"gorm.io/gorm/utils" "gorm.io/gorm/utils"
) )
// parsePreloadMap extracts nested preloads. e.g.
//
// // schema has a "k0" relation and a "k7.k8" embedded relation
// parsePreloadMap(schema, map[string][]interface{}{
// clause.Associations: {"arg1"},
// "k1": {"arg2"},
// "k2.k3": {"arg3"},
// "k4.k5.k6": {"arg4"},
// })
// // preloadMap is
// map[string]map[string][]interface{}{
// "k0": {},
// "k7": {
// "k8": {},
// },
// "k1": {},
// "k2": {
// "k3": {"arg3"},
// },
// "k4": {
// "k5.k6": {"arg4"},
// },
// }
func parsePreloadMap(s *schema.Schema, preloads map[string][]interface{}) map[string]map[string][]interface{} {
preloadMap := map[string]map[string][]interface{}{}
setPreloadMap := func(name, value string, args []interface{}) {
if _, ok := preloadMap[name]; !ok {
preloadMap[name] = map[string][]interface{}{}
}
if value != "" {
preloadMap[name][value] = args
}
}
for name, args := range preloads {
preloadFields := strings.Split(name, ".")
value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), ".")
if preloadFields[0] == clause.Associations {
for _, relation := range s.Relationships.Relations {
if relation.Schema == s {
setPreloadMap(relation.Name, value, args)
}
}
for embedded, embeddedRelations := range s.Relationships.EmbeddedRelations {
for _, value := range embeddedValues(embeddedRelations) {
setPreloadMap(embedded, value, args)
}
}
} else {
setPreloadMap(preloadFields[0], value, args)
}
}
return preloadMap
}
func embeddedValues(embeddedRelations *schema.Relationships) []string {
if embeddedRelations == nil {
return nil
}
names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations))
for _, relation := range embeddedRelations.Relations {
// skip first struct name
names = append(names, strings.Join(relation.Field.BindNames[1:], "."))
}
for _, relations := range embeddedRelations.EmbeddedRelations {
names = append(names, embeddedValues(relations)...)
}
return names
}
func preloadEmbedded(tx *gorm.DB, relationships *schema.Relationships, s *schema.Schema, preloads map[string][]interface{}, as []interface{}) error {
if relationships == nil {
return nil
}
preloadMap := parsePreloadMap(s, preloads)
for name := range preloadMap {
if embeddedRelations := relationships.EmbeddedRelations[name]; embeddedRelations != nil {
if err := preloadEmbedded(tx, embeddedRelations, s, preloadMap[name], as); err != nil {
return err
}
} else if rel := relationships.Relations[name]; rel != nil {
if err := preload(tx, rel, append(preloads[name], as), preloadMap[name]); err != nil {
return err
}
} else {
return fmt.Errorf("%s: %w (embedded) for schema %s", name, gorm.ErrUnsupportedRelation, s.Name)
}
}
return nil
}
func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error { func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error {
var ( var (
reflectValue = tx.Statement.ReflectValue reflectValue = tx.Statement.ReflectValue

View File

@ -267,32 +267,7 @@ func Preload(db *gorm.DB) {
return return
} }
preloadMap := map[string]map[string][]interface{}{} preloadMap := parsePreloadMap(db.Statement.Schema, db.Statement.Preloads)
for name := range db.Statement.Preloads {
preloadFields := strings.Split(name, ".")
if preloadFields[0] == clause.Associations {
for _, rel := range db.Statement.Schema.Relationships.Relations {
if rel.Schema == db.Statement.Schema {
if _, ok := preloadMap[rel.Name]; !ok {
preloadMap[rel.Name] = map[string][]interface{}{}
}
if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" {
preloadMap[rel.Name][value] = db.Statement.Preloads[name]
}
}
}
} else {
if _, ok := preloadMap[preloadFields[0]]; !ok {
preloadMap[preloadFields[0]] = map[string][]interface{}{}
}
if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" {
preloadMap[preloadFields[0]][value] = db.Statement.Preloads[name]
}
}
}
preloadNames := make([]string, 0, len(preloadMap)) preloadNames := make([]string, 0, len(preloadMap))
for key := range preloadMap { for key := range preloadMap {
preloadNames = append(preloadNames, key) preloadNames = append(preloadNames, key)
@ -312,7 +287,9 @@ func Preload(db *gorm.DB) {
preloadDB.Statement.Unscoped = db.Statement.Unscoped preloadDB.Statement.Unscoped = db.Statement.Unscoped
for _, name := range preloadNames { for _, name := range preloadNames {
if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil { if relations := preloadDB.Statement.Schema.Relationships.EmbeddedRelations[name]; relations != nil {
db.AddError(preloadEmbedded(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), relations, db.Statement.Schema, preloadMap[name], db.Statement.Preloads[clause.Associations]))
} else if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil {
db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name])) db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name]))
} else { } else {
db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)) db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name))

View File

@ -366,6 +366,36 @@ func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) {
return tx return tx
} }
func (db *DB) executeScopes() (tx *DB) {
tx = db.getInstance()
scopes := db.Statement.scopes
if len(scopes) == 0 {
return tx
}
tx.Statement.scopes = nil
conditions := make([]clause.Interface, 0, 4)
if cs, ok := tx.Statement.Clauses["WHERE"]; ok && cs.Expression != nil {
conditions = append(conditions, cs.Expression.(clause.Interface))
cs.Expression = nil
tx.Statement.Clauses["WHERE"] = cs
}
for _, scope := range scopes {
tx = scope(tx)
if cs, ok := tx.Statement.Clauses["WHERE"]; ok && cs.Expression != nil {
conditions = append(conditions, cs.Expression.(clause.Interface))
cs.Expression = nil
tx.Statement.Clauses["WHERE"] = cs
}
}
for _, condition := range conditions {
tx.Statement.AddClause(condition)
}
return tx
}
// Preload preload associations with given conditions // Preload preload associations with given conditions
// //
// // get all users, and preload all non-cancelled orders // // get all users, and preload all non-cancelled orders

View File

@ -33,7 +33,7 @@ func (limit Limit) MergeClause(clause *Clause) {
clause.Name = "" clause.Name = ""
if v, ok := clause.Expression.(Limit); ok { if v, ok := clause.Expression.(Limit); ok {
if (limit.Limit == nil || *limit.Limit == 0) && (v.Limit != nil && *v.Limit != 0) { if (limit.Limit == nil || *limit.Limit == 0) && v.Limit != nil {
limit.Limit = v.Limit limit.Limit = v.Limit
} }

View File

@ -28,6 +28,10 @@ func TestLimit(t *testing.T) {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}}, []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}},
"SELECT * FROM `users` LIMIT 0", nil, "SELECT * FROM `users` LIMIT 0", nil,
}, },
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}, clause.Limit{Offset: 0}},
"SELECT * FROM `users` LIMIT 0", nil,
},
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}}, []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}},
"SELECT * FROM `users` OFFSET 20", nil, "SELECT * FROM `users` OFFSET 20", nil,

View File

@ -12,11 +12,7 @@ func (db *DB) Migrator() Migrator {
// apply scopes to migrator // apply scopes to migrator
for len(tx.Statement.scopes) > 0 { for len(tx.Statement.scopes) > 0 {
scopes := tx.Statement.scopes tx = tx.executeScopes()
tx.Statement.scopes = nil
for _, scope := range scopes {
tx = scope(tx)
}
} }
return tx.Dialector.Migrator(tx.Session(&Session{})) return tx.Dialector.Migrator(tx.Session(&Session{}))

View File

@ -17,12 +17,12 @@ func (idx Index) Table() string {
return idx.TableName return idx.TableName
} }
// Name return the name of the index. // Name return the name of the index.
func (idx Index) Name() string { func (idx Index) Name() string {
return idx.NameValue return idx.NameValue
} }
// Columns return the columns fo the index // Columns return the columns of the index
func (idx Index) Columns() []string { func (idx Index) Columns() []string {
return idx.ColumnList return idx.ColumnList
} }
@ -37,7 +37,7 @@ func (idx Index) Unique() (unique bool, ok bool) {
return idx.UniqueValue.Bool, idx.UniqueValue.Valid return idx.UniqueValue.Bool, idx.UniqueValue.Valid
} }
// Option return the optional attribute fo the index // Option return the optional attribute of the index
func (idx Index) Option() string { func (idx Index) Option() string {
return idx.OptionValue return idx.OptionValue
} }

View File

@ -113,7 +113,7 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
return err return err
} }
} else { } else {
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
columnTypes, err := queryTx.Migrator().ColumnTypes(value) columnTypes, err := queryTx.Migrator().ColumnTypes(value)
if err != nil { if err != nil {
return err return err
@ -123,7 +123,6 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
parseCheckConstraints = stmt.Schema.ParseCheckConstraints() parseCheckConstraints = stmt.Schema.ParseCheckConstraints()
) )
for _, dbName := range stmt.Schema.DBNames { for _, dbName := range stmt.Schema.DBNames {
field := stmt.Schema.FieldsByDBName[dbName]
var foundColumn gorm.ColumnType var foundColumn gorm.ColumnType
for _, columnType := range columnTypes { for _, columnType := range columnTypes {
@ -135,12 +134,15 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
if foundColumn == nil { if foundColumn == nil {
// not found, add column // not found, add column
if err := execTx.Migrator().AddColumn(value, dbName); err != nil { if err = execTx.Migrator().AddColumn(value, dbName); err != nil {
return err
}
} else {
// found, smartly migrate
field := stmt.Schema.FieldsByDBName[dbName]
if err = execTx.Migrator().MigrateColumn(value, field, foundColumn); err != nil {
return err return err
} }
} else if err := execTx.Migrator().MigrateColumn(value, field, foundColumn); err != nil {
// found, smart migrate
return err
} }
} }
@ -195,7 +197,7 @@ func (m Migrator) GetTables() (tableList []string, err error) {
func (m Migrator) CreateTable(values ...interface{}) error { func (m Migrator) CreateTable(values ...interface{}) error {
for _, value := range m.ReorderModels(values, false) { for _, value := range m.ReorderModels(values, false) {
tx := m.DB.Session(&gorm.Session{}) tx := m.DB.Session(&gorm.Session{})
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { if err := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
var ( var (
createTableSQL = "CREATE TABLE ? (" createTableSQL = "CREATE TABLE ? ("
values = []interface{}{m.CurrentTable(stmt)} values = []interface{}{m.CurrentTable(stmt)}
@ -214,7 +216,7 @@ func (m Migrator) CreateTable(values ...interface{}) error {
if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 { if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 {
createTableSQL += "PRIMARY KEY ?," createTableSQL += "PRIMARY KEY ?,"
primaryKeys := []interface{}{} primaryKeys := make([]interface{}, 0, len(stmt.Schema.PrimaryFields))
for _, field := range stmt.Schema.PrimaryFields { for _, field := range stmt.Schema.PrimaryFields {
primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName}) primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName})
} }
@ -225,8 +227,8 @@ func (m Migrator) CreateTable(values ...interface{}) error {
for _, idx := range stmt.Schema.ParseIndexes() { for _, idx := range stmt.Schema.ParseIndexes() {
if m.CreateIndexAfterCreateTable { if m.CreateIndexAfterCreateTable {
defer func(value interface{}, name string) { defer func(value interface{}, name string) {
if errr == nil { if err == nil {
errr = tx.Migrator().CreateIndex(value, name) err = tx.Migrator().CreateIndex(value, name)
} }
}(value, idx.Name) }(value, idx.Name)
} else { } else {
@ -276,8 +278,8 @@ func (m Migrator) CreateTable(values ...interface{}) error {
createTableSQL += fmt.Sprint(tableOption) createTableSQL += fmt.Sprint(tableOption)
} }
errr = tx.Exec(createTableSQL, values...).Error err = tx.Exec(createTableSQL, values...).Error
return errr return err
}); err != nil { }); err != nil {
return err return err
} }
@ -498,7 +500,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
currentDefaultNotNull := field.HasDefaultValue && (field.DefaultValueInterface != nil || !strings.EqualFold(field.DefaultValue, "NULL")) currentDefaultNotNull := field.HasDefaultValue && (field.DefaultValueInterface != nil || !strings.EqualFold(field.DefaultValue, "NULL"))
dv, dvNotNull := columnType.DefaultValue() dv, dvNotNull := columnType.DefaultValue()
if dvNotNull && !currentDefaultNotNull { if dvNotNull && !currentDefaultNotNull {
// defalut value -> null // default value -> null
alterColumn = true alterColumn = true
} else if !dvNotNull && currentDefaultNotNull { } else if !dvNotNull && currentDefaultNotNull {
// null -> default value // null -> default value

View File

@ -89,6 +89,10 @@ type Field struct {
NewValuePool FieldNewValuePool NewValuePool FieldNewValuePool
} }
func (field *Field) BindName() string {
return strings.Join(field.BindNames, ".")
}
// ParseField parses reflect.StructField to Field // ParseField parses reflect.StructField to Field
func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
var ( var (
@ -580,8 +584,6 @@ func (field *Field) setupValuerAndSetter() {
case **bool: case **bool:
if data != nil && *data != nil { if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetBool(**data) field.ReflectValueOf(ctx, value).SetBool(**data)
} else {
field.ReflectValueOf(ctx, value).SetBool(false)
} }
case bool: case bool:
field.ReflectValueOf(ctx, value).SetBool(data) field.ReflectValueOf(ctx, value).SetBool(data)
@ -601,8 +603,6 @@ func (field *Field) setupValuerAndSetter() {
case **int64: case **int64:
if data != nil && *data != nil { if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetInt(**data) field.ReflectValueOf(ctx, value).SetInt(**data)
} else {
field.ReflectValueOf(ctx, value).SetInt(0)
} }
case int64: case int64:
field.ReflectValueOf(ctx, value).SetInt(data) field.ReflectValueOf(ctx, value).SetInt(data)
@ -667,8 +667,6 @@ func (field *Field) setupValuerAndSetter() {
case **uint64: case **uint64:
if data != nil && *data != nil { if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetUint(**data) field.ReflectValueOf(ctx, value).SetUint(**data)
} else {
field.ReflectValueOf(ctx, value).SetUint(0)
} }
case uint64: case uint64:
field.ReflectValueOf(ctx, value).SetUint(data) field.ReflectValueOf(ctx, value).SetUint(data)
@ -721,8 +719,6 @@ func (field *Field) setupValuerAndSetter() {
case **float64: case **float64:
if data != nil && *data != nil { if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetFloat(**data) field.ReflectValueOf(ctx, value).SetFloat(**data)
} else {
field.ReflectValueOf(ctx, value).SetFloat(0)
} }
case float64: case float64:
field.ReflectValueOf(ctx, value).SetFloat(data) field.ReflectValueOf(ctx, value).SetFloat(data)
@ -767,8 +763,6 @@ func (field *Field) setupValuerAndSetter() {
case **string: case **string:
if data != nil && *data != nil { if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetString(**data) field.ReflectValueOf(ctx, value).SetString(**data)
} else {
field.ReflectValueOf(ctx, value).SetString("")
} }
case string: case string:
field.ReflectValueOf(ctx, value).SetString(data) field.ReflectValueOf(ctx, value).SetString(data)

View File

@ -27,6 +27,8 @@ type Relationships struct {
HasMany []*Relationship HasMany []*Relationship
Many2Many []*Relationship Many2Many []*Relationship
Relations map[string]*Relationship Relations map[string]*Relationship
EmbeddedRelations map[string]*Relationships
} }
type Relationship struct { type Relationship struct {
@ -106,7 +108,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
} }
if schema.err == nil { if schema.err == nil {
schema.Relationships.Relations[relation.Name] = relation schema.setRelation(relation)
switch relation.Type { switch relation.Type {
case HasOne: case HasOne:
schema.Relationships.HasOne = append(schema.Relationships.HasOne, relation) schema.Relationships.HasOne = append(schema.Relationships.HasOne, relation)
@ -122,6 +124,39 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
return relation return relation
} }
func (schema *Schema) setRelation(relation *Relationship) {
// set non-embedded relation
if rel := schema.Relationships.Relations[relation.Name]; rel != nil {
if len(rel.Field.BindNames) > 1 {
schema.Relationships.Relations[relation.Name] = relation
}
} else {
schema.Relationships.Relations[relation.Name] = relation
}
// set embedded relation
if len(relation.Field.BindNames) <= 1 {
return
}
relationships := &schema.Relationships
for i, name := range relation.Field.BindNames {
if i < len(relation.Field.BindNames)-1 {
if relationships.EmbeddedRelations == nil {
relationships.EmbeddedRelations = map[string]*Relationships{}
}
if r := relationships.EmbeddedRelations[name]; r == nil {
relationships.EmbeddedRelations[name] = &Relationships{}
}
relationships = relationships.EmbeddedRelations[name]
} else {
if relationships.Relations == nil {
relationships.Relations = map[string]*Relationship{}
}
relationships.Relations[relation.Name] = relation
}
}
}
// User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner` // User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner`
// //
// type User struct { // type User struct {
@ -166,6 +201,11 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi
} }
} }
if primaryKeyField == nil {
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field", relation.FieldSchema, schema, field.Name)
return
}
// use same data type for foreign keys // use same data type for foreign keys
if copyableDataType(primaryKeyField.DataType) { if copyableDataType(primaryKeyField.DataType) {
relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType
@ -443,6 +483,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
primaryFields = primarySchema.PrimaryFields primaryFields = primarySchema.PrimaryFields
} }
primaryFieldLoop:
for _, primaryField := range primaryFields { for _, primaryField := range primaryFields {
lookUpName := primarySchemaName + primaryField.Name lookUpName := primarySchemaName + primaryField.Name
if gl == guessBelongs { if gl == guessBelongs {
@ -454,11 +495,18 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID", strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID")) lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID", strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID"))
} }
for _, name := range lookUpNames {
if f := foreignSchema.LookUpFieldByBindName(field.BindNames, name); f != nil {
foreignFields = append(foreignFields, f)
primaryFields = append(primaryFields, primaryField)
continue primaryFieldLoop
}
}
for _, name := range lookUpNames { for _, name := range lookUpNames {
if f := foreignSchema.LookUpField(name); f != nil { if f := foreignSchema.LookUpField(name); f != nil {
foreignFields = append(foreignFields, f) foreignFields = append(foreignFields, f)
primaryFields = append(primaryFields, primaryField) primaryFields = append(primaryFields, primaryField)
break continue primaryFieldLoop
} }
} }
} }

View File

@ -518,6 +518,132 @@ func TestEmbeddedRelation(t *testing.T) {
} }
} }
func TestEmbeddedHas(t *testing.T) {
type Toy struct {
ID int
Name string
OwnerID int
OwnerType string
}
type User struct {
ID int
Cat struct {
Name string
Toy Toy `gorm:"polymorphic:Owner;"`
Toys []Toy `gorm:"polymorphic:Owner;"`
} `gorm:"embedded;embeddedPrefix:cat_"`
Dog struct {
ID int
Name string
UserID int
Toy Toy `gorm:"polymorphic:Owner;"`
Toys []Toy `gorm:"polymorphic:Owner;"`
}
Toys []Toy `gorm:"polymorphic:Owner;"`
}
s, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"Cat": {
Relations: map[string]Relation{
"Toy": {
Name: "Toy",
Type: schema.HasOne,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
References: []Reference{
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
"Toys": {
Name: "Toys",
Type: schema.HasMany,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
References: []Reference{
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
},
},
})
}
func TestEmbeddedBelongsTo(t *testing.T) {
type Country struct {
ID int `gorm:"primaryKey"`
Name string
}
type Address struct {
CountryID int
Country Country
}
type NestedAddress struct {
Address
}
type Org struct {
ID int
PostalAddress Address `gorm:"embedded;embeddedPrefix:postal_address_"`
VisitingAddress Address `gorm:"embedded;embeddedPrefix:visiting_address_"`
AddressID int
Address struct {
ID int
Address
}
NestedAddress *NestedAddress `gorm:"embedded;embeddedPrefix:nested_address_"`
}
s, err := schema.Parse(&Org{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Errorf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"PostalAddress": {
Relations: map[string]Relation{
"Country": {
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
References: []Reference{
{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
},
},
},
},
"VisitingAddress": {
Relations: map[string]Relation{
"Country": {
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
References: []Reference{
{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
},
},
},
},
"NestedAddress": {
EmbeddedRelations: map[string]EmbeddedRelations{
"Address": {
Relations: map[string]Relation{
"Country": {
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
References: []Reference{
{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
},
},
},
},
},
},
})
}
func TestVariableRelation(t *testing.T) { func TestVariableRelation(t *testing.T) {
var result struct { var result struct {
User User

View File

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"go/ast" "go/ast"
"reflect" "reflect"
"strings"
"sync" "sync"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
@ -25,6 +26,7 @@ type Schema struct {
PrimaryFieldDBNames []string PrimaryFieldDBNames []string
Fields []*Field Fields []*Field
FieldsByName map[string]*Field FieldsByName map[string]*Field
FieldsByBindName map[string]*Field // embedded fields is 'Embed.Field'
FieldsByDBName map[string]*Field FieldsByDBName map[string]*Field
FieldsWithDefaultDBValue []*Field // fields with default value assigned by database FieldsWithDefaultDBValue []*Field // fields with default value assigned by database
Relationships Relationships Relationships Relationships
@ -67,6 +69,27 @@ func (schema Schema) LookUpField(name string) *Field {
return nil return nil
} }
// LookUpFieldByBindName looks for the closest field in the embedded struct.
//
// type Struct struct {
// Embedded struct {
// ID string // is selected by LookUpFieldByBindName([]string{"Embedded", "ID"}, "ID")
// }
// ID string // is selected by LookUpFieldByBindName([]string{"ID"}, "ID")
// }
func (schema Schema) LookUpFieldByBindName(bindNames []string, name string) *Field {
if len(bindNames) == 0 {
return nil
}
for i := len(bindNames) - 1; i >= 0; i-- {
find := strings.Join(bindNames[:i], ".") + "." + name
if field, ok := schema.FieldsByBindName[find]; ok {
return field
}
}
return nil
}
type Tabler interface { type Tabler interface {
TableName() string TableName() string
} }
@ -140,15 +163,16 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
} }
schema := &Schema{ schema := &Schema{
Name: modelType.Name(), Name: modelType.Name(),
ModelType: modelType, ModelType: modelType,
Table: tableName, Table: tableName,
FieldsByName: map[string]*Field{}, FieldsByName: map[string]*Field{},
FieldsByDBName: map[string]*Field{}, FieldsByBindName: map[string]*Field{},
Relationships: Relationships{Relations: map[string]*Relationship{}}, FieldsByDBName: map[string]*Field{},
cacheStore: cacheStore, Relationships: Relationships{Relations: map[string]*Relationship{}},
namer: namer, cacheStore: cacheStore,
initialized: make(chan struct{}), namer: namer,
initialized: make(chan struct{}),
} }
// When the schema initialization is completed, the channel will be closed // When the schema initialization is completed, the channel will be closed
defer close(schema.initialized) defer close(schema.initialized)
@ -176,6 +200,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
field.DBName = namer.ColumnName(schema.Table, field.Name) field.DBName = namer.ColumnName(schema.Table, field.Name)
} }
bindName := field.BindName()
if field.DBName != "" { if field.DBName != "" {
// nonexistence or shortest path or first appear prioritized if has permission // nonexistence or shortest path or first appear prioritized if has permission
if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) { if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) {
@ -184,6 +209,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
} }
schema.FieldsByDBName[field.DBName] = field schema.FieldsByDBName[field.DBName] = field
schema.FieldsByName[field.Name] = field schema.FieldsByName[field.Name] = field
schema.FieldsByBindName[bindName] = field
if v != nil && v.PrimaryKey { if v != nil && v.PrimaryKey {
for idx, f := range schema.PrimaryFields { for idx, f := range schema.PrimaryFields {
@ -202,6 +228,9 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" { if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" {
schema.FieldsByName[field.Name] = field schema.FieldsByName[field.Name] = field
} }
if of, ok := schema.FieldsByBindName[bindName]; !ok || of.TagSettings["-"] == "-" {
schema.FieldsByBindName[bindName] = field
}
field.setupValuerAndSetter() field.setupValuerAndSetter()
} }
@ -293,6 +322,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
return schema, schema.err return schema, schema.err
} else { } else {
schema.FieldsByName[field.Name] = field schema.FieldsByName[field.Name] = field
schema.FieldsByBindName[field.BindName()] = field
} }
} }

View File

@ -201,6 +201,37 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) {
}) })
} }
type EmbeddedRelations struct {
Relations map[string]Relation
EmbeddedRelations map[string]EmbeddedRelations
}
func checkEmbeddedRelations(t *testing.T, actual map[string]*schema.Relationships, expected map[string]EmbeddedRelations) {
for name, relations := range actual {
rs := expected[name]
t.Run("CheckEmbeddedRelations/"+name, func(t *testing.T) {
if len(relations.Relations) != len(rs.Relations) {
t.Errorf("schema relations count don't match, expects %d, got %d", len(rs.Relations), len(relations.Relations))
}
if len(relations.EmbeddedRelations) != len(rs.EmbeddedRelations) {
t.Errorf("schema embedded relations count don't match, expects %d, got %d", len(rs.EmbeddedRelations), len(relations.EmbeddedRelations))
}
for n, rel := range relations.Relations {
if r, ok := rs.Relations[n]; !ok {
t.Errorf("failed to find relation by name %s", n)
} else {
checkSchemaRelation(t, &schema.Schema{
Relationships: schema.Relationships{
Relations: map[string]*schema.Relationship{n: rel},
},
}, r)
}
}
checkEmbeddedRelations(t, relations.EmbeddedRelations, rs.EmbeddedRelations)
})
}
}
func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) { func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) {
for k, v := range values { for k, v := range values {
t.Run("CheckField/"+k, func(t *testing.T) { t.Run("CheckField/"+k, func(t *testing.T) {

View File

@ -324,11 +324,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
case clause.Expression: case clause.Expression:
conds = append(conds, v) conds = append(conds, v)
case *DB: case *DB:
for _, scope := range v.Statement.scopes { v.executeScopes()
v = scope(v)
}
if cs, ok := v.Statement.Clauses["WHERE"]; ok { if cs, ok := v.Statement.Clauses["WHERE"]; ok && cs.Expression != nil {
if where, ok := cs.Expression.(clause.Where); ok { if where, ok := cs.Expression.(clause.Where); ok {
if len(where.Exprs) == 1 { if len(where.Exprs) == 1 {
if orConds, ok := where.Exprs[0].(clause.OrConditions); ok { if orConds, ok := where.Exprs[0].(clause.OrConditions); ok {
@ -336,9 +334,13 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
} }
} }
conds = append(conds, clause.And(where.Exprs...)) conds = append(conds, clause.And(where.Exprs...))
} else if cs.Expression != nil { } else {
conds = append(conds, cs.Expression) conds = append(conds, cs.Expression)
} }
if v.Statement == stmt {
cs.Expression = nil
stmt.Statement.Clauses["WHERE"] = cs
}
} }
case map[interface{}]interface{}: case map[interface{}]interface{}:
for i, j := range v { for i, j := range v {

View File

@ -393,3 +393,33 @@ func TestConcurrentMany2ManyAssociation(t *testing.T) {
AssertEqual(t, err, nil) AssertEqual(t, err, nil)
AssertAssociationCount(t, find, "Languages", int64(count), "after concurrent append") AssertAssociationCount(t, find, "Languages", int64(count), "after concurrent append")
} }
func TestMany2ManyDuplicateBelongsToAssociation(t *testing.T) {
user1 := User{Name: "TestMany2ManyDuplicateBelongsToAssociation-1", Friends: []*User{
{Name: "TestMany2ManyDuplicateBelongsToAssociation-friend-1", Company: Company{
ID: 1,
Name: "Test-company-1",
}},
}}
user2 := User{Name: "TestMany2ManyDuplicateBelongsToAssociation-2", Friends: []*User{
{Name: "TestMany2ManyDuplicateBelongsToAssociation-friend-2", Company: Company{
ID: 1,
Name: "Test-company-1",
}},
}}
users := []*User{&user1, &user2}
var err error
err = DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(users).Error
AssertEqual(t, nil, err)
var findUser1 User
err = DB.Preload("Friends.Company").Where("id = ?", user1.ID).First(&findUser1).Error
AssertEqual(t, nil, err)
AssertEqual(t, user1, findUser1)
var findUser2 User
err = DB.Preload("Friends.Company").Where("id = ?", user2.ID).First(&findUser2).Error
AssertEqual(t, nil, err)
AssertEqual(t, user2, findUser2)
}

View File

@ -103,9 +103,16 @@ func TestEmbeddedPointerTypeStruct(t *testing.T) {
URL string URL string
} }
type Author struct {
ID string
Name string
Email string
}
type HNPost struct { type HNPost struct {
*BasePost *BasePost
Upvotes int32 Upvotes int32
*Author `gorm:"EmbeddedPrefix:user_"` // Embedded struct
} }
DB.Migrator().DropTable(&HNPost{}) DB.Migrator().DropTable(&HNPost{})
@ -123,6 +130,10 @@ func TestEmbeddedPointerTypeStruct(t *testing.T) {
if hnPost.Title != "embedded_pointer_type" { if hnPost.Title != "embedded_pointer_type" {
t.Errorf("Should find correct value for embedded pointer type") t.Errorf("Should find correct value for embedded pointer type")
} }
if hnPost.Author != nil {
t.Errorf("Expected to get back a nil Author but got: %v", hnPost.Author)
}
} }
type Content struct { type Content struct {

View File

@ -306,3 +306,141 @@ func TestNestedPreloadWithUnscoped(t *testing.T) {
DB.Unscoped().Preload("Pets.Toy").Find(&user6, "id = ?", user.ID) DB.Unscoped().Preload("Pets.Toy").Find(&user6, "id = ?", user.ID)
CheckUserUnscoped(t, *user6, user) CheckUserUnscoped(t, *user6, user)
} }
func TestEmbedPreload(t *testing.T) {
type Country struct {
ID int `gorm:"primaryKey"`
Name string
}
type EmbeddedAddress struct {
ID int
Name string
CountryID *int
Country *Country
}
type NestedAddress struct {
EmbeddedAddress
}
type Org struct {
ID int
PostalAddress EmbeddedAddress `gorm:"embedded;embeddedPrefix:postal_address_"`
VisitingAddress EmbeddedAddress `gorm:"embedded;embeddedPrefix:visiting_address_"`
AddressID *int
Address *EmbeddedAddress
NestedAddress NestedAddress `gorm:"embedded;embeddedPrefix:nested_address_"`
}
DB.Migrator().DropTable(&Org{}, &EmbeddedAddress{}, &Country{})
DB.AutoMigrate(&Org{}, &EmbeddedAddress{}, &Country{})
org := Org{
PostalAddress: EmbeddedAddress{Name: "a1", Country: &Country{Name: "c1"}},
VisitingAddress: EmbeddedAddress{Name: "a2", Country: &Country{Name: "c2"}},
Address: &EmbeddedAddress{Name: "a3", Country: &Country{Name: "c3"}},
NestedAddress: NestedAddress{
EmbeddedAddress: EmbeddedAddress{Name: "a4", Country: &Country{Name: "c4"}},
},
}
if err := DB.Create(&org).Error; err != nil {
t.Errorf("failed to create org, got err: %v", err)
}
tests := []struct {
name string
preloads map[string][]interface{}
expect Org
}{
{
name: "address country",
preloads: map[string][]interface{}{"Address.Country": {}},
expect: Org{
ID: org.ID,
PostalAddress: EmbeddedAddress{
ID: org.PostalAddress.ID,
Name: org.PostalAddress.Name,
CountryID: org.PostalAddress.CountryID,
Country: nil,
},
VisitingAddress: EmbeddedAddress{
ID: org.VisitingAddress.ID,
Name: org.VisitingAddress.Name,
CountryID: org.VisitingAddress.CountryID,
Country: nil,
},
AddressID: org.AddressID,
Address: org.Address,
NestedAddress: NestedAddress{EmbeddedAddress{
ID: org.NestedAddress.ID,
Name: org.NestedAddress.Name,
CountryID: org.NestedAddress.CountryID,
Country: nil,
}},
},
}, {
name: "postal address country",
preloads: map[string][]interface{}{"PostalAddress.Country": {}},
expect: Org{
ID: org.ID,
PostalAddress: org.PostalAddress,
VisitingAddress: EmbeddedAddress{
ID: org.VisitingAddress.ID,
Name: org.VisitingAddress.Name,
CountryID: org.VisitingAddress.CountryID,
Country: nil,
},
AddressID: org.AddressID,
Address: nil,
NestedAddress: NestedAddress{EmbeddedAddress{
ID: org.NestedAddress.ID,
Name: org.NestedAddress.Name,
CountryID: org.NestedAddress.CountryID,
Country: nil,
}},
},
}, {
name: "nested address country",
preloads: map[string][]interface{}{"NestedAddress.EmbeddedAddress.Country": {}},
expect: Org{
ID: org.ID,
PostalAddress: EmbeddedAddress{
ID: org.PostalAddress.ID,
Name: org.PostalAddress.Name,
CountryID: org.PostalAddress.CountryID,
Country: nil,
},
VisitingAddress: EmbeddedAddress{
ID: org.VisitingAddress.ID,
Name: org.VisitingAddress.Name,
CountryID: org.VisitingAddress.CountryID,
Country: nil,
},
AddressID: org.AddressID,
Address: nil,
NestedAddress: org.NestedAddress,
},
}, {
name: "associations",
preloads: map[string][]interface{}{
clause.Associations: {},
// clause.Associations wont preload nested associations
"Address.Country": {},
},
expect: org,
},
}
DB = DB.Debug()
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
actual := Org{}
tx := DB.Where("id = ?", org.ID).Session(&gorm.Session{})
for name, args := range test.preloads {
tx = tx.Preload(name, args...)
}
if err := tx.Find(&actual).Error; err != nil {
t.Errorf("failed to find org, got err: %v", err)
}
AssertEqual(t, actual, test.expect)
})
}
}

View File

@ -72,3 +72,54 @@ func TestScopes(t *testing.T) {
t.Errorf("select max(id)") t.Errorf("select max(id)")
} }
} }
func TestComplexScopes(t *testing.T) {
tests := []struct {
name string
queryFn func(tx *gorm.DB) *gorm.DB
expected string
}{
{
name: "depth_1",
queryFn: func(tx *gorm.DB) *gorm.DB {
return tx.Scopes(
func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") },
func(d *gorm.DB) *gorm.DB { return d.Where(d.Or("b = 2").Or("c = 3")) },
).Find(&Language{})
},
expected: `SELECT * FROM "languages" WHERE a = 1 AND (b = 2 OR c = 3)`,
}, {
name: "depth_1_pre_cond",
queryFn: func(tx *gorm.DB) *gorm.DB {
return tx.Where("z = 0").Scopes(
func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") },
func(d *gorm.DB) *gorm.DB { return d.Or(d.Where("b = 2").Or("c = 3")) },
).Find(&Language{})
},
expected: `SELECT * FROM "languages" WHERE z = 0 AND a = 1 OR (b = 2 OR c = 3)`,
}, {
name: "depth_2",
queryFn: func(tx *gorm.DB) *gorm.DB {
return tx.Scopes(
func(d *gorm.DB) *gorm.DB { return d.Model(&Language{}) },
func(d *gorm.DB) *gorm.DB {
return d.
Or(d.Scopes(
func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") },
func(d *gorm.DB) *gorm.DB { return d.Where("b = 2") },
)).
Or("c = 3")
},
func(d *gorm.DB) *gorm.DB { return d.Where("d = 4") },
).Find(&Language{})
},
expected: `SELECT * FROM "languages" WHERE d = 4 OR c = 3 OR (a = 1 AND b = 2)`,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
assertEqualSQL(t, test.expected, DB.ToSQL(test.queryFn))
})
}
}