From 62dcd7896accb4cedfd9428a03a99332281da2a0 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 20 Feb 2020 23:04:03 +0800 Subject: [PATCH] Add Migrator --- callbacks.go | 5 +- helpers.go | 2 + migrator.go | 12 +++- migrator/migrator.go | 153 ++++++++++++++++++++++++++++++++++++++++++- statement.go | 7 ++ 5 files changed, 172 insertions(+), 7 deletions(-) diff --git a/callbacks.go b/callbacks.go index 8546ae16..4f19a681 100644 --- a/callbacks.go +++ b/callbacks.go @@ -75,13 +75,10 @@ func (p *processor) Execute(db *DB) { } if stmt.Model != nil { - var err error - stmt.Schema, err = schema.Parse(stmt.Model, db.cacheStore, db.NamingStrategy) + err := stmt.Parse(stmt.Model) if err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || stmt.Table == "") { db.AddError(err) - } else if stmt.Table == "" && stmt.Schema != nil { - stmt.Table = stmt.Schema.Table } } } diff --git a/helpers.go b/helpers.go index 2e5c8ed1..d7177ba7 100644 --- a/helpers.go +++ b/helpers.go @@ -15,6 +15,8 @@ var ( ErrInvalidTransaction = errors.New("no valid transaction") // ErrUnaddressable unaddressable value ErrUnaddressable = errors.New("using unaddressable value") + // ErrNotImplemented not implemented + ErrNotImplemented = errors.New("not implemented") ) // Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt diff --git a/migrator.go b/migrator.go index c21cda42..b6d273e7 100644 --- a/migrator.go +++ b/migrator.go @@ -4,6 +4,11 @@ import ( "database/sql" ) +// Migrator returns migrator +func (db *DB) Migrator() Migrator { + return db.Dialector.Migrator() +} + // ViewOption view option type ViewOption struct { Replace bool @@ -15,10 +20,13 @@ type Migrator interface { // AutoMigrate AutoMigrate(dst ...interface{}) error + // Database + CurrentDatabase() string + // Tables CreateTable(dst ...interface{}) error DropTable(dst ...interface{}) error - HasTable(dst ...interface{}) error + HasTable(dst ...interface{}) bool RenameTable(oldName, newName string) error // Columns @@ -39,6 +47,6 @@ type Migrator interface { // Indexes CreateIndex(dst interface{}, name string) error DropIndex(dst interface{}, name string) error - HasIndex(dst interface{}, name string) error + HasIndex(dst interface{}, name string) bool RenameIndex(dst interface{}, oldName, newName string) error } diff --git a/migrator/migrator.go b/migrator/migrator.go index 0ff83ac1..e9725935 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -1,6 +1,11 @@ package migrator -import "github.com/jinzhu/gorm" +import ( + "database/sql" + "fmt" + + "github.com/jinzhu/gorm" +) // Migrator migrator struct type Migrator struct { @@ -12,3 +17,149 @@ type Config struct { CheckExistsBeforeDropping bool DB *gorm.DB } + +func (migrator Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error { + stmt := migrator.DB.Statement + if stmt == nil { + stmt = &gorm.Statement{DB: migrator.DB} + } + + if err := stmt.Parse(value); err != nil { + return err + } + + return fc(stmt) +} + +// AutoMigrate +func (migrator Migrator) AutoMigrate(values ...interface{}) error { + return gorm.ErrNotImplemented +} + +func (migrator Migrator) CreateTable(values ...interface{}) error { + return gorm.ErrNotImplemented +} + +func (migrator Migrator) DropTable(values ...interface{}) error { + for _, value := range values { + if err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + return migrator.DB.Exec("DROP TABLE " + stmt.Quote(stmt.Table)).Error + }); err != nil { + return err + } + } + return nil +} + +func (migrator Migrator) HasTable(values ...interface{}) bool { + var count int64 + for _, value := range values { + err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + currentDatabase := migrator.DB.Migrator().CurrentDatabase() + return migrator.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentDatabase, stmt.Table, "BASE TABLE").Scan(&count).Error + }) + + if err != nil || count == 0 { + return false + } + } + + return true +} + +func (migrator Migrator) RenameTable(oldName, newName string) error { + return migrator.DB.Exec("RENAME TABLE ? TO ?", oldName, newName).Error +} + +func (migrator Migrator) AddColumn(value interface{}, field string) error { + return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + if field := stmt.Schema.LookUpField(field); field != nil { + return migrator.DB.Exec(fmt.Sprintf("ALTER TABLE ? ADD ? %s", field.DBDataType), stmt.Table, field.DBName).Error + } + return fmt.Errorf("failed to look up field with name: %s", field) + }) +} + +func (migrator Migrator) DropColumn(value interface{}, field string) error { + return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + if field := stmt.Schema.LookUpField(field); field != nil { + return migrator.DB.Exec("ALTER TABLE ? DROP COLUMN ?", stmt.Table, field.DBName).Error + } + return fmt.Errorf("failed to look up field with name: %s", field) + }) +} + +func (migrator Migrator) AlterColumn(value interface{}, field string) error { + return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + if field := stmt.Schema.LookUpField(field); field != nil { + return migrator.DB.Exec(fmt.Sprintf("ALTER TABLE ? ALTER COLUMN ? TYPE %s", field.DBDataType), stmt.Table, field.DBName).Error + } + return fmt.Errorf("failed to look up field with name: %s", field) + }) +} + +func (migrator Migrator) RenameColumn(value interface{}, oldName, field string) error { + return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + if field := stmt.Schema.LookUpField(field); field != nil { + oldName = migrator.DB.NamingStrategy.ColumnName(stmt.Table, oldName) + return migrator.DB.Exec("ALTER TABLE ? RENAME COLUMN ? TO ?", stmt.Table, oldName, field.DBName).Error + } + return fmt.Errorf("failed to look up field with name: %s", field) + }) +} + +func (migrator Migrator) ColumnTypes(value interface{}) ([]*sql.ColumnType, error) { + return nil, gorm.ErrNotImplemented +} + +func (migrator Migrator) CreateView(name string, option gorm.ViewOption) error { + return gorm.ErrNotImplemented +} + +func (migrator Migrator) DropView(name string) error { + return gorm.ErrNotImplemented +} + +func (migrator Migrator) CreateConstraint(value interface{}, name string) error { + return gorm.ErrNotImplemented +} + +func (migrator Migrator) DropConstraint(value interface{}, name string) error { + return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + return migrator.DB.Raw("ALTER TABLE ? DROP CONSTRAINT ?", stmt.Table, name).Error + }) +} + +func (migrator Migrator) CreateIndex(value interface{}, name string) error { + return gorm.ErrNotImplemented +} + +func (migrator Migrator) DropIndex(value interface{}, name string) error { + return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + return migrator.DB.Raw("DROP INDEX ? ON ?", name, stmt.Table).Error + }) +} + +func (migrator Migrator) HasIndex(value interface{}, name string) bool { + var count int64 + migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + currentDatabase := migrator.DB.Migrator().CurrentDatabase() + return migrator.DB.Raw("SELECT count(*) FROM information_schema.statistics WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, stmt.Table, name).Scan(&count).Error + }) + + if count != 0 { + return true + } + return false +} + +func (migrator Migrator) RenameIndex(value interface{}, oldName, newName string) error { + return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + return migrator.DB.Exec("ALTER TABLE ? RENAME INDEX ? TO ?", stmt.Table, oldName, newName).Error + }) +} + +func (migrator Migrator) CurrentDatabase() (name string) { + migrator.DB.Raw("SELECT DATABASE()").Scan(&name) + return +} diff --git a/statement.go b/statement.go index b2626d95..8c75c90d 100644 --- a/statement.go +++ b/statement.go @@ -267,3 +267,10 @@ func (stmt *Statement) Build(clauses ...string) { } // TODO handle named vars } + +func (stmt *Statement) Parse(value interface{}) (err error) { + if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" { + stmt.Table = stmt.Schema.Table + } + return err +}