From 825bbf35ecdbafabe57b100e0480e3ba1b6f0fd8 Mon Sep 17 00:00:00 2001 From: kvii2202 <13721952202@163.com> Date: Wed, 16 Mar 2022 19:54:12 +0800 Subject: [PATCH] Add with clause support --- callbacks/callbacks.go | 2 +- chainable_api.go | 104 +++++++++++++++++++++++++++++++++++++++++ clause/clause_test.go | 6 +-- clause/with.go | 87 +++++++++++++++++++++++++++++++++- clause/with_test.go | 104 +++++++++++++++++++++++++++++++++++++++++ 5 files changed, 298 insertions(+), 5 deletions(-) create mode 100644 clause/with_test.go diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go index d681aef3..2185a0e0 100644 --- a/callbacks/callbacks.go +++ b/callbacks/callbacks.go @@ -6,7 +6,7 @@ import ( var ( createClauses = []string{"INSERT", "VALUES", "ON CONFLICT"} - queryClauses = []string{"SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR"} + queryClauses = []string{"WITH", "SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR"} updateClauses = []string{"UPDATE", "SET", "WHERE"} deleteClauses = []string{"DELETE", "FROM", "WHERE"} ) diff --git a/chainable_api.go b/chainable_api.go index 173479d3..aff735c6 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -1,6 +1,7 @@ package gorm import ( + "errors" "fmt" "regexp" "strings" @@ -61,6 +62,109 @@ func (db *DB) Table(name string, args ...interface{}) (tx *DB) { return } +// Recursive specify recursive common table expressions +func (db *DB) Recursive(name string, args ...interface{}) (tx *DB) { + return db.with(name, true, args...) +} + +// With specify common table expressions +// +// db.With("cte", "SELECT id, `name` FROM users") +// db.With("cte", gorm.Expr("SELECT id, `name` FROM users")) +// db.With("cte", db.Select("id", "name").Table("users")) +// db.With("cte", "id", "name", "SELECT id, `name` FROM users") +// db.With("cte", "id, name", "SELECT id, `name` FROM users") +// db.With("cte", []string{"id", "name"}, "SELECT id, `name` FROM users") +func (db *DB) With(name string, args ...interface{}) (tx *DB) { + return db.with(name, false, args...) +} + +// with clause +// +// db.with("cte", false, "SELECT id, `name` FROM users") +// db.with("cte", false, gorm.Expr("SELECT id, `name` FROM users")) +// db.with("cte", false, db.Select("id", "name").Table("users")) +// db.with("cte", false, "id", "name", "SELECT id, `name` FROM users") +// db.with("cte", false, "id, name", "SELECT id, `name` FROM users") +// db.with("cte", false, []string{"id", "name"}, "SELECT id, `name` FROM users") +func (db *DB) with(name string, recursive bool, args ...interface{}) (tx *DB) { + tx = db.getInstance() + + if len(args) == 0 { + tx.AddError(errors.New("with clause must have subquery")) + return + } + + // parse optional columns + columns := make([]string, 0, len(args)-1) + switch arg := args[0].(type) { + case string: + switch len(args) { + case 1: + // args[0] is subquery + break + case 2: + // args[0] is single column or comma-separated string + columns = strings.FieldsFunc(arg, func(r rune) bool { + return r == ',' || r == ' ' + }) + default: + // args[:len(args)-1] is []string + for idx := 0; idx < len(args)-1; idx++ { + col, ok := args[idx].(string) + if !ok { + tx.AddError(fmt.Errorf("unsupported with column type %T", args[idx])) + return + } + if strings.Contains(col, ",") { + tx.AddError(fmt.Errorf("unsupported mixed string type")) + return + } + columns = append(columns, col) + } + } + case []string: + for _, col := range arg { + if strings.Contains(col, ",") { + tx.AddError(fmt.Errorf("unsupported mixed string type")) + return + } + } + columns = arg + default: + // specify optional columns field with a wrong type + if len(args) > 1 { + tx.AddError(fmt.Errorf("unsupported with column type %T", arg)) + return + } + } + + // parse subquery + var subquery clause.Expression + switch arg := args[len(args)-1].(type) { + case clause.Expression: + subquery = arg + case string: + subquery = clause.Expr{SQL: arg} + case *DB: + subquery = clause.Expr{SQL: "?", Vars: []interface{}{arg}, WithoutParentheses: true} + default: + tx.AddError(fmt.Errorf("unsupported with subquery type %T", arg)) + return + } + + // add with clause + tx.Statement.AddClause(clause.With{ + Recursive: recursive, + Exprs: []clause.Expression{clause.WithExpression{ + Name: name, + Columns: columns, + Expr: subquery, + }}, + }) + return +} + // Distinct specify distinct fields that you want querying func (db *DB) Distinct(args ...interface{}) (tx *DB) { tx = db.getInstance() diff --git a/clause/clause_test.go b/clause/clause_test.go index 6239ff39..f516f2fa 100644 --- a/clause/clause_test.go +++ b/clause/clause_test.go @@ -17,15 +17,15 @@ var db, _ = gorm.Open(tests.DummyDialector{}, nil) func checkBuildClauses(t *testing.T, clauses []clause.Interface, result string, vars []interface{}) { var ( buildNames []string - buildNamesMap = map[string]bool{} + buildNamesSet = map[string]struct{}{} user, _ = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) stmt = gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} ) for _, c := range clauses { - if _, ok := buildNamesMap[c.Name()]; !ok { + if _, ok := buildNamesSet[c.Name()]; !ok { buildNames = append(buildNames, c.Name()) - buildNamesMap[c.Name()] = true + buildNamesSet[c.Name()] = struct{}{} } stmt.AddClause(c) diff --git a/clause/with.go b/clause/with.go index 0768488e..3ddd77f0 100644 --- a/clause/with.go +++ b/clause/with.go @@ -1,3 +1,88 @@ package clause -type With struct{} +import "strings" + +// With Common Table Expressions +type With struct { + Recursive bool + Exprs []Expression + Expression Expression +} + +// Name with clause name +func (with With) Name() string { + return "WITH" +} + +// Build build with clause +func (with With) Build(builder Builder) { + if with.Expression != nil { + with.Expression.Build(builder) + return + } + if len(with.Exprs) == 0 { + return + } + + if with.Recursive { + builder.WriteString("RECURSIVE ") + } + for idx, expr := range with.Exprs { + if idx != 0 { + builder.WriteByte(',') + } + expr.Build(builder) + } +} + +// MergeClause merge with clauses +func (with With) MergeClause(clause *Clause) { + if w, ok := clause.Expression.(With); ok { + if !with.Recursive { + with.Recursive = w.Recursive + } + if w.Expression != nil { + with.Expression = w.Expression + with.Exprs = nil + } else if with.Expression == nil { + exprs := make([]Expression, len(w.Exprs)+len(with.Exprs)) + copy(exprs, w.Exprs) + copy(exprs[len(w.Exprs):], with.Exprs) + with.Exprs = exprs + } + } + clause.Expression = with +} + +// WithExpression with expression +type WithExpression struct { + Name string + Columns []string + Expr Expression +} + +func (with WithExpression) Build(builder Builder) { + if with.Name == "" || with.Expr == nil { + return + } + + builder.WriteQuoted(with.Name) + + if len(with.Columns) > 0 { + builder.WriteByte(' ') + builder.WriteByte('(') + for idx, column := range with.Columns { + if idx != 0 { + builder.WriteByte(',') + } + column = strings.TrimSpace(column) + builder.WriteQuoted(column) + } + builder.WriteByte(')') + } + + builder.WriteString(" AS ") + builder.WriteByte('(') + with.Expr.Build(builder) + builder.WriteByte(')') +} diff --git a/clause/with_test.go b/clause/with_test.go new file mode 100644 index 00000000..199eb66f --- /dev/null +++ b/clause/with_test.go @@ -0,0 +1,104 @@ +package clause_test + +import ( + "fmt" + "testing" + + "gorm.io/gorm/clause" +) + +func TestWith(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.With{Exprs: []clause.Expression{ + clause.WithExpression{Name: "cte1", Columns: []string{"foo", "bar"}, Expr: clause.Expr{SQL: "SELECT 1, 2"}}, + clause.WithExpression{Name: "cte2", Expr: clause.Expr{SQL: "SELECT 1"}}, + }}}, + "WITH `cte1` (`foo`,`bar`) AS (SELECT 1, 2),`cte2` AS (SELECT 1)", + nil, + }, + { + []clause.Interface{ + clause.With{ + Recursive: true, + Exprs: []clause.Expression{clause.WithExpression{Name: "cte1", Expr: clause.Expr{SQL: "SELECT 1"}}}, + }, + clause.With{ + Recursive: false, + Exprs: []clause.Expression{clause.WithExpression{Name: "cte2", Expr: clause.Expr{SQL: "SELECT 2"}}}, + }, + }, + "WITH RECURSIVE `cte1` AS (SELECT 1),`cte2` AS (SELECT 2)", + nil, + }, + { + []clause.Interface{ + clause.With{ + Expression: clause.Expr{SQL: "`cte1` AS (SELECT 1)"}, + }, + clause.With{ + Exprs: []clause.Expression{clause.WithExpression{Name: "cte2", Expr: clause.Expr{SQL: "SELECT 2"}}}, + }, + }, + "WITH `cte1` AS (SELECT 1)", + nil, + }, + { + []clause.Interface{ + clause.With{ + Exprs: []clause.Expression{clause.WithExpression{Name: "cte1", Expr: clause.Expr{SQL: "SELECT 1"}}}, + }, + clause.With{ + Expression: clause.Expr{SQL: "`cte2` AS (SELECT 2)"}, + }, + }, + "WITH `cte2` AS (SELECT 2)", + nil, + }, + { + []clause.Interface{clause.With{ + Exprs: []clause.Expression{ + clause.WithExpression{Name: "cte1", Columns: []string{"foo", "bar"}, Expr: clause.Expr{SQL: "SELECT 1, 2"}}, + }, + Expression: clause.Expr{SQL: "cte2 AS (SELECT 1)"}, + }}, + "WITH cte2 AS (SELECT 1)", + nil, + }, + { + []clause.Interface{clause.With{ + Recursive: false, + Exprs: nil, + Expression: nil, + }}, + "WITH", + nil, + }, + { + []clause.Interface{clause.With{Exprs: []clause.Expression{clause.WithExpression{ + Name: "", + Expr: clause.Expr{SQL: "SELECT 1"}, + }}}}, + "WITH", + nil, + }, + { + []clause.Interface{clause.With{Exprs: []clause.Expression{clause.WithExpression{ + Name: "cte1", + Expr: nil, + }}}}, + "WITH", + nil, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +}