Complete the design and implementation of generic version Join

This commit is contained in:
Jinzhu 2025-05-09 17:16:10 +08:00
parent 7095605cd0
commit 05925b2fc0
5 changed files with 55 additions and 23 deletions

View File

@ -448,6 +448,7 @@ func (db *DB) Assign(attrs ...interface{}) (tx *DB) {
// Unscoped allows queries to include records marked as deleted, // Unscoped allows queries to include records marked as deleted,
// overriding the soft deletion behavior. // overriding the soft deletion behavior.
// Example: // Example:
//
// var users []User // var users []User
// db.Unscoped().Find(&users) // db.Unscoped().Find(&users)
// // Retrieves all users, including deleted ones. // // Retrieves all users, including deleted ones.

View File

@ -11,6 +11,30 @@ const (
RightJoin JoinType = "RIGHT" RightJoin JoinType = "RIGHT"
) )
type JoinTarget struct {
Type JoinType
Association string
Subquery Expression
Table string
}
func Has(name string) JoinTarget {
return JoinTarget{Type: LeftJoin, Association: name}
}
func (jt JoinType) Association(name string) JoinTarget {
return JoinTarget{Type: jt, Association: name}
}
func (jt JoinType) Subquery(subquery Expression) JoinTarget {
return JoinTarget{Type: jt, Subquery: subquery}
}
func (jt JoinTarget) As(name string) JoinTarget {
jt.Table = name
return jt
}
// Join clause for from // Join clause for from
type Join struct { type Join struct {
Type JoinType Type JoinType

View File

@ -28,8 +28,7 @@ type ChainInterface[T any] interface {
Or(query interface{}, args ...interface{}) ChainInterface[T] Or(query interface{}, args ...interface{}) ChainInterface[T]
Limit(offset int) ChainInterface[T] Limit(offset int) ChainInterface[T]
Offset(offset int) ChainInterface[T] Offset(offset int) ChainInterface[T]
Joins(query string, args ...interface{}) ChainInterface[T] Joins(query clause.JoinTarget, args func(db ChainInterface[any], joinTable clause.Table, curTable clause.Table) ChainInterface[any]) ChainInterface[T]
InnerJoins(query string, args ...interface{}) ChainInterface[T]
Select(query string, args ...interface{}) ChainInterface[T] Select(query string, args ...interface{}) ChainInterface[T]
Omit(columns ...string) ChainInterface[T] Omit(columns ...string) ChainInterface[T]
MapColumns(m map[string]string) ChainInterface[T] MapColumns(m map[string]string) ChainInterface[T]
@ -186,16 +185,9 @@ func (c chainG[T]) Offset(offset int) ChainInterface[T] {
}) })
} }
func (c chainG[T]) Joins(query string, args ...interface{}) ChainInterface[T] { func (c chainG[T]) Joins(query clause.JoinTarget, args func(db ChainInterface[any], joinTable clause.Table, curTable clause.Table) ChainInterface[any]) ChainInterface[T] {
return c.with(func(db *DB) *DB { // TODO
return db.Joins(query, args...) return nil
})
}
func (c chainG[T]) InnerJoins(query string, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.InnerJoins(query, args...)
})
} }
func (c chainG[T]) Select(query string, args ...interface{}) ChainInterface[T] { func (c chainG[T]) Select(query string, args ...interface{}) ChainInterface[T] {

View File

@ -285,7 +285,9 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
db.Create(ctx, &u) db.Create(ctx, &u)
// LEFT JOIN + WHERE // LEFT JOIN + WHERE
result, err := db.Joins("Company").Where("?.name = ?", clause.JoinTable("Company"), u.Company.Name).First(ctx) result, err := db.Joins(clause.Has("Company"), func(db gorm.ChainInterface[any], joinTable clause.Table, curTable clause.Table) gorm.ChainInterface[any] {
return db.Where("?.name = ?", joinTable, u.Company.Name)
}).First(ctx)
if err != nil { if err != nil {
t.Fatalf("Joins failed: %v", err) t.Fatalf("Joins failed: %v", err)
} }
@ -293,13 +295,26 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
t.Fatalf("Joins expected %s, got %+v", u.Name, result) t.Fatalf("Joins expected %s, got %+v", u.Name, result)
} }
// INNER JOIN + Inline WHERE // JOIN
result2, err := db.InnerJoins("Company", "?.name = ?", clause.JoinTable("Company"), u.Company.Name).First(ctx) result, err = db.Joins(clause.Has("Company"), func(db gorm.ChainInterface[any], joinTable clause.Table, curTable clause.Table) gorm.ChainInterface[any] {
return nil
}).First(ctx)
if err != nil { if err != nil {
t.Fatalf("InnerJoins failed: %v", err) t.Fatalf("Joins failed: %v", err)
} }
if result2.Name != u.Name || result2.Company.Name != u.Company.Name { if result.Name != u.Name || result.Company.Name != u.Company.Name {
t.Errorf("InnerJoins expected , got %+v", result2) t.Fatalf("Joins expected %s, got %+v", u.Name, result)
}
// Left JOIN
result, err = db.Joins(clause.LeftJoin.Association("Company").As("t"), func(db gorm.ChainInterface[any], joinTable clause.Table, curTable clause.Table) gorm.ChainInterface[any] {
return nil
}).First(ctx)
if err != nil {
t.Fatalf("Joins failed: %v", err)
}
if result.Name != u.Name || result.Company.Name != u.Company.Name {
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
} }
// Preload // Preload

View File

@ -29,8 +29,8 @@ require (
github.com/microsoft/go-mssqldb v1.8.0 // indirect github.com/microsoft/go-mssqldb v1.8.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.12.0 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect
golang.org/x/crypto v0.37.0 // indirect golang.org/x/crypto v0.38.0 // indirect
golang.org/x/text v0.24.0 // indirect golang.org/x/text v0.25.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
) )