Add with clause support
This commit is contained in:
parent
61b4c31236
commit
825bbf35ec
@ -6,7 +6,7 @@ import (
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
createClauses = []string{"INSERT", "VALUES", "ON CONFLICT"}
|
createClauses = []string{"INSERT", "VALUES", "ON CONFLICT"}
|
||||||
queryClauses = []string{"SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR"}
|
queryClauses = []string{"WITH", "SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR"}
|
||||||
updateClauses = []string{"UPDATE", "SET", "WHERE"}
|
updateClauses = []string{"UPDATE", "SET", "WHERE"}
|
||||||
deleteClauses = []string{"DELETE", "FROM", "WHERE"}
|
deleteClauses = []string{"DELETE", "FROM", "WHERE"}
|
||||||
)
|
)
|
||||||
|
104
chainable_api.go
104
chainable_api.go
@ -1,6 +1,7 @@
|
|||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
@ -61,6 +62,109 @@ func (db *DB) Table(name string, args ...interface{}) (tx *DB) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Recursive specify recursive common table expressions
|
||||||
|
func (db *DB) Recursive(name string, args ...interface{}) (tx *DB) {
|
||||||
|
return db.with(name, true, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// With specify common table expressions
|
||||||
|
//
|
||||||
|
// db.With("cte", "SELECT id, `name` FROM users")
|
||||||
|
// db.With("cte", gorm.Expr("SELECT id, `name` FROM users"))
|
||||||
|
// db.With("cte", db.Select("id", "name").Table("users"))
|
||||||
|
// db.With("cte", "id", "name", "SELECT id, `name` FROM users")
|
||||||
|
// db.With("cte", "id, name", "SELECT id, `name` FROM users")
|
||||||
|
// db.With("cte", []string{"id", "name"}, "SELECT id, `name` FROM users")
|
||||||
|
func (db *DB) With(name string, args ...interface{}) (tx *DB) {
|
||||||
|
return db.with(name, false, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// with clause
|
||||||
|
//
|
||||||
|
// db.with("cte", false, "SELECT id, `name` FROM users")
|
||||||
|
// db.with("cte", false, gorm.Expr("SELECT id, `name` FROM users"))
|
||||||
|
// db.with("cte", false, db.Select("id", "name").Table("users"))
|
||||||
|
// db.with("cte", false, "id", "name", "SELECT id, `name` FROM users")
|
||||||
|
// db.with("cte", false, "id, name", "SELECT id, `name` FROM users")
|
||||||
|
// db.with("cte", false, []string{"id", "name"}, "SELECT id, `name` FROM users")
|
||||||
|
func (db *DB) with(name string, recursive bool, args ...interface{}) (tx *DB) {
|
||||||
|
tx = db.getInstance()
|
||||||
|
|
||||||
|
if len(args) == 0 {
|
||||||
|
tx.AddError(errors.New("with clause must have subquery"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// parse optional columns
|
||||||
|
columns := make([]string, 0, len(args)-1)
|
||||||
|
switch arg := args[0].(type) {
|
||||||
|
case string:
|
||||||
|
switch len(args) {
|
||||||
|
case 1:
|
||||||
|
// args[0] is subquery
|
||||||
|
break
|
||||||
|
case 2:
|
||||||
|
// args[0] is single column or comma-separated string
|
||||||
|
columns = strings.FieldsFunc(arg, func(r rune) bool {
|
||||||
|
return r == ',' || r == ' '
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
// args[:len(args)-1] is []string
|
||||||
|
for idx := 0; idx < len(args)-1; idx++ {
|
||||||
|
col, ok := args[idx].(string)
|
||||||
|
if !ok {
|
||||||
|
tx.AddError(fmt.Errorf("unsupported with column type %T", args[idx]))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strings.Contains(col, ",") {
|
||||||
|
tx.AddError(fmt.Errorf("unsupported mixed string type"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
columns = append(columns, col)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case []string:
|
||||||
|
for _, col := range arg {
|
||||||
|
if strings.Contains(col, ",") {
|
||||||
|
tx.AddError(fmt.Errorf("unsupported mixed string type"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
columns = arg
|
||||||
|
default:
|
||||||
|
// specify optional columns field with a wrong type
|
||||||
|
if len(args) > 1 {
|
||||||
|
tx.AddError(fmt.Errorf("unsupported with column type %T", arg))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parse subquery
|
||||||
|
var subquery clause.Expression
|
||||||
|
switch arg := args[len(args)-1].(type) {
|
||||||
|
case clause.Expression:
|
||||||
|
subquery = arg
|
||||||
|
case string:
|
||||||
|
subquery = clause.Expr{SQL: arg}
|
||||||
|
case *DB:
|
||||||
|
subquery = clause.Expr{SQL: "?", Vars: []interface{}{arg}, WithoutParentheses: true}
|
||||||
|
default:
|
||||||
|
tx.AddError(fmt.Errorf("unsupported with subquery type %T", arg))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// add with clause
|
||||||
|
tx.Statement.AddClause(clause.With{
|
||||||
|
Recursive: recursive,
|
||||||
|
Exprs: []clause.Expression{clause.WithExpression{
|
||||||
|
Name: name,
|
||||||
|
Columns: columns,
|
||||||
|
Expr: subquery,
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Distinct specify distinct fields that you want querying
|
// Distinct specify distinct fields that you want querying
|
||||||
func (db *DB) Distinct(args ...interface{}) (tx *DB) {
|
func (db *DB) Distinct(args ...interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
|
@ -17,15 +17,15 @@ var db, _ = gorm.Open(tests.DummyDialector{}, nil)
|
|||||||
func checkBuildClauses(t *testing.T, clauses []clause.Interface, result string, vars []interface{}) {
|
func checkBuildClauses(t *testing.T, clauses []clause.Interface, result string, vars []interface{}) {
|
||||||
var (
|
var (
|
||||||
buildNames []string
|
buildNames []string
|
||||||
buildNamesMap = map[string]bool{}
|
buildNamesSet = map[string]struct{}{}
|
||||||
user, _ = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
|
user, _ = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
|
||||||
stmt = gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
|
stmt = gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
|
||||||
)
|
)
|
||||||
|
|
||||||
for _, c := range clauses {
|
for _, c := range clauses {
|
||||||
if _, ok := buildNamesMap[c.Name()]; !ok {
|
if _, ok := buildNamesSet[c.Name()]; !ok {
|
||||||
buildNames = append(buildNames, c.Name())
|
buildNames = append(buildNames, c.Name())
|
||||||
buildNamesMap[c.Name()] = true
|
buildNamesSet[c.Name()] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
stmt.AddClause(c)
|
stmt.AddClause(c)
|
||||||
|
@ -1,3 +1,88 @@
|
|||||||
package clause
|
package clause
|
||||||
|
|
||||||
type With struct{}
|
import "strings"
|
||||||
|
|
||||||
|
// With Common Table Expressions
|
||||||
|
type With struct {
|
||||||
|
Recursive bool
|
||||||
|
Exprs []Expression
|
||||||
|
Expression Expression
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name with clause name
|
||||||
|
func (with With) Name() string {
|
||||||
|
return "WITH"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build build with clause
|
||||||
|
func (with With) Build(builder Builder) {
|
||||||
|
if with.Expression != nil {
|
||||||
|
with.Expression.Build(builder)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(with.Exprs) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if with.Recursive {
|
||||||
|
builder.WriteString("RECURSIVE ")
|
||||||
|
}
|
||||||
|
for idx, expr := range with.Exprs {
|
||||||
|
if idx != 0 {
|
||||||
|
builder.WriteByte(',')
|
||||||
|
}
|
||||||
|
expr.Build(builder)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MergeClause merge with clauses
|
||||||
|
func (with With) MergeClause(clause *Clause) {
|
||||||
|
if w, ok := clause.Expression.(With); ok {
|
||||||
|
if !with.Recursive {
|
||||||
|
with.Recursive = w.Recursive
|
||||||
|
}
|
||||||
|
if w.Expression != nil {
|
||||||
|
with.Expression = w.Expression
|
||||||
|
with.Exprs = nil
|
||||||
|
} else if with.Expression == nil {
|
||||||
|
exprs := make([]Expression, len(w.Exprs)+len(with.Exprs))
|
||||||
|
copy(exprs, w.Exprs)
|
||||||
|
copy(exprs[len(w.Exprs):], with.Exprs)
|
||||||
|
with.Exprs = exprs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
clause.Expression = with
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithExpression with expression
|
||||||
|
type WithExpression struct {
|
||||||
|
Name string
|
||||||
|
Columns []string
|
||||||
|
Expr Expression
|
||||||
|
}
|
||||||
|
|
||||||
|
func (with WithExpression) Build(builder Builder) {
|
||||||
|
if with.Name == "" || with.Expr == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
builder.WriteQuoted(with.Name)
|
||||||
|
|
||||||
|
if len(with.Columns) > 0 {
|
||||||
|
builder.WriteByte(' ')
|
||||||
|
builder.WriteByte('(')
|
||||||
|
for idx, column := range with.Columns {
|
||||||
|
if idx != 0 {
|
||||||
|
builder.WriteByte(',')
|
||||||
|
}
|
||||||
|
column = strings.TrimSpace(column)
|
||||||
|
builder.WriteQuoted(column)
|
||||||
|
}
|
||||||
|
builder.WriteByte(')')
|
||||||
|
}
|
||||||
|
|
||||||
|
builder.WriteString(" AS ")
|
||||||
|
builder.WriteByte('(')
|
||||||
|
with.Expr.Build(builder)
|
||||||
|
builder.WriteByte(')')
|
||||||
|
}
|
||||||
|
104
clause/with_test.go
Normal file
104
clause/with_test.go
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
package clause_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWith(t *testing.T) {
|
||||||
|
results := []struct {
|
||||||
|
Clauses []clause.Interface
|
||||||
|
Result string
|
||||||
|
Vars []interface{}
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
[]clause.Interface{clause.With{Exprs: []clause.Expression{
|
||||||
|
clause.WithExpression{Name: "cte1", Columns: []string{"foo", "bar"}, Expr: clause.Expr{SQL: "SELECT 1, 2"}},
|
||||||
|
clause.WithExpression{Name: "cte2", Expr: clause.Expr{SQL: "SELECT 1"}},
|
||||||
|
}}},
|
||||||
|
"WITH `cte1` (`foo`,`bar`) AS (SELECT 1, 2),`cte2` AS (SELECT 1)",
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
[]clause.Interface{
|
||||||
|
clause.With{
|
||||||
|
Recursive: true,
|
||||||
|
Exprs: []clause.Expression{clause.WithExpression{Name: "cte1", Expr: clause.Expr{SQL: "SELECT 1"}}},
|
||||||
|
},
|
||||||
|
clause.With{
|
||||||
|
Recursive: false,
|
||||||
|
Exprs: []clause.Expression{clause.WithExpression{Name: "cte2", Expr: clause.Expr{SQL: "SELECT 2"}}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"WITH RECURSIVE `cte1` AS (SELECT 1),`cte2` AS (SELECT 2)",
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
[]clause.Interface{
|
||||||
|
clause.With{
|
||||||
|
Expression: clause.Expr{SQL: "`cte1` AS (SELECT 1)"},
|
||||||
|
},
|
||||||
|
clause.With{
|
||||||
|
Exprs: []clause.Expression{clause.WithExpression{Name: "cte2", Expr: clause.Expr{SQL: "SELECT 2"}}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"WITH `cte1` AS (SELECT 1)",
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
[]clause.Interface{
|
||||||
|
clause.With{
|
||||||
|
Exprs: []clause.Expression{clause.WithExpression{Name: "cte1", Expr: clause.Expr{SQL: "SELECT 1"}}},
|
||||||
|
},
|
||||||
|
clause.With{
|
||||||
|
Expression: clause.Expr{SQL: "`cte2` AS (SELECT 2)"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"WITH `cte2` AS (SELECT 2)",
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
[]clause.Interface{clause.With{
|
||||||
|
Exprs: []clause.Expression{
|
||||||
|
clause.WithExpression{Name: "cte1", Columns: []string{"foo", "bar"}, Expr: clause.Expr{SQL: "SELECT 1, 2"}},
|
||||||
|
},
|
||||||
|
Expression: clause.Expr{SQL: "cte2 AS (SELECT 1)"},
|
||||||
|
}},
|
||||||
|
"WITH cte2 AS (SELECT 1)",
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
[]clause.Interface{clause.With{
|
||||||
|
Recursive: false,
|
||||||
|
Exprs: nil,
|
||||||
|
Expression: nil,
|
||||||
|
}},
|
||||||
|
"WITH",
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
[]clause.Interface{clause.With{Exprs: []clause.Expression{clause.WithExpression{
|
||||||
|
Name: "",
|
||||||
|
Expr: clause.Expr{SQL: "SELECT 1"},
|
||||||
|
}}}},
|
||||||
|
"WITH",
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
[]clause.Interface{clause.With{Exprs: []clause.Expression{clause.WithExpression{
|
||||||
|
Name: "cte1",
|
||||||
|
Expr: nil,
|
||||||
|
}}}},
|
||||||
|
"WITH",
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for idx, result := range results {
|
||||||
|
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
|
||||||
|
checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user