diff --git a/callbacks/query.go b/callbacks/query.go index 8d13095e..a4ed3adb 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -8,10 +8,13 @@ import ( ) func Query(db *gorm.DB) { - db.Statement.AddClauseIfNotExists(clause.Select{}) - db.Statement.AddClauseIfNotExists(clause.From{}) + if db.Statement.SQL.String() == "" { + db.Statement.AddClauseIfNotExists(clause.Select{}) + db.Statement.AddClauseIfNotExists(clause.From{}) + + db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") + } - db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) fmt.Println(err) fmt.Println(result) diff --git a/callbacks/raw.go b/callbacks/raw.go index 6d0a5aac..e8cad25d 100644 --- a/callbacks/raw.go +++ b/callbacks/raw.go @@ -1,11 +1,14 @@ package callbacks -import "github.com/jinzhu/gorm" +import ( + "github.com/jinzhu/gorm" +) func RawExec(db *gorm.DB) { result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) - db.RowsAffected, _ = result.RowsAffected() if err != nil { db.AddError(err) + } else { + db.RowsAffected, _ = result.RowsAffected() } } diff --git a/callbacks/row.go b/callbacks/row.go index 04fe4f48..f7d6752d 100644 --- a/callbacks/row.go +++ b/callbacks/row.go @@ -6,10 +6,12 @@ import ( ) func RowQuery(db *gorm.DB) { - db.Statement.AddClauseIfNotExists(clause.Select{}) - db.Statement.AddClauseIfNotExists(clause.From{}) + if db.Statement.SQL.String() == "" { + db.Statement.AddClauseIfNotExists(clause.Select{}) + db.Statement.AddClauseIfNotExists(clause.From{}) - db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") + db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") + } if _, ok := db.Get("rows"); ok { db.Statement.Dest, db.Error = db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) diff --git a/chainable_api.go b/chainable_api.go index ccd61716..770b2236 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -222,8 +222,7 @@ func (db *DB) Unscoped() (tx *DB) { func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) { tx = db.getInstance() - stmt := tx.Statement - stmt.SQL = strings.Builder{} - clause.Expr{SQL: sql, Vars: values}.Build(stmt) + tx.Statement.SQL = strings.Builder{} + clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) return } diff --git a/clause/expression.go b/clause/expression.go index 6b3575df..d72db08d 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -1,6 +1,8 @@ package clause -import "strings" +import ( + "strings" +) // Expression expression interface type Expression interface { @@ -22,7 +24,7 @@ type Expr struct { func (expr Expr) Build(builder Builder) { sql := expr.SQL for _, v := range expr.Vars { - sql = strings.Replace(sql, " ?", " "+builder.AddVar(v), 1) + sql = strings.Replace(sql, "?", builder.AddVar(v), 1) } builder.Write(sql) } diff --git a/clause/expression_test.go b/clause/expression_test.go new file mode 100644 index 00000000..e51d189e --- /dev/null +++ b/clause/expression_test.go @@ -0,0 +1,35 @@ +package clause_test + +import ( + "fmt" + "sync" + "testing" + + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/clause" + "github.com/jinzhu/gorm/schema" + "github.com/jinzhu/gorm/tests" +) + +func TestExpr(t *testing.T) { + results := []struct { + SQL string + Result string + Vars []interface{} + }{{ + SQL: "create table ? (? ?, ? ?)", + Vars: []interface{}{clause.Table{Name: "users"}, clause.Column{Name: "id"}, clause.Expr{SQL: "int"}, clause.Column{Name: "name"}, clause.Expr{SQL: "text"}}, + Result: "create table `users` (`id` int, `name` text)", + }} + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} + clause.Expr{SQL: result.SQL, Vars: result.Vars}.Build(stmt) + if stmt.SQL.String() != result.Result { + t.Errorf("generated SQL is not equal, expects %v, but got %v", result.Result, stmt.SQL.String()) + } + }) + } +} diff --git a/dialects/sqlite/migrator.go b/dialects/sqlite/migrator.go index 07e189ad..4ddcbb5d 100644 --- a/dialects/sqlite/migrator.go +++ b/dialects/sqlite/migrator.go @@ -30,8 +30,8 @@ func (m Migrator) HasColumn(value interface{}, field string) bool { } return m.DB.Raw( - "SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE ? OR sql LIKE ?)", - stmt.Table, `%"`+name+`" %`, `%`+name+` %`, + "SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ?)", + stmt.Table, `%"`+name+`" %`, `%`+name+` %`, "%`"+name+"`%", ).Row().Scan(&count) }) return count > 0 diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go index 804016a5..38cd760b 100644 --- a/dialects/sqlite/sqlite.go +++ b/dialects/sqlite/sqlite.go @@ -28,8 +28,9 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { return Migrator{migrator.Migrator{Config: migrator.Config{ - DB: db, - Dialector: dialector, + DB: db, + Dialector: dialector, + CreateIndexAfterCreateTable: true, }}} } @@ -44,20 +45,20 @@ func (dialector Dialector) QuoteChars() [2]byte { func (dialector Dialector) DataTypeOf(field *schema.Field) string { switch field.DataType { case schema.Bool: - return "NUMERIC" + return "numeric" case schema.Int, schema.Uint: if field.AutoIncrement { // https://www.sqlite.org/autoinc.html - return "INTEGER PRIMARY KEY AUTOINCREMENT" + return "integer PRIMARY KEY AUTOINCREMENT" } else { - return "INTEGER" + return "integer" } case schema.Float: - return "REAL" + return "real" case schema.String, schema.Time: - return "TEXT" + return "text" case schema.Bytes: - return "BLOB" + return "blob" } return "" diff --git a/finisher_api.go b/finisher_api.go index 8b824d12..c9b58861 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -2,6 +2,7 @@ package gorm import ( "database/sql" + "strings" "github.com/jinzhu/gorm/clause" ) @@ -166,6 +167,8 @@ func (db *DB) Rollback() (tx *DB) { func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { tx = db.getInstance() + tx.Statement.SQL = strings.Builder{} + clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) tx.callbacks.Raw().Execute(tx) return } diff --git a/go.mod b/go.mod index cdb7e574..9046ea99 100644 --- a/go.mod +++ b/go.mod @@ -5,4 +5,5 @@ go 1.13 require ( github.com/jinzhu/inflection v1.0.0 github.com/jinzhu/now v1.1.1 + github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect ) diff --git a/migrator/migrator.go b/migrator/migrator.go index 5debc600..e3097abd 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -18,7 +18,8 @@ type Migrator struct { // Config schema config type Config struct { - DB *gorm.DB + CreateIndexAfterCreateTable bool + DB *gorm.DB gorm.Dialector } @@ -80,9 +81,11 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { } // create join table - joinValue := reflect.New(rel.JoinTable.ModelType).Interface() - if !m.DB.Migrator().HasTable(joinValue) { - defer m.DB.Migrator().CreateTable(joinValue) + if rel.JoinTable != nil { + joinValue := reflect.New(rel.JoinTable.ModelType).Interface() + if !m.DB.Migrator().HasTable(joinValue) { + defer m.DB.Migrator().CreateTable(joinValue) + } } } return nil @@ -140,8 +143,12 @@ func (m Migrator) CreateTable(values ...interface{}) error { } for _, idx := range stmt.Schema.ParseIndexes() { - createTableSQL += "INDEX ? ?," - values = append(values, clause.Expr{SQL: idx.Name}, m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)) + if m.CreateIndexAfterCreateTable { + m.DB.Migrator().CreateIndex(value, idx.Name) + } else { + createTableSQL += "INDEX ? ?," + values = append(values, clause.Expr{SQL: idx.Name}, m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)) + } } for _, rel := range stmt.Schema.Relationships.Relations { @@ -152,9 +159,11 @@ func (m Migrator) CreateTable(values ...interface{}) error { } // create join table - joinValue := reflect.New(rel.JoinTable.ModelType).Interface() - if !m.DB.Migrator().HasTable(joinValue) { - defer m.DB.Migrator().CreateTable(joinValue) + if rel.JoinTable != nil { + joinValue := reflect.New(rel.JoinTable.ModelType).Interface() + if !m.DB.Migrator().HasTable(joinValue) { + defer m.DB.Migrator().CreateTable(joinValue) + } } } @@ -302,7 +311,7 @@ func buildConstraint(constraint *schema.Constraint) (sql string, results []inter for _, field := range constraint.References { references = append(references, clause.Column{Name: field.DBName}) } - results = append(results, constraint.Name, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references) + results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references) return } @@ -326,14 +335,14 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error { err := fmt.Errorf("failed to create constraint with name %v", name) if field := stmt.Schema.LookUpField(name); field != nil { for _, cc := range checkConstraints { - if err = m.CreateIndex(value, cc.Name); err != nil { + if err = m.DB.Migrator().CreateIndex(value, cc.Name); err != nil { return err } } for _, rel := range stmt.Schema.Relationships.Relations { if constraint := rel.ParseConstraint(); constraint != nil && constraint.Field == field { - if err = m.CreateIndex(value, constraint.Name); err != nil { + if err = m.DB.Migrator().CreateIndex(value, constraint.Name); err != nil { return err } } diff --git a/schema/naming.go b/schema/naming.go index d6f26e9f..f7c82f32 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -46,7 +46,7 @@ func (ns NamingStrategy) JoinTableName(str string) string { // RelationshipFKName generate fk name for relation func (ns NamingStrategy) RelationshipFKName(rel Relationship) string { - return fmt.Sprintf("fk_%s_%s", rel.Schema.Table, rel.FieldSchema.Table) + return fmt.Sprintf("fk_%s_%s", rel.Schema.Table, toDBName(rel.Field.Name)) } // CheckerName generate checker name diff --git a/schema/relationship.go b/schema/relationship.go index 6606d77e..4ffea8b3 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -339,7 +339,7 @@ func (rel *Relationship) ParseConstraint() *Constraint { } } - if constraint.ReferenceSchema == nil { + if rel.JoinTable != nil || constraint.ReferenceSchema == nil { return nil } diff --git a/statement.go b/statement.go index 8c75c90d..d486a1c7 100644 --- a/statement.go +++ b/statement.go @@ -152,8 +152,11 @@ func (stmt *Statement) AddVar(vars ...interface{}) string { stmt.Vars = append(stmt.Vars, v.Value) placeholders.WriteString(stmt.DB.Dialector.BindVar(stmt, v.Value)) } - case clause.Column: + case clause.Column, clause.Table: placeholders.WriteString(stmt.Quote(v)) + case clause.Expr: + placeholders.WriteString(v.SQL) + stmt.Vars = append(stmt.Vars, v.Vars...) case []interface{}: if len(v) > 0 { placeholders.WriteByte('(') diff --git a/tests/dummy_dialecter.go b/tests/dummy_dialecter.go index e2cda8fc..b4e3361b 100644 --- a/tests/dummy_dialecter.go +++ b/tests/dummy_dialecter.go @@ -2,6 +2,7 @@ package tests import ( "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/schema" ) type DummyDialector struct { @@ -11,7 +12,7 @@ func (DummyDialector) Initialize(*gorm.DB) error { return nil } -func (DummyDialector) Migrator() gorm.Migrator { +func (DummyDialector) Migrator(*gorm.DB) gorm.Migrator { return nil } @@ -22,3 +23,7 @@ func (DummyDialector) BindVar(stmt *gorm.Statement, v interface{}) string { func (DummyDialector) QuoteChars() [2]byte { return [2]byte{'`', '`'} // `name` } + +func (DummyDialector) DataTypeOf(*schema.Field) string { + return "" +} diff --git a/tests/migrate.go b/tests/migrate.go index 0466fe11..9f7e2d67 100644 --- a/tests/migrate.go +++ b/tests/migrate.go @@ -9,11 +9,21 @@ import ( func TestMigrate(t *testing.T, db *gorm.DB) { allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Toy{}, &Company{}, &Language{}} - db.AutoMigrate(allModels...) + for _, m := range allModels { + if db.Migrator().HasTable(m) { + if err := db.Migrator().DropTable(m); err != nil { + t.Errorf("Failed to drop table, got error %v", err) + } + } + } + + if err := db.AutoMigrate(allModels...); err != nil { + t.Errorf("Failed to auto migrate, but got error %v", err) + } for _, m := range allModels { if !db.Migrator().HasTable(m) { - t.Errorf("Failed to create table for %+v", m) + t.Errorf("Failed to create table for %#v", m) } } }