From 8f7f3ad3153c2bbcd6a74f6758ac819260ad7189 Mon Sep 17 00:00:00 2001 From: Paras Waykole Date: Wed, 5 May 2021 05:27:54 +0530 Subject: [PATCH 1/7] fixed belongs_to & has_one reversed if field same (#4343) --- schema/relationship.go | 10 +++++++--- schema/relationship_test.go | 19 +++++++++++++++++++ utils/utils.go | 12 ++++++++++++ 3 files changed, 38 insertions(+), 3 deletions(-) diff --git a/schema/relationship.go b/schema/relationship.go index fee96cbd..b2d485de 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -7,6 +7,7 @@ import ( "github.com/jinzhu/inflection" "gorm.io/gorm/clause" + "gorm.io/gorm/utils" ) // RelationshipType relationship type @@ -404,11 +405,14 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu if len(relation.foreignKeys) > 0 { for _, foreignKey := range relation.foreignKeys { - if f := foreignSchema.LookUpField(foreignKey); f != nil { - foreignFields = append(foreignFields, f) - } else { + ff := foreignSchema.LookUpField(foreignKey) + pf := primarySchema.LookUpField(foreignKey) + isKeySame := utils.ExistsIn(foreignKey, &relation.primaryKeys) + if ff == nil || (pf != nil && ff != nil && schema == primarySchema && primarySchema != foreignSchema && !isKeySame) { reguessOrErr() return + } else { + foreignFields = append(foreignFields, ff) } } } else { diff --git a/schema/relationship_test.go b/schema/relationship_test.go index 2971698c..391e3a25 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -482,3 +482,22 @@ func TestSameForeignKey(t *testing.T) { }, ) } + +func TestBelongsToWithSameForeignKey(t *testing.T) { + type Profile struct { + gorm.Model + Name string + ProfileRefer int + } + + type User struct { + gorm.Model + Profile Profile `gorm:"ForeignKey:ProfileRefer"` + ProfileRefer int + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"ID", "Profile", "ProfileRefer", "User", "", false}}, + }) +} diff --git a/utils/utils.go b/utils/utils.go index ecba7fb9..ce6f35df 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -111,3 +111,15 @@ func ToString(value interface{}) string { } return "" } + +func ExistsIn(a string, list *[]string) bool { + if list == nil { + return false + } + for _, b := range *list { + if b == a { + return true + } + } + return false +} From 3f359eab9bcfb77f47500c826027a28f714f1954 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=88=91=E7=9A=84=E6=88=91=E7=9A=84?= <67250607+guzzsek@users.noreply.github.com> Date: Wed, 5 May 2021 08:14:40 +0800 Subject: [PATCH 2/7] slim trace if depth (#4346) Co-authored-by: gogs --- logger/logger.go | 53 +++++++++++++++++++++++++----------------------- 1 file changed, 28 insertions(+), 25 deletions(-) diff --git a/logger/logger.go b/logger/logger.go index f14748c1..381199d5 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -139,31 +139,34 @@ func (l logger) Error(ctx context.Context, msg string, data ...interface{}) { // Trace print sql message func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { - if l.LogLevel > Silent { - elapsed := time.Since(begin) - switch { - case err != nil && l.LogLevel >= Error && (!errors.Is(err, ErrRecordNotFound) || !l.IgnoreRecordNotFoundError): - sql, rows := fc() - if rows == -1 { - l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, "-", sql) - } else { - l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql) - } - case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= Warn: - sql, rows := fc() - slowLog := fmt.Sprintf("SLOW SQL >= %v", l.SlowThreshold) - if rows == -1 { - l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, "-", sql) - } else { - l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, rows, sql) - } - case l.LogLevel == Info: - sql, rows := fc() - if rows == -1 { - l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql) - } else { - l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) - } + + if l.LogLevel <= Silent { + return + } + + elapsed := time.Since(begin) + switch { + case err != nil && l.LogLevel >= Error && (!errors.Is(err, ErrRecordNotFound) || !l.IgnoreRecordNotFoundError): + sql, rows := fc() + if rows == -1 { + l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, "-", sql) + } else { + l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql) + } + case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= Warn: + sql, rows := fc() + slowLog := fmt.Sprintf("SLOW SQL >= %v", l.SlowThreshold) + if rows == -1 { + l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, "-", sql) + } else { + l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, rows, sql) + } + case l.LogLevel == Info: + sql, rows := fc() + if rows == -1 { + l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql) + } else { + l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) } } } From 2aca96d1474967da11bac81a58db9c97bd7bdcac Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 5 May 2021 08:26:15 +0800 Subject: [PATCH 3/7] test ignore migration, close #4314, #4315 --- schema/field_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/schema/field_test.go b/schema/field_test.go index 64f4a909..00f8cd42 100644 --- a/schema/field_test.go +++ b/schema/field_test.go @@ -235,6 +235,7 @@ type UserWithPermissionControl struct { Name5 string `gorm:"<-:update"` Name6 string `gorm:"<-:create,update"` Name7 string `gorm:"->:false;<-:create,update"` + Name8 string `gorm:"->;-:migration"` } func TestParseFieldWithPermission(t *testing.T) { @@ -252,6 +253,7 @@ func TestParseFieldWithPermission(t *testing.T) { {Name: "Name5", DBName: "name5", BindNames: []string{"Name5"}, DataType: schema.String, Tag: `gorm:"<-:update"`, Creatable: false, Updatable: true, Readable: true}, {Name: "Name6", DBName: "name6", BindNames: []string{"Name6"}, DataType: schema.String, Tag: `gorm:"<-:create,update"`, Creatable: true, Updatable: true, Readable: true}, {Name: "Name7", DBName: "name7", BindNames: []string{"Name7"}, DataType: schema.String, Tag: `gorm:"->:false;<-:create,update"`, Creatable: true, Updatable: true, Readable: false}, + {Name: "Name8", DBName: "name8", BindNames: []string{"Name8"}, DataType: schema.String, Tag: `gorm:"<->"`, Creatable: false, Updatable: false, Readable: true, IgnoreMigration: true}, } for _, f := range fields { From 6b7abc54a2a02ac0604a580571732e1c73bc42bf Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 6 May 2021 13:06:31 +0800 Subject: [PATCH 4/7] Fix tests --- schema/field_test.go | 2 +- tests/go.mod | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/schema/field_test.go b/schema/field_test.go index 00f8cd42..4be3e5ab 100644 --- a/schema/field_test.go +++ b/schema/field_test.go @@ -253,7 +253,7 @@ func TestParseFieldWithPermission(t *testing.T) { {Name: "Name5", DBName: "name5", BindNames: []string{"Name5"}, DataType: schema.String, Tag: `gorm:"<-:update"`, Creatable: false, Updatable: true, Readable: true}, {Name: "Name6", DBName: "name6", BindNames: []string{"Name6"}, DataType: schema.String, Tag: `gorm:"<-:create,update"`, Creatable: true, Updatable: true, Readable: true}, {Name: "Name7", DBName: "name7", BindNames: []string{"Name7"}, DataType: schema.String, Tag: `gorm:"->:false;<-:create,update"`, Creatable: true, Updatable: true, Readable: false}, - {Name: "Name8", DBName: "name8", BindNames: []string{"Name8"}, DataType: schema.String, Tag: `gorm:"<->"`, Creatable: false, Updatable: false, Readable: true, IgnoreMigration: true}, + {Name: "Name8", DBName: "name8", BindNames: []string{"Name8"}, DataType: schema.String, Tag: `gorm:"->;-:migration"`, Creatable: false, Updatable: false, Readable: true, IgnoreMigration: true}, } for _, f := range fields { diff --git a/tests/go.mod b/tests/go.mod index d4b0c975..643b72c7 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -8,10 +8,10 @@ require ( github.com/lib/pq v1.6.0 github.com/stretchr/testify v1.5.1 gorm.io/driver/mysql v1.0.5 - gorm.io/driver/postgres v1.0.8 + gorm.io/driver/postgres v1.1.0 gorm.io/driver/sqlite v1.1.4 gorm.io/driver/sqlserver v1.0.7 - gorm.io/gorm v1.21.4 + gorm.io/gorm v1.21.9 ) replace gorm.io/gorm => ../ From a480bd85450d444d2526a309f2ecae07cac814c0 Mon Sep 17 00:00:00 2001 From: Chen Quan Date: Mon, 10 May 2021 09:51:50 +0800 Subject: [PATCH 5/7] Update Optimize schema (#4364) --- schema/schema.go | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/schema/schema.go b/schema/schema.go index d08842e6..1ce88fa5 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -71,7 +71,7 @@ type Tabler interface { TableName() string } -// get data type from dialector +// Parse get data type from dialector func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { if dest == nil { return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) @@ -91,6 +91,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) if v, ok := cacheStore.Load(modelType); ok { s := v.(*Schema) + // Wait for the initialization of other goroutines to complete <-s.initialized return s, s.err } @@ -115,6 +116,15 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) namer: namer, initialized: make(chan struct{}), } + // When the schema initialization is completed, the channel will be closed + defer close(schema.initialized) + + if v, loaded := cacheStore.LoadOrStore(modelType, schema); loaded { + s := v.(*Schema) + // Wait for the initialization of other goroutines to complete + <-s.initialized + return s, s.err + } defer func() { if schema.err != nil { @@ -223,13 +233,6 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) } } - if v, loaded := cacheStore.LoadOrStore(modelType, schema); loaded { - s := v.(*Schema) - <-s.initialized - return s, s.err - } - - defer close(schema.initialized) if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded { for _, field := range schema.Fields { if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) { From 92c3ba9dccd65f41652baa4d51e4c82af5496eec Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 17 May 2021 15:34:24 +0800 Subject: [PATCH 6/7] Fix create new db sessions in scopes --- callbacks.go | 4 +++- finisher_api.go | 51 +++++++++++++++++--------------------------- tests/scopes_test.go | 8 +++++++ 3 files changed, 30 insertions(+), 33 deletions(-) diff --git a/callbacks.go b/callbacks.go index 01d9ed30..26e9c40d 100644 --- a/callbacks.go +++ b/callbacks.go @@ -72,7 +72,7 @@ func (cs *callbacks) Raw() *processor { return cs.processors["raw"] } -func (p *processor) Execute(db *DB) { +func (p *processor) Execute(db *DB) *DB { // call scopes for len(db.Statement.scopes) > 0 { scopes := db.Statement.scopes @@ -142,6 +142,8 @@ func (p *processor) Execute(db *DB) { if resetBuildClauses { stmt.BuildClauses = nil } + + return db } func (p *processor) Get(name string) func(*DB) { diff --git a/finisher_api.go b/finisher_api.go index b5cbfaa6..c3941784 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -21,8 +21,7 @@ func (db *DB) Create(value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = value - tx.callbacks.Create().Execute(tx) - return + return tx.callbacks.Create().Execute(tx) } // CreateInBatches insert the value in batches into database @@ -64,7 +63,7 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { default: tx = db.getInstance() tx.Statement.Dest = value - tx.callbacks.Create().Execute(tx) + tx = tx.callbacks.Create().Execute(tx) } return } @@ -80,13 +79,12 @@ func (db *DB) Save(value interface{}) (tx *DB) { if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok { tx = tx.Clauses(clause.OnConflict{UpdateAll: true}) } - tx.callbacks.Create().Execute(tx.InstanceSet("gorm:update_track_time", true)) + tx = tx.callbacks.Create().Execute(tx.InstanceSet("gorm:update_track_time", true)) case reflect.Struct: if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { for _, pf := range tx.Statement.Schema.PrimaryFields { if _, isZero := pf.ValueOf(reflectValue); isZero { - tx.callbacks.Create().Execute(tx) - return + return tx.callbacks.Create().Execute(tx) } } } @@ -99,7 +97,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { tx.Statement.Selects = append(tx.Statement.Selects, "*") } - tx.callbacks.Update().Execute(tx) + tx = tx.callbacks.Update().Execute(tx) if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate { result := reflect.New(tx.Statement.Schema.ModelType).Interface() @@ -124,8 +122,7 @@ func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { } tx.Statement.RaiseErrorOnNotFound = true tx.Statement.Dest = dest - tx.callbacks.Query().Execute(tx) - return + return tx.callbacks.Query().Execute(tx) } // Take return a record that match given conditions, the order will depend on the database implementation @@ -138,8 +135,7 @@ func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) { } tx.Statement.RaiseErrorOnNotFound = true tx.Statement.Dest = dest - tx.callbacks.Query().Execute(tx) - return + return tx.callbacks.Query().Execute(tx) } // Last find last record that match given conditions, order by primary key @@ -155,8 +151,7 @@ func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { } tx.Statement.RaiseErrorOnNotFound = true tx.Statement.Dest = dest - tx.callbacks.Query().Execute(tx) - return + return tx.callbacks.Query().Execute(tx) } // Find find records that match given conditions @@ -168,8 +163,7 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { } } tx.Statement.Dest = dest - tx.callbacks.Query().Execute(tx) - return + return tx.callbacks.Query().Execute(tx) } // FindInBatches find records in batches @@ -334,32 +328,28 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) Update(column string, value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = map[string]interface{}{column: value} - tx.callbacks.Update().Execute(tx) - return + return tx.callbacks.Update().Execute(tx) } // Updates update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields func (db *DB) Updates(values interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = values - tx.callbacks.Update().Execute(tx) - return + return tx.callbacks.Update().Execute(tx) } func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = map[string]interface{}{column: value} tx.Statement.SkipHooks = true - tx.callbacks.Update().Execute(tx) - return + return tx.callbacks.Update().Execute(tx) } func (db *DB) UpdateColumns(values interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = values tx.Statement.SkipHooks = true - tx.callbacks.Update().Execute(tx) - return + return tx.callbacks.Update().Execute(tx) } // Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition @@ -371,8 +361,7 @@ func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) { } } tx.Statement.Dest = value - tx.callbacks.Delete().Execute(tx) - return + return tx.callbacks.Delete().Execute(tx) } func (db *DB) Count(count *int64) (tx *DB) { @@ -428,7 +417,7 @@ func (db *DB) Count(count *int64) (tx *DB) { } tx.Statement.Dest = count - tx.callbacks.Query().Execute(tx) + tx = tx.callbacks.Query().Execute(tx) if tx.RowsAffected != 1 { *count = tx.RowsAffected } @@ -437,7 +426,7 @@ func (db *DB) Count(count *int64) (tx *DB) { func (db *DB) Row() *sql.Row { tx := db.getInstance().InstanceSet("rows", false) - tx.callbacks.Row().Execute(tx) + tx = tx.callbacks.Row().Execute(tx) row, ok := tx.Statement.Dest.(*sql.Row) if !ok && tx.DryRun { db.Logger.Error(tx.Statement.Context, ErrDryRunModeUnsupported.Error()) @@ -447,7 +436,7 @@ func (db *DB) Row() *sql.Row { func (db *DB) Rows() (*sql.Rows, error) { tx := db.getInstance().InstanceSet("rows", true) - tx.callbacks.Row().Execute(tx) + tx = tx.callbacks.Row().Execute(tx) rows, ok := tx.Statement.Dest.(*sql.Rows) if !ok && tx.DryRun && tx.Error == nil { tx.Error = ErrDryRunModeUnsupported @@ -505,8 +494,7 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { }) } tx.Statement.Dest = dest - tx.callbacks.Query().Execute(tx) - return + return tx.callbacks.Query().Execute(tx) } func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { @@ -644,6 +632,5 @@ func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) } - tx.callbacks.Raw().Execute(tx) - return + return tx.callbacks.Raw().Execute(tx) } diff --git a/tests/scopes_test.go b/tests/scopes_test.go index 9836c41e..0ec4783b 100644 --- a/tests/scopes_test.go +++ b/tests/scopes_test.go @@ -54,4 +54,12 @@ func TestScopes(t *testing.T) { if db.Find(&User{}).Statement.Table != "custom_table" { t.Errorf("failed to call Scopes") } + + result := DB.Scopes(NameIn1And2, func(tx *gorm.DB) *gorm.DB { + return tx.Session(&gorm.Session{}) + }).Find(&users1) + + if result.RowsAffected != 2 { + t.Errorf("Should found two users's name in 1, 2, but got %v", result.RowsAffected) + } } From cf93b16730e405ac611b7e3571f5fc92682efe7a Mon Sep 17 00:00:00 2001 From: Atreya <44151328+atreya2011@users.noreply.github.com> Date: Mon, 17 May 2021 16:53:48 +0900 Subject: [PATCH 7/7] Fix ErrInvalidTransaction error message (#4380) --- errors.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/errors.go b/errors.go index f1f6c137..145614d9 100644 --- a/errors.go +++ b/errors.go @@ -10,7 +10,7 @@ var ( // ErrRecordNotFound record not found error ErrRecordNotFound = logger.ErrRecordNotFound // ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback` - ErrInvalidTransaction = errors.New("no valid transaction") + ErrInvalidTransaction = errors.New("invalid transaction") // ErrNotImplemented not implemented ErrNotImplemented = errors.New("not implemented") // ErrMissingWhereClause missing where clause