improve generics version Joins support
This commit is contained in:
parent
05925b2fc0
commit
d073805c86
@ -146,9 +146,13 @@ func BuildQuerySQL(db *gorm.DB) {
|
||||
|
||||
if isRelations {
|
||||
genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join {
|
||||
tableAliasName := relation.Name
|
||||
if parentTableName != clause.CurrentTable {
|
||||
tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName)
|
||||
tableAliasName := join.Alias
|
||||
|
||||
if tableAliasName == "" {
|
||||
tableAliasName = relation.Name
|
||||
if parentTableName != clause.CurrentTable {
|
||||
tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName)
|
||||
}
|
||||
}
|
||||
|
||||
columnStmt := gorm.Statement{
|
||||
@ -167,6 +171,13 @@ func BuildQuerySQL(db *gorm.DB) {
|
||||
}
|
||||
}
|
||||
|
||||
if join.Expression != nil {
|
||||
return clause.Join{
|
||||
Type: join.JoinType,
|
||||
Expression: join.Expression,
|
||||
}
|
||||
}
|
||||
|
||||
exprs := make([]clause.Expression, len(relation.References))
|
||||
for idx, ref := range relation.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
|
@ -19,15 +19,15 @@ type JoinTarget struct {
|
||||
}
|
||||
|
||||
func Has(name string) JoinTarget {
|
||||
return JoinTarget{Type: LeftJoin, Association: name}
|
||||
return JoinTarget{Type: InnerJoin, Association: name}
|
||||
}
|
||||
|
||||
func (jt JoinType) Association(name string) JoinTarget {
|
||||
return JoinTarget{Type: jt, Association: name}
|
||||
}
|
||||
|
||||
func (jt JoinType) Subquery(subquery Expression) JoinTarget {
|
||||
return JoinTarget{Type: jt, Subquery: subquery}
|
||||
func (jt JoinType) AssociationFrom(name string, subquery Expression) JoinTarget {
|
||||
return JoinTarget{Type: jt, Association: name, Subquery: subquery}
|
||||
}
|
||||
|
||||
func (jt JoinTarget) As(name string) JoinTarget {
|
||||
|
126
generics.go
126
generics.go
@ -3,8 +3,11 @@ package gorm
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
type Interface[T any] interface {
|
||||
@ -28,7 +31,7 @@ type ChainInterface[T any] interface {
|
||||
Or(query interface{}, args ...interface{}) ChainInterface[T]
|
||||
Limit(offset int) ChainInterface[T]
|
||||
Offset(offset int) ChainInterface[T]
|
||||
Joins(query clause.JoinTarget, args func(db ChainInterface[any], joinTable clause.Table, curTable clause.Table) ChainInterface[any]) ChainInterface[T]
|
||||
Joins(query clause.JoinTarget, args func(db QueryInterface, joinTable clause.Table, curTable clause.Table) QueryInterface) ChainInterface[T]
|
||||
Select(query string, args ...interface{}) ChainInterface[T]
|
||||
Omit(columns ...string) ChainInterface[T]
|
||||
MapColumns(m map[string]string) ChainInterface[T]
|
||||
@ -38,6 +41,8 @@ type ChainInterface[T any] interface {
|
||||
Order(value interface{}) ChainInterface[T]
|
||||
Preload(query string, args ...interface{}) ChainInterface[T]
|
||||
|
||||
Build(builder clause.Builder)
|
||||
|
||||
Delete(ctx context.Context) (rowsAffected int, err error)
|
||||
Update(ctx context.Context, name string, value any) (rowsAffected int, err error)
|
||||
Updates(ctx context.Context, t T) (rowsAffected int, err error)
|
||||
@ -55,6 +60,14 @@ type ExecInterface[T any] interface {
|
||||
Rows(ctx context.Context) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
type QueryInterface interface {
|
||||
Select(...string) QueryInterface
|
||||
Omit(...string) QueryInterface
|
||||
Where(query interface{}, args ...interface{}) QueryInterface
|
||||
Not(query interface{}, args ...interface{}) QueryInterface
|
||||
Or(query interface{}, args ...interface{}) QueryInterface
|
||||
}
|
||||
|
||||
type op func(*DB) *DB
|
||||
|
||||
func G[T any](db *DB, opts ...clause.Expression) Interface[T] {
|
||||
@ -185,9 +198,77 @@ func (c chainG[T]) Offset(offset int) ChainInterface[T] {
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Joins(query clause.JoinTarget, args func(db ChainInterface[any], joinTable clause.Table, curTable clause.Table) ChainInterface[any]) ChainInterface[T] {
|
||||
// TODO
|
||||
return nil
|
||||
type query struct {
|
||||
db *DB
|
||||
}
|
||||
|
||||
func (q query) Where(query interface{}, args ...interface{}) QueryInterface {
|
||||
q.db.Where(query, args...)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q query) Or(query interface{}, args ...interface{}) QueryInterface {
|
||||
q.db.Where(query, args...)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q query) Not(query interface{}, args ...interface{}) QueryInterface {
|
||||
q.db.Where(query, args...)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q query) Select(columns ...string) QueryInterface {
|
||||
q.db.Select(columns)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q query) Omit(columns ...string) QueryInterface {
|
||||
q.db.Omit(columns...)
|
||||
return q
|
||||
}
|
||||
|
||||
func (c chainG[T]) Joins(jt clause.JoinTarget, args func(db QueryInterface, joinTable clause.Table, curTable clause.Table) QueryInterface) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
if jt.Table == "" {
|
||||
jt.Table = clause.JoinTable(strings.Split(jt.Association, ".")...).Name
|
||||
}
|
||||
|
||||
q := query{db: db.Session(&Session{NewDB: true}).getInstance().Table(jt.Table)}
|
||||
if args != nil {
|
||||
args(q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable})
|
||||
}
|
||||
|
||||
j := join{
|
||||
Name: jt.Association,
|
||||
Alias: jt.Table,
|
||||
Selects: q.db.Statement.Selects,
|
||||
Omits: q.db.Statement.Omits,
|
||||
JoinType: jt.Type,
|
||||
}
|
||||
|
||||
if where, ok := q.db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok {
|
||||
j.On = &where
|
||||
}
|
||||
|
||||
if jt.Subquery != nil {
|
||||
joinType := j.JoinType
|
||||
if joinType == "" {
|
||||
joinType = clause.LeftJoin
|
||||
}
|
||||
|
||||
expr := clause.NamedExpr{SQL: fmt.Sprintf("%s JOIN (?) AS ?", joinType), Vars: []interface{}{jt.Subquery, clause.Table{Name: j.Alias}}}
|
||||
|
||||
if j.On != nil {
|
||||
expr.SQL += " ON ?"
|
||||
expr.Vars = append(expr.Vars, clause.AndConditions{Exprs: j.On.Exprs})
|
||||
}
|
||||
|
||||
j.Expression = expr
|
||||
}
|
||||
|
||||
db.Statement.Joins = append(db.Statement.Joins, j)
|
||||
return db
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Select(query string, args ...interface{}) ChainInterface[T] {
|
||||
@ -261,6 +342,43 @@ func (c chainG[T]) Count(ctx context.Context, column string) (result int64, err
|
||||
return
|
||||
}
|
||||
|
||||
func (c chainG[T]) Build(builder clause.Builder) {
|
||||
subdb := c.getInstance()
|
||||
subdb.Logger = logger.Discard
|
||||
subdb.DryRun = true
|
||||
|
||||
if stmt, ok := builder.(*Statement); ok {
|
||||
if subdb.Statement.SQL.Len() > 0 {
|
||||
var (
|
||||
vars = subdb.Statement.Vars
|
||||
sql = subdb.Statement.SQL.String()
|
||||
)
|
||||
|
||||
subdb.Statement.Vars = make([]interface{}, 0, len(vars))
|
||||
for _, vv := range vars {
|
||||
subdb.Statement.Vars = append(subdb.Statement.Vars, vv)
|
||||
bindvar := strings.Builder{}
|
||||
subdb.BindVarTo(&bindvar, subdb.Statement, vv)
|
||||
sql = strings.Replace(sql, bindvar.String(), "?", 1)
|
||||
}
|
||||
|
||||
subdb.Statement.SQL.Reset()
|
||||
subdb.Statement.Vars = stmt.Vars
|
||||
if strings.Contains(sql, "@") {
|
||||
clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement)
|
||||
} else {
|
||||
clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement)
|
||||
}
|
||||
} else {
|
||||
subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...)
|
||||
subdb.callbacks.Query().Execute(subdb)
|
||||
}
|
||||
|
||||
builder.WriteString(subdb.Statement.SQL.String())
|
||||
stmt.Vars = subdb.Statement.Vars
|
||||
}
|
||||
}
|
||||
|
||||
type execG[T any] struct {
|
||||
g *g[T]
|
||||
}
|
||||
|
7
scan.go
7
scan.go
@ -4,6 +4,7 @@ import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm/schema"
|
||||
@ -244,6 +245,12 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
|
||||
matchedFieldCount[column] = 1
|
||||
}
|
||||
} else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation
|
||||
for _, join := range db.Statement.Joins {
|
||||
if join.Alias == names[0] {
|
||||
names = append(strings.Split(join.Name, "."), names[len(names)-1])
|
||||
}
|
||||
}
|
||||
|
||||
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
|
||||
subNameCount := len(names)
|
||||
// nested relation fields
|
||||
|
43
statement.go
43
statement.go
@ -50,12 +50,14 @@ type Statement struct {
|
||||
}
|
||||
|
||||
type join struct {
|
||||
Name string
|
||||
Conds []interface{}
|
||||
On *clause.Where
|
||||
Selects []string
|
||||
Omits []string
|
||||
JoinType clause.JoinType
|
||||
Name string
|
||||
Alias string
|
||||
Conds []interface{}
|
||||
On *clause.Where
|
||||
Selects []string
|
||||
Omits []string
|
||||
Expression clause.Expression
|
||||
JoinType clause.JoinType
|
||||
}
|
||||
|
||||
// StatementModifier statement modifier interface
|
||||
@ -322,6 +324,11 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||
arg, _ = valuer.Value()
|
||||
}
|
||||
|
||||
curTable := stmt.Table
|
||||
if curTable == "" {
|
||||
curTable = clause.CurrentTable
|
||||
}
|
||||
|
||||
switch v := arg.(type) {
|
||||
case clause.Expression:
|
||||
conds = append(conds, v)
|
||||
@ -352,7 +359,8 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||
sort.Strings(keys)
|
||||
|
||||
for _, key := range keys {
|
||||
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
|
||||
column := clause.Column{Name: key, Table: curTable}
|
||||
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
|
||||
}
|
||||
case map[string]interface{}:
|
||||
keys := make([]string, 0, len(v))
|
||||
@ -363,12 +371,13 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||
|
||||
for _, key := range keys {
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(v[key]))
|
||||
column := clause.Column{Name: key, Table: curTable}
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if _, ok := v[key].(driver.Valuer); ok {
|
||||
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
|
||||
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
|
||||
} else if _, ok := v[key].(Valuer); ok {
|
||||
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
|
||||
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
|
||||
} else {
|
||||
// optimize reflect value length
|
||||
valueLen := reflectValue.Len()
|
||||
@ -377,10 +386,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||
values[i] = reflectValue.Index(i).Interface()
|
||||
}
|
||||
|
||||
conds = append(conds, clause.IN{Column: key, Values: values})
|
||||
conds = append(conds, clause.IN{Column: column, Values: values})
|
||||
}
|
||||
default:
|
||||
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
|
||||
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
|
||||
}
|
||||
}
|
||||
default:
|
||||
@ -407,9 +416,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||
if selected || (!restricted && field.Readable) {
|
||||
if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected {
|
||||
if field.DBName != "" {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.DBName}, Value: v})
|
||||
} else if field.DataType != "" {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.Name}, Value: v})
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -421,9 +430,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||
if selected || (!restricted && field.Readable) {
|
||||
if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected {
|
||||
if field.DBName != "" {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.DBName}, Value: v})
|
||||
} else if field.DataType != "" {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.Name}, Value: v})
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -448,14 +457,14 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||
}
|
||||
|
||||
if len(values) > 0 {
|
||||
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values})
|
||||
conds = append(conds, clause.IN{Column: clause.Column{Table: curTable, Name: clause.PrimaryKey}, Values: values})
|
||||
return []clause.Expression{clause.And(conds...)}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args})
|
||||
conds = append(conds, clause.IN{Column: clause.Column{Table: curTable, Name: clause.PrimaryKey}, Values: args})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -44,7 +44,7 @@ func TestGenericsCreate(t *testing.T) {
|
||||
|
||||
if u, err := gorm.G[User](DB).Omit("name").Where("name = ?", user.Name).First(ctx); err != nil {
|
||||
t.Fatalf("failed to find user, got error: %v", err)
|
||||
} else if u.Name != "" || u.Age != u.Age {
|
||||
} else if u.Name != "" || u.Age != user.Age {
|
||||
t.Errorf("found invalid user, got %v, expect %v", u, user)
|
||||
}
|
||||
|
||||
@ -92,7 +92,6 @@ func TestGenericsCreateInBatches(t *testing.T) {
|
||||
|
||||
found, err := gorm.G[User](DB).Raw("SELECT * FROM users WHERE name LIKE ?", "GenericsCreateInBatches%").Find(ctx)
|
||||
if len(found) != len(batch) {
|
||||
fmt.Println(found)
|
||||
t.Errorf("expected %d from Raw Find, got %d", len(batch), len(found))
|
||||
}
|
||||
|
||||
@ -282,10 +281,12 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
|
||||
db := gorm.G[User](DB)
|
||||
|
||||
u := User{Name: "GenericsJoins", Company: Company{Name: "GenericsCompany"}}
|
||||
db.Create(ctx, &u)
|
||||
u2 := User{Name: "GenericsJoins_2", Company: Company{Name: "GenericsCompany_2"}}
|
||||
u3 := User{Name: "GenericsJoins_3", Company: Company{Name: "GenericsCompany_3"}}
|
||||
db.CreateInBatches(ctx, &[]User{u3, u, u2}, 10)
|
||||
|
||||
// LEFT JOIN + WHERE
|
||||
result, err := db.Joins(clause.Has("Company"), func(db gorm.ChainInterface[any], joinTable clause.Table, curTable clause.Table) gorm.ChainInterface[any] {
|
||||
// Inner JOIN + WHERE
|
||||
result, err := db.Joins(clause.Has("Company"), func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface {
|
||||
return db.Where("?.name = ?", joinTable, u.Company.Name)
|
||||
}).First(ctx)
|
||||
if err != nil {
|
||||
@ -295,9 +296,9 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
|
||||
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
|
||||
}
|
||||
|
||||
// JOIN
|
||||
result, err = db.Joins(clause.Has("Company"), func(db gorm.ChainInterface[any], joinTable clause.Table, curTable clause.Table) gorm.ChainInterface[any] {
|
||||
return nil
|
||||
// Inner JOIN + WHERE with map
|
||||
result, err = db.Joins(clause.Has("Company"), func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface {
|
||||
return db.Where(map[string]any{"name": u.Company.Name})
|
||||
}).First(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Joins failed: %v", err)
|
||||
@ -306,10 +307,8 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
|
||||
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
|
||||
}
|
||||
|
||||
// Left JOIN
|
||||
result, err = db.Joins(clause.LeftJoin.Association("Company").As("t"), func(db gorm.ChainInterface[any], joinTable clause.Table, curTable clause.Table) gorm.ChainInterface[any] {
|
||||
return nil
|
||||
}).First(ctx)
|
||||
// Left JOIN w/o WHERE
|
||||
result, err = db.Joins(clause.LeftJoin.Association("Company"), nil).Where(map[string]any{"name": u.Name}).First(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Joins failed: %v", err)
|
||||
}
|
||||
@ -317,6 +316,36 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
|
||||
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
|
||||
}
|
||||
|
||||
// Left JOIN + Alias WHERE
|
||||
result, err = db.Joins(clause.LeftJoin.Association("Company").As("t"), func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface {
|
||||
if joinTable.Name != "t" {
|
||||
t.Fatalf("Join table should be t, but got %v", joinTable.Name)
|
||||
}
|
||||
return db.Where("?.name = ?", joinTable, u.Company.Name)
|
||||
}).Where(map[string]any{"name": u.Name}).First(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Joins failed: %v", err)
|
||||
}
|
||||
if result.Name != u.Name || result.Company.Name != u.Company.Name {
|
||||
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
|
||||
}
|
||||
|
||||
// Raw Subquery JOIN + WHERE
|
||||
result, err = db.Joins(clause.LeftJoin.AssociationFrom("Company", gorm.G[Company](DB)).As("t"),
|
||||
func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface {
|
||||
if joinTable.Name != "t" {
|
||||
t.Fatalf("Join table should be t, but got %v", joinTable.Name)
|
||||
}
|
||||
return db.Where("?.name = ?", joinTable, u.Company.Name)
|
||||
},
|
||||
).Where(map[string]any{"name": u2.Name}).First(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Raw subquery join failed: %v", err)
|
||||
}
|
||||
if result.Name != u2.Name || result.Company.Name != u.Company.Name {
|
||||
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
|
||||
}
|
||||
|
||||
// Preload
|
||||
result3, err := db.Preload("Company").Where("name = ?", u.Name).First(ctx)
|
||||
if err != nil {
|
||||
|
Loading…
x
Reference in New Issue
Block a user