Utilize go1.8 context support in database/sql
Fixes #1231 The related go1.8 release notes: https://golang.org/doc/go1.8#database_sql
This commit is contained in:
		
							parent
							
								
									0a51f6cdc5
								
							
						
					
					
						commit
						9340b97a0b
					
				@ -20,6 +20,7 @@ The fantastic ORM library for Golang, aims to be developer friendly.
 | 
			
		||||
* Extendable, write Plugins based on GORM callbacks
 | 
			
		||||
* Every feature comes with tests
 | 
			
		||||
* Developer Friendly
 | 
			
		||||
* Supports context.Context on golang 1.8
 | 
			
		||||
 | 
			
		||||
## Getting Started
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -115,7 +115,7 @@ func createCallback(scope *Scope) {
 | 
			
		||||
 | 
			
		||||
		// execute create sql
 | 
			
		||||
		if lastInsertIDReturningSuffix == "" || primaryField == nil {
 | 
			
		||||
			if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
 | 
			
		||||
			if result, err := scope.sqldbExec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
 | 
			
		||||
				// set rows affected count
 | 
			
		||||
				scope.db.RowsAffected, _ = result.RowsAffected()
 | 
			
		||||
 | 
			
		||||
@ -128,7 +128,7 @@ func createCallback(scope *Scope) {
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
			if primaryField.Field.CanAddr() {
 | 
			
		||||
				if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
 | 
			
		||||
				if err := scope.sqldbQueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
 | 
			
		||||
					primaryField.IsBlank = false
 | 
			
		||||
					scope.db.RowsAffected = 1
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
@ -55,7 +55,7 @@ func queryCallback(scope *Scope) {
 | 
			
		||||
			scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
 | 
			
		||||
		if rows, err := scope.sqldbQuery(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
 | 
			
		||||
			defer rows.Close()
 | 
			
		||||
 | 
			
		||||
			columns, _ := rows.Columns()
 | 
			
		||||
 | 
			
		||||
@ -22,9 +22,9 @@ func rowQueryCallback(scope *Scope) {
 | 
			
		||||
		scope.prepareQuerySQL()
 | 
			
		||||
 | 
			
		||||
		if rowResult, ok := result.(*RowQueryResult); ok {
 | 
			
		||||
			rowResult.Row = scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...)
 | 
			
		||||
			rowResult.Row = scope.sqldbQueryRow(scope.SQL, scope.SQLVars...)
 | 
			
		||||
		} else if rowsResult, ok := result.(*RowsQueryResult); ok {
 | 
			
		||||
			rowsResult.Rows, rowsResult.Error = scope.SQLDB().Query(scope.SQL, scope.SQLVars...)
 | 
			
		||||
			rowsResult.Rows, rowsResult.Error = scope.sqldbQuery(scope.SQL, scope.SQLVars...)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -9,6 +9,14 @@ import (
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	queryHasIndex        = "SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?"
 | 
			
		||||
	queryRemoveIndex     = "DROP INDEX %v"
 | 
			
		||||
	queryHasTable        = "SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?"
 | 
			
		||||
	queryHasColumn       = "SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?"
 | 
			
		||||
	queryCurrentDatabase = "SELECT DATABASE()"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// DefaultForeignKeyNamer contains the default foreign key name generator method
 | 
			
		||||
type DefaultForeignKeyNamer struct {
 | 
			
		||||
}
 | 
			
		||||
@ -90,38 +98,10 @@ func (s *commonDialect) DataTypeOf(field *StructField) string {
 | 
			
		||||
	return fmt.Sprintf("%v %v", sqlType, additionalType)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s commonDialect) HasIndex(tableName string, indexName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", s.CurrentDatabase(), tableName, indexName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s commonDialect) RemoveIndex(tableName string, indexName string) error {
 | 
			
		||||
	_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName))
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool {
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s commonDialect) HasTable(tableName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", s.CurrentDatabase(), tableName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s commonDialect) HasColumn(tableName string, columnName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.CurrentDatabase(), tableName, columnName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s commonDialect) CurrentDatabase() (name string) {
 | 
			
		||||
	s.db.QueryRow("SELECT DATABASE()").Scan(&name)
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
 | 
			
		||||
	if limit != nil {
 | 
			
		||||
		if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										36
									
								
								dialect_common_go1.8.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								dialect_common_go1.8.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,36 @@
 | 
			
		||||
// +build go1.8
 | 
			
		||||
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func (s commonDialect) HasIndex(tableName string, indexName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRowContext(context.Background(), queryHasIndex, s.CurrentDatabase(), tableName, indexName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s commonDialect) RemoveIndex(tableName string, indexName string) error {
 | 
			
		||||
	_, err := s.db.ExecContext(context.Background(), fmt.Sprintf(queryRemoveIndex, indexName))
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s commonDialect) HasTable(tableName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRowContext(context.Background(), queryHasTable, s.CurrentDatabase(), tableName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s commonDialect) HasColumn(tableName string, columnName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRowContext(context.Background(), queryHasColumn, s.CurrentDatabase(), tableName, columnName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s commonDialect) CurrentDatabase() (name string) {
 | 
			
		||||
	s.db.QueryRowContext(context.Background(), queryCurrentDatabase).Scan(&name)
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										33
									
								
								dialect_common_go1.8pre.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								dialect_common_go1.8pre.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,33 @@
 | 
			
		||||
// +build !go1.8
 | 
			
		||||
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import "fmt"
 | 
			
		||||
 | 
			
		||||
func (s commonDialect) HasIndex(tableName string, indexName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRow(queryHasIndex, s.CurrentDatabase(), tableName, indexName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s commonDialect) RemoveIndex(tableName string, indexName string) error {
 | 
			
		||||
	_, err := s.db.Exec(fmt.Sprintf(queryRemoveIndex, indexName))
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s commonDialect) HasTable(tableName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRow(queryHasTable, s.CurrentDatabase(), tableName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s commonDialect) HasColumn(tableName string, columnName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRow(queryHasColumn, s.CurrentDatabase(), tableName, columnName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s commonDialect) CurrentDatabase() (name string) {
 | 
			
		||||
	s.db.QueryRow(queryCurrentDatabase).Scan(&name)
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
@ -11,6 +11,12 @@ import (
 | 
			
		||||
	"unicode/utf8"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	queryMySQLRemoveIndex     = "DROP INDEX %v ON %v"
 | 
			
		||||
	queryMySQLHasForeignKey   = "SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'"
 | 
			
		||||
	queryMySQLCurrentDatabase = "SELECT DATABASE()"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type mysql struct {
 | 
			
		||||
	commonDialect
 | 
			
		||||
}
 | 
			
		||||
@ -122,11 +128,6 @@ func (s *mysql) DataTypeOf(field *StructField) string {
 | 
			
		||||
	return fmt.Sprintf("%v %v", sqlType, additionalType)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s mysql) RemoveIndex(tableName string, indexName string) error {
 | 
			
		||||
	_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
 | 
			
		||||
	if limit != nil {
 | 
			
		||||
		if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
 | 
			
		||||
@ -142,17 +143,6 @@ func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", s.CurrentDatabase(), tableName, foreignKeyName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s mysql) CurrentDatabase() (name string) {
 | 
			
		||||
	s.db.QueryRow("SELECT DATABASE()").Scan(&name)
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (mysql) SelectFromDummyTable() string {
 | 
			
		||||
	return "FROM DUAL"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										24
									
								
								dialect_mysql_go1.8.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								dialect_mysql_go1.8.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,24 @@
 | 
			
		||||
// +build go1.8
 | 
			
		||||
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func (s mysql) RemoveIndex(tableName string, indexName string) error {
 | 
			
		||||
	_, err := s.db.ExecContext(context.Background(), fmt.Sprintf(queryMySQLRemoveIndex, indexName, s.Quote(tableName)))
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRowContext(context.Background(), queryMySQLHasForeignKey, s.CurrentDatabase(), tableName, foreignKeyName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s mysql) CurrentDatabase() (name string) {
 | 
			
		||||
	s.db.QueryRowContext(context.Background(), queryMySQLCurrentDatabase).Scan(&name)
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										21
									
								
								dialect_mysql_go1.8pre.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								dialect_mysql_go1.8pre.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,21 @@
 | 
			
		||||
// +build !go1.8
 | 
			
		||||
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import "fmt"
 | 
			
		||||
 | 
			
		||||
func (s mysql) RemoveIndex(tableName string, indexName string) error {
 | 
			
		||||
	_, err := s.db.Exec(fmt.Sprintf(queryMySQLRemoveIndex, indexName, s.Quote(tableName)))
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRow(queryMySQLHasForeignKey, s.CurrentDatabase(), tableName, foreignKeyName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s mysql) CurrentDatabase() (name string) {
 | 
			
		||||
	s.db.QueryRow(queryMySQLCurrentDatabase).Scan(&name)
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
@ -7,6 +7,14 @@ import (
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	queryPostgresHasIndex        = "SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()"
 | 
			
		||||
	queryPostgresHasForeignKey   = "SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'"
 | 
			
		||||
	queryPostgresHasTable        = "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()"
 | 
			
		||||
	queryPostgresHasColumn       = "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()"
 | 
			
		||||
	queryPostgresCurrentDatabase = "SELECT CURRENT_DATABASE()"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type postgres struct {
 | 
			
		||||
	commonDialect
 | 
			
		||||
}
 | 
			
		||||
@ -85,35 +93,6 @@ func (s *postgres) DataTypeOf(field *StructField) string {
 | 
			
		||||
	return fmt.Sprintf("%v %v", sqlType, additionalType)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s postgres) HasIndex(tableName string, indexName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()", tableName, indexName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRow("SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", tableName, foreignKeyName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s postgres) HasTable(tableName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()", tableName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s postgres) HasColumn(tableName string, columnName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()", tableName, columnName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s postgres) CurrentDatabase() (name string) {
 | 
			
		||||
	s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name)
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string {
 | 
			
		||||
	return fmt.Sprintf("RETURNING %v.%v", tableName, key)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										34
									
								
								dialect_postgres_go1.8.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								dialect_postgres_go1.8.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,34 @@
 | 
			
		||||
// +build go1.8
 | 
			
		||||
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import "context"
 | 
			
		||||
 | 
			
		||||
func (s postgres) HasIndex(tableName string, indexName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRowContext(context.Background(), queryPostgresHasIndex, tableName, indexName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRowContext(context.Background(), queryPostgresHasForeignKey, tableName, foreignKeyName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s postgres) HasTable(tableName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRowContext(context.Background(), queryPostgresHasTable, tableName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s postgres) HasColumn(tableName string, columnName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRowContext(context.Background(), queryPostgresHasColumn, tableName, columnName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s postgres) CurrentDatabase() (name string) {
 | 
			
		||||
	s.db.QueryRowContext(context.Background(), queryPostgresCurrentDatabase).Scan(&name)
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										32
									
								
								dialect_postgres_go1.8pre.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								dialect_postgres_go1.8pre.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,32 @@
 | 
			
		||||
// +build !go1.8
 | 
			
		||||
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
func (s postgres) HasIndex(tableName string, indexName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRow(queryPostgresHasIndex, tableName, indexName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRow(queryPostgresHasForeignKey, tableName, foreignKeyName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s postgres) HasTable(tableName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRow(queryPostgresHasTable, tableName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s postgres) HasColumn(tableName string, columnName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRow(queryPostgresHasColumn, tableName, columnName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s postgres) CurrentDatabase() (name string) {
 | 
			
		||||
	s.db.QueryRow(queryPostgresCurrentDatabase).Scan(&name)
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
@ -7,6 +7,13 @@ import (
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	querySQLite3HasIndex        = "SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'"
 | 
			
		||||
	querySQLite3HasTable        = "SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?"
 | 
			
		||||
	querySQLite3HasColumn       = "SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n"
 | 
			
		||||
	querySQLite3CurrentDatabase = "PRAGMA database_list"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type sqlite3 struct {
 | 
			
		||||
	commonDialect
 | 
			
		||||
}
 | 
			
		||||
@ -69,39 +76,3 @@ func (s *sqlite3) DataTypeOf(field *StructField) string {
 | 
			
		||||
	}
 | 
			
		||||
	return fmt.Sprintf("%v %v", sqlType, additionalType)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s sqlite3) HasIndex(tableName string, indexName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s sqlite3) HasTable(tableName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRow("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s sqlite3) HasColumn(tableName string, columnName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n", columnName, columnName), tableName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s sqlite3) CurrentDatabase() (name string) {
 | 
			
		||||
	var (
 | 
			
		||||
		ifaces   = make([]interface{}, 3)
 | 
			
		||||
		pointers = make([]*string, 3)
 | 
			
		||||
		i        int
 | 
			
		||||
	)
 | 
			
		||||
	for i = 0; i < 3; i++ {
 | 
			
		||||
		ifaces[i] = &pointers[i]
 | 
			
		||||
	}
 | 
			
		||||
	if err := s.db.QueryRow("PRAGMA database_list").Scan(ifaces...); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if pointers[1] != nil {
 | 
			
		||||
		name = *pointers[1]
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										44
									
								
								dialect_sqlite3_go1.8.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								dialect_sqlite3_go1.8.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,44 @@
 | 
			
		||||
// +build go1.8
 | 
			
		||||
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func (s sqlite3) HasIndex(tableName string, indexName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRowContext(context.Background(), fmt.Sprintf(querySQLite3HasIndex, indexName), tableName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s sqlite3) HasTable(tableName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRowContext(context.Background(), querySQLite3HasTable, tableName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s sqlite3) HasColumn(tableName string, columnName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRowContext(context.Background(), fmt.Sprintf(querySQLite3HasColumn, columnName, columnName), tableName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s sqlite3) CurrentDatabase() (name string) {
 | 
			
		||||
	var (
 | 
			
		||||
		ifaces   = make([]interface{}, 3)
 | 
			
		||||
		pointers = make([]*string, 3)
 | 
			
		||||
		i        int
 | 
			
		||||
	)
 | 
			
		||||
	for i = 0; i < 3; i++ {
 | 
			
		||||
		ifaces[i] = &pointers[i]
 | 
			
		||||
	}
 | 
			
		||||
	if err := s.db.QueryRowContext(context.Background(), querySQLite3CurrentDatabase).Scan(ifaces...); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if pointers[1] != nil {
 | 
			
		||||
		name = *pointers[1]
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										41
									
								
								dialect_sqlite3_go1.8pre.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								dialect_sqlite3_go1.8pre.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,41 @@
 | 
			
		||||
// +build !go1.8
 | 
			
		||||
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import "fmt"
 | 
			
		||||
 | 
			
		||||
func (s sqlite3) HasIndex(tableName string, indexName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRow(fmt.Sprintf(querySQLite3HasIndex, indexName), tableName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s sqlite3) HasTable(tableName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRow(querySQLite3HasTable, tableName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s sqlite3) HasColumn(tableName string, columnName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRow(fmt.Sprintf(querySQLite3HasColumn, columnName, columnName), tableName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s sqlite3) CurrentDatabase() (name string) {
 | 
			
		||||
	var (
 | 
			
		||||
		ifaces   = make([]interface{}, 3)
 | 
			
		||||
		pointers = make([]*string, 3)
 | 
			
		||||
		i        int
 | 
			
		||||
	)
 | 
			
		||||
	for i = 0; i < 3; i++ {
 | 
			
		||||
		ifaces[i] = &pointers[i]
 | 
			
		||||
	}
 | 
			
		||||
	if err := s.db.QueryRow(querySQLite3CurrentDatabase).Scan(ifaces...); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if pointers[1] != nil {
 | 
			
		||||
		name = *pointers[1]
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
@ -11,6 +11,14 @@ import (
 | 
			
		||||
	"github.com/jinzhu/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	queryMSSQLHasIndex        = "SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)"
 | 
			
		||||
	queryMSSQLRemoveIndex     = "DROP INDEX %v ON %v"
 | 
			
		||||
	queryMSSQLHasTable        = "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?"
 | 
			
		||||
	queryMSSQLHasColumn       = "SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?"
 | 
			
		||||
	queryMSSQLCurrentDatabase = "SELECT DB_NAME() AS [Current Database]"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func setIdentityInsert(scope *gorm.Scope) {
 | 
			
		||||
	if scope.Dialect().GetName() == "mssql" {
 | 
			
		||||
		for _, field := range scope.PrimaryFields() {
 | 
			
		||||
@ -111,38 +119,10 @@ func (s *mssql) DataTypeOf(field *gorm.StructField) string {
 | 
			
		||||
	return fmt.Sprintf("%v %v", sqlType, additionalType)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s mssql) HasIndex(tableName string, indexName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s mssql) RemoveIndex(tableName string, indexName string) error {
 | 
			
		||||
	_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool {
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s mssql) HasTable(tableName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.CurrentDatabase()).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s mssql) HasColumn(tableName string, columnName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.CurrentDatabase(), tableName, columnName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s mssql) CurrentDatabase() (name string) {
 | 
			
		||||
	s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name)
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
 | 
			
		||||
	if offset != nil {
 | 
			
		||||
		if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										36
									
								
								dialects/mssql/mssql_go1.8.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								dialects/mssql/mssql_go1.8.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,36 @@
 | 
			
		||||
// +build go1.8
 | 
			
		||||
 | 
			
		||||
package mssql
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func (s mssql) HasIndex(tableName string, indexName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRowContext(context.Background(), queryMSSQLHasIndex, indexName, tableName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s mssql) RemoveIndex(tableName string, indexName string) error {
 | 
			
		||||
	_, err := s.db.ExecContext(context.Background(), fmt.Sprintf(queryMSSQLRemoveIndex, indexName, s.Quote(tableName)))
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s mssql) HasTable(tableName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRowContext(context.Background(), queryMSSQLHasTable, tableName, s.CurrentDatabase()).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s mssql) HasColumn(tableName string, columnName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRowContext(context.Background(), queryMSSQLHasColumn, s.CurrentDatabase(), tableName, columnName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s mssql) CurrentDatabase() (name string) {
 | 
			
		||||
	s.db.QueryRowContext(context.Background(), queryMSSQLCurrentDatabase).Scan(&name)
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										33
									
								
								dialects/mssql/mssql_go1.8pre.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								dialects/mssql/mssql_go1.8pre.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,33 @@
 | 
			
		||||
// +build !go1.8
 | 
			
		||||
 | 
			
		||||
package mssql
 | 
			
		||||
 | 
			
		||||
import "fmt"
 | 
			
		||||
 | 
			
		||||
func (s mssql) HasIndex(tableName string, indexName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRow(queryMSSQLHasIndex, indexName, tableName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s mssql) RemoveIndex(tableName string, indexName string) error {
 | 
			
		||||
	_, err := s.db.Exec(fmt.Sprintf(queryMSSQLRemoveIndex, indexName, s.Quote(tableName)))
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s mssql) HasTable(tableName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRow(queryMSSQLHasTable, tableName, s.CurrentDatabase()).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s mssql) HasColumn(tableName string, columnName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRow(queryMSSQLHasColumn, s.CurrentDatabase(), tableName, columnName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s mssql) CurrentDatabase() (name string) {
 | 
			
		||||
	s.db.QueryRow(queryMSSQLCurrentDatabase).Scan(&name)
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										14
									
								
								interface.go
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								interface.go
									
									
									
									
									
								
							@ -1,19 +1,5 @@
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import "database/sql"
 | 
			
		||||
 | 
			
		||||
// SQLCommon is the minimal database connection functionality gorm requires.  Implemented by *sql.DB.
 | 
			
		||||
type SQLCommon interface {
 | 
			
		||||
	Exec(query string, args ...interface{}) (sql.Result, error)
 | 
			
		||||
	Prepare(query string) (*sql.Stmt, error)
 | 
			
		||||
	Query(query string, args ...interface{}) (*sql.Rows, error)
 | 
			
		||||
	QueryRow(query string, args ...interface{}) *sql.Row
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type sqlDb interface {
 | 
			
		||||
	Begin() (*sql.Tx, error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type sqlTx interface {
 | 
			
		||||
	Commit() error
 | 
			
		||||
	Rollback() error
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										20
									
								
								interface_go1.8.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								interface_go1.8.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,20 @@
 | 
			
		||||
// +build go1.8
 | 
			
		||||
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"database/sql"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// SQLCommon is the minimal database connection functionality gorm requires.  Implemented by *sql.DB.
 | 
			
		||||
type SQLCommon interface {
 | 
			
		||||
	ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
 | 
			
		||||
	PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
 | 
			
		||||
	QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
 | 
			
		||||
	QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type sqlDb interface {
 | 
			
		||||
	BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										17
									
								
								interface_go1.8pre.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								interface_go1.8pre.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,17 @@
 | 
			
		||||
// +build !go1.8
 | 
			
		||||
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import "database/sql"
 | 
			
		||||
 | 
			
		||||
// SQLCommon is the minimal database connection functionality gorm requires.  Implemented by *sql.DB.
 | 
			
		||||
type SQLCommon interface {
 | 
			
		||||
	Exec(query string, args ...interface{}) (sql.Result, error)
 | 
			
		||||
	Prepare(query string) (*sql.Stmt, error)
 | 
			
		||||
	Query(query string, args ...interface{}) (*sql.Rows, error)
 | 
			
		||||
	QueryRow(query string, args ...interface{}) *sql.Row
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type sqlDb interface {
 | 
			
		||||
	Begin() (*sql.Tx, error)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										18
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										18
									
								
								main.go
									
									
									
									
									
								
							@ -7,6 +7,9 @@ import (
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	// Using the old package to support older golangs
 | 
			
		||||
	"golang.org/x/net/context"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// DB contains information for current db connection
 | 
			
		||||
@ -22,6 +25,7 @@ type DB struct {
 | 
			
		||||
	logger            logger
 | 
			
		||||
	search            *search
 | 
			
		||||
	values            map[string]interface{}
 | 
			
		||||
	context           context.Context
 | 
			
		||||
 | 
			
		||||
	// global db
 | 
			
		||||
	parent        *DB
 | 
			
		||||
@ -460,19 +464,6 @@ func (s *DB) Debug() *DB {
 | 
			
		||||
	return s.clone().LogMode(true)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Begin begin a transaction
 | 
			
		||||
func (s *DB) Begin() *DB {
 | 
			
		||||
	c := s.clone()
 | 
			
		||||
	if db, ok := c.db.(sqlDb); ok && db != nil {
 | 
			
		||||
		tx, err := db.Begin()
 | 
			
		||||
		c.db = interface{}(tx).(SQLCommon)
 | 
			
		||||
		c.AddError(err)
 | 
			
		||||
	} else {
 | 
			
		||||
		c.AddError(ErrCantStartTransaction)
 | 
			
		||||
	}
 | 
			
		||||
	return c
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Commit commit a transaction
 | 
			
		||||
func (s *DB) Commit() *DB {
 | 
			
		||||
	if db, ok := s.db.(sqlTx); ok && db != nil {
 | 
			
		||||
@ -717,6 +708,7 @@ func (s *DB) clone() *DB {
 | 
			
		||||
		logger:            s.logger,
 | 
			
		||||
		logMode:           s.logMode,
 | 
			
		||||
		values:            map[string]interface{}{},
 | 
			
		||||
		context:           s.context,
 | 
			
		||||
		Value:             s.Value,
 | 
			
		||||
		Error:             s.Error,
 | 
			
		||||
		blockGlobalUpdate: s.blockGlobalUpdate,
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										46
									
								
								main_go1.8.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								main_go1.8.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,46 @@
 | 
			
		||||
// +build go1.8
 | 
			
		||||
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"database/sql"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// WithContext specify context to be passed to the underlying `*sql.DB` or
 | 
			
		||||
// `*sql.Tx` query methods
 | 
			
		||||
func (s *DB) WithContext(ctx context.Context) *DB {
 | 
			
		||||
	db := s.clone()
 | 
			
		||||
	db.context = ctx
 | 
			
		||||
	return db
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Context returns the specified context for this instance, or nil if not set
 | 
			
		||||
func (s *DB) Context() context.Context {
 | 
			
		||||
	return s.context
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *DB) contextOrBackground() context.Context {
 | 
			
		||||
	if s.context != nil {
 | 
			
		||||
		return s.context
 | 
			
		||||
	}
 | 
			
		||||
	return context.Background()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// BeginTx starts a transaction with the given options
 | 
			
		||||
func (s *DB) BeginTx(opts *sql.TxOptions) *DB {
 | 
			
		||||
	c := s.clone()
 | 
			
		||||
	if db, ok := c.db.(sqlDb); ok && db != nil {
 | 
			
		||||
		tx, err := db.BeginTx(s.contextOrBackground(), opts)
 | 
			
		||||
		c.db = interface{}(tx).(SQLCommon)
 | 
			
		||||
		c.AddError(err)
 | 
			
		||||
	} else {
 | 
			
		||||
		c.AddError(ErrCantStartTransaction)
 | 
			
		||||
	}
 | 
			
		||||
	return c
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Begin starts a transaction
 | 
			
		||||
func (s *DB) Begin() *DB {
 | 
			
		||||
	return s.BeginTx(nil)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										19
									
								
								main_go1.8_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								main_go1.8_test.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,19 @@
 | 
			
		||||
// +build go1.8
 | 
			
		||||
 | 
			
		||||
package gorm_test
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestContext(t *testing.T) {
 | 
			
		||||
	user1 := User{Name: "RowsUser1", Age: 1, Birthday: parseTime("2000-1-1")}
 | 
			
		||||
	expiredCtx, cancel := context.WithDeadline(context.Background(), time.Date(2000, 1, 1, 1, 0, 0, 0, time.UTC))
 | 
			
		||||
	err := DB.WithContext(expiredCtx).Save(&user1).Error
 | 
			
		||||
	cancel()
 | 
			
		||||
	if err.Error() != context.DeadlineExceeded.Error() {
 | 
			
		||||
		t.Fatal("unexpected err:", err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										16
									
								
								main_go1.8pre.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								main_go1.8pre.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,16 @@
 | 
			
		||||
// +build !go1.8
 | 
			
		||||
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
// Begin starts a transaction
 | 
			
		||||
func (s *DB) Begin() *DB {
 | 
			
		||||
	c := s.clone()
 | 
			
		||||
	if db, ok := c.db.(sqlDb); ok && db != nil {
 | 
			
		||||
		tx, err := db.Begin()
 | 
			
		||||
		c.db = interface{}(tx).(SQLCommon)
 | 
			
		||||
		c.AddError(err)
 | 
			
		||||
	} else {
 | 
			
		||||
		c.AddError(ErrCantStartTransaction)
 | 
			
		||||
	}
 | 
			
		||||
	return c
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										23
									
								
								scope.go
									
									
									
									
									
								
							
							
						
						
									
										23
									
								
								scope.go
									
									
									
									
									
								
							@ -359,7 +359,7 @@ func (scope *Scope) Exec() *Scope {
 | 
			
		||||
	defer scope.trace(NowFunc())
 | 
			
		||||
 | 
			
		||||
	if !scope.HasError() {
 | 
			
		||||
		if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
 | 
			
		||||
		if result, err := scope.sqldbExec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
 | 
			
		||||
			if count, err := result.RowsAffected(); scope.Err(err) == nil {
 | 
			
		||||
				scope.db.RowsAffected = count
 | 
			
		||||
			}
 | 
			
		||||
@ -397,17 +397,6 @@ func (scope *Scope) InstanceGet(name string) (interface{}, bool) {
 | 
			
		||||
	return scope.Get(name + scope.InstanceID())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Begin start a transaction
 | 
			
		||||
func (scope *Scope) Begin() *Scope {
 | 
			
		||||
	if db, ok := scope.SQLDB().(sqlDb); ok {
 | 
			
		||||
		if tx, err := db.Begin(); err == nil {
 | 
			
		||||
			scope.db.db = interface{}(tx).(SQLCommon)
 | 
			
		||||
			scope.InstanceSet("gorm:started_transaction", true)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return scope
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CommitOrRollback commit current transaction if no error happened, otherwise will rollback it
 | 
			
		||||
func (scope *Scope) CommitOrRollback() *Scope {
 | 
			
		||||
	if _, ok := scope.InstanceGet("gorm:started_transaction"); ok {
 | 
			
		||||
@ -1062,6 +1051,7 @@ func (scope *Scope) getTableOptions() string {
 | 
			
		||||
	return tableOptions.(string)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TODO: context variant
 | 
			
		||||
func (scope *Scope) createJoinTable(field *StructField) {
 | 
			
		||||
	if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
 | 
			
		||||
		joinTableHandler := relationship.JoinTableHandler
 | 
			
		||||
@ -1098,6 +1088,7 @@ func (scope *Scope) createJoinTable(field *StructField) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TODO: context variant
 | 
			
		||||
func (scope *Scope) createTable() *Scope {
 | 
			
		||||
	var tags []string
 | 
			
		||||
	var primaryKeys []string
 | 
			
		||||
@ -1133,19 +1124,23 @@ func (scope *Scope) createTable() *Scope {
 | 
			
		||||
	return scope
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TODO: context variant
 | 
			
		||||
func (scope *Scope) dropTable() *Scope {
 | 
			
		||||
	scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec()
 | 
			
		||||
	return scope
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TODO: context variant
 | 
			
		||||
func (scope *Scope) modifyColumn(column string, typ string) {
 | 
			
		||||
	scope.Raw(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TODO: context variant
 | 
			
		||||
func (scope *Scope) dropColumn(column string) {
 | 
			
		||||
	scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.QuotedTableName(), scope.Quote(column))).Exec()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TODO: context variant
 | 
			
		||||
func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
 | 
			
		||||
	if scope.Dialect().HasIndex(scope.TableName(), indexName) {
 | 
			
		||||
		return
 | 
			
		||||
@ -1164,6 +1159,7 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
 | 
			
		||||
	scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSQL())).Exec()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TODO: context variant
 | 
			
		||||
func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {
 | 
			
		||||
	keyName := scope.Dialect().BuildForeignKeyName(scope.TableName(), field, dest)
 | 
			
		||||
 | 
			
		||||
@ -1174,10 +1170,12 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on
 | 
			
		||||
	scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TODO: context variant
 | 
			
		||||
func (scope *Scope) removeIndex(indexName string) {
 | 
			
		||||
	scope.Dialect().RemoveIndex(scope.TableName(), indexName)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TODO: context variant
 | 
			
		||||
func (scope *Scope) autoMigrate() *Scope {
 | 
			
		||||
	tableName := scope.TableName()
 | 
			
		||||
	quotedTableName := scope.QuotedTableName()
 | 
			
		||||
@ -1199,6 +1197,7 @@ func (scope *Scope) autoMigrate() *Scope {
 | 
			
		||||
	return scope
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TODO: context variant
 | 
			
		||||
func (scope *Scope) autoIndex() *Scope {
 | 
			
		||||
	var indexes = map[string][]string{}
 | 
			
		||||
	var uniqueIndexes = map[string][]string{}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										37
									
								
								scope_go1.8.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								scope_go1.8.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,37 @@
 | 
			
		||||
// +build go1.8
 | 
			
		||||
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import "database/sql"
 | 
			
		||||
 | 
			
		||||
// BeginTx start a transaction with the given options
 | 
			
		||||
func (scope *Scope) BeginTx(opts *sql.TxOptions) *Scope {
 | 
			
		||||
	if db, ok := scope.SQLDB().(sqlDb); ok {
 | 
			
		||||
		if tx, err := db.BeginTx(scope.DB().contextOrBackground(), opts); err == nil {
 | 
			
		||||
			scope.db.db = interface{}(tx).(SQLCommon)
 | 
			
		||||
			scope.InstanceSet("gorm:started_transaction", true)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return scope
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Begin start a transaction
 | 
			
		||||
func (scope *Scope) Begin() *Scope {
 | 
			
		||||
	return scope.BeginTx(nil)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (scope *Scope) sqldbExec(query string, args ...interface{}) (sql.Result, error) {
 | 
			
		||||
	return scope.SQLDB().ExecContext(scope.db.contextOrBackground(), query, args...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (scope *Scope) sqldbPrepare(query string) (*sql.Stmt, error) {
 | 
			
		||||
	return scope.SQLDB().PrepareContext(scope.db.contextOrBackground(), query)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (scope *Scope) sqldbQuery(query string, args ...interface{}) (*sql.Rows, error) {
 | 
			
		||||
	return scope.SQLDB().QueryContext(scope.db.contextOrBackground(), query, args...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (scope *Scope) sqldbQueryRow(query string, args ...interface{}) *sql.Row {
 | 
			
		||||
	return scope.SQLDB().QueryRowContext(scope.db.contextOrBackground(), query, args...)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										32
									
								
								scope_go1.8pre.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								scope_go1.8pre.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,32 @@
 | 
			
		||||
// +build !go1.8
 | 
			
		||||
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import "database/sql"
 | 
			
		||||
 | 
			
		||||
// Begin start a transaction
 | 
			
		||||
func (scope *Scope) Begin() *Scope {
 | 
			
		||||
	if db, ok := scope.SQLDB().(sqlDb); ok {
 | 
			
		||||
		if tx, err := db.Begin(); err == nil {
 | 
			
		||||
			scope.db.db = interface{}(tx).(SQLCommon)
 | 
			
		||||
			scope.InstanceSet("gorm:started_transaction", true)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return scope
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (scope *Scope) sqldbExec(query string, args ...interface{}) (sql.Result, error) {
 | 
			
		||||
	return scope.SQLDB().Exec(query, args...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (scope *Scope) sqldbPrepare(query string) (*sql.Stmt, error) {
 | 
			
		||||
	return scope.SQLDB().Prepare(query)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (scope *Scope) sqldbQuery(query string, args ...interface{}) (*sql.Rows, error) {
 | 
			
		||||
	return scope.SQLDB().Query(query, args...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (scope *Scope) sqldbQueryRow(query string, args ...interface{}) *sql.Row {
 | 
			
		||||
	return scope.SQLDB().QueryRow(query, args...)
 | 
			
		||||
}
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user