diff --git a/migrator.go b/migrator.go index 7dddcabf..6e95ad71 100644 --- a/migrator.go +++ b/migrator.go @@ -33,6 +33,14 @@ type ViewOption struct { Query *DB } +// ALLTables get all database tables +const ALLTables = "migrator:all_tables" + +// Table Database table list info +type Table struct { + TableName string `gorm:"column:TABLE_NAME"` +} + type ColumnType interface { Name() string DatabaseTypeName() string @@ -54,7 +62,7 @@ type Migrator interface { DropTable(dst ...interface{}) error HasTable(dst interface{}) bool RenameTable(oldName, newName interface{}) error - + GetTables(tables ...string) (tableList []Table, err error) // Columns AddColumn(dst interface{}, field string) error DropColumn(dst interface{}, field string) error diff --git a/migrator/migrator.go b/migrator/migrator.go index 30586a8c..2ec220cf 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -18,6 +18,13 @@ var ( regFullDataType = regexp.MustCompile(`[^\d]*(\d+)[^\d]?`) ) +const ( + // allTableQuery query db table list + allTableQuery = "SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?" + // tablesQuery query tables is has + tablesQuery = "SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=? and TABLE_NAME in (?)" +) + // Migrator m struct type Migrator struct { Config @@ -155,6 +162,13 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { return nil } +func (m Migrator) GetTables(tables ...string) (tableList []gorm.Table, err error) { + if len(tables) == 1 && tables[0] == gorm.ALLTables { + return tableList, m.DB.Raw(allTableQuery, m.CurrentDatabase()).Scan(&tableList).Error + } + return tableList, m.DB.Raw(tablesQuery, m.CurrentDatabase(), tables).Scan(&tableList).Error +} + func (m Migrator) CreateTable(values ...interface{}) error { for _, value := range m.ReorderModels(values, false) { tx := m.DB.Session(&gorm.Session{}) diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 0354e84e..a5efb874 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -16,6 +16,8 @@ func TestMigrate(t *testing.T) { rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) DB.Migrator().DropTable("user_speaks", "user_friends", "ccc") + DB.Migrator().GetTables("user_speaks", "user_friends", "ccc") + DB.Migrator().GetTables(gorm.ALLTables) if err := DB.Migrator().DropTable(allModels...); err != nil { t.Fatalf("Failed to drop table, got error %v", err)