modify unique to constraint

This commit is contained in:
black 2023-06-14 18:54:56 +08:00
parent fb2ec8caea
commit a9e2dfc503
3 changed files with 172 additions and 63 deletions

View File

@ -26,6 +26,8 @@ type Migrator struct {
// Config schema config // Config schema config
type Config struct { type Config struct {
CreateIndexAfterCreateTable bool CreateIndexAfterCreateTable bool
// Unique in ColumnType is affected by UniqueIndex, e.g. MySQL
UniqueAffectedByUniqueIndex bool
DB *gorm.DB DB *gorm.DB
gorm.Dialector gorm.Dialector
} }
@ -115,9 +117,8 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
return err return err
} }
var ( var (
parseIndexes = stmt.Schema.ParseIndexes() parseIndexes = stmt.Schema.ParseIndexes()
parseCheckConstraints = stmt.Schema.ParseCheckConstraints() parseCheckConstraints = stmt.Schema.ParseCheckConstraints()
parseUniqueConstraints = stmt.Schema.ParseUniqueConstraints()
) )
for _, dbName := range stmt.Schema.DBNames { for _, dbName := range stmt.Schema.DBNames {
var foundColumn gorm.ColumnType var foundColumn gorm.ColumnType
@ -157,14 +158,6 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
} }
} }
for _, uni := range parseUniqueConstraints {
if !queryTx.Migrator().HasConstraint(value, uni.Name) {
if err := execTx.Migrator().CreateConstraint(value, uni.Name); err != nil {
return err
}
}
}
for _, chk := range parseCheckConstraints { for _, chk := range parseCheckConstraints {
if !queryTx.Migrator().HasConstraint(value, chk.Name) { if !queryTx.Migrator().HasConstraint(value, chk.Name) {
if err := execTx.Migrator().CreateConstraint(value, chk.Name); err != nil { if err := execTx.Migrator().CreateConstraint(value, chk.Name); err != nil {
@ -438,6 +431,10 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error
// MigrateColumn migrate column // MigrateColumn migrate column
func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error { func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
if field.IgnoreMigration {
return nil
}
// found, smart migrate // found, smart migrate
fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL)) fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL))
realDataType := strings.ToLower(columnType.DatabaseTypeName()) realDataType := strings.ToLower(columnType.DatabaseTypeName())
@ -497,14 +494,6 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
} }
} }
// check unique
if unique, ok := columnType.Unique(); ok && !unique && (field.Unique || field.UniqueIndex) {
// not primary key
if !field.PrimaryKey {
alterColumn = true
}
}
// check default value // check default value
if !field.PrimaryKey { if !field.PrimaryKey {
currentDefaultNotNull := field.HasDefaultValue && (field.DefaultValueInterface != nil || !strings.EqualFold(field.DefaultValue, "NULL")) currentDefaultNotNull := field.HasDefaultValue && (field.DefaultValueInterface != nil || !strings.EqualFold(field.DefaultValue, "NULL"))
@ -533,13 +522,101 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
} }
} }
if alterColumn && !field.IgnoreMigration { if alterColumn {
return m.DB.Migrator().AlterColumn(value, field.DBName) if err := m.DB.Migrator().AlterColumn(value, field.DBName); err != nil {
return err
}
}
if err := m.MigrateColumnUnique(value, field, columnType); err != nil {
return err
} }
return nil return nil
} }
func (m Migrator) MigrateColumnUnique(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
unique, ok := columnType.Unique()
if !ok || field.PrimaryKey {
return nil // skip primary key
}
queryTx := m.DB.Session(&gorm.Session{})
execTx := queryTx
if m.DB.DryRun {
queryTx.DryRun = false
execTx = m.DB.Session(&gorm.Session{Logger: &printSQLLogger{Interface: m.DB.Logger}})
}
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint := m.DB.NamingStrategy.UniqueName(stmt.Table, field.DBName)
index := m.DB.NamingStrategy.IndexName(stmt.Table, field.DBName)
if m.UniqueAffectedByUniqueIndex {
if unique {
hasConstraint := queryTx.Migrator().HasConstraint(value, constraint)
switch {
case field.Unique && !hasConstraint:
if field.Unique {
if err := execTx.Migrator().CreateConstraint(value, constraint); err != nil {
return err
}
}
// field isn't Unique but ColumnType's Unique is reported by UniqueConstraint.
case !field.Unique && hasConstraint:
if err := execTx.Migrator().DropConstraint(value, constraint); err != nil {
return err
}
if field.UniqueIndex {
if err := execTx.Migrator().CreateIndex(value, index); err != nil {
return err
}
}
}
hasIndex := queryTx.Migrator().HasIndex(value, index)
switch {
case field.UniqueIndex && !hasIndex:
if err := execTx.Migrator().CreateIndex(value, index); err != nil {
return err
}
// field isn't UniqueIndex but ColumnType's Unique is reported by UniqueIndex
case !field.UniqueIndex && hasIndex:
if err := execTx.Migrator().DropIndex(value, index); err != nil {
return err
}
// create normal index
if idx := stmt.Schema.LookIndex(index); idx != nil {
if err := execTx.Migrator().CreateIndex(value, index); err != nil {
return err
}
}
}
} else {
if field.Unique {
if err := execTx.Migrator().CreateConstraint(value, constraint); err != nil {
return err
}
}
if field.UniqueIndex {
if err := execTx.Migrator().CreateIndex(value, index); err != nil {
return err
}
}
return nil
}
} else {
if unique && !field.Unique {
return execTx.Migrator().DropConstraint(value, constraint)
}
if !unique && field.Unique {
return execTx.Migrator().CreateConstraint(value, constraint)
}
}
return nil
})
}
// ColumnTypes return columnTypes []gorm.ColumnType and execErr error // ColumnTypes return columnTypes []gorm.ColumnType and execErr error
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
columnTypes := make([]gorm.ColumnType, 0) columnTypes := make([]gorm.ColumnType, 0)
@ -610,7 +687,22 @@ func (m Migrator) DropView(name string) error {
} }
// GuessConstraintAndTable guess statement's constraint and it's table based on name // GuessConstraintAndTable guess statement's constraint and it's table based on name
func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ schema.ConstraintInterface, table string) { //
// Deprecated: use GuessConstraintInterfaceAndTable instead.
func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (*schema.Constraint, *schema.CheckConstraint, string) {
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
switch c := constraint.(type) {
case *schema.Constraint:
return c, nil, table
case *schema.CheckConstraint:
return nil, c, table
default:
return nil, nil, table
}
}
// GuessConstraintInterfaceAndTable guess statement's constraint and it's table based on name
func (m Migrator) GuessConstraintInterfaceAndTable(stmt *gorm.Statement, name string) (_ schema.ConstraintInterface, table string) {
if stmt.Schema == nil { if stmt.Schema == nil {
return nil, stmt.Table return nil, stmt.Table
} }
@ -669,7 +761,7 @@ func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_
// CreateConstraint create constraint // CreateConstraint create constraint
func (m Migrator) CreateConstraint(value interface{}, name string) error { func (m Migrator) CreateConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, table := m.GuessConstraintAndTable(stmt, name) constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
if constraint != nil { if constraint != nil {
vars := []interface{}{clause.Table{Name: table}} vars := []interface{}{clause.Table{Name: table}}
if stmt.TableExpr != nil { if stmt.TableExpr != nil {
@ -685,7 +777,7 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error {
// DropConstraint drop constraint // DropConstraint drop constraint
func (m Migrator) DropConstraint(value interface{}, name string) error { func (m Migrator) DropConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, table := m.GuessConstraintAndTable(stmt, name) constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
if constraint != nil { if constraint != nil {
name = constraint.GetName() name = constraint.GetName()
} }
@ -698,7 +790,7 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool {
var count int64 var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error { m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := m.DB.Migrator().CurrentDatabase() currentDatabase := m.DB.Migrator().CurrentDatabase()
constraint, table := m.GuessConstraintAndTable(stmt, name) constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
if constraint != nil { if constraint != nil {
name = constraint.GetName() name = constraint.GetName()
} }

View File

@ -566,7 +566,8 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
} }
} }
type FKConstraint struct { // Constraint is ForeignKey Constraint
type Constraint struct {
Name string Name string
Field *Field Field *Field
Schema *Schema Schema *Schema
@ -577,9 +578,9 @@ type FKConstraint struct {
OnUpdate string OnUpdate string
} }
func (constraint *FKConstraint) GetName() string { return constraint.Name } func (constraint *Constraint) GetName() string { return constraint.Name }
func (constraint *FKConstraint) Build() (sql string, vars []interface{}) { func (constraint *Constraint) Build() (sql string, vars []interface{}) {
sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??" sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
if constraint.OnDelete != "" { if constraint.OnDelete != "" {
sql += " ON DELETE " + constraint.OnDelete sql += " ON DELETE " + constraint.OnDelete
@ -601,7 +602,7 @@ func (constraint *FKConstraint) Build() (sql string, vars []interface{}) {
return return
} }
func (rel *Relationship) ParseConstraint() *FKConstraint { func (rel *Relationship) ParseConstraint() *Constraint {
str := rel.Field.TagSettings["CONSTRAINT"] str := rel.Field.TagSettings["CONSTRAINT"]
if str == "-" { if str == "-" {
return nil return nil
@ -641,7 +642,7 @@ func (rel *Relationship) ParseConstraint() *FKConstraint {
name = rel.Schema.namer.RelationshipFKName(*rel) name = rel.Schema.namer.RelationshipFKName(*rel)
} }
constraint := FKConstraint{ constraint := Constraint{
Name: name, Name: name,
Field: rel.Field, Field: rel.Field,
OnUpdate: settings["ONUPDATE"], OnUpdate: settings["ONUPDATE"],

View File

@ -1654,40 +1654,46 @@ func TestMigrateWithUniqueIndexAndUnique(t *testing.T) {
} }
// not unique // not unique
type UniqueStruct1 struct { type (
Name string `gorm:"size:20"` UniqueStruct1 struct {
} Name string `gorm:"size:10"`
}
UniqueStruct2 struct {
Name string `gorm:"size:20"`
}
)
checkField(&UniqueStruct1{}, "name", false, false) checkField(&UniqueStruct1{}, "name", false, false)
checkField(&UniqueStruct2{}, "name", false, false)
type ( // unique type ( // unique
UniqueStruct2 struct {
Name string `gorm:"size:20;unique"`
}
UniqueStruct3 struct { UniqueStruct3 struct {
Name string `gorm:"size:30;unique"` Name string `gorm:"size:30;unique"`
} }
UniqueStruct4 struct {
Name string `gorm:"size:40;unique"`
}
) )
checkField(&UniqueStruct2{}, "name", true, false)
checkField(&UniqueStruct3{}, "name", true, false) checkField(&UniqueStruct3{}, "name", true, false)
checkField(&UniqueStruct4{}, "name", true, false)
type ( // uniqueIndex type ( // uniqueIndex
UniqueStruct4 struct {
Name string `gorm:"size:40;uniqueIndex"`
}
UniqueStruct5 struct { UniqueStruct5 struct {
Name string `gorm:"size:50;uniqueIndex"` Name string `gorm:"size:50;uniqueIndex"`
} }
UniqueStruct6 struct { UniqueStruct6 struct {
Name string `gorm:"size:20;uniqueIndex:idx_us6_all_names"` Name string `gorm:"size:60;uniqueIndex"`
NickName string `gorm:"size:20;uniqueIndex:idx_us6_all_names"` }
UniqueStruct7 struct {
Name string `gorm:"size:70;uniqueIndex:idx_us6_all_names"`
NickName string `gorm:"size:70;uniqueIndex:idx_us6_all_names"`
} }
) )
checkField(&UniqueStruct4{}, "name", false, true)
checkField(&UniqueStruct5{}, "name", false, true) checkField(&UniqueStruct5{}, "name", false, true)
checkField(&UniqueStruct6{}, "name", false, true)
checkField(&UniqueStruct6{}, "name", false, false) checkField(&UniqueStruct7{}, "name", false, false)
checkField(&UniqueStruct6{}, "nick_name", false, false) checkField(&UniqueStruct7{}, "nick_name", false, false)
checkField(&UniqueStruct6{}, "nick_name", false, false) checkField(&UniqueStruct7{}, "nick_name", false, false)
type TestCase struct { type TestCase struct {
name string name string
@ -1703,36 +1709,46 @@ func TestMigrateWithUniqueIndexAndUnique(t *testing.T) {
checkColumnType(t, "name", true) checkColumnType(t, "name", true)
checkIndex(t, nil) checkIndex(t, nil)
} }
index := &migrator.Index{ uniqueIndex := &migrator.Index{
TableName: table, TableName: table,
NameValue: DB.Config.NamingStrategy.IndexName(table, "name"), NameValue: DB.Config.NamingStrategy.IndexName(table, "name"),
ColumnList: []string{"name"}, ColumnList: []string{"name"},
PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true}, PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true},
UniqueValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true},
} }
if DB.Dialector.Name() != "mysql" { checkUniqueIndex := func(t *testing.T) {
checkColumnType(t, "name", false)
checkIndex(t, []gorm.Index{uniqueIndex})
}
if DB.Dialector.Name() == "mysql" {
// in mysql, unique equals uniqueIndex // in mysql, unique equals uniqueIndex
checkUnique = func(t *testing.T) { checkUnique = func(t *testing.T) {
checkColumnType(t, "name", true) checkColumnType(t, "name", true)
checkIndex(t, []gorm.Index{index}) checkIndex(t, []gorm.Index{&migrator.Index{
TableName: table,
NameValue: DB.Config.NamingStrategy.UniqueName(table, "name"),
ColumnList: []string{"name"},
PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true},
UniqueValue: sql.NullBool{Bool: true, Valid: true},
}})
}
checkUniqueIndex = func(t *testing.T) {
checkColumnType(t, "name", true)
checkIndex(t, []gorm.Index{uniqueIndex})
} }
}
checkUniqueIndex := func(t *testing.T) {
checkColumnType(t, "name", false)
checkIndex(t, []gorm.Index{index})
} }
tests := []TestCase{ tests := []TestCase{
{name: "notUnique to notUnique", from: &UniqueStruct1{}, to: &UniqueStruct1{}, checkFunc: checkNotUnique}, {name: "notUnique to notUnique", from: &UniqueStruct1{}, to: &UniqueStruct2{}, checkFunc: checkNotUnique},
{name: "notUnique to unique", from: &UniqueStruct1{}, to: &UniqueStruct2{}, checkFunc: checkUnique}, {name: "notUnique to unique", from: &UniqueStruct1{}, to: &UniqueStruct3{}, checkFunc: checkUnique},
{name: "notUnique to uniqueIndex", from: &UniqueStruct1{}, to: &UniqueStruct4{}, checkFunc: checkUniqueIndex}, {name: "notUnique to uniqueIndex", from: &UniqueStruct1{}, to: &UniqueStruct5{}, checkFunc: checkUniqueIndex},
{name: "unique to notUnique", from: &UniqueStruct2{}, to: &UniqueStruct1{}, checkFunc: checkNotUnique}, {name: "unique to notUnique", from: &UniqueStruct3{}, to: &UniqueStruct1{}, checkFunc: checkNotUnique},
{name: "unique to unique", from: &UniqueStruct2{}, to: &UniqueStruct3{}, checkFunc: checkUnique}, {name: "unique to unique", from: &UniqueStruct3{}, to: &UniqueStruct4{}, checkFunc: checkUnique},
{name: "unique to uniqueIndex", from: &UniqueStruct2{}, to: &UniqueStruct4{}, checkFunc: checkUniqueIndex}, {name: "unique to uniqueIndex", from: &UniqueStruct3{}, to: &UniqueStruct5{}, checkFunc: checkUniqueIndex},
{name: "uniqueIndex to notUnique", from: &UniqueStruct4{}, to: &UniqueStruct2{}, checkFunc: checkNotUnique}, {name: "uniqueIndex to notUnique", from: &UniqueStruct5{}, to: &UniqueStruct2{}, checkFunc: checkNotUnique},
{name: "uniqueIndex to unique", from: &UniqueStruct4{}, to: &UniqueStruct2{}, checkFunc: checkUnique}, {name: "uniqueIndex to unique", from: &UniqueStruct5{}, to: &UniqueStruct3{}, checkFunc: checkUnique},
{name: "uniqueIndex to uniqueIndex", from: &UniqueStruct4{}, to: &UniqueStruct5{}, checkFunc: checkUniqueIndex}, {name: "uniqueIndex to uniqueIndex", from: &UniqueStruct5{}, to: &UniqueStruct6{}, checkFunc: checkUniqueIndex},
{name: "uniqueIndex to multi uniqueIndex", from: &UniqueStruct4{}, to: &UniqueStruct6{}, checkFunc: func(t *testing.T) { {name: "uniqueIndex to multi uniqueIndex", from: &UniqueStruct5{}, to: &UniqueStruct7{}, checkFunc: func(t *testing.T) {
checkColumnType(t, "name", false) checkColumnType(t, "name", false)
checkColumnType(t, "nick_name", false) checkColumnType(t, "nick_name", false)
checkIndex(t, []gorm.Index{&migrator.Index{ checkIndex(t, []gorm.Index{&migrator.Index{