Add DataTypeOf for dialector
This commit is contained in:
		
							parent
							
								
									0801cdf164
								
							
						
					
					
						commit
						fab7d96da5
					
				
							
								
								
									
										37
									
								
								dialects/mssql/migrator.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								dialects/mssql/migrator.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,37 @@ | |||||||
|  | package mssql | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"github.com/jinzhu/gorm" | ||||||
|  | 	"github.com/jinzhu/gorm/migrator" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | type Migrator struct { | ||||||
|  | 	migrator.Migrator | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (m Migrator) HasIndex(value interface{}, name string) bool { | ||||||
|  | 	var count int | ||||||
|  | 	m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
|  | 		return m.DB.Raw( | ||||||
|  | 			"SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", | ||||||
|  | 			name, stmt.Table, | ||||||
|  | 		).Row().Scan(&count) | ||||||
|  | 	}) | ||||||
|  | 	return count > 0 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (m Migrator) HasConstraint(value interface{}, name string) bool { | ||||||
|  | 	var count int64 | ||||||
|  | 	m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
|  | 		return m.DB.Raw( | ||||||
|  | 			`SELECT count(*) FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id inner join information_schema.tables as I on I.TABLE_NAME = T.name WHERE F.name = ?  AND T.Name = ? AND I.TABLE_CATALOG = ?;`, | ||||||
|  | 			name, stmt.Table, m.CurrentDatabase(), | ||||||
|  | 		).Row().Scan(&count) | ||||||
|  | 	}) | ||||||
|  | 	return count > 0 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (m Migrator) CurrentDatabase() (name string) { | ||||||
|  | 	m.DB.Raw("SELECT DB_NAME() AS [Current Database]").Row().Scan(&name) | ||||||
|  | 	return | ||||||
|  | } | ||||||
							
								
								
									
										75
									
								
								dialects/mssql/mssql.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										75
									
								
								dialects/mssql/mssql.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,75 @@ | |||||||
|  | package mssql | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"database/sql" | ||||||
|  | 	"fmt" | ||||||
|  | 
 | ||||||
