Replace all use of *sql.DB with sqlCommon
Exporting sqlCommon as SQLCommon. This allows passing alternate implementations of the database connection, or wrapping the connection with middleware. This change didn't change any usages of the database variables. All usages were already only using the functions defined in SQLCommon. This does cause a breaking change in Dialect, since *sql.DB was referenced in the interface.
This commit is contained in:
		
							parent
							
								
									5409931a1b
								
							
						
					
					
						commit
						c3f3558f0e
					
				@ -14,7 +14,7 @@ type Dialect interface {
 | 
				
			|||||||
	GetName() string
 | 
						GetName() string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// SetDB set db for dialect
 | 
						// SetDB set db for dialect
 | 
				
			||||||
	SetDB(db *sql.DB)
 | 
						SetDB(db SQLCommon)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1
 | 
						// BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1
 | 
				
			||||||
	BindVar(i int) string
 | 
						BindVar(i int) string
 | 
				
			||||||
@ -50,7 +50,7 @@ type Dialect interface {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
var dialectsMap = map[string]Dialect{}
 | 
					var dialectsMap = map[string]Dialect{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func newDialect(name string, db *sql.DB) Dialect {
 | 
					func newDialect(name string, db SQLCommon) Dialect {
 | 
				
			||||||
	if value, ok := dialectsMap[name]; ok {
 | 
						if value, ok := dialectsMap[name]; ok {
 | 
				
			||||||
		dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect)
 | 
							dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect)
 | 
				
			||||||
		dialect.SetDB(db)
 | 
							dialect.SetDB(db)
 | 
				
			||||||
 | 
				
			|||||||
@ -1,7 +1,6 @@
 | 
				
			|||||||
package gorm
 | 
					package gorm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"database/sql"
 | 
					 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"reflect"
 | 
						"reflect"
 | 
				
			||||||
	"regexp"
 | 
						"regexp"
 | 
				
			||||||
@ -15,7 +14,7 @@ type DefaultForeignKeyNamer struct {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type commonDialect struct {
 | 
					type commonDialect struct {
 | 
				
			||||||
	db *sql.DB
 | 
						db SQLCommon
 | 
				
			||||||
	DefaultForeignKeyNamer
 | 
						DefaultForeignKeyNamer
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -27,7 +26,7 @@ func (commonDialect) GetName() string {
 | 
				
			|||||||
	return "common"
 | 
						return "common"
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (s *commonDialect) SetDB(db *sql.DB) {
 | 
					func (s *commonDialect) SetDB(db SQLCommon) {
 | 
				
			||||||
	s.db = db
 | 
						s.db = db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -1,7 +1,6 @@
 | 
				
			|||||||
package mssql
 | 
					package mssql
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"database/sql"
 | 
					 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"reflect"
 | 
						"reflect"
 | 
				
			||||||
	"strconv"
 | 
						"strconv"
 | 
				
			||||||
@ -24,7 +23,7 @@ func init() {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type mssql struct {
 | 
					type mssql struct {
 | 
				
			||||||
	db *sql.DB
 | 
						db gorm.SQLCommon
 | 
				
			||||||
	gorm.DefaultForeignKeyNamer
 | 
						gorm.DefaultForeignKeyNamer
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -32,7 +31,7 @@ func (mssql) GetName() string {
 | 
				
			|||||||
	return "mssql"
 | 
						return "mssql"
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (s *mssql) SetDB(db *sql.DB) {
 | 
					func (s *mssql) SetDB(db gorm.SQLCommon) {
 | 
				
			||||||
	s.db = db
 | 
						s.db = db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -2,7 +2,8 @@ package gorm
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import "database/sql"
 | 
					import "database/sql"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type sqlCommon interface {
 | 
					// SQLCommon is the minimal database connection functionality gorm requires.  Implemented by *sql.DB.
 | 
				
			||||||
 | 
					type SQLCommon interface {
 | 
				
			||||||
	Exec(query string, args ...interface{}) (sql.Result, error)
 | 
						Exec(query string, args ...interface{}) (sql.Result, error)
 | 
				
			||||||
	Prepare(query string) (*sql.Stmt, error)
 | 
						Prepare(query string) (*sql.Stmt, error)
 | 
				
			||||||
	Query(query string, args ...interface{}) (*sql.Rows, error)
 | 
						Query(query string, args ...interface{}) (*sql.Rows, error)
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										11
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										11
									
								
								main.go
									
									
									
									
									
								
							@ -16,7 +16,7 @@ type DB struct {
 | 
				
			|||||||
	RowsAffected int64
 | 
						RowsAffected int64
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// single db
 | 
						// single db
 | 
				
			||||||
	db                sqlCommon
 | 
						db                SQLCommon
 | 
				
			||||||
	blockGlobalUpdate bool
 | 
						blockGlobalUpdate bool
 | 
				
			||||||
	logMode           int
 | 
						logMode           int
 | 
				
			||||||
	logger            logger
 | 
						logger            logger
 | 
				
			||||||
@ -47,7 +47,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
 | 
				
			|||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	var source string
 | 
						var source string
 | 
				
			||||||
	var dbSQL *sql.DB
 | 
						var dbSQL SQLCommon
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	switch value := args[0].(type) {
 | 
						switch value := args[0].(type) {
 | 
				
			||||||
	case string:
 | 
						case string:
 | 
				
			||||||
@ -59,8 +59,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
 | 
				
			|||||||
			source = args[1].(string)
 | 
								source = args[1].(string)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		dbSQL, err = sql.Open(driver, source)
 | 
							dbSQL, err = sql.Open(driver, source)
 | 
				
			||||||
	case *sql.DB:
 | 
						case SQLCommon:
 | 
				
			||||||
		source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String()
 | 
					 | 
				
			||||||
		dbSQL = value
 | 
							dbSQL = value
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -104,7 +103,7 @@ func (s *DB) DB() *sql.DB {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// CommonDB return the underlying `*sql.DB` or `*sql.Tx` instance, mainly intended to allow coexistence with legacy non-GORM code.
 | 
					// CommonDB return the underlying `*sql.DB` or `*sql.Tx` instance, mainly intended to allow coexistence with legacy non-GORM code.
 | 
				
			||||||
func (s *DB) CommonDB() sqlCommon {
 | 
					func (s *DB) CommonDB() SQLCommon {
 | 
				
			||||||
	return s.db
 | 
						return s.db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -449,7 +448,7 @@ func (s *DB) Begin() *DB {
 | 
				
			|||||||
	c := s.clone()
 | 
						c := s.clone()
 | 
				
			||||||
	if db, ok := c.db.(sqlDb); ok {
 | 
						if db, ok := c.db.(sqlDb); ok {
 | 
				
			||||||
		tx, err := db.Begin()
 | 
							tx, err := db.Begin()
 | 
				
			||||||
		c.db = interface{}(tx).(sqlCommon)
 | 
							c.db = interface{}(tx).(SQLCommon)
 | 
				
			||||||
		c.AddError(err)
 | 
							c.AddError(err)
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		c.AddError(ErrCantStartTransaction)
 | 
							c.AddError(ErrCantStartTransaction)
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										4
									
								
								scope.go
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								scope.go
									
									
									
									
									
								
							@ -58,7 +58,7 @@ func (scope *Scope) NewDB() *DB {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// SQLDB return *sql.DB
 | 
					// SQLDB return *sql.DB
 | 
				
			||||||
func (scope *Scope) SQLDB() sqlCommon {
 | 
					func (scope *Scope) SQLDB() SQLCommon {
 | 
				
			||||||
	return scope.db.db
 | 
						return scope.db.db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -391,7 +391,7 @@ func (scope *Scope) InstanceGet(name string) (interface{}, bool) {
 | 
				
			|||||||
func (scope *Scope) Begin() *Scope {
 | 
					func (scope *Scope) Begin() *Scope {
 | 
				
			||||||
	if db, ok := scope.SQLDB().(sqlDb); ok {
 | 
						if db, ok := scope.SQLDB().(sqlDb); ok {
 | 
				
			||||||
		if tx, err := db.Begin(); err == nil {
 | 
							if tx, err := db.Begin(); err == nil {
 | 
				
			||||||
			scope.db.db = interface{}(tx).(sqlCommon)
 | 
								scope.db.db = interface{}(tx).(SQLCommon)
 | 
				
			||||||
			scope.InstanceSet("gorm:started_transaction", true)
 | 
								scope.InstanceSet("gorm:started_transaction", true)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user