Add with clause support
This commit is contained in:
parent
61b4c31236
commit
825bbf35ec
@ -6,7 +6,7 @@ import (
|
||||
|
||||
var (
|
||||
createClauses = []string{"INSERT", "VALUES", "ON CONFLICT"}
|
||||
queryClauses = []string{"SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR"}
|
||||
queryClauses = []string{"WITH", "SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR"}
|
||||
updateClauses = []string{"UPDATE", "SET", "WHERE"}
|
||||
deleteClauses = []string{"DELETE", "FROM", "WHERE"}
|
||||
)
|
||||
|
104
chainable_api.go
104
chainable_api.go
@ -1,6 +1,7 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
@ -61,6 +62,109 @@ func (db *DB) Table(name string, args ...interface{}) (tx *DB) {
|
||||
return
|
||||
}
|
||||
|
||||
// Recursive specify recursive common table expressions
|
||||
func (db *DB) Recursive(name string, args ...interface{}) (tx *DB) {
|
||||
return db.with(name, true, args...)
|
||||
}
|
||||
|
||||
// With specify common table expressions
|
||||
//
|
||||
// db.With("cte", "SELECT id, `name` FROM users")
|
||||
// db.With("cte", gorm.Expr("SELECT id, `name` FROM users"))
|
||||
// db.With("cte", db.Select("id", "name").Table("users"))
|
||||
// db.With("cte", "id", "name", "SELECT id, `name` FROM users")
|
||||
// db.With("cte", "id, name", "SELECT id, `name` FROM users")
|
||||
// db.With("cte", []string{"id", "name"}, "SELECT id, `name` FROM users")
|
||||
func (db *DB) With(name string, args ...interface{}) (tx *DB) {
|
||||
return db.with(name, false, args...)
|
||||
}
|
||||
|
||||
// with clause
|
||||
//
|
||||
// db.with("cte", false, "SELECT id, `name` FROM users")
|
||||
// db.with("cte", false, gorm.Expr("SELECT id, `name` FROM users"))
|
||||
// db.with("cte", false, db.Select("id", "name").Table("users"))
|
||||
// db.with("cte", false, "id", "name", "SELECT id, `name` FROM users")
|
||||
// db.with("cte", false, "id, name", "SELECT id, `name` FROM users")
|
||||
// db.with("cte", false, []string{"id", "name"}, "SELECT id, `name` FROM users")
|
||||
func (db *DB) with(name string, recursive bool, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
|
||||
if len(args) == 0 {
|
||||
tx.AddError(errors.New("with clause must have subquery"))
|
||||
return
|
||||
}
|
||||
|
||||
// parse optional columns
|
||||
columns := make([]string, 0, len(args)-1)
|
||||
switch arg := args[0].(type) {
|
||||
case string:
|
||||
switch len(args) {
|
||||
case 1:
|
||||
// args[0] is subquery
|
||||
break
|
||||
case 2:
|
||||
// args[0] is single column or comma-separated string
|
||||
columns = strings.FieldsFunc(arg, func(r rune) bool {
|
||||
return r == ',' || r == ' '
|
||||
})
|
||||
default:
|
||||
// args[:len(args)-1] is []string
|
||||
for idx := 0; idx < len(args)-1; idx++ {
|
||||
col, ok := args[idx].(string)
|
||||
if !ok {
|
||||
tx.AddError(fmt.Errorf("unsupported with column type %T", args[idx]))
|
||||
return
|
||||
}
|
||||
if strings.Contains(col, ",") {
|
||||
tx.AddError(fmt.Errorf("unsupported mixed string type"))
|
||||
return
|
||||
}
|
||||
columns = append(columns, col)
|
||||
}
|
||||
}
|
||||
case []string:
|
||||
for _, col := range arg {
|
||||
if strings.Contains(col, ",") {
|
||||
tx.AddError(fmt.Errorf("unsupported mixed string type"))
|
||||
return
|
||||
}
|
||||
}
|
||||
columns = arg
|
||||
default:
|
||||
// specify optional columns field with a wrong type
|
||||
if len(args) > 1 {
|
||||
tx.AddError(fmt.Errorf("unsupported with column type %T", arg))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// parse subquery
|
||||
var subquery clause.Expression
|
||||
switch arg := args[len(args)-1].(type) {
|
||||
case clause.Expression:
|
||||
subquery = arg
|
||||
case string:
|
||||
subquery = clause.Expr{SQL: arg}
|
||||
case *DB:
|
||||
subquery = clause.Expr{SQL: "?", Vars: []interface{}{arg}, WithoutParentheses: true}
|
||||
default:
|
||||
tx.AddError(fmt.Errorf("unsupported with subquery type %T", arg))
|
||||
return
|
||||
}
|
||||
|
||||
// add with clause
|
||||
tx.Statement.AddClause(clause.With{
|
||||
Recursive: recursive,
|
||||
Exprs: []clause.Expression{clause.WithExpression{
|
||||
Name: name,
|
||||
Columns: columns,
|
||||
Expr: subquery,
|
||||
}},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Distinct specify distinct fields that you want querying
|
||||
func (db *DB) Distinct(args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
|
@ -17,15 +17,15 @@ var db, _ = gorm.Open(tests.DummyDialector{}, nil)
|
||||
func checkBuildClauses(t *testing.T, clauses []clause.Interface, result string, vars []interface{}) {
|
||||
var (
|
||||
buildNames []string
|
||||
buildNamesMap = map[string]bool{}
|
||||
buildNamesSet = map[string]struct{}{}
|
||||
user, _ = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
|
||||
stmt = gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
|
||||
)
|
||||
|
||||
for _, c := range clauses {
|
||||
if _, ok := buildNamesMap[c.Name()]; !ok {
|
||||
if _, ok := buildNamesSet[c.Name()]; !ok {
|
||||
buildNames = append(buildNames, c.Name())
|
||||
buildNamesMap[c.Name()] = true
|
||||
buildNamesSet[c.Name()] = struct{}{}
|
||||
}
|
||||
|
||||
stmt.AddClause(c)
|
||||
|
@ -1,3 +1,88 @@
|
||||
package clause
|
||||
|
||||
type With struct{}
|
||||
import "strings"
|
||||
|
||||
// With Common Table Expressions
|
||||
type With struct {
|
||||
Recursive bool
|
||||
Exprs []Expression
|
||||
Expression Expression
|
||||
}
|
||||
|
||||
// Name with clause name
|
||||
func (with With) Name() string {
|
||||
return "WITH"
|
||||
}
|
||||
|
||||
// Build build with clause
|
||||
func (with With) Build(builder Builder) {
|
||||
if with.Expression != nil {
|
||||
with.Expression.Build(builder)
|
||||
return
|
||||
}
|
||||
if len(with.Exprs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if with.Recursive {
|
||||
builder.WriteString("RECURSIVE ")
|
||||
}
|
||||
for idx, expr := range with.Exprs {
|
||||
if idx != 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
expr.Build(builder)
|
||||
}
|
||||
}
|
||||
|
||||
// MergeClause merge with clauses
|
||||
func (with With) MergeClause(clause *Clause) {
|
||||
if w, ok := clause.Expression.(With); ok {
|
||||
if !with.Recursive {
|
||||
with.Recursive = w.Recursive
|
||||
}
|
||||
if w.Expression != nil {
|
||||
with.Expression = w.Expression
|
||||
with.Exprs = nil
|
||||
} else if with.Expression == nil {
|
||||
exprs := make([]Expression, len(w.Exprs)+len(with.Exprs))
|
||||
copy(exprs, w.Exprs)
|
||||
copy(exprs[len(w.Exprs):], with.Exprs)
|
||||
with.Exprs = exprs
|
||||
}
|
||||
}
|
||||
clause.Expression = with
|
||||
}
|
||||
|
||||
// WithExpression with expression
|
||||
type WithExpression struct {
|
||||
Name string
|
||||
Columns []string
|
||||
Expr Expression
|
||||
}
|
||||
|
||||
func (with WithExpression) Build(builder Builder) {
|
||||
if with.Name == "" || with.Expr == nil {
|
||||
return
|
||||
}
|
||||
|
||||
builder.WriteQuoted(with.Name)
|
||||
|
||||
if len(with.Columns) > 0 {
|
||||
builder.WriteByte(' ')
|
||||
builder.WriteByte('(')
|
||||
for idx, column := range with.Columns {
|
||||
if idx != 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
column = strings.TrimSpace(column)
|
||||
builder.WriteQuoted(column)
|
||||
}
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
|
||||
builder.WriteString(" AS ")
|
||||
builder.WriteByte('(')
|
||||
with.Expr.Build(builder)
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
|
104
clause/with_test.go
Normal file
104
clause/with_test.go
Normal file
@ -0,0 +1,104 @@
|
||||
package clause_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
func TestWith(t *testing.T) {
|
||||
results := []struct {
|
||||
Clauses []clause.Interface
|
||||
Result string
|
||||
Vars []interface{}
|
||||
}{
|
||||
{
|
||||
[]clause.Interface{clause.With{Exprs: []clause.Expression{
|
||||
clause.WithExpression{Name: "cte1", Columns: []string{"foo", "bar"}, Expr: clause.Expr{SQL: "SELECT 1, 2"}},
|
||||
clause.WithExpression{Name: "cte2", Expr: clause.Expr{SQL: "SELECT 1"}},
|
||||
}}},
|
||||
"WITH `cte1` (`foo`,`bar`) AS (SELECT 1, 2),`cte2` AS (SELECT 1)",
|
||||
nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{
|
||||
clause.With{
|
||||
Recursive: true,
|
||||
Exprs: []clause.Expression{clause.WithExpression{Name: "cte1", Expr: clause.Expr{SQL: "SELECT 1"}}},
|
||||
},
|
||||
clause.With{
|
||||
Recursive: false,
|
||||
Exprs: []clause.Expression{clause.WithExpression{Name: "cte2", Expr: clause.Expr{SQL: "SELECT 2"}}},
|
||||
},
|
||||
},
|
||||
"WITH RECURSIVE `cte1` AS (SELECT 1),`cte2` AS (SELECT 2)",
|
||||
nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{
|
||||
clause.With{
|
||||
Expression: clause.Expr{SQL: "`cte1` AS (SELECT 1)"},
|
||||
},
|
||||
clause.With{
|
||||
Exprs: []clause.Expression{clause.WithExpression{Name: "cte2", Expr: clause.Expr{SQL: "SELECT 2"}}},
|
||||
},
|
||||
},
|
||||
"WITH `cte1` AS (SELECT 1)",
|
||||
nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{
|
||||
clause.With{
|
||||
Exprs: []clause.Expression{clause.WithExpression{Name: "cte1", Expr: clause.Expr{SQL: "SELECT 1"}}},
|
||||
},
|
||||
clause.With{
|
||||
Expression: clause.Expr{SQL: "`cte2` AS (SELECT 2)"},
|
||||
},
|
||||
},
|
||||
"WITH `cte2` AS (SELECT 2)",
|
||||
nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.With{
|
||||
Exprs: []clause.Expression{
|
||||
clause.WithExpression{Name: "cte1", Columns: []string{"foo", "bar"}, Expr: clause.Expr{SQL: "SELECT 1, 2"}},
|
||||
},
|
||||
Expression: clause.Expr{SQL: "cte2 AS (SELECT 1)"},
|
||||
}},
|
||||
"WITH cte2 AS (SELECT 1)",
|
||||
nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.With{
|
||||
Recursive: false,
|
||||
Exprs: nil,
|
||||
Expression: nil,
|
||||
}},
|
||||
"WITH",
|
||||
nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.With{Exprs: []clause.Expression{clause.WithExpression{
|
||||
Name: "",
|
||||
Expr: clause.Expr{SQL: "SELECT 1"},
|
||||
}}}},
|
||||
"WITH",
|
||||
nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.With{Exprs: []clause.Expression{clause.WithExpression{
|
||||
Name: "cte1",
|
||||
Expr: nil,
|
||||
}}}},
|
||||
"WITH",
|
||||
nil,
|
||||
},
|
||||
}
|
||||
|
||||
for idx, result := range results {
|
||||
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
|
||||
checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
|
||||
})
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user