feat: conditional.go query, production friendly

This commit is contained in:
Aitimate 2023-04-27 23:36:59 +08:00
parent aeb298635b
commit 4eb10914ab
4 changed files with 722 additions and 0 deletions

319
conditional/conditional.go Normal file
View File

@ -0,0 +1,319 @@
package conditional
import (
"fmt"
"gorm.io/gorm"
"reflect"
"regexp"
"strings"
"unicode"
"unsafe"
)
const GeneralSumKey = "#sum"
var isInjectionReg = regexp.MustCompile("[^a-zA-Z_]+")
type BaseCondition struct {
Page *int `json:"page" form:"page"` // zh: 页码
Pagesize *int `json:"pagesize" form:"pagesize"` // zh: 每页大小
OrderKey *string `json:"orderKey" form:"orderKey"` // zh: 排序字段名
Desc *bool `json:"desc" form:"desc"` // zh: 是否降序
}
type GeneralResult struct {
Total int64 `json:"total"`
List []map[string]interface{} `json:"list"`
Sum map[string]interface{} `json:"sum"`
}
type wrapDB struct {
Db *gorm.DB
maxPagesize int
includeEmptyString int
Page *int `json:"page" form:"page"` // zh: 页码
Pagesize *int `json:"pagesize" form:"pagesize"` // zh: 每页大小
OrderKey *string `json:"orderKey" form:"orderKey"` // zh: 排序字段名
Desc *bool `json:"desc" form:"desc"` // zh: 是否降序
}
func camelCaseToUnderscore(s string) string {
var output []rune
output = append(output, unicode.ToLower(rune(s[0])))
for i := 1; i < len(s); i++ {
if unicode.IsUpper(rune(s[i])) {
output = append(output, '_')
}
output = append(output, unicode.ToLower(rune(s[i])))
}
return string(output)
}
func underscoreToUpperCamelCase(s string) string {
var output []rune
for i, f := 0, false; i < len(s); i++ {
if s[i] == '_' {
f = true
continue
}
if f {
f = false
output = append(output, unicode.ToUpper(rune(s[i])))
} else {
output = append(output, rune(s[i]))
}
}
return string(output)
}
func (w *wrapDB) doOrder() *wrapDB {
if w.OrderKey != nil {
if orderKey := *w.OrderKey; orderKey != "" && !isInjectionReg.MatchString(orderKey) {
if w.Desc != nil && *w.Desc {
orderKey += " desc"
}
w.Db.Order(orderKey)
}
}
return w
}
func (w *wrapDB) doPage() *wrapDB {
if w.Pagesize == nil || *w.Pagesize > w.maxPagesize {
w.Pagesize = &w.maxPagesize
}
if *w.Pagesize > 0 {
w.Db.Limit(*w.Pagesize)
}
if w.Page != nil {
w.Db.Offset(*w.Pagesize * (*w.Page - 1))
}
return w
}
func (w *wrapDB) doWhere(key string, val interface{}) *wrapDB {
if len(key) == 0 || strings.HasPrefix(key, "#") {
return w
}
if w.includeEmptyString < 1 {
if ref := reflect.ValueOf(val); ref.Kind() == reflect.String && ref.String() == "" {
return w
}
}
db := w.Db
if key = camelCaseToUnderscore(key); len(key) > 3 {
pre := key[:3]
switch pre {
case "neq":
if key[3] == '_' {
db.Where(fmt.Sprintf("`%s` <> ?", key[4:]), val)
}
case "gt_":
db.Where(fmt.Sprintf("`%s` >= ?", key[3:]), val)
case "lt_":
db.Where(fmt.Sprintf("`%s` <= ?", key[3:]), val)
case "in_":
db.Where(fmt.Sprintf("`%s` in ?", key[3:]), val)
case "nin":
if key[3] == '_' {
db.Where(fmt.Sprintf("`%s` not in ?", key[4:]), val)
}
case "lik":
if strings.HasPrefix(key, "like_") {
db.Where(fmt.Sprintf("`%s` like ?", key[5:]), val)
}
case "nli":
if strings.HasPrefix(key, "nlike_") {
db.Where(fmt.Sprintf("`%s` not like ?", key[6:]), val)
}
case "pag":
if key == "page" {
var page int
ref := reflect.ValueOf(val)
if ref.CanFloat() {
page = int(ref.Float())
} else if ref.CanInt() {
page = int(ref.Int())
} else {
page = int(ref.Uint())
}
w.Page = &page
} else if key == "pagesize" || key == "page_size" {
var pagesize int
ref := reflect.ValueOf(val)
if ref.CanFloat() {
pagesize = int(ref.Float())
} else if ref.CanInt() {
pagesize = int(ref.Int())
} else {
pagesize = int(ref.Uint())
}
w.Pagesize = &pagesize
}
case "ord":
if key == "order_key" {
v := camelCaseToUnderscore(val.(string))
if strings.HasPrefix(v, "desc_") {
n, b := v[5:], true
w.OrderKey, w.Desc = &n, &b
} else {
var n string
if strings.HasPrefix(v, "asc_") {
n = v[4:]
} else {
n = v
}
w.OrderKey = &n
}
}
case "eq_":
db.Where(fmt.Sprintf("`%s` = ?", key[3:]), val)
default:
db.Where(fmt.Sprintf("`%s` = ?", key), val)
}
} else {
db.Where(fmt.Sprintf("`%s` = ?", key), val)
}
return w
}
func (w *wrapDB) doDeepWhere(k string, v reflect.Value) *wrapDB {
kind := v.Kind()
switch kind {
case reflect.Pointer, reflect.UnsafePointer:
if !v.IsNil() {
w.doDeepWhere(k, v.Elem())
}
case reflect.Struct:
t, n := v.Type(), v.NumField()
for i := 0; i < n; i++ {
ki, vi := t.Field(i).Name, v.Field(i)
kind = vi.Kind()
switch kind {
case reflect.Pointer, reflect.UnsafePointer:
w.doDeepWhere(ki, vi.Elem())
case reflect.Struct, reflect.Map:
w.doDeepWhere("", vi)
default:
w.doWhere(ki, vi.Interface())
}
}
case reflect.Map:
if keys := v.MapKeys(); len(keys) > 0 && keys[0].Kind() == reflect.String {
for _, key := range keys {
w.doWhere(key.String(), v.MapIndex(key).Interface())
}
}
default:
w.doWhere(k, v.Interface())
}
return w
}
func QueryGeneralConditional(db *gorm.DB, search map[string]interface{}, maxPagesize, includeEmptyString int) (gr GeneralResult, err error) {
gr = GeneralResult{List: make([]map[string]interface{}, 0), Sum: make(map[string]interface{}), Total: 0}
wdb := &wrapDB{Db: db, maxPagesize: maxPagesize, includeEmptyString: includeEmptyString}
for key, val := range search {
wdb.doWhere(key, val)
}
if err = wdb.Db.Count(&gr.Total).Error; err != nil {
return
}
if gr.Total > 0 {
wdb.doOrder().doPage()
if err = wdb.Db.Scan(&gr.List).Error; err != nil {
return
}
// [Optional] Underscore field name to UpperCamelCase
if len(gr.List) > 0 {
list := make([]map[string]interface{}, 0)
for _, m := range gr.List {
nm := make(map[string]interface{})
for k, v := range m {
nm[underscoreToUpperCamelCase(k)] = v
}
list = append(list, nm)
}
gr.List = list
}
// page-1 can do sum only
if wdb.Page != nil && *wdb.Page == 1 {
var sumFields []string
if sumKeys := search[GeneralSumKey]; sumKeys != nil {
for _, sumKey := range sumKeys.([]string) {
if isInjectionReg.MatchString(sumKey) {
err = fmt.Errorf("Ilegal sumKey: %s ", sumKey)
return
}
sumFields = append(sumFields, camelCaseToUnderscore(sumKey))
}
}
if len(sumFields) > 0 {
var sb strings.Builder
for _, field := range sumFields {
sb.Write([]byte(fmt.Sprintf("sum(`%s`) as `%s`, ", field, field)))
}
if sb.Len() > 16 {
if err = db.Select(sb.String()[:sb.Len()-2]).Scan(gr.Sum).Error; err != nil {
return
}
}
}
}
}
return
}
func QueryStructConditional(db *gorm.DB, search interface{}, list interface{}, sum interface{}, total *int64, maxPagesize, includeEmptyString int) (err error) {
wdb := &wrapDB{Db: db, maxPagesize: maxPagesize, includeEmptyString: includeEmptyString}
if search != nil {
wdb.doDeepWhere("", reflect.ValueOf(search))
}
if total != nil {
if err = wdb.Db.Count(total).Error; err != nil {
return
}
}
if *total > 0 {
wdb.doOrder().doPage()
if list != nil {
if err = wdb.Db.Scan(list).Error; err != nil {
return
}
}
// page-1 can do sum only
if wdb.Page != nil && *wdb.Page == 1 && sum != nil {
sv := reflect.ValueOf(sum)
sk := sv.Kind()
switch sk {
case reflect.Pointer, reflect.UnsafePointer:
sv = reflect.ValueOf(sv.Elem().Interface())
sk = sv.Kind()
if sk == reflect.Struct {
if t, n := sv.Type(), sv.NumField(); n > 0 {
var sb strings.Builder
for i := 0; i < n; i++ {
ns := camelCaseToUnderscore(t.Field(i).Name)
sb.Write([]byte(fmt.Sprintf("sum(`%s`) as `%s`, ", ns, ns)))
}
if sb.Len() > 0 {
db.Select(sb.String()[:sb.Len()-2]).Scan(sum)
}
}
}
default:
return fmt.Errorf("[SUM] Unknow: %s , only pointer allowed", sum)
}
}
} else {
if list != nil {
// set empty array
s := (*reflect.SliceHeader)(reflect.ValueOf(list).UnsafePointer())
if s.Data == 0 {
e := make([]interface{}, 0)
s.Data = (uintptr)(unsafe.Pointer(&e))
}
}
}
return
}

View File

@ -0,0 +1,200 @@
package conditional
import (
"encoding/json"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"log"
"testing"
)
func initDB() *gorm.DB {
db, err := gorm.Open(mysql.Open("root:root@tcp(127.0.0.1:3306)/gorm_test?charset=utf8mb4&parseTime=True&loc=Local"), &gorm.Config{Logger: logger.Default.LogMode(logger.Info)})
if err != nil {
panic("failed to connect database")
}
// `id` INT unsigned NOT NULL AUTO_INCREMENT COMMENT 'user ID',
// `name` VARCHAR(64) NOT NULL COMMENT '钱包地址',
// `level` INT unsigned NOT NULL COMMENT '用户等级',
// `status` int unsigned NOT NULL DEFAULT '0' COMMENT '结算状态 0: 正常 20禁用',
// `created_at` bigint NOT NULL COMMENT '创建时间 毫秒',
// `updated_at` bigint NOT NULL COMMENT '更新时间 毫秒',
db = db.Table("user")
return db
}
func TestQueryGeneralConditionalNeq(t *testing.T) {
search := make(map[string]interface{})
search["neqId"] = uint(1)
search["neq_id"] = uint(1)
gr, err := QueryGeneralConditional(initDB(), search, 10, 1)
if err != nil {
log.Println(err)
}
marshal, _ := json.Marshal(gr)
log.Println(string(marshal))
}
func TestQueryGeneralConditionalEq(t *testing.T) {
search := make(map[string]interface{})
search["id"] = uint(1)
search["eq_id"] = uint(1)
search["eqId"] = uint(1)
gr, err := QueryGeneralConditional(initDB(), search, 10, 1)
if err != nil {
log.Println(err)
}
marshal, _ := json.Marshal(gr)
log.Println(string(marshal))
}
func TestQueryGeneralConditionalLt(t *testing.T) {
search := make(map[string]interface{})
search["ltId"] = uint(2)
gr, err := QueryGeneralConditional(initDB(), search, 10, 1)
if err != nil {
log.Println(err)
}
marshal, _ := json.Marshal(gr)
log.Println(string(marshal))
}
func TestQueryGeneralConditionalGt(t *testing.T) {
search := make(map[string]interface{})
search["gtId"] = uint(2)
gr, err := QueryGeneralConditional(initDB(), search, 10, 1)
if err != nil {
log.Println(err)
}
marshal, _ := json.Marshal(gr)
log.Println(string(marshal))
}
func TestQueryGeneralConditionalIn(t *testing.T) {
search := make(map[string]interface{})
in := make([]uint, 0)
in = append(in, 1)
in = append(in, 2)
in = append(in, 3)
search["inId"] = in
gr, err := QueryGeneralConditional(initDB(), search, 10, 1)
if err != nil {
log.Println(err)
}
marshal, _ := json.Marshal(gr)
log.Println(string(marshal))
}
func TestQueryGeneralConditionalNin(t *testing.T) {
search := make(map[string]interface{})
in := make([]uint, 0)
in = append(in, 1)
in = append(in, 2)
in = append(in, 3)
search["ninId"] = in
gr, err := QueryGeneralConditional(initDB(), search, 10, 1)
if err != nil {
log.Println(err)
}
marshal, _ := json.Marshal(gr)
log.Println(string(marshal))
}
func TestQueryGeneralConditionalLike(t *testing.T) {
search := make(map[string]interface{})
search["likeName"] = "_oo"
gr, err := QueryGeneralConditional(initDB(), search, 10, 1)
if err != nil {
log.Println(err)
}
marshal, _ := json.Marshal(gr)
log.Println(string(marshal))
}
func TestQueryGeneralConditionalNlike(t *testing.T) {
search := make(map[string]interface{})
search["likeName"] = "f%"
gr, err := QueryGeneralConditional(initDB(), search, 10, 1)
if err != nil {
log.Println(err)
}
marshal, _ := json.Marshal(gr)
log.Println(string(marshal))
}
func TestQueryGeneralConditionalPage(t *testing.T) {
search := make(map[string]interface{})
search["page"] = 2
search["pagesize"] = 2
gr, err := QueryGeneralConditional(initDB(), search, 10, 1)
if err != nil {
log.Println(err)
}
marshal, _ := json.Marshal(gr)
log.Println(string(marshal))
}
func TestQueryGeneralConditionalOrder(t *testing.T) {
search := make(map[string]interface{})
//search["orderKey"] = "descId"
search["orderKey"] = "Id"
search["orderKey"] = "ascId"
//search["orderKey"] = "descId or 1"
gr, err := QueryGeneralConditional(initDB(), search, 2, 1)
if err != nil {
log.Println(err)
}
marshal, _ := json.Marshal(gr)
log.Println(string(marshal))
}
func TestQueryGeneralConditionalPage1Sum(t *testing.T) {
search := make(map[string]interface{})
search["page"] = 1
search[GeneralSumKey] = []string{"level"}
gr, err := QueryGeneralConditional(initDB(), search, 2, 1)
if err != nil {
log.Println(err)
}
marshal, _ := json.Marshal(gr)
log.Println(string(marshal))
}
func TestQueryGeneralConditionalNotAllowEmptyString(t *testing.T) {
search := make(map[string]interface{})
search["likeName"] = ""
gr, err := QueryGeneralConditional(initDB(), search, 2, 0)
if err != nil {
log.Println(err)
}
marshal, _ := json.Marshal(gr)
log.Println(string(marshal))
search["likeName"] = "f%"
gr, err = QueryGeneralConditional(initDB(), search, 2, 0)
if err != nil {
log.Println(err)
}
marshal, _ = json.Marshal(gr)
log.Println(string(marshal))
}
func TestQueryGeneralConditionalMaxCount(t *testing.T) {
// [unsafe] Unrestricted mode : <=0
search := make(map[string]interface{})
gr, err := QueryGeneralConditional(initDB(), search, -1, 0)
if err != nil {
log.Println(err)
}
marshal, _ := json.Marshal(gr)
log.Println(string(marshal))
// [safe] > 0
search = make(map[string]interface{})
gr, err = QueryGeneralConditional(initDB(), search, 1, 0)
if err != nil {
log.Println(err)
}
marshal, _ = json.Marshal(gr)
log.Println(string(marshal))
}

