diff --git a/callbacks.go b/callbacks.go index 01d9ed30..26e9c40d 100644 --- a/callbacks.go +++ b/callbacks.go @@ -72,7 +72,7 @@ func (cs *callbacks) Raw() *processor { return cs.processors["raw"] } -func (p *processor) Execute(db *DB) { +func (p *processor) Execute(db *DB) *DB { // call scopes for len(db.Statement.scopes) > 0 { scopes := db.Statement.scopes @@ -142,6 +142,8 @@ func (p *processor) Execute(db *DB) { if resetBuildClauses { stmt.BuildClauses = nil } + + return db } func (p *processor) Get(name string) func(*DB) { diff --git a/errors.go b/errors.go index f1f6c137..145614d9 100644 --- a/errors.go +++ b/errors.go @@ -10,7 +10,7 @@ var ( // ErrRecordNotFound record not found error ErrRecordNotFound = logger.ErrRecordNotFound // ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback` - ErrInvalidTransaction = errors.New("no valid transaction") + ErrInvalidTransaction = errors.New("invalid transaction") // ErrNotImplemented not implemented ErrNotImplemented = errors.New("not implemented") // ErrMissingWhereClause missing where clause diff --git a/finisher_api.go b/finisher_api.go index b5cbfaa6..c3941784 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -21,8 +21,7 @@ func (db *DB) Create(value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = value - tx.callbacks.Create().Execute(tx) - return + return tx.callbacks.Create().Execute(tx) } // CreateInBatches insert the value in batches into database @@ -64,7 +63,7 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { default: tx = db.getInstance() tx.Statement.Dest = value - tx.callbacks.Create().Execute(tx) + tx = tx.callbacks.Create().Execute(tx) } return } @@ -80,13 +79,12 @@ func (db *DB) Save(value interface{}) (tx *DB) { if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok { tx = tx.Clauses(clause.OnConflict{UpdateAll: true}) } - tx.callbacks.Create().Execute(tx.InstanceSet("gorm:update_track_time", true)) + tx = tx.callbacks.Create().Execute(tx.InstanceSet("gorm:update_track_time", true)) case reflect.Struct: if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { for _, pf := range tx.Statement.Schema.PrimaryFields { if _, isZero := pf.ValueOf(reflectValue); isZero { - tx.callbacks.Create().Execute(tx) - return + return tx.callbacks.Create().Execute(tx) } } } @@ -99,7 +97,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { tx.Statement.Selects = append(tx.Statement.Selects, "*") } - tx.callbacks.Update().Execute(tx) + tx = tx.callbacks.Update().Execute(tx) if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate { result := reflect.New(tx.Statement.Schema.ModelType).Interface() @@ -124,8 +122,7 @@ func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { } tx.Statement.RaiseErrorOnNotFound = true tx.Statement.Dest = dest - tx.callbacks.Query().Execute(tx) - return + return tx.callbacks.Query().Execute(tx) } // Take return a record that match given conditions, the order will depend on the database implementation @@ -138,8 +135,7 @@ func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) { } tx.Statement.RaiseErrorOnNotFound = true tx.Statement.Dest = dest - tx.callbacks.Query().Execute(tx) - return + return tx.callbacks.Query().Execute(tx) } // Last find last record that match given conditions, order by primary key @@ -155,8 +151,7 @@ func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { } tx.Statement.RaiseErrorOnNotFound = true tx.Statement.Dest = dest - tx.callbacks.Query().Execute(tx) - return + return tx.callbacks.Query().Execute(tx) } // Find find records that match given conditions @@ -168,8 +163,7 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { } } tx.Statement.Dest = dest - tx.callbacks.Query().Execute(tx) - return + return tx.callbacks.Query().Execute(tx) } // FindInBatches find records in batches @@ -334,32 +328,28 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) Update(column string, value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = map[string]interface{}{column: value} - tx.callbacks.Update().Execute(tx) - return + return tx.callbacks.Update().Execute(tx) } // Updates update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields func (db *DB) Updates(values interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = values - tx.callbacks.Update().Execute(tx) - return + return tx.callbacks.Update().Execute(tx) } func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = map[string]interface{}{column: value} tx.Statement.SkipHooks = true - tx.callbacks.Update().Execute(tx) - return + return tx.callbacks.Update().Execute(tx) } func (db *DB) UpdateColumns(values interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = values tx.Statement.SkipHooks = true - tx.callbacks.Update().Execute(tx) - return + return tx.callbacks.Update().Execute(tx) } // Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition @@ -371,8 +361,7 @@ func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) { } } tx.Statement.Dest = value - tx.callbacks.Delete().Execute(tx) - return + return tx.callbacks.Delete().Execute(tx) } func (db *DB) Count(count *int64) (tx *DB) { @@ -428,7 +417,7 @@ func (db *DB) Count(count *int64) (tx *DB) { } tx.Statement.Dest = count - tx.callbacks.Query().Execute(tx) + tx = tx.callbacks.Query().Execute(tx) if tx.RowsAffected != 1 { *count = tx.RowsAffected } @@ -437,7 +426,7 @@ func (db *DB) Count(count *int64) (tx *DB) { func (db *DB) Row() *sql.Row { tx := db.getInstance().InstanceSet("rows", false) - tx.callbacks.Row().Execute(tx) + tx = tx.callbacks.Row().Execute(tx) row, ok := tx.Statement.Dest.(*sql.Row) if !ok && tx.DryRun { db.Logger.Error(tx.Statement.Context, ErrDryRunModeUnsupported.Error()) @@ -447,7 +436,7 @@ func (db *DB) Row() *sql.Row { func (db *DB) Rows() (*sql.Rows, error) { tx := db.getInstance().InstanceSet("rows", true) - tx.callbacks.Row().Execute(tx) + tx = tx.callbacks.Row().Execute(tx) rows, ok := tx.Statement.Dest.(*sql.Rows) if !ok && tx.DryRun && tx.Error == nil { tx.Error = ErrDryRunModeUnsupported @@ -505,8 +494,7 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { }) } tx.Statement.Dest = dest - tx.callbacks.Query().Execute(tx) - return + return tx.callbacks.Query().Execute(tx) } func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { @@ -644,6 +632,5 @@ func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) } - tx.callbacks.Raw().Execute(tx) - return + return tx.callbacks.Raw().Execute(tx) } diff --git a/logger/logger.go b/logger/logger.go index f14748c1..381199d5 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -139,31 +139,34 @@ func (l logger) Error(ctx context.Context, msg string, data ...interface{}) { // Trace print sql message func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { - if l.LogLevel > Silent { - elapsed := time.Since(begin) - switch { - case err != nil && l.LogLevel >= Error && (!errors.Is(err, ErrRecordNotFound) || !l.IgnoreRecordNotFoundError): - sql, rows := fc() - if rows == -1 { - l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, "-", sql) - } else { - l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql) - } - case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= Warn: - sql, rows := fc() - slowLog := fmt.Sprintf("SLOW SQL >= %v", l.SlowThreshold) - if rows == -1 { - l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, "-", sql) - } else { - l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, rows, sql) - } - case l.LogLevel == Info: - sql, rows := fc() - if rows == -1 { - l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql) - } else { - l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) - } + + if l.LogLevel <= Silent { + return + } + + elapsed := time.Since(begin) + switch { + case err != nil && l.LogLevel >= Error && (!errors.Is(err, ErrRecordNotFound) || !l.IgnoreRecordNotFoundError): + sql, rows := fc() + if rows == -1 { + l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, "-", sql) + } else { + l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql) + } + case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= Warn: + sql, rows := fc() + slowLog := fmt.Sprintf("SLOW SQL >= %v", l.SlowThreshold) + if rows == -1 { + l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, "-", sql) + } else { + l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, rows, sql) + } + case l.LogLevel == Info: + sql, rows := fc() + if rows == -1 { + l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql) + } else { + l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) } } } diff --git a/schema/field_test.go b/schema/field_test.go index 64f4a909..4be3e5ab 100644 --- a/schema/field_test.go +++ b/schema/field_test.go @@ -235,6 +235,7 @@ type UserWithPermissionControl struct { Name5 string `gorm:"<-:update"` Name6 string `gorm:"<-:create,update"` Name7 string `gorm:"->:false;<-:create,update"` + Name8 string `gorm:"->;-:migration"` } func TestParseFieldWithPermission(t *testing.T) { @@ -252,6 +253,7 @@ func TestParseFieldWithPermission(t *testing.T) { {Name: "Name5", DBName: "name5", BindNames: []string{"Name5"}, DataType: schema.String, Tag: `gorm:"<-:update"`, Creatable: false, Updatable: true, Readable: true}, {Name: "Name6", DBName: "name6", BindNames: []string{"Name6"}, DataType: schema.String, Tag: `gorm:"<-:create,update"`, Creatable: true, Updatable: true, Readable: true}, {Name: "Name7", DBName: "name7", BindNames: []string{"Name7"}, DataType: schema.String, Tag: `gorm:"->:false;<-:create,update"`, Creatable: true, Updatable: true, Readable: false}, + {Name: "Name8", DBName: "name8", BindNames: []string{"Name8"}, DataType: schema.String, Tag: `gorm:"->;-:migration"`, Creatable: false, Updatable: false, Readable: true, IgnoreMigration: true}, } for _, f := range fields { diff --git a/schema/schema.go b/schema/schema.go index d08842e6..1ce88fa5 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -71,7 +71,7 @@ type Tabler interface { TableName() string } -// get data type from dialector +// Parse get data type from dialector func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { if dest == nil { return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) @@ -91,6 +91,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) if v, ok := cacheStore.Load(modelType); ok { s := v.(*Schema) + // Wait for the initialization of other goroutines to complete <-s.initialized return s, s.err } @@ -115,6 +116,15 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) namer: namer, initialized: make(chan struct{}), } + // When the schema initialization is completed, the channel will be closed + defer close(schema.initialized) + + if v, loaded := cacheStore.LoadOrStore(modelType, schema); loaded { + s := v.(*Schema) + // Wait for the initialization of other goroutines to complete + <-s.initialized + return s, s.err + } defer func() { if schema.err != nil { @@ -223,13 +233,6 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) } } - if v, loaded := cacheStore.LoadOrStore(modelType, schema); loaded { - s := v.(*Schema) - <-s.initialized - return s, s.err - } - - defer close(schema.initialized) if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded { for _, field := range schema.Fields { if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) { diff --git a/tests/go.mod b/tests/go.mod index d4b0c975..643b72c7 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -8,10 +8,10 @@ require ( github.com/lib/pq v1.6.0 github.com/stretchr/testify v1.5.1 gorm.io/driver/mysql v1.0.5 - gorm.io/driver/postgres v1.0.8 + gorm.io/driver/postgres v1.1.0 gorm.io/driver/sqlite v1.1.4 gorm.io/driver/sqlserver v1.0.7 - gorm.io/gorm v1.21.4 + gorm.io/gorm v1.21.9 ) replace gorm.io/gorm => ../ diff --git a/tests/scopes_test.go b/tests/scopes_test.go index 9836c41e..0ec4783b 100644 --- a/tests/scopes_test.go +++ b/tests/scopes_test.go @@ -54,4 +54,12 @@ func TestScopes(t *testing.T) { if db.Find(&User{}).Statement.Table != "custom_table" { t.Errorf("failed to call Scopes") } + + result := DB.Scopes(NameIn1And2, func(tx *gorm.DB) *gorm.DB { + return tx.Session(&gorm.Session{}) + }).Find(&users1) + + if result.RowsAffected != 2 { + t.Errorf("Should found two users's name in 1, 2, but got %v", result.RowsAffected) + } }