diff --git a/dialect.go b/dialect.go index e879588b..1b97da2b 100644 --- a/dialect.go +++ b/dialect.go @@ -29,6 +29,8 @@ type Dialect interface { HasForeignKey(tableName string, foreignKeyName string) bool // RemoveIndex remove index RemoveIndex(tableName string, indexName string) error + // RemoveForeignKey remove foreign key + RemoveForeignKey(tableName string, foreignKeyName string) error // HasTable check has table or not HasTable(tableName string) bool // HasColumn check has column or not diff --git a/dialect_common.go b/dialect_common.go index a99627f2..b3b0afbd 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -105,6 +105,11 @@ func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bo return false } +func (s commonDialect) RemoveForeignKey(tableName string, foreignKeyName string) error { + _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v DROP CONSTRAINT %v", tableName, foreignKeyName)) + return err +} + func (s commonDialect) HasTable(tableName string) bool { var count int s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", s.CurrentDatabase(), tableName).Scan(&count) diff --git a/dialect_mysql.go b/dialect_mysql.go index 271670b8..6cf2585e 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -132,6 +132,11 @@ func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool { return count > 0 } +func (s mysql) RemoveForeignKey(tableName string, foreignKeyName string) error { + _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v DROP FOREIGN KEY %v", s.Quote(tableName), foreignKeyName)) + return err +} + func (s mysql) CurrentDatabase() (name string) { s.db.QueryRow("SELECT DATABASE()").Scan(&name) return diff --git a/main.go b/main.go index 97cff7db..a46a4102 100644 --- a/main.go +++ b/main.go @@ -602,6 +602,14 @@ func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate return scope.db } +// RemoveForeignKey Removes foreign key from the given scope, e.g: +// db.Model(&User{}).RemoveForeignKey("user_city_id_city_id_foreign") +func (s *DB) RemoveForeignKey(keyName string) *DB { + scope := s.NewScope(s.Value) + scope.removeForeignKey(keyName) + return scope.db +} + // Association start `Association Mode` to handler relations things easir in that mode, refer: https://jinzhu.github.io/gorm/associations.html#association-mode func (s *DB) Association(column string) *Association { var err error diff --git a/migration_test.go b/migration_test.go index 9fc14fa0..576eb761 100644 --- a/migration_test.go +++ b/migration_test.go @@ -191,6 +191,16 @@ type Comment struct { Post Post } +type Class struct { + Id int64 + Year string +} + +type Student struct { + Id int64 + ClassID int64 +} + // Scanner type NullValue struct { Id int64 @@ -254,7 +264,7 @@ func runMigration() { DB.Exec(fmt.Sprintf("drop table %v;", table)) } - values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}} + values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}, &Class{}, &Student{}} for _, value := range values { DB.DropTable(value) } @@ -332,6 +342,25 @@ func TestIndexes(t *testing.T) { } } +func TestForeignKeys(t *testing.T) { + if err := DB.Model(&Student{}).AddForeignKey("class_id", "classes (id)", "RESTRICT", "RESTRICT").Error; err != nil { + t.Errorf("Got error while trying to create foreign key: %+v", err) + } + + scope := DB.NewScope(&Student{}) + if !scope.Dialect().HasForeignKey(scope.TableName(), "student_class_id_class_id_foreign") { + t.Errorf("Student should have foreign key students_class_id_classes_id_foreign") + } + + if err := DB.Model(&Student{}).RemoveForeignKey("students_class_id_classes_id_foreign").Error; err != nil { + t.Errorf("Got error while trying to remove foreign key: %+v", err) + } + + if scope.Dialect().HasForeignKey(scope.TableName(), "student_class_id_class_id_foreign") { + t.Errorf("Student should no longer have foreign key students_class_id_classes_id_foreign") + } +} + type EmailWithIdx struct { Id int64 UserId int64 diff --git a/scope.go b/scope.go index 9a237998..ad19a9b2 100644 --- a/scope.go +++ b/scope.go @@ -1168,6 +1168,14 @@ func (scope *Scope) removeIndex(indexName string) { scope.Dialect().RemoveIndex(scope.TableName(), indexName) } +func (scope *Scope) removeForeignKey(keyName string) { + if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) { + return + } + + scope.Dialect().RemoveForeignKey(scope.TableName(), keyName) +} + func (scope *Scope) autoMigrate() *Scope { tableName := scope.TableName() quotedTableName := scope.QuotedTableName()