improve generics version Joins support
This commit is contained in:
parent
05925b2fc0
commit
d073805c86
@ -146,9 +146,13 @@ func BuildQuerySQL(db *gorm.DB) {
|
|||||||
|
|
||||||
if isRelations {
|
if isRelations {
|
||||||
genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join {
|
genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join {
|
||||||
tableAliasName := relation.Name
|
tableAliasName := join.Alias
|
||||||
if parentTableName != clause.CurrentTable {
|
|
||||||
tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName)
|
if tableAliasName == "" {
|
||||||
|
tableAliasName = relation.Name
|
||||||
|
if parentTableName != clause.CurrentTable {
|
||||||
|
tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
columnStmt := gorm.Statement{
|
columnStmt := gorm.Statement{
|
||||||
@ -167,6 +171,13 @@ func BuildQuerySQL(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if join.Expression != nil {
|
||||||
|
return clause.Join{
|
||||||
|
Type: join.JoinType,
|
||||||
|
Expression: join.Expression,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
exprs := make([]clause.Expression, len(relation.References))
|
exprs := make([]clause.Expression, len(relation.References))
|
||||||
for idx, ref := range relation.References {
|
for idx, ref := range relation.References {
|
||||||
if ref.OwnPrimaryKey {
|
if ref.OwnPrimaryKey {
|
||||||
|
@ -19,15 +19,15 @@ type JoinTarget struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Has(name string) JoinTarget {
|
func Has(name string) JoinTarget {
|
||||||
return JoinTarget{Type: LeftJoin, Association: name}
|
return JoinTarget{Type: InnerJoin, Association: name}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (jt JoinType) Association(name string) JoinTarget {
|
func (jt JoinType) Association(name string) JoinTarget {
|
||||||
return JoinTarget{Type: jt, Association: name}
|
return JoinTarget{Type: jt, Association: name}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (jt JoinType) Subquery(subquery Expression) JoinTarget {
|
func (jt JoinType) AssociationFrom(name string, subquery Expression) JoinTarget {
|
||||||
return JoinTarget{Type: jt, Subquery: subquery}
|
return JoinTarget{Type: jt, Association: name, Subquery: subquery}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (jt JoinTarget) As(name string) JoinTarget {
|
func (jt JoinTarget) As(name string) JoinTarget {
|
||||||
|
126
generics.go
126
generics.go
@ -3,8 +3,11 @@ package gorm
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"gorm.io/gorm/clause"
|
"gorm.io/gorm/clause"
|
||||||
|
"gorm.io/gorm/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Interface[T any] interface {
|
type Interface[T any] interface {
|
||||||
@ -28,7 +31,7 @@ type ChainInterface[T any] interface {
|
|||||||
Or(query interface{}, args ...interface{}) ChainInterface[T]
|
Or(query interface{}, args ...interface{}) ChainInterface[T]
|
||||||
Limit(offset int) ChainInterface[T]
|
Limit(offset int) ChainInterface[T]
|
||||||
Offset(offset int) ChainInterface[T]
|
Offset(offset int) ChainInterface[T]
|
||||||
Joins(query clause.JoinTarget, args func(db ChainInterface[any], joinTable clause.Table, curTable clause.Table) ChainInterface[any]) ChainInterface[T]
|
Joins(query clause.JoinTarget, args func(db QueryInterface, joinTable clause.Table, curTable clause.Table) QueryInterface) ChainInterface[T]
|
||||||
Select(query string, args ...interface{}) ChainInterface[T]
|
Select(query string, args ...interface{}) ChainInterface[T]
|
||||||
Omit(columns ...string) ChainInterface[T]
|
Omit(columns ...string) ChainInterface[T]
|
||||||
MapColumns(m map[string]string) ChainInterface[T]
|
MapColumns(m map[string]string) ChainInterface[T]
|
||||||
@ -38,6 +41,8 @@ type ChainInterface[T any] interface {
|
|||||||
Order(value interface{}) ChainInterface[T]
|
Order(value interface{}) ChainInterface[T]
|
||||||
Preload(query string, args ...interface{}) ChainInterface[T]
|
Preload(query string, args ...interface{}) ChainInterface[T]
|
||||||
|
|
||||||
|
Build(builder clause.Builder)
|
||||||
|
|
||||||
Delete(ctx context.Context) (rowsAffected int, err error)
|
Delete(ctx context.Context) (rowsAffected int, err error)
|
||||||
Update(ctx context.Context, name string, value any) (rowsAffected int, err error)
|
Update(ctx context.Context, name string, value any) (rowsAffected int, err error)
|
||||||
Updates(ctx context.Context, t T) (rowsAffected int, err error)
|
Updates(ctx context.Context, t T) (rowsAffected int, err error)
|
||||||
@ -55,6 +60,14 @@ type ExecInterface[T any] interface {
|
|||||||
Rows(ctx context.Context) (*sql.Rows, error)
|
Rows(ctx context.Context) (*sql.Rows, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type QueryInterface interface {
|
||||||
|
Select(...string) QueryInterface
|
||||||
|
Omit(...string) QueryInterface
|
||||||
|
Where(query interface{}, args ...interface{}) QueryInterface
|
||||||
|
Not(query interface{}, args ...interface{}) QueryInterface
|
||||||
|
Or(query interface{}, args ...interface{}) QueryInterface
|
||||||
|
}
|
||||||
|
|
||||||
type op func(*DB) *DB
|
type op func(*DB) *DB
|
||||||
|
|
||||||
func G[T any](db *DB, opts ...clause.Expression) Interface[T] {
|
func G[T any](db *DB, opts ...clause.Expression) Interface[T] {
|
||||||
@ -185,9 +198,77 @@ func (c chainG[T]) Offset(offset int) ChainInterface[T] {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c chainG[T]) Joins(query clause.JoinTarget, args func(db ChainInterface[any], joinTable clause.Table, curTable clause.Table) ChainInterface[any]) ChainInterface[T] {
|
type query struct {
|
||||||
// TODO
|
db *DB
|
||||||
return nil
|
}
|
||||||
|
|
||||||
|
func (q query) Where(query interface{}, args ...interface{}) QueryInterface {
|
||||||
|
q.db.Where(query, args...)
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q query) Or(query interface{}, args ...interface{}) QueryInterface {
|
||||||
|
q.db.Where(query, args...)
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q query) Not(query interface{}, args ...interface{}) QueryInterface {
|
||||||
|
q.db.Where(query, args...)
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q query) Select(columns ...string) QueryInterface {
|
||||||
|
q.db.Select(columns)
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q query) Omit(columns ...string) QueryInterface {
|
||||||
|
q.db.Omit(columns...)
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c chainG[T]) Joins(jt clause.JoinTarget, args func(db QueryInterface, joinTable clause.Table, curTable clause.Table) QueryInterface) ChainInterface[T] {
|
||||||
|
return c.with(func(db *DB) *DB {
|
||||||
|
if jt.Table == "" {
|
||||||
|
jt.Table = clause.JoinTable(strings.Split(jt.Association, ".")...).Name
|
||||||
|
}
|
||||||
|
|
||||||
|
q := query{db: db.Session(&Session{NewDB: true}).getInstance().Table(jt.Table)}
|
||||||
|
if args != nil {
|
||||||
|
args(q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable})
|
||||||
|
}
|
||||||
|
|
||||||
|
j := join{
|
||||||
|
Name: jt.Association,
|
||||||
|
Alias: jt.Table,
|
||||||
|
Selects: q.db.Statement.Selects,
|
||||||
|
Omits: q.db.Statement.Omits,
|
||||||
|
JoinType: jt.Type,
|
||||||
|
}
|
||||||
|
|
||||||
|
if where, ok := q.db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok {
|
||||||
|
j.On = &where
|
||||||
|
}
|
||||||
|
|
||||||
|
if jt.Subquery != nil {
|
||||||
|
joinType := j.JoinType
|
||||||
|
if joinType == "" {
|
||||||
|
joinType = clause.LeftJoin
|
||||||
|
}
|
||||||
|
|
||||||
|
expr := clause.NamedExpr{SQL: fmt.Sprintf("%s JOIN (?) AS ?", joinType), Vars: []interface{}{jt.Subquery, clause.Table{Name: j.Alias}}}
|
||||||
|
|
||||||
|
if j.On != nil {
|
||||||
|
expr.SQL += " ON ?"
|
||||||
|
expr.Vars = append(expr.Vars, clause.AndConditions{Exprs: j.On.Exprs})
|
||||||
|
}
|
||||||
|
|
||||||
|
j.Expression = expr
|
||||||
|
}
|
||||||
|
|
||||||
|
db.Statement.Joins = append(db.Statement.Joins, j)
|
||||||
|
return db
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c chainG[T]) Select(query string, args ...interface{}) ChainInterface[T] {
|
func (c chainG[T]) Select(query string, args ...interface{}) ChainInterface[T] {
|
||||||
@ -261,6 +342,43 @@ func (c chainG[T]) Count(ctx context.Context, column string) (result int64, err
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c chainG[T]) Build(builder clause.Builder) {
|
||||||
|
subdb := c.getInstance()
|
||||||
|
subdb.Logger = logger.Discard
|
||||||
|
subdb.DryRun = true
|
||||||
|
|
||||||
|
if stmt, ok := builder.(*Statement); ok {
|
||||||
|
if subdb.Statement.SQL.Len() > 0 {
|
||||||
|
var (
|
||||||
|
vars = subdb.Statement.Vars
|
||||||
|
sql = subdb.Statement.SQL.String()
|
||||||
|
)
|
||||||
|
|
||||||
|
subdb.Statement.Vars = make([]interface{}, 0, len(vars))
|
||||||
|
for _, vv := range vars {
|
||||||
|
subdb.Statement.Vars = append(subdb.Statement.Vars, vv)
|
||||||
|
bindvar := strings.Builder{}
|
||||||
|
subdb.BindVarTo(&bindvar, subdb.Statement, vv)
|
||||||
|
sql = strings.Replace(sql, bindvar.String(), "?", 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
subdb.Statement.SQL.Reset()
|
||||||
|
subdb.Statement.Vars = stmt.Vars
|
||||||
|
if strings.Contains(sql, "@") {
|
||||||
|
clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement)
|
||||||
|
} else {
|
||||||
|
clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...)
|
||||||
|
subdb.callbacks.Query().Execute(subdb)
|
||||||
|
}
|
||||||
|
|
||||||
|
builder.WriteString(subdb.Statement.SQL.String())
|
||||||
|
stmt.Vars = subdb.Statement.Vars
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type execG[T any] struct {
|
type execG[T any] struct {
|
||||||
g *g[T]
|
g *g[T]
|
||||||
}
|
}
|
||||||
|
7
scan.go
7
scan.go
@ -4,6 +4,7 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gorm.io/gorm/schema"
|
"gorm.io/gorm/schema"
|
||||||
@ -244,6 +245,12 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
|
|||||||
matchedFieldCount[column] = 1
|
matchedFieldCount[column] = 1
|
||||||
}
|
}
|
||||||
} else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation
|
} else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation
|
||||||
|
for _, join := range db.Statement.Joins {
|
||||||
|
if join.Alias == names[0] {
|
||||||
|
names = append(strings.Split(join.Name, "."), names[len(names)-1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
|
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
|
||||||
subNameCount := len(names)
|
subNameCount := len(names)
|
||||||
// nested relation fields
|
// nested relation fields
|
||||||
|
43
statement.go
43
statement.go
@ -50,12 +50,14 @@ type Statement struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type join struct {
|
type join struct {
|
||||||
Name string
|
Name string
|
||||||
Conds []interface{}
|
Alias string
|
||||||
On *clause.Where
|
Conds []interface{}
|
||||||
Selects []string
|
On *clause.Where
|
||||||
Omits []string
|
Selects []string
|
||||||
JoinType clause.JoinType
|
Omits []string
|
||||||
|
Expression clause.Expression
|
||||||
|
JoinType clause.JoinType
|
||||||
}
|
}
|
||||||
|
|
||||||
// StatementModifier statement modifier interface
|
// StatementModifier statement modifier interface
|
||||||
@ -322,6 +324,11 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
|||||||
arg, _ = valuer.Value()
|
arg, _ = valuer.Value()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
curTable := stmt.Table
|
||||||
|
if curTable == "" {
|
||||||
|
curTable = clause.CurrentTable
|
||||||
|
}
|
||||||
|
|
||||||
switch v := arg.(type) {
|
switch v := arg.(type) {
|
||||||
case clause.Expression:
|
case clause.Expression:
|
||||||
conds = append(conds, v)
|
conds = append(conds, v)
|
||||||
@ -352,7 +359,8 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
|||||||
sort.Strings(keys)
|
sort.Strings(keys)
|
||||||
|
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
|
column := clause.Column{Name: key, Table: curTable}
|
||||||
|
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
|
||||||
}
|
}
|
||||||
case map[string]interface{}:
|
case map[string]interface{}:
|
||||||
keys := make([]string, 0, len(v))
|
keys := make([]string, 0, len(v))
|
||||||
@ -363,12 +371,13 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
|||||||
|
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
reflectValue := reflect.Indirect(reflect.ValueOf(v[key]))
|
reflectValue := reflect.Indirect(reflect.ValueOf(v[key]))
|
||||||
|
column := clause.Column{Name: key, Table: curTable}
|
||||||
switch reflectValue.Kind() {
|
switch reflectValue.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
if _, ok := v[key].(driver.Valuer); ok {
|
if _, ok := v[key].(driver.Valuer); ok {
|
||||||
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
|
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
|
||||||
} else if _, ok := v[key].(Valuer); ok {
|
} else if _, ok := v[key].(Valuer); ok {
|
||||||
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
|
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
|
||||||
} else {
|
} else {
|
||||||
// optimize reflect value length
|
// optimize reflect value length
|
||||||
valueLen := reflectValue.Len()
|
valueLen := reflectValue.Len()
|
||||||
@ -377,10 +386,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
|||||||
values[i] = reflectValue.Index(i).Interface()
|
values[i] = reflectValue.Index(i).Interface()
|
||||||
}
|
}
|
||||||
|
|
||||||
conds = append(conds, clause.IN{Column: key, Values: values})
|
conds = append(conds, clause.IN{Column: column, Values: values})
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
|
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
@ -407,9 +416,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
|||||||
if selected || (!restricted && field.Readable) {
|
if selected || (!restricted && field.Readable) {
|
||||||
if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected {
|
if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected {
|
||||||
if field.DBName != "" {
|
if field.DBName != "" {
|
||||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
|
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.DBName}, Value: v})
|
||||||
} else if field.DataType != "" {
|
} else if field.DataType != "" {
|
||||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
|
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.Name}, Value: v})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -421,9 +430,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
|||||||
if selected || (!restricted && field.Readable) {
|
if selected || (!restricted && field.Readable) {
|
||||||
if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected {
|
if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected {
|
||||||
if field.DBName != "" {
|
if field.DBName != "" {
|
||||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
|
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.DBName}, Value: v})
|
||||||
} else if field.DataType != "" {
|
} else if field.DataType != "" {
|
||||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
|
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.Name}, Value: v})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -448,14 +457,14 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(values) > 0 {
|
if len(values) > 0 {
|
||||||
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values})
|
conds = append(conds, clause.IN{Column: clause.Column{Table: curTable, Name: clause.PrimaryKey}, Values: values})
|
||||||
return []clause.Expression{clause.And(conds...)}
|
return []clause.Expression{clause.And(conds...)}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args})
|
conds = append(conds, clause.IN{Column: clause.Column{Table: curTable, Name: clause.PrimaryKey}, Values: args})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -44,7 +44,7 @@ func TestGenericsCreate(t *testing.T) {
|
|||||||
|
|
||||||
if u, err := gorm.G[User](DB).Omit("name").Where("name = ?", user.Name).First(ctx); err != nil {
|
if u, err := gorm.G[User](DB).Omit("name").Where("name = ?", user.Name).First(ctx); err != nil {
|
||||||
t.Fatalf("failed to find user, got error: %v", err)
|
t.Fatalf("failed to find user, got error: %v", err)
|
||||||
} else if u.Name != "" || u.Age != u.Age {
|
} else if u.Name != "" || u.Age != user.Age {
|
||||||
t.Errorf("found invalid user, got %v, expect %v", u, user)
|
t.Errorf("found invalid user, got %v, expect %v", u, user)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -92,7 +92,6 @@ func TestGenericsCreateInBatches(t *testing.T) {
|
|||||||
|
|
||||||
found, err := gorm.G[User](DB).Raw("SELECT * FROM users WHERE name LIKE ?", "GenericsCreateInBatches%").Find(ctx)
|
found, err := gorm.G[User](DB).Raw("SELECT * FROM users WHERE name LIKE ?", "GenericsCreateInBatches%").Find(ctx)
|
||||||
if len(found) != len(batch) {
|
if len(found) != len(batch) {
|
||||||
fmt.Println(found)
|
|
||||||
t.Errorf("expected %d from Raw Find, got %d", len(batch), len(found))
|
t.Errorf("expected %d from Raw Find, got %d", len(batch), len(found))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -282,10 +281,12 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
|
|||||||
db := gorm.G[User](DB)
|
db := gorm.G[User](DB)
|
||||||
|
|
||||||
u := User{Name: "GenericsJoins", Company: Company{Name: "GenericsCompany"}}
|
u := User{Name: "GenericsJoins", Company: Company{Name: "GenericsCompany"}}
|
||||||
db.Create(ctx, &u)
|
u2 := User{Name: "GenericsJoins_2", Company: Company{Name: "GenericsCompany_2"}}
|
||||||
|
u3 := User{Name: "GenericsJoins_3", Company: Company{Name: "GenericsCompany_3"}}
|
||||||
|
db.CreateInBatches(ctx, &[]User{u3, u, u2}, 10)
|
||||||
|
|
||||||
// LEFT JOIN + WHERE
|
// Inner JOIN + WHERE
|
||||||
result, err := db.Joins(clause.Has("Company"), func(db gorm.ChainInterface[any], joinTable clause.Table, curTable clause.Table) gorm.ChainInterface[any] {
|
result, err := db.Joins(clause.Has("Company"), func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface {
|
||||||
return db.Where("?.name = ?", joinTable, u.Company.Name)
|
return db.Where("?.name = ?", joinTable, u.Company.Name)
|
||||||
}).First(ctx)
|
}).First(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -295,9 +296,9 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
|
|||||||
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
|
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
// JOIN
|
// Inner JOIN + WHERE with map
|
||||||
result, err = db.Joins(clause.Has("Company"), func(db gorm.ChainInterface[any], joinTable clause.Table, curTable clause.Table) gorm.ChainInterface[any] {
|
result, err = db.Joins(clause.Has("Company"), func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface {
|
||||||
return nil
|
return db.Where(map[string]any{"name": u.Company.Name})
|
||||||
}).First(ctx)
|
}).First(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Joins failed: %v", err)
|
t.Fatalf("Joins failed: %v", err)
|
||||||
@ -306,10 +307,8 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
|
|||||||
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
|
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Left JOIN
|
// Left JOIN w/o WHERE
|
||||||
result, err = db.Joins(clause.LeftJoin.Association("Company").As("t"), func(db gorm.ChainInterface[any], joinTable clause.Table, curTable clause.Table) gorm.ChainInterface[any] {
|
result, err = db.Joins(clause.LeftJoin.Association("Company"), nil).Where(map[string]any{"name": u.Name}).First(ctx)
|
||||||
return nil
|
|
||||||
}).First(ctx)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Joins failed: %v", err)
|
t.Fatalf("Joins failed: %v", err)
|
||||||
}
|
}
|
||||||
@ -317,6 +316,36 @@ func TestGenericsJoinsAndPreload(t *testing.T) {
|
|||||||
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
|
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Left JOIN + Alias WHERE
|
||||||
|
result, err = db.Joins(clause.LeftJoin.Association("Company").As("t"), func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface {
|
||||||
|
if joinTable.Name != "t" {
|
||||||
|
t.Fatalf("Join table should be t, but got %v", joinTable.Name)
|
||||||
|
}
|
||||||
|
return db.Where("?.name = ?", joinTable, u.Company.Name)
|
||||||
|
}).Where(map[string]any{"name": u.Name}).First(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Joins failed: %v", err)
|
||||||
|
}
|
||||||
|
if result.Name != u.Name || result.Company.Name != u.Company.Name {
|
||||||
|
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Raw Subquery JOIN + WHERE
|
||||||
|
result, err = db.Joins(clause.LeftJoin.AssociationFrom("Company", gorm.G[Company](DB)).As("t"),
|
||||||
|
func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface {
|
||||||
|
if joinTable.Name != "t" {
|
||||||
|
t.Fatalf("Join table should be t, but got %v", joinTable.Name)
|
||||||
|
}
|
||||||
|
return db.Where("?.name = ?", joinTable, u.Company.Name)
|
||||||
|
},
|
||||||
|
).Where(map[string]any{"name": u2.Name}).First(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Raw subquery join failed: %v", err)
|
||||||
|
}
|
||||||
|
if result.Name != u2.Name || result.Company.Name != u.Company.Name {
|
||||||
|
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
|
||||||
|
}
|
||||||
|
|
||||||
// Preload
|
// Preload
|
||||||
result3, err := db.Preload("Company").Where("name = ?", u.Name).First(ctx)
|
result3, err := db.Preload("Company").Where("name = ?", u.Name).First(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user