Merge branch 'go-gorm:master' into master

This commit is contained in:
Paras Waykole 2021-05-17 23:55:06 +05:30
commit ad1c581de7
8 changed files with 74 additions and 69 deletions

View File

@ -72,7 +72,7 @@ func (cs *callbacks) Raw() *processor {
return cs.processors["raw"] return cs.processors["raw"]
} }
func (p *processor) Execute(db *DB) { func (p *processor) Execute(db *DB) *DB {
// call scopes // call scopes
for len(db.Statement.scopes) > 0 { for len(db.Statement.scopes) > 0 {
scopes := db.Statement.scopes scopes := db.Statement.scopes
@ -142,6 +142,8 @@ func (p *processor) Execute(db *DB) {
if resetBuildClauses { if resetBuildClauses {
stmt.BuildClauses = nil stmt.BuildClauses = nil
} }
return db
} }
func (p *processor) Get(name string) func(*DB) { func (p *processor) Get(name string) func(*DB) {

View File

@ -10,7 +10,7 @@ var (
// ErrRecordNotFound record not found error // ErrRecordNotFound record not found error
ErrRecordNotFound = logger.ErrRecordNotFound ErrRecordNotFound = logger.ErrRecordNotFound
// ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback` // ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback`
ErrInvalidTransaction = errors.New("no valid transaction") ErrInvalidTransaction = errors.New("invalid transaction")
// ErrNotImplemented not implemented // ErrNotImplemented not implemented
ErrNotImplemented = errors.New("not implemented") ErrNotImplemented = errors.New("not implemented")
// ErrMissingWhereClause missing where clause // ErrMissingWhereClause missing where clause

View File

@ -21,8 +21,7 @@ func (db *DB) Create(value interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Dest = value tx.Statement.Dest = value
tx.callbacks.Create().Execute(tx) return tx.callbacks.Create().Execute(tx)
return
} }
// CreateInBatches insert the value in batches into database // CreateInBatches insert the value in batches into database
@ -64,7 +63,7 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
default: default:
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Dest = value tx.Statement.Dest = value
tx.callbacks.Create().Execute(tx) tx = tx.callbacks.Create().Execute(tx)
} }
return return
} }
@ -80,13 +79,12 @@ func (db *DB) Save(value interface{}) (tx *DB) {
if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok { if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok {
tx = tx.Clauses(clause.OnConflict{UpdateAll: true}) tx = tx.Clauses(clause.OnConflict{UpdateAll: true})
} }
tx.callbacks.Create().Execute(tx.InstanceSet("gorm:update_track_time", true)) tx = tx.callbacks.Create().Execute(tx.InstanceSet("gorm:update_track_time", true))
case reflect.Struct: case reflect.Struct:
if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
for _, pf := range tx.Statement.Schema.PrimaryFields { for _, pf := range tx.Statement.Schema.PrimaryFields {
if _, isZero := pf.ValueOf(reflectValue); isZero { if _, isZero := pf.ValueOf(reflectValue); isZero {
tx.callbacks.Create().Execute(tx) return tx.callbacks.Create().Execute(tx)
return
} }
} }
} }
@ -99,7 +97,7 @@ func (db *DB) Save(value interface{}) (tx *DB) {
tx.Statement.Selects = append(tx.Statement.Selects, "*") tx.Statement.Selects = append(tx.Statement.Selects, "*")
} }
tx.callbacks.Update().Execute(tx) tx = tx.callbacks.Update().Execute(tx)
if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate { if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate {
result := reflect.New(tx.Statement.Schema.ModelType).Interface() result := reflect.New(tx.Statement.Schema.ModelType).Interface()
@ -124,8 +122,7 @@ func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
} }
tx.Statement.RaiseErrorOnNotFound = true tx.Statement.RaiseErrorOnNotFound = true
tx.Statement.Dest = dest tx.Statement.Dest = dest
tx.callbacks.Query().Execute(tx) return tx.callbacks.Query().Execute(tx)
return
} }
// Take return a record that match given conditions, the order will depend on the database implementation // Take return a record that match given conditions, the order will depend on the database implementation
@ -138,8 +135,7 @@ func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
} }
tx.Statement.RaiseErrorOnNotFound = true tx.Statement.RaiseErrorOnNotFound = true
tx.Statement.Dest = dest tx.Statement.Dest = dest
tx.callbacks.Query().Execute(tx) return tx.callbacks.Query().Execute(tx)
return
} }
// Last find last record that match given conditions, order by primary key // Last find last record that match given conditions, order by primary key
@ -155,8 +151,7 @@ func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
} }
tx.Statement.RaiseErrorOnNotFound = true tx.Statement.RaiseErrorOnNotFound = true
tx.Statement.Dest = dest tx.Statement.Dest = dest
tx.callbacks.Query().Execute(tx) return tx.callbacks.Query().Execute(tx)
return
} }
// Find find records that match given conditions // Find find records that match given conditions
@ -168,8 +163,7 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
} }
} }
tx.Statement.Dest = dest tx.Statement.Dest = dest
tx.callbacks.Query().Execute(tx) return tx.callbacks.Query().Execute(tx)
return
} }
// FindInBatches find records in batches // FindInBatches find records in batches
@ -334,32 +328,28 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
func (db *DB) Update(column string, value interface{}) (tx *DB) { func (db *DB) Update(column string, value interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Dest = map[string]interface{}{column: value} tx.Statement.Dest = map[string]interface{}{column: value}
tx.callbacks.Update().Execute(tx) return tx.callbacks.Update().Execute(tx)
return
} }
// Updates update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields // Updates update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields
func (db *DB) Updates(values interface{}) (tx *DB) { func (db *DB) Updates(values interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Dest = values tx.Statement.Dest = values
tx.callbacks.Update().Execute(tx) return tx.callbacks.Update().Execute(tx)
return
} }
func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) { func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Dest = map[string]interface{}{column: value} tx.Statement.Dest = map[string]interface{}{column: value}
tx.Statement.SkipHooks = true tx.Statement.SkipHooks = true
tx.callbacks.Update().Execute(tx) return tx.callbacks.Update().Execute(tx)
return
} }
func (db *DB) UpdateColumns(values interface{}) (tx *DB) { func (db *DB) UpdateColumns(values interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Dest = values tx.Statement.Dest = values
tx.Statement.SkipHooks = true tx.Statement.SkipHooks = true
tx.callbacks.Update().Execute(tx) return tx.callbacks.Update().Execute(tx)
return
} }
// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition // Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
@ -371,8 +361,7 @@ func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) {
} }
} }
tx.Statement.Dest = value tx.Statement.Dest = value
tx.callbacks.Delete().Execute(tx) return tx.callbacks.Delete().Execute(tx)
return
} }
func (db *DB) Count(count *int64) (tx *DB) { func (db *DB) Count(count *int64) (tx *DB) {
@ -428,7 +417,7 @@ func (db *DB) Count(count *int64) (tx *DB) {
} }
tx.Statement.Dest = count tx.Statement.Dest = count
tx.callbacks.Query().Execute(tx) tx = tx.callbacks.Query().Execute(tx)
if tx.RowsAffected != 1 { if tx.RowsAffected != 1 {
*count = tx.RowsAffected *count = tx.RowsAffected
} }
@ -437,7 +426,7 @@ func (db *DB) Count(count *int64) (tx *DB) {
func (db *DB) Row() *sql.Row { func (db *DB) Row() *sql.Row {
tx := db.getInstance().InstanceSet("rows", false) tx := db.getInstance().InstanceSet("rows", false)
tx.callbacks.Row().Execute(tx) tx = tx.callbacks.Row().Execute(tx)
row, ok := tx.Statement.Dest.(*sql.Row) row, ok := tx.Statement.Dest.(*sql.Row)
if !ok && tx.DryRun { if !ok && tx.DryRun {
db.Logger.Error(tx.Statement.Context, ErrDryRunModeUnsupported.Error()) db.Logger.Error(tx.Statement.Context, ErrDryRunModeUnsupported.Error())
@ -447,7 +436,7 @@ func (db *DB) Row() *sql.Row {
func (db *DB) Rows() (*sql.Rows, error) { func (db *DB) Rows() (*sql.Rows, error) {
tx := db.getInstance().InstanceSet("rows", true) tx := db.getInstance().InstanceSet("rows", true)
tx.callbacks.Row().Execute(tx) tx = tx.callbacks.Row().Execute(tx)
rows, ok := tx.Statement.Dest.(*sql.Rows) rows, ok := tx.Statement.Dest.(*sql.Rows)
if !ok && tx.DryRun && tx.Error == nil { if !ok && tx.DryRun && tx.Error == nil {
tx.Error = ErrDryRunModeUnsupported tx.Error = ErrDryRunModeUnsupported
@ -505,8 +494,7 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
}) })
} }
tx.Statement.Dest = dest tx.Statement.Dest = dest
tx.callbacks.Query().Execute(tx) return tx.callbacks.Query().Execute(tx)
return
} }
func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
@ -644,6 +632,5 @@ func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
} }
tx.callbacks.Raw().Execute(tx) return tx.callbacks.Raw().Execute(tx)
return
} }

