Merge 9340b97a0b1b3bf37f77df3d482673781b88a11d into 802104cc7cfe58153cccc9bc76e5b9078296c16b

This commit is contained in:
Reza Mohammadi 2018-02-03 08:55:17 +00:00 committed by GitHub
commit f61cd769da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 580 additions and 181 deletions

View File

@ -21,6 +21,7 @@ The fantastic ORM library for Golang, aims to be developer friendly.
* Extendable, write Plugins based on GORM callbacks * Extendable, write Plugins based on GORM callbacks
* Every feature comes with tests * Every feature comes with tests
* Developer Friendly * Developer Friendly
* Supports context.Context on golang 1.8
## Getting Started ## Getting Started

View File

@ -115,7 +115,7 @@ func createCallback(scope *Scope) {
// execute create sql // execute create sql
if lastInsertIDReturningSuffix == "" || primaryField == nil { if lastInsertIDReturningSuffix == "" || primaryField == nil {
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { if result, err := scope.sqldbExec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
// set rows affected count // set rows affected count
scope.db.RowsAffected, _ = result.RowsAffected() scope.db.RowsAffected, _ = result.RowsAffected()
@ -128,7 +128,7 @@ func createCallback(scope *Scope) {
} }
} else { } else {
if primaryField.Field.CanAddr() { if primaryField.Field.CanAddr() {
if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil { if err := scope.sqldbQueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
primaryField.IsBlank = false primaryField.IsBlank = false
scope.db.RowsAffected = 1 scope.db.RowsAffected = 1
} }

View File

@ -55,7 +55,7 @@ func queryCallback(scope *Scope) {
scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str)) scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
} }
if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { if rows, err := scope.sqldbQuery(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
defer rows.Close() defer rows.Close()
columns, _ := rows.Columns() columns, _ := rows.Columns()

View File

@ -22,9 +22,9 @@ func rowQueryCallback(scope *Scope) {
scope.prepareQuerySQL() scope.prepareQuerySQL()
if rowResult, ok := result.(*RowQueryResult); ok { if rowResult, ok := result.(*RowQueryResult); ok {
rowResult.Row = scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...) rowResult.Row = scope.sqldbQueryRow(scope.SQL, scope.SQLVars...)
} else if rowsResult, ok := result.(*RowsQueryResult); ok { } else if rowsResult, ok := result.(*RowsQueryResult); ok {
rowsResult.Rows, rowsResult.Error = scope.SQLDB().Query(scope.SQL, scope.SQLVars...) rowsResult.Rows, rowsResult.Error = scope.sqldbQuery(scope.SQL, scope.SQLVars...)
} }
} }
} }

View File

