Handle constraint dependencies smartly
This commit is contained in:
		
							parent
							
								
									1895d281bf
								
							
						
					
					
						commit
						d3c63a03cb
					
				| @ -48,7 +48,7 @@ func (m Migrator) DataTypeOf(field *schema.Field) string { | |||||||
| // AutoMigrate
 | // AutoMigrate
 | ||||||
| func (m Migrator) AutoMigrate(values ...interface{}) error { | func (m Migrator) AutoMigrate(values ...interface{}) error { | ||||||
| 	// TODO smart migrate data type
 | 	// TODO smart migrate data type
 | ||||||
| 	for _, value := range values { | 	for _, value := range m.ReorderModels(values, true) { | ||||||
| 		tx := m.DB.Session(&gorm.Session{}) | 		tx := m.DB.Session(&gorm.Session{}) | ||||||
| 		if !tx.Migrator().HasTable(value) { | 		if !tx.Migrator().HasTable(value) { | ||||||
| 			if err := tx.Migrator().CreateTable(value); err != nil { | 			if err := tx.Migrator().CreateTable(value); err != nil { | ||||||
| @ -100,7 +100,7 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (m Migrator) CreateTable(values ...interface{}) error { | func (m Migrator) CreateTable(values ...interface{}) error { | ||||||
| 	for _, value := range values { | 	for _, value := range m.ReorderModels(values, false) { | ||||||
| 		tx := m.DB.Session(&gorm.Session{}) | 		tx := m.DB.Session(&gorm.Session{}) | ||||||
| 		if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { | 		if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
| 			var ( | 			var ( | ||||||
| @ -186,7 +186,9 @@ func (m Migrator) CreateTable(values ...interface{}) error { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (m Migrator) DropTable(values ...interface{}) error { | func (m Migrator) DropTable(values ...interface{}) error { | ||||||
| 	for _, value := range values { | 	values = m.ReorderModels(values, false) | ||||||
|  | 	for i := len(values) - 1; i >= 0; i-- { | ||||||
|  | 		value := values[i] | ||||||
| 		tx := m.DB.Session(&gorm.Session{}) | 		tx := m.DB.Session(&gorm.Session{}) | ||||||
| 		if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { | 		if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
| 			return tx.Exec("DROP TABLE ?", clause.Table{Name: stmt.Table}).Error | 			return tx.Exec("DROP TABLE ?", clause.Table{Name: stmt.Table}).Error | ||||||
| @ -475,3 +477,72 @@ func (m Migrator) CurrentDatabase() (name string) { | |||||||
| 	m.DB.Raw("SELECT DATABASE()").Row().Scan(&name) | 	m.DB.Raw("SELECT DATABASE()").Row().Scan(&name) | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // ReorderModels reorder models according to constraint dependencies
 | ||||||
|  | func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []interface{}) { | ||||||
|  | 	type Dependency struct { | ||||||
|  | 		Table   string | ||||||
|  | 		Depends []*schema.Schema | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var ( | ||||||
|  | 		modelNames, orderedModelNames []string | ||||||
|  | 		orderedModelNamesMap          = map[string]bool{} | ||||||
|  | 		valuesMap                     = map[string]*gorm.Statement{} | ||||||
|  | 		dependencies                  = map[string]Dependency{} | ||||||
|  | 		insertIntoOrderedMap          func(name string) | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	parseDependence := func(value interface{}, addToMap bool) { | ||||||
|  | 		stmt := &gorm.Statement{DB: m.DB, Dest: value} | ||||||
|  | 		stmt.Parse(value) | ||||||
|  | 		dep := Dependency{Table: stmt.Schema.Table} | ||||||
|  | 
 | ||||||
|  | 		for _, rel := range stmt.Schema.Relationships.Relations { | ||||||
|  | 			if constraint := rel.ParseConstraint(); constraint != nil { | ||||||
|  | 				dep.Depends = append(dep.Depends, constraint.ReferenceSchema) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		dependencies[stmt.Schema.Table] = dep | ||||||
|  | 
 | ||||||
|  | 		if addToMap { | ||||||
|  | 			modelNames = append(modelNames, stmt.Schema.Table) | ||||||
|  | 			valuesMap[stmt.Schema.Table] = stmt | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, value := range values { | ||||||
|  | 		parseDependence(value, true) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	insertIntoOrderedMap = func(name string) { | ||||||
|  | 		// avoid loop
 | ||||||
|  | 		if _, ok := orderedModelNamesMap[name]; ok { | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		dep := dependencies[name] | ||||||
|  | 		for _, d := range dep.Depends { | ||||||
|  | 			if _, ok := valuesMap[d.Table]; ok { | ||||||
|  | 				if _, ok := orderedModelNamesMap[d.Table]; !ok && name != d.Table { | ||||||
|  | 					insertIntoOrderedMap(d.Table) | ||||||
|  | 				} | ||||||
|  | 			} else if autoAdd { | ||||||
|  | 				parseDependence(reflect.New(d.ModelType).Interface(), autoAdd) | ||||||
|  | 				insertIntoOrderedMap(d.Table) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		orderedModelNames = append(orderedModelNames, name) | ||||||
|  | 		orderedModelNamesMap[name] = true | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, name := range modelNames { | ||||||
|  | 		insertIntoOrderedMap(name) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, name := range orderedModelNames { | ||||||
|  | 		results = append(results, valuesMap[name].Dest) | ||||||
|  | 	} | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | |||||||
| @ -1,20 +1,20 @@ | |||||||
| package tests | package tests | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"math/rand" | ||||||
| 	"testing" | 	"testing" | ||||||
|  | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/jinzhu/gorm" | 	"github.com/jinzhu/gorm" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestMigrate(t *testing.T, db *gorm.DB) { | func TestMigrate(t *testing.T, db *gorm.DB) { | ||||||
| 	allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}} | 	allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}} | ||||||
|  | 	rand.Seed(time.Now().UnixNano()) | ||||||
|  | 	rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) | ||||||
| 
 | 
 | ||||||
| 	for _, m := range allModels { | 	if err := db.Migrator().DropTable(allModels...); err != nil { | ||||||
| 		if db.Migrator().HasTable(m) { | 		t.Errorf("Failed to drop table, got error %v", err) | ||||||
| 			if err := db.Migrator().DropTable(m); err != nil { |  | ||||||
| 				t.Errorf("Failed to drop table, got error %v", err) |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := db.AutoMigrate(allModels...); err != nil { | 	if err := db.AutoMigrate(allModels...); err != nil { | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu