Add sqlite migration tests
This commit is contained in:
		
							parent
							
								
									215f5e7765
								
							
						
					
					
						commit
						6d58b62fd4
					
				| @ -8,10 +8,13 @@ import ( | ||||
| ) | ||||
| 
 | ||||
| func Query(db *gorm.DB) { | ||||
| 	db.Statement.AddClauseIfNotExists(clause.Select{}) | ||||
| 	db.Statement.AddClauseIfNotExists(clause.From{}) | ||||
| 	if db.Statement.SQL.String() == "" { | ||||
| 		db.Statement.AddClauseIfNotExists(clause.Select{}) | ||||
| 		db.Statement.AddClauseIfNotExists(clause.From{}) | ||||
| 
 | ||||
| 		db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") | ||||
| 	} | ||||
| 
 | ||||
| 	db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") | ||||
| 	result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||
| 	fmt.Println(err) | ||||
| 	fmt.Println(result) | ||||
|  | ||||
| @ -1,11 +1,14 @@ | ||||
| package callbacks | ||||
| 
 | ||||
| import "github.com/jinzhu/gorm" | ||||
| import ( | ||||
| 	"github.com/jinzhu/gorm" | ||||
| ) | ||||
| 
 | ||||
| func RawExec(db *gorm.DB) { | ||||
| 	result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||
| 	db.RowsAffected, _ = result.RowsAffected() | ||||
| 	if err != nil { | ||||
| 		db.AddError(err) | ||||
| 	} else { | ||||
| 		db.RowsAffected, _ = result.RowsAffected() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @ -6,10 +6,12 @@ import ( | ||||
| ) | ||||
| 
 | ||||
| func RowQuery(db *gorm.DB) { | ||||
| 	db.Statement.AddClauseIfNotExists(clause.Select{}) | ||||
| 	db.Statement.AddClauseIfNotExists(clause.From{}) | ||||
| 	if db.Statement.SQL.String() == "" { | ||||
| 		db.Statement.AddClauseIfNotExists(clause.Select{}) | ||||
| 		db.Statement.AddClauseIfNotExists(clause.From{}) | ||||
| 
 | ||||
| 	db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") | ||||
| 		db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") | ||||
| 	} | ||||
| 
 | ||||
| 	if _, ok := db.Get("rows"); ok { | ||||
| 		db.Statement.Dest, db.Error = db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||
|  | ||||
| @ -222,8 +222,7 @@ func (db *DB) Unscoped() (tx *DB) { | ||||
| 
 | ||||
| func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) { | ||||
| 	tx = db.getInstance() | ||||
| 	stmt := tx.Statement | ||||
| 	stmt.SQL = strings.Builder{} | ||||
| 	clause.Expr{SQL: sql, Vars: values}.Build(stmt) | ||||
| 	tx.Statement.SQL = strings.Builder{} | ||||
| 	clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) | ||||
| 	return | ||||
| } | ||||
|  | ||||
| @ -1,6 +1,8 @@ | ||||
| package clause | ||||
| 
 | ||||
| import "strings" | ||||
| import ( | ||||
| 	"strings" | ||||
| ) | ||||
| 
 | ||||
| // Expression expression interface
 | ||||
| type Expression interface { | ||||
| @ -22,7 +24,7 @@ type Expr struct { | ||||
| func (expr Expr) Build(builder Builder) { | ||||
| 	sql := expr.SQL | ||||
| 	for _, v := range expr.Vars { | ||||
| 		sql = strings.Replace(sql, " ?", " "+builder.AddVar(v), 1) | ||||
| 		sql = strings.Replace(sql, "?", builder.AddVar(v), 1) | ||||
| 	} | ||||
| 	builder.Write(sql) | ||||
| } | ||||
|  | ||||
							
								
								
									
										35
									
								
								clause/expression_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								clause/expression_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,35 @@ | ||||
| package clause_test | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"sync" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm" | ||||
| 	"github.com/jinzhu/gorm/clause" | ||||
| 	"github.com/jinzhu/gorm/schema" | ||||
| 	"github.com/jinzhu/gorm/tests" | ||||
| ) | ||||
| 
 | ||||
| func TestExpr(t *testing.T) { | ||||
| 	results := []struct { | ||||
| 		SQL    string | ||||
| 		Result string | ||||
| 		Vars   []interface{} | ||||
| 	}{{ | ||||
| 		SQL:    "create table ? (? ?, ? ?)", | ||||
| 		Vars:   []interface{}{clause.Table{Name: "users"}, clause.Column{Name: "id"}, clause.Expr{SQL: "int"}, clause.Column{Name: "name"}, clause.Expr{SQL: "text"}}, | ||||
| 		Result: "create table `users` (`id` int, `name` text)", | ||||
| 	}} | ||||
| 
 | ||||
| 	for idx, result := range results { | ||||
| 		t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { | ||||
| 			user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) | ||||
| 			stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} | ||||
| 			clause.Expr{SQL: result.SQL, Vars: result.Vars}.Build(stmt) | ||||
| 			if stmt.SQL.String() != result.Result { | ||||
| 				t.Errorf("generated SQL is not equal, expects %v, but got %v", result.Result, stmt.SQL.String()) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
| @ -30,8 +30,8 @@ func (m Migrator) HasColumn(value interface{}, field string) bool { | ||||
| 		} | ||||
| 
 | ||||
| 		return m.DB.Raw( | ||||
| 			"SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE ? OR sql LIKE ?)", | ||||
| 			stmt.Table, `%"`+name+`" %`, `%`+name+` %`, | ||||
| 			"SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ?)", | ||||
| 			stmt.Table, `%"`+name+`" %`, `%`+name+` %`, "%`"+name+"`%", | ||||
| 		).Row().Scan(&count) | ||||
| 	}) | ||||
| 	return count > 0 | ||||
|  | ||||
| @ -28,8 +28,9 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { | ||||
| 
 | ||||
| func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { | ||||
| 	return Migrator{migrator.Migrator{Config: migrator.Config{ | ||||
| 		DB:        db, | ||||
| 		Dialector: dialector, | ||||
| 		DB:                          db, | ||||
| 		Dialector:                   dialector, | ||||
| 		CreateIndexAfterCreateTable: true, | ||||
| 	}}} | ||||
| } | ||||
| 
 | ||||
| @ -44,20 +45,20 @@ func (dialector Dialector) QuoteChars() [2]byte { | ||||
| func (dialector Dialector) DataTypeOf(field *schema.Field) string { | ||||
| 	switch field.DataType { | ||||
| 	case schema.Bool: | ||||
| 		return "NUMERIC" | ||||
| 		return "numeric" | ||||
| 	case schema.Int, schema.Uint: | ||||
| 		if field.AutoIncrement { | ||||
| 			// https://www.sqlite.org/autoinc.html
 | ||||
| 			return "INTEGER PRIMARY KEY AUTOINCREMENT" | ||||
| 			return "integer PRIMARY KEY AUTOINCREMENT" | ||||
| 		} else { | ||||
| 			return "INTEGER" | ||||
| 			return "integer" | ||||
| 		} | ||||
| 	case schema.Float: | ||||
| 		return "REAL" | ||||
| 		return "real" | ||||
| 	case schema.String, schema.Time: | ||||
| 		return "TEXT" | ||||
| 		return "text" | ||||
| 	case schema.Bytes: | ||||
| 		return "BLOB" | ||||
| 		return "blob" | ||||
| 	} | ||||
| 
 | ||||
| 	return "" | ||||
|  | ||||
| @ -2,6 +2,7 @@ package gorm | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm/clause" | ||||
| ) | ||||
| @ -166,6 +167,8 @@ func (db *DB) Rollback() (tx *DB) { | ||||
| 
 | ||||
| func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { | ||||
| 	tx = db.getInstance() | ||||
| 	tx.Statement.SQL = strings.Builder{} | ||||
| 	clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) | ||||
| 	tx.callbacks.Raw().Execute(tx) | ||||
| 	return | ||||
| } | ||||
|  | ||||
							
								
								
									
										1
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								go.mod
									
									
									
									
									
								
							| @ -5,4 +5,5 @@ go 1.13 | ||||
| require ( | ||||
| 	github.com/jinzhu/inflection v1.0.0 | ||||
| 	github.com/jinzhu/now v1.1.1 | ||||
| 	github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect | ||||
| ) | ||||
|  | ||||
| @ -18,7 +18,8 @@ type Migrator struct { | ||||
| 
 | ||||
| // Config schema config
 | ||||
| type Config struct { | ||||
| 	DB *gorm.DB | ||||
| 	CreateIndexAfterCreateTable bool | ||||
| 	DB                          *gorm.DB | ||||
| 	gorm.Dialector | ||||
| } | ||||
| 
 | ||||
| @ -80,9 +81,11 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { | ||||
| 					} | ||||
| 
 | ||||
| 					// create join table
 | ||||
| 					joinValue := reflect.New(rel.JoinTable.ModelType).Interface() | ||||
| 					if !m.DB.Migrator().HasTable(joinValue) { | ||||
| 						defer m.DB.Migrator().CreateTable(joinValue) | ||||
| 					if rel.JoinTable != nil { | ||||
| 						joinValue := reflect.New(rel.JoinTable.ModelType).Interface() | ||||
| 						if !m.DB.Migrator().HasTable(joinValue) { | ||||
| 							defer m.DB.Migrator().CreateTable(joinValue) | ||||
| 						} | ||||
| 					} | ||||
| 				} | ||||
| 				return nil | ||||
| @ -140,8 +143,12 @@ func (m Migrator) CreateTable(values ...interface{}) error { | ||||
| 			} | ||||
| 
 | ||||
| 			for _, idx := range stmt.Schema.ParseIndexes() { | ||||
| 				createTableSQL += "INDEX ? ?," | ||||
| 				values = append(values, clause.Expr{SQL: idx.Name}, m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)) | ||||
| 				if m.CreateIndexAfterCreateTable { | ||||
| 					m.DB.Migrator().CreateIndex(value, idx.Name) | ||||
| 				} else { | ||||
| 					createTableSQL += "INDEX ? ?," | ||||
| 					values = append(values, clause.Expr{SQL: idx.Name}, m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)) | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			for _, rel := range stmt.Schema.Relationships.Relations { | ||||
| @ -152,9 +159,11 @@ func (m Migrator) CreateTable(values ...interface{}) error { | ||||
| 				} | ||||
| 
 | ||||
| 				// create join table
 | ||||
| 				joinValue := reflect.New(rel.JoinTable.ModelType).Interface() | ||||
| 				if !m.DB.Migrator().HasTable(joinValue) { | ||||
| 					defer m.DB.Migrator().CreateTable(joinValue) | ||||
| 				if rel.JoinTable != nil { | ||||
| 					joinValue := reflect.New(rel.JoinTable.ModelType).Interface() | ||||
| 					if !m.DB.Migrator().HasTable(joinValue) { | ||||
| 						defer m.DB.Migrator().CreateTable(joinValue) | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| @ -302,7 +311,7 @@ func buildConstraint(constraint *schema.Constraint) (sql string, results []inter | ||||
| 	for _, field := range constraint.References { | ||||
| 		references = append(references, clause.Column{Name: field.DBName}) | ||||
| 	} | ||||
| 	results = append(results, constraint.Name, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references) | ||||
| 	results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references) | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| @ -326,14 +335,14 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error { | ||||
| 		err := fmt.Errorf("failed to create constraint with name %v", name) | ||||
| 		if field := stmt.Schema.LookUpField(name); field != nil { | ||||
| 			for _, cc := range checkConstraints { | ||||
| 				if err = m.CreateIndex(value, cc.Name); err != nil { | ||||
| 				if err = m.DB.Migrator().CreateIndex(value, cc.Name); err != nil { | ||||
| 					return err | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			for _, rel := range stmt.Schema.Relationships.Relations { | ||||
| 				if constraint := rel.ParseConstraint(); constraint != nil && constraint.Field == field { | ||||
| 					if err = m.CreateIndex(value, constraint.Name); err != nil { | ||||
| 					if err = m.DB.Migrator().CreateIndex(value, constraint.Name); err != nil { | ||||
| 						return err | ||||
| 					} | ||||
| 				} | ||||
|  | ||||
| @ -46,7 +46,7 @@ func (ns NamingStrategy) JoinTableName(str string) string { | ||||
| 
 | ||||
| // RelationshipFKName generate fk name for relation
 | ||||
| func (ns NamingStrategy) RelationshipFKName(rel Relationship) string { | ||||
| 	return fmt.Sprintf("fk_%s_%s", rel.Schema.Table, rel.FieldSchema.Table) | ||||
| 	return fmt.Sprintf("fk_%s_%s", rel.Schema.Table, toDBName(rel.Field.Name)) | ||||
| } | ||||
| 
 | ||||
| // CheckerName generate checker name
 | ||||
|  | ||||
| @ -339,7 +339,7 @@ func (rel *Relationship) ParseConstraint() *Constraint { | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if constraint.ReferenceSchema == nil { | ||||
| 	if rel.JoinTable != nil || constraint.ReferenceSchema == nil { | ||||
| 		return nil | ||||
| 	} | ||||
| 
 | ||||
|  | ||||
| @ -152,8 +152,11 @@ func (stmt *Statement) AddVar(vars ...interface{}) string { | ||||
| 				stmt.Vars = append(stmt.Vars, v.Value) | ||||
| 				placeholders.WriteString(stmt.DB.Dialector.BindVar(stmt, v.Value)) | ||||
| 			} | ||||
| 		case clause.Column: | ||||
| 		case clause.Column, clause.Table: | ||||
| 			placeholders.WriteString(stmt.Quote(v)) | ||||
| 		case clause.Expr: | ||||
| 			placeholders.WriteString(v.SQL) | ||||
| 			stmt.Vars = append(stmt.Vars, v.Vars...) | ||||
| 		case []interface{}: | ||||
| 			if len(v) > 0 { | ||||
| 				placeholders.WriteByte('(') | ||||
|  | ||||
| @ -2,6 +2,7 @@ package tests | ||||
| 
 | ||||
| import ( | ||||
| 	"github.com/jinzhu/gorm" | ||||
| 	"github.com/jinzhu/gorm/schema" | ||||
| ) | ||||
| 
 | ||||
| type DummyDialector struct { | ||||
| @ -11,7 +12,7 @@ func (DummyDialector) Initialize(*gorm.DB) error { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (DummyDialector) Migrator() gorm.Migrator { | ||||
| func (DummyDialector) Migrator(*gorm.DB) gorm.Migrator { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| @ -22,3 +23,7 @@ func (DummyDialector) BindVar(stmt *gorm.Statement, v interface{}) string { | ||||
| func (DummyDialector) QuoteChars() [2]byte { | ||||
| 	return [2]byte{'`', '`'} // `name`
 | ||||
| } | ||||
| 
 | ||||
| func (DummyDialector) DataTypeOf(*schema.Field) string { | ||||
| 	return "" | ||||
| } | ||||
|  | ||||
| @ -9,11 +9,21 @@ import ( | ||||
| func TestMigrate(t *testing.T, db *gorm.DB) { | ||||
| 	allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Toy{}, &Company{}, &Language{}} | ||||
| 
 | ||||
| 	db.AutoMigrate(allModels...) | ||||
| 	for _, m := range allModels { | ||||
| 		if db.Migrator().HasTable(m) { | ||||
| 			if err := db.Migrator().DropTable(m); err != nil { | ||||
| 				t.Errorf("Failed to drop table, got error %v", err) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if err := db.AutoMigrate(allModels...); err != nil { | ||||
| 		t.Errorf("Failed to auto migrate, but got error %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	for _, m := range allModels { | ||||
| 		if !db.Migrator().HasTable(m) { | ||||
| 			t.Errorf("Failed to create table for %+v", m) | ||||
| 			t.Errorf("Failed to create table for %#v", m) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu