Add with clause support
This commit is contained in:
		
							parent
							
								
									61b4c31236
								
							
						
					
					
						commit
						825bbf35ec
					
				@ -6,7 +6,7 @@ import (
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	createClauses = []string{"INSERT", "VALUES", "ON CONFLICT"}
 | 
			
		||||
	queryClauses  = []string{"SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR"}
 | 
			
		||||
	queryClauses  = []string{"WITH", "SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR"}
 | 
			
		||||
	updateClauses = []string{"UPDATE", "SET", "WHERE"}
 | 
			
		||||
	deleteClauses = []string{"DELETE", "FROM", "WHERE"}
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										104
									
								
								chainable_api.go
									
									
									
									
									
								
							
							
						
						
									
										104
									
								
								chainable_api.go
									
									
									
									
									
								
							@ -1,6 +1,7 @@
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"regexp"
 | 
			
		||||
	"strings"
 | 
			
		||||
@ -61,6 +62,109 @@ func (db *DB) Table(name string, args ...interface{}) (tx *DB) {
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Recursive specify recursive common table expressions
 | 
			
		||||
func (db *DB) Recursive(name string, args ...interface{}) (tx *DB) {
 | 
			
		||||
	return db.with(name, true, args...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// With specify common table expressions
 | 
			
		||||
//
 | 
			
		||||
//	db.With("cte", "SELECT id, `name` FROM users")
 | 
			
		||||
//	db.With("cte", gorm.Expr("SELECT id, `name` FROM users"))
 | 
			
		||||
//	db.With("cte", db.Select("id", "name").Table("users"))
 | 
			
		||||
//	db.With("cte", "id", "name", "SELECT id, `name` FROM users")
 | 
			
		||||
//	db.With("cte", "id, name", "SELECT id, `name` FROM users")
 | 
			
		||||
//	db.With("cte", []string{"id", "name"}, "SELECT id, `name` FROM users")
 | 
			
		||||
func (db *DB) With(name string, args ...interface{}) (tx *DB) {
 | 
			
		||||
	return db.with(name, false, args...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// with clause
 | 
			
		||||
//
 | 
			
		||||
//	db.with("cte", false, "SELECT id, `name` FROM users")
 | 
			
		||||
//	db.with("cte", false, gorm.Expr("SELECT id, `name` FROM users"))
 | 
			
		||||
//	db.with("cte", false, db.Select("id", "name").Table("users"))
 | 
			
		||||
//	db.with("cte", false, "id", "name", "SELECT id, `name` FROM users")
 | 
			
		||||
//	db.with("cte", false, "id, name", "SELECT id, `name` FROM users")
 | 
			
		||||
//	db.with("cte", false, []string{"id", "name"}, "SELECT id, `name` FROM users")
 | 
			
		||||
func (db *DB) with(name string, recursive bool, args ...interface{}) (tx *DB) {
 | 
			
		||||
	tx = db.getInstance()
 | 
			
		||||
 | 
			
		||||
	if len(args) == 0 {
 | 
			
		||||
		tx.AddError(errors.New("with clause must have subquery"))
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// parse optional columns
 | 
			
		||||
	columns := make([]string, 0, len(args)-1)
 | 
			
		||||
	switch arg := args[0].(type) {
 | 
			
		||||
	case string:
 | 
			
		||||
		switch len(args) {
 | 
			
		||||
		case 1:
 | 
			
		||||
			// args[0] is subquery
 | 
			
		||||
			break
 | 
			
		||||
		case 2:
 | 
			
		||||
			// args[0] is single column or comma-separated string
 | 
			
		||||
			columns = strings.FieldsFunc(arg, func(r rune) bool {
 | 
			
		||||
				return r == ',' || r == ' '
 | 
			
		||||
			})
 | 
			
		||||
		default:
 | 
			
		||||
			// args[:len(args)-1] is []string
 | 
			
		||||
			for idx := 0; idx < len(args)-1; idx++ {
 | 
			
		||||
				col, ok := args[idx].(string)
 | 
			
		||||
				if !ok {
 | 
			
		||||
					tx.AddError(fmt.Errorf("unsupported with column type %T", args[idx]))
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
				if strings.Contains(col, ",") {
 | 
			
		||||
					tx.AddError(fmt.Errorf("unsupported mixed string type"))
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
				columns = append(columns, col)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	case []string:
 | 
			
		||||
		for _, col := range arg {
 | 
			
		||||
			if strings.Contains(col, ",") {
 | 
			
		||||
				tx.AddError(fmt.Errorf("unsupported mixed string type"))
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		columns = arg
 | 
			
		||||
	default:
 | 
			
		||||
		// specify optional columns field with a wrong type
 | 
			
		||||
		if len(args) > 1 {
 | 
			
		||||
			tx.AddError(fmt.Errorf("unsupported with column type %T", arg))
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// parse subquery
 | 
			
		||||
	var subquery clause.Expression
 | 
			
		||||
	switch arg := args[len(args)-1].(type) {
 | 
			
		||||
	case clause.Expression:
 | 
			
		||||
		subquery = arg
 | 
			
		||||
	case string:
 | 
			
		||||
		subquery = clause.Expr{SQL: arg}
 | 
			
		||||
	case *DB:
 | 
			
		||||
		subquery = clause.Expr{SQL: "?", Vars: []interface{}{arg}, WithoutParentheses: true}
 | 
			
		||||
	default:
 | 
			
		||||
		tx.AddError(fmt.Errorf("unsupported with subquery type %T", arg))
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// add with clause
 | 
			
		||||
	tx.Statement.AddClause(clause.With{
 | 
			
		||||
		Recursive: recursive,
 | 
			
		||||
		Exprs: []clause.Expression{clause.WithExpression{
 | 
			
		||||
			Name:    name,
 | 
			
		||||
			Columns: columns,
 | 
			
		||||
			Expr:    subquery,
 | 
			
		||||
		}},
 | 
			
		||||
	})
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Distinct specify distinct fields that you want querying
 | 
			
		||||
func (db *DB) Distinct(args ...interface{}) (tx *DB) {
 | 
			
		||||
	tx = db.getInstance()
 | 
			
		||||
 | 
			
		||||
@ -17,15 +17,15 @@ var db, _ = gorm.Open(tests.DummyDialector{}, nil)
 | 
			
		||||
func checkBuildClauses(t *testing.T, clauses []clause.Interface, result string, vars []interface{}) {
 | 
			
		||||
	var (
 | 
			
		||||
		buildNames    []string
 | 
			
		||||
		buildNamesMap = map[string]bool{}
 | 
			
		||||
		buildNamesSet = map[string]struct{}{}
 | 
			
		||||
		user, _       = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
 | 
			
		||||
		stmt          = gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	for _, c := range clauses {
 | 
			
		||||
		if _, ok := buildNamesMap[c.Name()]; !ok {
 | 
			
		||||
		if _, ok := buildNamesSet[c.Name()]; !ok {
 | 
			
		||||
			buildNames = append(buildNames, c.Name())
 | 
			
		||||
			buildNamesMap[c.Name()] = true
 | 
			
		||||
			buildNamesSet[c.Name()] = struct{}{}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		stmt.AddClause(c)
 | 
			
		||||
 | 
			
		||||
@ -1,3 +1,88 @@
 | 
			
		||||
package clause
 | 
			
		||||
 | 
			
		||||
type With struct{}
 | 
			
		||||
import "strings"
 | 
			
		||||
 | 
			
		||||
// With Common Table Expressions
 | 
			
		||||
type With struct {
 | 
			
		||||
	Recursive  bool
 | 
			
		||||
	Exprs      []Expression
 | 
			
		||||
	Expression Expression
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Name with clause name
 | 
			
		||||
func (with With) Name() string {
 | 
			
		||||
	return "WITH"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Build build with clause
 | 
			
		||||
func (with With) Build(builder Builder) {
 | 
			
		||||
	if with.Expression != nil {
 | 
			
		||||
		with.Expression.Build(builder)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if len(with.Exprs) == 0 {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if with.Recursive {
 | 
			
		||||
		builder.WriteString("RECURSIVE ")
 | 
			
		||||
	}
 | 
			
		||||
	for idx, expr := range with.Exprs {
 | 
			
		||||
		if idx != 0 {
 | 
			
		||||
			builder.WriteByte(',')
 | 
			
		||||
		}
 | 
			
		||||
		expr.Build(builder)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// MergeClause merge with clauses
 | 
			
		||||
func (with With) MergeClause(clause *Clause) {
 | 
			
		||||
	if w, ok := clause.Expression.(With); ok {
 | 
			
		||||
		if !with.Recursive {
 | 
			
		||||
			with.Recursive = w.Recursive
 | 
			
		||||
		}
 | 
			
		||||
		if w.Expression != nil {
 | 
			
		||||
			with.Expression = w.Expression
 | 
			
		||||
			with.Exprs = nil
 | 
			
		||||
		} else if with.Expression == nil {
 | 
			
		||||
			exprs := make([]Expression, len(w.Exprs)+len(with.Exprs))
 | 
			
		||||
			copy(exprs, w.Exprs)
 | 
			
		||||
			copy(exprs[len(w.Exprs):], with.Exprs)
 | 
			
		||||
			with.Exprs = exprs
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	clause.Expression = with
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithExpression with expression
 | 
			
		||||
type WithExpression struct {
 | 
			
		||||
	Name    string
 | 
			
		||||
	Columns []string
 | 
			
		||||
	Expr    Expression
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (with WithExpression) Build(builder Builder) {
 | 
			
		||||
	if with.Name == "" || with.Expr == nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	builder.WriteQuoted(with.Name)
 | 
			
		||||
 | 
			
		||||
	if len(with.Columns) > 0 {
 | 
			
		||||
		builder.WriteByte(' ')
 | 
			
		||||
		builder.WriteByte('(')
 | 
			
		||||
		for idx, column := range with.Columns {
 | 
			
		||||
			if idx != 0 {
 | 
			
		||||
				builder.WriteByte(',')
 | 
			
		||||
			}
 | 
			
		||||
			column = strings.TrimSpace(column)
 | 
			
		||||
			builder.WriteQuoted(column)
 | 
			
		||||
		}
 | 
			
		||||
		builder.WriteByte(')')
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	builder.WriteString(" AS ")
 | 
			
		||||
	builder.WriteByte('(')
 | 
			
		||||
	with.Expr.Build(builder)
 | 
			
		||||
	builder.WriteByte(')')
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										104
									
								
								clause/with_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										104
									
								
								clause/with_test.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,104 @@
 | 
			
		||||
package clause_test
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
	"gorm.io/gorm/clause"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestWith(t *testing.T) {
 | 
			
		||||
	results := []struct {
 | 
			
		||||
		Clauses []clause.Interface
 | 
			
		||||
		Result  string
 | 
			
		||||
		Vars    []interface{}
 | 
			
		||||
	}{
 | 
			
		||||
		{
 | 
			
		||||
			[]clause.Interface{clause.With{Exprs: []clause.Expression{
 | 
			
		||||
				clause.WithExpression{Name: "cte1", Columns: []string{"foo", "bar"}, Expr: clause.Expr{SQL: "SELECT 1, 2"}},
 | 
			
		||||
				clause.WithExpression{Name: "cte2", Expr: clause.Expr{SQL: "SELECT 1"}},
 | 
			
		||||
			}}},
 | 
			
		||||
			"WITH `cte1` (`foo`,`bar`) AS (SELECT 1, 2),`cte2` AS (SELECT 1)",
 | 
			
		||||
			nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			[]clause.Interface{
 | 
			
		||||
				clause.With{
 | 
			
		||||
					Recursive: true,
 | 
			
		||||
					Exprs:     []clause.Expression{clause.WithExpression{Name: "cte1", Expr: clause.Expr{SQL: "SELECT 1"}}},
 | 
			
		||||
				},
 | 
			
		||||
				clause.With{
 | 
			
		||||
					Recursive: false,
 | 
			
		||||
					Exprs:     []clause.Expression{clause.WithExpression{Name: "cte2", Expr: clause.Expr{SQL: "SELECT 2"}}},
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
			"WITH RECURSIVE `cte1` AS (SELECT 1),`cte2` AS (SELECT 2)",
 | 
			
		||||
			nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			[]clause.Interface{
 | 
			
		||||
				clause.With{
 | 
			
		||||
					Expression: clause.Expr{SQL: "`cte1` AS (SELECT 1)"},
 | 
			
		||||
				},
 | 
			
		||||
				clause.With{
 | 
			
		||||
					Exprs: []clause.Expression{clause.WithExpression{Name: "cte2", Expr: clause.Expr{SQL: "SELECT 2"}}},
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
			"WITH `cte1` AS (SELECT 1)",
 | 
			
		||||
			nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			[]clause.Interface{
 | 
			
		||||
				clause.With{
 | 
			
		||||
					Exprs: []clause.Expression{clause.WithExpression{Name: "cte1", Expr: clause.Expr{SQL: "SELECT 1"}}},
 | 
			
		||||
				},
 | 
			
		||||
				clause.With{
 | 
			
		||||
					Expression: clause.Expr{SQL: "`cte2` AS (SELECT 2)"},
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
			"WITH `cte2` AS (SELECT 2)",
 | 
			
		||||
			nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			[]clause.Interface{clause.With{
 | 
			
		||||
				Exprs: []clause.Expression{
 | 
			
		||||
					clause.WithExpression{Name: "cte1", Columns: []string{"foo", "bar"}, Expr: clause.Expr{SQL: "SELECT 1, 2"}},
 | 
			
		||||
				},
 | 
			
		||||
				Expression: clause.Expr{SQL: "cte2 AS (SELECT 1)"},
 | 
			
		||||
			}},
 | 
			
		||||
			"WITH cte2 AS (SELECT 1)",
 | 
			
		||||
			nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			[]clause.Interface{clause.With{
 | 
			
		||||
				Recursive:  false,
 | 
			
		||||
				Exprs:      nil,
 | 
			
		||||
				Expression: nil,
 | 
			
		||||
			}},
 | 
			
		||||
			"WITH",
 | 
			
		||||
			nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			[]clause.Interface{clause.With{Exprs: []clause.Expression{clause.WithExpression{
 | 
			
		||||
				Name: "",
 | 
			
		||||
				Expr: clause.Expr{SQL: "SELECT 1"},
 | 
			
		||||
			}}}},
 | 
			
		||||
			"WITH",
 | 
			
		||||
			nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			[]clause.Interface{clause.With{Exprs: []clause.Expression{clause.WithExpression{
 | 
			
		||||
				Name: "cte1",
 | 
			
		||||
				Expr: nil,
 | 
			
		||||
			}}}},
 | 
			
		||||
			"WITH",
 | 
			
		||||
			nil,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for idx, result := range results {
 | 
			
		||||
		t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
 | 
			
		||||
			checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user