From 05925b2fc08f9be5a3ca6c384454c7a43febe920 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 9 May 2025 17:16:10 +0800 Subject: [PATCH] Complete the design and implementation of generic version Join --- chainable_api.go | 7 ++++--- clause/joins.go | 24 ++++++++++++++++++++++++ generics.go | 16 ++++------------ tests/generics_test.go | 27 +++++++++++++++++++++------ tests/go.mod | 4 ++-- 5 files changed, 55 insertions(+), 23 deletions(-) diff --git a/chainable_api.go b/chainable_api.go index 8953413d..8a6aea34 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -448,9 +448,10 @@ func (db *DB) Assign(attrs ...interface{}) (tx *DB) { // Unscoped allows queries to include records marked as deleted, // overriding the soft deletion behavior. // Example: -// var users []User -// db.Unscoped().Find(&users) -// // Retrieves all users, including deleted ones. +// +// var users []User +// db.Unscoped().Find(&users) +// // Retrieves all users, including deleted ones. func (db *DB) Unscoped() (tx *DB) { tx = db.getInstance() tx.Statement.Unscoped = true diff --git a/clause/joins.go b/clause/joins.go index b0f0359c..ddb2a5a9 100644 --- a/clause/joins.go +++ b/clause/joins.go @@ -11,6 +11,30 @@ const ( RightJoin JoinType = "RIGHT" ) +type JoinTarget struct { + Type JoinType + Association string + Subquery Expression + Table string +} + +func Has(name string) JoinTarget { + return JoinTarget{Type: LeftJoin, Association: name} +} + +func (jt JoinType) Association(name string) JoinTarget { + return JoinTarget{Type: jt, Association: name} +} + +func (jt JoinType) Subquery(subquery Expression) JoinTarget { + return JoinTarget{Type: jt, Subquery: subquery} +} + +func (jt JoinTarget) As(name string) JoinTarget { + jt.Table = name + return jt +} + // Join clause for from type Join struct { Type JoinType diff --git a/generics.go b/generics.go index 9dd1af7d..95d98100 100644 --- a/generics.go +++ b/generics.go @@ -28,8 +28,7 @@ type ChainInterface[T any] interface { Or(query interface{}, args ...interface{}) ChainInterface[T] Limit(offset int) ChainInterface[T] Offset(offset int) ChainInterface[T] - Joins(query string, args ...interface{}) ChainInterface[T] - InnerJoins(query string, args ...interface{}) ChainInterface[T] + Joins(query clause.JoinTarget, args func(db ChainInterface[any], joinTable clause.Table, curTable clause.Table) ChainInterface[any]) ChainInterface[T] Select(query string, args ...interface{}) ChainInterface[T] Omit(columns ...string) ChainInterface[T] MapColumns(m map[string]string) ChainInterface[T] @@ -186,16 +185,9 @@ func (c chainG[T]) Offset(offset int) ChainInterface[T] { }) } -func (c chainG[T]) Joins(query string, args ...interface{}) ChainInterface[T] { - return c.with(func(db *DB) *DB { - return db.Joins(query, args...) - }) -} - -func (c chainG[T]) InnerJoins(query string, args ...interface{}) ChainInterface[T] { - return c.with(func(db *DB) *DB { - return db.InnerJoins(query, args...) - }) +func (c chainG[T]) Joins(query clause.JoinTarget, args func(db ChainInterface[any], joinTable clause.Table, curTable clause.Table) ChainInterface[any]) ChainInterface[T] { + // TODO + return nil } func (c chainG[T]) Select(query string, args ...interface{}) ChainInterface[T] { diff --git a/tests/generics_test.go b/tests/generics_test.go index c467f9ff..2d69d22e 100644 --- a/tests/generics_test.go +++ b/tests/generics_test.go @@ -285,7 +285,9 @@ func TestGenericsJoinsAndPreload(t *testing.T) { db.Create(ctx, &u) // LEFT JOIN + WHERE - result, err := db.Joins("Company").Where("?.name = ?", clause.JoinTable("Company"), u.Company.Name).First(ctx) + result, err := db.Joins(clause.Has("Company"), func(db gorm.ChainInterface[any], joinTable clause.Table, curTable clause.Table) gorm.ChainInterface[any] { + return db.Where("?.name = ?", joinTable, u.Company.Name) + }).First(ctx) if err != nil { t.Fatalf("Joins failed: %v", err) } @@ -293,13 +295,26 @@ func TestGenericsJoinsAndPreload(t *testing.T) { t.Fatalf("Joins expected %s, got %+v", u.Name, result) } - // INNER JOIN + Inline WHERE - result2, err := db.InnerJoins("Company", "?.name = ?", clause.JoinTable("Company"), u.Company.Name).First(ctx) + // JOIN + result, err = db.Joins(clause.Has("Company"), func(db gorm.ChainInterface[any], joinTable clause.Table, curTable clause.Table) gorm.ChainInterface[any] { + return nil + }).First(ctx) if err != nil { - t.Fatalf("InnerJoins failed: %v", err) + t.Fatalf("Joins failed: %v", err) } - if result2.Name != u.Name || result2.Company.Name != u.Company.Name { - t.Errorf("InnerJoins expected , got %+v", result2) + if result.Name != u.Name || result.Company.Name != u.Company.Name { + t.Fatalf("Joins expected %s, got %+v", u.Name, result) + } + + // Left JOIN + result, err = db.Joins(clause.LeftJoin.Association("Company").As("t"), func(db gorm.ChainInterface[any], joinTable clause.Table, curTable clause.Table) gorm.ChainInterface[any] { + return nil + }).First(ctx) + if err != nil { + t.Fatalf("Joins failed: %v", err) + } + if result.Name != u.Name || result.Company.Name != u.Company.Name { + t.Fatalf("Joins expected %s, got %+v", u.Name, result) } // Preload diff --git a/tests/go.mod b/tests/go.mod index 2d647b08..7f4d84f7 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -29,8 +29,8 @@ require ( github.com/microsoft/go-mssqldb v1.8.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect - golang.org/x/crypto v0.37.0 // indirect - golang.org/x/text v0.24.0 // indirect + golang.org/x/crypto v0.38.0 // indirect + golang.org/x/text v0.25.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect )