Add sqlite migration tests

This commit is contained in:
Jinzhu 2020-02-22 20:57:29 +08:00
parent 215f5e7765
commit 6d58b62fd4
16 changed files with 117 additions and 41 deletions

View File

@ -8,10 +8,13 @@ import (
) )
func Query(db *gorm.DB) { func Query(db *gorm.DB) {
db.Statement.AddClauseIfNotExists(clause.Select{}) if db.Statement.SQL.String() == "" {
db.Statement.AddClauseIfNotExists(clause.From{}) db.Statement.AddClauseIfNotExists(clause.Select{})
db.Statement.AddClauseIfNotExists(clause.From{})
db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
}
db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
fmt.Println(err) fmt.Println(err)
fmt.Println(result) fmt.Println(result)

View File

@ -1,11 +1,14 @@
package callbacks package callbacks
import "github.com/jinzhu/gorm" import (
"github.com/jinzhu/gorm"
)
func RawExec(db *gorm.DB) { func RawExec(db *gorm.DB) {
result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
db.RowsAffected, _ = result.RowsAffected()
if err != nil { if err != nil {
db.AddError(err) db.AddError(err)
} else {
db.RowsAffected, _ = result.RowsAffected()
} }
} }

View File

@ -6,10 +6,12 @@ import (
) )
func RowQuery(db *gorm.DB) { func RowQuery(db *gorm.DB) {
db.Statement.AddClauseIfNotExists(clause.Select{}) if db.Statement.SQL.String() == "" {
db.Statement.AddClauseIfNotExists(clause.From{}) db.Statement.AddClauseIfNotExists(clause.Select{})
db.Statement.AddClauseIfNotExists(clause.From{})
db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
}
if _, ok := db.Get("rows"); ok { if _, ok := db.Get("rows"); ok {
db.Statement.Dest, db.Error = db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) db.Statement.Dest, db.Error = db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)

View File

@ -222,8 +222,7 @@ func (db *DB) Unscoped() (tx *DB) {
func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) { func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
stmt := tx.Statement tx.Statement.SQL = strings.Builder{}
stmt.SQL = strings.Builder{} clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
clause.Expr{SQL: sql, Vars: values}.Build(stmt)
return return
} }

View File

@ -1,6 +1,8 @@
package clause package clause
import "strings" import (
"strings"
)
// Expression expression interface // Expression expression interface
type Expression interface { type Expression interface {
@ -22,7 +24,7 @@ type Expr struct {
func (expr Expr) Build(builder Builder) { func (expr Expr) Build(builder Builder) {
sql := expr.SQL sql := expr.SQL
for _, v := range expr.Vars { for _, v := range expr.Vars {
sql = strings.Replace(sql, " ?", " "+builder.AddVar(v), 1) sql = strings.Replace(sql, "?", builder.AddVar(v), 1)
} }
builder.Write(sql) builder.Write(sql)
} }

35
clause/expression_test.go Normal file
View File

@ -0,0 +1,35 @@
package clause_test
import (
"fmt"
"sync"
"testing"
"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/clause"
"github.com/jinzhu/gorm/schema"
"github.com/jinzhu/gorm/tests"
)
func TestExpr(t *testing.T) {
results := []struct {
SQL string
Result string
Vars []interface{}
}{{
SQL: "create table ? (? ?, ? ?)",
Vars: []interface{}{clause.Table{Name: "users"}, clause.Column{Name: "id"}, clause.Expr{SQL: "int"}, clause.Column{Name: "name"}, clause.Expr{SQL: "text"}},
Result: "create table `users` (`id` int, `name` text)",
}}
for idx, result := range results {
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
clause.Expr{SQL: result.SQL, Vars: result.Vars}.Build(stmt)
if stmt.SQL.String() != result.Result {
t.Errorf("generated SQL is not equal, expects %v, but got %v", result.Result, stmt.SQL.String())
}
})
}
}

View File

@ -30,8 +30,8 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
} }
return m.DB.Raw( return m.DB.Raw(
"SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE ? OR sql LIKE ?)", "SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ?)",
stmt.Table, `%"`+name+`" %`, `%`+name+` %`, stmt.Table, `%"`+name+`" %`, `%`+name+` %`, "%`"+name+"`%",
).Row().Scan(&count) ).Row().Scan(&count)
}) })
return count > 0 return count > 0

View File

