Merge 904b1289a4e883b6a43dc667c49ebd0178b57a7f into 4e34a6d21b63e9a9b701a70be9759e5539bf26e9
This commit is contained in:
		
						commit
						56bb02d628
					
				| @ -64,6 +64,7 @@ func (c Clause) Build(builder Builder) { | ||||
| 
 | ||||
| const ( | ||||
| 	PrimaryKey   string = "~~~py~~~" // primary key
 | ||||
| 	PrimaryKeys  string = "~~~ps~~~" // primary keys
 | ||||
| 	CurrentTable string = "~~~ct~~~" // current table
 | ||||
| 	Associations string = "~~~as~~~" // associations
 | ||||
| ) | ||||
|  | ||||
| @ -119,7 +119,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { | ||||
| // First finds the first record ordered by primary key, matching given conditions conds
 | ||||
| func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { | ||||
| 	tx = db.Limit(1).Order(clause.OrderByColumn{ | ||||
| 		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, | ||||
| 		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys}, | ||||
| 	}) | ||||
| 	if len(conds) > 0 { | ||||
| 		if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { | ||||
| @ -147,7 +147,7 @@ func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) { | ||||
| // Last finds the last record ordered by primary key, matching given conditions conds
 | ||||
| func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { | ||||
| 	tx = db.Limit(1).Order(clause.OrderByColumn{ | ||||
| 		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, | ||||
| 		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys}, | ||||
| 		Desc:   true, | ||||
| 	}) | ||||
| 	if len(conds) > 0 { | ||||
| @ -174,9 +174,10 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { | ||||
| 
 | ||||
| // FindInBatches finds all records in batches of batchSize
 | ||||
| func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB { | ||||
| 	// Use PrimaryKeys to handle composite primary key situations
 | ||||
| 	var ( | ||||
| 		tx = db.Order(clause.OrderByColumn{ | ||||
| 			Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, | ||||
| 			Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys}, | ||||
| 		}).Session(&Session{}) | ||||
| 		queryDB      = tx | ||||
| 		rowsAffected int64 | ||||
| @ -200,6 +201,7 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| find: | ||||
| 	for { | ||||
| 		result := queryDB.Limit(batchSize).Find(dest) | ||||
| 		rowsAffected += result.RowsAffected | ||||
| @ -228,17 +230,76 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat | ||||
| 
 | ||||
| 		// Optimize for-break
 | ||||
| 		resultsValue := reflect.Indirect(reflect.ValueOf(dest)) | ||||
| 		if result.Statement.Schema.PrioritizedPrimaryField == nil { | ||||
| 		if result.Statement.Schema.PrioritizedPrimaryField == nil && result.Statement.Schema.PrimaryFields != nil && len(result.Statement.Schema.PrimaryFields) == 1 { | ||||
| 			tx.AddError(ErrPrimaryKeyRequired) | ||||
| 			break | ||||
| 		} | ||||
| 
 | ||||
| 		primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) | ||||
| 		if zero { | ||||
| 			tx.AddError(ErrPrimaryKeyRequired) | ||||
| 			break | ||||
| 		// The following will build a where clause like this:
 | ||||
| 		// struct {
 | ||||
| 		// col1    uint `gorm:"primaryKey;autoIncrement:false"`
 | ||||
| 		// col2    uint `gorm:"primaryKey;autoIncrement:false"`
 | ||||
| 		// col3    uint `gorm:"primaryKey;autoIncrement:false"`
 | ||||
| 		// }
 | ||||
| 		// last row returned was col1 = 2, col2 = 3, col3 = 5
 | ||||
| 		// where clause will be generated as follows
 | ||||
| 		// WHERE (col1 > 2 OR (col1 = 2 AND col2 > 3) OR (col1 = 2 AND col2 = 3 AND col3 > 5))
 | ||||
| 		// Detect composite primary keys
 | ||||
| 		if result.Statement.Schema.PrimaryFields != nil { | ||||
| 			pkCount := len(result.Statement.Schema.PrimaryFields) | ||||
| 
 | ||||
| 			// Handle composite primary key Where clauses
 | ||||
| 			if pkCount > 1 { | ||||
| 				var f *schema.Field | ||||
| 				var orClauses []clause.Expression | ||||
| 				for i := 0; i < pkCount; i++ { | ||||
| 					var andClauses []clause.Expression | ||||
| 					// Build 1st column GT clause
 | ||||
| 					if i == 0 { | ||||
| 						f = result.Statement.Schema.PrimaryFields[i] | ||||
| 						primaryValue, zero := f.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) | ||||
| 						if zero { | ||||
| 							tx.AddError(ErrPrimaryKeyRequired) //nolint:typecheck,errcheck,gosec
 | ||||
| 							break find | ||||
| 						} | ||||
| 						orClauses = append(orClauses, clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName}, Value: primaryValue}) | ||||
| 					} else { | ||||
| 						// Build AND clause and append to OR clauses
 | ||||
| 						for j := 0; j <= i; j++ { | ||||
| 							f = result.Statement.Schema.PrimaryFields[j] | ||||
| 							primaryValue, zero := f.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) | ||||
| 							if zero { | ||||
| 								tx.AddError(ErrPrimaryKeyRequired) //nolint:typecheck,errcheck,gosec
 | ||||
| 								break find | ||||
| 							} | ||||
| 							if j == i { | ||||
| 								// Build current outer column GT clause
 | ||||
| 								andClauses = append(andClauses, clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName}, Value: primaryValue}) | ||||
| 							} else { | ||||
| 								// Build all other columns EQ clause
 | ||||
| 								andClauses = append(andClauses, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName}, Value: primaryValue}) | ||||
| 							} | ||||
| 						} | ||||
| 						orClauses = append(orClauses, clause.And(andClauses...)) | ||||
| 					} | ||||
| 				} | ||||
| 				queryDB = tx.Clauses(clause.Or(orClauses...)) | ||||
| 			} else { | ||||
| 				primaryValue, zero := result.Statement.Schema.PrimaryFields[0].ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) | ||||
| 				if zero { | ||||
| 					tx.AddError(ErrPrimaryKeyRequired) //nolint:typecheck,errcheck,gosec
 | ||||
| 					break | ||||
| 				} | ||||
| 				queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: result.Statement.Schema.PrimaryFields[0].DBName}, Value: primaryValue}) | ||||
| 			} | ||||
| 		} else { | ||||
| 			primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) | ||||
| 			if zero { | ||||
| 				tx.AddError(ErrPrimaryKeyRequired) //nolint:typecheck,errcheck,gosec
 | ||||
| 				break | ||||
| 			} | ||||
| 			queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) | ||||
| 		} | ||||
| 		queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) | ||||
| 	} | ||||
| 
 | ||||
| 	tx.RowsAffected = rowsAffected | ||||
| @ -308,7 +369,7 @@ func (db *DB) assignInterfacesToValue(values ...interface{}) { | ||||
| //	// user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
 | ||||
| func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { | ||||
| 	queryTx := db.Limit(1).Order(clause.OrderByColumn{ | ||||
| 		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, | ||||
| 		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys}, | ||||
| 	}) | ||||
| 
 | ||||
| 	if tx = queryTx.Find(dest, conds...); tx.RowsAffected == 0 { | ||||
| @ -348,7 +409,7 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { | ||||
| func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { | ||||
| 	tx = db.getInstance() | ||||
| 	queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{ | ||||
| 		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, | ||||
| 		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys}, | ||||
| 	}) | ||||
| 
 | ||||
| 	result := queryTx.Find(dest, conds...) | ||||
|  | ||||
							
								
								
									
										61
									
								
								statement.go
									
									
									
									
									
								
							
							
						
						
									
										61
									
								
								statement.go
									
									
									
									
									
								
							| @ -108,27 +108,62 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { | ||||
| 			write(v.Raw, v.Alias) | ||||
| 		} | ||||
| 	case clause.Column: | ||||
| 		if v.Table != "" { | ||||
| 			if v.Table == clause.CurrentTable { | ||||
| 				write(v.Raw, stmt.Table) | ||||
| 			} else { | ||||
| 				write(v.Raw, v.Table) | ||||
| 			} | ||||
| 			writer.WriteByte('.') | ||||
| 		} | ||||
| 
 | ||||
| 		if v.Name == clause.PrimaryKey { | ||||
| 		// Handle composite primary keys explicitly
 | ||||
| 		if v.Name == clause.PrimaryKeys { | ||||
| 			if stmt.Schema == nil { | ||||
| 				stmt.DB.AddError(ErrModelValueRequired) | ||||
| 			} else if stmt.Schema.PrioritizedPrimaryField != nil { | ||||
| 				write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName) | ||||
| 			} else if stmt.Schema.PrimaryFields != nil { | ||||
| 				for idx, s := range stmt.Schema.PrimaryFieldDBNames { | ||||
| 					if idx > 0 { | ||||
| 						writer.WriteByte(',') //nolint:typecheck,errcheck,gosec
 | ||||
| 					} | ||||
| 					if v.Table != "" { | ||||
| 						if v.Table == clause.CurrentTable { | ||||
| 							write(v.Raw, stmt.Table) | ||||
| 						} else { | ||||
| 							write(v.Raw, v.Table) | ||||
| 						} | ||||
| 						writer.WriteByte('.') //nolint:typecheck,errcheck,gosec
 | ||||
| 					} | ||||
| 					write(v.Raw, s) | ||||
| 				} | ||||
| 			} else if len(stmt.Schema.DBNames) > 0 { | ||||
| 				if v.Table != "" { | ||||
| 					if v.Table == clause.CurrentTable { | ||||
| 						write(v.Raw, stmt.Table) | ||||
| 					} else { | ||||
| 						write(v.Raw, v.Table) | ||||
| 					} | ||||
| 					writer.WriteByte('.') //nolint:typecheck,errcheck,gosec
 | ||||
| 				} | ||||
| 				write(v.Raw, stmt.Schema.DBNames[0]) | ||||
| 			} else { | ||||
| 				stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck
 | ||||
| 			} | ||||
| 		} else { | ||||
| 			write(v.Raw, v.Name) | ||||
| 			if v.Table != "" { | ||||
| 				if v.Table == clause.CurrentTable { | ||||
| 					write(v.Raw, stmt.Table) | ||||
| 				} else { | ||||
| 					write(v.Raw, v.Table) | ||||
| 				} | ||||
| 				writer.WriteByte('.') //nolint:typecheck,errcheck,gosec
 | ||||
| 			} | ||||
| 
 | ||||
| 			if v.Name == clause.PrimaryKey { | ||||
| 				switch { | ||||
| 				case stmt.Schema == nil: | ||||
| 					stmt.DB.AddError(ErrModelValueRequired) //nolint:typecheck,errcheck,gosec,staticcheck
 | ||||
| 				case stmt.Schema.PrioritizedPrimaryField != nil: | ||||
| 					write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName) | ||||
| 				case len(stmt.Schema.DBNames) > 0: | ||||
| 					write(v.Raw, stmt.Schema.DBNames[0]) | ||||
| 				default: | ||||
| 					stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck,gosec,staticcheck
 | ||||
| 				} | ||||
| 			} else { | ||||
| 				write(v.Raw, v.Name) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if v.Alias != "" { | ||||
|  | ||||
| @ -418,6 +418,75 @@ func TestFindInBatchesWithError(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestFindInBatchesCompositeKey(t *testing.T) { | ||||
| 	coupons := []Coupon{ | ||||
| 		{AmountOff: 1.0, PercentOff: 0.5, AppliesToProduct: []*CouponProduct{ | ||||
| 			{ProductId: "1", Desc: "find_in_batches"}, | ||||
| 			{ProductId: "2", Desc: "find_in_batches"}, | ||||
| 			{ProductId: "3", Desc: "find_in_batches"}, | ||||
| 		}}, | ||||
| 		{AmountOff: 2.0, PercentOff: 0.5, AppliesToProduct: []*CouponProduct{ | ||||
| 			{ProductId: "1", Desc: "find_in_batches"}, | ||||
| 			{ProductId: "2", Desc: "find_in_batches"}, | ||||
| 			{ProductId: "3", Desc: "find_in_batches"}, | ||||
| 		}}, | ||||
| 		{AmountOff: 3.0, PercentOff: 0.5, AppliesToProduct: []*CouponProduct{ | ||||
| 			{ProductId: "1", Desc: "find_in_batches"}, | ||||
| 			{ProductId: "2", Desc: "find_in_batches"}, | ||||
| 			{ProductId: "3", Desc: "find_in_batches"}, | ||||
| 		}}, | ||||
| 		{AmountOff: 4.0, PercentOff: 0.5, AppliesToProduct: []*CouponProduct{ | ||||
| 			{ProductId: "1", Desc: "find_in_batches"}, | ||||
| 			{ProductId: "2", Desc: "find_in_batches"}, | ||||
| 			{ProductId: "3", Desc: "find_in_batches"}, | ||||
| 		}}, | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Create(&coupons) | ||||
| 
 | ||||
| 	var ( | ||||
| 		results   []CouponProduct | ||||
| 		lastBatch int | ||||
| 	) | ||||
| 
 | ||||
| 	if result := DB.Table("coupon_products as cp").Where(&CouponProduct{Desc: "find_in_batches"}).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error { | ||||
| 		lastBatch = batch | ||||
| 
 | ||||
| 		if tx.RowsAffected != 2 { | ||||
| 			t.Errorf("Incorrect affected rows, expects: 2, got %v", tx.RowsAffected) | ||||
| 		} | ||||
| 
 | ||||
| 		if len(results) != 2 { | ||||
| 			t.Errorf("Incorrect coupon_product length, expects: 2, got %v", len(results)) | ||||
| 		} | ||||
| 
 | ||||
| 		for idx := range results { | ||||
| 			results[idx].Desc = results[idx].Desc + "_new" | ||||
| 		} | ||||
| 
 | ||||
| 		if err := tx.Save(results).Error; err != nil { | ||||
| 			t.Fatalf("failed to save coupon_product, got error %v", err) | ||||
| 		} | ||||
| 
 | ||||
| 		return nil | ||||
| 	}); result.Error != nil || result.RowsAffected != 12 { | ||||
| 		t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected) | ||||
| 	} | ||||
| 
 | ||||
| 	if lastBatch != 6 { | ||||
| 		t.Errorf("incorrect final batch, expects: %v, got %v", 6, lastBatch) | ||||
| 	} | ||||
| 
 | ||||
| 	var count int64 | ||||
| 	DB.Model(&CouponProduct{}).Where(&CouponProduct{Desc: "find_in_batches_new"}).Count(&count) | ||||
| 	if count != 12 { | ||||
| 		t.Errorf("incorrect count after update, expects: %v, got %v", 12, count) | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Unscoped().Where(&CouponProduct{Desc: "find_in_batches_new"}).Delete(&CouponProduct{}) | ||||
| 	DB.Unscoped().Where("id in (1,2,3)").Delete(&Coupon{}) | ||||
| } | ||||
| 
 | ||||
| func TestFillSmallerStruct(t *testing.T) { | ||||
| 	user := User{Name: "SmallerUser", Age: 100} | ||||
| 	DB.Save(&user) | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Keith Martin
						Keith Martin