Almost finish Migrator
This commit is contained in:
		
							parent
							
								
									0be4817ff9
								
							
						
					
					
						commit
						0801cdf164
					
				| @ -33,6 +33,7 @@ type Migrator interface { | |||||||
| 	AddColumn(dst interface{}, field string) error | 	AddColumn(dst interface{}, field string) error | ||||||
| 	DropColumn(dst interface{}, field string) error | 	DropColumn(dst interface{}, field string) error | ||||||
| 	AlterColumn(dst interface{}, field string) error | 	AlterColumn(dst interface{}, field string) error | ||||||
|  | 	HasColumn(dst interface{}, field string) bool | ||||||
| 	RenameColumn(dst interface{}, oldName, field string) error | 	RenameColumn(dst interface{}, oldName, field string) error | ||||||
| 	ColumnTypes(dst interface{}) ([]*sql.ColumnType, error) | 	ColumnTypes(dst interface{}) ([]*sql.ColumnType, error) | ||||||
| 
 | 
 | ||||||
| @ -43,6 +44,7 @@ type Migrator interface { | |||||||
| 	// Constraints
 | 	// Constraints
 | ||||||
| 	CreateConstraint(dst interface{}, name string) error | 	CreateConstraint(dst interface{}, name string) error | ||||||
| 	DropConstraint(dst interface{}, name string) error | 	DropConstraint(dst interface{}, name string) error | ||||||
|  | 	HasConstraint(dst interface{}, name string) bool | ||||||
| 
 | 
 | ||||||
| 	// Indexes
 | 	// Indexes
 | ||||||
| 	CreateIndex(dst interface{}, name string) error | 	CreateIndex(dst interface{}, name string) error | ||||||
|  | |||||||
| @ -3,9 +3,12 @@ package migrator | |||||||
| import ( | import ( | ||||||
| 	"database/sql" | 	"database/sql" | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"reflect" | ||||||
|  | 	"strings" | ||||||
| 
 | 
 | ||||||
| 	"github.com/jinzhu/gorm" | 	"github.com/jinzhu/gorm" | ||||||
| 	"github.com/jinzhu/gorm/clause" | 	"github.com/jinzhu/gorm/clause" | ||||||
|  | 	"github.com/jinzhu/gorm/schema" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // Migrator migrator struct
 | // Migrator migrator struct
 | ||||||
| @ -34,19 +37,133 @@ func (migrator Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement | |||||||
| 
 | 
 | ||||||
| // AutoMigrate
 | // AutoMigrate
 | ||||||
| func (migrator Migrator) AutoMigrate(values ...interface{}) error { | func (migrator Migrator) AutoMigrate(values ...interface{}) error { | ||||||
| 	// if has table
 | 	// TODO smart migrate data type
 | ||||||
| 	// not -> create table
 |  | ||||||
| 	// check columns -> add column, change column type
 |  | ||||||
| 	// check foreign keys -> create indexes
 |  | ||||||
| 	// check indexes -> create indexes
 |  | ||||||
| 
 | 
 | ||||||
| 	return gorm.ErrNotImplemented | 	for _, value := range values { | ||||||
|  | 		if !migrator.DB.Migrator().HasTable(value) { | ||||||
|  | 			if err := migrator.DB.Migrator().CreateTable(value); err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 		} else { | ||||||
|  | 			if err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
|  | 				for _, field := range stmt.Schema.FieldsByDBName { | ||||||
|  | 					if !migrator.DB.Migrator().HasColumn(value, field.DBName) { | ||||||
|  | 						if err := migrator.DB.Migrator().AddColumn(value, field.DBName); err != nil { | ||||||
|  | 							return err | ||||||
|  | 						} | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				for _, rel := range stmt.Schema.Relationships.Relations { | ||||||
|  | 					if constraint := rel.ParseConstraint(); constraint != nil { | ||||||
|  | 						if !migrator.DB.Migrator().HasConstraint(value, constraint.Name) { | ||||||
|  | 							if err := migrator.DB.Migrator().CreateConstraint(value, constraint.Name); err != nil { | ||||||
|  | 								return err | ||||||
|  | 							} | ||||||
|  | 						} | ||||||
|  | 					} | ||||||
|  | 
 | ||||||
|  | 					for _, chk := range stmt.Schema.ParseCheckConstraints() { | ||||||
|  | 						if !migrator.DB.Migrator().HasConstraint(value, chk.Name) { | ||||||
|  | 							if err := migrator.DB.Migrator().CreateConstraint(value, chk.Name); err != nil { | ||||||
|  | 								return err | ||||||
|  | 							} | ||||||
|  | 						} | ||||||
|  | 					} | ||||||
|  | 
 | ||||||
|  | 					// create join table
 | ||||||
|  | 					joinValue := reflect.New(rel.JoinTable.ModelType).Interface() | ||||||
|  | 					if !migrator.DB.Migrator().HasTable(joinValue) { | ||||||
|  | 						defer migrator.DB.Migrator().CreateTable(joinValue) | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  | 				return nil | ||||||
|  | 			}); err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (migrator Migrator) CreateTable(values ...interface{}) error { | func (migrator Migrator) CreateTable(values ...interface{}) error { | ||||||
| 	// migrate
 | 	for _, value := range values { | ||||||
| 	// create join table
 | 		if err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
| 	return gorm.ErrNotImplemented | 			var ( | ||||||
|  | 				createTableSQL          = "CREATE TABLE ? (" | ||||||
|  | 				values                  = []interface{}{clause.Table{Name: stmt.Table}} | ||||||
|  | 				hasPrimaryKeyInDataType bool | ||||||
|  | 			) | ||||||
|  | 
 | ||||||
|  | 			for _, dbName := range stmt.Schema.DBNames { | ||||||
|  | 				field := stmt.Schema.FieldsByDBName[dbName] | ||||||
|  | 				createTableSQL += fmt.Sprintf("? ?") | ||||||
|  | 				hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(field.DBDataType), "PRIMARY KEY") | ||||||
|  | 				values = append(values, clause.Column{Name: dbName}, clause.Expr{SQL: field.DBDataType}) | ||||||
|  | 
 | ||||||
|  | 				if field.AutoIncrement { | ||||||
|  | 					createTableSQL += " AUTO_INCREMENT" | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				if field.NotNull { | ||||||
|  | 					createTableSQL += " NOT NULL" | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				if field.Unique { | ||||||
|  | 					createTableSQL += " UNIQUE" | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				if field.DefaultValue != "" { | ||||||
|  | 					createTableSQL += " DEFAULT ?" | ||||||
|  | 					values = append(values, clause.Expr{SQL: field.DefaultValue}) | ||||||
|  | 				} | ||||||
|  | 				createTableSQL += "," | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if !hasPrimaryKeyInDataType { | ||||||
|  | 				createTableSQL += "PRIMARY KEY ?," | ||||||
|  | 				primaryKeys := []interface{}{} | ||||||
|  | 				for _, field := range stmt.Schema.PrimaryFields { | ||||||
|  | 					primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName}) | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				values = append(values, primaryKeys) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			for _, idx := range stmt.Schema.ParseIndexes() { | ||||||
|  | 				createTableSQL += "INDEX ? ?," | ||||||
|  | 				values = append(values, clause.Expr{SQL: idx.Name}, buildIndexOptions(idx.Fields, stmt)) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			for _, rel := range stmt.Schema.Relationships.Relations { | ||||||
|  | 				if constraint := rel.ParseConstraint(); constraint != nil { | ||||||
|  | 					sql, vars := buildConstraint(constraint) | ||||||
|  | 					createTableSQL += sql + "," | ||||||
|  | 					values = append(values, vars...) | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				// create join table
 | ||||||
|  | 				joinValue := reflect.New(rel.JoinTable.ModelType).Interface() | ||||||
|  | 				if !migrator.DB.Migrator().HasTable(joinValue) { | ||||||
|  | 					defer migrator.DB.Migrator().CreateTable(joinValue) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			for _, chk := range stmt.Schema.ParseCheckConstraints() { | ||||||
|  | 				createTableSQL += "CONSTRAINT ? CHECK ?," | ||||||
|  | 				values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			createTableSQL = strings.TrimSuffix(createTableSQL, ",") | ||||||
|  | 
 | ||||||
|  | 			createTableSQL += ")" | ||||||
|  | 			return migrator.DB.Exec(createTableSQL, values...).Error | ||||||
|  | 		}); err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (migrator Migrator) DropTable(values ...interface{}) error { | func (migrator Migrator) DropTable(values ...interface{}) error { | ||||||
| @ -115,6 +232,27 @@ func (migrator Migrator) AlterColumn(value interface{}, field string) error { | |||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (migrator Migrator) HasColumn(value interface{}, field string) bool { | ||||||
|  | 	var count int64 | ||||||
|  | 	migrator.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
|  | 		currentDatabase := migrator.DB.Migrator().CurrentDatabase() | ||||||
|  | 		name := field | ||||||
|  | 		if field := stmt.Schema.LookUpField(field); field != nil { | ||||||
|  | 			name = field.DBName | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		return migrator.DB.Raw( | ||||||
|  | 			"SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?", | ||||||
|  | 			currentDatabase, stmt.Table, name, | ||||||
|  | 		).Scan(&count).Error | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	if count != 0 { | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	return false | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (migrator Migrator) RenameColumn(value interface{}, oldName, field string) error { | func (migrator Migrator) RenameColumn(value interface{}, oldName, field string) error { | ||||||
| 	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { | 	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
| 		if field := stmt.Schema.LookUpField(field); field != nil { | 		if field := stmt.Schema.LookUpField(field); field != nil { | ||||||
| @ -140,6 +278,28 @@ func (migrator Migrator) DropView(name string) error { | |||||||
| 	return gorm.ErrNotImplemented | 	return gorm.ErrNotImplemented | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) { | ||||||
|  | 	sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??" | ||||||
|  | 	if constraint.OnDelete != "" { | ||||||
|  | 		sql += " ON DELETE " + constraint.OnDelete | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if constraint.OnUpdate != "" { | ||||||
|  | 		sql += " ON UPDATE  " + constraint.OnUpdate | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var foreignKeys, references []interface{} | ||||||
|  | 	for _, field := range constraint.ForeignKeys { | ||||||
|  | 		foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName}) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, field := range constraint.References { | ||||||
|  | 		references = append(references, clause.Column{Name: field.DBName}) | ||||||
|  | 	} | ||||||
|  | 	results = append(results, constraint.Name, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references) | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (migrator Migrator) CreateConstraint(value interface{}, name string) error { | func (migrator Migrator) CreateConstraint(value interface{}, name string) error { | ||||||
| 	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { | 	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
| 		checkConstraints := stmt.Schema.ParseCheckConstraints() | 		checkConstraints := stmt.Schema.ParseCheckConstraints() | ||||||
| @ -152,26 +312,8 @@ func (migrator Migrator) CreateConstraint(value interface{}, name string) error | |||||||
| 
 | 
 | ||||||
| 		for _, rel := range stmt.Schema.Relationships.Relations { | 		for _, rel := range stmt.Schema.Relationships.Relations { | ||||||
| 			if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name { | 			if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name { | ||||||
| 				sql := "ALTER TABLE ? ADD CONSTRAINT ? FOREIGN KEY ? REFERENCES ??" | 				sql, values := buildConstraint(constraint) | ||||||
| 				if constraint.OnDelete != "" { | 				return migrator.DB.Exec("ALTER TABLE ? ADD "+sql, append([]interface{}{clause.Table{Name: stmt.Table}}, values...)...).Error | ||||||
| 					sql += " ON DELETE " + constraint.OnDelete |  | ||||||
| 				} |  | ||||||
| 
 |  | ||||||
| 				if constraint.OnUpdate != "" { |  | ||||||
| 					sql += " ON UPDATE  " + constraint.OnUpdate |  | ||||||
| 				} |  | ||||||
| 				var foreignKeys, references []interface{} |  | ||||||
| 				for _, field := range constraint.ForeignKeys { |  | ||||||
| 					foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName}) |  | ||||||
| 				} |  | ||||||
| 
 |  | ||||||
| 				for _, field := range constraint.References { |  | ||||||
| 					references = append(references, clause.Column{Name: field.DBName}) |  | ||||||
| 				} |  | ||||||
| 
 |  | ||||||
| 				return migrator.DB.Exec( |  | ||||||
| 					sql, clause.Table{Name: stmt.Table}, clause.Column{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references, |  | ||||||
| 				).Error |  | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| @ -205,27 +347,47 @@ func (migrator Migrator) DropConstraint(value interface{}, name string) error { | |||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (migrator Migrator) HasConstraint(value interface{}, name string) bool { | ||||||
|  | 	var count int64 | ||||||
|  | 	migrator.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
|  | 		currentDatabase := migrator.DB.Migrator().CurrentDatabase() | ||||||
|  | 		return migrator.DB.Raw( | ||||||
|  | 			"SELECT count(*) FROM INFORMATION_SCHEMA.referential_constraints WHERE constraint_schema = ? AND table_name = ? AND constraint_name = ?", | ||||||
|  | 			currentDatabase, stmt.Table, name, | ||||||
|  | 		).Scan(&count).Error | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	if count != 0 { | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	return false | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func buildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { | ||||||
|  | 	for _, opt := range opts { | ||||||
|  | 		str := stmt.Quote(opt.DBName) | ||||||
|  | 		if opt.Expression != "" { | ||||||
|  | 			str = opt.Expression | ||||||
|  | 		} else if opt.Length > 0 { | ||||||
|  | 			str += fmt.Sprintf("(%d)", opt.Length) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if opt.Sort != "" { | ||||||
|  | 			str += " " + opt.Sort | ||||||
|  | 		} | ||||||
|  | 		results = append(results, clause.Expr{SQL: str}) | ||||||
|  | 	} | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (migrator Migrator) CreateIndex(value interface{}, name string) error { | func (migrator Migrator) CreateIndex(value interface{}, name string) error { | ||||||
| 	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { | 	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
| 		err := fmt.Errorf("failed to create index with name %v", name) | 		err := fmt.Errorf("failed to create index with name %v", name) | ||||||
| 		indexes := stmt.Schema.ParseIndexes() | 		indexes := stmt.Schema.ParseIndexes() | ||||||
| 
 | 
 | ||||||
| 		if idx, ok := indexes[name]; ok { | 		if idx, ok := indexes[name]; ok { | ||||||
| 			fields := []interface{}{} | 			opts := buildIndexOptions(idx.Fields, stmt) | ||||||
| 			for _, field := range idx.Fields { | 			values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts} | ||||||
| 				str := stmt.Quote(field.DBName) |  | ||||||
| 				if field.Expression != "" { |  | ||||||
| 					str = field.Expression |  | ||||||
| 				} else if field.Length > 0 { |  | ||||||
| 					str += fmt.Sprintf("(%d)", field.Length) |  | ||||||
| 				} |  | ||||||
| 
 |  | ||||||
| 				if field.Sort != "" { |  | ||||||
| 					str += " " + field.Sort |  | ||||||
| 				} |  | ||||||
| 				fields = append(fields, clause.Expr{SQL: str}) |  | ||||||
| 			} |  | ||||||
| 			values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, fields} |  | ||||||
| 
 | 
 | ||||||
| 			createIndexSQL := "CREATE " | 			createIndexSQL := "CREATE " | ||||||
| 			if idx.Class != "" { | 			if idx.Class != "" { | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu