Merge remote-tracking branch 'upstream/master' into f-table-comment-0411
This commit is contained in:
commit
b067c10de6
2
.github/workflows/invalid_question.yml
vendored
2
.github/workflows/invalid_question.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
|||||||
ACTIONS_STEP_DEBUG: true
|
ACTIONS_STEP_DEBUG: true
|
||||||
steps:
|
steps:
|
||||||
- name: Close Stale Issues
|
- name: Close Stale Issues
|
||||||
uses: actions/stale@v7
|
uses: actions/stale@v8
|
||||||
with:
|
with:
|
||||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 30 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
|
stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 30 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
|
||||||
|
2
.github/workflows/missing_playground.yml
vendored
2
.github/workflows/missing_playground.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
|||||||
ACTIONS_STEP_DEBUG: true
|
ACTIONS_STEP_DEBUG: true
|
||||||
steps:
|
steps:
|
||||||
- name: Close Stale Issues
|
- name: Close Stale Issues
|
||||||
uses: actions/stale@v7
|
uses: actions/stale@v8
|
||||||
with:
|
with:
|
||||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 30 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
|
stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 30 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
|
||||||
|
2
.github/workflows/stale.yml
vendored
2
.github/workflows/stale.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
|||||||
ACTIONS_STEP_DEBUG: true
|
ACTIONS_STEP_DEBUG: true
|
||||||
steps:
|
steps:
|
||||||
- name: Close Stale Issues
|
- name: Close Stale Issues
|
||||||
uses: actions/stale@v7
|
uses: actions/stale@v8
|
||||||
with:
|
with:
|
||||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
stale-issue-message: "This issue has been automatically marked as stale because it has been open 360 days with no activity. Remove stale label or comment or this will be closed in 180 days"
|
stale-issue-message: "This issue has been automatically marked as stale because it has been open 360 days with no activity. Remove stale label or comment or this will be closed in 180 days"
|
||||||
|
@ -35,9 +35,7 @@ The fantastic ORM library for Golang, aims to be developer friendly.
|
|||||||
|
|
||||||
## Contributors
|
## Contributors
|
||||||
|
|
||||||
Thank you for contributing to the GORM framework!
|
[Thank you](https://github.com/go-gorm/gorm/graphs/contributors) for contributing to the GORM framework!
|
||||||
|
|
||||||
[](https://github.com/go-gorm/gorm/graphs/contributors)
|
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
|
@ -75,11 +75,7 @@ func (cs *callbacks) Raw() *processor {
|
|||||||
func (p *processor) Execute(db *DB) *DB {
|
func (p *processor) Execute(db *DB) *DB {
|
||||||
// call scopes
|
// call scopes
|
||||||
for len(db.Statement.scopes) > 0 {
|
for len(db.Statement.scopes) > 0 {
|
||||||
scopes := db.Statement.scopes
|
db = db.executeScopes()
|
||||||
db.Statement.scopes = nil
|
|
||||||
for _, scope := range scopes {
|
|
||||||
db = scope(db)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -51,25 +51,40 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||||
|
distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||||
|
identityMap := map[string]bool{}
|
||||||
for i := 0; i < rValLen; i++ {
|
for i := 0; i < rValLen; i++ {
|
||||||
obj := db.Statement.ReflectValue.Index(i)
|
obj := db.Statement.ReflectValue.Index(i)
|
||||||
if reflect.Indirect(obj).Kind() != reflect.Struct {
|
if reflect.Indirect(obj).Kind() != reflect.Struct {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { // check belongs to relation value
|
if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { // check belongs to relation value
|
||||||
rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) // relation reflect value
|
rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) // relation reflect value
|
||||||
|
if !isPtr {
|
||||||
|
rv = rv.Addr()
|
||||||
|
}
|
||||||
objs = append(objs, obj)
|
objs = append(objs, obj)
|
||||||
if isPtr {
|
elems = reflect.Append(elems, rv)
|
||||||
elems = reflect.Append(elems, rv)
|
|
||||||
} else {
|
relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
|
||||||
elems = reflect.Append(elems, rv.Addr())
|
for _, pf := range rel.FieldSchema.PrimaryFields {
|
||||||
|
if pfv, ok := pf.ValueOf(db.Statement.Context, rv); !ok {
|
||||||
|
relPrimaryValues = append(relPrimaryValues, pfv)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cacheKey := utils.ToStringKey(relPrimaryValues...)
|
||||||
|
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
|
||||||
|
if cacheKey != "" { // has primary fields
|
||||||
|
identityMap[cacheKey] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
distinctElems = reflect.Append(distinctElems, rv)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if elems.Len() > 0 {
|
if elems.Len() > 0 {
|
||||||
if saveAssociations(db, rel, elems, selectColumns, restricted, nil) == nil {
|
if saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil) == nil {
|
||||||
for i := 0; i < elems.Len(); i++ {
|
for i := 0; i < elems.Len(); i++ {
|
||||||
setupReferences(objs[i], elems.Index(i))
|
setupReferences(objs[i], elems.Index(i))
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,7 @@ package callbacks
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"gorm.io/gorm/clause"
|
"gorm.io/gorm/clause"
|
||||||
@ -10,6 +11,98 @@ import (
|
|||||||
"gorm.io/gorm/utils"
|
"gorm.io/gorm/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// parsePreloadMap extracts nested preloads. e.g.
|
||||||
|
//
|
||||||
|
// // schema has a "k0" relation and a "k7.k8" embedded relation
|
||||||
|
// parsePreloadMap(schema, map[string][]interface{}{
|
||||||
|
// clause.Associations: {"arg1"},
|
||||||
|
// "k1": {"arg2"},
|
||||||
|
// "k2.k3": {"arg3"},
|
||||||
|
// "k4.k5.k6": {"arg4"},
|
||||||
|
// })
|
||||||
|
// // preloadMap is
|
||||||
|
// map[string]map[string][]interface{}{
|
||||||
|
// "k0": {},
|
||||||
|
// "k7": {
|
||||||
|
// "k8": {},
|
||||||
|
// },
|
||||||
|
// "k1": {},
|
||||||
|
// "k2": {
|
||||||
|
// "k3": {"arg3"},
|
||||||
|
// },
|
||||||
|
// "k4": {
|
||||||
|
// "k5.k6": {"arg4"},
|
||||||
|
// },
|
||||||
|
// }
|
||||||
|
func parsePreloadMap(s *schema.Schema, preloads map[string][]interface{}) map[string]map[string][]interface{} {
|
||||||
|
preloadMap := map[string]map[string][]interface{}{}
|
||||||
|
setPreloadMap := func(name, value string, args []interface{}) {
|
||||||
|
if _, ok := preloadMap[name]; !ok {
|
||||||
|
preloadMap[name] = map[string][]interface{}{}
|
||||||
|
}
|
||||||
|
if value != "" {
|
||||||
|
preloadMap[name][value] = args
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, args := range preloads {
|
||||||
|
preloadFields := strings.Split(name, ".")
|
||||||
|
value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), ".")
|
||||||
|
if preloadFields[0] == clause.Associations {
|
||||||
|
for _, relation := range s.Relationships.Relations {
|
||||||
|
if relation.Schema == s {
|
||||||
|
setPreloadMap(relation.Name, value, args)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for embedded, embeddedRelations := range s.Relationships.EmbeddedRelations {
|
||||||
|
for _, value := range embeddedValues(embeddedRelations) {
|
||||||
|
setPreloadMap(embedded, value, args)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
setPreloadMap(preloadFields[0], value, args)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return preloadMap
|
||||||
|
}
|
||||||
|
|
||||||
|
func embeddedValues(embeddedRelations *schema.Relationships) []string {
|
||||||
|
if embeddedRelations == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations))
|
||||||
|
for _, relation := range embeddedRelations.Relations {
|
||||||
|
// skip first struct name
|
||||||
|
names = append(names, strings.Join(relation.Field.BindNames[1:], "."))
|
||||||
|
}
|
||||||
|
for _, relations := range embeddedRelations.EmbeddedRelations {
|
||||||
|
names = append(names, embeddedValues(relations)...)
|
||||||
|
}
|
||||||
|
return names
|
||||||
|
}
|
||||||
|
|
||||||
|
func preloadEmbedded(tx *gorm.DB, relationships *schema.Relationships, s *schema.Schema, preloads map[string][]interface{}, as []interface{}) error {
|
||||||
|
if relationships == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
preloadMap := parsePreloadMap(s, preloads)
|
||||||
|
for name := range preloadMap {
|
||||||
|
if embeddedRelations := relationships.EmbeddedRelations[name]; embeddedRelations != nil {
|
||||||
|
if err := preloadEmbedded(tx, embeddedRelations, s, preloadMap[name], as); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else if rel := relationships.Relations[name]; rel != nil {
|
||||||
|
if err := preload(tx, rel, append(preloads[name], as), preloadMap[name]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("%s: %w (embedded) for schema %s", name, gorm.ErrUnsupportedRelation, s.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error {
|
func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error {
|
||||||
var (
|
var (
|
||||||
reflectValue = tx.Statement.ReflectValue
|
reflectValue = tx.Statement.ReflectValue
|
||||||
|
@ -267,32 +267,7 @@ func Preload(db *gorm.DB) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
preloadMap := map[string]map[string][]interface{}{}
|
preloadMap := parsePreloadMap(db.Statement.Schema, db.Statement.Preloads)
|
||||||
for name := range db.Statement.Preloads {
|
|
||||||
preloadFields := strings.Split(name, ".")
|
|
||||||
if preloadFields[0] == clause.Associations {
|
|
||||||
for _, rel := range db.Statement.Schema.Relationships.Relations {
|
|
||||||
if rel.Schema == db.Statement.Schema {
|
|
||||||
if _, ok := preloadMap[rel.Name]; !ok {
|
|
||||||
preloadMap[rel.Name] = map[string][]interface{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" {
|
|
||||||
preloadMap[rel.Name][value] = db.Statement.Preloads[name]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if _, ok := preloadMap[preloadFields[0]]; !ok {
|
|
||||||
preloadMap[preloadFields[0]] = map[string][]interface{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" {
|
|
||||||
preloadMap[preloadFields[0]][value] = db.Statement.Preloads[name]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
preloadNames := make([]string, 0, len(preloadMap))
|
preloadNames := make([]string, 0, len(preloadMap))
|
||||||
for key := range preloadMap {
|
for key := range preloadMap {
|
||||||
preloadNames = append(preloadNames, key)
|
preloadNames = append(preloadNames, key)
|
||||||
@ -312,7 +287,9 @@ func Preload(db *gorm.DB) {
|
|||||||
preloadDB.Statement.Unscoped = db.Statement.Unscoped
|
preloadDB.Statement.Unscoped = db.Statement.Unscoped
|
||||||
|
|
||||||
for _, name := range preloadNames {
|
for _, name := range preloadNames {
|
||||||
if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil {
|
if relations := preloadDB.Statement.Schema.Relationships.EmbeddedRelations[name]; relations != nil {
|
||||||
|
db.AddError(preloadEmbedded(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), relations, db.Statement.Schema, preloadMap[name], db.Statement.Preloads[clause.Associations]))
|
||||||
|
} else if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil {
|
||||||
db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name]))
|
db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name]))
|
||||||
} else {
|
} else {
|
||||||
db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name))
|
db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name))
|
||||||
|
@ -366,6 +366,36 @@ func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) {
|
|||||||
return tx
|
return tx
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (db *DB) executeScopes() (tx *DB) {
|
||||||
|
tx = db.getInstance()
|
||||||
|
scopes := db.Statement.scopes
|
||||||
|
if len(scopes) == 0 {
|
||||||
|
return tx
|
||||||
|
}
|
||||||
|
tx.Statement.scopes = nil
|
||||||
|
|
||||||
|
conditions := make([]clause.Interface, 0, 4)
|
||||||
|
if cs, ok := tx.Statement.Clauses["WHERE"]; ok && cs.Expression != nil {
|
||||||
|
conditions = append(conditions, cs.Expression.(clause.Interface))
|
||||||
|
cs.Expression = nil
|
||||||
|
tx.Statement.Clauses["WHERE"] = cs
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, scope := range scopes {
|
||||||
|
tx = scope(tx)
|
||||||
|
if cs, ok := tx.Statement.Clauses["WHERE"]; ok && cs.Expression != nil {
|
||||||
|
conditions = append(conditions, cs.Expression.(clause.Interface))
|
||||||
|
cs.Expression = nil
|
||||||
|
tx.Statement.Clauses["WHERE"] = cs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, condition := range conditions {
|
||||||
|
tx.Statement.AddClause(condition)
|
||||||
|
}
|
||||||
|
return tx
|
||||||
|
}
|
||||||
|
|
||||||
// Preload preload associations with given conditions
|
// Preload preload associations with given conditions
|
||||||
//
|
//
|
||||||
// // get all users, and preload all non-cancelled orders
|
// // get all users, and preload all non-cancelled orders
|
||||||
|
@ -33,7 +33,7 @@ func (limit Limit) MergeClause(clause *Clause) {
|
|||||||
clause.Name = ""
|
clause.Name = ""
|
||||||
|
|
||||||
if v, ok := clause.Expression.(Limit); ok {
|
if v, ok := clause.Expression.(Limit); ok {
|
||||||
if (limit.Limit == nil || *limit.Limit == 0) && (v.Limit != nil && *v.Limit != 0) {
|
if (limit.Limit == nil || *limit.Limit == 0) && v.Limit != nil {
|
||||||
limit.Limit = v.Limit
|
limit.Limit = v.Limit
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -28,6 +28,10 @@ func TestLimit(t *testing.T) {
|
|||||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}},
|
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}},
|
||||||
"SELECT * FROM `users` LIMIT 0", nil,
|
"SELECT * FROM `users` LIMIT 0", nil,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}, clause.Limit{Offset: 0}},
|
||||||
|
"SELECT * FROM `users` LIMIT 0", nil,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}},
|
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}},
|
||||||
"SELECT * FROM `users` OFFSET 20", nil,
|
"SELECT * FROM `users` OFFSET 20", nil,
|
||||||
|
@ -12,11 +12,7 @@ func (db *DB) Migrator() Migrator {
|
|||||||
|
|
||||||
// apply scopes to migrator
|
// apply scopes to migrator
|
||||||
for len(tx.Statement.scopes) > 0 {
|
for len(tx.Statement.scopes) > 0 {
|
||||||
scopes := tx.Statement.scopes
|
tx = tx.executeScopes()
|
||||||
tx.Statement.scopes = nil
|
|
||||||
for _, scope := range scopes {
|
|
||||||
tx = scope(tx)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return tx.Dialector.Migrator(tx.Session(&Session{}))
|
return tx.Dialector.Migrator(tx.Session(&Session{}))
|
||||||
|
@ -17,12 +17,12 @@ func (idx Index) Table() string {
|
|||||||
return idx.TableName
|
return idx.TableName
|
||||||
}
|
}
|
||||||
|
|
||||||
// Name return the name of the index.
|
// Name return the name of the index.
|
||||||
func (idx Index) Name() string {
|
func (idx Index) Name() string {
|
||||||
return idx.NameValue
|
return idx.NameValue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Columns return the columns fo the index
|
// Columns return the columns of the index
|
||||||
func (idx Index) Columns() []string {
|
func (idx Index) Columns() []string {
|
||||||
return idx.ColumnList
|
return idx.ColumnList
|
||||||
}
|
}
|
||||||
@ -37,7 +37,7 @@ func (idx Index) Unique() (unique bool, ok bool) {
|
|||||||
return idx.UniqueValue.Bool, idx.UniqueValue.Valid
|
return idx.UniqueValue.Bool, idx.UniqueValue.Valid
|
||||||
}
|
}
|
||||||
|
|
||||||
// Option return the optional attribute fo the index
|
// Option return the optional attribute of the index
|
||||||
func (idx Index) Option() string {
|
func (idx Index) Option() string {
|
||||||
return idx.OptionValue
|
return idx.OptionValue
|
||||||
}
|
}
|
||||||
|
@ -113,7 +113,7 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
|
if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||||
columnTypes, err := queryTx.Migrator().ColumnTypes(value)
|
columnTypes, err := queryTx.Migrator().ColumnTypes(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -123,7 +123,6 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
|
|||||||
parseCheckConstraints = stmt.Schema.ParseCheckConstraints()
|
parseCheckConstraints = stmt.Schema.ParseCheckConstraints()
|
||||||
)
|
)
|
||||||
for _, dbName := range stmt.Schema.DBNames {
|
for _, dbName := range stmt.Schema.DBNames {
|
||||||
field := stmt.Schema.FieldsByDBName[dbName]
|
|
||||||
var foundColumn gorm.ColumnType
|
var foundColumn gorm.ColumnType
|
||||||
|
|
||||||
for _, columnType := range columnTypes {
|
for _, columnType := range columnTypes {
|
||||||
@ -135,12 +134,15 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
|
|||||||
|
|
||||||
if foundColumn == nil {
|
if foundColumn == nil {
|
||||||
// not found, add column
|
// not found, add column
|
||||||
if err := execTx.Migrator().AddColumn(value, dbName); err != nil {
|
if err = execTx.Migrator().AddColumn(value, dbName); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// found, smartly migrate
|
||||||
|
field := stmt.Schema.FieldsByDBName[dbName]
|
||||||
|
if err = execTx.Migrator().MigrateColumn(value, field, foundColumn); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else if err := execTx.Migrator().MigrateColumn(value, field, foundColumn); err != nil {
|
|
||||||
// found, smart migrate
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -195,7 +197,7 @@ func (m Migrator) GetTables() (tableList []string, err error) {
|
|||||||
func (m Migrator) CreateTable(values ...interface{}) error {
|
func (m Migrator) CreateTable(values ...interface{}) error {
|
||||||
for _, value := range m.ReorderModels(values, false) {
|
for _, value := range m.ReorderModels(values, false) {
|
||||||
tx := m.DB.Session(&gorm.Session{})
|
tx := m.DB.Session(&gorm.Session{})
|
||||||
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
|
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
|
||||||
var (
|
var (
|
||||||
createTableSQL = "CREATE TABLE ? ("
|
createTableSQL = "CREATE TABLE ? ("
|
||||||
values = []interface{}{m.CurrentTable(stmt)}
|
values = []interface{}{m.CurrentTable(stmt)}
|
||||||
@ -214,7 +216,7 @@ func (m Migrator) CreateTable(values ...interface{}) error {
|
|||||||
|
|
||||||
if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 {
|
if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 {
|
||||||
createTableSQL += "PRIMARY KEY ?,"
|
createTableSQL += "PRIMARY KEY ?,"
|
||||||
primaryKeys := []interface{}{}
|
primaryKeys := make([]interface{}, 0, len(stmt.Schema.PrimaryFields))
|
||||||
for _, field := range stmt.Schema.PrimaryFields {
|
for _, field := range stmt.Schema.PrimaryFields {
|
||||||
primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName})
|
primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName})
|
||||||
}
|
}
|
||||||
@ -225,8 +227,8 @@ func (m Migrator) CreateTable(values ...interface{}) error {
|
|||||||
for _, idx := range stmt.Schema.ParseIndexes() {
|
for _, idx := range stmt.Schema.ParseIndexes() {
|
||||||
if m.CreateIndexAfterCreateTable {
|
if m.CreateIndexAfterCreateTable {
|
||||||
defer func(value interface{}, name string) {
|
defer func(value interface{}, name string) {
|
||||||
if errr == nil {
|
if err == nil {
|
||||||
errr = tx.Migrator().CreateIndex(value, name)
|
err = tx.Migrator().CreateIndex(value, name)
|
||||||
}
|
}
|
||||||
}(value, idx.Name)
|
}(value, idx.Name)
|
||||||
} else {
|
} else {
|
||||||
@ -276,8 +278,8 @@ func (m Migrator) CreateTable(values ...interface{}) error {
|
|||||||
createTableSQL += fmt.Sprint(tableOption)
|
createTableSQL += fmt.Sprint(tableOption)
|
||||||
}
|
}
|
||||||
|
|
||||||
errr = tx.Exec(createTableSQL, values...).Error
|
err = tx.Exec(createTableSQL, values...).Error
|
||||||
return errr
|
return err
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -498,7 +500,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
|
|||||||
currentDefaultNotNull := field.HasDefaultValue && (field.DefaultValueInterface != nil || !strings.EqualFold(field.DefaultValue, "NULL"))
|
currentDefaultNotNull := field.HasDefaultValue && (field.DefaultValueInterface != nil || !strings.EqualFold(field.DefaultValue, "NULL"))
|
||||||
dv, dvNotNull := columnType.DefaultValue()
|
dv, dvNotNull := columnType.DefaultValue()
|
||||||
if dvNotNull && !currentDefaultNotNull {
|
if dvNotNull && !currentDefaultNotNull {
|
||||||
// defalut value -> null
|
// default value -> null
|
||||||
alterColumn = true
|
alterColumn = true
|
||||||
} else if !dvNotNull && currentDefaultNotNull {
|
} else if !dvNotNull && currentDefaultNotNull {
|
||||||
// null -> default value
|
// null -> default value
|
||||||
|
@ -89,6 +89,10 @@ type Field struct {
|
|||||||
NewValuePool FieldNewValuePool
|
NewValuePool FieldNewValuePool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (field *Field) BindName() string {
|
||||||
|
return strings.Join(field.BindNames, ".")
|
||||||
|
}
|
||||||
|
|
||||||
// ParseField parses reflect.StructField to Field
|
// ParseField parses reflect.StructField to Field
|
||||||
func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||||
var (
|
var (
|
||||||
@ -580,8 +584,6 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
case **bool:
|
case **bool:
|
||||||
if data != nil && *data != nil {
|
if data != nil && *data != nil {
|
||||||
field.ReflectValueOf(ctx, value).SetBool(**data)
|
field.ReflectValueOf(ctx, value).SetBool(**data)
|
||||||
} else {
|
|
||||||
field.ReflectValueOf(ctx, value).SetBool(false)
|
|
||||||
}
|
}
|
||||||
case bool:
|
case bool:
|
||||||
field.ReflectValueOf(ctx, value).SetBool(data)
|
field.ReflectValueOf(ctx, value).SetBool(data)
|
||||||
@ -601,8 +603,6 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
case **int64:
|
case **int64:
|
||||||
if data != nil && *data != nil {
|
if data != nil && *data != nil {
|
||||||
field.ReflectValueOf(ctx, value).SetInt(**data)
|
field.ReflectValueOf(ctx, value).SetInt(**data)
|
||||||
} else {
|
|
||||||
field.ReflectValueOf(ctx, value).SetInt(0)
|
|
||||||
}
|
}
|
||||||
case int64:
|
case int64:
|
||||||
field.ReflectValueOf(ctx, value).SetInt(data)
|
field.ReflectValueOf(ctx, value).SetInt(data)
|
||||||
@ -667,8 +667,6 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
case **uint64:
|
case **uint64:
|
||||||
if data != nil && *data != nil {
|
if data != nil && *data != nil {
|
||||||
field.ReflectValueOf(ctx, value).SetUint(**data)
|
field.ReflectValueOf(ctx, value).SetUint(**data)
|
||||||
} else {
|
|
||||||
field.ReflectValueOf(ctx, value).SetUint(0)
|
|
||||||
}
|
}
|
||||||
case uint64:
|
case uint64:
|
||||||
field.ReflectValueOf(ctx, value).SetUint(data)
|
field.ReflectValueOf(ctx, value).SetUint(data)
|
||||||
@ -721,8 +719,6 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
case **float64:
|
case **float64:
|
||||||
if data != nil && *data != nil {
|
if data != nil && *data != nil {
|
||||||
field.ReflectValueOf(ctx, value).SetFloat(**data)
|
field.ReflectValueOf(ctx, value).SetFloat(**data)
|
||||||
} else {
|
|
||||||
field.ReflectValueOf(ctx, value).SetFloat(0)
|
|
||||||
}
|
}
|
||||||
case float64:
|
case float64:
|
||||||
field.ReflectValueOf(ctx, value).SetFloat(data)
|
field.ReflectValueOf(ctx, value).SetFloat(data)
|
||||||
@ -767,8 +763,6 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
case **string:
|
case **string:
|
||||||
if data != nil && *data != nil {
|
if data != nil && *data != nil {
|
||||||
field.ReflectValueOf(ctx, value).SetString(**data)
|
field.ReflectValueOf(ctx, value).SetString(**data)
|
||||||
} else {
|
|
||||||
field.ReflectValueOf(ctx, value).SetString("")
|
|
||||||
}
|
}
|
||||||
case string:
|
case string:
|
||||||
field.ReflectValueOf(ctx, value).SetString(data)
|
field.ReflectValueOf(ctx, value).SetString(data)
|
||||||
|
@ -27,6 +27,8 @@ type Relationships struct {
|
|||||||
HasMany []*Relationship
|
HasMany []*Relationship
|
||||||
Many2Many []*Relationship
|
Many2Many []*Relationship
|
||||||
Relations map[string]*Relationship
|
Relations map[string]*Relationship
|
||||||
|
|
||||||
|
EmbeddedRelations map[string]*Relationships
|
||||||
}
|
}
|
||||||
|
|
||||||
type Relationship struct {
|
type Relationship struct {
|
||||||
@ -106,7 +108,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if schema.err == nil {
|
if schema.err == nil {
|
||||||
schema.Relationships.Relations[relation.Name] = relation
|
schema.setRelation(relation)
|
||||||
switch relation.Type {
|
switch relation.Type {
|
||||||
case HasOne:
|
case HasOne:
|
||||||
schema.Relationships.HasOne = append(schema.Relationships.HasOne, relation)
|
schema.Relationships.HasOne = append(schema.Relationships.HasOne, relation)
|
||||||
@ -122,6 +124,39 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
|
|||||||
return relation
|
return relation
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (schema *Schema) setRelation(relation *Relationship) {
|
||||||
|
// set non-embedded relation
|
||||||
|
if rel := schema.Relationships.Relations[relation.Name]; rel != nil {
|
||||||
|
if len(rel.Field.BindNames) > 1 {
|
||||||
|
schema.Relationships.Relations[relation.Name] = relation
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
schema.Relationships.Relations[relation.Name] = relation
|
||||||
|
}
|
||||||
|
|
||||||
|
// set embedded relation
|
||||||
|
if len(relation.Field.BindNames) <= 1 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
relationships := &schema.Relationships
|
||||||
|
for i, name := range relation.Field.BindNames {
|
||||||
|
if i < len(relation.Field.BindNames)-1 {
|
||||||
|
if relationships.EmbeddedRelations == nil {
|
||||||
|
relationships.EmbeddedRelations = map[string]*Relationships{}
|
||||||
|
}
|
||||||
|
if r := relationships.EmbeddedRelations[name]; r == nil {
|
||||||
|
relationships.EmbeddedRelations[name] = &Relationships{}
|
||||||
|
}
|
||||||
|
relationships = relationships.EmbeddedRelations[name]
|
||||||
|
} else {
|
||||||
|
if relationships.Relations == nil {
|
||||||
|
relationships.Relations = map[string]*Relationship{}
|
||||||
|
}
|
||||||
|
relationships.Relations[relation.Name] = relation
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner`
|
// User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner`
|
||||||
//
|
//
|
||||||
// type User struct {
|
// type User struct {
|
||||||
@ -166,6 +201,11 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if primaryKeyField == nil {
|
||||||
|
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field", relation.FieldSchema, schema, field.Name)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// use same data type for foreign keys
|
// use same data type for foreign keys
|
||||||
if copyableDataType(primaryKeyField.DataType) {
|
if copyableDataType(primaryKeyField.DataType) {
|
||||||
relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType
|
relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType
|
||||||
@ -443,6 +483,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
|
|||||||
primaryFields = primarySchema.PrimaryFields
|
primaryFields = primarySchema.PrimaryFields
|
||||||
}
|
}
|
||||||
|
|
||||||
|
primaryFieldLoop:
|
||||||
for _, primaryField := range primaryFields {
|
for _, primaryField := range primaryFields {
|
||||||
lookUpName := primarySchemaName + primaryField.Name
|
lookUpName := primarySchemaName + primaryField.Name
|
||||||
if gl == guessBelongs {
|
if gl == guessBelongs {
|
||||||
@ -454,11 +495,18 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
|
|||||||
lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID", strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID"))
|
lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID", strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, name := range lookUpNames {
|
||||||
|
if f := foreignSchema.LookUpFieldByBindName(field.BindNames, name); f != nil {
|
||||||
|
foreignFields = append(foreignFields, f)
|
||||||
|
primaryFields = append(primaryFields, primaryField)
|
||||||
|
continue primaryFieldLoop
|
||||||
|
}
|
||||||
|
}
|
||||||
for _, name := range lookUpNames {
|
for _, name := range lookUpNames {
|
||||||
if f := foreignSchema.LookUpField(name); f != nil {
|
if f := foreignSchema.LookUpField(name); f != nil {
|
||||||
foreignFields = append(foreignFields, f)
|
foreignFields = append(foreignFields, f)
|
||||||
primaryFields = append(primaryFields, primaryField)
|
primaryFields = append(primaryFields, primaryField)
|
||||||
break
|
continue primaryFieldLoop
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -518,6 +518,132 @@ func TestEmbeddedRelation(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEmbeddedHas(t *testing.T) {
|
||||||
|
type Toy struct {
|
||||||
|
ID int
|
||||||
|
Name string
|
||||||
|
OwnerID int
|
||||||
|
OwnerType string
|
||||||
|
}
|
||||||
|
type User struct {
|
||||||
|
ID int
|
||||||
|
Cat struct {
|
||||||
|
Name string
|
||||||
|
Toy Toy `gorm:"polymorphic:Owner;"`
|
||||||
|
Toys []Toy `gorm:"polymorphic:Owner;"`
|
||||||
|
} `gorm:"embedded;embeddedPrefix:cat_"`
|
||||||
|
Dog struct {
|
||||||
|
ID int
|
||||||
|
Name string
|
||||||
|
UserID int
|
||||||
|
Toy Toy `gorm:"polymorphic:Owner;"`
|
||||||
|
Toys []Toy `gorm:"polymorphic:Owner;"`
|
||||||
|
}
|
||||||
|
Toys []Toy `gorm:"polymorphic:Owner;"`
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to parse schema, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
|
||||||
|
"Cat": {
|
||||||
|
Relations: map[string]Relation{
|
||||||
|
"Toy": {
|
||||||
|
Name: "Toy",
|
||||||
|
Type: schema.HasOne,
|
||||||
|
Schema: "User",
|
||||||
|
FieldSchema: "Toy",
|
||||||
|
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
|
||||||
|
References: []Reference{
|
||||||
|
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||||
|
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"Toys": {
|
||||||
|
Name: "Toys",
|
||||||
|
Type: schema.HasMany,
|
||||||
|
Schema: "User",
|
||||||
|
FieldSchema: "Toy",
|
||||||
|
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
|
||||||
|
References: []Reference{
|
||||||
|
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||||
|
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmbeddedBelongsTo(t *testing.T) {
|
||||||
|
type Country struct {
|
||||||
|
ID int `gorm:"primaryKey"`
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
type Address struct {
|
||||||
|
CountryID int
|
||||||
|
Country Country
|
||||||
|
}
|
||||||
|
type NestedAddress struct {
|
||||||
|
Address
|
||||||
|
}
|
||||||
|
type Org struct {
|
||||||
|
ID int
|
||||||
|
PostalAddress Address `gorm:"embedded;embeddedPrefix:postal_address_"`
|
||||||
|
VisitingAddress Address `gorm:"embedded;embeddedPrefix:visiting_address_"`
|
||||||
|
AddressID int
|
||||||
|
Address struct {
|
||||||
|
ID int
|
||||||
|
Address
|
||||||
|
}
|
||||||
|
NestedAddress *NestedAddress `gorm:"embedded;embeddedPrefix:nested_address_"`
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := schema.Parse(&Org{}, &sync.Map{}, schema.NamingStrategy{})
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to parse schema, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
|
||||||
|
"PostalAddress": {
|
||||||
|
Relations: map[string]Relation{
|
||||||
|
"Country": {
|
||||||
|
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
|
||||||
|
References: []Reference{
|
||||||
|
{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"VisitingAddress": {
|
||||||
|
Relations: map[string]Relation{
|
||||||
|
"Country": {
|
||||||
|
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
|
||||||
|
References: []Reference{
|
||||||
|
{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"NestedAddress": {
|
||||||
|
EmbeddedRelations: map[string]EmbeddedRelations{
|
||||||
|
"Address": {
|
||||||
|
Relations: map[string]Relation{
|
||||||
|
"Country": {
|
||||||
|
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
|
||||||
|
References: []Reference{
|
||||||
|
{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestVariableRelation(t *testing.T) {
|
func TestVariableRelation(t *testing.T) {
|
||||||
var result struct {
|
var result struct {
|
||||||
User
|
User
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"go/ast"
|
"go/ast"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"gorm.io/gorm/clause"
|
"gorm.io/gorm/clause"
|
||||||
@ -25,6 +26,7 @@ type Schema struct {
|
|||||||
PrimaryFieldDBNames []string
|
PrimaryFieldDBNames []string
|
||||||
Fields []*Field
|
Fields []*Field
|
||||||
FieldsByName map[string]*Field
|
FieldsByName map[string]*Field
|
||||||
|
FieldsByBindName map[string]*Field // embedded fields is 'Embed.Field'
|
||||||
FieldsByDBName map[string]*Field
|
FieldsByDBName map[string]*Field
|
||||||
FieldsWithDefaultDBValue []*Field // fields with default value assigned by database
|
FieldsWithDefaultDBValue []*Field // fields with default value assigned by database
|
||||||
Relationships Relationships
|
Relationships Relationships
|
||||||
@ -67,6 +69,27 @@ func (schema Schema) LookUpField(name string) *Field {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LookUpFieldByBindName looks for the closest field in the embedded struct.
|
||||||
|
//
|
||||||
|
// type Struct struct {
|
||||||
|
// Embedded struct {
|
||||||
|
// ID string // is selected by LookUpFieldByBindName([]string{"Embedded", "ID"}, "ID")
|
||||||
|
// }
|
||||||
|
// ID string // is selected by LookUpFieldByBindName([]string{"ID"}, "ID")
|
||||||
|
// }
|
||||||
|
func (schema Schema) LookUpFieldByBindName(bindNames []string, name string) *Field {
|
||||||
|
if len(bindNames) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
for i := len(bindNames) - 1; i >= 0; i-- {
|
||||||
|
find := strings.Join(bindNames[:i], ".") + "." + name
|
||||||
|
if field, ok := schema.FieldsByBindName[find]; ok {
|
||||||
|
return field
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type Tabler interface {
|
type Tabler interface {
|
||||||
TableName() string
|
TableName() string
|
||||||
}
|
}
|
||||||
@ -140,15 +163,16 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
|||||||
}
|
}
|
||||||
|
|
||||||
schema := &Schema{
|
schema := &Schema{
|
||||||
Name: modelType.Name(),
|
Name: modelType.Name(),
|
||||||
ModelType: modelType,
|
ModelType: modelType,
|
||||||
Table: tableName,
|
Table: tableName,
|
||||||
FieldsByName: map[string]*Field{},
|
FieldsByName: map[string]*Field{},
|
||||||
FieldsByDBName: map[string]*Field{},
|
FieldsByBindName: map[string]*Field{},
|
||||||
Relationships: Relationships{Relations: map[string]*Relationship{}},
|
FieldsByDBName: map[string]*Field{},
|
||||||
cacheStore: cacheStore,
|
Relationships: Relationships{Relations: map[string]*Relationship{}},
|
||||||
namer: namer,
|
cacheStore: cacheStore,
|
||||||
initialized: make(chan struct{}),
|
namer: namer,
|
||||||
|
initialized: make(chan struct{}),
|
||||||
}
|
}
|
||||||
// When the schema initialization is completed, the channel will be closed
|
// When the schema initialization is completed, the channel will be closed
|
||||||
defer close(schema.initialized)
|
defer close(schema.initialized)
|
||||||
@ -176,6 +200,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
|||||||
field.DBName = namer.ColumnName(schema.Table, field.Name)
|
field.DBName = namer.ColumnName(schema.Table, field.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bindName := field.BindName()
|
||||||
if field.DBName != "" {
|
if field.DBName != "" {
|
||||||
// nonexistence or shortest path or first appear prioritized if has permission
|
// nonexistence or shortest path or first appear prioritized if has permission
|
||||||
if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) {
|
if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) {
|
||||||
@ -184,6 +209,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
|||||||
}
|
}
|
||||||
schema.FieldsByDBName[field.DBName] = field
|
schema.FieldsByDBName[field.DBName] = field
|
||||||
schema.FieldsByName[field.Name] = field
|
schema.FieldsByName[field.Name] = field
|
||||||
|
schema.FieldsByBindName[bindName] = field
|
||||||
|
|
||||||
if v != nil && v.PrimaryKey {
|
if v != nil && v.PrimaryKey {
|
||||||
for idx, f := range schema.PrimaryFields {
|
for idx, f := range schema.PrimaryFields {
|
||||||
@ -202,6 +228,9 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
|||||||
if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" {
|
if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" {
|
||||||
schema.FieldsByName[field.Name] = field
|
schema.FieldsByName[field.Name] = field
|
||||||
}
|
}
|
||||||
|
if of, ok := schema.FieldsByBindName[bindName]; !ok || of.TagSettings["-"] == "-" {
|
||||||
|
schema.FieldsByBindName[bindName] = field
|
||||||
|
}
|
||||||
|
|
||||||
field.setupValuerAndSetter()
|
field.setupValuerAndSetter()
|
||||||
}
|
}
|
||||||
@ -293,6 +322,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
|||||||
return schema, schema.err
|
return schema, schema.err
|
||||||
} else {
|
} else {
|
||||||
schema.FieldsByName[field.Name] = field
|
schema.FieldsByName[field.Name] = field
|
||||||
|
schema.FieldsByBindName[field.BindName()] = field
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -201,6 +201,37 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type EmbeddedRelations struct {
|
||||||
|
Relations map[string]Relation
|
||||||
|
EmbeddedRelations map[string]EmbeddedRelations
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkEmbeddedRelations(t *testing.T, actual map[string]*schema.Relationships, expected map[string]EmbeddedRelations) {
|
||||||
|
for name, relations := range actual {
|
||||||
|
rs := expected[name]
|
||||||
|
t.Run("CheckEmbeddedRelations/"+name, func(t *testing.T) {
|
||||||
|
if len(relations.Relations) != len(rs.Relations) {
|
||||||
|
t.Errorf("schema relations count don't match, expects %d, got %d", len(rs.Relations), len(relations.Relations))
|
||||||
|
}
|
||||||
|
if len(relations.EmbeddedRelations) != len(rs.EmbeddedRelations) {
|
||||||
|
t.Errorf("schema embedded relations count don't match, expects %d, got %d", len(rs.EmbeddedRelations), len(relations.EmbeddedRelations))
|
||||||
|
}
|
||||||
|
for n, rel := range relations.Relations {
|
||||||
|
if r, ok := rs.Relations[n]; !ok {
|
||||||
|
t.Errorf("failed to find relation by name %s", n)
|
||||||
|
} else {
|
||||||
|
checkSchemaRelation(t, &schema.Schema{
|
||||||
|
Relationships: schema.Relationships{
|
||||||
|
Relations: map[string]*schema.Relationship{n: rel},
|
||||||
|
},
|
||||||
|
}, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
checkEmbeddedRelations(t, relations.EmbeddedRelations, rs.EmbeddedRelations)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) {
|
func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) {
|
||||||
for k, v := range values {
|
for k, v := range values {
|
||||||
t.Run("CheckField/"+k, func(t *testing.T) {
|
t.Run("CheckField/"+k, func(t *testing.T) {
|
||||||
|
12
statement.go
12
statement.go
@ -324,11 +324,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
|||||||
case clause.Expression:
|
case clause.Expression:
|
||||||
conds = append(conds, v)
|
conds = append(conds, v)
|
||||||
case *DB:
|
case *DB:
|
||||||
for _, scope := range v.Statement.scopes {
|
v.executeScopes()
|
||||||
v = scope(v)
|
|
||||||
}
|
|
||||||
|
|
||||||
if cs, ok := v.Statement.Clauses["WHERE"]; ok {
|
if cs, ok := v.Statement.Clauses["WHERE"]; ok && cs.Expression != nil {
|
||||||
if where, ok := cs.Expression.(clause.Where); ok {
|
if where, ok := cs.Expression.(clause.Where); ok {
|
||||||
if len(where.Exprs) == 1 {
|
if len(where.Exprs) == 1 {
|
||||||
if orConds, ok := where.Exprs[0].(clause.OrConditions); ok {
|
if orConds, ok := where.Exprs[0].(clause.OrConditions); ok {
|
||||||
@ -336,9 +334,13 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
conds = append(conds, clause.And(where.Exprs...))
|
conds = append(conds, clause.And(where.Exprs...))
|
||||||
} else if cs.Expression != nil {
|
} else {
|
||||||
conds = append(conds, cs.Expression)
|
conds = append(conds, cs.Expression)
|
||||||
}
|
}
|
||||||
|
if v.Statement == stmt {
|
||||||
|
cs.Expression = nil
|
||||||
|
stmt.Statement.Clauses["WHERE"] = cs
|
||||||
|
}
|
||||||
}
|
}
|
||||||
case map[interface{}]interface{}:
|
case map[interface{}]interface{}:
|
||||||
for i, j := range v {
|
for i, j := range v {
|
||||||
|
@ -393,3 +393,33 @@ func TestConcurrentMany2ManyAssociation(t *testing.T) {
|
|||||||
AssertEqual(t, err, nil)
|
AssertEqual(t, err, nil)
|
||||||
AssertAssociationCount(t, find, "Languages", int64(count), "after concurrent append")
|
AssertAssociationCount(t, find, "Languages", int64(count), "after concurrent append")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMany2ManyDuplicateBelongsToAssociation(t *testing.T) {
|
||||||
|
user1 := User{Name: "TestMany2ManyDuplicateBelongsToAssociation-1", Friends: []*User{
|
||||||
|
{Name: "TestMany2ManyDuplicateBelongsToAssociation-friend-1", Company: Company{
|
||||||
|
ID: 1,
|
||||||
|
Name: "Test-company-1",
|
||||||
|
}},
|
||||||
|
}}
|
||||||
|
|
||||||
|
user2 := User{Name: "TestMany2ManyDuplicateBelongsToAssociation-2", Friends: []*User{
|
||||||
|
{Name: "TestMany2ManyDuplicateBelongsToAssociation-friend-2", Company: Company{
|
||||||
|
ID: 1,
|
||||||
|
Name: "Test-company-1",
|
||||||
|
}},
|
||||||
|
}}
|
||||||
|
users := []*User{&user1, &user2}
|
||||||
|
var err error
|
||||||
|
err = DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(users).Error
|
||||||
|
AssertEqual(t, nil, err)
|
||||||
|
|
||||||
|
var findUser1 User
|
||||||
|
err = DB.Preload("Friends.Company").Where("id = ?", user1.ID).First(&findUser1).Error
|
||||||
|
AssertEqual(t, nil, err)
|
||||||
|
AssertEqual(t, user1, findUser1)
|
||||||
|
|
||||||
|
var findUser2 User
|
||||||
|
err = DB.Preload("Friends.Company").Where("id = ?", user2.ID).First(&findUser2).Error
|
||||||
|
AssertEqual(t, nil, err)
|
||||||
|
AssertEqual(t, user2, findUser2)
|
||||||
|
}
|
||||||
|
@ -103,9 +103,16 @@ func TestEmbeddedPointerTypeStruct(t *testing.T) {
|
|||||||
URL string
|
URL string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Author struct {
|
||||||
|
ID string
|
||||||
|
Name string
|
||||||
|
Email string
|
||||||
|
}
|
||||||
|
|
||||||
type HNPost struct {
|
type HNPost struct {
|
||||||
*BasePost
|
*BasePost
|
||||||
Upvotes int32
|
Upvotes int32
|
||||||
|
*Author `gorm:"EmbeddedPrefix:user_"` // Embedded struct
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Migrator().DropTable(&HNPost{})
|
DB.Migrator().DropTable(&HNPost{})
|
||||||
@ -123,6 +130,10 @@ func TestEmbeddedPointerTypeStruct(t *testing.T) {
|
|||||||
if hnPost.Title != "embedded_pointer_type" {
|
if hnPost.Title != "embedded_pointer_type" {
|
||||||
t.Errorf("Should find correct value for embedded pointer type")
|
t.Errorf("Should find correct value for embedded pointer type")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if hnPost.Author != nil {
|
||||||
|
t.Errorf("Expected to get back a nil Author but got: %v", hnPost.Author)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type Content struct {
|
type Content struct {
|
||||||
|
@ -306,3 +306,141 @@ func TestNestedPreloadWithUnscoped(t *testing.T) {
|
|||||||
DB.Unscoped().Preload("Pets.Toy").Find(&user6, "id = ?", user.ID)
|
DB.Unscoped().Preload("Pets.Toy").Find(&user6, "id = ?", user.ID)
|
||||||
CheckUserUnscoped(t, *user6, user)
|
CheckUserUnscoped(t, *user6, user)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEmbedPreload(t *testing.T) {
|
||||||
|
type Country struct {
|
||||||
|
ID int `gorm:"primaryKey"`
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
type EmbeddedAddress struct {
|
||||||
|
ID int
|
||||||
|
Name string
|
||||||
|
CountryID *int
|
||||||
|
Country *Country
|
||||||
|
}
|
||||||
|
type NestedAddress struct {
|
||||||
|
EmbeddedAddress
|
||||||
|
}
|
||||||
|
type Org struct {
|
||||||
|
ID int
|
||||||
|
PostalAddress EmbeddedAddress `gorm:"embedded;embeddedPrefix:postal_address_"`
|
||||||
|
VisitingAddress EmbeddedAddress `gorm:"embedded;embeddedPrefix:visiting_address_"`
|
||||||
|
AddressID *int
|
||||||
|
Address *EmbeddedAddress
|
||||||
|
NestedAddress NestedAddress `gorm:"embedded;embeddedPrefix:nested_address_"`
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Migrator().DropTable(&Org{}, &EmbeddedAddress{}, &Country{})
|
||||||
|
DB.AutoMigrate(&Org{}, &EmbeddedAddress{}, &Country{})
|
||||||
|
|
||||||
|
org := Org{
|
||||||
|
PostalAddress: EmbeddedAddress{Name: "a1", Country: &Country{Name: "c1"}},
|
||||||
|
VisitingAddress: EmbeddedAddress{Name: "a2", Country: &Country{Name: "c2"}},
|
||||||
|
Address: &EmbeddedAddress{Name: "a3", Country: &Country{Name: "c3"}},
|
||||||
|
NestedAddress: NestedAddress{
|
||||||
|
EmbeddedAddress: EmbeddedAddress{Name: "a4", Country: &Country{Name: "c4"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if err := DB.Create(&org).Error; err != nil {
|
||||||
|
t.Errorf("failed to create org, got err: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
preloads map[string][]interface{}
|
||||||
|
expect Org
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "address country",
|
||||||
|
preloads: map[string][]interface{}{"Address.Country": {}},
|
||||||
|
expect: Org{
|
||||||
|
ID: org.ID,
|
||||||
|
PostalAddress: EmbeddedAddress{
|
||||||
|
ID: org.PostalAddress.ID,
|
||||||
|
Name: org.PostalAddress.Name,
|
||||||
|
CountryID: org.PostalAddress.CountryID,
|
||||||
|
Country: nil,
|
||||||
|
},
|
||||||
|
VisitingAddress: EmbeddedAddress{
|
||||||
|
ID: org.VisitingAddress.ID,
|
||||||
|
Name: org.VisitingAddress.Name,
|
||||||
|
CountryID: org.VisitingAddress.CountryID,
|
||||||
|
Country: nil,
|
||||||
|
},
|
||||||
|
AddressID: org.AddressID,
|
||||||
|
Address: org.Address,
|
||||||
|
NestedAddress: NestedAddress{EmbeddedAddress{
|
||||||
|
ID: org.NestedAddress.ID,
|
||||||
|
Name: org.NestedAddress.Name,
|
||||||
|
CountryID: org.NestedAddress.CountryID,
|
||||||
|
Country: nil,
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "postal address country",
|
||||||
|
preloads: map[string][]interface{}{"PostalAddress.Country": {}},
|
||||||
|
expect: Org{
|
||||||
|
ID: org.ID,
|
||||||
|
PostalAddress: org.PostalAddress,
|
||||||
|
VisitingAddress: EmbeddedAddress{
|
||||||
|
ID: org.VisitingAddress.ID,
|
||||||
|
Name: org.VisitingAddress.Name,
|
||||||
|
CountryID: org.VisitingAddress.CountryID,
|
||||||
|
Country: nil,
|
||||||
|
},
|
||||||
|
AddressID: org.AddressID,
|
||||||
|
Address: nil,
|
||||||
|
NestedAddress: NestedAddress{EmbeddedAddress{
|
||||||
|
ID: org.NestedAddress.ID,
|
||||||
|
Name: org.NestedAddress.Name,
|
||||||
|
CountryID: org.NestedAddress.CountryID,
|
||||||
|
Country: nil,
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "nested address country",
|
||||||
|
preloads: map[string][]interface{}{"NestedAddress.EmbeddedAddress.Country": {}},
|
||||||
|
expect: Org{
|
||||||
|
ID: org.ID,
|
||||||
|
PostalAddress: EmbeddedAddress{
|
||||||
|
ID: org.PostalAddress.ID,
|
||||||
|
Name: org.PostalAddress.Name,
|
||||||
|
CountryID: org.PostalAddress.CountryID,
|
||||||
|
Country: nil,
|
||||||
|
},
|
||||||
|
VisitingAddress: EmbeddedAddress{
|
||||||
|
ID: org.VisitingAddress.ID,
|
||||||
|
Name: org.VisitingAddress.Name,
|
||||||
|
CountryID: org.VisitingAddress.CountryID,
|
||||||
|
Country: nil,
|
||||||
|
},
|
||||||
|
AddressID: org.AddressID,
|
||||||
|
Address: nil,
|
||||||
|
NestedAddress: org.NestedAddress,
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "associations",
|
||||||
|
preloads: map[string][]interface{}{
|
||||||
|
clause.Associations: {},
|
||||||
|
// clause.Associations won’t preload nested associations
|
||||||
|
"Address.Country": {},
|
||||||
|
},
|
||||||
|
expect: org,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
DB = DB.Debug()
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
actual := Org{}
|
||||||
|
tx := DB.Where("id = ?", org.ID).Session(&gorm.Session{})
|
||||||
|
for name, args := range test.preloads {
|
||||||
|
tx = tx.Preload(name, args...)
|
||||||
|
}
|
||||||
|
if err := tx.Find(&actual).Error; err != nil {
|
||||||
|
t.Errorf("failed to find org, got err: %v", err)
|
||||||
|
}
|
||||||
|
AssertEqual(t, actual, test.expect)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -72,3 +72,54 @@ func TestScopes(t *testing.T) {
|
|||||||
t.Errorf("select max(id)")
|
t.Errorf("select max(id)")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestComplexScopes(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
queryFn func(tx *gorm.DB) *gorm.DB
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "depth_1",
|
||||||
|
queryFn: func(tx *gorm.DB) *gorm.DB {
|
||||||
|
return tx.Scopes(
|
||||||
|
func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") },
|
||||||
|
func(d *gorm.DB) *gorm.DB { return d.Where(d.Or("b = 2").Or("c = 3")) },
|
||||||
|
).Find(&Language{})
|
||||||
|
},
|
||||||
|
expected: `SELECT * FROM "languages" WHERE a = 1 AND (b = 2 OR c = 3)`,
|
||||||
|
}, {
|
||||||
|
name: "depth_1_pre_cond",
|
||||||
|
queryFn: func(tx *gorm.DB) *gorm.DB {
|
||||||
|
return tx.Where("z = 0").Scopes(
|
||||||
|
func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") },
|
||||||
|
func(d *gorm.DB) *gorm.DB { return d.Or(d.Where("b = 2").Or("c = 3")) },
|
||||||
|
).Find(&Language{})
|
||||||
|
},
|
||||||
|
expected: `SELECT * FROM "languages" WHERE z = 0 AND a = 1 OR (b = 2 OR c = 3)`,
|
||||||
|
}, {
|
||||||
|
name: "depth_2",
|
||||||
|
queryFn: func(tx *gorm.DB) *gorm.DB {
|
||||||
|
return tx.Scopes(
|
||||||
|
func(d *gorm.DB) *gorm.DB { return d.Model(&Language{}) },
|
||||||
|
func(d *gorm.DB) *gorm.DB {
|
||||||
|
return d.
|
||||||
|
Or(d.Scopes(
|
||||||
|
func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") },
|
||||||
|
func(d *gorm.DB) *gorm.DB { return d.Where("b = 2") },
|
||||||
|
)).
|
||||||
|
Or("c = 3")
|
||||||
|
},
|
||||||
|
func(d *gorm.DB) *gorm.DB { return d.Where("d = 4") },
|
||||||
|
).Find(&Language{})
|
||||||
|
},
|
||||||
|
expected: `SELECT * FROM "languages" WHERE d = 4 OR c = 3 OR (a = 1 AND b = 2)`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
assertEqualSQL(t, test.expected, DB.ToSQL(test.queryFn))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user