improve generics version Joins support

This commit is contained in:
Jinzhu 2025-05-20 19:44:57 +08:00
parent 05925b2fc0
commit d073805c86
6 changed files with 213 additions and 39 deletions

View File

@ -146,9 +146,13 @@ func BuildQuerySQL(db *gorm.DB) {
if isRelations {
genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join {
tableAliasName := relation.Name
if parentTableName != clause.CurrentTable {
tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName)
tableAliasName := join.Alias
if tableAliasName == "" {
tableAliasName = relation.Name
if parentTableName != clause.CurrentTable {
tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName)
}
}
columnStmt := gorm.Statement{
@ -167,6 +171,13 @@ func BuildQuerySQL(db *gorm.DB) {
}
}
if join.Expression != nil {
return clause.Join{
Type: join.JoinType,
Expression: join.Expression,
}
}
exprs := make([]clause.Expression, len(relation.References))
for idx, ref := range relation.References {
if ref.OwnPrimaryKey {

View File

@ -19,15 +19,15 @@ type JoinTarget struct {
}
func Has(name string) JoinTarget {
return JoinTarget{Type: LeftJoin, Association: name}
return JoinTarget{Type: InnerJoin, Association: name}
}
func (jt JoinType) Association(name string) JoinTarget {
return JoinTarget{Type: jt, Association: name}
}
func (jt JoinType) Subquery(subquery Expression) JoinTarget {
return JoinTarget{Type: jt, Subquery: subquery}
func (jt JoinType) AssociationFrom(name string, subquery Expression) JoinTarget {
return JoinTarget{Type: jt, Association: name, Subquery: subquery}
}
func (jt JoinTarget) As(name string) JoinTarget {

View File

@ -3,8 +3,11 @@ package gorm
import (
"context"
"database/sql"
"fmt"
"strings"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
)
type Interface[T any] interface {
@ -28,7 +31,7 @@ type ChainInterface[T any] interface {
Or(query interface{}, args ...interface{}) ChainInterface[T]
Limit(offset int) ChainInterface[T]
Offset(offset int) ChainInterface[T]
Joins(query clause.JoinTarget, args func(db ChainInterface[any], joinTable clause.Table, curTable clause.Table) ChainInterface[any]) ChainInterface[T]
Joins(query clause.JoinTarget, args func(db QueryInterface, joinTable clause.Table, curTable clause.Table) QueryInterface) ChainInterface[T]
Select(query string, args ...interface{}) ChainInterface[T]
Omit(columns ...string) ChainInterface[T]
MapColumns(m map[string]string) ChainInterface[T]
@ -38,6 +41,8 @@ type ChainInterface[T any] interface {
Order(value interface{}) ChainInterface[T]
Preload(query string, args ...interface{}) ChainInterface[T]
Build(builder clause.Builder)
Delete(ctx context.Context) (rowsAffected int, err error)
Update(ctx context.Context, name string, value any) (rowsAffected int, err error)
Updates(ctx context.Context, t T) (rowsAffected int, err error)
@ -55,6 +60,14 @@ type ExecInterface[T any] interface {
Rows(ctx context.Context) (*sql.Rows, error)
}
type QueryInterface interface {
Select(...string) QueryInterface
Omit(...string) QueryInterface
Where(query interface{}, args ...interface{}) QueryInterface
Not(query interface{}, args ...interface{}) QueryInterface
Or(query interface{}, args ...interface{}) QueryInterface
}
type op func(*DB) *DB
func G[T any](db *DB, opts ...clause.Expression) Interface[T] {
@ -185,9 +198,77 @@ func (c chainG[T]) Offset(offset int) ChainInterface[T] {
})
}
func (c chainG[T]) Joins(query clause.JoinTarget, args func(db ChainInterface[any], joinTable clause.Table, curTable clause.Table) ChainInterface[any]) ChainInterface[T] {
// TODO
return nil
type query struct {
db *DB
}
func (q query) Where(query interface{}, args ...interface{}) QueryInterface {
q.db.Where(query, args...)
return q
}
func (q query) Or(query interface{}, args ...interface{}) QueryInterface {
q.db.Where(query, args...)
return q
}
func (q query) Not(query interface{}, args ...interface{}) QueryInterface {
q.db.Where(query, args...)
return q
}
func (q query) Select(columns ...string) QueryInterface {
q.db.Select(columns)
return q
}
func (q query) Omit(columns ...string) QueryInterface {
q.db.Omit(columns...)
return q
}
func (c chainG[T]) Joins(jt clause.JoinTarget, args func(db QueryInterface, joinTable clause.Table, curTable clause.Table) QueryInterface) ChainInterface[T] {
return c.with(func(db *DB) *DB {
if jt.Table == "" {
jt.Table = clause.JoinTable(strings.Split(jt.Association, ".")...).Name
}
q := query{db: db.Session(&Session{NewDB: true}).getInstance().Table(jt.Table)}
if args != nil {
args(q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable})
}
j := join{
Name: jt.Association,
Alias: jt.Table,
Selects: q.db.Statement.Selects,
Omits: q.db.Statement.Omits,
JoinType: jt.Type,
}
if where, ok := q.db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok {
j.On = &where
}
if jt.Subquery != nil {
joinType := j.JoinType
if joinType == "" {
joinType = clause.LeftJoin
}
expr := clause.NamedExpr{SQL: fmt.Sprintf("%s JOIN (?) AS ?", joinType), Vars: []interface{}{jt.Subquery, clause.Table{Name: j.Alias}}}
if j.On != nil {
expr.SQL += " ON ?"
expr.Vars = append(expr.Vars, clause.AndConditions{Exprs: j.On.Exprs})
}
j.Expression = expr
}
db.Statement.Joins = append(db.Statement.Joins, j)
return db
})
}
func (c chainG[T]) Select(query string, args ...interface{}) ChainInterface[T] {
@ -261,6 +342,43 @@ func (c chainG[T]) Count(ctx context.Context, column string) (result int64, err
return
}
func (c chainG[T]) Build(builder clause.Builder) {
subdb := c.getInstance()
subdb.Logger = logger.Discard
subdb.DryRun = true
if stmt, ok := builder.(*Statement); ok {
if subdb.Statement.SQL.Len() > 0 {
var (
vars = subdb.Statement.Vars
sql = subdb.Statement.SQL.String()
)
subdb.Statement.Vars = make([]interface{}, 0, len(vars))
for _, vv := range vars {
subdb.Statement.Vars = append(subdb.Statement.Vars, vv)
bindvar := strings.Builder{}
subdb.BindVarTo(&bindvar, subdb.Statement, vv)
sql = strings.Replace(sql, bindvar.String(), "?", 1)
}
subdb.Statement.SQL.Reset()
subdb.Statement.Vars = stmt.Vars
if strings.Contains(sql, "@") {
clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement)
} else {
clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement)
}
} else {
subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...)
subdb.callbacks.Query().Execute(subdb)
}
builder.WriteString(subdb.Statement.SQL.String())
stmt.Vars = subdb.Statement.Vars
}
}
type execG[T any] struct {
g *g[T]
}

View File

@ -4,6 +4,7 @@ import (
"database/sql"
"database/sql/driver"
"reflect"
"strings"
"time"
"gorm.io/gorm/schema"
@ -244,6 +245,12 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
matchedFieldCount[column] = 1
}
} else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation
for _, join := range db.Statement.Joins {
if join.Alias == names[0] {
names = append(strings.Split(join.Name, "."), names[len(names)-1])
}
}
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
subNameCount := len(names)
// nested relation fields

View File

@ -50,12 +50,14 @@ type Statement struct {
}
type join struct {
Name string
Conds []interface{}
On *clause.Where
Selects []string
Omits []string
JoinType clause.JoinType
Name string
Alias string
Conds []interface{}
On *clause.Where
Selects []string
Omits []string
Expression clause.Expression
JoinType clause.JoinType
}
// StatementModifier statement modifier interface
@ -322,6 +324,11 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
arg, _ = valuer.Value()
}
curTable := stmt.Table
if curTable == "" {
curTable = clause.CurrentTable
}
switch v := arg.(type) {
case clause.Expression:
conds = append(conds, v)
@ -352,7 +359,8 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
sort.Strings(keys)
for _, key := range keys {
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
column := clause.Column{Name: key, Table: curTable}
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
}
case map[string]interface{}:
keys := make([]string, 0, len(v))
@ -363,12 +371,13 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
for _, key := range keys {
reflectValue := reflect.Indirect(reflect.ValueOf(v[key]))
column := clause.Column{Name: key, Table: curTable}
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
if _, ok := v[key].(driver.Valuer); ok {
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
} else if _, ok := v[key].(Valuer); ok {
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
} else {
// optimize reflect value length
valueLen := reflectValue.Len()
@ -377,10 +386,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
values[i] = reflectValue.Index(i).Interface()
}
conds = append(conds, clause.IN{Column: key, Values: values})
conds = append(conds, clause.IN{Column: column, Values: values})
}
default:
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
}
}
default:
@ -407,9 +416,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
if selected || (!restricted && field.Readable) {
if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected {
if field.DBName != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.DBName}, Value: v})
} else if field.DataType != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.Name}, Value: v})
}
}
}
@ -421,9 +430,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
if selected || (!restricted && field.Readable) {
if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected {
if field.DBName != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.DBName}, Value: v})
} else if field.DataType != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.Name}, Value: v})
}
}
}
@ -448,14 +457,14 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
}
if len(values) > 0 {
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values})
conds = append(conds, clause.IN{Column: clause.Column{Table: curTable, Name: clause.PrimaryKey}, Values: values})
return []clause.Expression{clause.And(conds...)}
}
return nil
}
}
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args})
conds = append(conds, clause.IN{Column: clause.Column{Table: curTable, Name: clause.PrimaryKey}, Values: args})
}
}
}