@ -9,6 +9,14 @@ import (
"time" "time"
) )
const (
queryHasIndex = "SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?"
queryRemoveIndex = "DROP INDEX %v"
queryHasTable = "SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?"
queryHasColumn = "SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?"
queryCurrentDatabase = "SELECT DATABASE()"
)
// DefaultForeignKeyNamer contains the default foreign key name generator method // DefaultForeignKeyNamer contains the default foreign key name generator method
type DefaultForeignKeyNamer struct { type DefaultForeignKeyNamer struct {
} }
@ -90,38 +98,10 @@ func (s *commonDialect) DataTypeOf(field *StructField) string {
return fmt.Sprintf("%v %v", sqlType, additionalType) return fmt.Sprintf("%v %v", sqlType, additionalType)
} }
func (s commonDialect) HasIndex(tableName string, indexName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", s.CurrentDatabase(), tableName, indexName).Scan(&count)
return count > 0
}
func (s commonDialect) RemoveIndex(tableName string, indexName string) error {
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName))
return err
}
func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool { func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool {
return false return false
} }
func (s commonDialect) HasTable(tableName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", s.CurrentDatabase(), tableName).Scan(&count)
return count > 0
}
func (s commonDialect) HasColumn(tableName string, columnName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.CurrentDatabase(), tableName, columnName).Scan(&count)
return count > 0
}
func (s commonDialect) CurrentDatabase() (name string) {
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
return
}
func (commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { func (commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
if limit != nil { if limit != nil {
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 { if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {

36
dialect_common_go1.8.go Normal file
View File

@ -0,0 +1,36 @@
// +build go1.8
package gorm
import (
"context"
"fmt"
)
func (s commonDialect) HasIndex(tableName string, indexName string) bool {
var count int
s.db.QueryRowContext(context.Background(), queryHasIndex, s.CurrentDatabase(), tableName, indexName).Scan(&count)
return count > 0
}
func (s commonDialect) RemoveIndex(tableName string, indexName string) error {
_, err := s.db.ExecContext(context.Background(), fmt.Sprintf(queryRemoveIndex, indexName))
return err
}
func (s commonDialect) HasTable(tableName string) bool {
var count int
s.db.QueryRowContext(context.Background(), queryHasTable, s.CurrentDatabase(), tableName).Scan(&count)
return count > 0
}
func (s commonDialect) HasColumn(tableName string, columnName string) bool {
var count int
s.db.QueryRowContext(context.Background(), queryHasColumn, s.CurrentDatabase(), tableName, columnName).Scan(&count)
return count > 0
}
func (s commonDialect) CurrentDatabase() (name string) {
s.db.QueryRowContext(context.Background(), queryCurrentDatabase).Scan(&name)
return
}

View File

@ -0,0 +1,33 @@
// +build !go1.8
package gorm
import "fmt"
func (s commonDialect) HasIndex(tableName string, indexName string) bool {
var count int
s.db.QueryRow(queryHasIndex, s.CurrentDatabase(), tableName, indexName).Scan(&count)
return count > 0
}
func (s commonDialect) RemoveIndex(tableName string, indexName string) error {
_, err := s.db.Exec(fmt.Sprintf(queryRemoveIndex, indexName))
return err
}
func (s commonDialect) HasTable(tableName string) bool {
var count int
s.db.QueryRow(queryHasTable, s.CurrentDatabase(), tableName).Scan(&count)
return count > 0
}
func (s commonDialect) HasColumn(tableName string, columnName string) bool {
var count int
s.db.QueryRow(queryHasColumn, s.CurrentDatabase(), tableName, columnName).Scan(&count)
return count > 0
}
func (s commonDialect) CurrentDatabase() (name string) {
s.db.QueryRow(queryCurrentDatabase).Scan(&name)
return
}

View File

@ -11,6 +11,12 @@ import (
"unicode/utf8" "unicode/utf8"
) )
const (
queryMySQLRemoveIndex = "DROP INDEX %v ON %v"
queryMySQLHasForeignKey = "SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'"
queryMySQLCurrentDatabase = "SELECT DATABASE()"
)
type mysql struct { type mysql struct {
commonDialect commonDialect
} }
@ -122,11 +128,6 @@ func (s *mysql) DataTypeOf(field *StructField) string {
return fmt.Sprintf("%v %v", sqlType, additionalType) return fmt.Sprintf("%v %v", sqlType, additionalType)
} }
func (s mysql) RemoveIndex(tableName string, indexName string) error {
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
return err
}
func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
if limit != nil { if limit != nil {
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 { if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
@ -142,17 +143,6 @@ func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
return return
} }
func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", s.CurrentDatabase(), tableName, foreignKeyName).Scan(&count)
return count > 0
}
func (s mysql) CurrentDatabase() (name string) {
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
return
}
func (mysql) SelectFromDummyTable() string { func (mysql) SelectFromDummyTable() string {
return "FROM DUAL" return "FROM DUAL"
} }

24
dialect_mysql_go1.8.go Normal file
View File

@ -0,0 +1,24 @@
// +build go1.8
package gorm
import (
"context"
"fmt"
)
func (s mysql) RemoveIndex(tableName string, indexName string) error {
_, err := s.db.ExecContext(context.Background(), fmt.Sprintf(queryMySQLRemoveIndex, indexName, s.Quote(tableName)))
return err
}
func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool {
var count int
s.db.QueryRowContext(context.Background(), queryMySQLHasForeignKey, s.CurrentDatabase(), tableName, foreignKeyName).Scan(&count)
return count > 0
}
func (s mysql) CurrentDatabase() (name string) {
s.db.QueryRowContext(context.Background(), queryMySQLCurrentDatabase).Scan(&name)
return
}

21
dialect_mysql_go1.8pre.go Normal file
View File

@ -0,0 +1,21 @@
// +build !go1.8
package gorm
import "fmt"
func (s mysql) RemoveIndex(tableName string, indexName string) error {
_, err := s.db.Exec(fmt.Sprintf(queryMySQLRemoveIndex, indexName, s.Quote(tableName)))
return err
}
func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool {
var count int
s.db.QueryRow(queryMySQLHasForeignKey, s.CurrentDatabase(), tableName, foreignKeyName).Scan(&count)
return count > 0
}
func (s mysql) CurrentDatabase() (name string) {
s.db.QueryRow(queryMySQLCurrentDatabase).Scan(&name)
return
}

View File

@ -7,6 +7,14 @@ import (
"time" "time"
) )
const (
queryPostgresHasIndex = "SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()"
queryPostgresHasForeignKey = "SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'"
queryPostgresHasTable = "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()"
queryPostgresHasColumn = "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()"
queryPostgresCurrentDatabase = "SELECT CURRENT_DATABASE()"
)
type postgres struct { type postgres struct {
commonDialect commonDialect
} }
@ -85,35 +93,6 @@ func (s *postgres) DataTypeOf(field *StructField) string {
return fmt.Sprintf("%v %v", sqlType, additionalType) return fmt.Sprintf("%v %v", sqlType, additionalType)
} }
func (s postgres) HasIndex(tableName string, indexName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()", tableName, indexName).Scan(&count)
return count > 0
}
func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool {
var count int
s.db.QueryRow("SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", tableName, foreignKeyName).Scan(&count)
return count > 0
}
func (s postgres) HasTable(tableName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()", tableName).Scan(&count)
return count > 0
}
func (s postgres) HasColumn(tableName string, columnName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()", tableName, columnName).Scan(&count)
return count > 0
}
func (s postgres) CurrentDatabase() (name string) {
s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name)
return
}
func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string { func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string {
return fmt.Sprintf("RETURNING %v.%v", tableName, key) return fmt.Sprintf("RETURNING %v.%v", tableName, key)
} }

34
dialect_postgres_go1.8.go Normal file
View File

@ -0,0 +1,34 @@
// +build go1.8
package gorm
import "context"
func (s postgres) HasIndex(tableName string, indexName string) bool {
var count int
s.db.QueryRowContext(context.Background(), queryPostgresHasIndex, tableName, indexName).Scan(&count)
return count > 0
}
func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool {
var count int
s.db.QueryRowContext(context.Background(), queryPostgresHasForeignKey, tableName, foreignKeyName).Scan(&count)
return count > 0
}
func (s postgres) HasTable(tableName string) bool {
var count int
s.db.QueryRowContext(context.Background(), queryPostgresHasTable, tableName).Scan(&count)
return count > 0
}
func (s postgres) HasColumn(tableName string, columnName string) bool {
var count int
s.db.QueryRowContext(context.Background(), queryPostgresHasColumn, tableName, columnName).Scan(&count)
return count > 0
}
func (s postgres) CurrentDatabase() (name string) {
s.db.QueryRowContext(context.Background(), queryPostgresCurrentDatabase).Scan(&name)
return
}

View File

@ -0,0 +1,32 @@
// +build !go1.8
package gorm
func (s postgres) HasIndex(tableName string, indexName string) bool {
var count int
s.db.QueryRow(queryPostgresHasIndex, tableName, indexName).Scan(&count)
return count > 0
}
func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool {
var count int
s.db.QueryRow(queryPostgresHasForeignKey, tableName, foreignKeyName).Scan(&count)
return count > 0
}
func (s postgres) HasTable(tableName string) bool {
var count int
s.db.QueryRow(queryPostgresHasTable, tableName).Scan(&count)
return count > 0
}
func (s postgres) HasColumn(tableName string, columnName string) bool {
var count int
s.db.QueryRow(queryPostgresHasColumn, tableName, columnName).Scan(&count)
return count > 0
}
func (s postgres) CurrentDatabase() (name string) {
s.db.QueryRow(queryPostgresCurrentDatabase).Scan(&name)
return
}

View File

@ -7,6 +7,13 @@ import (
"time" "time"
) )
const (
querySQLite3HasIndex = "SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'"
querySQLite3HasTable = "SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?"
querySQLite3HasColumn = "SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n"
querySQLite3CurrentDatabase = "PRAGMA database_list"
)
type sqlite3 struct { type sqlite3 struct {
commonDialect commonDialect
} }
@ -69,39 +76,3 @@ func (s *sqlite3) DataTypeOf(field *StructField) string {
} }
return fmt.Sprintf("%v %v", sqlType, additionalType) return fmt.Sprintf("%v %v", sqlType, additionalType)
} }
func (s sqlite3) HasIndex(tableName string, indexName string) bool {
var count int
s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count)
return count > 0
}
func (s sqlite3) HasTable(tableName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count)
return count > 0
}
func (s sqlite3) HasColumn(tableName string, columnName string) bool {
var count int
s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n", columnName, columnName), tableName).Scan(&count)
return count > 0
}
func (s sqlite3) CurrentDatabase() (name string) {
var (
ifaces = make([]interface{}, 3)
pointers = make([]*string, 3)
i int
)
for i = 0; i < 3; i++ {
ifaces[i] = &pointers[i]
}
if err := s.db.QueryRow("PRAGMA database_list").Scan(ifaces...); err != nil {
return
}
if pointers[1] != nil {
name = *pointers[1]
}
return
}

