diff --git a/migrator/migrator.go b/migrator/migrator.go index 189a141f..bd36c394 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -17,25 +17,15 @@ import ( "gorm.io/gorm/schema" ) -// This regular expression seeks to find a sequence of digits (\d+) among zero or more non-digit characters (\D*), -// with a possible trailing non-digit character (\D?). - -// For example, values that can pass this regular expression are: -// - "123" -// - "abc456" -// -"%$#@789" +// Regular expression to match sequences of digits in data types. var regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`) -// TODO:? Create const vars for raw sql queries ? - -var _ gorm.Migrator = (*Migrator)(nil) - -// Migrator m struct +// Migrator struct implements the gorm.Migrator interface. type Migrator struct { Config } -// Config schema config +// Config holds the configuration for the Migrator. type Config struct { CreateIndexAfterCreateTable bool DB *gorm.DB @@ -46,18 +36,18 @@ type printSQLLogger struct { logger.Interface } -func (l *printSQLLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { +func (l *printSQLLogger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { sql, _ := fc() fmt.Println(sql + ";") l.Interface.Trace(ctx, begin, fc, err) } -// GormDataTypeInterface gorm data type interface +// GormDataTypeInterface allows custom data types to define their own database data type. type GormDataTypeInterface interface { GormDBDataType(*gorm.DB, *schema.Field) string } -// RunWithValue run migration with statement value +// RunWithValue executes a function with a prepared statement. func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error { stmt := &gorm.Statement{DB: m.DB} if m.DB.Statement != nil { @@ -74,7 +64,7 @@ func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error return fc(stmt) } -// DataTypeOf return field's db data type +// DataTypeOf returns the database data type for a field. func (m Migrator) DataTypeOf(field *schema.Field) string { fieldValue := reflect.New(field.IndirectFieldType) if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok { @@ -86,9 +76,9 @@ func (m Migrator) DataTypeOf(field *schema.Field) string { return m.Dialector.DataTypeOf(field) } -// FullDataTypeOf returns field's db full data type -func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { - expr.SQL = m.DataTypeOf(field) +// FullDataTypeOf returns the full database data type for a field, including constraints. +func (m Migrator) FullDataTypeOf(field *schema.Field) clause.Expr { + expr := clause.Expr{SQL: m.DataTypeOf(field)} if field.NotNull { expr.SQL += " NOT NULL" @@ -104,12 +94,13 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { } } - return + return expr } -func (m Migrator) GetQueryAndExecTx() (queryTx, execTx *gorm.DB) { - queryTx = m.DB.Session(&gorm.Session{}) - execTx = queryTx +// GetQueryAndExecTx returns query and execution transactions. +func (m Migrator) GetQueryAndExecTx() (*gorm.DB, *gorm.DB) { + queryTx := m.DB.Session(&gorm.Session{}) + execTx := queryTx if m.DB.DryRun { queryTx.DryRun = false execTx = m.DB.Session(&gorm.Session{Logger: &printSQLLogger{Interface: m.DB.Logger}}) @@ -117,7 +108,7 @@ func (m Migrator) GetQueryAndExecTx() (queryTx, execTx *gorm.DB) { return queryTx, execTx } -// AutoMigrate auto migrate values +// AutoMigrate automatically migrates the schema, adding tables, columns, and indexes as needed. func (m Migrator) AutoMigrate(values ...interface{}) error { for _, value := range m.ReorderModels(values, true) { queryTx, execTx := m.GetQueryAndExecTx() @@ -126,76 +117,83 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { return err } } else { - if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { + if err := m.migrateSchema(value, queryTx, execTx); err != nil { + return err + } + } + } + return nil +} - if stmt.Schema == nil { - return errors.New("failed to get schema") +func (m Migrator) migrateSchema(value interface{}, queryTx, execTx *gorm.DB) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if stmt.Schema == nil { + return errors.New("failed to get schema") + } + + columnTypes, err := queryTx.Migrator().ColumnTypes(value) + if err != nil { + return err + } + + parseIndexes := stmt.Schema.ParseIndexes() + parseCheckConstraints := stmt.Schema.ParseCheckConstraints() + + for _, dbName := range stmt.Schema.DBNames { + var foundColumn gorm.ColumnType + for _, columnType := range columnTypes { + if columnType.Name() == dbName { + foundColumn = columnType + break } + } - columnTypes, err := queryTx.Migrator().ColumnTypes(value) - if err != nil { + if foundColumn == nil { + if err = execTx.Migrator().AddColumn(value, dbName); err != nil { return err } - var ( - parseIndexes = stmt.Schema.ParseIndexes() - parseCheckConstraints = stmt.Schema.ParseCheckConstraints() - ) - for _, dbName := range stmt.Schema.DBNames { - var foundColumn gorm.ColumnType - - for _, columnType := range columnTypes { - if columnType.Name() == dbName { - foundColumn = columnType - break - } - } - - if foundColumn == nil { - // not found, add column - if err = execTx.Migrator().AddColumn(value, dbName); err != nil { - return err - } - } else { - // found, smartly migrate - field := stmt.Schema.FieldsByDBName[dbName] - if err = execTx.Migrator().MigrateColumn(value, field, foundColumn); err != nil { - return err - } - } + } else { + field := stmt.Schema.FieldsByDBName[dbName] + if err = execTx.Migrator().MigrateColumn(value, field, foundColumn); err != nil { + return err } + } + } - if !m.DB.DisableForeignKeyConstraintWhenMigrating && !m.DB.IgnoreRelationshipsWhenMigrating { - for _, rel := range stmt.Schema.Relationships.Relations { - if rel.Field.IgnoreMigration { - continue - } - if constraint := rel.ParseConstraint(); constraint != nil && - constraint.Schema == stmt.Schema && !queryTx.Migrator().HasConstraint(value, constraint.Name) { - if err := execTx.Migrator().CreateConstraint(value, constraint.Name); err != nil { - return err - } - } - } + if err := m.handleConstraints(value, stmt, queryTx, execTx, parseIndexes, parseCheckConstraints); err != nil { + return err + } + + return nil + }) +} + +func (m Migrator) handleConstraints(value interface{}, stmt *gorm.Statement, queryTx, execTx *gorm.DB, parseIndexes []schema.Index, parseCheckConstraints []schema.CheckConstraint) error { + if !m.DB.DisableForeignKeyConstraintWhenMigrating && !m.DB.IgnoreRelationshipsWhenMigrating { + for _, rel := range stmt.Schema.Relationships.Relations { + if rel.Field.IgnoreMigration { + continue + } + if constraint := rel.ParseConstraint(); constraint != nil && + constraint.Schema == stmt.Schema && !queryTx.Migrator().HasConstraint(value, constraint.Name) { + if err := execTx.Migrator().CreateConstraint(value, constraint.Name); err != nil { + return err } + } + } + } - for _, chk := range parseCheckConstraints { - if !queryTx.Migrator().HasConstraint(value, chk.Name) { - if err := execTx.Migrator().CreateConstraint(value, chk.Name); err != nil { - return err - } - } - } + for _, chk := range parseCheckConstraints { + if !queryTx.Migrator().HasConstraint(value, chk.Name) { + if err := execTx.Migrator().CreateConstraint(value, chk.Name); err != nil { + return err + } + } + } - for _, idx := range parseIndexes { - if !queryTx.Migrator().HasIndex(value, idx.Name) { - if err := execTx.Migrator().CreateIndex(value, idx.Name); err != nil { - return err - } - } - } - - return nil - }); err != nil { + for _, idx := range parseIndexes { + if !queryTx.Migrator().HasIndex(value, idx.Name) { + if err := execTx.Migrator().CreateIndex(value, idx.Name); err != nil { return err } } @@ -204,117 +202,116 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { return nil } -// GetTables returns tables -func (m Migrator) GetTables() (tableList []string, err error) { - err = m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?", m.CurrentDatabase()). +// GetTables returns a list of table names in the current database. +func (m Migrator) GetTables() ([]string, error) { + var tableList []string + err := m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables WHERE TABLE_SCHEMA = ?", m.CurrentDatabase()). Scan(&tableList).Error - return + return tableList, err } -// CreateTable create table in database for values +// CreateTable creates tables for the given values. func (m Migrator) CreateTable(values ...interface{}) error { for _, value := range m.ReorderModels(values, false) { - tx := m.DB.Session(&gorm.Session{}) - if err := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) { - - if stmt.Schema == nil { - return errors.New("failed to get schema") - } - - var ( - createTableSQL = "CREATE TABLE ? (" - values = []interface{}{m.CurrentTable(stmt)} - hasPrimaryKeyInDataType bool - ) - - for _, dbName := range stmt.Schema.DBNames { - field := stmt.Schema.FieldsByDBName[dbName] - if !field.IgnoreMigration { - createTableSQL += "? ?" - hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(m.DataTypeOf(field)), "PRIMARY KEY") - values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field)) - createTableSQL += "," - } - } - - if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 { - createTableSQL += "PRIMARY KEY ?," - primaryKeys := make([]interface{}, 0, len(stmt.Schema.PrimaryFields)) - for _, field := range stmt.Schema.PrimaryFields { - primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName}) - } - - values = append(values, primaryKeys) - } - - for _, idx := range stmt.Schema.ParseIndexes() { - if m.CreateIndexAfterCreateTable { - defer func(value interface{}, name string) { - if err == nil { - err = tx.Migrator().CreateIndex(value, name) - } - }(value, idx.Name) - } else { - if idx.Class != "" { - createTableSQL += idx.Class + " " - } - createTableSQL += "INDEX ? ?" - - if idx.Comment != "" { - createTableSQL += fmt.Sprintf(" COMMENT '%s'", idx.Comment) - } - - if idx.Option != "" { - createTableSQL += " " + idx.Option - } - - createTableSQL += "," - values = append(values, clause.Column{Name: idx.Name}, tx.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)) - } - } - - if !m.DB.DisableForeignKeyConstraintWhenMigrating && !m.DB.IgnoreRelationshipsWhenMigrating { - for _, rel := range stmt.Schema.Relationships.Relations { - if rel.Field.IgnoreMigration { - continue - } - if constraint := rel.ParseConstraint(); constraint != nil { - if constraint.Schema == stmt.Schema { - sql, vars := constraint.Build() - createTableSQL += sql + "," - values = append(values, vars...) - } - } - } - } - - for _, uni := range stmt.Schema.ParseUniqueConstraints() { - createTableSQL += "CONSTRAINT ? UNIQUE (?)," - values = append(values, clause.Column{Name: uni.Name}, clause.Expr{SQL: stmt.Quote(uni.Field.DBName)}) - } - - for _, chk := range stmt.Schema.ParseCheckConstraints() { - createTableSQL += "CONSTRAINT ? CHECK (?)," - values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}) - } - - createTableSQL = strings.TrimSuffix(createTableSQL, ",") - - createTableSQL += ")" - - if tableOption, ok := m.DB.Get("gorm:table_options"); ok { - createTableSQL += fmt.Sprint(tableOption) - } - - err = tx.Exec(createTableSQL, values...).Error - return err - }); err != nil { + if err := m.createTableForValue(value); err != nil { return err } } return nil } +func (m Migrator) createTableForValue(value interface{}) error { + tx := m.DB.Session(&gorm.Session{}) + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if stmt.Schema == nil { + return errors.New("failed to get schema") + } + + createTableSQL, values := m.buildCreateTableSQL(stmt) + return tx.Exec(createTableSQL, values...).Error + }) +} + +func (m Migrator) buildCreateTableSQL(stmt *gorm.Statement) (string, []interface{}) { + createTableSQL := "CREATE TABLE ? (" + values := []interface{}{m.CurrentTable(stmt)} + hasPrimaryKeyInDataType := false + + for _, dbName := range stmt.Schema.DBNames { + field := stmt.Schema.FieldsByDBName[dbName] + if !field.IgnoreMigration { + createTableSQL += "? ?" + hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(m.DataTypeOf(field)), "PRIMARY KEY") + values = append(values, clause.Column{Name: dbName}, m.FullDataTypeOf(field)) + createTableSQL += "," + } + } + + if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 { + createTableSQL += "PRIMARY KEY ?," + primaryKeys := make([]interface{}, len(stmt.Schema.PrimaryFields)) + for i, field := range stmt.Schema.PrimaryFields { + primaryKeys[i] = clause.Column{Name: field.DBName} + } + values = append(values, primaryKeys) + } + + m.appendConstraints(stmt, &createTableSQL, &values) + createTableSQL = strings.TrimSuffix(createTableSQL, ",") + ")" + + if tableOption, ok := m.DB.Get("gorm:table_options"); ok { + createTableSQL += fmt.Sprint(tableOption) + } + + return createTableSQL, values +} + +func (m Migrator) appendConstraints(stmt *gorm.Statement, createTableSQL *string, values *[]interface{}) { + for _, idx := range stmt.Schema.ParseIndexes() { + if m.CreateIndexAfterCreateTable { + defer m.DB.Migrator().CreateIndex(stmt.Table, idx.Name) + } else { + if idx.Class != "" { + *createTableSQL += idx.Class + " " + } + *createTableSQL += "INDEX ? ?" + if idx.Comment != "" { + *createTableSQL += fmt.Sprintf(" COMMENT '%s'", idx.Comment) + } + if idx.Option != "" { + *createTableSQL += " " + idx.Option + } + *createTableSQL += "," + *values = append(*values, clause.Column{Name: idx.Name}, m.BuildIndexOptions(idx.Fields, stmt)) + } + } + + if !m.DB.DisableForeignKeyConstraintWhenMigrating && !m.DB.IgnoreRelationshipsWhenMigrating { + for _, rel := range stmt.Schema.Relationships.Relations { + if rel.Field.IgnoreMigration { + continue + } + if constraint := rel.ParseConstraint(); constraint != nil && constraint.Schema == stmt.Schema { + sql, vars := constraint.Build() + *createTableSQL += sql + "," + *values = append(*values, vars...) + } + } + } + + for _, uni := range stmt.Schema.ParseUniqueConstraints() { + *createTableSQL += "CONSTRAINT ? UNIQUE (?)," + *values = append(*values, clause.Column{Name: uni.Name}, clause.Expr{SQL: stmt.Quote(uni.Field.DBName)}) + } + + for _, chk := range stmt.Schema.ParseCheckConstraints() { + *createTableSQL += "CONSTRAINT ? CHECK (?)," + *values = append(*values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}) + } +} + return nil +} + // DropTable drop table for values func (m Migrator) DropTable(values ...interface{}) error { values = m.ReorderModels(values, false)