View File

@ -0,0 +1,178 @@
package conditional
import (
"log"
"testing"
)
type User struct {
Id uint `json:"id"`
Name string `json:"name"`
Level uint `json:"level"`
Status uint `json:"status"`
CreatedAt uint `json:"createdAt"`
UpdatedAt uint `json:"updatedAt"`
}
func TestQueryStructConditionalNeq(t *testing.T) {
search := struct {
NeqId uint
}{NeqId: 1}
list, total := new([]User), new(int64)
err := QueryStructConditional(initDB(), search, list, nil, total, 10, 1)
if err != nil {
log.Println(err)
}
log.Println(list, *total)
}
func TestQueryStructConditionalEq(t *testing.T) {
//search := struct {
// Id uint
//}{Id: 1}
search := struct {
EqId uint
}{EqId: 1}
list, total := new([]User), new(int64)
err := QueryStructConditional(initDB(), search, list, nil, total, 10, 1)
if err != nil {
log.Println(err)
}
log.Println(list, *total)
}
func TestQueryStructConditionalLt(t *testing.T) {
search := struct {
LtId uint
}{LtId: 1}
list, total := new([]User), new(int64)
err := QueryStructConditional(initDB(), search, list, nil, total, 10, 1)
if err != nil {
log.Println(err)
}
log.Println(list, *total)
}
func TestQueryStructConditionalGt(t *testing.T) {
search := struct {
GtId uint
}{GtId: 1}
list, total := new([]User), new(int64)
err := QueryStructConditional(initDB(), search, list, nil, total, 10, 1)
if err != nil {
log.Println(err)
}
log.Println(list, *total)
}
func TestQueryStructConditionalIn(t *testing.T) {
search := struct {
InId []uint
}{InId: []uint{1, 2}}
list, total := new([]User), new(int64)
err := QueryStructConditional(initDB(), search, list, nil, total, 10, 1)
if err != nil {
log.Println(err)
}
log.Println(list, *total)
}
func TestQueryStructConditionalNin(t *testing.T) {
search := struct {
NinId []uint
}{NinId: []uint{1, 2}}
list, total := new([]User), new(int64)
err := QueryStructConditional(initDB(), search, list, nil, total, 10, 1)
if err != nil {
log.Println(err)
}
log.Println(list, *total)
}
func TestQueryStructConditionalLike(t *testing.T) {
search := struct {
LikeName string
}{LikeName: "f%"}
list, total := new([]User), new(int64)
err := QueryStructConditional(initDB(), search, list, nil, total, 10, 1)
if err != nil {
log.Println(err)
}
log.Println(list, *total)
}
func TestQueryStructConditionalNlike(t *testing.T) {
search := struct {
NlikeName string
}{NlikeName: "f%"}
list, total := new([]User), new(int64)
err := QueryStructConditional(initDB(), search, list, nil, total, 10, 1)
if err != nil {
log.Println(err)
}
log.Println(list, *total)
}
func TestQueryStructConditionalPage(t *testing.T) {
search := struct {
Page int
Pagesize int
}{Page: 2, Pagesize: 2}
list, total := new([]User), new(int64)
err := QueryStructConditional(initDB(), search, list, nil, total, 10, 1)
if err != nil {
log.Println(err)
}
log.Println(list, *total)
}
func TestQueryStructConditionalOrder(t *testing.T) {
search := struct {
OrderKey string
}{OrderKey: "descId"}
list, total := new([]User), new(int64)
err := QueryStructConditional(initDB(), search, list, nil, total, 10, 1)
if err != nil {
log.Println(err)
}
log.Println(list, *total)
}
func TestQueryStructConditionalPage1Sum(t *testing.T) {
search := struct {
Page int
}{Page: 1}
sum := new(struct {
Level uint
})
list, total := new([]User), new(int64)
err := QueryStructConditional(initDB(), search, list, sum, total, 10, 1)
if err != nil {
log.Println(err)
}
log.Println(sum, list, *total)
}
func TestQueryStructConditionalNotAllowEmptyString(t *testing.T) {
search := struct {
Name string
}{Name: ""}
//search := struct {
// Name string
//}{Name: "foo"}
list, total := new([]User), new(int64)
err := QueryStructConditional(initDB(), search, list, nil, total, 10, 0)
if err != nil {
log.Println(err)
}
log.Println(list, *total)
}
func TestQueryStructConditionalMaxCount(t *testing.T) {
list, total := new([]User), new(int64)
err := QueryStructConditional(initDB(), nil, list, nil, total, 2, 0)
if err != nil {
log.Println(err)
}
log.Println(list, *total)
}

