Add handling for Composite Primary Keys to First, Last, FindInBatches, FirstOrInit and FirstOrCreate.
This commit is contained in:
		
							parent
							
								
									e5b867e785
								
							
						
					
					
						commit
						881bd7747b
					
				| @ -64,6 +64,7 @@ func (c Clause) Build(builder Builder) { | ||||
| 
 | ||||
| const ( | ||||
| 	PrimaryKey   string = "~~~py~~~" // primary key
 | ||||
| 	PrimaryKeys  string = "~~~ps~~~" // primary keys
 | ||||
| 	CurrentTable string = "~~~ct~~~" // current table
 | ||||
| 	Associations string = "~~~as~~~" // associations
 | ||||
| ) | ||||
|  | ||||
| @ -118,7 +118,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { | ||||
| // First finds the first record ordered by primary key, matching given conditions conds
 | ||||
| func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { | ||||
| 	tx = db.Limit(1).Order(clause.OrderByColumn{ | ||||
| 		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, | ||||
| 		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys}, | ||||
| 	}) | ||||
| 	if len(conds) > 0 { | ||||
| 		if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { | ||||
| @ -146,7 +146,7 @@ func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) { | ||||
| // Last finds the last record ordered by primary key, matching given conditions conds
 | ||||
| func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { | ||||
| 	tx = db.Limit(1).Order(clause.OrderByColumn{ | ||||
| 		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, | ||||
| 		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys}, | ||||
| 		Desc:   true, | ||||
| 	}) | ||||
| 	if len(conds) > 0 { | ||||
| @ -173,9 +173,10 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { | ||||
| 
 | ||||
| // FindInBatches finds all records in batches of batchSize
 | ||||
| func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB { | ||||
| 	// Use PrimaryKeys to handle composite primary key situations
 | ||||
| 	var ( | ||||
| 		tx = db.Order(clause.OrderByColumn{ | ||||
| 			Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, | ||||
| 			Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys}, | ||||
| 		}).Session(&Session{}) | ||||
| 		queryDB      = tx | ||||
| 		rowsAffected int64 | ||||
| @ -232,12 +233,71 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat | ||||
| 			break | ||||
| 		} | ||||
| 
 | ||||
| 		primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) | ||||
| 		if zero { | ||||
| 			tx.AddError(ErrPrimaryKeyRequired) | ||||
| 			break | ||||
| 		// The following will build a where clause like this:
 | ||||
| 		// struct {
 | ||||
| 		// col1    uint `gorm:"primaryKey;autoIncrement:false"`
 | ||||
| 		// col2    uint `gorm:"primaryKey;autoIncrement:false"`
 | ||||
| 		// col3    uint `gorm:"primaryKey;autoIncrement:false"`
 | ||||
| 		// }
 | ||||
| 		// last row returned was col1 = 2, col2 = 3, col3 = 5
 | ||||
| 		// where clause will be generated as follows
 | ||||
| 		// WHERE (col1 > 2 OR (col1 = 2 AND col2 > 3) OR (col1 = 2 AND col2 = 3 AND col3 > 5))
 | ||||
| 		// Detect composite primary keys
 | ||||
| 		if result.Statement.Schema.PrimaryFields != nil { | ||||
| 			pkCount := len(result.Statement.Schema.PrimaryFields) | ||||
| 
 | ||||
| 			// Handle composite primary key Where clauses
 | ||||
| 			if pkCount > 1 { | ||||
| 				var f *schema.Field | ||||
| 				var orClauses []clause.Expression | ||||
| 				for i := 0; i < pkCount; i++ { | ||||
| 					var andClauses []clause.Expression | ||||
| 					// Build 1st column GT clause
 | ||||
| 					if i == 0 { | ||||
| 						f = result.Statement.Schema.PrimaryFields[i] | ||||
| 						primaryValue, zero := f.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) | ||||
| 						if zero { | ||||
| 							tx.AddError(ErrPrimaryKeyRequired) | ||||
| 							break | ||||
| 						} | ||||
| 						orClauses = append(orClauses, clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName}, Value: primaryValue}) | ||||
| 					} else { | ||||
| 						// Build AND clause and append to OR clauses
 | ||||
| 						for j := 0; j <= i; j++ { | ||||
| 							f = result.Statement.Schema.PrimaryFields[j] | ||||
| 							primaryValue, zero := f.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) | ||||
| 							if zero { | ||||
| 								tx.AddError(ErrPrimaryKeyRequired) | ||||
| 								break | ||||
| 							} | ||||
| 							if j == i { | ||||
| 								// Build current outer column GT clause
 | ||||
| 								andClauses = append(andClauses, clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName}, Value: primaryValue}) | ||||
| 							} else { | ||||
| 								// Build all other columns EQ clause
 | ||||
| 								andClauses = append(andClauses, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName}, Value: primaryValue}) | ||||
| 							} | ||||
| 						} | ||||
| 						orClauses = append(orClauses, clause.And(andClauses...)) | ||||
| 					} | ||||
| 				} | ||||
| 				queryDB = tx.Clauses(clause.Or(orClauses...)) | ||||
| 			} else { | ||||
| 				primaryValue, zero := result.Statement.Schema.PrimaryFields[0].ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) | ||||
| 				if zero { | ||||
| 					tx.AddError(ErrPrimaryKeyRequired) | ||||
| 					break | ||||
| 				} | ||||
| 				queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: result.Statement.Schema.PrimaryFields[0].DBName}, Value: primaryValue}) | ||||
| 			} | ||||
| 		} else { | ||||
| 			primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) | ||||
| 			if zero { | ||||
| 				tx.AddError(ErrPrimaryKeyRequired) | ||||
| 				break | ||||
| 			} | ||||
| 			queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) | ||||
| 		} | ||||
| 		queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) | ||||
| 	} | ||||
| 
 | ||||
