Add with clause support
This commit is contained in:
		
							parent
							
								
									61b4c31236
								
							
						
					
					
						commit
						825bbf35ec
					
				@ -6,7 +6,7 @@ import (
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
var (
 | 
					var (
 | 
				
			||||||
	createClauses = []string{"INSERT", "VALUES", "ON CONFLICT"}
 | 
						createClauses = []string{"INSERT", "VALUES", "ON CONFLICT"}
 | 
				
			||||||
	queryClauses  = []string{"SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR"}
 | 
						queryClauses  = []string{"WITH", "SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR"}
 | 
				
			||||||
	updateClauses = []string{"UPDATE", "SET", "WHERE"}
 | 
						updateClauses = []string{"UPDATE", "SET", "WHERE"}
 | 
				
			||||||
	deleteClauses = []string{"DELETE", "FROM", "WHERE"}
 | 
						deleteClauses = []string{"DELETE", "FROM", "WHERE"}
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										104
									
								
								chainable_api.go
									
									
									
									
									
								
							
							
						
						
									
										104
									
								
								chainable_api.go
									
									
									
									
									
								
							@ -1,6 +1,7 @@
 | 
				
			|||||||
package gorm
 | 
					package gorm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"regexp"
 | 
						"regexp"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
@ -61,6 +62,109 @@ func (db *DB) Table(name string, args ...interface{}) (tx *DB) {
 | 
				
			|||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Recursive specify recursive common table expressions
 | 
				
			||||||
 | 
					func (db *DB) Recursive(name string, args ...interface{}) (tx *DB) {
 | 
				
			||||||
 | 
						return db.with(name, true, args...)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// With specify common table expressions
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//	db.With("cte", "SELECT id, `name` FROM users")
 | 
				
			||||||
 | 
					//	db.With("cte", gorm.Expr("SELECT id, `name` FROM users"))
 | 
				
			||||||
 | 
					//	db.With("cte", db.Select("id", "name").Table("users"))
 | 
				
			||||||
 | 
					//	db.With("cte", "id", "name", "SELECT id, `name` FROM users")
 | 
				
			||||||
 | 
					//	db.With("cte", "id, name", "SELECT id, `name` FROM users")
 | 
				
			||||||
 | 
					//	db.With("cte", []string{"id", "name"}, "SELECT id, `name` FROM users")
 | 
				
			||||||
 | 
					func (db *DB) With(name string, args ...interface{}) (tx *DB) {
 | 
				
			||||||
 | 
						return db.with(name, false, args...)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// with clause
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//	db.with("cte", false, "SELECT id, `name` FROM users")
 | 
				
			||||||
 | 
					//	db.with("cte", false, gorm.Expr("SELECT id, `name` FROM users"))
 | 
				
			||||||
 | 
					//	db.with("cte", false, db.Select("id", "name").Table("users"))
 | 
				
			||||||
 | 
					//	db.with("cte", false, "id", "name", "SELECT id, `name` FROM users")
 | 
				
			||||||
 | 
					//	db.with("cte", false, "id, name", "SELECT id, `name` FROM users")
 | 
				
			||||||
 | 
					//	db.with("cte", false, []string{"id", "name"}, "SELECT id, `name` FROM users")
 | 
				
			||||||
 | 
					func (db *DB) with(name string, recursive bool, args ...interface{}) (tx *DB) {
 | 
				
			||||||
 | 
						tx = db.getInstance()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if len(args) == 0 {
 | 
				
			||||||
 | 
							tx.AddError(errors.New("with clause must have subquery"))
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// parse optional columns
 | 
				
			||||||
 | 
						columns := make([]string, 0, len(args)-1)
 | 
				
			||||||
 | 
						switch arg := args[0].(type) {
 | 
				
			||||||
 | 
						case string:
 | 
				
			||||||
 | 
							switch len(args) {
 | 
				
			||||||
 | 
							case 1:
 | 
				
			||||||
 | 
								// args[0] is subquery
 | 
				
			||||||
 | 
								break
 | 
				
			||||||
 | 
							case 2:
 | 
				
			||||||
 | 
								// args[0] is single column or comma-separated string
 | 
				
			||||||
 | 
								columns = strings.FieldsFunc(arg, func(r rune) bool {
 | 
				
			||||||
 | 
									return r == ',' || r == ' '
 | 
				
			||||||
 | 
								})
 | 
				
			||||||
 | 
							default:
 | 
				
			||||||
 | 
								// args[:len(args)-1] is []string
 | 
				
			||||||
 | 
								for idx := 0; idx < len(args)-1; idx++ {
 | 
				
			||||||
 | 
									col, ok := args[idx].(string)
 | 
				
			||||||
 | 
									if !ok {
 | 
				
			||||||
 | 
										tx.AddError(fmt.Errorf("unsupported with column type %T", args[idx]))
 | 
				
			||||||
 | 
										return
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									if strings.Contains(col, ",") {
 | 
				
			||||||
 | 
										tx.AddError(fmt.Errorf("unsupported mixed string type"))
 | 
				
			||||||
 | 
										return
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									columns = append(columns, col)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						case []string:
 | 
				
			||||||
 | 
							for _, col := range arg {
 | 
				
			||||||
 | 
								if strings.Contains(col, ",") {
 | 
				
			||||||
 | 
									tx.AddError(fmt.Errorf("unsupported mixed string type"))
 | 
				
			||||||
 | 
									return
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							columns = arg
 | 
				
			||||||
 | 
						default:
 | 
				
			||||||
 | 
							// specify optional columns field with a wrong type
 | 
				
			||||||
 | 
							if len(args) > 1 {
 | 
				
			||||||
 | 
								tx.AddError(fmt.Errorf("unsupported with column type %T", arg))
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// parse subquery
 | 
				
			||||||
 | 
						var subquery clause.Expression
 | 
				
			||||||
 | 
						switch arg := args[len(args)-1].(type) {
 | 
				
			||||||
 | 
						case clause.Expression:
 | 
				
			||||||
 | 
							subquery = arg
 | 
				
			||||||
 | 
						case string:
 | 
				
			||||||
 | 
							subquery = clause.Expr{SQL: arg}
 | 
				
			||||||
 | 
						case *DB:
 | 
				
			||||||
 | 
							subquery = clause.Expr{SQL: "?", Vars: []interface{}{arg}, WithoutParentheses: true}
 | 
				
			||||||
 | 
						default:
 | 
				
			||||||
 | 
							tx.AddError(fmt.Errorf("unsupported with subquery type %T", arg))
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// add with clause
 | 
				
			||||||
 | 
						tx.Statement.AddClause(clause.With{
 | 
				
			||||||
 | 
							Recursive: recursive,
 | 
				
			||||||
 | 
							Exprs: []clause.Expression{clause.WithExpression{
 | 
				
			||||||
 | 
								Name:    name,
 | 
				
			||||||
 | 
								Columns: columns,
 | 
				
			||||||
 | 
								Expr:    subquery,
 | 
				
			||||||
 | 
							}},
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Distinct specify distinct fields that you want querying
 | 
					// Distinct specify distinct fields that you want querying
 | 
				
			||||||
func (db *DB) Distinct(args ...interface{}) (tx *DB) {
 | 
					func (db *DB) Distinct(args ...interface{}) (tx *DB) {
 | 
				
			||||||
	tx = db.getInstance()
 | 
						tx = db.getInstance()
 | 
				
			||||||
 | 
				
			|||||||
@ -17,15 +17,15 @@ var db, _ = gorm.Open(tests.DummyDialector{}, nil)
 | 
				
			|||||||
func checkBuildClauses(t *testing.T, clauses []clause.Interface, result string, vars []interface{}) {
 | 
					func checkBuildClauses(t *testing.T, clauses []clause.Interface, result string, vars []interface{}) {
 | 
				
			||||||
	var (
 | 
						var (
 | 
				
			||||||
		buildNames    []string
 | 
							buildNames    []string
 | 
				
			||||||
		buildNamesMap = map[string]bool{}
 | 
							buildNamesSet = map[string]struct{}{}
 | 
				
			||||||
		user, _       = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
 | 
							user, _       = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
 | 
				
			||||||
		stmt          = gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
 | 
							stmt          = gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
 | 
				
			||||||
	)
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for _, c := range clauses {
 | 
						for _, c := range clauses {
 | 
				
			||||||
		if _, ok := buildNamesMap[c.Name()]; !ok {
 | 
							if _, ok := buildNamesSet[c.Name()]; !ok {
 | 
				
			||||||
			buildNames = append(buildNames, c.Name())
 | 
								buildNames = append(buildNames, c.Name())
 | 
				
			||||||
			buildNamesMap[c.Name()] = true
 | 
								buildNamesSet[c.Name()] = struct{}{}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		stmt.AddClause(c)
 | 
							stmt.AddClause(c)
 | 
				
			||||||
 | 
				
			|||||||
@ -1,3 +1,88 @@
 | 
				
			|||||||
package clause
 | 
					package clause
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type With struct{}
 | 
					import "strings"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// With Common Table Expressions
 | 
				
			||||||
 | 
					type With struct {
 | 
				
			||||||
 | 
						Recursive  bool
 | 
				
			||||||
 | 
						Exprs      []Expression
 | 
				
			||||||
 | 
						Expression Expression
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Name with clause name
 | 
				
			||||||
 | 
					func (with With) Name() string {
 | 
				
			||||||
 | 
						return "WITH"
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Build build with clause
 | 
				
			||||||
 | 
					func (with With) Build(builder Builder) {
 | 
				
			||||||
 | 
						if with.Expression != nil {
 | 
				
			||||||
 | 
							with.Expression.Build(builder)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if len(with.Exprs) == 0 {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if with.Recursive {
 | 
				
			||||||
 | 
							builder.WriteString("RECURSIVE ")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for idx, expr := range with.Exprs {
 | 
				
			||||||
 | 
							if idx != 0 {
 | 
				
			||||||
 | 
								builder.WriteByte(',')
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							expr.Build(builder)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// MergeClause merge with clauses
 | 
				
			||||||
 | 
					func (with With) MergeClause(clause *Clause) {
 | 
				
			||||||
 | 
						if w, ok := clause.Expression.(With); ok {
 | 
				
			||||||
 | 
							if !with.Recursive {
 | 
				
			||||||
 | 
								with.Recursive = w.Recursive
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if w.Expression != nil {
 | 
				
			||||||
 | 
								with.Expression = w.Expression
 | 
				
			||||||
 | 
								with.Exprs = nil
 | 
				
			||||||
 | 
							} else if with.Expression == nil {
 | 
				
			||||||
 | 
								exprs := make([]Expression, len(w.Exprs)+len(with.Exprs))
 | 
				
			||||||
 | 
								copy(exprs, w.Exprs)
 | 
				
			||||||
 | 
								copy(exprs[len(w.Exprs):], with.Exprs)
 | 
				
			||||||
 | 
								with.Exprs = exprs
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						clause.Expression = with
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// WithExpression with expression
 | 
				
			||||||
 | 
					type WithExpression struct {
 | 
				
			||||||
 | 
						Name    string
 | 
				
			||||||
 | 
						Columns []string
 | 
				
			||||||
 | 
						Expr    Expression
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (with WithExpression) Build(builder Builder) {
 | 
				
			||||||
 | 
						if with.Name == "" || with.Expr == nil {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						builder.WriteQuoted(with.Name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if len(with.Columns) > 0 {
 | 
				
			||||||
 | 
							builder.WriteByte(' ')
 | 
				
			||||||
 | 
							builder.WriteByte('(')
 | 
				
			||||||
 | 
							for idx, column := range with.Columns {
 | 
				
			||||||
 | 
								if idx != 0 {
 | 
				
			||||||
 | 
									builder.WriteByte(',')
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								column = strings.TrimSpace(column)
 | 
				
			||||||
 | 
								builder.WriteQuoted(column)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							builder.WriteByte(')')
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						builder.WriteString(" AS ")
 | 
				
			||||||
 | 
						builder.WriteByte('(')
 | 
				
			||||||
 | 
						with.Expr.Build(builder)
 | 
				
			||||||
 | 
						builder.WriteByte(')')
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										104
									
								
								clause/with_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										104
									
								
								clause/with_test.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,104 @@
 | 
				
			|||||||
 | 
					package clause_test
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"gorm.io/gorm/clause"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestWith(t *testing.T) {
 | 
				
			||||||
 | 
						results := []struct {
 | 
				
			||||||
 | 
							Clauses []clause.Interface
 | 
				
			||||||
 | 
							Result  string
 | 
				
			||||||
 | 
							Vars    []interface{}
 | 
				
			||||||
 | 
						}{
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								[]clause.Interface{clause.With{Exprs: []clause.Expression{
 | 
				
			||||||
 | 
									clause.WithExpression{Name: "cte1", Columns: []string{"foo", "bar"}, Expr: clause.Expr{SQL: "SELECT 1, 2"}},
 | 
				
			||||||
 | 
									clause.WithExpression{Name: "cte2", Expr: clause.Expr{SQL: "SELECT 1"}},
 | 
				
			||||||
 | 
								}}},
 | 
				
			||||||
 | 
								"WITH `cte1` (`foo`,`bar`) AS (SELECT 1, 2),`cte2` AS (SELECT 1)",
 | 
				
			||||||
 | 
								nil,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								[]clause.Interface{
 | 
				
			||||||
 | 
									clause.With{
 | 
				
			||||||
 | 
										Recursive: true,
 | 
				
			||||||
 | 
										Exprs:     []clause.Expression{clause.WithExpression{Name: "cte1", Expr: clause.Expr{SQL: "SELECT 1"}}},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
									clause.With{
 | 
				
			||||||
 | 
										Recursive: false,
 | 
				
			||||||
 | 
										Exprs:     []clause.Expression{clause.WithExpression{Name: "cte2", Expr: clause.Expr{SQL: "SELECT 2"}}},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								"WITH RECURSIVE `cte1` AS (SELECT 1),`cte2` AS (SELECT 2)",
 | 
				
			||||||
 | 
								nil,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								[]clause.Interface{
 | 
				
			||||||
 | 
									clause.With{
 | 
				
			||||||
 | 
										Expression: clause.Expr{SQL: "`cte1` AS (SELECT 1)"},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
									clause.With{
 | 
				
			||||||
 | 
										Exprs: []clause.Expression{clause.WithExpression{Name: "cte2", Expr: clause.Expr{SQL: "SELECT 2"}}},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								"WITH `cte1` AS (SELECT 1)",
 | 
				
			||||||
 | 
								nil,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								[]clause.Interface{
 | 
				
			||||||
 | 
									clause.With{
 | 
				
			||||||
 | 
										Exprs: []clause.Expression{clause.WithExpression{Name: "cte1", Expr: clause.Expr{SQL: "SELECT 1"}}},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
									clause.With{
 | 
				
			||||||
 | 
										Expression: clause.Expr{SQL: "`cte2` AS (SELECT 2)"},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								"WITH `cte2` AS (SELECT 2)",
 | 
				
			||||||
 | 
								nil,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								[]clause.Interface{clause.With{
 | 
				
			||||||
 | 
									Exprs: []clause.Expression{
 | 
				
			||||||
 | 
										clause.WithExpression{Name: "cte1", Columns: []string{"foo", "bar"}, Expr: clause.Expr{SQL: "SELECT 1, 2"}},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
									Expression: clause.Expr{SQL: "cte2 AS (SELECT 1)"},
 | 
				
			||||||
 | 
								}},
 | 
				
			||||||
 | 
								"WITH cte2 AS (SELECT 1)",
 | 
				
			||||||
 | 
								nil,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								[]clause.Interface{clause.With{
 | 
				
			||||||
 | 
									Recursive:  false,
 | 
				
			||||||
 | 
									Exprs:      nil,
 | 
				
			||||||
 | 
									Expression: nil,
 | 
				
			||||||
 | 
								}},
 | 
				
			||||||
 | 
								"WITH",
 | 
				
			||||||
 | 
								nil,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								[]clause.Interface{clause.With{Exprs: []clause.Expression{clause.WithExpression{
 | 
				
			||||||
 | 
									Name: "",
 | 
				
			||||||
 | 
									Expr: clause.Expr{SQL: "SELECT 1"},
 | 
				
			||||||
 | 
								}}}},
 | 
				
			||||||
 | 
								"WITH",
 | 
				
			||||||
 | 
								nil,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								[]clause.Interface{clause.With{Exprs: []clause.Expression{clause.WithExpression{
 | 
				
			||||||
 | 
									Name: "cte1",
 | 
				
			||||||
 | 
									Expr: nil,
 | 
				
			||||||
 | 
								}}}},
 | 
				
			||||||
 | 
								"WITH",
 | 
				
			||||||
 | 
								nil,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for idx, result := range results {
 | 
				
			||||||
 | 
							t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
 | 
				
			||||||
 | 
								checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user