25
conditional/example.sql Normal file
View File

@ -0,0 +1,25 @@
CREATE SCHEMA IF NOT EXISTS gorm_test;
USE gorm_test;
DROP TABLE IF EXISTS `user`;
CREATE TABLE IF NOT EXISTS `user`
(
`id` INT unsigned NOT NULL AUTO_INCREMENT COMMENT 'user ID',
`name` VARCHAR(64) NOT NULL COMMENT '钱包地址',
`level` INT unsigned NOT NULL COMMENT '用户等级',
`status` int unsigned NOT NULL DEFAULT '0' COMMENT '结算状态 0: 正常 20禁用',
`created_at` bigint NOT NULL COMMENT '创建时间 毫秒',
`updated_at` bigint NOT NULL COMMENT '更新时间 毫秒',
PRIMARY KEY (`id`),
key key_created_at (created_at)
) ENGINE = InnoDB
DEFAULT CHARSET = utf8mb3 COMMENT ='user';
INSERT INTO gorm_test.user (name, level, status, created_at, updated_at)
VALUES
('Boo', 1, 1, 1682597017126, 1682597017126),
('Foo', 2, 2, 1682697017126, 1682697017126),
('Hoo', 3, 3, 1682797017126, 1682797017126),
('Ioo', 4, 4, 1682897017126, 1682897017126),
('Joo', 5, 5, 1682997017126, 1682997017126),
('Koo', 6, 6, 1683097017126, 1683097017126);