Add LimitPerRecord for generic version Preload
This commit is contained in:
		
							parent
							
								
									91eb9477f2
								
							
						
					
					
						commit
						6307f69f18
					
				| @ -275,6 +275,8 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload | ||||
| 	column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues) | ||||
| 
 | ||||
| 	if len(values) != 0 { | ||||
| 		tx = tx.Model(reflectResults.Addr().Interface()).Where(clause.IN{Column: column, Values: values}) | ||||
| 
 | ||||
| 		for _, cond := range conds { | ||||
| 			if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok { | ||||
| 				tx = fc(tx) | ||||
| @ -283,7 +285,11 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if err := tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error; err != nil { | ||||
| 		if len(inlineConds) > 0 { | ||||
| 			tx = tx.Where(inlineConds[0], inlineConds[1:]...) | ||||
| 		} | ||||
| 
 | ||||
| 		if err := tx.Find(reflectResults.Addr().Interface()).Error; err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
							
								
								
									
										103
									
								
								generics.go
									
									
									
									
									
								
							
							
						
						
									
										103
									
								
								generics.go
									
									
									
									
									
								
							| @ -77,7 +77,7 @@ type PreloadBuilder interface { | ||||
| 	Limit(offset int) PreloadBuilder | ||||
| 	Offset(offset int) PreloadBuilder | ||||
| 	Order(value interface{}) PreloadBuilder | ||||
| 	Scopes(scopes ...func(db *Statement)) PreloadBuilder | ||||
| 	LimitPerRecord(num int) PreloadBuilder | ||||
| } | ||||
| 
 | ||||
| type op func(*DB) *DB | ||||
| @ -214,76 +214,78 @@ type joinBuilder struct { | ||||
| 	db *DB | ||||
| } | ||||
| 
 | ||||
| func (q joinBuilder) Where(query interface{}, args ...interface{}) JoinBuilder { | ||||
| func (q *joinBuilder) Where(query interface{}, args ...interface{}) JoinBuilder { | ||||
| 	q.db.Where(query, args...) | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| func (q joinBuilder) Or(query interface{}, args ...interface{}) JoinBuilder { | ||||
| func (q *joinBuilder) Or(query interface{}, args ...interface{}) JoinBuilder { | ||||
| 	q.db.Where(query, args...) | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| func (q joinBuilder) Not(query interface{}, args ...interface{}) JoinBuilder { | ||||
| func (q *joinBuilder) Not(query interface{}, args ...interface{}) JoinBuilder { | ||||
| 	q.db.Where(query, args...) | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| func (q joinBuilder) Select(columns ...string) JoinBuilder { | ||||
| func (q *joinBuilder) Select(columns ...string) JoinBuilder { | ||||
| 	q.db.Select(columns) | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| func (q joinBuilder) Omit(columns ...string) JoinBuilder { | ||||
| func (q *joinBuilder) Omit(columns ...string) JoinBuilder { | ||||
| 	q.db.Omit(columns...) | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| type preloadBuilder struct { | ||||
| 	db *DB | ||||
| 	limitPerRecord int | ||||
| 	db             *DB | ||||
| } | ||||
| 
 | ||||
| func (q preloadBuilder) Where(query interface{}, args ...interface{}) PreloadBuilder { | ||||
| func (q *preloadBuilder) Where(query interface{}, args ...interface{}) PreloadBuilder { | ||||
| 	q.db.Where(query, args...) | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| func (q preloadBuilder) Or(query interface{}, args ...interface{}) PreloadBuilder { | ||||
| func (q *preloadBuilder) Or(query interface{}, args ...interface{}) PreloadBuilder { | ||||
| 	q.db.Where(query, args...) | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| func (q preloadBuilder) Not(query interface{}, args ...interface{}) PreloadBuilder { | ||||
| func (q *preloadBuilder) Not(query interface{}, args ...interface{}) PreloadBuilder { | ||||
| 	q.db.Where(query, args...) | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| func (q preloadBuilder) Select(columns ...string) PreloadBuilder { | ||||
| func (q *preloadBuilder) Select(columns ...string) PreloadBuilder { | ||||
| 	q.db.Select(columns) | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| func (q preloadBuilder) Omit(columns ...string) PreloadBuilder { | ||||
| func (q *preloadBuilder) Omit(columns ...string) PreloadBuilder { | ||||
| 	q.db.Omit(columns...) | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| func (q preloadBuilder) Limit(limit int) PreloadBuilder { | ||||
| func (q *preloadBuilder) Limit(limit int) PreloadBuilder { | ||||
| 	q.db.Limit(limit) | ||||
| 	return q | ||||
| } | ||||
| func (q preloadBuilder) Offset(offset int) PreloadBuilder { | ||||
| 
 | ||||
| func (q *preloadBuilder) Offset(offset int) PreloadBuilder { | ||||
| 	q.db.Offset(offset) | ||||
| 	return q | ||||
| } | ||||
| func (q preloadBuilder) Order(value interface{}) PreloadBuilder { | ||||
| 
 | ||||
| func (q *preloadBuilder) Order(value interface{}) PreloadBuilder { | ||||
| 	q.db.Order(value) | ||||
| 	return q | ||||
| } | ||||
| func (q preloadBuilder) Scopes(scopes ...func(db *Statement)) PreloadBuilder { | ||||
| 	for _, fc := range scopes { | ||||
| 		fc(q.db.Statement) | ||||
| 	} | ||||
| 
 | ||||
| func (q *preloadBuilder) LimitPerRecord(num int) PreloadBuilder { | ||||
| 	q.limitPerRecord = num | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| @ -295,7 +297,7 @@ func (c chainG[T]) Joins(jt clause.JoinTarget, on func(db JoinBuilder, joinTable | ||||
| 
 | ||||
| 		q := joinBuilder{db: db.Session(&Session{NewDB: true, Initialized: true}).Table(jt.Table)} | ||||
| 		if on != nil { | ||||
| 			if err := on(q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable}); err != nil { | ||||
| 			if err := on(&q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable}); err != nil { | ||||
| 				db.AddError(err) | ||||
| 			} | ||||
| 		} | ||||
| @ -390,10 +392,69 @@ func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) err | ||||
| 		return db.Preload(association, func(tx *DB) *DB { | ||||
| 			q := preloadBuilder{db: tx.getInstance()} | ||||
| 			if query != nil { | ||||
| 				if err := query(q); err != nil { | ||||
| 				if err := query(&q); err != nil { | ||||
| 					db.AddError(err) | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			relation, ok := db.Statement.Schema.Relationships.Relations[association] | ||||
| 			if !ok { | ||||
| 				db.AddError(fmt.Errorf("relation %s not found", association)) | ||||
| 			} | ||||
| 
 | ||||
| 			if q.limitPerRecord > 0 { | ||||
| 				if relation.JoinTable != nil { | ||||
| 					err := fmt.Errorf("many2many relation %s don't support LimitPerRecord", association) | ||||
| 					tx.AddError(err) | ||||
| 					return tx | ||||
| 				} | ||||
| 
 | ||||
| 				refColumns := []clause.Column{} | ||||
| 				for _, rel := range relation.References { | ||||
| 					if rel.OwnPrimaryKey { | ||||
| 						refColumns = append(refColumns, clause.Column{Name: rel.ForeignKey.DBName}) | ||||
| 					} | ||||
| 				} | ||||
| 
 | ||||
| 				if len(refColumns) != 0 { | ||||
| 					selects := q.db.Statement.Selects | ||||
| 					selectExpr := clause.CommaExpression{} | ||||
| 					if len(selects) == 0 { | ||||
| 						selectExpr.Exprs = []clause.Expression{clause.Expr{SQL: "*", Vars: []interface{}{}}} | ||||
| 					} else { | ||||
| 						for _, column := range selects { | ||||
| 							selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column}}}) | ||||
| 						} | ||||
| 					} | ||||
| 
 | ||||
| 					partitionBy := clause.CommaExpression{} | ||||
| 					for _, column := range refColumns { | ||||
| 						partitionBy.Exprs = append(partitionBy.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column.Name}}}) | ||||
| 					} | ||||
| 
 | ||||
| 					rnnColumn := clause.Column{Name: "gorm_preload_rnn"} | ||||
| 					sql := "ROW_NUMBER() OVER (PARTITION BY ? ?)" | ||||
| 					vars := []interface{}{partitionBy} | ||||
| 					if orderBy, ok := q.db.Statement.Clauses["ORDER BY"]; ok { | ||||
| 						vars = append(vars, orderBy) | ||||
| 					} else { | ||||
| 						vars = append(vars, clause.Clause{Name: "ORDER BY", Expression: clause.OrderBy{ | ||||
| 							Columns: []clause.OrderByColumn{ | ||||
| 								{Column: clause.PrimaryColumn, Desc: false}, | ||||
| 							}, | ||||
| 						}}) | ||||
| 					} | ||||
| 					vars = append(vars, rnnColumn) | ||||
| 
 | ||||
| 					selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: sql + " AS ?", Vars: vars}) | ||||
| 
 | ||||
| 					q.db.Clauses(clause.Select{ | ||||
| 						Expression: selectExpr, | ||||
| 					}) | ||||
| 
 | ||||
| 					return q.db.Session(&Session{NewDB: true}).Unscoped().Table("(?) t", q.db).Where("? <= ?", rnnColumn, q.limitPerRecord) | ||||
| 				} | ||||
| 			} | ||||
| 			return q.db | ||||
| 		}) | ||||
| 	}) | ||||
|  | ||||
| @ -209,6 +209,7 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { | ||||
| 			} | ||||
| 		case interface{ getInstance() *DB }: | ||||
| 			cv := v.getInstance() | ||||
| 
 | ||||
| 			subdb := cv.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance() | ||||
| 			if cv.Statement.SQL.Len() > 0 { | ||||
| 				var ( | ||||
|  | ||||
| @ -277,7 +277,7 @@ func TestGenericsScopes(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestGenericsJoinsAndPreload(t *testing.T) { | ||||
| func TestGenericsJoins(t *testing.T) { | ||||
| 	ctx := context.Background() | ||||
| 	db := gorm.G[User](DB) | ||||
| 
 | ||||
| @ -374,20 +374,32 @@ func TestGenericsJoinsAndPreload(t *testing.T) { | ||||
| 	if err == nil { | ||||
| 		t.Fatalf("Joins should got error, but got nil") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| 	// Preload
 | ||||
| 	result3, err := db.Preload("Company", nil).Where("name = ?", u.Name).First(ctx) | ||||
| func TestGenericsPreloads(t *testing.T) { | ||||
| 	ctx := context.Background() | ||||
| 	db := gorm.G[User](DB) | ||||
| 
 | ||||
| 	u := *GetUser("GenericsPreloads_1", Config{Company: true, Pets: 3, Friends: 7}) | ||||
| 	u2 := *GetUser("GenericsPreloads_2", Config{Company: true, Pets: 5, Friends: 5}) | ||||
| 	u3 := *GetUser("GenericsPreloads_3", Config{Company: true, Pets: 7, Friends: 3}) | ||||
| 	names := []string{u.Name, u2.Name, u3.Name} | ||||
| 
 | ||||
| 	db.CreateInBatches(ctx, &[]User{u3, u, u2}, 10) | ||||
| 
 | ||||
| 	result, err := db.Preload("Company", nil).Preload("Pets", nil).Where("name = ?", u.Name).First(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Preload failed: %v", err) | ||||
| 	} | ||||
| 	if result3.Name != u.Name || result3.Company.Name != u.Company.Name { | ||||
| 
 | ||||
| 	if result.Name != u.Name || result.Company.Name != u.Company.Name || len(result.Pets) != len(u.Pets) { | ||||
| 		t.Fatalf("Preload expected %s, got %+v", u.Name, result) | ||||
| 	} | ||||
| 
 | ||||
| 	results, err := db.Preload("Company", func(db gorm.PreloadBuilder) error { | ||||
| 		db.Where("name = ?", u.Company.Name) | ||||
| 		return nil | ||||
| 	}).Find(ctx) | ||||
| 	}).Where("name in ?", names).Find(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Preload failed: %v", err) | ||||
| 	} | ||||
| @ -403,10 +415,80 @@ func TestGenericsJoinsAndPreload(t *testing.T) { | ||||
| 
 | ||||
| 	_, err = db.Preload("Company", func(db gorm.PreloadBuilder) error { | ||||
| 		return errors.New("preload error") | ||||
| 	}).Find(ctx) | ||||
| 	}).Where("name in ?", names).Find(ctx) | ||||
| 	if err == nil { | ||||
| 		t.Fatalf("Preload should failed, but got nil") | ||||
| 	} | ||||
| 
 | ||||
| 	results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error { | ||||
| 		db.LimitPerRecord(5) | ||||
| 		return nil | ||||
| 	}).Where("name in ?", names).Find(ctx) | ||||
| 
 | ||||
| 	for _, result := range results { | ||||
| 		if result.Name == u.Name { | ||||
| 			if len(result.Pets) != len(u.Pets) { | ||||
| 				t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets) | ||||
| 			} | ||||
| 		} else if len(result.Pets) != 5 { | ||||
| 			t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Dialector.Name() == "sqlserver" { | ||||
| 		// sqlserver doesn't support order by in subquery
 | ||||
| 		return | ||||
| 	} | ||||
| 	results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error { | ||||
| 		db.Order("name desc").LimitPerRecord(5) | ||||
| 		return nil | ||||
| 	}).Where("name in ?", names).Find(ctx) | ||||
| 
 | ||||
| 	for _, result := range results { | ||||
| 		if result.Name == u.Name { | ||||
| 			if len(result.Pets) != len(u.Pets) { | ||||
| 				t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets) | ||||
| 			} | ||||
| 		} else if len(result.Pets) != 5 { | ||||
| 			t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets) | ||||
| 		} | ||||
| 		for i := 1; i < len(result.Pets); i++ { | ||||
| 			if result.Pets[i-1].Name < result.Pets[i].Name { | ||||
| 				t.Fatalf("Preload user %v pets not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i]) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error { | ||||
| 		db.Order("name").LimitPerRecord(5) | ||||
| 		return nil | ||||
| 	}).Preload("Friends", func(db gorm.PreloadBuilder) error { | ||||
| 		db.Order("name") | ||||
| 		return nil | ||||
| 	}).Where("name in ?", names).Find(ctx) | ||||
| 
 | ||||
| 	for _, result := range results { | ||||
| 		if result.Name == u.Name { | ||||
| 			if len(result.Pets) != len(u.Pets) { | ||||
| 				t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets) | ||||
| 			} | ||||
| 			if len(result.Friends) != len(u.Friends) { | ||||
| 				t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets) | ||||
| 			} | ||||
| 		} else if len(result.Pets) != 5 || len(result.Friends) == 0 { | ||||
| 			t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets) | ||||
| 		} | ||||
| 		for i := 1; i < len(result.Pets); i++ { | ||||
| 			if result.Pets[i-1].Name > result.Pets[i].Name { | ||||
| 				t.Fatalf("Preload user %v pets not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i]) | ||||
| 			} | ||||
| 		} | ||||
| 		for i := 1; i < len(result.Pets); i++ { | ||||
| 			if result.Pets[i-1].Name > result.Pets[i].Name { | ||||
| 				t.Fatalf("Preload user %v friends not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i]) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestGenericsDistinct(t *testing.T) { | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu