Merge 904b1289a4e883b6a43dc667c49ebd0178b57a7f into 4e34a6d21b63e9a9b701a70be9759e5539bf26e9
This commit is contained in:
commit
56bb02d628
@ -64,6 +64,7 @@ func (c Clause) Build(builder Builder) {
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
PrimaryKey string = "~~~py~~~" // primary key
|
PrimaryKey string = "~~~py~~~" // primary key
|
||||||
|
PrimaryKeys string = "~~~ps~~~" // primary keys
|
||||||
CurrentTable string = "~~~ct~~~" // current table
|
CurrentTable string = "~~~ct~~~" // current table
|
||||||
Associations string = "~~~as~~~" // associations
|
Associations string = "~~~as~~~" // associations
|
||||||
)
|
)
|
||||||
|
@ -119,7 +119,7 @@ func (db *DB) Save(value interface{}) (tx *DB) {
|
|||||||
// First finds the first record ordered by primary key, matching given conditions conds
|
// First finds the first record ordered by primary key, matching given conditions conds
|
||||||
func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
|
func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||||
tx = db.Limit(1).Order(clause.OrderByColumn{
|
tx = db.Limit(1).Order(clause.OrderByColumn{
|
||||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys},
|
||||||
})
|
})
|
||||||
if len(conds) > 0 {
|
if len(conds) > 0 {
|
||||||
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
|
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
|
||||||
@ -147,7 +147,7 @@ func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
|
|||||||
// Last finds the last record ordered by primary key, matching given conditions conds
|
// Last finds the last record ordered by primary key, matching given conditions conds
|
||||||
func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
|
func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||||
tx = db.Limit(1).Order(clause.OrderByColumn{
|
tx = db.Limit(1).Order(clause.OrderByColumn{
|
||||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys},
|
||||||
Desc: true,
|
Desc: true,
|
||||||
})
|
})
|
||||||
if len(conds) > 0 {
|
if len(conds) > 0 {
|
||||||
@ -174,9 +174,10 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
|
|||||||
|
|
||||||
// FindInBatches finds all records in batches of batchSize
|
// FindInBatches finds all records in batches of batchSize
|
||||||
func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB {
|
func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB {
|
||||||
|
// Use PrimaryKeys to handle composite primary key situations
|
||||||
var (
|
var (
|
||||||
tx = db.Order(clause.OrderByColumn{
|
tx = db.Order(clause.OrderByColumn{
|
||||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys},
|
||||||
}).Session(&Session{})
|
}).Session(&Session{})
|
||||||
queryDB = tx
|
queryDB = tx
|
||||||
rowsAffected int64
|
rowsAffected int64
|
||||||
@ -200,6 +201,7 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
find:
|
||||||
for {
|
for {
|
||||||
result := queryDB.Limit(batchSize).Find(dest)
|
result := queryDB.Limit(batchSize).Find(dest)
|
||||||
rowsAffected += result.RowsAffected
|
rowsAffected += result.RowsAffected
|
||||||
@ -228,17 +230,76 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
|
|||||||
|
|
||||||
// Optimize for-break
|
// Optimize for-break
|
||||||
resultsValue := reflect.Indirect(reflect.ValueOf(dest))
|
resultsValue := reflect.Indirect(reflect.ValueOf(dest))
|
||||||
if result.Statement.Schema.PrioritizedPrimaryField == nil {
|
if result.Statement.Schema.PrioritizedPrimaryField == nil && result.Statement.Schema.PrimaryFields != nil && len(result.Statement.Schema.PrimaryFields) == 1 {
|
||||||
tx.AddError(ErrPrimaryKeyRequired)
|
tx.AddError(ErrPrimaryKeyRequired)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1))
|
// The following will build a where clause like this:
|
||||||
if zero {
|
// struct {
|
||||||
tx.AddError(ErrPrimaryKeyRequired)
|
// col1 uint `gorm:"primaryKey;autoIncrement:false"`
|
||||||
break
|
// col2 uint `gorm:"primaryKey;autoIncrement:false"`
|
||||||
|
// col3 uint `gorm:"primaryKey;autoIncrement:false"`
|
||||||
|
// }
|
||||||
|
// last row returned was col1 = 2, col2 = 3, col3 = 5
|
||||||
|
// where clause will be generated as follows
|
||||||
|
// WHERE (col1 > 2 OR (col1 = 2 AND col2 > 3) OR (col1 = 2 AND col2 = 3 AND col3 > 5))
|
||||||
|
// Detect composite primary keys
|
||||||
|
if result.Statement.Schema.PrimaryFields != nil {
|
||||||
|
pkCount := len(result.Statement.Schema.PrimaryFields)
|
||||||
|
|
||||||
|
// Handle composite primary key Where clauses
|
||||||
|
if pkCount > 1 {
|
||||||
|
var f *schema.Field
|
||||||
|
var orClauses []clause.Expression
|
||||||
|
for i := 0; i < pkCount; i++ {
|
||||||
|
var andClauses []clause.Expression
|
||||||
|
// Build 1st column GT clause
|
||||||
|
if i == 0 {
|
||||||
|
f = result.Statement.Schema.PrimaryFields[i]
|
||||||
|
primaryValue, zero := f.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1))
|
||||||
|
if zero {
|
||||||
|
tx.AddError(ErrPrimaryKeyRequired) //nolint:typecheck,errcheck,gosec
|
||||||
|
break find
|
||||||
|
}
|
||||||
|
orClauses = append(orClauses, clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName}, Value: primaryValue})
|
||||||
|
} else {
|
||||||
|
// Build AND clause and append to OR clauses
|
||||||
|
for j := 0; j <= i; j++ {
|
||||||
|
f = result.Statement.Schema.PrimaryFields[j]
|
||||||
|
primaryValue, zero := f.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1))
|
||||||
|
if zero {
|
||||||
|
tx.AddError(ErrPrimaryKeyRequired) //nolint:typecheck,errcheck,gosec
|
||||||
|
break find
|
||||||
|
}
|
||||||
|
if j == i {
|
||||||
|
// Build current outer column GT clause
|
||||||
|
andClauses = append(andClauses, clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName}, Value: primaryValue})
|
||||||
|
} else {
|
||||||
|
// Build all other columns EQ clause
|
||||||
|
andClauses = append(andClauses, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName}, Value: primaryValue})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
orClauses = append(orClauses, clause.And(andClauses...))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
queryDB = tx.Clauses(clause.Or(orClauses...))
|
||||||
|
} else {
|
||||||
|
primaryValue, zero := result.Statement.Schema.PrimaryFields[0].ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1))
|
||||||
|
if zero {
|
||||||
|
tx.AddError(ErrPrimaryKeyRequired) //nolint:typecheck,errcheck,gosec
|
||||||
|
break
|
||||||
|
}
|
||||||
|
queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: result.Statement.Schema.PrimaryFields[0].DBName}, Value: primaryValue})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1))
|
||||||
|
if zero {
|
||||||
|
tx.AddError(ErrPrimaryKeyRequired) //nolint:typecheck,errcheck,gosec
|
||||||
|
break
|
||||||
|
}
|
||||||
|
queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
|
||||||
}
|
}
|
||||||
queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tx.RowsAffected = rowsAffected
|
tx.RowsAffected = rowsAffected
|
||||||
@ -308,7 +369,7 @@ func (db *DB) assignInterfacesToValue(values ...interface{}) {
|
|||||||
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
|
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
|
||||||
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
|
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||||
queryTx := db.Limit(1).Order(clause.OrderByColumn{
|
queryTx := db.Limit(1).Order(clause.OrderByColumn{
|
||||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys},
|
||||||
})
|
})
|
||||||
|
|
||||||
if tx = queryTx.Find(dest, conds...); tx.RowsAffected == 0 {
|
if tx = queryTx.Find(dest, conds...); tx.RowsAffected == 0 {
|
||||||
@ -348,7 +409,7 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
|
|||||||
func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
|
func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{
|
queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{
|
||||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys},
|
||||||
})
|
})
|
||||||
|
|
||||||
result := queryTx.Find(dest, conds...)
|
result := queryTx.Find(dest, conds...)
|
||||||
|
61
statement.go
61
statement.go
@ -108,27 +108,62 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
|
|||||||
write(v.Raw, v.Alias)
|
write(v.Raw, v.Alias)
|
||||||
}
|
}
|
||||||
case clause.Column:
|
case clause.Column:
|
||||||
if v.Table != "" {
|
// Handle composite primary keys explicitly
|
||||||
if v.Table == clause.CurrentTable {
|
if v.Name == clause.PrimaryKeys {
|
||||||
write(v.Raw, stmt.Table)
|
|
||||||
} else {
|
|
||||||
write(v.Raw, v.Table)
|
|
||||||
}
|
|
||||||
writer.WriteByte('.')
|
|
||||||
}
|
|
||||||
|
|
||||||
if v.Name == clause.PrimaryKey {
|
|
||||||
if stmt.Schema == nil {
|
if stmt.Schema == nil {
|
||||||
stmt.DB.AddError(ErrModelValueRequired)
|
stmt.DB.AddError(ErrModelValueRequired)
|
||||||
} else if stmt.Schema.PrioritizedPrimaryField != nil {
|
} else if stmt.Schema.PrimaryFields != nil {
|
||||||
write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName)
|
for idx, s := range stmt.Schema.PrimaryFieldDBNames {
|
||||||
|
if idx > 0 {
|
||||||
|
writer.WriteByte(',') //nolint:typecheck,errcheck,gosec
|
||||||
|
}
|
||||||
|
if v.Table != "" {
|
||||||
|
if v.Table == clause.CurrentTable {
|
||||||
|
write(v.Raw, stmt.Table)
|
||||||
|
} else {
|
||||||
|
write(v.Raw, v.Table)
|
||||||
|
}
|
||||||
|
writer.WriteByte('.') //nolint:typecheck,errcheck,gosec
|
||||||
|
}
|
||||||
|
write(v.Raw, s)
|
||||||
|
}
|
||||||
} else if len(stmt.Schema.DBNames) > 0 {
|
} else if len(stmt.Schema.DBNames) > 0 {
|
||||||
|
if v.Table != "" {
|
||||||
|
if v.Table == clause.CurrentTable {
|
||||||
|
write(v.Raw, stmt.Table)
|
||||||
|
} else {
|
||||||
|
write(v.Raw, v.Table)
|
||||||
|
}
|
||||||
|
writer.WriteByte('.') //nolint:typecheck,errcheck,gosec
|
||||||
|
}
|
||||||
write(v.Raw, stmt.Schema.DBNames[0])
|
write(v.Raw, stmt.Schema.DBNames[0])
|
||||||
} else {
|
} else {
|
||||||
stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck
|
stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
write(v.Raw, v.Name)
|
if v.Table != "" {
|
||||||
|
if v.Table == clause.CurrentTable {
|
||||||
|
write(v.Raw, stmt.Table)
|
||||||
|
} else {
|
||||||
|
write(v.Raw, v.Table)
|
||||||
|
}
|
||||||
|
writer.WriteByte('.') //nolint:typecheck,errcheck,gosec
|
||||||
|
}
|
||||||
|
|
||||||
|
if v.Name == clause.PrimaryKey {
|
||||||
|
switch {
|
||||||
|
case stmt.Schema == nil:
|
||||||
|
stmt.DB.AddError(ErrModelValueRequired) //nolint:typecheck,errcheck,gosec,staticcheck
|
||||||
|
case stmt.Schema.PrioritizedPrimaryField != nil:
|
||||||
|
write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName)
|
||||||
|
case len(stmt.Schema.DBNames) > 0:
|
||||||
|
write(v.Raw, stmt.Schema.DBNames[0])
|
||||||
|
default:
|
||||||
|
stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck,gosec,staticcheck
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
write(v.Raw, v.Name)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if v.Alias != "" {
|
if v.Alias != "" {
|
||||||
|
@ -418,6 +418,75 @@ func TestFindInBatchesWithError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFindInBatchesCompositeKey(t *testing.T) {
|
||||||
|
coupons := []Coupon{
|
||||||
|
{AmountOff: 1.0, PercentOff: 0.5, AppliesToProduct: []*CouponProduct{
|
||||||
|
{ProductId: "1", Desc: "find_in_batches"},
|
||||||
|
{ProductId: "2", Desc: "find_in_batches"},
|
||||||
|
{ProductId: "3", Desc: "find_in_batches"},
|
||||||
|
}},
|
||||||
|
{AmountOff: 2.0, PercentOff: 0.5, AppliesToProduct: []*CouponProduct{
|
||||||
|
{ProductId: "1", Desc: "find_in_batches"},
|
||||||
|
{ProductId: "2", Desc: "find_in_batches"},
|
||||||
|
{ProductId: "3", Desc: "find_in_batches"},
|
||||||
|
}},
|
||||||
|
{AmountOff: 3.0, PercentOff: 0.5, AppliesToProduct: []*CouponProduct{
|
||||||
|
{ProductId: "1", Desc: "find_in_batches"},
|
||||||
|
{ProductId: "2", Desc: "find_in_batches"},
|
||||||
|
{ProductId: "3", Desc: "find_in_batches"},
|
||||||
|
}},
|
||||||
|
{AmountOff: 4.0, PercentOff: 0.5, AppliesToProduct: []*CouponProduct{
|
||||||
|
{ProductId: "1", Desc: "find_in_batches"},
|
||||||
|
{ProductId: "2", Desc: "find_in_batches"},
|
||||||
|
{ProductId: "3", Desc: "find_in_batches"},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Create(&coupons)
|
||||||
|
|
||||||
|
var (
|
||||||
|
results []CouponProduct
|
||||||
|
lastBatch int
|
||||||
|
)
|
||||||
|
|
||||||
|
if result := DB.Table("coupon_products as cp").Where(&CouponProduct{Desc: "find_in_batches"}).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error {
|
||||||
|
lastBatch = batch
|
||||||
|
|
||||||
|
if tx.RowsAffected != 2 {
|
||||||
|
t.Errorf("Incorrect affected rows, expects: 2, got %v", tx.RowsAffected)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(results) != 2 {
|
||||||
|
t.Errorf("Incorrect coupon_product length, expects: 2, got %v", len(results))
|
||||||
|
}
|
||||||
|
|
||||||
|
for idx := range results {
|
||||||
|
results[idx].Desc = results[idx].Desc + "_new"
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Save(results).Error; err != nil {
|
||||||
|
t.Fatalf("failed to save coupon_product, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}); result.Error != nil || result.RowsAffected != 12 {
|
||||||
|
t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected)
|
||||||
|
}
|
||||||
|
|
||||||
|
if lastBatch != 6 {
|
||||||
|
t.Errorf("incorrect final batch, expects: %v, got %v", 6, lastBatch)
|
||||||
|
}
|
||||||
|
|
||||||
|
var count int64
|
||||||
|
DB.Model(&CouponProduct{}).Where(&CouponProduct{Desc: "find_in_batches_new"}).Count(&count)
|
||||||
|
if count != 12 {
|
||||||
|
t.Errorf("incorrect count after update, expects: %v, got %v", 12, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Unscoped().Where(&CouponProduct{Desc: "find_in_batches_new"}).Delete(&CouponProduct{})
|
||||||
|
DB.Unscoped().Where("id in (1,2,3)").Delete(&Coupon{})
|
||||||
|
}
|
||||||
|
|
||||||
func TestFillSmallerStruct(t *testing.T) {
|
func TestFillSmallerStruct(t *testing.T) {
|
||||||
user := User{Name: "SmallerUser", Age: 100}
|
user := User{Name: "SmallerUser", Age: 100}
|
||||||
DB.Save(&user)
|
DB.Save(&user)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user