Fix RenameColumn for mssql, DropColumn for sqlite
This commit is contained in:
		
							parent
							
								
									58bc0f51c1
								
							
						
					
					
						commit
						24285060d5
					
				| @ -41,6 +41,23 @@ func (m Migrator) HasColumn(value interface{}, field string) bool { | |||||||
| 	return count > 0 | 	return count > 0 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error { | ||||||
|  | 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
|  | 		if field := stmt.Schema.LookUpField(oldName); field != nil { | ||||||
|  | 			oldName = field.DBName | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if field := stmt.Schema.LookUpField(newName); field != nil { | ||||||
|  | 			newName = field.DBName | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		return m.DB.Exec( | ||||||
|  | 			"sp_rename @objname = ?, @newname = ?, @objtype = 'COLUMN';", | ||||||
|  | 			fmt.Sprintf("%s.%s", stmt.Table, oldName), clause.Column{Name: newName}, | ||||||
|  | 		).Error | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (m Migrator) HasIndex(value interface{}, name string) bool { | func (m Migrator) HasIndex(value interface{}, name string) bool { | ||||||
| 	var count int | 	var count int | ||||||
| 	m.RunWithValue(value, func(stmt *gorm.Statement) error { | 	m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
|  | |||||||
| @ -2,6 +2,7 @@ package sqlite | |||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"regexp" | ||||||
| 	"strings" | 	"strings" | ||||||
| 
 | 
 | ||||||
| 	"github.com/jinzhu/gorm" | 	"github.com/jinzhu/gorm" | ||||||
| @ -22,11 +23,10 @@ func (m Migrator) HasTable(value interface{}) bool { | |||||||
| 	return count > 0 | 	return count > 0 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (m Migrator) HasColumn(value interface{}, field string) bool { | func (m Migrator) HasColumn(value interface{}, name string) bool { | ||||||
| 	var count int | 	var count int | ||||||
| 	m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error { | 	m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
| 		name := field | 		if field := stmt.Schema.LookUpField(name); field != nil { | ||||||
| 		if field := stmt.Schema.LookUpField(field); field != nil { |  | ||||||
| 			name = field.DBName | 			name = field.DBName | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| @ -38,6 +38,45 @@ func (m Migrator) HasColumn(value interface{}, field string) bool { | |||||||
| 	return count > 0 | 	return count > 0 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (m Migrator) DropColumn(value interface{}, name string) error { | ||||||
|  | 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
|  | 		if field := stmt.Schema.LookUpField(name); field != nil { | ||||||
|  | 			name = field.DBName | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		var ( | ||||||
|  | 			createSQL    string | ||||||
|  | 			newTableName = stmt.Table + "__temp" | ||||||
|  | 		) | ||||||
|  | 
 | ||||||
|  | 		m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", stmt.Table, stmt.Table).Row().Scan(&createSQL) | ||||||
|  | 
 | ||||||
|  | 		if reg, err := regexp.Compile("(`|'|\"| )" + name + "(`|'|\"| ) .*?,"); err == nil { | ||||||
|  | 			tableReg, err := regexp.Compile(" ('|`|\"| )" + stmt.Table + "('|`|\"| ) ") | ||||||
|  | 			if err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName)) | ||||||
|  | 			createSQL = reg.ReplaceAllString(createSQL, "") | ||||||
|  | 
 | ||||||
|  | 			var columns []string | ||||||
|  | 			columnTypes, _ := m.DB.Migrator().ColumnTypes(value) | ||||||
|  | 			for _, columnType := range columnTypes { | ||||||
|  | 				if columnType.Name() != name { | ||||||
|  | 					columns = append(columns, fmt.Sprintf("`%v`", columnType.Name())) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			createSQL = fmt.Sprintf("PRAGMA foreign_keys=off;BEGIN TRANSACTION;"+createSQL+";INSERT INTO `%v`(%v) SELECT %v FROM `%v`;DROP TABLE `%v`;ALTER TABLE `%v` RENAME TO `%v`;COMMIT;", newTableName, strings.Join(columns, ","), strings.Join(columns, ","), stmt.Table, stmt.Table, newTableName, stmt.Table) | ||||||
|  | 
 | ||||||
|  | 			return m.DB.Exec(createSQL).Error | ||||||
|  | 		} else { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (m Migrator) CreateConstraint(interface{}, string) error { | func (m Migrator) CreateConstraint(interface{}, string) error { | ||||||
| 	return gorm.ErrNotImplemented | 	return gorm.ErrNotImplemented | ||||||
| } | } | ||||||
|  | |||||||
							
								
								
									
										2
									
								
								gorm.go
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								gorm.go
									
									
									
									
									
								
							| @ -66,7 +66,7 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if config.NowFunc == nil { | 	if config.NowFunc == nil { | ||||||
| 		config.NowFunc = func() time.Time { return time.Now().Local().Round(time.Second) } | 		config.NowFunc = func() time.Time { return time.Now().Local() } | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if dialector != nil { | 	if dialector != nil { | ||||||
|  | |||||||
| @ -243,14 +243,15 @@ func (m Migrator) AddColumn(value interface{}, field string) error { | |||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (m Migrator) DropColumn(value interface{}, field string) error { | func (m Migrator) DropColumn(value interface{}, name string) error { | ||||||
| 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
| 		if field := stmt.Schema.LookUpField(field); field != nil { | 		if field := stmt.Schema.LookUpField(name); field != nil { | ||||||
| 			return m.DB.Exec( | 			name = field.DBName | ||||||
| 				"ALTER TABLE ? DROP COLUMN ?", clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, |  | ||||||
| 			).Error |  | ||||||
| 		} | 		} | ||||||
| 		return fmt.Errorf("failed to look up field with name: %s", field) | 
 | ||||||
|  | 		return m.DB.Exec( | ||||||
|  | 			"ALTER TABLE ? DROP COLUMN ?", clause.Table{Name: stmt.Table}, clause.Column{Name: name}, | ||||||
|  | 		).Error | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| @ -284,16 +285,20 @@ func (m Migrator) HasColumn(value interface{}, field string) bool { | |||||||
| 	return count > 0 | 	return count > 0 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (m Migrator) RenameColumn(value interface{}, oldName, field string) error { | func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error { | ||||||
| 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
| 		if field := stmt.Schema.LookUpField(field); field != nil { | 		if field := stmt.Schema.LookUpField(oldName); field != nil { | ||||||
| 			oldName = m.DB.NamingStrategy.ColumnName(stmt.Table, oldName) | 			oldName = field.DBName | ||||||
| 			return m.DB.Exec( |  | ||||||
| 				"ALTER TABLE ? RENAME COLUMN ? TO ?", |  | ||||||
| 				clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: field.DBName}, |  | ||||||
| 			).Error |  | ||||||
| 		} | 		} | ||||||
| 		return fmt.Errorf("failed to look up field with name: %s", field) | 
 | ||||||
|  | 		if field := stmt.Schema.LookUpField(newName); field != nil { | ||||||
|  | 			newName = field.DBName | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		return m.DB.Exec( | ||||||
|  | 			"ALTER TABLE ? RENAME COLUMN ? TO ?", | ||||||
|  | 			clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: newName}, | ||||||
|  | 		).Error | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -98,18 +98,38 @@ func TestColumns(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := DB.Table("column_structs").Migrator().AddColumn(&NewColumnStruct{}, "NewName"); err != nil { | 	if err := DB.Table("column_structs").Migrator().AddColumn(&NewColumnStruct{}, "NewName"); err != nil { | ||||||
| 		t.Errorf("Failed to add column, got %v", err) | 		t.Fatalf("Failed to add column, got %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if !DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "NewName") { | 	if !DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "NewName") { | ||||||
| 		t.Errorf("Failed to find added column") | 		t.Fatalf("Failed to find added column") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := DB.Table("column_structs").Migrator().DropColumn(&NewColumnStruct{}, "NewName"); err != nil { | 	if err := DB.Table("column_structs").Migrator().DropColumn(&NewColumnStruct{}, "NewName"); err != nil { | ||||||
| 		t.Errorf("Failed to add column, got %v", err) | 		t.Fatalf("Failed to add column, got %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "NewName") { | 	if DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "NewName") { | ||||||
| 		t.Errorf("Found deleted column") | 		t.Fatalf("Found deleted column") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := DB.Table("column_structs").Migrator().AddColumn(&NewColumnStruct{}, "NewName"); err != nil { | ||||||
|  | 		t.Fatalf("Failed to add column, got %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := DB.Table("column_structs").Migrator().RenameColumn(&NewColumnStruct{}, "NewName", "new_new_name"); err != nil { | ||||||
|  | 		t.Fatalf("Failed to add column, got %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if !DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "new_new_name") { | ||||||
|  | 		t.Fatalf("Failed to found renamed column") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := DB.Table("column_structs").Migrator().DropColumn(&NewColumnStruct{}, "new_new_name"); err != nil { | ||||||
|  | 		t.Fatalf("Failed to add column, got %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "new_new_name") { | ||||||
|  | 		t.Fatalf("Found deleted column") | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | |||||||
| @ -88,8 +88,8 @@ func AssertEqual(t *testing.T, got, expect interface{}) { | |||||||
| 			if curTime, ok := got.(time.Time); ok { | 			if curTime, ok := got.(time.Time); ok { | ||||||
| 				format := "2006-01-02T15:04:05Z07:00" | 				format := "2006-01-02T15:04:05Z07:00" | ||||||
| 
 | 
 | ||||||
| 				if curTime.Round(time.Second).Format(format) != expect.(time.Time).Round(time.Second).Format(format) { | 				if curTime.Round(time.Second).Format(format) != expect.(time.Time).Round(time.Second).Format(format) && curTime.Truncate(time.Second).Format(format) != expect.(time.Time).Truncate(time.Second).Format(format) { | ||||||
| 					t.Errorf("%v: expect: %v, got %v after time round", utils.FileWithLineNum(), expect.(time.Time).Round(time.Second).Format(format), curTime.Round(time.Second).Format(format)) | 					t.Errorf("%v: expect: %v, got %v after time round", utils.FileWithLineNum(), expect.(time.Time), curTime) | ||||||
| 				} | 				} | ||||||
| 			} else if fmt.Sprint(got) != fmt.Sprint(expect) { | 			} else if fmt.Sprint(got) != fmt.Sprint(expect) { | ||||||
| 				t.Errorf("%v: expect: %#v, got %#v", utils.FileWithLineNum(), expect, got) | 				t.Errorf("%v: expect: %#v, got %#v", utils.FileWithLineNum(), expect, got) | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu