Merge 904b1289a4e883b6a43dc667c49ebd0178b57a7f into 4e34a6d21b63e9a9b701a70be9759e5539bf26e9
This commit is contained in:
		
						commit
						56bb02d628
					
				| @ -64,6 +64,7 @@ func (c Clause) Build(builder Builder) { | |||||||
| 
 | 
 | ||||||
| const ( | const ( | ||||||
| 	PrimaryKey   string = "~~~py~~~" // primary key
 | 	PrimaryKey   string = "~~~py~~~" // primary key
 | ||||||
|  | 	PrimaryKeys  string = "~~~ps~~~" // primary keys
 | ||||||
| 	CurrentTable string = "~~~ct~~~" // current table
 | 	CurrentTable string = "~~~ct~~~" // current table
 | ||||||
| 	Associations string = "~~~as~~~" // associations
 | 	Associations string = "~~~as~~~" // associations
 | ||||||
| ) | ) | ||||||
|  | |||||||
| @ -119,7 +119,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { | |||||||
| // First finds the first record ordered by primary key, matching given conditions conds
 | // First finds the first record ordered by primary key, matching given conditions conds
 | ||||||
| func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { | func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { | ||||||
| 	tx = db.Limit(1).Order(clause.OrderByColumn{ | 	tx = db.Limit(1).Order(clause.OrderByColumn{ | ||||||
| 		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, | 		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys}, | ||||||
| 	}) | 	}) | ||||||
| 	if len(conds) > 0 { | 	if len(conds) > 0 { | ||||||
| 		if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { | 		if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { | ||||||
| @ -147,7 +147,7 @@ func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) { | |||||||
| // Last finds the last record ordered by primary key, matching given conditions conds
 | // Last finds the last record ordered by primary key, matching given conditions conds
 | ||||||
| func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { | func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { | ||||||
| 	tx = db.Limit(1).Order(clause.OrderByColumn{ | 	tx = db.Limit(1).Order(clause.OrderByColumn{ | ||||||
| 		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, | 		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys}, | ||||||
| 		Desc:   true, | 		Desc:   true, | ||||||
| 	}) | 	}) | ||||||
| 	if len(conds) > 0 { | 	if len(conds) > 0 { | ||||||
| @ -174,9 +174,10 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { | |||||||
| 
 | 
 | ||||||
| // FindInBatches finds all records in batches of batchSize
 | // FindInBatches finds all records in batches of batchSize
 | ||||||
| func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB { | func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB { | ||||||
|  | 	// Use PrimaryKeys to handle composite primary key situations
 | ||||||
| 	var ( | 	var ( | ||||||
| 		tx = db.Order(clause.OrderByColumn{ | 		tx = db.Order(clause.OrderByColumn{ | ||||||
| 			Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, | 			Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys}, | ||||||
| 		}).Session(&Session{}) | 		}).Session(&Session{}) | ||||||
| 		queryDB      = tx | 		queryDB      = tx | ||||||
| 		rowsAffected int64 | 		rowsAffected int64 | ||||||
| @ -200,6 +201,7 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | find: | ||||||
| 	for { | 	for { | ||||||
| 		result := queryDB.Limit(batchSize).Find(dest) | 		result := queryDB.Limit(batchSize).Find(dest) | ||||||
| 		rowsAffected += result.RowsAffected | 		rowsAffected += result.RowsAffected | ||||||
| @ -228,17 +230,76 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat | |||||||
| 
 | 
 | ||||||
| 		// Optimize for-break
 | 		// Optimize for-break
 | ||||||
| 		resultsValue := reflect.Indirect(reflect.ValueOf(dest)) | 		resultsValue := reflect.Indirect(reflect.ValueOf(dest)) | ||||||
| 		if result.Statement.Schema.PrioritizedPrimaryField == nil { | 		if result.Statement.Schema.PrioritizedPrimaryField == nil && result.Statement.Schema.PrimaryFields != nil && len(result.Statement.Schema.PrimaryFields) == 1 { | ||||||
| 			tx.AddError(ErrPrimaryKeyRequired) | 			tx.AddError(ErrPrimaryKeyRequired) | ||||||
| 			break | 			break | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) | 		// The following will build a where clause like this:
 | ||||||
| 		if zero { | 		// struct {
 | ||||||
| 			tx.AddError(ErrPrimaryKeyRequired) | 		// col1    uint `gorm:"primaryKey;autoIncrement:false"`
 | ||||||
| 			break | 		// col2    uint `gorm:"primaryKey;autoIncrement:false"`
 | ||||||
|  | 		// col3    uint `gorm:"primaryKey;autoIncrement:false"`
 | ||||||
|  | 		// }
 | ||||||
|  | 		// last row returned was col1 = 2, col2 = 3, col3 = 5
 | ||||||
|  | 		// where clause will be generated as follows
 | ||||||
|  | 		// WHERE (col1 > 2 OR (col1 = 2 AND col2 > 3) OR (col1 = 2 AND col2 = 3 AND col3 > 5))
 | ||||||
|  | 		// Detect composite primary keys
 | ||||||
|  | 		if result.Statement.Schema.PrimaryFields != nil { | ||||||
|  | 			pkCount := len(result.Statement.Schema.PrimaryFields) | ||||||
|  | 
 | ||||||
|  | 			// Handle composite primary key Where clauses
 | ||||||
|  | 			if pkCount > 1 { | ||||||
|  | 				var f *schema.Field | ||||||
|  | 				var orClauses []clause.Expression | ||||||
|  | 				for i := 0; i < pkCount; i++ { | ||||||
|  | 					var andClauses []clause.Expression | ||||||
|  | 					// Build 1st column GT clause
 | ||||||
|  | 					if i == 0 { | ||||||
|  | 						f = result.Statement.Schema.PrimaryFields[i] | ||||||
|  | 						primaryValue, zero := f.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) | ||||||
|  | 						if zero { | ||||||
|  | 							tx.AddError(ErrPrimaryKeyRequired) //nolint:typecheck,errcheck,gosec
 | ||||||
|  | 							break find | ||||||
|  | 						} | ||||||
|  | 						orClauses = append(orClauses, clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName}, Value: primaryValue}) | ||||||
|  | 					} else { | ||||||
|  | 						// Build AND clause and append to OR clauses
 | ||||||
|  | 						for j := 0; j <= i; j++ { | ||||||
|  | 							f = result.Statement.Schema.PrimaryFields[j] | ||||||
|  | 							primaryValue, zero := f.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) | ||||||
|  | 							if zero { | ||||||
|  | 								tx.AddError(ErrPrimaryKeyRequired) //nolint:typecheck,errcheck,gosec
 | ||||||
|  | 								break find | ||||||
|  | 							} | ||||||
|  | 							if j == i { | ||||||
|  | 								// Build current outer column GT clause
 | ||||||
|  | 								andClauses = append(andClauses, clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName}, Value: primaryValue}) | ||||||
|  | 							} else { | ||||||
|  | 								// Build all other columns EQ clause
 | ||||||
|  | 								andClauses = append(andClauses, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName}, Value: primaryValue}) | ||||||
|  | 							} | ||||||
|  | 						} | ||||||
|  | 						orClauses = append(orClauses, clause.And(andClauses...)) | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  | 				queryDB = tx.Clauses(clause.Or(orClauses...)) | ||||||
|  | 			} else { | ||||||
|  | 				primaryValue, zero := result.Statement.Schema.PrimaryFields[0].ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) | ||||||
|  | 				if zero { | ||||||
|  | 					tx.AddError(ErrPrimaryKeyRequired) //nolint:typecheck,errcheck,gosec
 | ||||||
|  | 					break | ||||||
|  | 				} | ||||||
|  | 				queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: result.Statement.Schema.PrimaryFields[0].DBName}, Value: primaryValue}) | ||||||
|  | 			} | ||||||
|  | 		} else { | ||||||
|  | 			primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) | ||||||
|  | 			if zero { | ||||||
|  | 				tx.AddError(ErrPrimaryKeyRequired) //nolint:typecheck,errcheck,gosec
 | ||||||
|  | 				break | ||||||
|  | 			} | ||||||
|  | 			queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) | ||||||
| 		} | 		} | ||||||
| 		queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	tx.RowsAffected = rowsAffected | 	tx.RowsAffected = rowsAffected | ||||||
| @ -308,7 +369,7 @@ func (db *DB) assignInterfacesToValue(values ...interface{}) { | |||||||
| //	// user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
 | //	// user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
 | ||||||
| func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { | func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { | ||||||
| 	queryTx := db.Limit(1).Order(clause.OrderByColumn{ | 	queryTx := db.Limit(1).Order(clause.OrderByColumn{ | ||||||
| 		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, | 		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys}, | ||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
| 	if tx = queryTx.Find(dest, conds...); tx.RowsAffected == 0 { | 	if tx = queryTx.Find(dest, conds...); tx.RowsAffected == 0 { | ||||||
| @ -348,7 +409,7 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { | |||||||
| func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { | func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
| 	queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{ | 	queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{ | ||||||
| 		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, | 		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys}, | ||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
| 	result := queryTx.Find(dest, conds...) | 	result := queryTx.Find(dest, conds...) | ||||||
|  | |||||||
							
								
								
									
										61
									
								
								statement.go
									
									
									
									
									
								
							
							
						
						
									
										61
									
								
								statement.go
									
									
									
									
									
								
							| @ -108,27 +108,62 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { | |||||||
| 			write(v.Raw, v.Alias) | 			write(v.Raw, v.Alias) | ||||||
| 		} | 		} | ||||||
| 	case clause.Column: | 	case clause.Column: | ||||||
| 		if v.Table != "" { | 		// Handle composite primary keys explicitly
 | ||||||
| 			if v.Table == clause.CurrentTable { | 		if v.Name == clause.PrimaryKeys { | ||||||
| 				write(v.Raw, stmt.Table) |  | ||||||
| 			} else { |  | ||||||
| 				write(v.Raw, v.Table) |  | ||||||
| 			} |  | ||||||
| 			writer.WriteByte('.') |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if v.Name == clause.PrimaryKey { |  | ||||||
| 			if stmt.Schema == nil { | 			if stmt.Schema == nil { | ||||||
| 				stmt.DB.AddError(ErrModelValueRequired) | 				stmt.DB.AddError(ErrModelValueRequired) | ||||||
| 			} else if stmt.Schema.PrioritizedPrimaryField != nil { | 			} else if stmt.Schema.PrimaryFields != nil { | ||||||
| 				write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName) | 				for idx, s := range stmt.Schema.PrimaryFieldDBNames { | ||||||
|  | 					if idx > 0 { | ||||||
|  | 						writer.WriteByte(',') //nolint:typecheck,errcheck,gosec
 | ||||||
|  | 					} | ||||||
|  | 					if v.Table != "" { | ||||||
|  | 						if v.Table == clause.CurrentTable { | ||||||
|  | 							write(v.Raw, stmt.Table) | ||||||
|  | 						} else { | ||||||
|  | 							write(v.Raw, v.Table) | ||||||
|  | 						} | ||||||
|  | 						writer.WriteByte('.') //nolint:typecheck,errcheck,gosec
 | ||||||
|  | 					} | ||||||
|  | 					write(v.Raw, s) | ||||||
|  | 				} | ||||||
| 			} else if len(stmt.Schema.DBNames) > 0 { | 			} else if len(stmt.Schema.DBNames) > 0 { | ||||||
|  | 				if v.Table != "" { | ||||||
|  | 					if v.Table == clause.CurrentTable { | ||||||
|  | 						write(v.Raw, stmt.Table) | ||||||
|  | 					} else { | ||||||
|  | 						write(v.Raw, v.Table) | ||||||
|  | 					} | ||||||
|  | 					writer.WriteByte('.') //nolint:typecheck,errcheck,gosec
 | ||||||
|  | 				} | ||||||
| 				write(v.Raw, stmt.Schema.DBNames[0]) | 				write(v.Raw, stmt.Schema.DBNames[0]) | ||||||
| 			} else { | 			} else { | ||||||
| 				stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck
 | 				stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck
 | ||||||
| 			} | 			} | ||||||
| 		} else { | 		} else { | ||||||
| 			write(v.Raw, v.Name) | 			if v.Table != "" { | ||||||
|  | 				if v.Table == clause.CurrentTable { | ||||||
|  | 					write(v.Raw, stmt.Table) | ||||||
|  | 				} else { | ||||||
|  | 					write(v.Raw, v.Table) | ||||||
|  | 				} | ||||||
|  | 				writer.WriteByte('.') //nolint:typecheck,errcheck,gosec
 | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if v.Name == clause.PrimaryKey { | ||||||
|  | 				switch { | ||||||
|  | 				case stmt.Schema == nil: | ||||||
|  | 					stmt.DB.AddError(ErrModelValueRequired) //nolint:typecheck,errcheck,gosec,staticcheck
 | ||||||
|  | 				case stmt.Schema.PrioritizedPrimaryField != nil: | ||||||
|  | 					write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName) | ||||||
|  | 				case len(stmt.Schema.DBNames) > 0: | ||||||
|  | 					write(v.Raw, stmt.Schema.DBNames[0]) | ||||||
|  | 				default: | ||||||
|  | 					stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck,gosec,staticcheck
 | ||||||
|  | 				} | ||||||
|  | 			} else { | ||||||
|  | 				write(v.Raw, v.Name) | ||||||
|  | 			} | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		if v.Alias != "" { | 		if v.Alias != "" { | ||||||
|  | |||||||
| @ -418,6 +418,75 @@ func TestFindInBatchesWithError(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func TestFindInBatchesCompositeKey(t *testing.T) { | ||||||
|  | 	coupons := []Coupon{ | ||||||
|  | 		{AmountOff: 1.0, PercentOff: 0.5, AppliesToProduct: []*CouponProduct{ | ||||||
|  | 			{ProductId: "1", Desc: "find_in_batches"}, | ||||||
|  | 			{ProductId: "2", Desc: "find_in_batches"}, | ||||||
|  | 			{ProductId: "3", Desc: "find_in_batches"}, | ||||||
|  | 		}}, | ||||||
|  | 		{AmountOff: 2.0, PercentOff: 0.5, AppliesToProduct: []*CouponProduct{ | ||||||
|  | 			{ProductId: "1", Desc: "find_in_batches"}, | ||||||
|  | 			{ProductId: "2", Desc: "find_in_batches"}, | ||||||
|  | 			{ProductId: "3", Desc: "find_in_batches"}, | ||||||
|  | 		}}, | ||||||
|  | 		{AmountOff: 3.0, PercentOff: 0.5, AppliesToProduct: []*CouponProduct{ | ||||||
|  | 			{ProductId: "1", Desc: "find_in_batches"}, | ||||||
|  | 			{ProductId: "2", Desc: "find_in_batches"}, | ||||||
|  | 			{ProductId: "3", Desc: "find_in_batches"}, | ||||||
|  | 		}}, | ||||||
|  | 		{AmountOff: 4.0, PercentOff: 0.5, AppliesToProduct: []*CouponProduct{ | ||||||
|  | 			{ProductId: "1", Desc: "find_in_batches"}, | ||||||
|  | 			{ProductId: "2", Desc: "find_in_batches"}, | ||||||
|  | 			{ProductId: "3", Desc: "find_in_batches"}, | ||||||
|  | 		}}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	DB.Create(&coupons) | ||||||
|  | 
 | ||||||
|  | 	var ( | ||||||
|  | 		results   []CouponProduct | ||||||
|  | 		lastBatch int | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	if result := DB.Table("coupon_products as cp").Where(&CouponProduct{Desc: "find_in_batches"}).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error { | ||||||
|  | 		lastBatch = batch | ||||||
|  | 
 | ||||||
|  | 		if tx.RowsAffected != 2 { | ||||||
|  | 			t.Errorf("Incorrect affected rows, expects: 2, got %v", tx.RowsAffected) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if len(results) != 2 { | ||||||
|  | 			t.Errorf("Incorrect coupon_product length, expects: 2, got %v", len(results)) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		for idx := range results { | ||||||
|  | 			results[idx].Desc = results[idx].Desc + "_new" | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if err := tx.Save(results).Error; err != nil { | ||||||
|  | 			t.Fatalf("failed to save coupon_product, got error %v", err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		return nil | ||||||
|  | 	}); result.Error != nil || result.RowsAffected != 12 { | ||||||
|  | 		t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if lastBatch != 6 { | ||||||
|  | 		t.Errorf("incorrect final batch, expects: %v, got %v", 6, lastBatch) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var count int64 | ||||||
|  | 	DB.Model(&CouponProduct{}).Where(&CouponProduct{Desc: "find_in_batches_new"}).Count(&count) | ||||||
|  | 	if count != 12 { | ||||||
|  | 		t.Errorf("incorrect count after update, expects: %v, got %v", 12, count) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	DB.Unscoped().Where(&CouponProduct{Desc: "find_in_batches_new"}).Delete(&CouponProduct{}) | ||||||
|  | 	DB.Unscoped().Where("id in (1,2,3)").Delete(&Coupon{}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func TestFillSmallerStruct(t *testing.T) { | func TestFillSmallerStruct(t *testing.T) { | ||||||
| 	user := User{Name: "SmallerUser", Age: 100} | 	user := User{Name: "SmallerUser", Age: 100} | ||||||
| 	DB.Save(&user) | 	DB.Save(&user) | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Keith Martin
						Keith Martin