improve generics version Joins support

This commit is contained in:
Jinzhu 2025-05-20 19:44:57 +08:00
parent 05925b2fc0
commit d073805c86
6 changed files with 213 additions and 39 deletions

View File

@ -146,10 +146,14 @@ func BuildQuerySQL(db *gorm.DB) {
if isRelations { if isRelations {
genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join { genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join {
tableAliasName := relation.Name tableAliasName := join.Alias
if tableAliasName == "" {
tableAliasName = relation.Name
if parentTableName != clause.CurrentTable { if parentTableName != clause.CurrentTable {
tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName) tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName)
} }
}
columnStmt := gorm.Statement{ columnStmt := gorm.Statement{
Table: tableAliasName, DB: db, Schema: relation.FieldSchema, Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
@ -167,6 +171,13 @@ func BuildQuerySQL(db *gorm.DB) {
} }
} }
if join.Expression != nil {
return clause.Join{
Type: join.JoinType,
Expression: join.Expression,
}
}
exprs := make([]clause.Expression, len(relation.References)) exprs := make([]clause.Expression, len(relation.References))
for idx, ref := range relation.References { for idx, ref := range relation.References {
if ref.OwnPrimaryKey { if ref.OwnPrimaryKey {

View File

@ -19,15 +19,15 @@ type JoinTarget struct {
} }
func Has(name string) JoinTarget { func Has(name string) JoinTarget {
return JoinTarget{Type: LeftJoin, Association: name} return JoinTarget{Type: InnerJoin, Association: name}
} }
func (jt JoinType) Association(name string) JoinTarget { func (jt JoinType) Association(name string) JoinTarget {
return JoinTarget{Type: jt, Association: name} return JoinTarget{Type: jt, Association: name}
} }
func (jt JoinType) Subquery(subquery Expression) JoinTarget { func (jt JoinType) AssociationFrom(name string, subquery Expression) JoinTarget {
return JoinTarget{Type: jt, Subquery: subquery} return JoinTarget{Type: jt, Association: name, Subquery: subquery}
} }
func (jt JoinTarget) As(name string) JoinTarget { func (jt JoinTarget) As(name string) JoinTarget {

View File

@ -3,8 +3,11 @@ package gorm
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"strings"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
"gorm.io/gorm/logger"
) )
type Interface[T any] interface { type Interface[T any] interface {
@ -28,7 +31,7 @@ type ChainInterface[T any] interface {
Or(query interface{}, args ...interface{}) ChainInterface[T] Or(query interface{}, args ...interface{}) ChainInterface[T]
Limit(offset int) ChainInterface[T] Limit(offset int) ChainInterface[T]
Offset(offset int) ChainInterface[T] Offset(offset int) ChainInterface[T]
Joins(query clause.JoinTarget, args func(db ChainInterface[any], joinTable clause.Table, curTable clause.Table) ChainInterface[any]) ChainInterface[T] Joins(query clause.JoinTarget, args func(db QueryInterface, joinTable clause.Table, curTable clause.Table) QueryInterface) ChainInterface[T]
Select(query string, args ...interface{}) ChainInterface[T] Select(query string, args ...interface{}) ChainInterface[T]
Omit(columns ...string) ChainInterface[T] Omit(columns ...string) ChainInterface[T]
MapColumns(m map[string]string) ChainInterface[T] MapColumns(m map[string]string) ChainInterface[T]
@ -38,6 +41,8 @@ type ChainInterface[T any] interface {
Order(value interface{}) ChainInterface[T] Order(value interface{}) ChainInterface[T]
Preload(query string, args ...interface{}) ChainInterface[T] Preload(query string, args ...interface{}) ChainInterface[T]
Build(builder clause.Builder)
Delete(ctx context.Context) (rowsAffected int, err error) Delete(ctx context.Context) (rowsAffected int, err error)
Update(ctx context.Context, name string, value any) (rowsAffected int, err error) Update(ctx context.Context, name string, value any) (rowsAffected int, err error)
Updates(ctx context.Context, t T) (rowsAffected int, err error) Updates(ctx context.Context, t T) (rowsAffected int, err error)
@ -55,6 +60,14 @@ type ExecInterface[T any] interface {
Rows(ctx context.Context) (*sql.Rows, error) Rows(ctx context.Context) (*sql.Rows, error)
} }
type QueryInterface interface {
Select(...string) QueryInterface
Omit(...string) QueryInterface
Where(query interface{}, args ...interface{}) QueryInterface
Not(query interface{}, args ...interface{}) QueryInterface
Or(query interface{}, args ...interface{}) QueryInterface
}
type op func(*DB) *DB type op func(*DB) *DB
func G[T any](db *DB, opts ...clause.Expression) Interface[T] { func G[T any](db *DB, opts ...clause.Expression) Interface[T] {
@ -185,9 +198,77 @@ func (c chainG[T]) Offset(offset int) ChainInterface[T] {
}) })
} }
func (c chainG[T]) Joins(query clause.JoinTarget, args func(db ChainInterface[any], joinTable clause.Table, curTable clause.Table) ChainInterface[any]) ChainInterface[T] { type query struct {
// TODO db *DB
return nil }
func (q query) Where(query interface{}, args ...interface{}) QueryInterface {
q.db.Where(query, args...)
return q
}
func (q query) Or(query interface{}, args ...interface{}) QueryInterface {
q.db.Where(query, args...)
return q
}
func (q query) Not(query interface{}, args ...interface{}) QueryInterface {
q.db.Where(query, args...)
return q
}
func (q query) Select(columns ...string) QueryInterface {
q.db.Select(columns)
return q
}
func (q query) Omit(columns ...string) QueryInterface {
q.db.Omit(columns...)
return q
}
func (c chainG[T]) Joins(jt clause.JoinTarget, args func(db QueryInterface, joinTable clause.Table, curTable clause.Table) QueryInterface) ChainInterface[T] {
return c.with(func(db *DB) *DB {
if jt.Table == "" {
jt.Table = clause.JoinTable(strings.Split(jt.Association, ".")...).Name
}
q := query{db: db.Session(&Session{NewDB: true}).getInstance().Table(jt.Table)}
if args != nil {
args(q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable})
}
j := join{
Name: jt.Association,
Alias: jt.Table,
Selects: q.db.Statement.Selects,
Omits: q.db.Statement.Omits,
JoinType: jt.Type,
}
if where, ok := q.db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok {
j.On = &where
}
if jt.Subquery != nil {
joinType := j.JoinType
if joinType == "" {
joinType = clause.LeftJoin
}
expr := clause.NamedExpr{SQL: fmt.Sprintf("%s JOIN (?) AS ?", joinType), Vars: []interface{}{jt.Subquery, clause.Table{Name: j.Alias}}}
if j.On != nil {
expr.SQL += " ON ?"
expr.Vars = append(expr.Vars, clause.AndConditions{Exprs: j.On.Exprs})
}
j.Expression = expr
}
db.Statement.Joins = append(db.Statement.Joins, j)
return db
})
} }
func (c chainG[T]) Select(query string, args ...interface{}) ChainInterface[T] { func (c chainG[T]) Select(query string, args ...interface{}) ChainInterface[T] {
@ -261,6 +342,43 @@ func (c chainG[T]) Count(ctx context.Context, column string) (result int64, err
return return
} }
func (c chainG[T]) Build(builder clause.Builder) {
subdb := c.getInstance()
subdb.Logger = logger.Discard
subdb.DryRun = true
if stmt, ok := builder.(*Statement); ok {
if subdb.Statement.SQL.Len() > 0 {
var (
vars = subdb.Statement.Vars
sql = subdb.Statement.SQL.String()
)
subdb.Statement.Vars = make([]interface{}, 0, len(vars))
for _, vv := range vars {
subdb.Statement.Vars = append(subdb.Statement.Vars, vv)
bindvar := strings.Builder{}
subdb.BindVarTo(&bindvar, subdb.Statement, vv)
sql = strings.Replace(sql, bindvar.String(), "?", 1)
}
subdb.Statement.SQL.Reset()
subdb.Statement.Vars = stmt.Vars
if strings.Contains(sql, "@") {
clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement)
} else {
clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement)
}
} else {
subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...)
subdb.callbacks.Query().Execute(subdb)
}
builder.WriteString(subdb.Statement.SQL.String())
stmt.Vars = subdb.Statement.Vars
}
}
type execG[T any] struct { type execG[T any] struct {
g *g[T] g *g[T]
} }

View File

@ -4,6 +4,7 @@ import (
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"reflect" "reflect"
"strings"
"time" "time"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
@ -244,6 +245,12 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
matchedFieldCount[column] = 1 matchedFieldCount[column] = 1
} }
} else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation } else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation
for _, join := range db.Statement.Joins {
if join.Alias == names[0] {
names = append(strings.Split(join.Name, "."), names[len(names)-1])
}
}
if rel, ok := sch.Relationships.Relations[names[0]]; ok { if rel, ok := sch.Relationships.Relations[names[0]]; ok {
subNameCount := len(names) subNameCount := len(names)
// nested relation fields // nested relation fields

View File

@ -51,10 +51,12 @@ type Statement struct {
type join struct { type join struct {
Name string Name string
Alias string
Conds []interface{} Conds []interface{}
On *clause.Where On *clause.Where
Selects []string Selects []string
Omits []string Omits []string
Expression clause.Expression
JoinType clause.JoinType JoinType clause.JoinType
} }
@ -322,6 +324,11 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
arg, _ = valuer.Value() arg, _ = valuer.Value()
} }
curTable := stmt.Table
if curTable == "" {
curTable = clause.CurrentTable
}
switch v := arg.(type) { switch v := arg.(type) {
case clause.Expression: case clause.Expression:
conds = append(conds, v) conds = append(conds, v)
@ -352,7 +359,8 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
sort.Strings(keys) sort.Strings(keys)
for _, key := range keys { for _, key := range keys {
conds = append(conds, clause.Eq{Column: key, Value: v[key]}) column := clause.Column{Name: key, Table: curTable}
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
} }
case map[string]interface{}: case map[string]interface{}:
keys := make([]string, 0, len(v)) keys := make([]string, 0, len(v))
@ -363,12 +371,13 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
for _, key := range keys { for _, key := range keys {
reflectValue := reflect.Indirect(reflect.ValueOf(v[key])) reflectValue := reflect.Indirect(reflect.ValueOf(v[key]))
column := clause.Column{Name: key, Table: curTable}
switch reflectValue.Kind() { switch reflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
if _, ok := v[key].(driver.Valuer); ok { if _, ok := v[key].(driver.Valuer); ok {
conds = append(conds, clause.Eq{Column: key, Value: v[key]}) conds = append(conds, clause.Eq{Column: column, Value: v[key]})
} else if _, ok := v[key].(Valuer); ok { } else if _, ok := v[key].(Valuer); ok {
conds = append(conds, clause.Eq{Column: key, Value: v[key]}) conds = append(conds, clause.Eq{Column: column, Value: v[key]})
} else { } else {
// optimize reflect value length // optimize reflect value length
valueLen := reflectValue.Len() valueLen := reflectValue.Len()
@ -377,10 +386,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
values[i] = reflectValue.Index(i).Interface() values[i] = reflectValue.Index(i).Interface()
} }
conds = append(conds, clause.IN{Column: key, Values: values}) conds = append(conds, clause.IN{Column: column, Values: values})
} }
default: default:
conds = append(conds, clause.Eq{Column: key, Value: v[key]}) conds = append(conds, clause.Eq{Column: column, Value: v[key]})
} }
} }
default: default:
@ -407,9 +416,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
if selected || (!restricted && field.Readable) { if selected || (!restricted && field.Readable) {
if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected { if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected {
if field.DBName != "" { if field.DBName != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.DBName}, Value: v})
} else if field.DataType != "" { } else if field.DataType != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v}) conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.Name}, Value: v})
} }
} }
} }
@ -421,9 +430,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
if selected || (!restricted && field.Readable) { if selected || (!restricted && field.Readable) {
if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected { if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected {
if field.DBName != "" { if field.DBName != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.DBName}, Value: v})
} else if field.DataType != "" { } else if field.DataType != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v}) conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.Name}, Value: v})
} }
} }
} }
@ -448,14 +457,14 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
} }
if len(values) > 0 { if len(values) > 0 {
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values}) conds = append(conds, clause.IN{Column: clause.Column{Table: curTable, Name: clause.PrimaryKey}, Values: values})
return []clause.Expression{clause.And(conds...)} return []clause.Expression{clause.And(conds...)}
} }
return nil return nil
} }
} }
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args}) conds = append(conds, clause.IN{Column: clause.Column{Table: curTable, Name: clause.PrimaryKey}, Values: args})
} }
} }
} }

View File

@ -44,7 +44,7 @@ func TestGenericsCreate(t *testing.T) {
if u, err := gorm.G[User](DB).Omit("name").Where("name = ?", user.Name).First(ctx); err != nil { if u, err := gorm.G[User](DB).Omit("name").Where("name = ?", user.Name).First(ctx); err != nil {
t.Fatalf("failed to find user, got error: %v", err) t.Fatalf("failed to find user, got error: %v", err)
} else if u.Name != "" || u.Age != u.Age { } else if u.Name != "" || u.Age != user.Age {
t.Errorf("found invalid user, got %v, expect %v", u, user) t.Errorf("found invalid user, got %v, expect %v", u, user)
} }
@ -92,7 +92,6 @@ func TestGenericsCreateInBatches(t *testing.T) {
found, err := gorm.G[User](DB).Raw("SELECT * FROM users WHERE name LIKE ?", "GenericsCreateInBatches%").Find(ctx) found, err := gorm.G[User](DB).Raw("SELECT * FROM users WHERE name LIKE ?", "GenericsCreateInBatches%").Find(ctx)
if len(found) != len(batch) { if len(found) != len(batch) {
fmt.Println(found)
t.Errorf("expected %d from Raw Find, got %d", len(batch), len(found)) t.Errorf("expected %d from Raw Find, got %d", len(batch), len(found))
} }
@ -282,10 +281,12 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
db := gorm.G[User](DB) db := gorm.G[User](DB)
u := User{Name: "GenericsJoins", Company: Company{Name: "GenericsCompany"}} u := User{Name: "GenericsJoins", Company: Company{Name: "GenericsCompany"}}
db.Create(ctx, &u) u2 := User{Name: "GenericsJoins_2", Company: Company{Name: "GenericsCompany_2"}}
u3 := User{Name: "GenericsJoins_3", Company: Company{Name: "GenericsCompany_3"}}
db.CreateInBatches(ctx, &[]User{u3, u, u2}, 10)
// LEFT JOIN + WHERE // Inner JOIN + WHERE
result, err := db.Joins(clause.Has("Company"), func(db gorm.ChainInterface[any], joinTable clause.Table, curTable clause.Table) gorm.ChainInterface[any] { result, err := db.Joins(clause.Has("Company"), func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface {
return db.Where("?.name = ?", joinTable, u.Company.Name) return db.Where("?.name = ?", joinTable, u.Company.Name)
}).First(ctx) }).First(ctx)
if err != nil { if err != nil {
@ -295,9 +296,9 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
t.Fatalf("Joins expected %s, got %+v", u.Name, result) t.Fatalf("Joins expected %s, got %+v", u.Name, result)
} }
// JOIN // Inner JOIN + WHERE with map
result, err = db.Joins(clause.Has("Company"), func(db gorm.ChainInterface[any], joinTable clause.Table, curTable clause.Table) gorm.ChainInterface[any] { result, err = db.Joins(clause.Has("Company"), func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface {
return nil return db.Where(map[string]any{"name": u.Company.Name})
}).First(ctx) }).First(ctx)
if err != nil { if err != nil {
t.Fatalf("Joins failed: %v", err) t.Fatalf("Joins failed: %v", err)
@ -306,10 +307,8 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
t.Fatalf("Joins expected %s, got %+v", u.Name, result) t.Fatalf("Joins expected %s, got %+v", u.Name, result)
} }
// Left JOIN // Left JOIN w/o WHERE
result, err = db.Joins(clause.LeftJoin.Association("Company").As("t"), func(db gorm.ChainInterface[any], joinTable clause.Table, curTable clause.Table) gorm.ChainInterface[any] { result, err = db.Joins(clause.LeftJoin.Association("Company"), nil).Where(map[string]any{"name": u.Name}).First(ctx)
return nil
}).First(ctx)
if err != nil { if err != nil {
t.Fatalf("Joins failed: %v", err) t.Fatalf("Joins failed: %v", err)
} }
@ -317,6 +316,36 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
t.Fatalf("Joins expected %s, got %+v", u.Name, result) t.Fatalf("Joins expected %s, got %+v", u.Name, result)
} }
// Left JOIN + Alias WHERE
result, err = db.Joins(clause.LeftJoin.Association("Company").As("t"), func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface {
if joinTable.Name != "t" {
t.Fatalf("Join table should be t, but got %v", joinTable.Name)
}
return db.Where("?.name = ?", joinTable, u.Company.Name)
}).Where(map[string]any{"name": u.Name}).First(ctx)
if err != nil {
t.Fatalf("Joins failed: %v", err)
}
if result.Name != u.Name || result.Company.Name != u.Company.Name {
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
}
// Raw Subquery JOIN + WHERE
result, err = db.Joins(clause.LeftJoin.AssociationFrom("Company", gorm.G[Company](DB)).As("t"),
func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface {
if joinTable.Name != "t" {
t.Fatalf("Join table should be t, but got %v", joinTable.Name)
}
return db.Where("?.name = ?", joinTable, u.Company.Name)
},
).Where(map[string]any{"name": u2.Name}).First(ctx)
if err != nil {
t.Fatalf("Raw subquery join failed: %v", err)
}
if result.Name != u2.Name || result.Company.Name != u.Company.Name {
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
}
// Preload // Preload
result3, err := db.Preload("Company").Where("name = ?", u.Name).First(ctx) result3, err := db.Preload("Company").Where("name = ?", u.Name).First(ctx)
if err != nil { if err != nil {