From d0764bead1bb0283c1f68842ce39cb4a001b8676 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 21 Jun 2020 13:53:13 +0800 Subject: [PATCH] Test migrate with comment and check created constraints --- migrator.go | 4 ++++ migrator/migrator.go | 36 +++++++++++++++--------------------- schema/index.go | 18 ++++++++++-------- tests/go.mod | 4 ++-- tests/migrate_test.go | 32 ++++++++++++++++++++++++++++++++ 5 files changed, 63 insertions(+), 31 deletions(-) diff --git a/migrator.go b/migrator.go index d45e3ac2..37051f81 100644 --- a/migrator.go +++ b/migrator.go @@ -2,6 +2,9 @@ package gorm import ( "database/sql" + + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" ) // Migrator returns migrator @@ -27,6 +30,7 @@ type Migrator interface { // Database CurrentDatabase() string + FullDataTypeOf(*schema.Field) clause.Expr // Tables CreateTable(dst ...interface{}) error diff --git a/migrator/migrator.go b/migrator/migrator.go index 90ab7892..64e02ac7 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -18,9 +18,8 @@ type Migrator struct { // Config schema config type Config struct { - CreateIndexAfterCreateTable bool - AllowDeferredConstraintsWhenAutoMigrate bool - DB *gorm.DB + CreateIndexAfterCreateTable bool + DB *gorm.DB gorm.Dialector } @@ -120,13 +119,13 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { if rel.JoinTable != nil { joinValue := reflect.New(rel.JoinTable.ModelType).Interface() if !tx.Migrator().HasTable(rel.JoinTable.Table) { - defer func() { - errr = tx.Table(rel.JoinTable.Table).Migrator().CreateTable(joinValue) - }() + defer func(table string, joinValue interface{}) { + errr = tx.Table(table).Migrator().CreateTable(joinValue) + }(rel.JoinTable.Table, joinValue) } else { - defer func() { - errr = tx.Table(rel.JoinTable.Table).Migrator().AutoMigrate(joinValue) - }() + defer func(table string, joinValue interface{}) { + errr = tx.Table(table).Migrator().AutoMigrate(joinValue) + }(rel.JoinTable.Table, joinValue) } } } @@ -154,7 +153,7 @@ func (m Migrator) CreateTable(values ...interface{}) error { field := stmt.Schema.FieldsByDBName[dbName] createTableSQL += fmt.Sprintf("? ?") hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(string(field.DataType)), "PRIMARY KEY") - values = append(values, clause.Column{Name: dbName}, m.FullDataTypeOf(field)) + values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field)) createTableSQL += "," } @@ -170,9 +169,9 @@ func (m Migrator) CreateTable(values ...interface{}) error { for _, idx := range stmt.Schema.ParseIndexes() { if m.CreateIndexAfterCreateTable { - defer func() { - errr = tx.Migrator().CreateIndex(value, idx.Name) - }() + defer func(value interface{}, name string) { + errr = tx.Migrator().CreateIndex(value, name) + }(value, idx.Name) } else { createTableSQL += "INDEX ? ?," values = append(values, clause.Expr{SQL: idx.Name}, tx.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)) @@ -277,7 +276,7 @@ func (m Migrator) AddColumn(value interface{}, field string) error { if field := stmt.Schema.LookUpField(field); field != nil { return m.DB.Exec( "ALTER TABLE ? ADD ? ?", - clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.FullDataTypeOf(field), + clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.DB.Migrator().FullDataTypeOf(field), ).Error } return fmt.Errorf("failed to look up field with name: %s", field) @@ -301,7 +300,7 @@ func (m Migrator) AlterColumn(value interface{}, field string) error { if field := stmt.Schema.LookUpField(field); field != nil { return m.DB.Exec( "ALTER TABLE ? ALTER COLUMN ? TYPE ?", - clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.FullDataTypeOf(field), + clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.DB.Migrator().FullDataTypeOf(field), ).Error } return fmt.Errorf("failed to look up field with name: %s", field) @@ -436,7 +435,7 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool { m.RunWithValue(value, func(stmt *gorm.Statement) error { currentDatabase := m.DB.Migrator().CurrentDatabase() return m.DB.Raw( - "SELECT count(*) FROM INFORMATION_SCHEMA.referential_constraints WHERE constraint_schema = ? AND table_name = ? AND constraint_name = ?", + "SELECT count(*) FROM INFORMATION_SCHEMA.table_constraints WHERE constraint_schema = ? AND table_name = ? AND constraint_name = ?", currentDatabase, stmt.Table, name, ).Row().Scan(&count) }) @@ -481,11 +480,6 @@ func (m Migrator) CreateIndex(value interface{}, name string) error { } createIndexSQL += "INDEX ? ON ??" - if idx.Comment != "" { - values = append(values, idx.Comment) - createIndexSQL += " COMMENT ?" - } - if idx.Type != "" { createIndexSQL += " USING " + idx.Type } diff --git a/schema/index.go b/schema/index.go index 4228bba2..cf3338c3 100644 --- a/schema/index.go +++ b/schema/index.go @@ -53,16 +53,18 @@ func (schema *Schema) ParseIndexes() map[string]Index { } func (schema *Schema) LookIndex(name string) *Index { - indexes := schema.ParseIndexes() - for _, index := range indexes { - if index.Name == name { - return &index - } - - for _, field := range index.Fields { - if field.Name == name { + if schema != nil { + indexes := schema.ParseIndexes() + for _, index := range indexes { + if index.Name == name { return &index } + + for _, field := range index.Fields { + if field.Name == name { + return &index + } + } } } diff --git a/tests/go.mod b/tests/go.mod index 85ef8dcb..abe32cd6 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -7,8 +7,8 @@ require ( github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 gorm.io/driver/mysql v0.2.3 - gorm.io/driver/postgres v0.2.2 - gorm.io/driver/sqlite v1.0.6 + gorm.io/driver/postgres v0.2.3 + gorm.io/driver/sqlite v1.0.7 gorm.io/driver/sqlserver v0.2.2 gorm.io/gorm v0.2.9 ) diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 194b5cbf..fce4c4aa 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -15,6 +15,8 @@ func TestMigrate(t *testing.T) { rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) + DB.Migrator().DropTable("user_speaks", "user_friends") + if err := DB.Migrator().DropTable(allModels...); err != nil { t.Fatalf("Failed to drop table, got error %v", err) } @@ -28,6 +30,36 @@ func TestMigrate(t *testing.T) { t.Fatalf("Failed to create table for %#v---", m) } } + + for _, indexes := range [][2]string{ + {"user_speaks", "fk_user_speaks_user"}, + {"user_speaks", "fk_user_speaks_language"}, + {"user_friends", "fk_user_friends_user"}, + {"user_friends", "fk_user_friends_friends"}, + {"accounts", "fk_users_account"}, + {"users", "fk_users_team"}, + {"users", "fk_users_manager"}, + {"users", "fk_users_company"}, + } { + if !DB.Migrator().HasConstraint(indexes[0], indexes[1]) { + t.Fatalf("Failed to find index for many2many for %v %v", indexes[0], indexes[1]) + } + } +} + +func TestMigrateWithComment(t *testing.T) { + type UserWithComment struct { + gorm.Model + Name string `gorm:"size:111;index:,comment:这是一个index;comment:this is a 字段"` + } + + if err := DB.Migrator().DropTable(&UserWithComment{}); err != nil { + t.Fatalf("Failed to drop table, got error %v", err) + } + + if err := DB.AutoMigrate(&UserWithComment{}); err != nil { + t.Fatalf("Failed to auto migrate, but got error %v", err) + } } func TestTable(t *testing.T) {