44
dialect_sqlite3_go1.8.go Normal file
View File

@ -0,0 +1,44 @@
// +build go1.8
package gorm
import (
"context"
"fmt"
)
func (s sqlite3) HasIndex(tableName string, indexName string) bool {
var count int
s.db.QueryRowContext(context.Background(), fmt.Sprintf(querySQLite3HasIndex, indexName), tableName).Scan(&count)
return count > 0
}
func (s sqlite3) HasTable(tableName string) bool {
var count int
s.db.QueryRowContext(context.Background(), querySQLite3HasTable, tableName).Scan(&count)
return count > 0
}
func (s sqlite3) HasColumn(tableName string, columnName string) bool {
var count int
s.db.QueryRowContext(context.Background(), fmt.Sprintf(querySQLite3HasColumn, columnName, columnName), tableName).Scan(&count)
return count > 0
}
func (s sqlite3) CurrentDatabase() (name string) {
var (
ifaces = make([]interface{}, 3)
pointers = make([]*string, 3)
i int
)
for i = 0; i < 3; i++ {
ifaces[i] = &pointers[i]
}
if err := s.db.QueryRowContext(context.Background(), querySQLite3CurrentDatabase).Scan(ifaces...); err != nil {
return
}
if pointers[1] != nil {
name = *pointers[1]
}
return
}

View File

@ -0,0 +1,41 @@
// +build !go1.8
package gorm
import "fmt"
func (s sqlite3) HasIndex(tableName string, indexName string) bool {
var count int
s.db.QueryRow(fmt.Sprintf(querySQLite3HasIndex, indexName), tableName).Scan(&count)
return count > 0
}
func (s sqlite3) HasTable(tableName string) bool {
var count int
s.db.QueryRow(querySQLite3HasTable, tableName).Scan(&count)
return count > 0
}
func (s sqlite3) HasColumn(tableName string, columnName string) bool {
var count int
s.db.QueryRow(fmt.Sprintf(querySQLite3HasColumn, columnName, columnName), tableName).Scan(&count)
return count > 0
}
func (s sqlite3) CurrentDatabase() (name string) {
var (
ifaces = make([]interface{}, 3)
pointers = make([]*string, 3)
i int
)
for i = 0; i < 3; i++ {
ifaces[i] = &pointers[i]
}
if err := s.db.QueryRow(querySQLite3CurrentDatabase).Scan(ifaces...); err != nil {
return
}
if pointers[1] != nil {
name = *pointers[1]
}
return
}

View File

@ -11,6 +11,14 @@ import (
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
) )
const (
queryMSSQLHasIndex = "SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)"
queryMSSQLRemoveIndex = "DROP INDEX %v ON %v"
queryMSSQLHasTable = "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?"
queryMSSQLHasColumn = "SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?"
queryMSSQLCurrentDatabase = "SELECT DB_NAME() AS [Current Database]"
)
func setIdentityInsert(scope *gorm.Scope) { func setIdentityInsert(scope *gorm.Scope) {
if scope.Dialect().GetName() == "mssql" { if scope.Dialect().GetName() == "mssql" {
for _, field := range scope.PrimaryFields() { for _, field := range scope.PrimaryFields() {
@ -111,38 +119,10 @@ func (s *mssql) DataTypeOf(field *gorm.StructField) string {
return fmt.Sprintf("%v %v", sqlType, additionalType) return fmt.Sprintf("%v %v", sqlType, additionalType)
} }
func (s mssql) HasIndex(tableName string, indexName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count)
return count > 0
}
func (s mssql) RemoveIndex(tableName string, indexName string) error {
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
return err
}
func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool { func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool {
return false return false
} }
func (s mssql) HasTable(tableName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.CurrentDatabase()).Scan(&count)
return count > 0
}
func (s mssql) HasColumn(tableName string, columnName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.CurrentDatabase(), tableName, columnName).Scan(&count)
return count > 0
}
func (s mssql) CurrentDatabase() (name string) {
s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name)
return
}
func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
if offset != nil { if offset != nil {
if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 { if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {

View File

@ -0,0 +1,36 @@
// +build go1.8
package mssql
import (
"context"
"fmt"
)
func (s mssql) HasIndex(tableName string, indexName string) bool {
var count int
s.db.QueryRowContext(context.Background(), queryMSSQLHasIndex, indexName, tableName).Scan(&count)
return count > 0
}
func (s mssql) RemoveIndex(tableName string, indexName string) error {
_, err := s.db.ExecContext(context.Background(), fmt.Sprintf(queryMSSQLRemoveIndex, indexName, s.Quote(tableName)))
return err
}
func (s mssql) HasTable(tableName string) bool {
var count int
s.db.QueryRowContext(context.Background(), queryMSSQLHasTable, tableName, s.CurrentDatabase()).Scan(&count)
return count > 0
}
func (s mssql) HasColumn(tableName string, columnName string) bool {
var count int
s.db.QueryRowContext(context.Background(), queryMSSQLHasColumn, s.CurrentDatabase(), tableName, columnName).Scan(&count)
return count > 0
}
func (s mssql) CurrentDatabase() (name string) {
s.db.QueryRowContext(context.Background(), queryMSSQLCurrentDatabase).Scan(&name)
return
}

View File

@ -0,0 +1,33 @@
// +build !go1.8
package mssql
import "fmt"
func (s mssql) HasIndex(tableName string, indexName string) bool {
var count int
s.db.QueryRow(queryMSSQLHasIndex, indexName, tableName).Scan(&count)
return count > 0
}
func (s mssql) RemoveIndex(tableName string, indexName string) error {
_, err := s.db.Exec(fmt.Sprintf(queryMSSQLRemoveIndex, indexName, s.Quote(tableName)))
return err
}
func (s mssql) HasTable(tableName string) bool {
var count int
s.db.QueryRow(queryMSSQLHasTable, tableName, s.CurrentDatabase()).Scan(&count)
return count > 0
}
func (s mssql) HasColumn(tableName string, columnName string) bool {
var count int
s.db.QueryRow(queryMSSQLHasColumn, s.CurrentDatabase(), tableName, columnName).Scan(&count)
return count > 0
}
func (s mssql) CurrentDatabase() (name string) {
s.db.QueryRow(queryMSSQLCurrentDatabase).Scan(&name)
return
}

View File

@ -1,19 +1,5 @@
package gorm package gorm
import "database/sql"
// SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB.
type SQLCommon interface {
Exec(query string, args ...interface{}) (sql.Result, error)
Prepare(query string) (*sql.Stmt, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
}
type sqlDb interface {
Begin() (*sql.Tx, error)
}
type sqlTx interface { type sqlTx interface {
Commit() error Commit() error
Rollback() error Rollback() error

20
interface_go1.8.go Normal file
View File

@ -0,0 +1,20 @@
// +build go1.8
package gorm
import (
"context"
"database/sql"
)
// SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB.
type SQLCommon interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}
type sqlDb interface {
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
}

17
interface_go1.8pre.go Normal file
View File

@ -0,0 +1,17 @@
// +build !go1.8
package gorm
import "database/sql"
// SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB.
type SQLCommon interface {
Exec(query string, args ...interface{}) (sql.Result, error)
Prepare(query string) (*sql.Stmt, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
}
type sqlDb interface {
Begin() (*sql.Tx, error)
}

18
main.go
View File

@ -7,6 +7,9 @@ import (
"reflect" "reflect"
"strings" "strings"
"time" "time"
// Using the old package to support older golangs
"golang.org/x/net/context"
) )
// DB contains information for current db connection // DB contains information for current db connection
@ -22,6 +25,7 @@ type DB struct {
logger logger logger logger
search *search search *search
values map[string]interface{} values map[string]interface{}
context context.Context
// global db // global db
parent *DB parent *DB
@ -460,19 +464,6 @@ func (s *DB) Debug() *DB {
return s.clone().LogMode(true) return s.clone().LogMode(true)
} }
// Begin begin a transaction
func (s *DB) Begin() *DB {
c := s.clone()
if db, ok := c.db.(sqlDb); ok && db != nil {
tx, err := db.Begin()
c.db = interface{}(tx).(SQLCommon)
c.AddError(err)
} else {
c.AddError(ErrCantStartTransaction)
}
return c
}
// Commit commit a transaction // Commit commit a transaction
func (s *DB) Commit() *DB { func (s *DB) Commit() *DB {
if db, ok := s.db.(sqlTx); ok && db != nil { if db, ok := s.db.(sqlTx); ok && db != nil {
@ -717,6 +708,7 @@ func (s *DB) clone() *DB {
logger: s.logger, logger: s.logger,
logMode: s.logMode, logMode: s.logMode,
values: map[string]interface{}{}, values: map[string]interface{}{},
context: s.context,
Value: s.Value, Value: s.Value,
Error: s.Error, Error: s.Error,
blockGlobalUpdate: s.blockGlobalUpdate, blockGlobalUpdate: s.blockGlobalUpdate,

46
main_go1.8.go Normal file
View File

@ -0,0 +1,46 @@
// +build go1.8
package gorm
import (
"context"
"database/sql"
)
// WithContext specify context to be passed to the underlying `*sql.DB` or
// `*sql.Tx` query methods
func (s *DB) WithContext(ctx context.Context) *DB {
db := s.clone()
db.context = ctx
return db
}
// Context returns the specified context for this instance, or nil if not set
func (s *DB) Context() context.Context {
return s.context
}
func (s *DB) contextOrBackground() context.Context {
if s.context != nil {
return s.context
}
return context.Background()
}
// BeginTx starts a transaction with the given options
func (s *DB) BeginTx(opts *sql.TxOptions) *DB {
c := s.clone()
if db, ok := c.db.(sqlDb); ok && db != nil {
tx, err := db.BeginTx(s.contextOrBackground(), opts)
c.db = interface{}(tx).(SQLCommon)
c.AddError(err)
} else {
c.AddError(ErrCantStartTransaction)
}
return c
}
// Begin starts a transaction
func (s *DB) Begin() *DB {
return s.BeginTx(nil)
}

19
main_go1.8_test.go Normal file
View File

@ -0,0 +1,19 @@
// +build go1.8
package gorm_test
import (
"context"
"testing"
"time"
)
func TestContext(t *testing.T) {
user1 := User{Name: "RowsUser1", Age: 1, Birthday: parseTime("2000-1-1")}
expiredCtx, cancel := context.WithDeadline(context.Background(), time.Date(2000, 1, 1, 1, 0, 0, 0, time.UTC))
err := DB.WithContext(expiredCtx).Save(&user1).Error
cancel()
if err.Error() != context.DeadlineExceeded.Error() {
t.Fatal("unexpected err:", err)
}
}

16
main_go1.8pre.go Normal file
View File

@ -0,0 +1,16 @@
// +build !go1.8
package gorm
// Begin starts a transaction
func (s *DB) Begin() *DB {
c := s.clone()
if db, ok := c.db.(sqlDb); ok && db != nil {
tx, err := db.Begin()
c.db = interface{}(tx).(SQLCommon)
c.AddError(err)
} else {
c.AddError(ErrCantStartTransaction)
}
return c
}

View File

@ -359,7 +359,7 @@ func (scope *Scope) Exec() *Scope {
defer scope.trace(NowFunc()) defer scope.trace(NowFunc())
if !scope.HasError() { if !scope.HasError() {
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { if result, err := scope.sqldbExec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
if count, err := result.RowsAffected(); scope.Err(err) == nil { if count, err := result.RowsAffected(); scope.Err(err) == nil {
scope.db.RowsAffected = count scope.db.RowsAffected = count
} }
@ -397,17 +397,6 @@ func (scope *Scope) InstanceGet(name string) (interface{}, bool) {
return scope.Get(name + scope.InstanceID()) return scope.Get(name + scope.InstanceID())
} }
// Begin start a transaction
func (scope *Scope) Begin() *Scope {
if db, ok := scope.SQLDB().(sqlDb); ok {
if tx, err := db.Begin(); err == nil {
scope.db.db = interface{}(tx).(SQLCommon)
scope.InstanceSet("gorm:started_transaction", true)
}
}
return scope
}
// CommitOrRollback commit current transaction if no error happened, otherwise will rollback it // CommitOrRollback commit current transaction if no error happened, otherwise will rollback it
func (scope *Scope) CommitOrRollback() *Scope { func (scope *Scope) CommitOrRollback() *Scope {
if _, ok := scope.InstanceGet("gorm:started_transaction"); ok { if _, ok := scope.InstanceGet("gorm:started_transaction"); ok {
@ -1062,6 +1051,7 @@ func (scope *Scope) getTableOptions() string {
return tableOptions.(string) return tableOptions.(string)
} }
// TODO: context variant
func (scope *Scope) createJoinTable(field *StructField) { func (scope *Scope) createJoinTable(field *StructField) {
if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil { if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
joinTableHandler := relationship.JoinTableHandler joinTableHandler := relationship.JoinTableHandler
@ -1098,6 +1088,7 @@ func (scope *Scope) createJoinTable(field *StructField) {
} }
} }
// TODO: context variant
func (scope *Scope) createTable() *Scope { func (scope *Scope) createTable() *Scope {
var tags []string var tags []string
var primaryKeys []string var primaryKeys []string
@ -1133,19 +1124,23 @@ func (scope *Scope) createTable() *Scope {
return scope return scope
} }
// TODO: context variant
func (scope *Scope) dropTable() *Scope { func (scope *Scope) dropTable() *Scope {
scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec() scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec()
return scope return scope
} }
// TODO: context variant
func (scope *Scope) modifyColumn(column string, typ string) { func (scope *Scope) modifyColumn(column string, typ string) {
scope.Raw(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec() scope.Raw(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec()
} }
// TODO: context variant
func (scope *Scope) dropColumn(column string) { func (scope *Scope) dropColumn(column string) {
scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.QuotedTableName(), scope.Quote(column))).Exec() scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.QuotedTableName(), scope.Quote(column))).Exec()
} }
// TODO: context variant
func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
if scope.Dialect().HasIndex(scope.TableName(), indexName) { if scope.Dialect().HasIndex(scope.TableName(), indexName) {
return return
@ -1164,6 +1159,7 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSQL())).Exec() scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSQL())).Exec()
} }
// TODO: context variant
func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) { func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {
// Compatible with old generated key // Compatible with old generated key
keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign") keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign")
@ -1175,10 +1171,12 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec() scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec()
} }
// TODO: context variant
func (scope *Scope) removeIndex(indexName string) { func (scope *Scope) removeIndex(indexName string) {
scope.Dialect().RemoveIndex(scope.TableName(), indexName) scope.Dialect().RemoveIndex(scope.TableName(), indexName)
} }
// TODO: context variant
func (scope *Scope) autoMigrate() *Scope { func (scope *Scope) autoMigrate() *Scope {
tableName := scope.TableName() tableName := scope.TableName()
quotedTableName := scope.QuotedTableName() quotedTableName := scope.QuotedTableName()
@ -1200,6 +1198,7 @@ func (scope *Scope) autoMigrate() *Scope {
return scope return scope
} }
// TODO: context variant
func (scope *Scope) autoIndex() *Scope { func (scope *Scope) autoIndex() *Scope {
var indexes = map[string][]string{} var indexes = map[string][]string{}
var uniqueIndexes = map[string][]string{} var uniqueIndexes = map[string][]string{}

37
scope_go1.8.go Normal file
View File

@ -0,0 +1,37 @@
// +build go1.8
package gorm
import "database/sql"
// BeginTx start a transaction with the given options
func (scope *Scope) BeginTx(opts *sql.TxOptions) *Scope {
if db, ok := scope.SQLDB().(sqlDb); ok {
if tx, err := db.BeginTx(scope.DB().contextOrBackground(), opts); err == nil {
scope.db.db = interface{}(tx).(SQLCommon)
scope.InstanceSet("gorm:started_transaction", true)
}
}
return scope
}
// Begin start a transaction
func (scope *Scope) Begin() *Scope {
return scope.BeginTx(nil)
}
func (scope *Scope) sqldbExec(query string, args ...interface{}) (sql.Result, error) {
return scope.SQLDB().ExecContext(scope.db.contextOrBackground(), query, args...)
}
func (scope *Scope) sqldbPrepare(query string) (*sql.Stmt, error) {
return scope.SQLDB().PrepareContext(scope.db.contextOrBackground(), query)
}
func (scope *Scope) sqldbQuery(query string, args ...interface{}) (*sql.Rows, error) {
return scope.SQLDB().QueryContext(scope.db.contextOrBackground(), query, args...)
}
func (scope *Scope) sqldbQueryRow(query string, args ...interface{}) *sql.Row {
return scope.SQLDB().QueryRowContext(scope.db.contextOrBackground(), query, args...)
}

32
scope_go1.8pre.go Normal file
View File

@ -0,0 +1,32 @@
// +build !go1.8
package gorm
import "database/sql"
// Begin start a transaction
func (scope *Scope) Begin() *Scope {
if db, ok := scope.SQLDB().(sqlDb); ok {
if tx, err := db.Begin(); err == nil {
scope.db.db = interface{}(tx).(SQLCommon)
scope.InstanceSet("gorm:started_transaction", true)
}
}
return scope
}
func (scope *Scope) sqldbExec(query string, args ...interface{}) (sql.Result, error) {
return scope.SQLDB().Exec(query, args...)
}
func (scope *Scope) sqldbPrepare(query string) (*sql.Stmt, error) {
return scope.SQLDB().Prepare(query)
}
func (scope *Scope) sqldbQuery(query string, args ...interface{}) (*sql.Rows, error) {
return scope.SQLDB().Query(query, args...)
}
func (scope *Scope) sqldbQueryRow(query string, args ...interface{}) *sql.Row {
return scope.SQLDB().QueryRow(query, args...)
}