From 2d8faf10b40cab0d3c370aebd65bb73b34585352 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 26 Feb 2018 22:53:38 +0800 Subject: [PATCH] Implement GetTable method --- dialects/sqlite/sqlite.go | 14 +++---- model/errors.go | 8 ++++ model/model.go | 80 +++++++++++++++++++++++---------------- 3 files changed, 61 insertions(+), 41 deletions(-) create mode 100644 model/errors.go diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go index 5c519d79..de425e52 100644 --- a/dialects/sqlite/sqlite.go +++ b/dialects/sqlite/sqlite.go @@ -4,7 +4,6 @@ import ( "bytes" "github.com/jinzhu/gorm" - "github.com/jinzhu/gorm/builder" "github.com/jinzhu/gorm/dialects/common/utils" // import sqlite3 driver @@ -16,13 +15,12 @@ type Dialect struct { } // Insert insert -func (*Dialect) Insert(stmt *builder.Statement) error { +func (*Dialect) Insert(tx *gorm.DB) error { var ( args []interface{} defaultValueColumns []string - errs = gorm.Errors{} - assignmentsChan = utils.GetCreatingAssignments(stmt, &errs) - tableNameChan = utils.GetTable(stmt, &errs) + assignmentsChan = utils.GetCreatingAssignments(tx) + tableNameChan = utils.GetTable(tx) ) s := bytes.NewBufferString("INSERT INTO ") @@ -40,16 +38,16 @@ func (*Dialect) Insert(stmt *builder.Statement) error { } // Query query -func (*Dialect) Query(*builder.Statement) error { +func (*Dialect) Query(tx *gorm.DB) error { return nil } // Update update -func (*Dialect) Update(*builder.Statement) error { +func (*Dialect) Update(tx *gorm.DB) error { return nil } // Delete delete -func (*Dialect) Delete(*builder.Statement) error { +func (*Dialect) Delete(tx *gorm.DB) error { return nil } diff --git a/model/errors.go b/model/errors.go new file mode 100644 index 00000000..f4f2be55 --- /dev/null +++ b/model/errors.go @@ -0,0 +1,8 @@ +package model + +import "errors" + +var ( + // ErrInvalidTable invalid table name + ErrInvalidTable = errors.New("invalid table name") +) diff --git a/model/model.go b/model/model.go index 833e916a..878d4c90 100644 --- a/model/model.go +++ b/model/model.go @@ -2,56 +2,70 @@ package model import ( "github.com/jinzhu/gorm" - "github.com/jinzhu/gorm/builder" "github.com/jinzhu/gorm/schema" + "github.com/jinzhu/inflection" ) // DefaultTableNameHandler default table name handler -var DefaultTableNameHandler = func(stmt *builder.Statement, tableName string) string { - return tableName -} +// DefaultTableNameHandler = func(tx *gorm.DB, tableName string) string { +// return tableName +// } +var DefaultTableNameHandler func(tx *gorm.DB, tableName string) string // GetCreatingAssignments get creating assignments -func GetCreatingAssignments(stmt *builder.Statement, errs *gorm.Errors) chan []schema.Field { +func GetCreatingAssignments(tx *gorm.DB) chan []schema.Field { return nil } -// GetTable get table name -func GetTable(stmt *builder.Statement, errs *gorm.Errors) chan string { +// GetTable get table name for current db operation +func GetTable(tx *gorm.DB) chan string { tableChan := make(chan string) go func() { - if stmt.Table != nil { - if table, ok := stmt.Table.(string); ok { - tableChan <- DefaultTableNameHandler(stmt, table) - } else if tableSchema := schema.Parse(stmt.Table); tableSchema != nil { - if tableSchema.TableName != "" { - tableChan <- DefaultTableNameHandler(stmt, tableSchema.TableName) + var tableName string + if name, ok := tx.Statement.Table.(string); ok { + tableName = name + } else { + for _, v := range []interface{}{tx.Statement.Table, tx.Statement.Dest} { + if t, ok := v.(tabler); ok { + tableName = t.TableName() + } else if t, ok := v.(dbTabler); ok { + tableName = t.TableName(tx) + } else if s := schema.Parse(tx.Statement.Table); s != nil { + if s.TableName != "" { + tableName = s.TableName + } else { + tableName = schema.ToDBName(s.ModelType.Name()) + if !tx.Config.SingularTable { + tableName = inflection.Plural(tableName) + } + } + } + + if tableName != "" { + break } - tableSchema.ModelType.Name } } + + if tableName != "" { + if DefaultTableNameHandler != nil { + tableChan <- DefaultTableNameHandler(tx, tableName) + } else { + tableChan <- tableName + } + } else { + tx.AddError(ErrInvalidTable) + } }() return tableChan } -// if scope.Value == nil { -// return &modelStruct -// } -// TableName get model's table name -// func (schema *Schema) TableName(stmt *builder.Statement) string { -// if s.defaultTableName == "" && db != nil && s.ModelType != nil { -// // Set default table name -// if tabler, ok := reflect.New(s.ModelType).Interface().(tabler); ok { -// s.defaultTableName = tabler.TableName() -// } else { -// tableName := ToDBName(s.ModelType.Name()) -// if db == nil || !db.parent.singularTable { -// tableName = inflection.Plural(tableName) -// } -// s.defaultTableName = tableName -// } -// } -// return DefaultTableNameHandler(db, s.defaultTableName) -// } +type tabler interface { + TableName() string +} + +type dbTabler interface { + TableName(*gorm.DB) string +}