Add with clause support

This commit is contained in:
kvii2202 2022-03-16 19:54:12 +08:00
parent 61b4c31236
commit 825bbf35ec
5 changed files with 298 additions and 5 deletions

View File

@ -6,7 +6,7 @@ import (
var (
createClauses = []string{"INSERT", "VALUES", "ON CONFLICT"}
queryClauses = []string{"SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR"}
queryClauses = []string{"WITH", "SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR"}
updateClauses = []string{"UPDATE", "SET", "WHERE"}
deleteClauses = []string{"DELETE", "FROM", "WHERE"}
)

View File

@ -1,6 +1,7 @@
package gorm
import (
"errors"
"fmt"
"regexp"
"strings"
@ -61,6 +62,109 @@ func (db *DB) Table(name string, args ...interface{}) (tx *DB) {
return
}
// Recursive specify recursive common table expressions
func (db *DB) Recursive(name string, args ...interface{}) (tx *DB) {
return db.with(name, true, args...)
}
// With specify common table expressions
//
// db.With("cte", "SELECT id, `name` FROM users")
// db.With("cte", gorm.Expr("SELECT id, `name` FROM users"))
// db.With("cte", db.Select("id", "name").Table("users"))
// db.With("cte", "id", "name", "SELECT id, `name` FROM users")
// db.With("cte", "id, name", "SELECT id, `name` FROM users")
// db.With("cte", []string{"id", "name"}, "SELECT id, `name` FROM users")
func (db *DB) With(name string, args ...interface{}) (tx *DB) {
return db.with(name, false, args...)
}
// with clause
//
// db.with("cte", false, "SELECT id, `name` FROM users")
// db.with("cte", false, gorm.Expr("SELECT id, `name` FROM users"))
// db.with("cte", false, db.Select("id", "name").Table("users"))
// db.with("cte", false, "id", "name", "SELECT id, `name` FROM users")
// db.with("cte", false, "id, name", "SELECT id, `name` FROM users")
// db.with("cte", false, []string{"id", "name"}, "SELECT id, `name` FROM users")
func (db *DB) with(name string, recursive bool, args ...interface{}) (tx *DB) {
tx = db.getInstance()
if len(args) == 0 {
tx.AddError(errors.New("with clause must have subquery"))
return
}
// parse optional columns
columns := make([]string, 0, len(args)-1)
switch arg := args[0].(type) {
case string:
switch len(args) {
case 1:
// args[0] is subquery
break
case 2:
// args[0] is single column or comma-separated string
columns = strings.FieldsFunc(arg, func(r rune) bool {
return r == ',' || r == ' '
})
default:
// args[:len(args)-1] is []string
for idx := 0; idx < len(args)-1; idx++ {
col, ok := args[idx].(string)
if !ok {
tx.AddError(fmt.Errorf("unsupported with column type %T", args[idx]))
return
}
if strings.Contains(col, ",") {
tx.AddError(fmt.Errorf("unsupported mixed string type"))
return
}
columns = append(columns, col)
}
}
case []string:
for _, col := range arg {
if strings.Contains(col, ",") {
tx.AddError(fmt.Errorf("unsupported mixed string type"))
return
}
}
columns = arg
default:
// specify optional columns field with a wrong type
if len(args) > 1 {
tx.AddError(fmt.Errorf("unsupported with column type %T", arg))
return
}
}
// parse subquery
var subquery clause.Expression
switch arg := args[len(args)-1].(type) {
case clause.Expression:
subquery = arg
case string:
subquery = clause.Expr{SQL: arg}
case *DB:
subquery = clause.Expr{SQL: "?", Vars: []interface{}{arg}, WithoutParentheses: true}
default:
tx.AddError(fmt.Errorf("unsupported with subquery type %T", arg))
return
}
// add with clause
tx.Statement.AddClause(clause.With{
Recursive: recursive,
Exprs: []clause.Expression{clause.WithExpression{
Name: name,
Columns: columns,
Expr: subquery,
}},
})
return
}
// Distinct specify distinct fields that you want querying
func (db *DB) Distinct(args ...interface{}) (tx *DB) {
tx = db.getInstance()

View File

@ -17,15 +17,15 @@ var db, _ = gorm.Open(tests.DummyDialector{}, nil)
func checkBuildClauses(t *testing.T, clauses []clause.Interface, result string, vars []interface{}) {
var (
buildNames []string
buildNamesMap = map[string]bool{}
buildNamesSet = map[string]struct{}{}
user, _ = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
stmt = gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
)
for _, c := range clauses {
if _, ok := buildNamesMap[c.Name()]; !ok {
if _, ok := buildNamesSet[c.Name()]; !ok {
buildNames = append(buildNames, c.Name())
buildNamesMap[c.Name()] = true
buildNamesSet[c.Name()] = struct{}{}
}
stmt.AddClause(c)

View File

@ -1,3 +1,88 @@
package clause
type With struct{}
import "strings"
// With Common Table Expressions
type With struct {
Recursive bool
Exprs []Expression
Expression Expression
}
// Name with clause name
func (with With) Name() string {
return "WITH"
}
// Build build with clause
func (with With) Build(builder Builder) {
if with.Expression != nil {
with.Expression.Build(builder)
return
}
if len(with.Exprs) == 0 {
return
}
if with.Recursive {
builder.WriteString("RECURSIVE ")
}
for idx, expr := range with.Exprs {
if idx != 0 {
builder.WriteByte(',')
}
expr.Build(builder)
}
}
// MergeClause merge with clauses
func (with With) MergeClause(clause *Clause) {
if w, ok := clause.Expression.(With); ok {
if !with.Recursive {
with.Recursive = w.Recursive
}
if w.Expression != nil {
with.Expression = w.Expression
with.Exprs = nil
} else if with.Expression == nil {
exprs := make([]Expression, len(w.Exprs)+len(with.Exprs))
copy(exprs, w.Exprs)
copy(exprs[len(w.Exprs):], with.Exprs)
with.Exprs = exprs
}
}
clause.Expression = with
}
// WithExpression with expression
type WithExpression struct {
Name string
Columns []string
Expr Expression
}
func (with WithExpression) Build(builder Builder) {
if with.Name == "" || with.Expr == nil {
return
}
builder.WriteQuoted(with.Name)
if len(with.Columns) > 0 {
builder.WriteByte(' ')
builder.WriteByte('(')
for idx, column := range with.Columns {
if idx != 0 {
builder.WriteByte(',')
}
column = strings.TrimSpace(column)
builder.WriteQuoted(column)
}
builder.WriteByte(')')
}
builder.WriteString(" AS ")
builder.WriteByte('(')
with.Expr.Build(builder)
builder.WriteByte(')')
}

104
clause/with_test.go Normal file
View File

@ -0,0 +1,104 @@
package clause_test
import (
"fmt"
"testing"
"gorm.io/gorm/clause"
)
func TestWith(t *testing.T) {
results := []struct {
Clauses []clause.Interface
Result string
Vars []interface{}
}{
{
[]clause.Interface{clause.With{Exprs: []clause.Expression{
clause.WithExpression{Name: "cte1", Columns: []string{"foo", "bar"}, Expr: clause.Expr{SQL: "SELECT 1, 2"}},
clause.WithExpression{Name: "cte2", Expr: clause.Expr{SQL: "SELECT 1"}},
}}},
"WITH `cte1` (`foo`,`bar`) AS (SELECT 1, 2),`cte2` AS (SELECT 1)",
nil,
},
{
[]clause.Interface{
clause.With{
Recursive: true,
Exprs: []clause.Expression{clause.WithExpression{Name: "cte1", Expr: clause.Expr{SQL: "SELECT 1"}}},
},
clause.With{
Recursive: false,
Exprs: []clause.Expression{clause.WithExpression{Name: "cte2", Expr: clause.Expr{SQL: "SELECT 2"}}},
},
},
"WITH RECURSIVE `cte1` AS (SELECT 1),`cte2` AS (SELECT 2)",
nil,
},
{
[]clause.Interface{
clause.With{
Expression: clause.Expr{SQL: "`cte1` AS (SELECT 1)"},
},
clause.With{
Exprs: []clause.Expression{clause.WithExpression{Name: "cte2", Expr: clause.Expr{SQL: "SELECT 2"}}},
},
},
"WITH `cte1` AS (SELECT 1)",
nil,
},
{
[]clause.Interface{
clause.With{
Exprs: []clause.Expression{clause.WithExpression{Name: "cte1", Expr: clause.Expr{SQL: "SELECT 1"}}},
},
clause.With{
Expression: clause.Expr{SQL: "`cte2` AS (SELECT 2)"},
},
},
"WITH `cte2` AS (SELECT 2)",
nil,
},
{
[]clause.Interface{clause.With{
Exprs: []clause.Expression{
clause.WithExpression{Name: "cte1", Columns: []string{"foo", "bar"}, Expr: clause.Expr{SQL: "SELECT 1, 2"}},
},
Expression: clause.Expr{SQL: "cte2 AS (SELECT 1)"},
}},
"WITH cte2 AS (SELECT 1)",
nil,
},
{
[]clause.Interface{clause.With{
Recursive: false,
Exprs: nil,
Expression: nil,
}},
"WITH",
nil,
},
{
[]clause.Interface{clause.With{Exprs: []clause.Expression{clause.WithExpression{
Name: "",
Expr: clause.Expr{SQL: "SELECT 1"},
}}}},
"WITH",
nil,
},
{
[]clause.Interface{clause.With{Exprs: []clause.Expression{clause.WithExpression{
Name: "cte1",
Expr: nil,
}}}},
"WITH",
nil,
},
}
for idx, result := range results {
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
})
}
}