|  | 	_ "github.com/denisenkom/go-mssqldb" | ||||||
|  | 	"github.com/jinzhu/gorm" | ||||||
|  | 	"github.com/jinzhu/gorm/callbacks" | ||||||
|  | 	"github.com/jinzhu/gorm/migrator" | ||||||
|  | 	"github.com/jinzhu/gorm/schema" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | type Dialector struct { | ||||||
|  | 	DSN string | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func Open(dsn string) gorm.Dialector { | ||||||
|  | 	return &Dialector{DSN: dsn} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (dialector Dialector) Initialize(db *gorm.DB) (err error) { | ||||||
|  | 	// register callbacks
 | ||||||
|  | 	callbacks.RegisterDefaultCallbacks(db) | ||||||
|  | 
 | ||||||
|  | 	db.DB, err = sql.Open("sqlserver", dialector.DSN) | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { | ||||||
|  | 	return Migrator{migrator.Migrator{Config: migrator.Config{DB: db}}} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { | ||||||
|  | 	return "?" | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (dialector Dialector) QuoteChars() [2]byte { | ||||||
|  | 	return [2]byte{'[', ']'} // `name`
 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (dialector Dialector) DataTypeOf(field *schema.Field) string { | ||||||
|  | 	switch field.DataType { | ||||||
|  | 	case schema.Bool: | ||||||
|  | 		return "bit" | ||||||
|  | 	case schema.Int, schema.Uint: | ||||||
|  | 		var sqlType string | ||||||
|  | 		switch { | ||||||
|  | 		case field.Size < 16: | ||||||
|  | 			sqlType = "smallint" | ||||||
|  | 		case field.Size < 31: | ||||||
|  | 			sqlType = "int" | ||||||
|  | 		default: | ||||||
|  | 			sqlType = "bigint" | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if field.AutoIncrement { | ||||||
|  | 			return sqlType + " IDENTITY(1,1)" | ||||||
|  | 		} | ||||||
|  | 		return sqlType | ||||||
|  | 	case schema.Float: | ||||||
|  | 		return "decimal" | ||||||
|  | 	case schema.String: | ||||||
|  | 		if field.Size > 0 && field.Size <= 4000 { | ||||||
|  | 			return fmt.Sprintf("nvarchar(%d)", field.Size) | ||||||
|  | 		} | ||||||
|  | 		return "ntext" | ||||||
|  | 	case schema.Time: | ||||||
|  | 		return "datetimeoffset" | ||||||
|  | 	case schema.Bytes: | ||||||
|  | 		return "binary" | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return "" | ||||||
|  | } | ||||||
							
								
								
									
										43
									
								
								dialects/mysql/migrator.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										43
									
								
								dialects/mysql/migrator.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,43 @@ | |||||||
|  | package mysql | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 
 | ||||||
|  | 	"github.com/jinzhu/gorm" | ||||||
|  | 	"github.com/jinzhu/gorm/clause" | ||||||
|  | 	"github.com/jinzhu/gorm/migrator" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | type Migrator struct { | ||||||
|  | 	migrator.Migrator | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (m Migrator) AlterColumn(value interface{}, field string) error { | ||||||
|  | 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
|  | 		if field := stmt.Schema.LookUpField(field); field != nil { | ||||||
|  | 			return m.DB.Exec( | ||||||
|  | 				"ALTER TABLE ? MODIFY COLUMN ? TYPE ?", | ||||||
|  | 				clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DBDataType}, | ||||||
|  | 			).Error | ||||||
|  | 		} | ||||||
|  | 		return fmt.Errorf("failed to look up field with name: %s", field) | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (m Migrator) DropConstraint(value interface{}, name string) error { | ||||||
|  | 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
|  | 		for _, chk := range stmt.Schema.ParseCheckConstraints() { | ||||||
|  | 			if chk.Name == name { | ||||||
|  | 				return m.DB.Exec( | ||||||
|  | 					"ALTER TABLE ? DROP CHECK ?", | ||||||
|  | 					clause.Table{Name: stmt.Table}, clause.Column{Name: name}, | ||||||
|  | 				).Error | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		return m.DB.Exec( | ||||||
|  | 			"ALTER TABLE ? DROP FOREIGN KEY ?", | ||||||
|  | 			clause.Table{Name: stmt.Table}, clause.Column{Name: name}, | ||||||
|  | 		).Error | ||||||
|  | 	}) | ||||||
|  | } | ||||||
| @ -1,33 +1,104 @@ | |||||||
| package mysql | package mysql | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"database/sql" | ||||||
|  | 	"fmt" | ||||||
|  | 	"math" | ||||||
|  | 
 | ||||||
| 	_ "github.com/go-sql-driver/mysql" | 	_ "github.com/go-sql-driver/mysql" | ||||||
| 	"github.com/jinzhu/gorm" | 	"github.com/jinzhu/gorm" | ||||||
| 	"github.com/jinzhu/gorm/callbacks" | 	"github.com/jinzhu/gorm/callbacks" | ||||||
|  | 	"github.com/jinzhu/gorm/migrator" | ||||||
|  | 	"github.com/jinzhu/gorm/schema" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type Dialector struct { | type Dialector struct { | ||||||
|  | 	DSN string | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func Open(dsn string) gorm.Dialector { | func Open(dsn string) gorm.Dialector { | ||||||
| 	return &Dialector{} | 	return &Dialector{DSN: dsn} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (Dialector) Initialize(db *gorm.DB) error { | func (dialector Dialector) Initialize(db *gorm.DB) (err error) { | ||||||
| 	// register callbacks
 | 	// register callbacks
 | ||||||
| 	callbacks.RegisterDefaultCallbacks(db) | 	callbacks.RegisterDefaultCallbacks(db) | ||||||
|  | 	db.DB, err = sql.Open("sqlite3", dialector.DSN) | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (Dialector) Migrator() gorm.Migrator { | func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { | ||||||
| 	return nil | 	return Migrator{migrator.Migrator{Config: migrator.Config{DB: db}}} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (Dialector) BindVar(stmt gorm.Statement, v interface{}) string { | func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { | ||||||
| 	return "?" | 	return "?" | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (Dialector) QuoteChars() [2]byte { | func (dialector Dialector) QuoteChars() [2]byte { | ||||||
| 	return [2]byte{'`', '`'} // `name`
 | 	return [2]byte{'`', '`'} // `name`
 | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func (dialector Dialector) DataTypeOf(field *schema.Field) string { | ||||||
|  | 	switch field.DataType { | ||||||
|  | 	case schema.Bool: | ||||||
|  | 		return "boolean" | ||||||
|  | 	case schema.Int, schema.Uint: | ||||||
|  | 		sqlType := "int" | ||||||
|  | 		switch { | ||||||
|  | 		case field.Size <= 8: | ||||||
|  | 			sqlType = "tinyint" | ||||||
|  | 		case field.Size <= 16: | ||||||
|  | 			sqlType = "smallint" | ||||||
|  | 		case field.Size <= 32: | ||||||
|  | 			sqlType = "int" | ||||||
|  | 		default: | ||||||
|  | 			sqlType = "bigint" | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if field.DataType == schema.Uint { | ||||||
|  | 			sqlType += " unsigned" | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if field.AutoIncrement { | ||||||
|  | 			sqlType += " AUTO_INCREMENT" | ||||||
|  | 		} | ||||||
|  | 		return sqlType | ||||||
|  | 	case schema.Float: | ||||||
|  | 		if field.Size <= 32 { | ||||||
|  | 			return "float" | ||||||
|  | 		} | ||||||
|  | 		return "double" | ||||||
|  | 	case schema.String: | ||||||
|  | 		size := field.Size | ||||||
|  | 		if size >= 65536 && size <= int(math.Pow(2, 24)) { | ||||||
|  | 			return "mediumtext" | ||||||
|  | 		} else if size > int(math.Pow(2, 24)) || size < 0 { | ||||||
|  | 			return "longtext" | ||||||
|  | 		} | ||||||
|  | 		return fmt.Sprintf("varchar(%d)", size) | ||||||
|  | 	case schema.Time: | ||||||
|  | 		precision := "" | ||||||
|  | 		if field.Precision > 0 { | ||||||
|  | 			precision = fmt.Sprintf("(%d)", field.Precision) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if field.NotNull || field.PrimaryKey { | ||||||
|  | 			return "datetime" + precision | ||||||
|  | 		} | ||||||
|  | 		return "datetime" + precision + " NULL" | ||||||
|  | 	case schema.Bytes: | ||||||
|  | 		if field.Size > 0 && field.Size < 65536 { | ||||||
|  | 			return fmt.Sprintf("varbinary(%d)", field.Size) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if field.Size >= 65536 && field.Size <= int(math.Pow(2, 24)) { | ||||||
|  | 			return "mediumblob" | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		return "longblob" | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return "" | ||||||
|  | } | ||||||
|  | |||||||
							
								
								
									
										89
									
								
								dialects/postgres/migrator.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										89
									
								
								dialects/postgres/migrator.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,89 @@ | |||||||
|  | package postgres | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 
 | ||||||
|  | 	"github.com/jinzhu/gorm" | ||||||
|  | 	"github.com/jinzhu/gorm/clause" | ||||||
|  | 	"github.com/jinzhu/gorm/migrator" | ||||||
|  | 	"github.com/jinzhu/gorm/schema" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | type Migrator struct { | ||||||
|  | 	migrator.Migrator | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (m Migrator) CurrentDatabase() (name string) { | ||||||
|  | 	m.DB.Raw("SELECT CURRENT_DATABASE()").Row().Scan(&name) | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { | ||||||
|  | 	for _, opt := range opts { | ||||||
|  | 		str := stmt.Quote(opt.DBName) | ||||||
|  | 		if opt.Expression != "" { | ||||||
|  | 			str = opt.Expression | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if opt.Collate != "" { | ||||||
|  | 			str += " COLLATE " + opt.Collate | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if opt.Sort != "" { | ||||||
|  | 			str += " " + opt.Sort | ||||||
|  | 		} | ||||||
|  | 		results = append(results, clause.Expr{SQL: str}) | ||||||
|  | 	} | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (m Migrator) HasIndex(value interface{}, indexName string) bool { | ||||||
|  | 	var count int64 | ||||||
|  | 	m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
|  | 		return m.DB.Raw( | ||||||
|  | 			"SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ? AND schemaname = CURRENT_SCHEMA()", stmt.Table, indexName, | ||||||
|  | 		).Row().Scan(&count) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	return count > 0 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (m Migrator) CreateIndex(value interface{}, name string) error { | ||||||
|  | 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
|  | 		err := fmt.Errorf("failed to create index with name %v", name) | ||||||
|  | 		indexes := stmt.Schema.ParseIndexes() | ||||||
|  | 
 | ||||||
|  | 		if idx, ok := indexes[name]; ok { | ||||||
|  | 			opts := m.BuildIndexOptions(idx.Fields, stmt) | ||||||
|  | 			values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts} | ||||||
|  | 
 | ||||||
|  | 			createIndexSQL := "CREATE " | ||||||
|  | 			if idx.Class != "" { | ||||||
|  | 				createIndexSQL += idx.Class + " " | ||||||
|  | 			} | ||||||
|  | 			createIndexSQL += "INDEX ?" | ||||||
|  | 
 | ||||||
|  | 			if idx.Type != "" { | ||||||
|  | 				createIndexSQL += " USING " + idx.Type | ||||||
|  | 			} | ||||||
|  | 			createIndexSQL += " ON ??" | ||||||
|  | 
 | ||||||
|  | 			if idx.Where != "" { | ||||||
|  | 				createIndexSQL += " WHERE " + idx.Where | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			return m.DB.Exec(createIndexSQL, values...).Error | ||||||
|  | 		} else if field := stmt.Schema.LookUpField(name); field != nil { | ||||||
|  | 			for _, idx := range indexes { | ||||||
|  | 				for _, idxOpt := range idx.Fields { | ||||||
|  | 					if idxOpt.Field == field { | ||||||
|  | 						if err = m.CreateIndex(value, idx.Name); err != nil { | ||||||
|  | 							return err | ||||||
|  | 						} | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		return err | ||||||
|  | 	}) | ||||||
|  | } | ||||||
| @ -2,9 +2,12 @@ package postgres | |||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"database/sql" | 	"database/sql" | ||||||
|  | 	"fmt" | ||||||
| 
 | 
 | ||||||
| 	"github.com/jinzhu/gorm" | 	"github.com/jinzhu/gorm" | ||||||
| 	"github.com/jinzhu/gorm/callbacks" | 	"github.com/jinzhu/gorm/callbacks" | ||||||
|  | 	"github.com/jinzhu/gorm/migrator" | ||||||
|  | 	"github.com/jinzhu/gorm/schema" | ||||||
| 	_ "github.com/lib/pq" | 	_ "github.com/lib/pq" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| @ -24,14 +27,54 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { | |||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (Dialector) Migrator() gorm.Migrator { | func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { | ||||||
| 	return nil | 	return Migrator{migrator.Migrator{Config: migrator.Config{DB: db}}} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { | func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { | ||||||
| 	return "?" | 	return "?" | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (Dialector) QuoteChars() [2]byte { | func (dialector Dialector) QuoteChars() [2]byte { | ||||||
| 	return [2]byte{'"', '"'} // "name"
 | 	return [2]byte{'"', '"'} // "name"
 | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func (dialector Dialector) DataTypeOf(field *schema.Field) string { | ||||||
|  | 	switch field.DataType { | ||||||
|  | 	case schema.Bool: | ||||||
|  | 		return "boolean" | ||||||
|  | 	case schema.Int, schema.Uint: | ||||||
|  | 		if field.AutoIncrement { | ||||||
|  | 			switch { | ||||||
|  | 			case field.Size < 16: | ||||||
|  | 				return "smallserial" | ||||||
|  | 			case field.Size < 31: | ||||||
|  | 				return "serial" | ||||||
|  | 			default: | ||||||
|  | 				return "bigserial" | ||||||
|  | 			} | ||||||
|  | 		} else { | ||||||
|  | 			switch { | ||||||
|  | 			case field.Size < 16: | ||||||
|  | 				return "smallint" | ||||||
|  | 			case field.Size < 31: | ||||||
|  | 				return "integer" | ||||||
|  | 			default: | ||||||
|  | 				return "bigint" | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	case schema.Float: | ||||||
|  | 		return "decimal" | ||||||
|  | 	case schema.String: | ||||||
|  | 		if field.Size > 0 { | ||||||
|  | 			return fmt.Sprintf("varchar(%d)", field.Size) | ||||||
|  | 		} | ||||||
|  | 		return "text" | ||||||
|  | 	case schema.Time: | ||||||
|  | 		return "timestamp with time zone" | ||||||
|  | 	case schema.Bytes: | ||||||
|  | 		return "bytea" | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return "" | ||||||
|  | } | ||||||
|  | |||||||
							
								
								
									
										122
									
								
								dialects/sqlite/migrator.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										122
									
								
								dialects/sqlite/migrator.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,122 @@ | |||||||
|  | package sqlite | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 
 | ||||||
|  | 	"github.com/jinzhu/gorm" | ||||||
|  | 	"github.com/jinzhu/gorm/clause" | ||||||
|  | 	"github.com/jinzhu/gorm/migrator" | ||||||
|  | 	"github.com/jinzhu/gorm/schema" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | type Migrator struct { | ||||||
|  | 	migrator.Migrator | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (m Migrator) HasTable(value interface{}) bool { | ||||||
|  | 	var count int | ||||||
|  | 	m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
|  | 		return m.DB.Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", stmt.Table).Row().Scan(&count) | ||||||
|  | 	}) | ||||||
|  | 	return count > 0 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (m Migrator) HasColumn(value interface{}, field string) bool { | ||||||
|  | 	var count int | ||||||
|  | 	m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
|  | 		name := field | ||||||
|  | 		if field := stmt.Schema.LookUpField(field); field != nil { | ||||||
|  | 			name = field.DBName | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		return m.DB.Raw( | ||||||
|  | 			"SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE ? OR sql LIKE ?)", | ||||||
|  | 			stmt.Table, `%"`+name+`" %`, `%`+name+` %`, | ||||||
|  | 		).Row().Scan(&count) | ||||||
|  | 	}) | ||||||
|  | 	return count > 0 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (m Migrator) HasIndex(value interface{}, name string) bool { | ||||||
|  | 	var count int | ||||||
|  | 	m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
|  | 		return m.DB.Raw( | ||||||
|  | 			"SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE ?", | ||||||
|  | 			stmt.Table, "%INDEX "+name+" ON%", | ||||||
|  | 		).Row().Scan(&count) | ||||||
|  | 	}) | ||||||
|  | 	return count > 0 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (m Migrator) CreateConstraint(interface{}, string) error { | ||||||
|  | 	return gorm.ErrNotImplemented | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (m Migrator) DropConstraint(interface{}, string) error { | ||||||
|  | 	return gorm.ErrNotImplemented | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (m Migrator) CurrentDatabase() (name string) { | ||||||
|  | 	var null interface{} | ||||||
|  | 	m.DB.Raw("PRAGMA database_list").Row().Scan(&null, &name, &null) | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { | ||||||
|  | 	for _, opt := range opts { | ||||||
|  | 		str := stmt.Quote(opt.DBName) | ||||||
|  | 		if opt.Expression != "" { | ||||||
|  | 			str = opt.Expression | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if opt.Collate != "" { | ||||||
|  | 			str += " COLLATE " + opt.Collate | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if opt.Sort != "" { | ||||||
|  | 			str += " " + opt.Sort | ||||||
|  | 		} | ||||||
|  | 		results = append(results, clause.Expr{SQL: str}) | ||||||
|  | 	} | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (m Migrator) CreateIndex(value interface{}, name string) error { | ||||||
|  | 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
|  | 		err := fmt.Errorf("failed to create index with name %v", name) | ||||||
|  | 		indexes := stmt.Schema.ParseIndexes() | ||||||
|  | 
 | ||||||
|  | 		if idx, ok := indexes[name]; ok { | ||||||
|  | 			opts := m.BuildIndexOptions(idx.Fields, stmt) | ||||||
|  | 			values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts} | ||||||
|  | 
 | ||||||
|  | 			createIndexSQL := "CREATE " | ||||||
|  | 			if idx.Class != "" { | ||||||
|  | 				createIndexSQL += idx.Class + " " | ||||||
|  | 			} | ||||||
|  | 			createIndexSQL += "INDEX ?" | ||||||
|  | 
 | ||||||
|  | 			if idx.Type != "" { | ||||||
|  | 				createIndexSQL += " USING " + idx.Type | ||||||
|  | 			} | ||||||
|  | 			createIndexSQL += " ON ??" | ||||||
|  | 
 | ||||||
|  | 			if idx.Where != "" { | ||||||
|  | 				createIndexSQL += " WHERE " + idx.Where | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			return m.DB.Exec(createIndexSQL, values...).Error | ||||||
|  | 		} else if field := stmt.Schema.LookUpField(name); field != nil { | ||||||
|  | 			for _, idx := range indexes { | ||||||
|  | 				for _, idxOpt := range idx.Fields { | ||||||
|  | 					if idxOpt.Field == field { | ||||||
|  | 						if err = m.CreateIndex(value, idx.Name); err != nil { | ||||||
|  | 							return err | ||||||
|  | 						} | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		return err | ||||||
|  | 	}) | ||||||
|  | } | ||||||
| @ -5,6 +5,8 @@ import ( | |||||||
| 
 | 
 | ||||||
| 	"github.com/jinzhu/gorm" | 	"github.com/jinzhu/gorm" | ||||||
| 	"github.com/jinzhu/gorm/callbacks" | 	"github.com/jinzhu/gorm/callbacks" | ||||||
|  | 	"github.com/jinzhu/gorm/migrator" | ||||||
|  | 	"github.com/jinzhu/gorm/schema" | ||||||
| 	_ "github.com/mattn/go-sqlite3" | 	_ "github.com/mattn/go-sqlite3" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| @ -24,14 +26,36 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { | |||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (Dialector) Migrator() gorm.Migrator { | func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { | ||||||
| 	return nil | 	return Migrator{migrator.Migrator{Config: migrator.Config{DB: db}}} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { | func (dialector Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { | ||||||
| 	return "?" | 	return "?" | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (Dialector) QuoteChars() [2]byte { | func (dialector Dialector) QuoteChars() [2]byte { | ||||||
| 	return [2]byte{'`', '`'} // `name`
 | 	return [2]byte{'`', '`'} // `name`
 | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func (dialector Dialector) DataTypeOf(field *schema.Field) string { | ||||||
|  | 	switch field.DataType { | ||||||
|  | 	case schema.Bool: | ||||||
|  | 		return "NUMERIC" | ||||||
|  | 	case schema.Int, schema.Uint: | ||||||
|  | 		if field.AutoIncrement { | ||||||
|  | 			// https://www.sqlite.org/autoinc.html
 | ||||||
|  | 			return "INTEGER PRIMARY KEY AUTOINCREMENT" | ||||||
|  | 		} else { | ||||||
|  | 			return "INTEGER" | ||||||
|  | 		} | ||||||
|  | 	case schema.Float: | ||||||
|  | 		return "REAL" | ||||||
|  | 	case schema.String, schema.Time: | ||||||
|  | 		return "TEXT" | ||||||
|  | 	case schema.Bytes: | ||||||
|  | 		return "BLOB" | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return "" | ||||||
|  | } | ||||||
|  | |||||||
| @ -3,12 +3,15 @@ package gorm | |||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"database/sql" | 	"database/sql" | ||||||
|  | 
 | ||||||
|  | 	"github.com/jinzhu/gorm/schema" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // Dialector GORM database dialector
 | // Dialector GORM database dialector
 | ||||||
| type Dialector interface { | type Dialector interface { | ||||||
| 	Initialize(*DB) error | 	Initialize(*DB) error | ||||||
| 	Migrator() Migrator | 	Migrator(db *DB) Migrator | ||||||
|  | 	DataTypeOf(*schema.Field) string | ||||||
| 	BindVar(stmt *Statement, v interface{}) string | 	BindVar(stmt *Statement, v interface{}) string | ||||||
| 	QuoteChars() [2]byte | 	QuoteChars() [2]byte | ||||||
| } | } | ||||||
|  | |||||||
| @ -6,7 +6,7 @@ import ( | |||||||
| 
 | 
 | ||||||
| // Migrator returns migrator
 | // Migrator returns migrator
 | ||||||
| func (db *DB) Migrator() Migrator { | func (db *DB) Migrator() Migrator { | ||||||
| 	return db.Dialector.Migrator() | 	return db.Dialector.Migrator(db) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // ViewOption view option
 | // ViewOption view option
 | ||||||
| @ -26,7 +26,7 @@ type Migrator interface { | |||||||
| 	// Tables
 | 	// Tables
 | ||||||
| 	CreateTable(dst ...interface{}) error | 	CreateTable(dst ...interface{}) error | ||||||
| 	DropTable(dst ...interface{}) error | 	DropTable(dst ...interface{}) error | ||||||
| 	HasTable(dst ...interface{}) bool | 	HasTable(dst interface{}) bool | ||||||
| 	RenameTable(oldName, newName string) error | 	RenameTable(oldName, newName string) error | ||||||
| 
 | 
 | ||||||
| 	// Columns
 | 	// Columns
 | ||||||
|  | |||||||
| @ -11,21 +11,21 @@ import ( | |||||||
| 	"github.com/jinzhu/gorm/schema" | 	"github.com/jinzhu/gorm/schema" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // Migrator migrator struct
 | // Migrator m struct
 | ||||||
| type Migrator struct { | type Migrator struct { | ||||||
| 	*Config | 	Config | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Config schema config
 | // Config schema config
 | ||||||
| type Config struct { | type Config struct { | ||||||
| 	CheckExistsBeforeDropping bool | 	DB *gorm.DB | ||||||
| 	DB                        *gorm.DB | 	gorm.Dialector | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (migrator Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error { | func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error { | ||||||
| 	stmt := migrator.DB.Statement | 	stmt := m.DB.Statement | ||||||
| 	if stmt == nil { | 	if stmt == nil { | ||||||
| 		stmt = &gorm.Statement{DB: migrator.DB} | 		stmt = &gorm.Statement{DB: m.DB} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := stmt.Parse(value); err != nil { | 	if err := stmt.Parse(value); err != nil { | ||||||
| @ -35,20 +35,28 @@ func (migrator Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement | |||||||
| 	return fc(stmt) | 	return fc(stmt) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (m Migrator) DataTypeOf(field *schema.Field) string { | ||||||
|  | 	if field.DBDataType != "" { | ||||||
|  | 		return field.DBDataType | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return m.Dialector.DataTypeOf(field) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // AutoMigrate
 | // AutoMigrate
 | ||||||
| func (migrator Migrator) AutoMigrate(values ...interface{}) error { | func (m Migrator) AutoMigrate(values ...interface{}) error { | ||||||
| 	// TODO smart migrate data type
 | 	// TODO smart migrate data type
 | ||||||
| 
 | 
 | ||||||
| 	for _, value := range values { | 	for _, value := range values { | ||||||
| 		if !migrator.DB.Migrator().HasTable(value) { | 		if !m.DB.Migrator().HasTable(value) { | ||||||
| 			if err := migrator.DB.Migrator().CreateTable(value); err != nil { | 			if err := m.DB.Migrator().CreateTable(value); err != nil { | ||||||
| 				return err | 				return err | ||||||
| 			} | 			} | ||||||
| 		} else { | 		} else { | ||||||
| 			if err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error { | 			if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
| 				for _, field := range stmt.Schema.FieldsByDBName { | 				for _, field := range stmt.Schema.FieldsByDBName { | ||||||
| 					if !migrator.DB.Migrator().HasColumn(value, field.DBName) { | 					if !m.DB.Migrator().HasColumn(value, field.DBName) { | ||||||
| 						if err := migrator.DB.Migrator().AddColumn(value, field.DBName); err != nil { | 						if err := m.DB.Migrator().AddColumn(value, field.DBName); err != nil { | ||||||
| 							return err | 							return err | ||||||
| 						} | 						} | ||||||
| 					} | 					} | ||||||
| @ -56,16 +64,16 @@ func (migrator Migrator) AutoMigrate(values ...interface{}) error { | |||||||
| 
 | 
 | ||||||
| 				for _, rel := range stmt.Schema.Relationships.Relations { | 				for _, rel := range stmt.Schema.Relationships.Relations { | ||||||
| 					if constraint := rel.ParseConstraint(); constraint != nil { | 					if constraint := rel.ParseConstraint(); constraint != nil { | ||||||
| 						if !migrator.DB.Migrator().HasConstraint(value, constraint.Name) { | 						if !m.DB.Migrator().HasConstraint(value, constraint.Name) { | ||||||
| 							if err := migrator.DB.Migrator().CreateConstraint(value, constraint.Name); err != nil { | 							if err := m.DB.Migrator().CreateConstraint(value, constraint.Name); err != nil { | ||||||
| 								return err | 								return err | ||||||
| 							} | 							} | ||||||
| 						} | 						} | ||||||
| 					} | 					} | ||||||
| 
 | 
 | ||||||
| 					for _, chk := range stmt.Schema.ParseCheckConstraints() { | 					for _, chk := range stmt.Schema.ParseCheckConstraints() { | ||||||
| 						if !migrator.DB.Migrator().HasConstraint(value, chk.Name) { | 						if !m.DB.Migrator().HasConstraint(value, chk.Name) { | ||||||
| 							if err := migrator.DB.Migrator().CreateConstraint(value, chk.Name); err != nil { | 							if err := m.DB.Migrator().CreateConstraint(value, chk.Name); err != nil { | ||||||
| 								return err | 								return err | ||||||
| 							} | 							} | ||||||
| 						} | 						} | ||||||
| @ -73,8 +81,8 @@ func (migrator Migrator) AutoMigrate(values ...interface{}) error { | |||||||
| 
 | 
 | ||||||
| 					// create join table
 | 					// create join table
 | ||||||
| 					joinValue := reflect.New(rel.JoinTable.ModelType).Interface() | 					joinValue := reflect.New(rel.JoinTable.ModelType).Interface() | ||||||
| 					if !migrator.DB.Migrator().HasTable(joinValue) { | 					if !m.DB.Migrator().HasTable(joinValue) { | ||||||
| 						defer migrator.DB.Migrator().CreateTable(joinValue) | 						defer m.DB.Migrator().CreateTable(joinValue) | ||||||
| 					} | 					} | ||||||
| 				} | 				} | ||||||
| 				return nil | 				return nil | ||||||
| @ -87,9 +95,9 @@ func (migrator Migrator) AutoMigrate(values ...interface{}) error { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (migrator Migrator) CreateTable(values ...interface{}) error { | func (m Migrator) CreateTable(values ...interface{}) error { | ||||||
| 	for _, value := range values { | 	for _, value := range values { | ||||||
| 		if err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error { | 		if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
| 			var ( | 			var ( | ||||||
| 				createTableSQL          = "CREATE TABLE ? (" | 				createTableSQL          = "CREATE TABLE ? (" | ||||||
| 				values                  = []interface{}{clause.Table{Name: stmt.Table}} | 				values                  = []interface{}{clause.Table{Name: stmt.Table}} | ||||||
| @ -100,7 +108,7 @@ func (migrator Migrator) CreateTable(values ...interface{}) error { | |||||||
| 				field := stmt.Schema.FieldsByDBName[dbName] | 				field := stmt.Schema.FieldsByDBName[dbName] | ||||||
| 				createTableSQL += fmt.Sprintf("? ?") | 				createTableSQL += fmt.Sprintf("? ?") | ||||||
| 				hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(field.DBDataType), "PRIMARY KEY") | 				hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(field.DBDataType), "PRIMARY KEY") | ||||||
| 				values = append(values, clause.Column{Name: dbName}, clause.Expr{SQL: field.DBDataType}) | 				values = append(values, clause.Column{Name: dbName}, clause.Expr{SQL: m.DataTypeOf(field)}) | ||||||
| 
 | 
 | ||||||
| 				if field.AutoIncrement { | 				if field.AutoIncrement { | ||||||
| 					createTableSQL += " AUTO_INCREMENT" | 					createTableSQL += " AUTO_INCREMENT" | ||||||
| @ -133,7 +141,7 @@ func (migrator Migrator) CreateTable(values ...interface{}) error { | |||||||
| 
 | 
 | ||||||
| 			for _, idx := range stmt.Schema.ParseIndexes() { | 			for _, idx := range stmt.Schema.ParseIndexes() { | ||||||
| 				createTableSQL += "INDEX ? ?," | 				createTableSQL += "INDEX ? ?," | ||||||
| 				values = append(values, clause.Expr{SQL: idx.Name}, buildIndexOptions(idx.Fields, stmt)) | 				values = append(values, clause.Expr{SQL: idx.Name}, m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)) | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			for _, rel := range stmt.Schema.Relationships.Relations { | 			for _, rel := range stmt.Schema.Relationships.Relations { | ||||||
| @ -145,8 +153,8 @@ func (migrator Migrator) CreateTable(values ...interface{}) error { | |||||||
| 
 | 
 | ||||||
| 				// create join table
 | 				// create join table
 | ||||||
| 				joinValue := reflect.New(rel.JoinTable.ModelType).Interface() | 				joinValue := reflect.New(rel.JoinTable.ModelType).Interface() | ||||||
| 				if !migrator.DB.Migrator().HasTable(joinValue) { | 				if !m.DB.Migrator().HasTable(joinValue) { | ||||||
| 					defer migrator.DB.Migrator().CreateTable(joinValue) | 					defer m.DB.Migrator().CreateTable(joinValue) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| @ -158,7 +166,7 @@ func (migrator Migrator) CreateTable(values ...interface{}) error { | |||||||
| 			createTableSQL = strings.TrimSuffix(createTableSQL, ",") | 			createTableSQL = strings.TrimSuffix(createTableSQL, ",") | ||||||
| 
 | 
 | ||||||
| 			createTableSQL += ")" | 			createTableSQL += ")" | ||||||
| 			return migrator.DB.Exec(createTableSQL, values...).Error | 			return m.DB.Exec(createTableSQL, values...).Error | ||||||
| 		}); err != nil { | 		}); err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| @ -166,10 +174,10 @@ func (migrator Migrator) CreateTable(values ...interface{}) error { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (migrator Migrator) DropTable(values ...interface{}) error { | func (m Migrator) DropTable(values ...interface{}) error { | ||||||
| 	for _, value := range values { | 	for _, value := range values { | ||||||
| 		if err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error { | 		if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
| 			return migrator.DB.Exec("DROP TABLE ?", clause.Table{Name: stmt.Table}).Error | 			return m.DB.Exec("DROP TABLE ?", clause.Table{Name: stmt.Table}).Error | ||||||
| 		}); err != nil { | 		}); err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| @ -177,42 +185,36 @@ func (migrator Migrator) DropTable(values ...interface{}) error { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (migrator Migrator) HasTable(values ...interface{}) bool { | func (m Migrator) HasTable(value interface{}) bool { | ||||||
| 	var count int64 | 	var count int64 | ||||||
| 	for _, value := range values { | 	m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
| 		err := migrator.RunWithValue(value, func(stmt *gorm.Statement) error { | 		currentDatabase := m.DB.Migrator().CurrentDatabase() | ||||||
| 			currentDatabase := migrator.DB.Migrator().CurrentDatabase() | 		return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentDatabase, stmt.Table, "BASE TABLE").Row().Scan(&count) | ||||||
| 			return migrator.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentDatabase, stmt.Table, "BASE TABLE").Scan(&count).Error | 	}) | ||||||
| 		}) |  | ||||||
| 
 | 
 | ||||||
| 		if err != nil || count == 0 { | 	return count > 0 | ||||||
| 			return false |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return true |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (migrator Migrator) RenameTable(oldName, newName string) error { | func (m Migrator) RenameTable(oldName, newName string) error { | ||||||
| 	return migrator.DB.Exec("RENAME TABLE ? TO ?", oldName, newName).Error | 	return m.DB.Exec("RENAME TABLE ? TO ?", oldName, newName).Error | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (migrator Migrator) AddColumn(value interface{}, field string) error { | func (m Migrator) AddColumn(value interface{}, field string) error { | ||||||
| 	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { | 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
| 		if field := stmt.Schema.LookUpField(field); field != nil { | 		if field := stmt.Schema.LookUpField(field); field != nil { | ||||||
| 			return migrator.DB.Exec( | 			return m.DB.Exec( | ||||||
| 				"ALTER TABLE ? ADD ? ?", | 				"ALTER TABLE ? ADD ? ?", | ||||||
| 				clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DBDataType}, | 				clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: m.DataTypeOf(field)}, | ||||||
| 			).Error | 			).Error | ||||||
| 		} | 		} | ||||||
| 		return fmt.Errorf("failed to look up field with name: %s", field) | 		return fmt.Errorf("failed to look up field with name: %s", field) | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (migrator Migrator) DropColumn(value interface{}, field string) error { | func (m Migrator) DropColumn(value interface{}, field string) error { | ||||||
| 	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { | 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
| 		if field := stmt.Schema.LookUpField(field); field != nil { | 		if field := stmt.Schema.LookUpField(field); field != nil { | ||||||
| 			return migrator.DB.Exec( | 			return m.DB.Exec( | ||||||
| 				"ALTER TABLE ? DROP COLUMN ?", clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, | 				"ALTER TABLE ? DROP COLUMN ?", clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, | ||||||
| 			).Error | 			).Error | ||||||
| 		} | 		} | ||||||
| @ -220,44 +222,41 @@ func (migrator Migrator) DropColumn(value interface{}, field string) error { | |||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (migrator Migrator) AlterColumn(value interface{}, field string) error { | func (m Migrator) AlterColumn(value interface{}, field string) error { | ||||||
| 	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { | 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
| 		if field := stmt.Schema.LookUpField(field); field != nil { | 		if field := stmt.Schema.LookUpField(field); field != nil { | ||||||
| 			return migrator.DB.Exec( | 			return m.DB.Exec( | ||||||
| 				"ALTER TABLE ? ALTER COLUMN ? TYPE ?", | 				"ALTER TABLE ? ALTER COLUMN ? TYPE ?", | ||||||
| 				clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DBDataType}, | 				clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, clause.Expr{SQL: m.DataTypeOf(field)}, | ||||||
| 			).Error | 			).Error | ||||||
| 		} | 		} | ||||||
| 		return fmt.Errorf("failed to look up field with name: %s", field) | 		return fmt.Errorf("failed to look up field with name: %s", field) | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (migrator Migrator) HasColumn(value interface{}, field string) bool { | func (m Migrator) HasColumn(value interface{}, field string) bool { | ||||||
| 	var count int64 | 	var count int64 | ||||||
| 	migrator.RunWithValue(value, func(stmt *gorm.Statement) error { | 	m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
| 		currentDatabase := migrator.DB.Migrator().CurrentDatabase() | 		currentDatabase := m.DB.Migrator().CurrentDatabase() | ||||||
| 		name := field | 		name := field | ||||||
| 		if field := stmt.Schema.LookUpField(field); field != nil { | 		if field := stmt.Schema.LookUpField(field); field != nil { | ||||||
| 			name = field.DBName | 			name = field.DBName | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		return migrator.DB.Raw( | 		return m.DB.Raw( | ||||||
| 			"SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?", | 			"SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?", | ||||||
| 			currentDatabase, stmt.Table, name, | 			currentDatabase, stmt.Table, name, | ||||||
| 		).Scan(&count).Error | 		).Row().Scan(&count) | ||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
| 	if count != 0 { | 	return count > 0 | ||||||
| 		return true |  | ||||||
| 	} |  | ||||||
| 	return false |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (migrator Migrator) RenameColumn(value interface{}, oldName, field string) error { | func (m Migrator) RenameColumn(value interface{}, oldName, field string) error { | ||||||
| 	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { | 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
| 		if field := stmt.Schema.LookUpField(field); field != nil { | 		if field := stmt.Schema.LookUpField(field); field != nil { | ||||||
| 			oldName = migrator.DB.NamingStrategy.ColumnName(stmt.Table, oldName) | 			oldName = m.DB.NamingStrategy.ColumnName(stmt.Table, oldName) | ||||||
| 			return migrator.DB.Exec( | 			return m.DB.Exec( | ||||||
| 				"ALTER TABLE ? RENAME COLUMN ? TO ?", | 				"ALTER TABLE ? RENAME COLUMN ? TO ?", | ||||||
| 				clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: field.DBName}, | 				clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: field.DBName}, | ||||||
| 			).Error | 			).Error | ||||||
| @ -266,15 +265,15 @@ func (migrator Migrator) RenameColumn(value interface{}, oldName, field string) | |||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (migrator Migrator) ColumnTypes(value interface{}) ([]*sql.ColumnType, error) { | func (m Migrator) ColumnTypes(value interface{}) ([]*sql.ColumnType, error) { | ||||||
| 	return nil, gorm.ErrNotImplemented | 	return nil, gorm.ErrNotImplemented | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (migrator Migrator) CreateView(name string, option gorm.ViewOption) error { | func (m Migrator) CreateView(name string, option gorm.ViewOption) error { | ||||||
| 	return gorm.ErrNotImplemented | 	return gorm.ErrNotImplemented | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (migrator Migrator) DropView(name string) error { | func (m Migrator) DropView(name string) error { | ||||||
| 	return gorm.ErrNotImplemented | 	return gorm.ErrNotImplemented | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| @ -300,11 +299,11 @@ func buildConstraint(constraint *schema.Constraint) (sql string, results []inter | |||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (migrator Migrator) CreateConstraint(value interface{}, name string) error { | func (m Migrator) CreateConstraint(value interface{}, name string) error { | ||||||
| 	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { | 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
| 		checkConstraints := stmt.Schema.ParseCheckConstraints() | 		checkConstraints := stmt.Schema.ParseCheckConstraints() | ||||||
| 		if chk, ok := checkConstraints[name]; ok { | 		if chk, ok := checkConstraints[name]; ok { | ||||||
| 			return migrator.DB.Exec( | 			return m.DB.Exec( | ||||||
| 				"ALTER TABLE ? ADD CONSTRAINT ? CHECK ?", | 				"ALTER TABLE ? ADD CONSTRAINT ? CHECK ?", | ||||||
| 				clause.Table{Name: stmt.Table}, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}, | 				clause.Table{Name: stmt.Table}, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}, | ||||||
| 			).Error | 			).Error | ||||||
| @ -313,21 +312,21 @@ func (migrator Migrator) CreateConstraint(value interface{}, name string) error | |||||||
| 		for _, rel := range stmt.Schema.Relationships.Relations { | 		for _, rel := range stmt.Schema.Relationships.Relations { | ||||||
| 			if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name { | 			if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name { | ||||||
| 				sql, values := buildConstraint(constraint) | 				sql, values := buildConstraint(constraint) | ||||||
| 				return migrator.DB.Exec("ALTER TABLE ? ADD "+sql, append([]interface{}{clause.Table{Name: stmt.Table}}, values...)...).Error | 				return m.DB.Exec("ALTER TABLE ? ADD "+sql, append([]interface{}{clause.Table{Name: stmt.Table}}, values...)...).Error | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		err := fmt.Errorf("failed to create constraint with name %v", name) | 		err := fmt.Errorf("failed to create constraint with name %v", name) | ||||||
| 		if field := stmt.Schema.LookUpField(name); field != nil { | 		if field := stmt.Schema.LookUpField(name); field != nil { | ||||||
| 			for _, cc := range checkConstraints { | 			for _, cc := range checkConstraints { | ||||||
| 				if err = migrator.CreateIndex(value, cc.Name); err != nil { | 				if err = m.CreateIndex(value, cc.Name); err != nil { | ||||||
| 					return err | 					return err | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			for _, rel := range stmt.Schema.Relationships.Relations { | 			for _, rel := range stmt.Schema.Relationships.Relations { | ||||||
| 				if constraint := rel.ParseConstraint(); constraint != nil && constraint.Field == field { | 				if constraint := rel.ParseConstraint(); constraint != nil && constraint.Field == field { | ||||||
| 					if err = migrator.CreateIndex(value, constraint.Name); err != nil { | 					if err = m.CreateIndex(value, constraint.Name); err != nil { | ||||||
| 						return err | 						return err | ||||||
| 					} | 					} | ||||||
| 				} | 				} | ||||||
| @ -338,32 +337,29 @@ func (migrator Migrator) CreateConstraint(value interface{}, name string) error | |||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (migrator Migrator) DropConstraint(value interface{}, name string) error { | func (m Migrator) DropConstraint(value interface{}, name string) error { | ||||||
| 	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { | 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
| 		return migrator.DB.Exec( | 		return m.DB.Exec( | ||||||
| 			"ALTER TABLE ? DROP CONSTRAINT ?", | 			"ALTER TABLE ? DROP CONSTRAINT ?", | ||||||
| 			clause.Table{Name: stmt.Table}, clause.Column{Name: name}, | 			clause.Table{Name: stmt.Table}, clause.Column{Name: name}, | ||||||
| 		).Error | 		).Error | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (migrator Migrator) HasConstraint(value interface{}, name string) bool { | func (m Migrator) HasConstraint(value interface{}, name string) bool { | ||||||
| 	var count int64 | 	var count int64 | ||||||
| 	migrator.RunWithValue(value, func(stmt *gorm.Statement) error { | 	m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
| 		currentDatabase := migrator.DB.Migrator().CurrentDatabase() | 		currentDatabase := m.DB.Migrator().CurrentDatabase() | ||||||
| 		return migrator.DB.Raw( | 		return m.DB.Raw( | ||||||
| 			"SELECT count(*) FROM INFORMATION_SCHEMA.referential_constraints WHERE constraint_schema = ? AND table_name = ? AND constraint_name = ?", | 			"SELECT count(*) FROM INFORMATION_SCHEMA.referential_constraints WHERE constraint_schema = ? AND table_name = ? AND constraint_name = ?", | ||||||
| 			currentDatabase, stmt.Table, name, | 			currentDatabase, stmt.Table, name, | ||||||
| 		).Scan(&count).Error | 		).Row().Scan(&count) | ||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
| 	if count != 0 { | 	return count > 0 | ||||||
| 		return true |  | ||||||
| 	} |  | ||||||
| 	return false |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func buildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { | func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { | ||||||
| 	for _, opt := range opts { | 	for _, opt := range opts { | ||||||
| 		str := stmt.Quote(opt.DBName) | 		str := stmt.Quote(opt.DBName) | ||||||
| 		if opt.Expression != "" { | 		if opt.Expression != "" { | ||||||
| @ -372,6 +368,10 @@ func buildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results | |||||||
| 			str += fmt.Sprintf("(%d)", opt.Length) | 			str += fmt.Sprintf("(%d)", opt.Length) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | 		if opt.Collate != "" { | ||||||
|  | 			str += " COLLATE " + opt.Collate | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
| 		if opt.Sort != "" { | 		if opt.Sort != "" { | ||||||
| 			str += " " + opt.Sort | 			str += " " + opt.Sort | ||||||
| 		} | 		} | ||||||
| @ -380,13 +380,17 @@ func buildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results | |||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (migrator Migrator) CreateIndex(value interface{}, name string) error { | type BuildIndexOptionsInterface interface { | ||||||
| 	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { | 	BuildIndexOptions([]schema.IndexOption, *gorm.Statement) []interface{} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (m Migrator) CreateIndex(value interface{}, name string) error { | ||||||
|  | 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
| 		err := fmt.Errorf("failed to create index with name %v", name) | 		err := fmt.Errorf("failed to create index with name %v", name) | ||||||
| 		indexes := stmt.Schema.ParseIndexes() | 		indexes := stmt.Schema.ParseIndexes() | ||||||
| 
 | 
 | ||||||
| 		if idx, ok := indexes[name]; ok { | 		if idx, ok := indexes[name]; ok { | ||||||
| 			opts := buildIndexOptions(idx.Fields, stmt) | 			opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt) | ||||||
| 			values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts} | 			values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts} | ||||||
| 
 | 
 | ||||||
| 			createIndexSQL := "CREATE " | 			createIndexSQL := "CREATE " | ||||||
| @ -404,12 +408,12 @@ func (migrator Migrator) CreateIndex(value interface{}, name string) error { | |||||||
| 				createIndexSQL += " USING " + idx.Type | 				createIndexSQL += " USING " + idx.Type | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			return migrator.DB.Raw(createIndexSQL, values...).Error | 			return m.DB.Exec(createIndexSQL, values...).Error | ||||||
| 		} else if field := stmt.Schema.LookUpField(name); field != nil { | 		} else if field := stmt.Schema.LookUpField(name); field != nil { | ||||||
| 			for _, idx := range indexes { | 			for _, idx := range indexes { | ||||||
| 				for _, idxOpt := range idx.Fields { | 				for _, idxOpt := range idx.Fields { | ||||||
| 					if idxOpt.Field == field { | 					if idxOpt.Field == field { | ||||||
| 						if err = migrator.CreateIndex(value, idx.Name); err != nil { | 						if err = m.CreateIndex(value, idx.Name); err != nil { | ||||||
| 							return err | 							return err | ||||||
| 						} | 						} | ||||||
| 					} | 					} | ||||||
| @ -420,38 +424,35 @@ func (migrator Migrator) CreateIndex(value interface{}, name string) error { | |||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (migrator Migrator) DropIndex(value interface{}, name string) error { | func (m Migrator) DropIndex(value interface{}, name string) error { | ||||||
| 	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { | 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
| 		return migrator.DB.Raw("DROP INDEX ? ON ?", clause.Column{Name: name}, clause.Table{Name: stmt.Table}).Error | 		return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, clause.Table{Name: stmt.Table}).Error | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (migrator Migrator) HasIndex(value interface{}, name string) bool { | func (m Migrator) HasIndex(value interface{}, name string) bool { | ||||||
| 	var count int64 | 	var count int64 | ||||||
| 	migrator.RunWithValue(value, func(stmt *gorm.Statement) error { | 	m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
| 		currentDatabase := migrator.DB.Migrator().CurrentDatabase() | 		currentDatabase := m.DB.Migrator().CurrentDatabase() | ||||||
| 		return migrator.DB.Raw( | 		return m.DB.Raw( | ||||||
| 			"SELECT count(*) FROM information_schema.statistics WHERE table_schema = ? AND table_name = ? AND index_name = ?", | 			"SELECT count(*) FROM information_schema.statistics WHERE table_schema = ? AND table_name = ? AND index_name = ?", | ||||||
| 			currentDatabase, stmt.Table, name, | 			currentDatabase, stmt.Table, name, | ||||||
| 		).Scan(&count).Error | 		).Row().Scan(&count) | ||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
| 	if count != 0 { | 	return count > 0 | ||||||
| 		return true |  | ||||||
| 	} |  | ||||||
| 	return false |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (migrator Migrator) RenameIndex(value interface{}, oldName, newName string) error { | func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { | ||||||
| 	return migrator.RunWithValue(value, func(stmt *gorm.Statement) error { | 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
| 		return migrator.DB.Exec( | 		return m.DB.Exec( | ||||||
| 			"ALTER TABLE ? RENAME INDEX ? TO ?", | 			"ALTER TABLE ? RENAME INDEX ? TO ?", | ||||||
| 			clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: newName}, | 			clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: newName}, | ||||||
| 		).Error | 		).Error | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (migrator Migrator) CurrentDatabase() (name string) { | func (m Migrator) CurrentDatabase() (name string) { | ||||||
| 	migrator.DB.Raw("SELECT DATABASE()").Scan(&name) | 	m.DB.Raw("SELECT DATABASE()").Row().Scan(&name) | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
|  | |||||||
| @ -138,7 +138,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if num, ok := field.TagSettings["SIZE"]; ok { | 	if num, ok := field.TagSettings["SIZE"]; ok { | ||||||
| 		field.Size, _ = strconv.Atoi(num) | 		var err error | ||||||
|  | 		if field.Size, err = strconv.Atoi(num); err != nil { | ||||||
|  | 			field.Size = -1 | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if p, ok := field.TagSettings["PRECISION"]; ok { | 	if p, ok := field.TagSettings["PRECISION"]; ok { | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu