diff --git a/callbacks/query.go b/callbacks/query.go index bbf238a9..56a5944a 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -146,9 +146,13 @@ func BuildQuerySQL(db *gorm.DB) { if isRelations { genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join { - tableAliasName := relation.Name - if parentTableName != clause.CurrentTable { - tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName) + tableAliasName := join.Alias + + if tableAliasName == "" { + tableAliasName = relation.Name + if parentTableName != clause.CurrentTable { + tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName) + } } columnStmt := gorm.Statement{ @@ -167,6 +171,13 @@ func BuildQuerySQL(db *gorm.DB) { } } + if join.Expression != nil { + return clause.Join{ + Type: join.JoinType, + Expression: join.Expression, + } + } + exprs := make([]clause.Expression, len(relation.References)) for idx, ref := range relation.References { if ref.OwnPrimaryKey { diff --git a/clause/joins.go b/clause/joins.go index ddb2a5a9..a6f13e55 100644 --- a/clause/joins.go +++ b/clause/joins.go @@ -19,15 +19,15 @@ type JoinTarget struct { } func Has(name string) JoinTarget { - return JoinTarget{Type: LeftJoin, Association: name} + return JoinTarget{Type: InnerJoin, Association: name} } func (jt JoinType) Association(name string) JoinTarget { return JoinTarget{Type: jt, Association: name} } -func (jt JoinType) Subquery(subquery Expression) JoinTarget { - return JoinTarget{Type: jt, Subquery: subquery} +func (jt JoinType) AssociationFrom(name string, subquery Expression) JoinTarget { + return JoinTarget{Type: jt, Association: name, Subquery: subquery} } func (jt JoinTarget) As(name string) JoinTarget { diff --git a/generics.go b/generics.go index 95d98100..43c7223a 100644 --- a/generics.go +++ b/generics.go @@ -3,8 +3,11 @@ package gorm import ( "context" "database/sql" + "fmt" + "strings" "gorm.io/gorm/clause" + "gorm.io/gorm/logger" ) type Interface[T any] interface { @@ -28,7 +31,7 @@ type ChainInterface[T any] interface { Or(query interface{}, args ...interface{}) ChainInterface[T] Limit(offset int) ChainInterface[T] Offset(offset int) ChainInterface[T] - Joins(query clause.JoinTarget, args func(db ChainInterface[any], joinTable clause.Table, curTable clause.Table) ChainInterface[any]) ChainInterface[T] + Joins(query clause.JoinTarget, args func(db QueryInterface, joinTable clause.Table, curTable clause.Table) QueryInterface) ChainInterface[T] Select(query string, args ...interface{}) ChainInterface[T] Omit(columns ...string) ChainInterface[T] MapColumns(m map[string]string) ChainInterface[T] @@ -38,6 +41,8 @@ type ChainInterface[T any] interface { Order(value interface{}) ChainInterface[T] Preload(query string, args ...interface{}) ChainInterface[T] + Build(builder clause.Builder) + Delete(ctx context.Context) (rowsAffected int, err error) Update(ctx context.Context, name string, value any) (rowsAffected int, err error) Updates(ctx context.Context, t T) (rowsAffected int, err error) @@ -55,6 +60,14 @@ type ExecInterface[T any] interface { Rows(ctx context.Context) (*sql.Rows, error) } +type QueryInterface interface { + Select(...string) QueryInterface + Omit(...string) QueryInterface + Where(query interface{}, args ...interface{}) QueryInterface + Not(query interface{}, args ...interface{}) QueryInterface + Or(query interface{}, args ...interface{}) QueryInterface +} + type op func(*DB) *DB func G[T any](db *DB, opts ...clause.Expression) Interface[T] { @@ -185,9 +198,77 @@ func (c chainG[T]) Offset(offset int) ChainInterface[T] { }) } -func (c chainG[T]) Joins(query clause.JoinTarget, args func(db ChainInterface[any], joinTable clause.Table, curTable clause.Table) ChainInterface[any]) ChainInterface[T] { - // TODO - return nil +type query struct { + db *DB +} + +func (q query) Where(query interface{}, args ...interface{}) QueryInterface { + q.db.Where(query, args...) + return q +} + +func (q query) Or(query interface{}, args ...interface{}) QueryInterface { + q.db.Where(query, args...) + return q +} + +func (q query) Not(query interface{}, args ...interface{}) QueryInterface { + q.db.Where(query, args...) + return q +} + +func (q query) Select(columns ...string) QueryInterface { + q.db.Select(columns) + return q +} + +func (q query) Omit(columns ...string) QueryInterface { + q.db.Omit(columns...) + return q +} + +func (c chainG[T]) Joins(jt clause.JoinTarget, args func(db QueryInterface, joinTable clause.Table, curTable clause.Table) QueryInterface) ChainInterface[T] { + return c.with(func(db *DB) *DB { + if jt.Table == "" { + jt.Table = clause.JoinTable(strings.Split(jt.Association, ".")...).Name + } + + q := query{db: db.Session(&Session{NewDB: true}).getInstance().Table(jt.Table)} + if args != nil { + args(q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable}) + } + + j := join{ + Name: jt.Association, + Alias: jt.Table, + Selects: q.db.Statement.Selects, + Omits: q.db.Statement.Omits, + JoinType: jt.Type, + } + + if where, ok := q.db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok { + j.On = &where + } + + if jt.Subquery != nil { + joinType := j.JoinType + if joinType == "" { + joinType = clause.LeftJoin + } + + expr := clause.NamedExpr{SQL: fmt.Sprintf("%s JOIN (?) AS ?", joinType), Vars: []interface{}{jt.Subquery, clause.Table{Name: j.Alias}}} + + if j.On != nil { + expr.SQL += " ON ?" + expr.Vars = append(expr.Vars, clause.AndConditions{Exprs: j.On.Exprs}) + } + + j.Expression = expr + } + + db.Statement.Joins = append(db.Statement.Joins, j) + return db + }) } func (c chainG[T]) Select(query string, args ...interface{}) ChainInterface[T] { @@ -261,6 +342,43 @@ func (c chainG[T]) Count(ctx context.Context, column string) (result int64, err return } +func (c chainG[T]) Build(builder clause.Builder) { + subdb := c.getInstance() + subdb.Logger = logger.Discard + subdb.DryRun = true + + if stmt, ok := builder.(*Statement); ok { + if subdb.Statement.SQL.Len() > 0 { + var ( + vars = subdb.Statement.Vars + sql = subdb.Statement.SQL.String() + ) + + subdb.Statement.Vars = make([]interface{}, 0, len(vars)) + for _, vv := range vars { + subdb.Statement.Vars = append(subdb.Statement.Vars, vv) + bindvar := strings.Builder{} + subdb.BindVarTo(&bindvar, subdb.Statement, vv) + sql = strings.Replace(sql, bindvar.String(), "?", 1) + } + + subdb.Statement.SQL.Reset() + subdb.Statement.Vars = stmt.Vars + if strings.Contains(sql, "@") { + clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement) + } else { + clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement) + } + } else { + subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...) + subdb.callbacks.Query().Execute(subdb) + } + + builder.WriteString(subdb.Statement.SQL.String()) + stmt.Vars = subdb.Statement.Vars + } +} + type execG[T any] struct { g *g[T] } diff --git a/scan.go b/scan.go index 6dc55f62..624f822f 100644 --- a/scan.go +++ b/scan.go @@ -4,6 +4,7 @@ import ( "database/sql" "database/sql/driver" "reflect" + "strings" "time" "gorm.io/gorm/schema" @@ -244,6 +245,12 @@ func Scan(rows Rows, db *DB, mode ScanMode) { matchedFieldCount[column] = 1 } } else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation + for _, join := range db.Statement.Joins { + if join.Alias == names[0] { + names = append(strings.Split(join.Name, "."), names[len(names)-1]) + } + } + if rel, ok := sch.Relationships.Relations[names[0]]; ok { subNameCount := len(names) // nested relation fields diff --git a/statement.go b/statement.go index 88d76dc3..63f78006 100644 --- a/statement.go +++ b/statement.go @@ -50,12 +50,14 @@ type Statement struct { } type join struct { - Name string - Conds []interface{} - On *clause.Where - Selects []string - Omits []string - JoinType clause.JoinType + Name string + Alias string + Conds []interface{} + On *clause.Where + Selects []string + Omits []string + Expression clause.Expression + JoinType clause.JoinType } // StatementModifier statement modifier interface @@ -322,6 +324,11 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] arg, _ = valuer.Value() } + curTable := stmt.Table + if curTable == "" { + curTable = clause.CurrentTable + } + switch v := arg.(type) { case clause.Expression: conds = append(conds, v) @@ -352,7 +359,8 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] sort.Strings(keys) for _, key := range keys { - conds = append(conds, clause.Eq{Column: key, Value: v[key]}) + column := clause.Column{Name: key, Table: curTable} + conds = append(conds, clause.Eq{Column: column, Value: v[key]}) } case map[string]interface{}: keys := make([]string, 0, len(v)) @@ -363,12 +371,13 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] for _, key := range keys { reflectValue := reflect.Indirect(reflect.ValueOf(v[key])) + column := clause.Column{Name: key, Table: curTable} switch reflectValue.Kind() { case reflect.Slice, reflect.Array: if _, ok := v[key].(driver.Valuer); ok { - conds = append(conds, clause.Eq{Column: key, Value: v[key]}) + conds = append(conds, clause.Eq{Column: column, Value: v[key]}) } else if _, ok := v[key].(Valuer); ok { - conds = append(conds, clause.Eq{Column: key, Value: v[key]}) + conds = append(conds, clause.Eq{Column: column, Value: v[key]}) } else { // optimize reflect value length valueLen := reflectValue.Len() @@ -377,10 +386,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] values[i] = reflectValue.Index(i).Interface() } - conds = append(conds, clause.IN{Column: key, Values: values}) + conds = append(conds, clause.IN{Column: column, Values: values}) } default: - conds = append(conds, clause.Eq{Column: key, Value: v[key]}) + conds = append(conds, clause.Eq{Column: column, Value: v[key]}) } } default: @@ -407,9 +416,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] if selected || (!restricted && field.Readable) { if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected { if field.DBName != "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) + conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.DBName}, Value: v}) } else if field.DataType != "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v}) + conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.Name}, Value: v}) } } } @@ -421,9 +430,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] if selected || (!restricted && field.Readable) { if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected { if field.DBName != "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) + conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.DBName}, Value: v}) } else if field.DataType != "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v}) + conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.Name}, Value: v}) } } } @@ -448,14 +457,14 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] } if len(values) > 0 { - conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values}) + conds = append(conds, clause.IN{Column: clause.Column{Table: curTable, Name: clause.PrimaryKey}, Values: values}) return []clause.Expression{clause.And(conds...)} } return nil } } - conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args}) + conds = append(conds, clause.IN{Column: clause.Column{Table: curTable, Name: clause.PrimaryKey}, Values: args}) } } } diff --git a/tests/generics_test.go b/tests/generics_test.go index 2d69d22e..1e1bf711 100644 --- a/tests/generics_test.go +++ b/tests/generics_test.go @@ -44,7 +44,7 @@ func TestGenericsCreate(t *testing.T) { if u, err := gorm.G[User](DB).Omit("name").Where("name = ?", user.Name).First(ctx); err != nil { t.Fatalf("failed to find user, got error: %v", err) - } else if u.Name != "" || u.Age != u.Age { + } else if u.Name != "" || u.Age != user.Age { t.Errorf("found invalid user, got %v, expect %v", u, user) } @@ -92,7 +92,6 @@ func TestGenericsCreateInBatches(t *testing.T) { found, err := gorm.G[User](DB).Raw("SELECT * FROM users WHERE name LIKE ?", "GenericsCreateInBatches%").Find(ctx) if len(found) != len(batch) { - fmt.Println(found) t.Errorf("expected %d from Raw Find, got %d", len(batch), len(found)) } @@ -282,10 +281,12 @@ func TestGenericsJoinsAndPreload(t *testing.T) { db := gorm.G[User](DB) u := User{Name: "GenericsJoins", Company: Company{Name: "GenericsCompany"}} - db.Create(ctx, &u) + u2 := User{Name: "GenericsJoins_2", Company: Company{Name: "GenericsCompany_2"}} + u3 := User{Name: "GenericsJoins_3", Company: Company{Name: "GenericsCompany_3"}} + db.CreateInBatches(ctx, &[]User{u3, u, u2}, 10) - // LEFT JOIN + WHERE - result, err := db.Joins(clause.Has("Company"), func(db gorm.ChainInterface[any], joinTable clause.Table, curTable clause.Table) gorm.ChainInterface[any] { + // Inner JOIN + WHERE + result, err := db.Joins(clause.Has("Company"), func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface { return db.Where("?.name = ?", joinTable, u.Company.Name) }).First(ctx) if err != nil { @@ -295,9 +296,9 @@ func TestGenericsJoinsAndPreload(t *testing.T) { t.Fatalf("Joins expected %s, got %+v", u.Name, result) } - // JOIN - result, err = db.Joins(clause.Has("Company"), func(db gorm.ChainInterface[any], joinTable clause.Table, curTable clause.Table) gorm.ChainInterface[any] { - return nil + // Inner JOIN + WHERE with map + result, err = db.Joins(clause.Has("Company"), func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface { + return db.Where(map[string]any{"name": u.Company.Name}) }).First(ctx) if err != nil { t.Fatalf("Joins failed: %v", err) @@ -306,10 +307,8 @@ func TestGenericsJoinsAndPreload(t *testing.T) { t.Fatalf("Joins expected %s, got %+v", u.Name, result) } - // Left JOIN - result, err = db.Joins(clause.LeftJoin.Association("Company").As("t"), func(db gorm.ChainInterface[any], joinTable clause.Table, curTable clause.Table) gorm.ChainInterface[any] { - return nil - }).First(ctx) + // Left JOIN w/o WHERE + result, err = db.Joins(clause.LeftJoin.Association("Company"), nil).Where(map[string]any{"name": u.Name}).First(ctx) if err != nil { t.Fatalf("Joins failed: %v", err) } @@ -317,6 +316,36 @@ func TestGenericsJoinsAndPreload(t *testing.T) { t.Fatalf("Joins expected %s, got %+v", u.Name, result) } + // Left JOIN + Alias WHERE + result, err = db.Joins(clause.LeftJoin.Association("Company").As("t"), func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface { + if joinTable.Name != "t" { + t.Fatalf("Join table should be t, but got %v", joinTable.Name) + } + return db.Where("?.name = ?", joinTable, u.Company.Name) + }).Where(map[string]any{"name": u.Name}).First(ctx) + if err != nil { + t.Fatalf("Joins failed: %v", err) + } + if result.Name != u.Name || result.Company.Name != u.Company.Name { + t.Fatalf("Joins expected %s, got %+v", u.Name, result) + } + + // Raw Subquery JOIN + WHERE + result, err = db.Joins(clause.LeftJoin.AssociationFrom("Company", gorm.G[Company](DB)).As("t"), + func(db gorm.QueryInterface, joinTable clause.Table, curTable clause.Table) gorm.QueryInterface { + if joinTable.Name != "t" { + t.Fatalf("Join table should be t, but got %v", joinTable.Name) + } + return db.Where("?.name = ?", joinTable, u.Company.Name) + }, + ).Where(map[string]any{"name": u2.Name}).First(ctx) + if err != nil { + t.Fatalf("Raw subquery join failed: %v", err) + } + if result.Name != u2.Name || result.Company.Name != u.Company.Name { + t.Fatalf("Joins expected %s, got %+v", u.Name, result) + } + // Preload result3, err := db.Preload("Company").Where("name = ?", u.Name).First(ctx) if err != nil {