| 	tx.RowsAffected = rowsAffected | ||||
| @ -307,7 +367,7 @@ func (db *DB) assignInterfacesToValue(values ...interface{}) { | ||||
| //	// user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
 | ||||
| func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { | ||||
| 	queryTx := db.Limit(1).Order(clause.OrderByColumn{ | ||||
| 		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, | ||||
| 		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys}, | ||||
| 	}) | ||||
| 
 | ||||
| 	if tx = queryTx.Find(dest, conds...); tx.RowsAffected == 0 { | ||||
| @ -347,7 +407,7 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { | ||||
| func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { | ||||
| 	tx = db.getInstance() | ||||
| 	queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{ | ||||
| 		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, | ||||
| 		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys}, | ||||
| 	}) | ||||
| 
 | ||||
| 	result := queryTx.Find(dest, conds...) | ||||
|  | ||||
							
								
								
									
										60
									
								
								statement.go
									
									
									
									
									
								
							
							
						
						
									
										60
									
								
								statement.go
									
									
									
									
									
								
							| @ -105,27 +105,61 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { | ||||
| 			write(v.Raw, v.Alias) | ||||
| 		} | ||||
| 	case clause.Column: | ||||
| 		if v.Table != "" { | ||||
| 			if v.Table == clause.CurrentTable { | ||||
| 				write(v.Raw, stmt.Table) | ||||
| 			} else { | ||||
| 				write(v.Raw, v.Table) | ||||
| 			} | ||||
| 			writer.WriteByte('.') | ||||
| 		} | ||||
| 
 | ||||
| 		if v.Name == clause.PrimaryKey { | ||||
| 		// Handle composite primary keys explicitly
 | ||||
| 		if v.Name == clause.PrimaryKeys { | ||||
| 			if stmt.Schema == nil { | ||||
| 				stmt.DB.AddError(ErrModelValueRequired) | ||||
| 			} else if stmt.Schema.PrioritizedPrimaryField != nil { | ||||
| 				write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName) | ||||
| 			} else if stmt.Schema.PrimaryFields != nil { | ||||
| 				for idx, s := range stmt.Schema.PrimaryFieldDBNames { | ||||
| 					if idx > 0 { | ||||
| 						writer.WriteByte(',') | ||||
| 					} | ||||
| 					if v.Table != "" { | ||||
| 						if v.Table == clause.CurrentTable { | ||||
| 							write(v.Raw, stmt.Table) | ||||
| 						} else { | ||||
| 							write(v.Raw, v.Table) | ||||
| 						} | ||||
| 						writer.WriteByte('.') | ||||
| 					} | ||||
| 					write(v.Raw, s) | ||||
| 				} | ||||
| 			} else if len(stmt.Schema.DBNames) > 0 { | ||||
| 				if v.Table != "" { | ||||
| 					if v.Table == clause.CurrentTable { | ||||
| 						write(v.Raw, stmt.Table) | ||||
| 					} else { | ||||
| 						write(v.Raw, v.Table) | ||||
| 					} | ||||
| 					writer.WriteByte('.') | ||||
| 				} | ||||
| 				write(v.Raw, stmt.Schema.DBNames[0]) | ||||
| 			} else { | ||||
| 				stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck
 | ||||
| 			} | ||||
| 		} else { | ||||
| 			write(v.Raw, v.Name) | ||||
| 			if v.Table != "" { | ||||
| 				if v.Table == clause.CurrentTable { | ||||
| 					write(v.Raw, stmt.Table) | ||||
| 				} else { | ||||
| 					write(v.Raw, v.Table) | ||||
| 				} | ||||
| 				writer.WriteByte('.') | ||||
| 			} | ||||
| 
 | ||||
| 			if v.Name == clause.PrimaryKey { | ||||
| 				if stmt.Schema == nil { | ||||
| 					stmt.DB.AddError(ErrModelValueRequired) | ||||
| 				} else if stmt.Schema.PrioritizedPrimaryField != nil { | ||||
| 					write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName) | ||||
| 				} else if len(stmt.Schema.DBNames) > 0 { | ||||
| 					write(v.Raw, stmt.Schema.DBNames[0]) | ||||
| 				} else { | ||||
| 					stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck
 | ||||
| 				} | ||||
| 			} else { | ||||
| 				write(v.Raw, v.Name) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if v.Alias != "" { | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Keith Martin
						Keith Martin