From 820dc0aa89a377c9bb86022e0d22993d66d29ee3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=96=B9=E6=98=8E?= Date: Wed, 28 May 2025 14:18:25 +0800 Subject: [PATCH] add apaas code --- apaas/checker.go | 92 ++++++++++++++ apaas/collection.go | 53 ++++++++ apaas/dailer.go | 1 + apaas/error.go | 9 ++ apaas/fetcher.go | 94 ++++++++++++++ apaas/idl.go | 29 +++++ apaas/reflect.go | 150 ++++++++++++++++++++++ apaas/types.go | 235 +++++++++++++++++++++++++++++++++++ apaas/view.go | 1 + apaas_mode.go | 18 +++ callbacks.go | 16 ++- callbacks/apaas_callbacks.go | 79 ++++++++++++ callbacks/callbacks.go | 10 ++ finisher_api.go | 1 + go.mod | 4 +- go.sum | 4 +- gorm.go | 30 +++++ logger/logger.go | 2 +- schema/field.go | 6 +- schema/schema.go | 5 + statement.go | 5 + 21 files changed, 838 insertions(+), 6 deletions(-) create mode 100644 apaas/checker.go create mode 100644 apaas/collection.go create mode 100644 apaas/dailer.go create mode 100644 apaas/error.go create mode 100644 apaas/fetcher.go create mode 100644 apaas/idl.go create mode 100644 apaas/reflect.go create mode 100644 apaas/types.go create mode 100644 apaas/view.go create mode 100644 apaas_mode.go create mode 100644 callbacks/apaas_callbacks.go diff --git a/apaas/checker.go b/apaas/checker.go new file mode 100644 index 00000000..33c7e09a --- /dev/null +++ b/apaas/checker.go @@ -0,0 +1,92 @@ +package apaas + +import ( + "encoding/json" + "fmt" +) + +type Checker interface { + Check(string) error +} + +func ExtraCheck(rule map[string]*ExtraFieldMeta, extra string) error { + var data map[string]any + err := json.Unmarshal([]byte(extra), &data) + if err != nil { + return err + } + for k, v := range data { + r, ok := rule[k] + if !ok || r == nil { + continue + } + err = FieldCheck(r, v) + if err != nil { + return GenError(err.Error()) + } + } + return nil +} + +func FieldCheck(rule *ExtraFieldMeta, value any) error { + if rule == nil || value == nil { + return nil + } + var err error + switch val := value.(type) { + case map[string]any: + if len(rule.ObjectMeta) == 0 { + return nil + } + for k, v := range val { + r, ok := rule.ObjectMeta[k] + if !ok || r == nil { + continue + } + err = FieldCheck(r, v) + if err != nil { + return fmt.Errorf("rule(key=%s, type=%s), value(type=map[%s], suberror=%s)", + rule.Key, FieldMapString[rule.Type], k, err.Error()) + } + } + case []any: + if rule.Type != FieldArray { + return fmt.Errorf("rule(key=%s, type=%s) dismatch value(type=array, value=%#v)", + rule.Key, FieldMapString[rule.Type], val) + } + if len(rule.ArrayMeta) != 0 { + for i, v := range val { + err = FieldCheck(rule.ArrayMeta[i], v) + if err != nil { + return fmt.Errorf("rule(key=%s, type=%s), value(type array, [%d] element suberror=%s)", + rule.Key, FieldMapString[rule.Type], i, err.Error()) + } + } + } + case string: + if rule.Type != FieldString { + return fmt.Errorf("rule(key=%s, type=%s) dismatch value(type=string, value=%#v)", + rule.Key, FieldMapString[rule.Type], val) + } + case float64: + if rule.Type != FieldInt && rule.Type != FieldFloat64 { + return fmt.Errorf("rule(key=%s, type=%s) dismatch value(type=float/int, value=%#v)", + rule.Key, FieldMapString[rule.Type], val) + } + if rule.Type == FieldInt && float64(int(val)) != val { + return fmt.Errorf("rule(key=%s, type=%s) dismatch value(type=float, value=%#v)", + rule.Key, FieldMapString[rule.Type], val) + } + case bool: + if rule.Type != FieldBool { + return fmt.Errorf("rule(key=%s, type=%s) dismatch value(type=bool, value=%#v)", + rule.Key, FieldMapString[rule.Type]) + } + case nil: + if rule.Type != FieldArray || rule.Type != FieldObject { + return fmt.Errorf("rule(key=%s, type=%s) dismatch value(type nil)", + rule.Key, FieldMapString[rule.Type]) + } + } + return nil +} diff --git a/apaas/collection.go b/apaas/collection.go new file mode 100644 index 00000000..92196ab2 --- /dev/null +++ b/apaas/collection.go @@ -0,0 +1,53 @@ +package apaas + +import ( + "sync" + "sync/atomic" +) + +var gDBCol atomic.Value + +func init() { + dbCol := &DBCollection{ + dbs: make(map[string]*DBMeta, 128), + } + SetDBCol(dbCol) +} + +func SetDBCol(dbCol *DBCollection) { + if dbCol != nil { + gDBCol.Store(dbCol) + } +} + +func GetDBCol() *DBCollection { + dbCol, ok := gDBCol.Load().(*DBCollection) + if ok { + return dbCol + } + return nil +} + +type DBCollection struct { + dbs map[string]*DBMeta + lock sync.RWMutex +} + +func (p *DBCollection) GetDB(dbName string) (*DBMeta, bool) { + p.lock.RLock() + v, ok := p.dbs[dbName] + p.lock.RUnlock() + return v, ok +} + +func (p *DBCollection) SetDB(dbName string, dbMeta *DBMeta) { + p.lock.Lock() + p.dbs[dbName] = dbMeta + p.lock.Unlock() +} + +func (p *DBCollection) DeleteDB(dbName string) { + p.lock.Lock() + delete(p.dbs, dbName) + p.lock.Unlock() +} diff --git a/apaas/dailer.go b/apaas/dailer.go new file mode 100644 index 00000000..3b13e103 --- /dev/null +++ b/apaas/dailer.go @@ -0,0 +1 @@ +package apaas diff --git a/apaas/error.go b/apaas/error.go new file mode 100644 index 00000000..452dc1e6 --- /dev/null +++ b/apaas/error.go @@ -0,0 +1,9 @@ +package apaas + +import "fmt" + +const MSG_PREFIX = "[apaas_engine]" + +func GenError(msg string) error { + return fmt.Errorf("%s %s", MSG_PREFIX, msg) +} diff --git a/apaas/fetcher.go b/apaas/fetcher.go new file mode 100644 index 00000000..a3b22f9b --- /dev/null +++ b/apaas/fetcher.go @@ -0,0 +1,94 @@ +package apaas + +import ( + "context" + + "gorm.io/gorm/logger" +) + +type DBFetcher interface { + Fetch() ([]*ApaasTable, error) + // each fetcher has only uniq name, Name is must DBName + // if DBName is nil, default fetch all + DBName() string +} + +var gAllFetcher = map[string]DBFetcher{} +var gDeltaFetcher = map[string]DBFetcher{} + +func AddDBFetcher(f DBFetcher) { + gAllFetcher[f.DBName()] = f +} + +// Update All DB metas, called by apaas_db_engine SDK +func UpdateAllDBCol() { + for name, f := range gAllFetcher { + if f.DBName() == "" { + tables, err := f.Fetch() + if err != nil { + logger.Default.Error(context.Background(), "%s fetcher(name=%s, DBName=%s) Fetch data error=%s", MSG_PREFIX, name, f.DBName(), err.Error()) + continue + } + logger.Default.Info(context.Background(), "%s fetcher(name=%s, DBName=%s) Fetch data len=%d", MSG_PREFIX, name, f.DBName(), len(tables)) + + dbCol := &DBCollection{ + dbs: make(map[string]*DBMeta, 128), + } + for _, table := range tables { + v, ok := dbCol.dbs[table.DBName] + if !ok { + v := &DBMeta{ + tableList: make([]*ApaasTable, 0, len(tables)>>2), + tableView: make(map[string]*ApaasTable, len(tables)>>2), + lookupIDView: make(map[string]*ApaasTable, len(tables)>>2), + } + dbCol.dbs[table.DBName] = v + } + v.tableList = append(v.tableList, table) + v.tableView[table.TableName] = table + if table.LookupIDField != nil { + v.lookupIDView[table.LookupIDField.Name] = table + } + } + SetDBCol(dbCol) + } + } + for name, f := range gAllFetcher { + if f.DBName() == "" { + continue + } + tables, err := f.Fetch() + if err != nil { + logger.Default.Error(context.Background(), "%s fetcher(name=%s, DBName=%s) Fetch data error=%s", MSG_PREFIX, name, f.DBName(), err.Error()) + continue + } + logger.Default.Info(context.Background(), "%s fetcher(name=%s, DBName=%s) Fetch data len=%d", MSG_PREFIX, name, f.DBName(), len(tables)) + + if len(tables) == 0 { + continue + } + if len(f.DBName()) != 0 { + dbMeta := &DBMeta{ + tableList: tables, + } + dbMeta.tableView = make(map[string]*ApaasTable, len(tables)) + dbMeta.lookupIDView = make(map[string]*ApaasTable, len(tables)) + for _, table := range tables { + dbMeta.tableView[table.TableName] = table + if table.LookupIDField != nil { + dbMeta.lookupIDView[table.LookupIDField.Name] = table + } + } + GetDBCol().SetDB(f.DBName(), dbMeta) + continue + } + + } +} + +/* +// Update DB metas, called by apaas_db_engine SDK +func UpdateDeltaDBCol() { + +} +*/ diff --git a/apaas/idl.go b/apaas/idl.go new file mode 100644 index 00000000..e60761eb --- /dev/null +++ b/apaas/idl.go @@ -0,0 +1,29 @@ +package apaas + +type ApaasQueryType int8 + +const ( + _skipQueryType ApaasQueryType = iota + SelectType + CreateType + InsertType + UpdateType + DeleteType + RawType +) + +type ApaasQueryArgs struct { + Select any + From any + Where any + Join any + Group any + Order any + Limit any + Offset any + Update any + Delete any + create any + RawSql string + QueryType ApaasQueryType +} diff --git a/apaas/reflect.go b/apaas/reflect.go new file mode 100644 index 00000000..b1a2c61e --- /dev/null +++ b/apaas/reflect.go @@ -0,0 +1,150 @@ +package apaas + +import ( + "fmt" + "reflect" + "strings" +) + +var _unknown_field = reflect.StructField{} + +func GetFieldTypeByColumnNameV2(value reflect.Value, fieldColumnName string) (reflect.StructField, bool) { + // check obj is pointer or not + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + typ := value.Type() + for i := 0; i < value.NumField(); i++ { + field := typ.Field(i) + fieldValue := value.Field(i) + if field.Anonymous { + return GetFieldTypeByColumnNameV2(fieldValue, fieldColumnName) + } + tag := field.Tag.Get("gorm") + if tag == "" { + continue + } + if cname, ok := getColumnNameByColumnTag(tag); ok && cname == fieldColumnName { + return field, ok + } + } + return _unknown_field, false +} + +/* +case: + + type Faction struct { + ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"id"` + LiveID int64 `gorm:"column:live_id" json:"live_id"` + BizID int64 `gorm:"column:biz_id" json:"biz_id"` + FactionID int64 `gorm:"column:faction_id;not null" json:"faction_id"` + FactionName string `gorm:"column:faction_name" json:"faction_name"` + OrgID int64 `gorm:"column:org_id;comment:union外键" json:"org_id" apass_engine_lookup_id:"webcast.union.org_id"` // union外键 + } + +fieldColumnName: faction_id +*/ +func GetFieldTypeByColumnName(obj any, fieldColumnName string) (reflect.StructField, bool) { + return GetFieldTypeByColumnNameV2(reflect.ValueOf(obj), fieldColumnName) +} + +/* +case: `gorm:"column:id;primaryKey;autoIncrement:true" +*/ +func getColumnNameByColumnTag(tag string) (string, bool) { + fs := strings.Split(tag, ";") + if len(fs) >= 1 { + sfs := strings.Split(fs[0], ":") + if sfs[0] == "column" { + return sfs[1], true + } + } + return "", false +} + +func GetFieldTypeByColumnNameV3(typ reflect.Type, fieldColumnName string) (reflect.StructField, bool) { + // check obj is pointer or not + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + for i := 0; i < typ.NumField(); i++ { + field := typ.Field(i) + if field.Anonymous { + return GetFieldTypeByColumnNameV3(field.Type, fieldColumnName) + } + tag := field.Tag.Get("gorm") + if tag == "" { + continue + } + if cname, ok := getColumnNameByColumnTag(tag); ok && cname == fieldColumnName { + return field, ok + } + } + return _unknown_field, false +} + +func ParseLookupTagMeta(tag, columnName, dbName, tableName string) (*ApaasLookupMeta, error) { + fmt.Printf("[apaas_engine] tag: %s, columnName: %s, dbName: %s, tableName: %s\n", tag, columnName, dbName, tableName) + if tag == "" { + return nil, nil + } + fs := strings.Split(tag, ".") + if len(fs) > MaxTagDeep { + return nil, GenError(fmt.Sprintf("apass_engine_lookup_value=%s, lookup deep=%d>%d", tag, len(fs), MaxTagDeep)) + } + meta := &ApaasLookupMeta{ + CName: columnName, + LookupMeta: make([]*LookupMeta, len(fs)-1), + LastField: fs[len(fs)-1], + OrgTag: fs, + } + dbCol := GetDBCol() + if dbCol == nil { + return nil, GenError(fmt.Sprintf("apass_engine_lookup_value=%s, cann't get db collection", tag)) + } + dbMeta, ok := dbCol.GetDB(dbName) + if !ok || dbMeta == nil { + return nil, GenError(fmt.Sprintf("apass_engine_lookup_value=%s, cann't get db(name=%s) meta", tag, dbName)) + } + + tableMeta, ok := dbMeta.tableView[tableName] + if !ok { + return nil, GenError(fmt.Sprintf("apass_engine_lookup_value=%s, cann't get table(name=%s) meta", tag, tableName)) + } + var idx int = 0 + for idx < len(fs)-1 { + lp := &LookupMeta{ + FieldName: fs[idx], + ForeignMeta: ForeignMeta{ + DBName: dbName, + FName: fs[idx+1], + }, + } + found := false + for _, ff := range tableMeta.ForeignFields { + if ff.Name == fs[idx] { + if fs[idx+1] != ff.foreignMeta.FName { + return nil, GenError(fmt.Sprintf("apass_engine_lookup_value=%s, db=%s, table=%s, field=%s, foreign(table=%s), foreign field=%s not equal to lookup field=%s)", tag, dbName, tableName, ff.Name, ff.foreignMeta.TName, fs[idx+1])) + } + if dbName != ff.foreignMeta.DBName { + return nil, GenError(fmt.Sprintf("apass_engine_lookup_value=%s, db=%s, table=%s, field=%s, foreign(table=%s), foreign db=%s not equal to lookup db=%s)", tag, tag, dbName, tableName, ff.Name, ff.foreignMeta.DBName, dbName)) + } + lp.ForeignMeta.TName = ff.foreignMeta.TName + lp.ForeignMeta.FTMeta = ff.foreignMeta.FTMeta + found = true + break + } + } + if !found { + return nil, GenError(fmt.Sprintf("apass_engine_lookup_value=%s, db=%s, table=%s, field=%s is not foreign key", tag, tag, dbName, tableName, fs[idx])) + } + meta.LookupMeta[idx] = lp + tableName = lp.ForeignMeta.TName + tableMeta, ok = dbMeta.tableView[tableName] + if !ok { + return nil, GenError(fmt.Sprintf("apass_engine_lookup_value=%s, cann't get table(name=%s) meta", tag, tableName)) + } + } + return meta, nil +} diff --git a/apaas/types.go b/apaas/types.go new file mode 100644 index 00000000..02c05ffc --- /dev/null +++ b/apaas/types.go @@ -0,0 +1,235 @@ +package apaas + +import ( + "strings" +) + +const ( + TagID = "apass_engine_lookup_id" + TagValue = "apass_engine_lookup_value" + MaxTagDeep = 4 +) + +type ApaasFieldType uint8 +type FieldType uint8 + +const ( + _skipFieldType FieldType = iota + FieldBool + FieldInt + FieldInt64 + FieldFloat + FieldFloat64 + FieldString + FieldTime + FieldBit + FieldBin + FieldArray + FieldObject +) + +const ( + _skipApaasFieldType ApaasFieldType = iota + ApaasLookupID + ApaasLookupValue + ApaasExtraType + ApaasFormulaType +) + +var FieldMapString = []string{ + "_skipApaasFieldTyle", + "FieldBool", + "FieldInt", + "FieldInt64", + "FieldFloat", + "FieldFloat64", + "FieldString", + "FieldTime", + "FieldBit", + "FieldBin", + "FieldArray", + "FieldObject", +} + +func (p FieldType) String() string { + return FieldMapString[p] +} + +/* +description: store dynamic field info when is an instance like tag + + `apass_engine_lookup_value` + +case: + + type RoomAnchorView struct { + Room + Gender string `gorm:"column:gender" json:"gender" apass_engine_lookup_value:"anchor_id.uid.gender"` + Career string `gorm:"column:career" json:"career" apass_engine_lookup_value:"anchor_id.uid.career"` + Name string `gorm:"column:name" json:"name" apass_engine_lookup_value:"anchor_id.uid.name"` + } + + type RoomAnchorFactionView struct { + Room + Gender string `gorm:"column:gender" json:"gender" apass_engine_lookup_value:"anchor_id.uid.gender"` + Career string `gorm:"column:career" json:"career" apass_engine_lookup_value:"anchor_id.uid.career"` + Name string `gorm:"column:name" json:"name" apass_engine_lookup_value:"anchor_id.uid.name"` + OrgName string `gorm:"column:org_name" json:"org_name" apass_engine_lookup_value:"anchor_id.faction_id.faction_name"` + UnionInfo string `gorm:"column:union_info" json:"union_info" apass_engine_lookup_value:"anchor_id.faction_id.org_id.union_info"` + } +*/ +type ApaasLookupMeta struct { + // field gorm column tag name/table field's name. example: union_info + CName string + // lookup orgin meta. example: [anchor_id.faction_id.org_id.union_info] + /* + 1. anchor_id; 2. faction_id; 3: org_id; + 1. anchor.anchor_id;2. Vction.faction_id;3: union.org_id; + */ + LookupMeta []*LookupMeta + LastField string // org_name/union_name + + /* + OrgTag only set value when used in SDK mode + */ + // lookup org tag [anchor_id, faction_id, org_id, org_name] + OrgTag []string +} + +// lookup orgin meta. example: [anchor_id.faction_id.org_id.union_info] +type LookupMeta struct { + FieldName string + ForeignMeta ForeignMeta +} + +type ApaasTable struct { + TableName string + DBName string + Fields []*ApaasField + FieldsByName map[string]*ApaasField + LookupIDField *ApaasField // example: room.room_id, user.uid, faction.faction_id + ForeignFields []*ApaasField // foreign key. example: room.anchor_id, named of relookupid + FormulaFields []*ApaasField // example: union.title=update_time + org_name + LookupValueFields []*ApaasField // example: org_id.org_name, anchor_id.org_id.union_id.union_name + ExtraFields []*ApaasField // example: anchor.extra, room.extra +} + +type ApaasField struct { + Name string + Type string + FType FieldType + IsUniq bool + IsForeign bool + foreignMeta *ForeignMeta + IsApaasType bool + ApaasMeta *ApaasMeta +} + +func (p *ApaasField) parseFieldType() { + tp := strings.ToUpper(p.Type) + ftp := _skipFieldType + switch tp { + case "BOOL": + ftp = FieldBool + case "INT", "TINYINT", "SMALLINT", "MEDIUMINT": + ftp = FieldInt + case "BIGINT": + ftp = FieldInt64 + case "FLOAT": + ftp = FieldFloat + case "DOUBLE", "DECIMAL", "REAL": + ftp = FieldFloat64 + case "DATE", "DATETIME", "TIMESTAMP", "TIME", "YEAR": + ftp = FieldTime + case "VARCHAR", "CHAR", "ENUM", "TEXT", "TINYTEXT", "MEDIUMTEXT", "LONGTEXT", "JSON": + ftp = FieldString + case "BLOB", "TINYBLOB", "MEDIUMBLOB", "LONGBLOB", "BINARY", "VARBINARY": + ftp = FieldBin + case "BIT": + ftp = FieldBit + default: + ftp = FieldString + } + p.FType = ftp +} + +func (p *ApaasField) GetApaasMeta() *ApaasMeta { + return p.ApaasMeta +} + +type ForeignMeta struct { + DBName string + TName string + FName string + FTMeta *ApaasTable +} + +type ApaasMeta struct { + ApaasFType ApaasFieldType + ExtraMeta map[string]*ExtraFieldMeta + FormulaMeta *FormulaMeta + LookupMeta *ApaasLookupMeta + Checker +} + +func (p *ApaasMeta) IsApaasFieldType() bool { + return p.ApaasFType == _skipApaasFieldType +} +func (p *ApaasMeta) IsExtraField() bool { + return p.ApaasFType == ApaasExtraType +} +func (p *ApaasMeta) IsFormulaField() bool { + return p.ApaasFType == ApaasFormulaType +} +func (p *ApaasMeta) IsLookupID() bool { + return p.ApaasFType == ApaasLookupID +} +func (p *ApaasMeta) IsLookupValue() bool { + return p.ApaasFType == ApaasLookupValue +} + +func (p *ApaasMeta) GetApaasFieldType() ApaasFieldType { + return p.ApaasFType +} +func (p *ApaasMeta) GetExtraMeta() map[string]*ExtraFieldMeta { + return p.ExtraMeta +} +func (p *ApaasMeta) GetFormulaMeta() *FormulaMeta { + return p.FormulaMeta +} +func (p *ApaasMeta) Check(extra string) error { + // step1: extra rule check + err := ExtraCheck(p.ExtraMeta, extra) + return err +} + +type ExtraFieldMeta struct { + Key string + Type FieldType // Bool/Int/Float/String/Object/Array + ObjectMeta map[string]*ExtraFieldMeta // if Type is Object, use ObjectMeta + ArrayMeta []*ExtraFieldMeta // if Type is Array, need ArrayMeta +} + +type FormulaMeta struct { + InputFields map[string]*ApaasField + FormulaRule *FormulaRule +} + +type FormulaRule struct { +} + +type DBMeta struct { + tableList []*ApaasTable + tableView map[string]*ApaasTable + lookupIDView map[string]*ApaasTable // each table has one one and only key to supprort lookup +} + +func (p *DBMeta) GetTableByLookupID(lookID string) (*ApaasTable, bool) { + v, ok := p.lookupIDView[lookID] + return v, ok +} + +func (p *DBMeta) GetTableByName(tableName string) (*ApaasTable, bool) { + v, ok := p.tableView[tableName] + return v, ok +} diff --git a/apaas/view.go b/apaas/view.go new file mode 100644 index 00000000..3b13e103 --- /dev/null +++ b/apaas/view.go @@ -0,0 +1 @@ +package apaas diff --git a/apaas_mode.go b/apaas_mode.go new file mode 100644 index 00000000..4c3ed70f --- /dev/null +++ b/apaas_mode.go @@ -0,0 +1,18 @@ +package gorm + +type ApaasModeType uint8 + +const ( + DirectMode ApaasModeType = iota + EngineMode +) + +func (m ApaasModeType) String() string { + switch m { + case DirectMode: + return "DirectMode" + case EngineMode: + return "EngineMode" + } + return "DirectMode" +} diff --git a/callbacks.go b/callbacks.go index 50b5b0e9..d0e3e6df 100644 --- a/callbacks.go +++ b/callbacks.go @@ -8,6 +8,7 @@ import ( "sort" "time" + "gorm.io/gorm/apaas" "gorm.io/gorm/schema" "gorm.io/gorm/utils" ) @@ -109,8 +110,8 @@ func (p *processor) Execute(db *DB) *DB { db.AddError(err) } } + parseLookupTagMeta(stmt) } - // assign stmt.ReflectValue if stmt.Dest != nil { stmt.ReflectValue = reflect.ValueOf(stmt.Dest) @@ -358,3 +359,16 @@ func removeCallbacks(cs []*callback, nameMap map[string]bool) []*callback { } return callbacks } + +func parseLookupTagMeta(stmt *Statement) { + for _, field := range stmt.Schema.Fields { + lookupMeta, err := apaas.ParseLookupTagMeta(field.Tag.Get(apaas.TagValue), field.DBName, stmt.DBName, stmt.Schema.Table) + if err != nil { + stmt.DB.AddError(err) + } + if lookupMeta != nil { + field.LookupMeta = lookupMeta + stmt.ApaasMode = EngineMode + } + } +} diff --git a/callbacks/apaas_callbacks.go b/callbacks/apaas_callbacks.go new file mode 100644 index 00000000..edbe6289 --- /dev/null +++ b/callbacks/apaas_callbacks.go @@ -0,0 +1,79 @@ +package callbacks + +import ( + "fmt" + "reflect" + + "gorm.io/gorm" + "gorm.io/gorm/apaas" +) + +var dbNameCaller func(*gorm.DB) (string, error) + +func SetDBNameCaller(fn func(*gorm.DB) (string, error)) { + dbNameCaller = fn +} + +func ExtraCheckerCallBack(stage string) func(db *gorm.DB) { + return func(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil { + if db.Config.DBName == "" { + if dbNameCaller != nil { + db.Config.DBName, _ = dbNameCaller(db) + } else { + db.Config.DBName, _ = db.GetDBName() + } + } + dbName := db.Config.DBName + if dbName == "" { + //db.Error = db.AddError(GenError(fmt.Sprintf("%s ExtraCheckerCallBack(stage=%s) GetDBName nil", MSG_PREFIX, stage))) + return + } + /* + db.Logger.Info(db.Statement.Context, "===schema: %#v\n", db.Statement.Schema.Fields) + for i, s := range db.Statement.Schema.Fields { + db.Logger.Info(db.Statement.Context, "===schema[i=%d]: %#v\n", i, *s) + } + */ + dbCol := apaas.GetDBCol() + if dbCol == nil { + //db.Error = db.AddError(GenError(fmt.Sprintf("%s ExtraCheckerCallBack(stage=%s) GetDBCollection nil ", MSG_PREFIX, stage))) + return + } + dbMeta, ok := dbCol.GetDB(dbName) + if !ok { + //db.Error = db.AddError(GenError(fmt.Sprintf("%s ExtraCheckerCallBack(stage=%s) GetDB(db=%s) nil", dbName, MSG_PREFIX, stage))) + return + } + tableMeta, ok := dbMeta.GetTableByName(db.Statement.Table) + if !ok { + //db.Error = db.AddError(GenError(fmt.Sprintf("%s ExtraCheckerCallBack(stage=%s) GetTable(db=%s,table=%s) GetDB nil", MSG_PREFIX, stage, dbName, db.Statement.Table))) + return + } + db.Logger.Info(db.Statement.Context, "%s ExtraCheckerCallBack(stage=%s) db_name=%s, table=%s", apaas.MSG_PREFIX, stage, dbName, tableMeta.TableName) + val := reflect.ValueOf(db.Statement.Dest) + for _, extraField := range tableMeta.ExtraFields { + db.Logger.Info(db.Statement.Context, "%s ExtraCheckerCallBack(stage=%s) db_name=%s, table=%s, check extra field=%s begin", apaas.MSG_PREFIX, stage, dbName, db.Statement.Table, extraField.Name) + field, ok := db.Statement.Schema.FieldsByDBName[extraField.Name] + if !ok { + db.Error = db.AddError(apaas.GenError(fmt.Sprintf("ExtraCheckerCallBack(stage=%s) Extra Field(db=%s,table=%s,field=%s) not found value in DestValue", stage, dbName, db.Statement.Table, field.DBName))) + return + } + v, _ := field.ValueOf(db.Statement.Context, val) + strV, ok := v.(string) + pStrV, ok1 := v.(*string) + if !ok && !ok1 { + db.Error = db.AddError(apaas.GenError(fmt.Sprintf("ExtraCheckerCallBack(stage=%s) Extra Field(db=%s,table=%s,field=%s) is not string/*string value in DestValue", stage, dbName, db.Statement.Table, field.DBName))) + return + } + if ok1 { + strV = *pStrV + } + if err := extraField.GetApaasMeta().Check(strV); err != nil { + db.Error = db.AddError(apaas.GenError(fmt.Sprintf("ExtraCheckerCallBack(stage=%s) Extra Field(db=%s,table=%s,field=%s) check error=%s", stage, dbName, db.Statement.Table, field.DBName, err.Error()))) + return + } + } + } + } +} diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go index d681aef3..ac3b1827 100644 --- a/callbacks/callbacks.go +++ b/callbacks/callbacks.go @@ -38,12 +38,17 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { } createCallback := db.Callback().Create() + // =====apaas callback========== + // apaas create callback register before transaction + createCallback.Before("gorm:create").Register("apaas_plugin:before_create", ExtraCheckerCallBack("create")) + // =====apaas callback end====== createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) createCallback.Register("gorm:before_create", BeforeCreate) createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(true)) createCallback.Register("gorm:create", Create(config)) createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true)) createCallback.Register("gorm:after_create", AfterCreate) + createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) createCallback.Clauses = config.CreateClauses @@ -63,6 +68,10 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { deleteCallback.Clauses = config.DeleteClauses updateCallback := db.Callback().Update() + // =====apaas callback========== + // apaas update callback register before transaction + updateCallback.Register("apaas_plugin:before_update", ExtraCheckerCallBack("update")) + // =====apaas callback end====== updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue) updateCallback.Register("gorm:before_update", BeforeUpdate) @@ -70,6 +79,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { updateCallback.Register("gorm:update", Update(config)) updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false)) updateCallback.Register("gorm:after_update", AfterUpdate) + updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) updateCallback.Clauses = config.UpdateClauses diff --git a/finisher_api.go b/finisher_api.go index 6802945c..5e546103 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -74,6 +74,7 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { func (db *DB) Save(value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = value + db.Logger.Info(db.Statement.Context, "-----gorm save value: %#v", tx.Statement.Dest) reflectValue := reflect.Indirect(reflect.ValueOf(value)) for reflectValue.Kind() == reflect.Ptr || reflectValue.Kind() == reflect.Interface { diff --git a/go.mod b/go.mod index 3060fc8f..d846b1a9 100644 --- a/go.mod +++ b/go.mod @@ -5,5 +5,7 @@ go 1.18 require ( github.com/jinzhu/inflection v1.0.0 github.com/jinzhu/now v1.1.5 - golang.org/x/text v0.20.0 + golang.org/x/text v0.14.0 ) + +replace gorm.io/gorm => ./ diff --git a/go.sum b/go.sum index 9af11572..e3e29009 100644 --- a/go.sum +++ b/go.sum @@ -2,5 +2,5 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= -golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug= -golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= diff --git a/gorm.go b/gorm.go index 63a28b37..f43ca585 100644 --- a/gorm.go +++ b/gorm.go @@ -69,6 +69,10 @@ type Config struct { callbacks *callbacks cacheStore *sync.Map + + /* ====Apaas Begin==== */ + DBName string + /* ====Apaas End====== */ } // Apply update config to new config @@ -224,6 +228,9 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { config.Logger.Error(context.Background(), "failed to initialize database, got error %v", err) } + dbName, _ := db.GetDBName() + db.Config.DBName = dbName + return } @@ -524,3 +531,26 @@ func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string { return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) } + +func (db *DB) GetDBName() (string, error) { + var dbName string + var err error + // must be create new db + db1 := db.WithContext(context.Background()) + // 检测数据库类型 + switch db1.Dialector.Name() { + case "mysql": + err = db1.Raw("SELECT DATABASE()").Scan(&dbName).Error + case "postgres": + err = db1.Raw("SELECT current_database()").Scan(&dbName).Error + case "sqlite": + // SQLite中数据库名通常是文件路径 + var path string + err = db1.Raw("PRAGMA database_list").Scan(&struct{ Name, File string }{}).Error + dbName = path + default: + return "", fmt.Errorf("unsupported database dialect") + } + + return dbName, err +} diff --git a/logger/logger.go b/logger/logger.go index 8088cde2..46e3af5c 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -74,7 +74,7 @@ var ( // Default Default logger Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ SlowThreshold: 200 * time.Millisecond, - LogLevel: Warn, + LogLevel: Info, IgnoreRecordNotFoundError: false, Colorful: true, }) diff --git a/schema/field.go b/schema/field.go index d1a633ce..dc16879b 100644 --- a/schema/field.go +++ b/schema/field.go @@ -12,6 +12,7 @@ import ( "time" "github.com/jinzhu/now" + "gorm.io/gorm/apaas" "gorm.io/gorm/clause" "gorm.io/gorm/utils" ) @@ -96,6 +97,10 @@ type Field struct { // It causes field unnecessarily migration. // Therefore, we need to record the UniqueIndex on this column (exclude Mul UniqueIndex) for MigrateColumnUnique. UniqueIndex string + + // ==========apaas engine field begin========== + LookupMeta *apaas.ApaasLookupMeta + // ==========apaas engine field end========== } func (field *Field) BindName() string { @@ -131,7 +136,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { Comment: tagSetting["COMMENT"], AutoIncrementIncrement: DefaultAutoIncrementIncrement, } - for field.IndirectFieldType.Kind() == reflect.Ptr { field.IndirectFieldType = field.IndirectFieldType.Elem() } diff --git a/schema/schema.go b/schema/schema.go index db236797..2e75e184 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -57,6 +57,11 @@ type Schema struct { initialized chan struct{} namer Namer cacheStore *sync.Map + // functional for apaas engine add by fangming + // ==========apaas engine field begin========== + ApaasLookupFields []*Field + ApaasFormulaFields []*Field + // ==========apaas engine field end========== } func (schema Schema) String() string { diff --git a/statement.go b/statement.go index 39e05d09..904edc4c 100644 --- a/statement.go +++ b/statement.go @@ -47,6 +47,11 @@ type Statement struct { attrs []interface{} assigns []interface{} scopes []func(*DB) *DB + + // ==========apaas engine field begin========== + // apaas mode for gormDB + ApaasMode ApaasModeType + // ==========apaas engine field end========== } type join struct {