Merge remote-tracking branch 'upstream/master' into f-table-comment-0411

This commit is contained in:
John Mai 2023-04-11 23:10:41 +08:00
commit b067c10de6
24 changed files with 664 additions and 92 deletions

View File

@ -16,7 +16,7 @@ jobs:
ACTIONS_STEP_DEBUG: true
steps:
- name: Close Stale Issues
uses: actions/stale@v7
uses: actions/stale@v8
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 30 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"

View File

@ -16,7 +16,7 @@ jobs:
ACTIONS_STEP_DEBUG: true
steps:
- name: Close Stale Issues
uses: actions/stale@v7
uses: actions/stale@v8
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 30 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"

View File

@ -16,7 +16,7 @@ jobs:
ACTIONS_STEP_DEBUG: true
steps:
- name: Close Stale Issues
uses: actions/stale@v7
uses: actions/stale@v8
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "This issue has been automatically marked as stale because it has been open 360 days with no activity. Remove stale label or comment or this will be closed in 180 days"

View File

@ -35,9 +35,7 @@ The fantastic ORM library for Golang, aims to be developer friendly.
## Contributors
Thank you for contributing to the GORM framework!
[![Contributors](https://contrib.rocks/image?repo=go-gorm/gorm)](https://github.com/go-gorm/gorm/graphs/contributors)
[Thank you](https://github.com/go-gorm/gorm/graphs/contributors) for contributing to the GORM framework!
## License

View File

@ -75,11 +75,7 @@ func (cs *callbacks) Raw() *processor {
func (p *processor) Execute(db *DB) *DB {
// call scopes
for len(db.Statement.scopes) > 0 {
scopes := db.Statement.scopes
db.Statement.scopes = nil
for _, scope := range scopes {
db = scope(db)
}
db = db.executeScopes()
}
var (

View File

@ -51,25 +51,40 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
}
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
identityMap := map[string]bool{}
for i := 0; i < rValLen; i++ {
obj := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(obj).Kind() != reflect.Struct {
break
}
if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { // check belongs to relation value
rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) // relation reflect value
if !isPtr {
rv = rv.Addr()
}
objs = append(objs, obj)
if isPtr {
elems = reflect.Append(elems, rv)
} else {
elems = reflect.Append(elems, rv.Addr())
elems = reflect.Append(elems, rv)
relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
for _, pf := range rel.FieldSchema.PrimaryFields {
if pfv, ok := pf.ValueOf(db.Statement.Context, rv); !ok {
relPrimaryValues = append(relPrimaryValues, pfv)
}
}
cacheKey := utils.ToStringKey(relPrimaryValues...)
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
if cacheKey != "" { // has primary fields
identityMap[cacheKey] = true
}
distinctElems = reflect.Append(distinctElems, rv)
}
}
}
if elems.Len() > 0 {
if saveAssociations(db, rel, elems, selectColumns, restricted, nil) == nil {
if saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil) == nil {
for i := 0; i < elems.Len(); i++ {
setupReferences(objs[i], elems.Index(i))
}

View File

@ -3,6 +3,7 @@ package callbacks
import (
"fmt"
"reflect"
"strings"
"gorm.io/gorm"
"gorm.io/gorm/clause"
@ -10,6 +11,98 @@ import (
"gorm.io/gorm/utils"
)
// parsePreloadMap extracts nested preloads. e.g.
//
// // schema has a "k0" relation and a "k7.k8" embedded relation
// parsePreloadMap(schema, map[string][]interface{}{
// clause.Associations: {"arg1"},
// "k1": {"arg2"},
// "k2.k3": {"arg3"},
// "k4.k5.k6": {"arg4"},
// })
// // preloadMap is
// map[string]map[string][]interface{}{
// "k0": {},
// "k7": {
// "k8": {},
// },
// "k1": {},
// "k2": {
// "k3": {"arg3"},
// },
// "k4": {
// "k5.k6": {"arg4"},
// },
// }
func parsePreloadMap(s *schema.Schema, preloads map[string][]interface{}) map[string]map[string][]interface{} {
preloadMap := map[string]map[string][]interface{}{}
setPreloadMap := func(name, value string, args []interface{}) {
if _, ok := preloadMap[name]; !ok {
preloadMap[name] = map[string][]interface{}{}
}
if value != "" {
preloadMap[name][value] = args
}
}
for name, args := range preloads {
preloadFields := strings.Split(name, ".")
value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), ".")
if preloadFields[0] == clause.Associations {
for _, relation := range s.Relationships.Relations {
if relation.Schema == s {
setPreloadMap(relation.Name, value, args)
}
}
for embedded, embeddedRelations := range s.Relationships.EmbeddedRelations {
for _, value := range embeddedValues(embeddedRelations) {
setPreloadMap(embedded, value, args)
}
}
} else {
setPreloadMap(preloadFields[0], value, args)
}
}
return preloadMap
}
func embeddedValues(embeddedRelations *schema.Relationships) []string {
if embeddedRelations == nil {
return nil
}
names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations))
for _, relation := range embeddedRelations.Relations {
// skip first struct name
names = append(names, strings.Join(relation.Field.BindNames[1:], "."))
}
for _, relations := range embeddedRelations.EmbeddedRelations {
names = append(names, embeddedValues(relations)...)
}
return names
}
func preloadEmbedded(tx *gorm.DB, relationships *schema.Relationships, s *schema.Schema, preloads map[string][]interface{}, as []interface{}) error {
if relationships == nil {
return nil
}
preloadMap := parsePreloadMap(s, preloads)
for name := range preloadMap {
if embeddedRelations := relationships.EmbeddedRelations[name]; embeddedRelations != nil {
if err := preloadEmbedded(tx, embeddedRelations, s, preloadMap[name], as); err != nil {
return err
}
} else if rel := relationships.Relations[name]; rel != nil {
if err := preload(tx, rel, append(preloads[name], as), preloadMap[name]); err != nil {
return err
}
} else {
return fmt.Errorf("%s: %w (embedded) for schema %s", name, gorm.ErrUnsupportedRelation, s.Name)
}
}
return nil
}
func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error {
var (
reflectValue = tx.Statement.ReflectValue

View File

@ -267,32 +267,7 @@ func Preload(db *gorm.DB) {
return
}
preloadMap := map[string]map[string][]interface{}{}
for name := range db.Statement.Preloads {
preloadFields := strings.Split(name, ".")
if preloadFields[0] == clause.Associations {
for _, rel := range db.Statement.Schema.Relationships.Relations {
if rel.Schema == db.Statement.Schema {
if _, ok := preloadMap[rel.Name]; !ok {
preloadMap[rel.Name] = map[string][]interface{}{}
}
if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" {
preloadMap[rel.Name][value] = db.Statement.Preloads[name]
}
}
}
} else {
if _, ok := preloadMap[preloadFields[0]]; !ok {
preloadMap[preloadFields[0]] = map[string][]interface{}{}
}
if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" {
preloadMap[preloadFields[0]][value] = db.Statement.Preloads[name]
}
}
}
preloadMap := parsePreloadMap(db.Statement.Schema, db.Statement.Preloads)
preloadNames := make([]string, 0, len(preloadMap))
for key := range preloadMap {
preloadNames = append(preloadNames, key)
@ -312,7 +287,9 @@ func Preload(db *gorm.DB) {
preloadDB.Statement.Unscoped = db.Statement.Unscoped
for _, name := range preloadNames {
if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil {
if relations := preloadDB.Statement.Schema.Relationships.EmbeddedRelations[name]; relations != nil {
db.AddError(preloadEmbedded(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), relations, db.Statement.Schema, preloadMap[name], db.Statement.Preloads[clause.Associations]))
} else if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil {
db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name]))
} else {
db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name))

View File

@ -366,6 +366,36 @@ func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) {
return tx
}
func (db *DB) executeScopes() (tx *DB) {
tx = db.getInstance()
scopes := db.Statement.scopes
if len(scopes) == 0 {
return tx
}
tx.Statement.scopes = nil
conditions := make([]clause.Interface, 0, 4)
if cs, ok := tx.Statement.Clauses["WHERE"]; ok && cs.Expression != nil {
conditions = append(conditions, cs.Expression.(clause.Interface))
cs.Expression = nil
tx.Statement.Clauses["WHERE"] = cs
}
for _, scope := range scopes {
tx = scope(tx)
if cs, ok := tx.Statement.Clauses["WHERE"]; ok && cs.Expression != nil {
conditions = append(conditions, cs.Expression.(clause.Interface))
cs.Expression = nil
tx.Statement.Clauses["WHERE"] = cs
}
}
for _, condition := range conditions {
tx.Statement.AddClause(condition)
}
return tx
}
// Preload preload associations with given conditions
//
// // get all users, and preload all non-cancelled orders

View File

@ -33,7 +33,7 @@ func (limit Limit) MergeClause(clause *Clause) {
clause.Name = ""
if v, ok := clause.Expression.(Limit); ok {
if (limit.Limit == nil || *limit.Limit == 0) && (v.Limit != nil && *v.Limit != 0) {
if (limit.Limit == nil || *limit.Limit == 0) && v.Limit != nil {
limit.Limit = v.Limit
}

View File

@ -28,6 +28,10 @@ func TestLimit(t *testing.T) {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}},
"SELECT * FROM `users` LIMIT 0", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}, clause.Limit{Offset: 0}},
"SELECT * FROM `users` LIMIT 0", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}},
"SELECT * FROM `users` OFFSET 20", nil,