@ -28,8 +28,9 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
return Migrator{migrator.Migrator{Config: migrator.Config{ return Migrator{migrator.Migrator{Config: migrator.Config{
DB: db, DB: db,
Dialector: dialector, Dialector: dialector,
CreateIndexAfterCreateTable: true,
}}} }}}
} }
@ -44,20 +45,20 @@ func (dialector Dialector) QuoteChars() [2]byte {
func (dialector Dialector) DataTypeOf(field *schema.Field) string { func (dialector Dialector) DataTypeOf(field *schema.Field) string {
switch field.DataType { switch field.DataType {
case schema.Bool: case schema.Bool:
return "NUMERIC" return "numeric"
case schema.Int, schema.Uint: case schema.Int, schema.Uint:
if field.AutoIncrement { if field.AutoIncrement {
// https://www.sqlite.org/autoinc.html // https://www.sqlite.org/autoinc.html
return "INTEGER PRIMARY KEY AUTOINCREMENT" return "integer PRIMARY KEY AUTOINCREMENT"
} else { } else {
return "INTEGER" return "integer"
} }
case schema.Float: case schema.Float:
return "REAL" return "real"
case schema.String, schema.Time: case schema.String, schema.Time:
return "TEXT" return "text"
case schema.Bytes: case schema.Bytes:
return "BLOB" return "blob"
} }
return "" return ""

View File

@ -2,6 +2,7 @@ package gorm
import ( import (
"database/sql" "database/sql"
"strings"
"github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/clause"
) )
@ -166,6 +167,8 @@ func (db *DB) Rollback() (tx *DB) {
func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.SQL = strings.Builder{}
clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
tx.callbacks.Raw().Execute(tx) tx.callbacks.Raw().Execute(tx)
return return
} }

1
go.mod
View File

@ -5,4 +5,5 @@ go 1.13
require ( require (
github.com/jinzhu/inflection v1.0.0 github.com/jinzhu/inflection v1.0.0
github.com/jinzhu/now v1.1.1 github.com/jinzhu/now v1.1.1
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
) )

View File

@ -18,7 +18,8 @@ type Migrator struct {
// Config schema config // Config schema config
type Config struct { type Config struct {
DB *gorm.DB CreateIndexAfterCreateTable bool
DB *gorm.DB
gorm.Dialector gorm.Dialector
} }
@ -80,9 +81,11 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
} }
// create join table // create join table
joinValue := reflect.New(rel.JoinTable.ModelType).Interface() if rel.JoinTable != nil {
if !m.DB.Migrator().HasTable(joinValue) { joinValue := reflect.New(rel.JoinTable.ModelType).Interface()
defer m.DB.Migrator().CreateTable(joinValue) if !m.DB.Migrator().HasTable(joinValue) {
defer m.DB.Migrator().CreateTable(joinValue)
}
} }
} }
return nil return nil
@ -140,8 +143,12 @@ func (m Migrator) CreateTable(values ...interface{}) error {
} }
for _, idx := range stmt.Schema.ParseIndexes() { for _, idx := range stmt.Schema.ParseIndexes() {
createTableSQL += "INDEX ? ?," if m.CreateIndexAfterCreateTable {
values = append(values, clause.Expr{SQL: idx.Name}, m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)) m.DB.Migrator().CreateIndex(value, idx.Name)
} else {
createTableSQL += "INDEX ? ?,"
values = append(values, clause.Expr{SQL: idx.Name}, m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt))
}
} }
for _, rel := range stmt.Schema.Relationships.Relations { for _, rel := range stmt.Schema.Relationships.Relations {
@ -152,9 +159,11 @@ func (m Migrator) CreateTable(values ...interface{}) error {
} }
// create join table // create join table
joinValue := reflect.New(rel.JoinTable.ModelType).Interface() if rel.JoinTable != nil {
if !m.DB.Migrator().HasTable(joinValue) { joinValue := reflect.New(rel.JoinTable.ModelType).Interface()
defer m.DB.Migrator().CreateTable(joinValue) if !m.DB.Migrator().HasTable(joinValue) {
defer m.DB.Migrator().CreateTable(joinValue)
}
} }
} }
@ -302,7 +311,7 @@ func buildConstraint(constraint *schema.Constraint) (sql string, results []inter
for _, field := range constraint.References { for _, field := range constraint.References {
references = append(references, clause.Column{Name: field.DBName}) references = append(references, clause.Column{Name: field.DBName})
} }
results = append(results, constraint.Name, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references) results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
return return
} }
@ -326,14 +335,14 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error {
err := fmt.Errorf("failed to create constraint with name %v", name) err := fmt.Errorf("failed to create constraint with name %v", name)
if field := stmt.Schema.LookUpField(name); field != nil { if field := stmt.Schema.LookUpField(name); field != nil {
for _, cc := range checkConstraints { for _, cc := range checkConstraints {
if err = m.CreateIndex(value, cc.Name); err != nil { if err = m.DB.Migrator().CreateIndex(value, cc.Name); err != nil {
return err return err
} }
} }
for _, rel := range stmt.Schema.Relationships.Relations { for _, rel := range stmt.Schema.Relationships.Relations {
if constraint := rel.ParseConstraint(); constraint != nil && constraint.Field == field { if constraint := rel.ParseConstraint(); constraint != nil && constraint.Field == field {
if err = m.CreateIndex(value, constraint.Name); err != nil { if err = m.DB.Migrator().CreateIndex(value, constraint.Name); err != nil {
return err return err
} }
} }

View File

@ -46,7 +46,7 @@ func (ns NamingStrategy) JoinTableName(str string) string {
// RelationshipFKName generate fk name for relation // RelationshipFKName generate fk name for relation
func (ns NamingStrategy) RelationshipFKName(rel Relationship) string { func (ns NamingStrategy) RelationshipFKName(rel Relationship) string {
return fmt.Sprintf("fk_%s_%s", rel.Schema.Table, rel.FieldSchema.Table) return fmt.Sprintf("fk_%s_%s", rel.Schema.Table, toDBName(rel.Field.Name))
} }
// CheckerName generate checker name // CheckerName generate checker name

View File

@ -339,7 +339,7 @@ func (rel *Relationship) ParseConstraint() *Constraint {
} }
} }
if constraint.ReferenceSchema == nil { if rel.JoinTable != nil || constraint.ReferenceSchema == nil {
return nil return nil
} }

View File

@ -152,8 +152,11 @@ func (stmt *Statement) AddVar(vars ...interface{}) string {
stmt.Vars = append(stmt.Vars, v.Value) stmt.Vars = append(stmt.Vars, v.Value)
placeholders.WriteString(stmt.DB.Dialector.BindVar(stmt, v.Value)) placeholders.WriteString(stmt.DB.Dialector.BindVar(stmt, v.Value))
} }
case clause.Column: case clause.Column, clause.Table:
placeholders.WriteString(stmt.Quote(v)) placeholders.WriteString(stmt.Quote(v))
case clause.Expr:
placeholders.WriteString(v.SQL)
stmt.Vars = append(stmt.Vars, v.Vars...)
case []interface{}: case []interface{}:
if len(v) > 0 { if len(v) > 0 {
placeholders.WriteByte('(') placeholders.WriteByte('(')

View File

@ -2,6 +2,7 @@ package tests
import ( import (
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/schema"
) )
type DummyDialector struct { type DummyDialector struct {
@ -11,7 +12,7 @@ func (DummyDialector) Initialize(*gorm.DB) error {
return nil return nil
} }
func (DummyDialector) Migrator() gorm.Migrator { func (DummyDialector) Migrator(*gorm.DB) gorm.Migrator {
return nil return nil
} }
@ -22,3 +23,7 @@ func (DummyDialector) BindVar(stmt *gorm.Statement, v interface{}) string {
func (DummyDialector) QuoteChars() [2]byte { func (DummyDialector) QuoteChars() [2]byte {
return [2]byte{'`', '`'} // `name` return [2]byte{'`', '`'} // `name`
} }
func (DummyDialector) DataTypeOf(*schema.Field) string {
return ""
}

View File

@ -9,11 +9,21 @@ import (
func TestMigrate(t *testing.T, db *gorm.DB) { func TestMigrate(t *testing.T, db *gorm.DB) {
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Toy{}, &Company{}, &Language{}} allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Toy{}, &Company{}, &Language{}}
db.AutoMigrate(allModels...) for _, m := range allModels {
if db.Migrator().HasTable(m) {
if err := db.Migrator().DropTable(m); err != nil {
t.Errorf("Failed to drop table, got error %v", err)
}
}
}
if err := db.AutoMigrate(allModels...); err != nil {
t.Errorf("Failed to auto migrate, but got error %v", err)
}
for _, m := range allModels { for _, m := range allModels {
if !db.Migrator().HasTable(m) { if !db.Migrator().HasTable(m) {
t.Errorf("Failed to create table for %+v", m) t.Errorf("Failed to create table for %#v", m)
} }
} }
} }