View File

@ -44,7 +44,7 @@ func TestGenericsCreate(t *testing.T) {
if u, err := gorm.G[User](DB).Omit("name").Where("name = ?", user.Name).First(ctx); err != nil {
t.Fatalf("failed to find user, got error: %v", err)
} else if u.Name != "" || u.Age != u.Age {
} else if u.Name != "" || u.Age != user.Age {
t.Errorf("found invalid user, got %v, expect %v", u, user)
}
@ -92,7 +92,6 @@ func TestGenericsCreateInBatches(t *testing.T) {
found, err := gorm.G[User](DB).Raw("SELECT * FROM users WHERE name LIKE ?", "GenericsCreateInBatches%").Find(ctx)
if len(found) != len(batch) {
fmt.Println(found)
t.Errorf("expected %d from Raw Find, got %d", len(batch), len(found))
}
@ -282,10 +281,12 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
db := gorm.G[User](DB)
u := User{Name: "GenericsJoins", Company: Company{Name: "GenericsCompany"}}
db.Create(ctx, &u)
u2 := User{Name: "GenericsJoins_2", Company: Company{Name: "GenericsCompany_2"}}
u3 := User{Name: "GenericsJoins_3", Company: Company{Name: "GenericsCompany_3"}}
db.CreateInBatches(ctx, &[]User{u3, u, u2}, 10)
// LEFT JOIN + WHERE
result, err := db.Joins(clause.Has("Company"), func(db gorm.ChainInterface[any], joinTable clause.Table, curTable clause.Table) gorm.ChainInterface[any] {
// Inner JOIN + WHERE
result, err := db.Joins(clause.Has("Company"), func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface {
return db.Where("?.name = ?", joinTable, u.Company.Name)
}).First(ctx)
if err != nil {
@ -295,9 +296,9 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
}
// JOIN
result, err = db.Joins(clause.Has("Company"), func(db gorm.ChainInterface[any], joinTable clause.Table, curTable clause.Table) gorm.ChainInterface[any] {
return nil
// Inner JOIN + WHERE with map
result, err = db.Joins(clause.Has("Company"), func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface {
return db.Where(map[string]any{"name": u.Company.Name})
}).First(ctx)
if err != nil {
t.Fatalf("Joins failed: %v", err)
@ -306,10 +307,8 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
}
// Left JOIN
result, err = db.Joins(clause.LeftJoin.Association("Company").As("t"), func(db gorm.ChainInterface[any], joinTable clause.Table, curTable clause.Table) gorm.ChainInterface[any] {
return nil
}).First(ctx)
// Left JOIN w/o WHERE
result, err = db.Joins(clause.LeftJoin.Association("Company"), nil).Where(map[string]any{"name": u.Name}).First(ctx)
if err != nil {
t.Fatalf("Joins failed: %v", err)
}
@ -317,6 +316,36 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
}
// Left JOIN + Alias WHERE
result, err = db.Joins(clause.LeftJoin.Association("Company").As("t"), func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface {
if joinTable.Name != "t" {
t.Fatalf("Join table should be t, but got %v", joinTable.Name)
}
return db.Where("?.name = ?", joinTable, u.Company.Name)
}).Where(map[string]any{"name": u.Name}).First(ctx)
if err != nil {
t.Fatalf("Joins failed: %v", err)
}
if result.Name != u.Name || result.Company.Name != u.Company.Name {
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
}
// Raw Subquery JOIN + WHERE
result, err = db.Joins(clause.LeftJoin.AssociationFrom("Company", gorm.G[Company](DB)).As("t"),
func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface {
if joinTable.Name != "t" {
t.Fatalf("Join table should be t, but got %v", joinTable.Name)
}
return db.Where("?.name = ?", joinTable, u.Company.Name)
},
).Where(map[string]any{"name": u2.Name}).First(ctx)
if err != nil {
t.Fatalf("Raw subquery join failed: %v", err)
}
if result.Name != u2.Name || result.Company.Name != u.Company.Name {
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
}
// Preload
result3, err := db.Preload("Company").Where("name = ?", u.Name).First(ctx)
if err != nil {