From 9340b97a0b1b3bf37f77df3d482673781b88a11d Mon Sep 17 00:00:00 2001 From: Reza Mohammadi Date: Sat, 8 Apr 2017 21:02:53 +0430 Subject: [PATCH] Utilize go1.8 context support in database/sql Fixes #1231 The related go1.8 release notes: https://golang.org/doc/go1.8#database_sql --- README.md | 1 + callback_create.go | 4 +-- callback_query.go | 2 +- callback_row_query.go | 4 +-- dialect_common.go | 36 ++++++------------------- dialect_common_go1.8.go | 36 +++++++++++++++++++++++++ dialect_common_go1.8pre.go | 33 +++++++++++++++++++++++ dialect_mysql.go | 22 +++++---------- dialect_mysql_go1.8.go | 24 +++++++++++++++++ dialect_mysql_go1.8pre.go | 21 +++++++++++++++ dialect_postgres.go | 37 ++++++------------------- dialect_postgres_go1.8.go | 34 +++++++++++++++++++++++ dialect_postgres_go1.8pre.go | 32 ++++++++++++++++++++++ dialect_sqlite3.go | 43 +++++------------------------ dialect_sqlite3_go1.8.go | 44 ++++++++++++++++++++++++++++++ dialect_sqlite3_go1.8pre.go | 41 ++++++++++++++++++++++++++++ dialects/mssql/mssql.go | 36 ++++++------------------- dialects/mssql/mssql_go1.8.go | 36 +++++++++++++++++++++++++ dialects/mssql/mssql_go1.8pre.go | 33 +++++++++++++++++++++++ interface.go | 14 ---------- interface_go1.8.go | 20 ++++++++++++++ interface_go1.8pre.go | 17 ++++++++++++ main.go | 18 ++++--------- main_go1.8.go | 46 ++++++++++++++++++++++++++++++++ main_go1.8_test.go | 19 +++++++++++++ main_go1.8pre.go | 16 +++++++++++ scope.go | 23 ++++++++-------- scope_go1.8.go | 37 +++++++++++++++++++++++++ scope_go1.8pre.go | 32 ++++++++++++++++++++++ 29 files changed, 580 insertions(+), 181 deletions(-) create mode 100644 dialect_common_go1.8.go create mode 100644 dialect_common_go1.8pre.go create mode 100644 dialect_mysql_go1.8.go create mode 100644 dialect_mysql_go1.8pre.go create mode 100644 dialect_postgres_go1.8.go create mode 100644 dialect_postgres_go1.8pre.go create mode 100644 dialect_sqlite3_go1.8.go create mode 100644 dialect_sqlite3_go1.8pre.go create mode 100644 dialects/mssql/mssql_go1.8.go create mode 100644 dialects/mssql/mssql_go1.8pre.go create mode 100644 interface_go1.8.go create mode 100644 interface_go1.8pre.go create mode 100644 main_go1.8.go create mode 100644 main_go1.8_test.go create mode 100644 main_go1.8pre.go create mode 100644 scope_go1.8.go create mode 100644 scope_go1.8pre.go diff --git a/README.md b/README.md index 44eb4a69..47e7fc89 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. * Extendable, write Plugins based on GORM callbacks * Every feature comes with tests * Developer Friendly +* Supports context.Context on golang 1.8 ## Getting Started diff --git a/callback_create.go b/callback_create.go index a4da39e8..95aead17 100644 --- a/callback_create.go +++ b/callback_create.go @@ -115,7 +115,7 @@ func createCallback(scope *Scope) { // execute create sql if lastInsertIDReturningSuffix == "" || primaryField == nil { - if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { + if result, err := scope.sqldbExec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { // set rows affected count scope.db.RowsAffected, _ = result.RowsAffected() @@ -128,7 +128,7 @@ func createCallback(scope *Scope) { } } else { if primaryField.Field.CanAddr() { - if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil { + if err := scope.sqldbQueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil { primaryField.IsBlank = false scope.db.RowsAffected = 1 } diff --git a/callback_query.go b/callback_query.go index 20e88161..6ba27f6a 100644 --- a/callback_query.go +++ b/callback_query.go @@ -55,7 +55,7 @@ func queryCallback(scope *Scope) { scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str)) } - if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { + if rows, err := scope.sqldbQuery(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { defer rows.Close() columns, _ := rows.Columns() diff --git a/callback_row_query.go b/callback_row_query.go index c2ff4a08..2e42ae80 100644 --- a/callback_row_query.go +++ b/callback_row_query.go @@ -22,9 +22,9 @@ func rowQueryCallback(scope *Scope) { scope.prepareQuerySQL() if rowResult, ok := result.(*RowQueryResult); ok { - rowResult.Row = scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...) + rowResult.Row = scope.sqldbQueryRow(scope.SQL, scope.SQLVars...) } else if rowsResult, ok := result.(*RowsQueryResult); ok { - rowsResult.Rows, rowsResult.Error = scope.SQLDB().Query(scope.SQL, scope.SQLVars...) + rowsResult.Rows, rowsResult.Error = scope.sqldbQuery(scope.SQL, scope.SQLVars...) } } } diff --git a/dialect_common.go b/dialect_common.go index a99627f2..63b61acd 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -9,6 +9,14 @@ import ( "time" ) +const ( + queryHasIndex = "SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?" + queryRemoveIndex = "DROP INDEX %v" + queryHasTable = "SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?" + queryHasColumn = "SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?" + queryCurrentDatabase = "SELECT DATABASE()" +) + // DefaultForeignKeyNamer contains the default foreign key name generator method type DefaultForeignKeyNamer struct { } @@ -90,38 +98,10 @@ func (s *commonDialect) DataTypeOf(field *StructField) string { return fmt.Sprintf("%v %v", sqlType, additionalType) } -func (s commonDialect) HasIndex(tableName string, indexName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", s.CurrentDatabase(), tableName, indexName).Scan(&count) - return count > 0 -} - -func (s commonDialect) RemoveIndex(tableName string, indexName string) error { - _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName)) - return err -} - func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool { return false } -func (s commonDialect) HasTable(tableName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", s.CurrentDatabase(), tableName).Scan(&count) - return count > 0 -} - -func (s commonDialect) HasColumn(tableName string, columnName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.CurrentDatabase(), tableName, columnName).Scan(&count) - return count > 0 -} - -func (s commonDialect) CurrentDatabase() (name string) { - s.db.QueryRow("SELECT DATABASE()").Scan(&name) - return -} - func (commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { if limit != nil { if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 { diff --git a/dialect_common_go1.8.go b/dialect_common_go1.8.go new file mode 100644 index 00000000..d279e289 --- /dev/null +++ b/dialect_common_go1.8.go @@ -0,0 +1,36 @@ +// +build go1.8 + +package gorm + +import ( + "context" + "fmt" +) + +func (s commonDialect) HasIndex(tableName string, indexName string) bool { + var count int + s.db.QueryRowContext(context.Background(), queryHasIndex, s.CurrentDatabase(), tableName, indexName).Scan(&count) + return count > 0 +} + +func (s commonDialect) RemoveIndex(tableName string, indexName string) error { + _, err := s.db.ExecContext(context.Background(), fmt.Sprintf(queryRemoveIndex, indexName)) + return err +} + +func (s commonDialect) HasTable(tableName string) bool { + var count int + s.db.QueryRowContext(context.Background(), queryHasTable, s.CurrentDatabase(), tableName).Scan(&count) + return count > 0 +} + +func (s commonDialect) HasColumn(tableName string, columnName string) bool { + var count int + s.db.QueryRowContext(context.Background(), queryHasColumn, s.CurrentDatabase(), tableName, columnName).Scan(&count) + return count > 0 +} + +func (s commonDialect) CurrentDatabase() (name string) { + s.db.QueryRowContext(context.Background(), queryCurrentDatabase).Scan(&name) + return +} diff --git a/dialect_common_go1.8pre.go b/dialect_common_go1.8pre.go new file mode 100644 index 00000000..2400fdb8 --- /dev/null +++ b/dialect_common_go1.8pre.go @@ -0,0 +1,33 @@ +// +build !go1.8 + +package gorm + +import "fmt" + +func (s commonDialect) HasIndex(tableName string, indexName string) bool { + var count int + s.db.QueryRow(queryHasIndex, s.CurrentDatabase(), tableName, indexName).Scan(&count) + return count > 0 +} + +func (s commonDialect) RemoveIndex(tableName string, indexName string) error { + _, err := s.db.Exec(fmt.Sprintf(queryRemoveIndex, indexName)) + return err +} + +func (s commonDialect) HasTable(tableName string) bool { + var count int + s.db.QueryRow(queryHasTable, s.CurrentDatabase(), tableName).Scan(&count) + return count > 0 +} + +func (s commonDialect) HasColumn(tableName string, columnName string) bool { + var count int + s.db.QueryRow(queryHasColumn, s.CurrentDatabase(), tableName, columnName).Scan(&count) + return count > 0 +} + +func (s commonDialect) CurrentDatabase() (name string) { + s.db.QueryRow(queryCurrentDatabase).Scan(&name) + return +} diff --git a/dialect_mysql.go b/dialect_mysql.go index 6fcd0079..4b802b21 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -11,6 +11,12 @@ import ( "unicode/utf8" ) +const ( + queryMySQLRemoveIndex = "DROP INDEX %v ON %v" + queryMySQLHasForeignKey = "SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'" + queryMySQLCurrentDatabase = "SELECT DATABASE()" +) + type mysql struct { commonDialect } @@ -122,11 +128,6 @@ func (s *mysql) DataTypeOf(field *StructField) string { return fmt.Sprintf("%v %v", sqlType, additionalType) } -func (s mysql) RemoveIndex(tableName string, indexName string) error { - _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName))) - return err -} - func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { if limit != nil { if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 { @@ -142,17 +143,6 @@ func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { return } -func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", s.CurrentDatabase(), tableName, foreignKeyName).Scan(&count) - return count > 0 -} - -func (s mysql) CurrentDatabase() (name string) { - s.db.QueryRow("SELECT DATABASE()").Scan(&name) - return -} - func (mysql) SelectFromDummyTable() string { return "FROM DUAL" } diff --git a/dialect_mysql_go1.8.go b/dialect_mysql_go1.8.go new file mode 100644 index 00000000..3be61248 --- /dev/null +++ b/dialect_mysql_go1.8.go @@ -0,0 +1,24 @@ +// +build go1.8 + +package gorm + +import ( + "context" + "fmt" +) + +func (s mysql) RemoveIndex(tableName string, indexName string) error { + _, err := s.db.ExecContext(context.Background(), fmt.Sprintf(queryMySQLRemoveIndex, indexName, s.Quote(tableName))) + return err +} + +func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool { + var count int + s.db.QueryRowContext(context.Background(), queryMySQLHasForeignKey, s.CurrentDatabase(), tableName, foreignKeyName).Scan(&count) + return count > 0 +} + +func (s mysql) CurrentDatabase() (name string) { + s.db.QueryRowContext(context.Background(), queryMySQLCurrentDatabase).Scan(&name) + return +} diff --git a/dialect_mysql_go1.8pre.go b/dialect_mysql_go1.8pre.go new file mode 100644 index 00000000..cacf551d --- /dev/null +++ b/dialect_mysql_go1.8pre.go @@ -0,0 +1,21 @@ +// +build !go1.8 + +package gorm + +import "fmt" + +func (s mysql) RemoveIndex(tableName string, indexName string) error { + _, err := s.db.Exec(fmt.Sprintf(queryMySQLRemoveIndex, indexName, s.Quote(tableName))) + return err +} + +func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool { + var count int + s.db.QueryRow(queryMySQLHasForeignKey, s.CurrentDatabase(), tableName, foreignKeyName).Scan(&count) + return count > 0 +} + +func (s mysql) CurrentDatabase() (name string) { + s.db.QueryRow(queryMySQLCurrentDatabase).Scan(&name) + return +} diff --git a/dialect_postgres.go b/dialect_postgres.go index 6fdf4df1..411525e3 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -7,6 +7,14 @@ import ( "time" ) +const ( + queryPostgresHasIndex = "SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()" + queryPostgresHasForeignKey = "SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'" + queryPostgresHasTable = "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()" + queryPostgresHasColumn = "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()" + queryPostgresCurrentDatabase = "SELECT CURRENT_DATABASE()" +) + type postgres struct { commonDialect } @@ -85,35 +93,6 @@ func (s *postgres) DataTypeOf(field *StructField) string { return fmt.Sprintf("%v %v", sqlType, additionalType) } -func (s postgres) HasIndex(tableName string, indexName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()", tableName, indexName).Scan(&count) - return count > 0 -} - -func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool { - var count int - s.db.QueryRow("SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", tableName, foreignKeyName).Scan(&count) - return count > 0 -} - -func (s postgres) HasTable(tableName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()", tableName).Scan(&count) - return count > 0 -} - -func (s postgres) HasColumn(tableName string, columnName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()", tableName, columnName).Scan(&count) - return count > 0 -} - -func (s postgres) CurrentDatabase() (name string) { - s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name) - return -} - func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string { return fmt.Sprintf("RETURNING %v.%v", tableName, key) } diff --git a/dialect_postgres_go1.8.go b/dialect_postgres_go1.8.go new file mode 100644 index 00000000..35a66c03 --- /dev/null +++ b/dialect_postgres_go1.8.go @@ -0,0 +1,34 @@ +// +build go1.8 + +package gorm + +import "context" + +func (s postgres) HasIndex(tableName string, indexName string) bool { + var count int + s.db.QueryRowContext(context.Background(), queryPostgresHasIndex, tableName, indexName).Scan(&count) + return count > 0 +} + +func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool { + var count int + s.db.QueryRowContext(context.Background(), queryPostgresHasForeignKey, tableName, foreignKeyName).Scan(&count) + return count > 0 +} + +func (s postgres) HasTable(tableName string) bool { + var count int + s.db.QueryRowContext(context.Background(), queryPostgresHasTable, tableName).Scan(&count) + return count > 0 +} + +func (s postgres) HasColumn(tableName string, columnName string) bool { + var count int + s.db.QueryRowContext(context.Background(), queryPostgresHasColumn, tableName, columnName).Scan(&count) + return count > 0 +} + +func (s postgres) CurrentDatabase() (name string) { + s.db.QueryRowContext(context.Background(), queryPostgresCurrentDatabase).Scan(&name) + return +} diff --git a/dialect_postgres_go1.8pre.go b/dialect_postgres_go1.8pre.go new file mode 100644 index 00000000..1b31e37a --- /dev/null +++ b/dialect_postgres_go1.8pre.go @@ -0,0 +1,32 @@ +// +build !go1.8 + +package gorm + +func (s postgres) HasIndex(tableName string, indexName string) bool { + var count int + s.db.QueryRow(queryPostgresHasIndex, tableName, indexName).Scan(&count) + return count > 0 +} + +func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool { + var count int + s.db.QueryRow(queryPostgresHasForeignKey, tableName, foreignKeyName).Scan(&count) + return count > 0 +} + +func (s postgres) HasTable(tableName string) bool { + var count int + s.db.QueryRow(queryPostgresHasTable, tableName).Scan(&count) + return count > 0 +} + +func (s postgres) HasColumn(tableName string, columnName string) bool { + var count int + s.db.QueryRow(queryPostgresHasColumn, tableName, columnName).Scan(&count) + return count > 0 +} + +func (s postgres) CurrentDatabase() (name string) { + s.db.QueryRow(queryPostgresCurrentDatabase).Scan(&name) + return +} diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go index de9c05cb..6f852a74 100644 --- a/dialect_sqlite3.go +++ b/dialect_sqlite3.go @@ -7,6 +7,13 @@ import ( "time" ) +const ( + querySQLite3HasIndex = "SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'" + querySQLite3HasTable = "SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?" + querySQLite3HasColumn = "SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n" + querySQLite3CurrentDatabase = "PRAGMA database_list" +) + type sqlite3 struct { commonDialect } @@ -69,39 +76,3 @@ func (s *sqlite3) DataTypeOf(field *StructField) string { } return fmt.Sprintf("%v %v", sqlType, additionalType) } - -func (s sqlite3) HasIndex(tableName string, indexName string) bool { - var count int - s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count) - return count > 0 -} - -func (s sqlite3) HasTable(tableName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count) - return count > 0 -} - -func (s sqlite3) HasColumn(tableName string, columnName string) bool { - var count int - s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n", columnName, columnName), tableName).Scan(&count) - return count > 0 -} - -func (s sqlite3) CurrentDatabase() (name string) { - var ( - ifaces = make([]interface{}, 3) - pointers = make([]*string, 3) - i int - ) - for i = 0; i < 3; i++ { - ifaces[i] = &pointers[i] - } - if err := s.db.QueryRow("PRAGMA database_list").Scan(ifaces...); err != nil { - return - } - if pointers[1] != nil { - name = *pointers[1] - } - return -} diff --git a/dialect_sqlite3_go1.8.go b/dialect_sqlite3_go1.8.go new file mode 100644 index 00000000..6a5671ad --- /dev/null +++ b/dialect_sqlite3_go1.8.go @@ -0,0 +1,44 @@ +// +build go1.8 + +package gorm + +import ( + "context" + "fmt" +) + +func (s sqlite3) HasIndex(tableName string, indexName string) bool { + var count int + s.db.QueryRowContext(context.Background(), fmt.Sprintf(querySQLite3HasIndex, indexName), tableName).Scan(&count) + return count > 0 +} + +func (s sqlite3) HasTable(tableName string) bool { + var count int + s.db.QueryRowContext(context.Background(), querySQLite3HasTable, tableName).Scan(&count) + return count > 0 +} + +func (s sqlite3) HasColumn(tableName string, columnName string) bool { + var count int + s.db.QueryRowContext(context.Background(), fmt.Sprintf(querySQLite3HasColumn, columnName, columnName), tableName).Scan(&count) + return count > 0 +} + +func (s sqlite3) CurrentDatabase() (name string) { + var ( + ifaces = make([]interface{}, 3) + pointers = make([]*string, 3) + i int + ) + for i = 0; i < 3; i++ { + ifaces[i] = &pointers[i] + } + if err := s.db.QueryRowContext(context.Background(), querySQLite3CurrentDatabase).Scan(ifaces...); err != nil { + return + } + if pointers[1] != nil { + name = *pointers[1] + } + return +} diff --git a/dialect_sqlite3_go1.8pre.go b/dialect_sqlite3_go1.8pre.go new file mode 100644 index 00000000..64fbcc37 --- /dev/null +++ b/dialect_sqlite3_go1.8pre.go @@ -0,0 +1,41 @@ +// +build !go1.8 + +package gorm + +import "fmt" + +func (s sqlite3) HasIndex(tableName string, indexName string) bool { + var count int + s.db.QueryRow(fmt.Sprintf(querySQLite3HasIndex, indexName), tableName).Scan(&count) + return count > 0 +} + +func (s sqlite3) HasTable(tableName string) bool { + var count int + s.db.QueryRow(querySQLite3HasTable, tableName).Scan(&count) + return count > 0 +} + +func (s sqlite3) HasColumn(tableName string, columnName string) bool { + var count int + s.db.QueryRow(fmt.Sprintf(querySQLite3HasColumn, columnName, columnName), tableName).Scan(&count) + return count > 0 +} + +func (s sqlite3) CurrentDatabase() (name string) { + var ( + ifaces = make([]interface{}, 3) + pointers = make([]*string, 3) + i int + ) + for i = 0; i < 3; i++ { + ifaces[i] = &pointers[i] + } + if err := s.db.QueryRow(querySQLite3CurrentDatabase).Scan(ifaces...); err != nil { + return + } + if pointers[1] != nil { + name = *pointers[1] + } + return +} diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index de2ae7ca..dc39d13a 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -11,6 +11,14 @@ import ( "github.com/jinzhu/gorm" ) +const ( + queryMSSQLHasIndex = "SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)" + queryMSSQLRemoveIndex = "DROP INDEX %v ON %v" + queryMSSQLHasTable = "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?" + queryMSSQLHasColumn = "SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?" + queryMSSQLCurrentDatabase = "SELECT DB_NAME() AS [Current Database]" +) + func setIdentityInsert(scope *gorm.Scope) { if scope.Dialect().GetName() == "mssql" { for _, field := range scope.PrimaryFields() { @@ -111,38 +119,10 @@ func (s *mssql) DataTypeOf(field *gorm.StructField) string { return fmt.Sprintf("%v %v", sqlType, additionalType) } -func (s mssql) HasIndex(tableName string, indexName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count) - return count > 0 -} - -func (s mssql) RemoveIndex(tableName string, indexName string) error { - _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName))) - return err -} - func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool { return false } -func (s mssql) HasTable(tableName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.CurrentDatabase()).Scan(&count) - return count > 0 -} - -func (s mssql) HasColumn(tableName string, columnName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.CurrentDatabase(), tableName, columnName).Scan(&count) - return count > 0 -} - -func (s mssql) CurrentDatabase() (name string) { - s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name) - return -} - func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { if offset != nil { if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 { diff --git a/dialects/mssql/mssql_go1.8.go b/dialects/mssql/mssql_go1.8.go new file mode 100644 index 00000000..ba2e1f94 --- /dev/null +++ b/dialects/mssql/mssql_go1.8.go @@ -0,0 +1,36 @@ +// +build go1.8 + +package mssql + +import ( + "context" + "fmt" +) + +func (s mssql) HasIndex(tableName string, indexName string) bool { + var count int + s.db.QueryRowContext(context.Background(), queryMSSQLHasIndex, indexName, tableName).Scan(&count) + return count > 0 +} + +func (s mssql) RemoveIndex(tableName string, indexName string) error { + _, err := s.db.ExecContext(context.Background(), fmt.Sprintf(queryMSSQLRemoveIndex, indexName, s.Quote(tableName))) + return err +} + +func (s mssql) HasTable(tableName string) bool { + var count int + s.db.QueryRowContext(context.Background(), queryMSSQLHasTable, tableName, s.CurrentDatabase()).Scan(&count) + return count > 0 +} + +func (s mssql) HasColumn(tableName string, columnName string) bool { + var count int + s.db.QueryRowContext(context.Background(), queryMSSQLHasColumn, s.CurrentDatabase(), tableName, columnName).Scan(&count) + return count > 0 +} + +func (s mssql) CurrentDatabase() (name string) { + s.db.QueryRowContext(context.Background(), queryMSSQLCurrentDatabase).Scan(&name) + return +} diff --git a/dialects/mssql/mssql_go1.8pre.go b/dialects/mssql/mssql_go1.8pre.go new file mode 100644 index 00000000..29c4d54f --- /dev/null +++ b/dialects/mssql/mssql_go1.8pre.go @@ -0,0 +1,33 @@ +// +build !go1.8 + +package mssql + +import "fmt" + +func (s mssql) HasIndex(tableName string, indexName string) bool { + var count int + s.db.QueryRow(queryMSSQLHasIndex, indexName, tableName).Scan(&count) + return count > 0 +} + +func (s mssql) RemoveIndex(tableName string, indexName string) error { + _, err := s.db.Exec(fmt.Sprintf(queryMSSQLRemoveIndex, indexName, s.Quote(tableName))) + return err +} + +func (s mssql) HasTable(tableName string) bool { + var count int + s.db.QueryRow(queryMSSQLHasTable, tableName, s.CurrentDatabase()).Scan(&count) + return count > 0 +} + +func (s mssql) HasColumn(tableName string, columnName string) bool { + var count int + s.db.QueryRow(queryMSSQLHasColumn, s.CurrentDatabase(), tableName, columnName).Scan(&count) + return count > 0 +} + +func (s mssql) CurrentDatabase() (name string) { + s.db.QueryRow(queryMSSQLCurrentDatabase).Scan(&name) + return +} diff --git a/interface.go b/interface.go index 55128f7f..ec4af809 100644 --- a/interface.go +++ b/interface.go @@ -1,19 +1,5 @@ package gorm -import "database/sql" - -// SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB. -type SQLCommon interface { - Exec(query string, args ...interface{}) (sql.Result, error) - Prepare(query string) (*sql.Stmt, error) - Query(query string, args ...interface{}) (*sql.Rows, error) - QueryRow(query string, args ...interface{}) *sql.Row -} - -type sqlDb interface { - Begin() (*sql.Tx, error) -} - type sqlTx interface { Commit() error Rollback() error diff --git a/interface_go1.8.go b/interface_go1.8.go new file mode 100644 index 00000000..40970b8e --- /dev/null +++ b/interface_go1.8.go @@ -0,0 +1,20 @@ +// +build go1.8 + +package gorm + +import ( + "context" + "database/sql" +) + +// SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB. +type SQLCommon interface { + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row +} + +type sqlDb interface { + BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) +} diff --git a/interface_go1.8pre.go b/interface_go1.8pre.go new file mode 100644 index 00000000..53d3c23a --- /dev/null +++ b/interface_go1.8pre.go @@ -0,0 +1,17 @@ +// +build !go1.8 + +package gorm + +import "database/sql" + +// SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB. +type SQLCommon interface { + Exec(query string, args ...interface{}) (sql.Result, error) + Prepare(query string) (*sql.Stmt, error) + Query(query string, args ...interface{}) (*sql.Rows, error) + QueryRow(query string, args ...interface{}) *sql.Row +} + +type sqlDb interface { + Begin() (*sql.Tx, error) +} diff --git a/main.go b/main.go index 16fa0b79..67511b13 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,9 @@ import ( "reflect" "strings" "time" + + // Using the old package to support older golangs + "golang.org/x/net/context" ) // DB contains information for current db connection @@ -22,6 +25,7 @@ type DB struct { logger logger search *search values map[string]interface{} + context context.Context // global db parent *DB @@ -460,19 +464,6 @@ func (s *DB) Debug() *DB { return s.clone().LogMode(true) } -// Begin begin a transaction -func (s *DB) Begin() *DB { - c := s.clone() - if db, ok := c.db.(sqlDb); ok && db != nil { - tx, err := db.Begin() - c.db = interface{}(tx).(SQLCommon) - c.AddError(err) - } else { - c.AddError(ErrCantStartTransaction) - } - return c -} - // Commit commit a transaction func (s *DB) Commit() *DB { if db, ok := s.db.(sqlTx); ok && db != nil { @@ -717,6 +708,7 @@ func (s *DB) clone() *DB { logger: s.logger, logMode: s.logMode, values: map[string]interface{}{}, + context: s.context, Value: s.Value, Error: s.Error, blockGlobalUpdate: s.blockGlobalUpdate, diff --git a/main_go1.8.go b/main_go1.8.go new file mode 100644 index 00000000..0b7f5d68 --- /dev/null +++ b/main_go1.8.go @@ -0,0 +1,46 @@ +// +build go1.8 + +package gorm + +import ( + "context" + "database/sql" +) + +// WithContext specify context to be passed to the underlying `*sql.DB` or +// `*sql.Tx` query methods +func (s *DB) WithContext(ctx context.Context) *DB { + db := s.clone() + db.context = ctx + return db +} + +// Context returns the specified context for this instance, or nil if not set +func (s *DB) Context() context.Context { + return s.context +} + +func (s *DB) contextOrBackground() context.Context { + if s.context != nil { + return s.context + } + return context.Background() +} + +// BeginTx starts a transaction with the given options +func (s *DB) BeginTx(opts *sql.TxOptions) *DB { + c := s.clone() + if db, ok := c.db.(sqlDb); ok && db != nil { + tx, err := db.BeginTx(s.contextOrBackground(), opts) + c.db = interface{}(tx).(SQLCommon) + c.AddError(err) + } else { + c.AddError(ErrCantStartTransaction) + } + return c +} + +// Begin starts a transaction +func (s *DB) Begin() *DB { + return s.BeginTx(nil) +} diff --git a/main_go1.8_test.go b/main_go1.8_test.go new file mode 100644 index 00000000..1a0a9628 --- /dev/null +++ b/main_go1.8_test.go @@ -0,0 +1,19 @@ +// +build go1.8 + +package gorm_test + +import ( + "context" + "testing" + "time" +) + +func TestContext(t *testing.T) { + user1 := User{Name: "RowsUser1", Age: 1, Birthday: parseTime("2000-1-1")} + expiredCtx, cancel := context.WithDeadline(context.Background(), time.Date(2000, 1, 1, 1, 0, 0, 0, time.UTC)) + err := DB.WithContext(expiredCtx).Save(&user1).Error + cancel() + if err.Error() != context.DeadlineExceeded.Error() { + t.Fatal("unexpected err:", err) + } +} diff --git a/main_go1.8pre.go b/main_go1.8pre.go new file mode 100644 index 00000000..42e79ebb --- /dev/null +++ b/main_go1.8pre.go @@ -0,0 +1,16 @@ +// +build !go1.8 + +package gorm + +// Begin starts a transaction +func (s *DB) Begin() *DB { + c := s.clone() + if db, ok := c.db.(sqlDb); ok && db != nil { + tx, err := db.Begin() + c.db = interface{}(tx).(SQLCommon) + c.AddError(err) + } else { + c.AddError(ErrCantStartTransaction) + } + return c +} diff --git a/scope.go b/scope.go index 51ebd5a0..92d4e302 100644 --- a/scope.go +++ b/scope.go @@ -359,7 +359,7 @@ func (scope *Scope) Exec() *Scope { defer scope.trace(NowFunc()) if !scope.HasError() { - if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { + if result, err := scope.sqldbExec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { if count, err := result.RowsAffected(); scope.Err(err) == nil { scope.db.RowsAffected = count } @@ -397,17 +397,6 @@ func (scope *Scope) InstanceGet(name string) (interface{}, bool) { return scope.Get(name + scope.InstanceID()) } -// Begin start a transaction -func (scope *Scope) Begin() *Scope { - if db, ok := scope.SQLDB().(sqlDb); ok { - if tx, err := db.Begin(); err == nil { - scope.db.db = interface{}(tx).(SQLCommon) - scope.InstanceSet("gorm:started_transaction", true) - } - } - return scope -} - // CommitOrRollback commit current transaction if no error happened, otherwise will rollback it func (scope *Scope) CommitOrRollback() *Scope { if _, ok := scope.InstanceGet("gorm:started_transaction"); ok { @@ -1062,6 +1051,7 @@ func (scope *Scope) getTableOptions() string { return tableOptions.(string) } +// TODO: context variant func (scope *Scope) createJoinTable(field *StructField) { if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil { joinTableHandler := relationship.JoinTableHandler @@ -1098,6 +1088,7 @@ func (scope *Scope) createJoinTable(field *StructField) { } } +// TODO: context variant func (scope *Scope) createTable() *Scope { var tags []string var primaryKeys []string @@ -1133,19 +1124,23 @@ func (scope *Scope) createTable() *Scope { return scope } +// TODO: context variant func (scope *Scope) dropTable() *Scope { scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec() return scope } +// TODO: context variant func (scope *Scope) modifyColumn(column string, typ string) { scope.Raw(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec() } +// TODO: context variant func (scope *Scope) dropColumn(column string) { scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.QuotedTableName(), scope.Quote(column))).Exec() } +// TODO: context variant func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { if scope.Dialect().HasIndex(scope.TableName(), indexName) { return @@ -1164,6 +1159,7 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSQL())).Exec() } +// TODO: context variant func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) { keyName := scope.Dialect().BuildForeignKeyName(scope.TableName(), field, dest) @@ -1174,10 +1170,12 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec() } +// TODO: context variant func (scope *Scope) removeIndex(indexName string) { scope.Dialect().RemoveIndex(scope.TableName(), indexName) } +// TODO: context variant func (scope *Scope) autoMigrate() *Scope { tableName := scope.TableName() quotedTableName := scope.QuotedTableName() @@ -1199,6 +1197,7 @@ func (scope *Scope) autoMigrate() *Scope { return scope } +// TODO: context variant func (scope *Scope) autoIndex() *Scope { var indexes = map[string][]string{} var uniqueIndexes = map[string][]string{} diff --git a/scope_go1.8.go b/scope_go1.8.go new file mode 100644 index 00000000..2f53ada5 --- /dev/null +++ b/scope_go1.8.go @@ -0,0 +1,37 @@ +// +build go1.8 + +package gorm + +import "database/sql" + +// BeginTx start a transaction with the given options +func (scope *Scope) BeginTx(opts *sql.TxOptions) *Scope { + if db, ok := scope.SQLDB().(sqlDb); ok { + if tx, err := db.BeginTx(scope.DB().contextOrBackground(), opts); err == nil { + scope.db.db = interface{}(tx).(SQLCommon) + scope.InstanceSet("gorm:started_transaction", true) + } + } + return scope +} + +// Begin start a transaction +func (scope *Scope) Begin() *Scope { + return scope.BeginTx(nil) +} + +func (scope *Scope) sqldbExec(query string, args ...interface{}) (sql.Result, error) { + return scope.SQLDB().ExecContext(scope.db.contextOrBackground(), query, args...) +} + +func (scope *Scope) sqldbPrepare(query string) (*sql.Stmt, error) { + return scope.SQLDB().PrepareContext(scope.db.contextOrBackground(), query) +} + +func (scope *Scope) sqldbQuery(query string, args ...interface{}) (*sql.Rows, error) { + return scope.SQLDB().QueryContext(scope.db.contextOrBackground(), query, args...) +} + +func (scope *Scope) sqldbQueryRow(query string, args ...interface{}) *sql.Row { + return scope.SQLDB().QueryRowContext(scope.db.contextOrBackground(), query, args...) +} diff --git a/scope_go1.8pre.go b/scope_go1.8pre.go new file mode 100644 index 00000000..dc5194bd --- /dev/null +++ b/scope_go1.8pre.go @@ -0,0 +1,32 @@ +// +build !go1.8 + +package gorm + +import "database/sql" + +// Begin start a transaction +func (scope *Scope) Begin() *Scope { + if db, ok := scope.SQLDB().(sqlDb); ok { + if tx, err := db.Begin(); err == nil { + scope.db.db = interface{}(tx).(SQLCommon) + scope.InstanceSet("gorm:started_transaction", true) + } + } + return scope +} + +func (scope *Scope) sqldbExec(query string, args ...interface{}) (sql.Result, error) { + return scope.SQLDB().Exec(query, args...) +} + +func (scope *Scope) sqldbPrepare(query string) (*sql.Stmt, error) { + return scope.SQLDB().Prepare(query) +} + +func (scope *Scope) sqldbQuery(query string, args ...interface{}) (*sql.Rows, error) { + return scope.SQLDB().Query(query, args...) +} + +func (scope *Scope) sqldbQueryRow(query string, args ...interface{}) *sql.Row { + return scope.SQLDB().QueryRow(query, args...) +}