View File

@ -12,11 +12,7 @@ func (db *DB) Migrator() Migrator {
// apply scopes to migrator
for len(tx.Statement.scopes) > 0 {
scopes := tx.Statement.scopes
tx.Statement.scopes = nil
for _, scope := range scopes {
tx = scope(tx)
}
tx = tx.executeScopes()
}
return tx.Dialector.Migrator(tx.Session(&Session{}))

View File

@ -17,12 +17,12 @@ func (idx Index) Table() string {
return idx.TableName
}
// Name return the name of the index.
// Name return the name of the index.
func (idx Index) Name() string {
return idx.NameValue
}
// Columns return the columns fo the index
// Columns return the columns of the index
func (idx Index) Columns() []string {
return idx.ColumnList
}
@ -37,7 +37,7 @@ func (idx Index) Unique() (unique bool, ok bool) {
return idx.UniqueValue.Bool, idx.UniqueValue.Valid
}
// Option return the optional attribute fo the index
// Option return the optional attribute of the index
func (idx Index) Option() string {
return idx.OptionValue
}

View File

@ -113,7 +113,7 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
return err
}
} else {
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
columnTypes, err := queryTx.Migrator().ColumnTypes(value)
if err != nil {
return err
@ -123,7 +123,6 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
parseCheckConstraints = stmt.Schema.ParseCheckConstraints()
)
for _, dbName := range stmt.Schema.DBNames {
field := stmt.Schema.FieldsByDBName[dbName]
var foundColumn gorm.ColumnType
for _, columnType := range columnTypes {
@ -135,12 +134,15 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
if foundColumn == nil {
// not found, add column
if err := execTx.Migrator().AddColumn(value, dbName); err != nil {
if err = execTx.Migrator().AddColumn(value, dbName); err != nil {
return err
}
} else {
// found, smartly migrate
field := stmt.Schema.FieldsByDBName[dbName]
if err = execTx.Migrator().MigrateColumn(value, field, foundColumn); err != nil {
return err
}
} else if err := execTx.Migrator().MigrateColumn(value, field, foundColumn); err != nil {
// found, smart migrate
return err
}
}
@ -195,7 +197,7 @@ func (m Migrator) GetTables() (tableList []string, err error) {
func (m Migrator) CreateTable(values ...interface{}) error {
for _, value := range m.ReorderModels(values, false) {
tx := m.DB.Session(&gorm.Session{})
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
var (
createTableSQL = "CREATE TABLE ? ("
values = []interface{}{m.CurrentTable(stmt)}
@ -214,7 +216,7 @@ func (m Migrator) CreateTable(values ...interface{}) error {
if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 {
createTableSQL += "PRIMARY KEY ?,"
primaryKeys := []interface{}{}
primaryKeys := make([]interface{}, 0, len(stmt.Schema.PrimaryFields))
for _, field := range stmt.Schema.PrimaryFields {
primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName})
}
@ -225,8 +227,8 @@ func (m Migrator) CreateTable(values ...interface{}) error {
for _, idx := range stmt.Schema.ParseIndexes() {
if m.CreateIndexAfterCreateTable {
defer func(value interface{}, name string) {
if errr == nil {
errr = tx.Migrator().CreateIndex(value, name)
if err == nil {
err = tx.Migrator().CreateIndex(value, name)
}
}(value, idx.Name)
} else {
@ -276,8 +278,8 @@ func (m Migrator) CreateTable(values ...interface{}) error {
createTableSQL += fmt.Sprint(tableOption)
}
errr = tx.Exec(createTableSQL, values...).Error
return errr
err = tx.Exec(createTableSQL, values...).Error
return err
}); err != nil {
return err
}
@ -498,7 +500,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
currentDefaultNotNull := field.HasDefaultValue && (field.DefaultValueInterface != nil || !strings.EqualFold(field.DefaultValue, "NULL"))
dv, dvNotNull := columnType.DefaultValue()
if dvNotNull && !currentDefaultNotNull {
// defalut value -> null
// default value -> null
alterColumn = true
} else if !dvNotNull && currentDefaultNotNull {
// null -> default value

View File

@ -89,6 +89,10 @@ type Field struct {
NewValuePool FieldNewValuePool
}
func (field *Field) BindName() string {
return strings.Join(field.BindNames, ".")
}
// ParseField parses reflect.StructField to Field
func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
var (
@ -580,8 +584,6 @@ func (field *Field) setupValuerAndSetter() {
case **bool:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetBool(**data)
} else {
field.ReflectValueOf(ctx, value).SetBool(false)
}
case bool:
field.ReflectValueOf(ctx, value).SetBool(data)
@ -601,8 +603,6 @@ func (field *Field) setupValuerAndSetter() {
case **int64:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetInt(**data)
} else {
field.ReflectValueOf(ctx, value).SetInt(0)
}
case int64:
field.ReflectValueOf(ctx, value).SetInt(data)
@ -667,8 +667,6 @@ func (field *Field) setupValuerAndSetter() {
case **uint64:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetUint(**data)
} else {
field.ReflectValueOf(ctx, value).SetUint(0)
}
case uint64:
field.ReflectValueOf(ctx, value).SetUint(data)
@ -721,8 +719,6 @@ func (field *Field) setupValuerAndSetter() {
case **float64:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetFloat(**data)
} else {
field.ReflectValueOf(ctx, value).SetFloat(0)
}
case float64:
field.ReflectValueOf(ctx, value).SetFloat(data)
@ -767,8 +763,6 @@ func (field *Field) setupValuerAndSetter() {
case **string:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetString(**data)
} else {
field.ReflectValueOf(ctx, value).SetString("")
}
case string:
field.ReflectValueOf(ctx, value).SetString(data)

View File

@ -27,6 +27,8 @@ type Relationships struct {
HasMany []*Relationship
Many2Many []*Relationship
Relations map[string]*Relationship
EmbeddedRelations map[string]*Relationships
}
type Relationship struct {
@ -106,7 +108,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
}
if schema.err == nil {
schema.Relationships.Relations[relation.Name] = relation
schema.setRelation(relation)
switch relation.Type {
case HasOne:
schema.Relationships.HasOne = append(schema.Relationships.HasOne, relation)
@ -122,6 +124,39 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
return relation
}
func (schema *Schema) setRelation(relation *Relationship) {
// set non-embedded relation
if rel := schema.Relationships.Relations[relation.Name]; rel != nil {
if len(rel.Field.BindNames) > 1 {
schema.Relationships.Relations[relation.Name] = relation
}
} else {
schema.Relationships.Relations[relation.Name] = relation
}
// set embedded relation
if len(relation.Field.BindNames) <= 1 {
return
}
relationships := &schema.Relationships
for i, name := range relation.Field.BindNames {
if i < len(relation.Field.BindNames)-1 {
if relationships.EmbeddedRelations == nil {
relationships.EmbeddedRelations = map[string]*Relationships{}
}
if r := relationships.EmbeddedRelations[name]; r == nil {
relationships.EmbeddedRelations[name] = &Relationships{}
}
relationships = relationships.EmbeddedRelations[name]
} else {
if relationships.Relations == nil {
relationships.Relations = map[string]*Relationship{}
}
relationships.Relations[relation.Name] = relation
}
}
}
// User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner`
//
// type User struct {
@ -166,6 +201,11 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi
}
}
if primaryKeyField == nil {
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field", relation.FieldSchema, schema, field.Name)
return
}
// use same data type for foreign keys
if copyableDataType(primaryKeyField.DataType) {
relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType
@ -443,6 +483,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
primaryFields = primarySchema.PrimaryFields
}
primaryFieldLoop:
for _, primaryField := range primaryFields {
lookUpName := primarySchemaName + primaryField.Name
if gl == guessBelongs {
@ -454,11 +495,18 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID", strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID"))
}
for _, name := range lookUpNames {
if f := foreignSchema.LookUpFieldByBindName(field.BindNames, name); f != nil {
foreignFields = append(foreignFields, f)
primaryFields = append(primaryFields, primaryField)
continue primaryFieldLoop
}
}
for _, name := range lookUpNames {
if f := foreignSchema.LookUpField(name); f != nil {
foreignFields = append(foreignFields, f)
primaryFields = append(primaryFields, primaryField)
break
continue primaryFieldLoop
}
}
}

View File

@ -518,6 +518,132 @@ func TestEmbeddedRelation(t *testing.T) {
}
}
func TestEmbeddedHas(t *testing.T) {
type Toy struct {
ID int
Name string
OwnerID int
OwnerType string
}
type User struct {
ID int
Cat struct {
Name string
Toy Toy `gorm:"polymorphic:Owner;"`
Toys []Toy `gorm:"polymorphic:Owner;"`
} `gorm:"embedded;embeddedPrefix:cat_"`
Dog struct {
ID int
Name string
UserID int
Toy Toy `gorm:"polymorphic:Owner;"`
Toys []Toy `gorm:"polymorphic:Owner;"`
}
Toys []Toy `gorm:"polymorphic:Owner;"`
}
s, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"Cat": {
Relations: map[string]Relation{
"Toy": {
Name: "Toy",
Type: schema.HasOne,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
References: []Reference{
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
"Toys": {
Name: "Toys",
Type: schema.HasMany,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
References: []Reference{
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
},
},
})
}
func TestEmbeddedBelongsTo(t *testing.T) {
type Country struct {
ID int `gorm:"primaryKey"`
Name string
}
type Address struct {
CountryID int
Country Country
}
type NestedAddress struct {
Address
}
type Org struct {
ID int
PostalAddress Address `gorm:"embedded;embeddedPrefix:postal_address_"`
VisitingAddress Address `gorm:"embedded;embeddedPrefix:visiting_address_"`
AddressID int
Address struct {
ID int
Address
}
NestedAddress *NestedAddress `gorm:"embedded;embeddedPrefix:nested_address_"`
}
s, err := schema.Parse(&Org{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Errorf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"PostalAddress": {
Relations: map[string]Relation{
"Country": {
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
References: []Reference{
{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
},
},
},
},
"VisitingAddress": {
Relations: map[string]Relation{
"Country": {
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
References: []Reference{
{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
},
},
},
},
"NestedAddress": {
EmbeddedRelations: map[string]EmbeddedRelations{
"Address": {
Relations: map[string]Relation{
"Country": {
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
References: []Reference{
{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
},
},
},
},
},
},
})
}
func TestVariableRelation(t *testing.T) {
var result struct {
User

View File

@ -6,6 +6,7 @@ import (
"fmt"
"go/ast"
"reflect"
"strings"
"sync"
"gorm.io/gorm/clause"
@ -25,6 +26,7 @@ type Schema struct {
PrimaryFieldDBNames []string
Fields []*Field
FieldsByName map[string]*Field
FieldsByBindName map[string]*Field // embedded fields is 'Embed.Field'
FieldsByDBName map[string]*Field
FieldsWithDefaultDBValue []*Field // fields with default value assigned by database
Relationships Relationships
@ -67,6 +69,27 @@ func (schema Schema) LookUpField(name string) *Field {
return nil
}
// LookUpFieldByBindName looks for the closest field in the embedded struct.
//
// type Struct struct {
// Embedded struct {
// ID string // is selected by LookUpFieldByBindName([]string{"Embedded", "ID"}, "ID")
// }
// ID string // is selected by LookUpFieldByBindName([]string{"ID"}, "ID")
// }
func (schema Schema) LookUpFieldByBindName(bindNames []string, name string) *Field {
if len(bindNames) == 0 {
return nil
}
for i := len(bindNames) - 1; i >= 0; i-- {
find := strings.Join(bindNames[:i], ".") + "." + name
if field, ok := schema.FieldsByBindName[find]; ok {
return field
}
}
return nil
}
type Tabler interface {
TableName() string
}
@ -140,15 +163,16 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
}
schema := &Schema{
Name: modelType.Name(),
ModelType: modelType,
Table: tableName,
FieldsByName: map[string]*Field{},
FieldsByDBName: map[string]*Field{},
Relationships: Relationships{Relations: map[string]*Relationship{}},
cacheStore: cacheStore,
namer: namer,
initialized: make(chan struct{}),
Name: modelType.Name(),
ModelType: modelType,
Table: tableName,
FieldsByName: map[string]*Field{},
FieldsByBindName: map[string]*Field{},
FieldsByDBName: map[string]*Field{},
Relationships: Relationships{Relations: map[string]*Relationship{}},
cacheStore: cacheStore,
namer: namer,
initialized: make(chan struct{}),
}
// When the schema initialization is completed, the channel will be closed
defer close(schema.initialized)
@ -176,6 +200,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
field.DBName = namer.ColumnName(schema.Table, field.Name)
}
bindName := field.BindName()
if field.DBName != "" {
// nonexistence or shortest path or first appear prioritized if has permission
if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) {
@ -184,6 +209,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
}
schema.FieldsByDBName[field.DBName] = field
schema.FieldsByName[field.Name] = field
schema.FieldsByBindName[bindName] = field
if v != nil && v.PrimaryKey {
for idx, f := range schema.PrimaryFields {
@ -202,6 +228,9 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" {
schema.FieldsByName[field.Name] = field
}
if of, ok := schema.FieldsByBindName[bindName]; !ok || of.TagSettings["-"] == "-" {
schema.FieldsByBindName[bindName] = field
}
field.setupValuerAndSetter()
}
@ -293,6 +322,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
return schema, schema.err
} else {
schema.FieldsByName[field.Name] = field
schema.FieldsByBindName[field.BindName()] = field
}
}

View File

@ -201,6 +201,37 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) {
})
}
type EmbeddedRelations struct {
Relations map[string]Relation
EmbeddedRelations map[string]EmbeddedRelations
}
func checkEmbeddedRelations(t *testing.T, actual map[string]*schema.Relationships, expected map[string]EmbeddedRelations) {
for name, relations := range actual {
rs := expected[name]
t.Run("CheckEmbeddedRelations/"+name, func(t *testing.T) {
if len(relations.Relations) != len(rs.Relations) {
t.Errorf("schema relations count don't match, expects %d, got %d", len(rs.Relations), len(relations.Relations))
}
if len(relations.EmbeddedRelations) != len(rs.EmbeddedRelations) {
t.Errorf("schema embedded relations count don't match, expects %d, got %d", len(rs.EmbeddedRelations), len(relations.EmbeddedRelations))
}
for n, rel := range relations.Relations {
if r, ok := rs.Relations[n]; !ok {
t.Errorf("failed to find relation by name %s", n)
} else {
checkSchemaRelation(t, &schema.Schema{
Relationships: schema.Relationships{
Relations: map[string]*schema.Relationship{n: rel},
},
}, r)
}
}
checkEmbeddedRelations(t, relations.EmbeddedRelations, rs.EmbeddedRelations)
})
}
}
func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) {
for k, v := range values {
t.Run("CheckField/"+k, func(t *testing.T) {

View File

@ -324,11 +324,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
case clause.Expression:
conds = append(conds, v)
case *DB:
for _, scope := range v.Statement.scopes {
v = scope(v)
}
v.executeScopes()
if cs, ok := v.Statement.Clauses["WHERE"]; ok {
if cs, ok := v.Statement.Clauses["WHERE"]; ok && cs.Expression != nil {
if where, ok := cs.Expression.(clause.Where); ok {
if len(where.Exprs) == 1 {
if orConds, ok := where.Exprs[0].(clause.OrConditions); ok {
@ -336,9 +334,13 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
}
}
conds = append(conds, clause.And(where.Exprs...))
} else if cs.Expression != nil {
} else {
conds = append(conds, cs.Expression)
}
if v.Statement == stmt {
cs.Expression = nil
stmt.Statement.Clauses["WHERE"] = cs
}
}
case map[interface{}]interface{}:
for i, j := range v {

View File

@ -393,3 +393,33 @@ func TestConcurrentMany2ManyAssociation(t *testing.T) {
AssertEqual(t, err, nil)
AssertAssociationCount(t, find, "Languages", int64(count), "after concurrent append")
}
func TestMany2ManyDuplicateBelongsToAssociation(t *testing.T) {
user1 := User{Name: "TestMany2ManyDuplicateBelongsToAssociation-1", Friends: []*User{
{Name: "TestMany2ManyDuplicateBelongsToAssociation-friend-1", Company: Company{
ID: 1,
Name: "Test-company-1",
}},
}}
user2 := User{Name: "TestMany2ManyDuplicateBelongsToAssociation-2", Friends: []*User{
{Name: "TestMany2ManyDuplicateBelongsToAssociation-friend-2", Company: Company{
ID: 1,
Name: "Test-company-1",
}},
}}
users := []*User{&user1, &user2}
var err error
err = DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(users).Error
AssertEqual(t, nil, err)
var findUser1 User
err = DB.Preload("Friends.Company").Where("id = ?", user1.ID).First(&findUser1).Error
AssertEqual(t, nil, err)
AssertEqual(t, user1, findUser1)
var findUser2 User
err = DB.Preload("Friends.Company").Where("id = ?", user2.ID).First(&findUser2).Error
AssertEqual(t, nil, err)
AssertEqual(t, user2, findUser2)
}

View File

@ -103,9 +103,16 @@ func TestEmbeddedPointerTypeStruct(t *testing.T) {
URL string
}
type Author struct {
ID string
Name string
Email string
}
type HNPost struct {
*BasePost
Upvotes int32
*Author `gorm:"EmbeddedPrefix:user_"` // Embedded struct
}
DB.Migrator().DropTable(&HNPost{})
@ -123,6 +130,10 @@ func TestEmbeddedPointerTypeStruct(t *testing.T) {
if hnPost.Title != "embedded_pointer_type" {
t.Errorf("Should find correct value for embedded pointer type")
}
if hnPost.Author != nil {
t.Errorf("Expected to get back a nil Author but got: %v", hnPost.Author)
}
}
type Content struct {

View File

@ -306,3 +306,141 @@ func TestNestedPreloadWithUnscoped(t *testing.T) {
DB.Unscoped().Preload("Pets.Toy").Find(&user6, "id = ?", user.ID)
CheckUserUnscoped(t, *user6, user)
}
func TestEmbedPreload(t *testing.T) {
type Country struct {
ID int `gorm:"primaryKey"`
Name string
}
type EmbeddedAddress struct {
ID int
Name string
CountryID *int
Country *Country
}
type NestedAddress struct {
EmbeddedAddress
}
type Org struct {
ID int
PostalAddress EmbeddedAddress `gorm:"embedded;embeddedPrefix:postal_address_"`
VisitingAddress EmbeddedAddress `gorm:"embedded;embeddedPrefix:visiting_address_"`
AddressID *int
Address *EmbeddedAddress
NestedAddress NestedAddress `gorm:"embedded;embeddedPrefix:nested_address_"`
}
DB.Migrator().DropTable(&Org{}, &EmbeddedAddress{}, &Country{})
DB.AutoMigrate(&Org{}, &EmbeddedAddress{}, &Country{})
org := Org{
PostalAddress: EmbeddedAddress{Name: "a1", Country: &Country{Name: "c1"}},
VisitingAddress: EmbeddedAddress{Name: "a2", Country: &Country{Name: "c2"}},
Address: &EmbeddedAddress{Name: "a3", Country: &Country{Name: "c3"}},
NestedAddress: NestedAddress{
EmbeddedAddress: EmbeddedAddress{Name: "a4", Country: &Country{Name: "c4"}},
},
}
if err := DB.Create(&org).Error; err != nil {
t.Errorf("failed to create org, got err: %v", err)
}
tests := []struct {
name string
preloads map[string][]interface{}
expect Org
}{
{
name: "address country",
preloads: map[string][]interface{}{"Address.Country": {}},
expect: Org{
ID: org.ID,
PostalAddress: EmbeddedAddress{
ID: org.PostalAddress.ID,
Name: org.PostalAddress.Name,
CountryID: org.PostalAddress.CountryID,
Country: nil,
},
VisitingAddress: EmbeddedAddress{
ID: org.VisitingAddress.ID,
Name: org.VisitingAddress.Name,
CountryID: org.VisitingAddress.CountryID,
Country: nil,
},
AddressID: org.AddressID,
Address: org.Address,
NestedAddress: NestedAddress{EmbeddedAddress{
ID: org.NestedAddress.ID,
Name: org.NestedAddress.Name,
CountryID: org.NestedAddress.CountryID,
Country: nil,
}},
},
}, {
name: "postal address country",
preloads: map[string][]interface{}{"PostalAddress.Country": {}},
expect: Org{
ID: org.ID,
PostalAddress: org.PostalAddress,
VisitingAddress: EmbeddedAddress{
ID: org.VisitingAddress.ID,
Name: org.VisitingAddress.Name,
CountryID: org.VisitingAddress.CountryID,
Country: nil,
},
AddressID: org.AddressID,
Address: nil,
NestedAddress: NestedAddress{EmbeddedAddress{
ID: org.NestedAddress.ID,
Name: org.NestedAddress.Name,
CountryID: org.NestedAddress.CountryID,
Country: nil,
}},
},
}, {
name: "nested address country",
preloads: map[string][]interface{}{"NestedAddress.EmbeddedAddress.Country": {}},
expect: Org{
ID: org.ID,
PostalAddress: EmbeddedAddress{
ID: org.PostalAddress.ID,
Name: org.PostalAddress.Name,
CountryID: org.PostalAddress.CountryID,
Country: nil,
},
VisitingAddress: EmbeddedAddress{
ID: org.VisitingAddress.ID,
Name: org.VisitingAddress.Name,
CountryID: org.VisitingAddress.CountryID,
Country: nil,
},
AddressID: org.AddressID,
Address: nil,
NestedAddress: org.NestedAddress,
},
}, {
name: "associations",
preloads: map[string][]interface{}{
clause.Associations: {},
// clause.Associations wont preload nested associations
"Address.Country": {},
},
expect: org,
},
}
DB = DB.Debug()
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
actual := Org{}
tx := DB.Where("id = ?", org.ID).Session(&gorm.Session{})
for name, args := range test.preloads {
tx = tx.Preload(name, args...)
}
if err := tx.Find(&actual).Error; err != nil {
t.Errorf("failed to find org, got err: %v", err)
}
AssertEqual(t, actual, test.expect)
})
}
}

View File

@ -72,3 +72,54 @@ func TestScopes(t *testing.T) {
t.Errorf("select max(id)")
}
}
func TestComplexScopes(t *testing.T) {
tests := []struct {
name string
queryFn func(tx *gorm.DB) *gorm.DB
expected string
}{
{
name: "depth_1",
queryFn: func(tx *gorm.DB) *gorm.DB {
return tx.Scopes(
func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") },
func(d *gorm.DB) *gorm.DB { return d.Where(d.Or("b = 2").Or("c = 3")) },
).Find(&Language{})
},
expected: `SELECT * FROM "languages" WHERE a = 1 AND (b = 2 OR c = 3)`,
}, {
name: "depth_1_pre_cond",
queryFn: func(tx *gorm.DB) *gorm.DB {
return tx.Where("z = 0").Scopes(
func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") },
func(d *gorm.DB) *gorm.DB { return d.Or(d.Where("b = 2").Or("c = 3")) },
).Find(&Language{})
},
expected: `SELECT * FROM "languages" WHERE z = 0 AND a = 1 OR (b = 2 OR c = 3)`,
}, {
name: "depth_2",
queryFn: func(tx *gorm.DB) *gorm.DB {
return tx.Scopes(
func(d *gorm.DB) *gorm.DB { return d.Model(&Language{}) },
func(d *gorm.DB) *gorm.DB {
return d.
Or(d.Scopes(
func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") },
func(d *gorm.DB) *gorm.DB { return d.Where("b = 2") },
)).
Or("c = 3")
},
func(d *gorm.DB) *gorm.DB { return d.Where("d = 4") },
).Find(&Language{})
},
expected: `SELECT * FROM "languages" WHERE d = 4 OR c = 3 OR (a = 1 AND b = 2)`,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
assertEqualSQL(t, test.expected, DB.ToSQL(test.queryFn))
})
}
}