View File

@ -139,7 +139,11 @@ func (l logger) Error(ctx context.Context, msg string, data ...interface{}) {
// Trace print sql message // Trace print sql message
func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
if l.LogLevel > Silent {
if l.LogLevel <= Silent {
return
}
elapsed := time.Since(begin) elapsed := time.Since(begin)
switch { switch {
case err != nil && l.LogLevel >= Error && (!errors.Is(err, ErrRecordNotFound) || !l.IgnoreRecordNotFoundError): case err != nil && l.LogLevel >= Error && (!errors.Is(err, ErrRecordNotFound) || !l.IgnoreRecordNotFoundError):
@ -165,7 +169,6 @@ func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, i
l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql)
} }
} }
}
} }
type traceRecorder struct { type traceRecorder struct {

View File

@ -235,6 +235,7 @@ type UserWithPermissionControl struct {
Name5 string `gorm:"<-:update"` Name5 string `gorm:"<-:update"`
Name6 string `gorm:"<-:create,update"` Name6 string `gorm:"<-:create,update"`
Name7 string `gorm:"->:false;<-:create,update"` Name7 string `gorm:"->:false;<-:create,update"`
Name8 string `gorm:"->;-:migration"`
} }
func TestParseFieldWithPermission(t *testing.T) { func TestParseFieldWithPermission(t *testing.T) {
@ -252,6 +253,7 @@ func TestParseFieldWithPermission(t *testing.T) {
{Name: "Name5", DBName: "name5", BindNames: []string{"Name5"}, DataType: schema.String, Tag: `gorm:"<-:update"`, Creatable: false, Updatable: true, Readable: true}, {Name: "Name5", DBName: "name5", BindNames: []string{"Name5"}, DataType: schema.String, Tag: `gorm:"<-:update"`, Creatable: false, Updatable: true, Readable: true},
{Name: "Name6", DBName: "name6", BindNames: []string{"Name6"}, DataType: schema.String, Tag: `gorm:"<-:create,update"`, Creatable: true, Updatable: true, Readable: true}, {Name: "Name6", DBName: "name6", BindNames: []string{"Name6"}, DataType: schema.String, Tag: `gorm:"<-:create,update"`, Creatable: true, Updatable: true, Readable: true},
{Name: "Name7", DBName: "name7", BindNames: []string{"Name7"}, DataType: schema.String, Tag: `gorm:"->:false;<-:create,update"`, Creatable: true, Updatable: true, Readable: false}, {Name: "Name7", DBName: "name7", BindNames: []string{"Name7"}, DataType: schema.String, Tag: `gorm:"->:false;<-:create,update"`, Creatable: true, Updatable: true, Readable: false},
{Name: "Name8", DBName: "name8", BindNames: []string{"Name8"}, DataType: schema.String, Tag: `gorm:"->;-:migration"`, Creatable: false, Updatable: false, Readable: true, IgnoreMigration: true},
} }
for _, f := range fields { for _, f := range fields {

View File

@ -71,7 +71,7 @@ type Tabler interface {
TableName() string TableName() string
} }
// get data type from dialector // Parse get data type from dialector
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
if dest == nil { if dest == nil {
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
@ -91,6 +91,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
if v, ok := cacheStore.Load(modelType); ok { if v, ok := cacheStore.Load(modelType); ok {
s := v.(*Schema) s := v.(*Schema)
// Wait for the initialization of other goroutines to complete
<-s.initialized <-s.initialized
return s, s.err return s, s.err
} }
@ -115,6 +116,15 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
namer: namer, namer: namer,
initialized: make(chan struct{}), initialized: make(chan struct{}),
} }
// When the schema initialization is completed, the channel will be closed
defer close(schema.initialized)
if v, loaded := cacheStore.LoadOrStore(modelType, schema); loaded {
s := v.(*Schema)
// Wait for the initialization of other goroutines to complete
<-s.initialized
return s, s.err
}
defer func() { defer func() {
if schema.err != nil { if schema.err != nil {
@ -223,13 +233,6 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
} }
} }
if v, loaded := cacheStore.LoadOrStore(modelType, schema); loaded {
s := v.(*Schema)
<-s.initialized
return s, s.err
}
defer close(schema.initialized)
if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded { if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded {
for _, field := range schema.Fields { for _, field := range schema.Fields {
if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) { if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) {

View File

@ -8,10 +8,10 @@ require (
github.com/lib/pq v1.6.0 github.com/lib/pq v1.6.0
github.com/stretchr/testify v1.5.1 github.com/stretchr/testify v1.5.1
gorm.io/driver/mysql v1.0.5 gorm.io/driver/mysql v1.0.5
gorm.io/driver/postgres v1.0.8 gorm.io/driver/postgres v1.1.0
gorm.io/driver/sqlite v1.1.4 gorm.io/driver/sqlite v1.1.4
gorm.io/driver/sqlserver v1.0.7 gorm.io/driver/sqlserver v1.0.7
gorm.io/gorm v1.21.4 gorm.io/gorm v1.21.9
) )
replace gorm.io/gorm => ../ replace gorm.io/gorm => ../

View File

@ -54,4 +54,12 @@ func TestScopes(t *testing.T) {
if db.Find(&User{}).Statement.Table != "custom_table" { if db.Find(&User{}).Statement.Table != "custom_table" {
t.Errorf("failed to call Scopes") t.Errorf("failed to call Scopes")
} }
result := DB.Scopes(NameIn1And2, func(tx *gorm.DB) *gorm.DB {
return tx.Session(&gorm.Session{})
}).Find(&users1)
if result.RowsAffected != 2 {
t.Errorf("Should found two users's name in 1, 2, but got %v", result.RowsAffected)
}
} }