From 4eb10914ab4dc0cfff2a2fdce61bafe36d469deb Mon Sep 17 00:00:00 2001 From: Aitimate <549384715@qq.com> Date: Thu, 27 Apr 2023 23:36:59 +0800 Subject: [PATCH] feat: conditional.go query, production friendly --- conditional/conditional.go | 319 +++++++++++++++++++++++++ conditional/conditional_gen_test.go | 200 ++++++++++++++++ conditional/conditional_struct_test.go | 178 ++++++++++++++ conditional/example.sql | 25 ++ 4 files changed, 722 insertions(+) create mode 100644 conditional/conditional.go create mode 100644 conditional/conditional_gen_test.go create mode 100644 conditional/conditional_struct_test.go create mode 100644 conditional/example.sql diff --git a/conditional/conditional.go b/conditional/conditional.go new file mode 100644 index 00000000..d9ce6bd6 --- /dev/null +++ b/conditional/conditional.go @@ -0,0 +1,319 @@ +package conditional + +import ( + "fmt" + "gorm.io/gorm" + "reflect" + "regexp" + "strings" + "unicode" + "unsafe" +) + +const GeneralSumKey = "#sum" + +var isInjectionReg = regexp.MustCompile("[^a-zA-Z_]+") + +type BaseCondition struct { + Page *int `json:"page" form:"page"` // zh: 页码 + Pagesize *int `json:"pagesize" form:"pagesize"` // zh: 每页大小 + OrderKey *string `json:"orderKey" form:"orderKey"` // zh: 排序字段名 + Desc *bool `json:"desc" form:"desc"` // zh: 是否降序 +} + +type GeneralResult struct { + Total int64 `json:"total"` + List []map[string]interface{} `json:"list"` + Sum map[string]interface{} `json:"sum"` +} + +type wrapDB struct { + Db *gorm.DB + maxPagesize int + includeEmptyString int + Page *int `json:"page" form:"page"` // zh: 页码 + Pagesize *int `json:"pagesize" form:"pagesize"` // zh: 每页大小 + OrderKey *string `json:"orderKey" form:"orderKey"` // zh: 排序字段名 + Desc *bool `json:"desc" form:"desc"` // zh: 是否降序 +} + +func camelCaseToUnderscore(s string) string { + var output []rune + output = append(output, unicode.ToLower(rune(s[0]))) + for i := 1; i < len(s); i++ { + if unicode.IsUpper(rune(s[i])) { + output = append(output, '_') + } + output = append(output, unicode.ToLower(rune(s[i]))) + } + return string(output) +} + +func underscoreToUpperCamelCase(s string) string { + var output []rune + for i, f := 0, false; i < len(s); i++ { + if s[i] == '_' { + f = true + continue + } + if f { + f = false + output = append(output, unicode.ToUpper(rune(s[i]))) + } else { + output = append(output, rune(s[i])) + } + } + return string(output) +} + +func (w *wrapDB) doOrder() *wrapDB { + if w.OrderKey != nil { + if orderKey := *w.OrderKey; orderKey != "" && !isInjectionReg.MatchString(orderKey) { + if w.Desc != nil && *w.Desc { + orderKey += " desc" + } + w.Db.Order(orderKey) + } + } + return w +} + +func (w *wrapDB) doPage() *wrapDB { + if w.Pagesize == nil || *w.Pagesize > w.maxPagesize { + w.Pagesize = &w.maxPagesize + } + if *w.Pagesize > 0 { + w.Db.Limit(*w.Pagesize) + } + if w.Page != nil { + w.Db.Offset(*w.Pagesize * (*w.Page - 1)) + } + return w +} + +func (w *wrapDB) doWhere(key string, val interface{}) *wrapDB { + if len(key) == 0 || strings.HasPrefix(key, "#") { + return w + } + if w.includeEmptyString < 1 { + if ref := reflect.ValueOf(val); ref.Kind() == reflect.String && ref.String() == "" { + return w + } + } + db := w.Db + if key = camelCaseToUnderscore(key); len(key) > 3 { + pre := key[:3] + switch pre { + case "neq": + if key[3] == '_' { + db.Where(fmt.Sprintf("`%s` <> ?", key[4:]), val) + } + case "gt_": + db.Where(fmt.Sprintf("`%s` >= ?", key[3:]), val) + case "lt_": + db.Where(fmt.Sprintf("`%s` <= ?", key[3:]), val) + case "in_": + db.Where(fmt.Sprintf("`%s` in ?", key[3:]), val) + case "nin": + if key[3] == '_' { + db.Where(fmt.Sprintf("`%s` not in ?", key[4:]), val) + } + case "lik": + if strings.HasPrefix(key, "like_") { + db.Where(fmt.Sprintf("`%s` like ?", key[5:]), val) + } + case "nli": + if strings.HasPrefix(key, "nlike_") { + db.Where(fmt.Sprintf("`%s` not like ?", key[6:]), val) + } + case "pag": + if key == "page" { + var page int + ref := reflect.ValueOf(val) + if ref.CanFloat() { + page = int(ref.Float()) + } else if ref.CanInt() { + page = int(ref.Int()) + } else { + page = int(ref.Uint()) + } + w.Page = &page + } else if key == "pagesize" || key == "page_size" { + var pagesize int + ref := reflect.ValueOf(val) + if ref.CanFloat() { + pagesize = int(ref.Float()) + } else if ref.CanInt() { + pagesize = int(ref.Int()) + } else { + pagesize = int(ref.Uint()) + } + w.Pagesize = &pagesize + } + case "ord": + if key == "order_key" { + v := camelCaseToUnderscore(val.(string)) + if strings.HasPrefix(v, "desc_") { + n, b := v[5:], true + w.OrderKey, w.Desc = &n, &b + } else { + var n string + if strings.HasPrefix(v, "asc_") { + n = v[4:] + } else { + n = v + } + w.OrderKey = &n + } + } + case "eq_": + db.Where(fmt.Sprintf("`%s` = ?", key[3:]), val) + default: + db.Where(fmt.Sprintf("`%s` = ?", key), val) + } + } else { + db.Where(fmt.Sprintf("`%s` = ?", key), val) + } + return w +} + +func (w *wrapDB) doDeepWhere(k string, v reflect.Value) *wrapDB { + kind := v.Kind() + switch kind { + case reflect.Pointer, reflect.UnsafePointer: + if !v.IsNil() { + w.doDeepWhere(k, v.Elem()) + } + case reflect.Struct: + t, n := v.Type(), v.NumField() + for i := 0; i < n; i++ { + ki, vi := t.Field(i).Name, v.Field(i) + kind = vi.Kind() + switch kind { + case reflect.Pointer, reflect.UnsafePointer: + w.doDeepWhere(ki, vi.Elem()) + case reflect.Struct, reflect.Map: + w.doDeepWhere("", vi) + default: + w.doWhere(ki, vi.Interface()) + } + } + case reflect.Map: + if keys := v.MapKeys(); len(keys) > 0 && keys[0].Kind() == reflect.String { + for _, key := range keys { + w.doWhere(key.String(), v.MapIndex(key).Interface()) + } + } + default: + w.doWhere(k, v.Interface()) + } + return w +} + +func QueryGeneralConditional(db *gorm.DB, search map[string]interface{}, maxPagesize, includeEmptyString int) (gr GeneralResult, err error) { + gr = GeneralResult{List: make([]map[string]interface{}, 0), Sum: make(map[string]interface{}), Total: 0} + wdb := &wrapDB{Db: db, maxPagesize: maxPagesize, includeEmptyString: includeEmptyString} + for key, val := range search { + wdb.doWhere(key, val) + } + if err = wdb.Db.Count(&gr.Total).Error; err != nil { + return + } + if gr.Total > 0 { + wdb.doOrder().doPage() + if err = wdb.Db.Scan(&gr.List).Error; err != nil { + return + } + // [Optional] Underscore field name to UpperCamelCase + if len(gr.List) > 0 { + list := make([]map[string]interface{}, 0) + for _, m := range gr.List { + nm := make(map[string]interface{}) + for k, v := range m { + nm[underscoreToUpperCamelCase(k)] = v + } + list = append(list, nm) + } + gr.List = list + } + // page-1 can do sum only + if wdb.Page != nil && *wdb.Page == 1 { + var sumFields []string + if sumKeys := search[GeneralSumKey]; sumKeys != nil { + for _, sumKey := range sumKeys.([]string) { + if isInjectionReg.MatchString(sumKey) { + err = fmt.Errorf("Ilegal sumKey: %s ", sumKey) + return + } + sumFields = append(sumFields, camelCaseToUnderscore(sumKey)) + } + } + if len(sumFields) > 0 { + var sb strings.Builder + for _, field := range sumFields { + sb.Write([]byte(fmt.Sprintf("sum(`%s`) as `%s`, ", field, field))) + } + if sb.Len() > 16 { + if err = db.Select(sb.String()[:sb.Len()-2]).Scan(gr.Sum).Error; err != nil { + return + } + } + } + } + } + return +} + +func QueryStructConditional(db *gorm.DB, search interface{}, list interface{}, sum interface{}, total *int64, maxPagesize, includeEmptyString int) (err error) { + wdb := &wrapDB{Db: db, maxPagesize: maxPagesize, includeEmptyString: includeEmptyString} + if search != nil { + wdb.doDeepWhere("", reflect.ValueOf(search)) + } + if total != nil { + if err = wdb.Db.Count(total).Error; err != nil { + return + } + } + if *total > 0 { + wdb.doOrder().doPage() + if list != nil { + if err = wdb.Db.Scan(list).Error; err != nil { + return + } + } + // page-1 can do sum only + if wdb.Page != nil && *wdb.Page == 1 && sum != nil { + sv := reflect.ValueOf(sum) + sk := sv.Kind() + switch sk { + case reflect.Pointer, reflect.UnsafePointer: + sv = reflect.ValueOf(sv.Elem().Interface()) + sk = sv.Kind() + if sk == reflect.Struct { + if t, n := sv.Type(), sv.NumField(); n > 0 { + var sb strings.Builder + for i := 0; i < n; i++ { + ns := camelCaseToUnderscore(t.Field(i).Name) + sb.Write([]byte(fmt.Sprintf("sum(`%s`) as `%s`, ", ns, ns))) + } + if sb.Len() > 0 { + db.Select(sb.String()[:sb.Len()-2]).Scan(sum) + } + } + } + default: + return fmt.Errorf("[SUM] Unknow: %s , only pointer allowed", sum) + } + } + } else { + if list != nil { + // set empty array + s := (*reflect.SliceHeader)(reflect.ValueOf(list).UnsafePointer()) + if s.Data == 0 { + e := make([]interface{}, 0) + s.Data = (uintptr)(unsafe.Pointer(&e)) + } + } + } + return +} diff --git a/conditional/conditional_gen_test.go b/conditional/conditional_gen_test.go new file mode 100644 index 00000000..61d3568a --- /dev/null +++ b/conditional/conditional_gen_test.go @@ -0,0 +1,200 @@ +package conditional + +import ( + "encoding/json" + "gorm.io/driver/mysql" + "gorm.io/gorm" + "gorm.io/gorm/logger" + "log" + "testing" +) + +func initDB() *gorm.DB { + db, err := gorm.Open(mysql.Open("root:root@tcp(127.0.0.1:3306)/gorm_test?charset=utf8mb4&parseTime=True&loc=Local"), &gorm.Config{Logger: logger.Default.LogMode(logger.Info)}) + if err != nil { + panic("failed to connect database") + } + // `id` INT unsigned NOT NULL AUTO_INCREMENT COMMENT 'user ID', + // `name` VARCHAR(64) NOT NULL COMMENT '钱包地址', + // `level` INT unsigned NOT NULL COMMENT '用户等级', + // `status` int unsigned NOT NULL DEFAULT '0' COMMENT '结算状态 0: 正常 20禁用', + // `created_at` bigint NOT NULL COMMENT '创建时间 毫秒', + // `updated_at` bigint NOT NULL COMMENT '更新时间 毫秒', + db = db.Table("user") + return db +} + +func TestQueryGeneralConditionalNeq(t *testing.T) { + search := make(map[string]interface{}) + search["neqId"] = uint(1) + search["neq_id"] = uint(1) + gr, err := QueryGeneralConditional(initDB(), search, 10, 1) + if err != nil { + log.Println(err) + } + marshal, _ := json.Marshal(gr) + log.Println(string(marshal)) +} + +func TestQueryGeneralConditionalEq(t *testing.T) { + search := make(map[string]interface{}) + search["id"] = uint(1) + search["eq_id"] = uint(1) + search["eqId"] = uint(1) + gr, err := QueryGeneralConditional(initDB(), search, 10, 1) + if err != nil { + log.Println(err) + } + marshal, _ := json.Marshal(gr) + log.Println(string(marshal)) +} + +func TestQueryGeneralConditionalLt(t *testing.T) { + search := make(map[string]interface{}) + search["ltId"] = uint(2) + gr, err := QueryGeneralConditional(initDB(), search, 10, 1) + if err != nil { + log.Println(err) + } + marshal, _ := json.Marshal(gr) + log.Println(string(marshal)) +} + +func TestQueryGeneralConditionalGt(t *testing.T) { + search := make(map[string]interface{}) + search["gtId"] = uint(2) + gr, err := QueryGeneralConditional(initDB(), search, 10, 1) + if err != nil { + log.Println(err) + } + marshal, _ := json.Marshal(gr) + log.Println(string(marshal)) +} + +func TestQueryGeneralConditionalIn(t *testing.T) { + search := make(map[string]interface{}) + in := make([]uint, 0) + in = append(in, 1) + in = append(in, 2) + in = append(in, 3) + search["inId"] = in + gr, err := QueryGeneralConditional(initDB(), search, 10, 1) + if err != nil { + log.Println(err) + } + marshal, _ := json.Marshal(gr) + log.Println(string(marshal)) +} + +func TestQueryGeneralConditionalNin(t *testing.T) { + search := make(map[string]interface{}) + in := make([]uint, 0) + in = append(in, 1) + in = append(in, 2) + in = append(in, 3) + search["ninId"] = in + gr, err := QueryGeneralConditional(initDB(), search, 10, 1) + if err != nil { + log.Println(err) + } + marshal, _ := json.Marshal(gr) + log.Println(string(marshal)) +} + +func TestQueryGeneralConditionalLike(t *testing.T) { + search := make(map[string]interface{}) + search["likeName"] = "_oo" + gr, err := QueryGeneralConditional(initDB(), search, 10, 1) + if err != nil { + log.Println(err) + } + marshal, _ := json.Marshal(gr) + log.Println(string(marshal)) +} + +func TestQueryGeneralConditionalNlike(t *testing.T) { + search := make(map[string]interface{}) + search["likeName"] = "f%" + gr, err := QueryGeneralConditional(initDB(), search, 10, 1) + if err != nil { + log.Println(err) + } + marshal, _ := json.Marshal(gr) + log.Println(string(marshal)) +} + +func TestQueryGeneralConditionalPage(t *testing.T) { + search := make(map[string]interface{}) + search["page"] = 2 + search["pagesize"] = 2 + gr, err := QueryGeneralConditional(initDB(), search, 10, 1) + if err != nil { + log.Println(err) + } + marshal, _ := json.Marshal(gr) + log.Println(string(marshal)) +} + +func TestQueryGeneralConditionalOrder(t *testing.T) { + search := make(map[string]interface{}) + //search["orderKey"] = "descId" + search["orderKey"] = "Id" + search["orderKey"] = "ascId" + //search["orderKey"] = "descId or 1" + gr, err := QueryGeneralConditional(initDB(), search, 2, 1) + if err != nil { + log.Println(err) + } + marshal, _ := json.Marshal(gr) + log.Println(string(marshal)) +} + +func TestQueryGeneralConditionalPage1Sum(t *testing.T) { + search := make(map[string]interface{}) + search["page"] = 1 + search[GeneralSumKey] = []string{"level"} + gr, err := QueryGeneralConditional(initDB(), search, 2, 1) + if err != nil { + log.Println(err) + } + marshal, _ := json.Marshal(gr) + log.Println(string(marshal)) +} + +func TestQueryGeneralConditionalNotAllowEmptyString(t *testing.T) { + search := make(map[string]interface{}) + search["likeName"] = "" + gr, err := QueryGeneralConditional(initDB(), search, 2, 0) + if err != nil { + log.Println(err) + } + marshal, _ := json.Marshal(gr) + log.Println(string(marshal)) + search["likeName"] = "f%" + gr, err = QueryGeneralConditional(initDB(), search, 2, 0) + if err != nil { + log.Println(err) + } + marshal, _ = json.Marshal(gr) + log.Println(string(marshal)) +} + +func TestQueryGeneralConditionalMaxCount(t *testing.T) { + // [unsafe] Unrestricted mode : <=0 + search := make(map[string]interface{}) + gr, err := QueryGeneralConditional(initDB(), search, -1, 0) + if err != nil { + log.Println(err) + } + marshal, _ := json.Marshal(gr) + log.Println(string(marshal)) + + // [safe] > 0 + search = make(map[string]interface{}) + gr, err = QueryGeneralConditional(initDB(), search, 1, 0) + if err != nil { + log.Println(err) + } + marshal, _ = json.Marshal(gr) + log.Println(string(marshal)) +} diff --git a/conditional/conditional_struct_test.go b/conditional/conditional_struct_test.go new file mode 100644 index 00000000..0a6d1ebc --- /dev/null +++ b/conditional/conditional_struct_test.go @@ -0,0 +1,178 @@ +package conditional + +import ( + "log" + "testing" +) + +type User struct { + Id uint `json:"id"` + Name string `json:"name"` + Level uint `json:"level"` + Status uint `json:"status"` + CreatedAt uint `json:"createdAt"` + UpdatedAt uint `json:"updatedAt"` +} + +func TestQueryStructConditionalNeq(t *testing.T) { + search := struct { + NeqId uint + }{NeqId: 1} + list, total := new([]User), new(int64) + err := QueryStructConditional(initDB(), search, list, nil, total, 10, 1) + if err != nil { + log.Println(err) + } + log.Println(list, *total) +} + +func TestQueryStructConditionalEq(t *testing.T) { + //search := struct { + // Id uint + //}{Id: 1} + search := struct { + EqId uint + }{EqId: 1} + list, total := new([]User), new(int64) + err := QueryStructConditional(initDB(), search, list, nil, total, 10, 1) + if err != nil { + log.Println(err) + } + log.Println(list, *total) +} + +func TestQueryStructConditionalLt(t *testing.T) { + search := struct { + LtId uint + }{LtId: 1} + list, total := new([]User), new(int64) + err := QueryStructConditional(initDB(), search, list, nil, total, 10, 1) + if err != nil { + log.Println(err) + } + log.Println(list, *total) +} + +func TestQueryStructConditionalGt(t *testing.T) { + search := struct { + GtId uint + }{GtId: 1} + list, total := new([]User), new(int64) + err := QueryStructConditional(initDB(), search, list, nil, total, 10, 1) + if err != nil { + log.Println(err) + } + log.Println(list, *total) +} + +func TestQueryStructConditionalIn(t *testing.T) { + search := struct { + InId []uint + }{InId: []uint{1, 2}} + list, total := new([]User), new(int64) + err := QueryStructConditional(initDB(), search, list, nil, total, 10, 1) + if err != nil { + log.Println(err) + } + log.Println(list, *total) +} + +func TestQueryStructConditionalNin(t *testing.T) { + search := struct { + NinId []uint + }{NinId: []uint{1, 2}} + list, total := new([]User), new(int64) + err := QueryStructConditional(initDB(), search, list, nil, total, 10, 1) + if err != nil { + log.Println(err) + } + log.Println(list, *total) +} + +func TestQueryStructConditionalLike(t *testing.T) { + search := struct { + LikeName string + }{LikeName: "f%"} + list, total := new([]User), new(int64) + err := QueryStructConditional(initDB(), search, list, nil, total, 10, 1) + if err != nil { + log.Println(err) + } + log.Println(list, *total) +} + +func TestQueryStructConditionalNlike(t *testing.T) { + search := struct { + NlikeName string + }{NlikeName: "f%"} + list, total := new([]User), new(int64) + err := QueryStructConditional(initDB(), search, list, nil, total, 10, 1) + if err != nil { + log.Println(err) + } + log.Println(list, *total) +} + +func TestQueryStructConditionalPage(t *testing.T) { + search := struct { + Page int + Pagesize int + }{Page: 2, Pagesize: 2} + list, total := new([]User), new(int64) + err := QueryStructConditional(initDB(), search, list, nil, total, 10, 1) + if err != nil { + log.Println(err) + } + log.Println(list, *total) +} + +func TestQueryStructConditionalOrder(t *testing.T) { + search := struct { + OrderKey string + }{OrderKey: "descId"} + list, total := new([]User), new(int64) + err := QueryStructConditional(initDB(), search, list, nil, total, 10, 1) + if err != nil { + log.Println(err) + } + log.Println(list, *total) +} + +func TestQueryStructConditionalPage1Sum(t *testing.T) { + search := struct { + Page int + }{Page: 1} + sum := new(struct { + Level uint + }) + list, total := new([]User), new(int64) + err := QueryStructConditional(initDB(), search, list, sum, total, 10, 1) + if err != nil { + log.Println(err) + } + log.Println(sum, list, *total) +} + +func TestQueryStructConditionalNotAllowEmptyString(t *testing.T) { + search := struct { + Name string + }{Name: ""} + //search := struct { + // Name string + //}{Name: "foo"} + list, total := new([]User), new(int64) + err := QueryStructConditional(initDB(), search, list, nil, total, 10, 0) + if err != nil { + log.Println(err) + } + log.Println(list, *total) +} + +func TestQueryStructConditionalMaxCount(t *testing.T) { + list, total := new([]User), new(int64) + err := QueryStructConditional(initDB(), nil, list, nil, total, 2, 0) + if err != nil { + log.Println(err) + } + log.Println(list, *total) +} diff --git a/conditional/example.sql b/conditional/example.sql new file mode 100644 index 00000000..81eafcba --- /dev/null +++ b/conditional/example.sql @@ -0,0 +1,25 @@ +CREATE SCHEMA IF NOT EXISTS gorm_test; +USE gorm_test; + +DROP TABLE IF EXISTS `user`; +CREATE TABLE IF NOT EXISTS `user` +( + `id` INT unsigned NOT NULL AUTO_INCREMENT COMMENT 'user ID', + `name` VARCHAR(64) NOT NULL COMMENT '钱包地址', + `level` INT unsigned NOT NULL COMMENT '用户等级', + `status` int unsigned NOT NULL DEFAULT '0' COMMENT '结算状态 0: 正常 20禁用', + `created_at` bigint NOT NULL COMMENT '创建时间 毫秒', + `updated_at` bigint NOT NULL COMMENT '更新时间 毫秒', + PRIMARY KEY (`id`), + key key_created_at (created_at) +) ENGINE = InnoDB + DEFAULT CHARSET = utf8mb3 COMMENT ='user'; + +INSERT INTO gorm_test.user (name, level, status, created_at, updated_at) +VALUES + ('Boo', 1, 1, 1682597017126, 1682597017126), + ('Foo', 2, 2, 1682697017126, 1682697017126), + ('Hoo', 3, 3, 1682797017126, 1682797017126), + ('Ioo', 4, 4, 1682897017126, 1682897017126), + ('Joo', 5, 5, 1682997017126, 1682997017126), + ('Koo', 6, 6, 1683097017126, 1683097017126);