Implementatoin of possibility to reconnect

This commit is contained in:
Ruslan Fedoseenko 2018-01-13 15:18:25 +03:00
parent 0a51f6cdc5
commit 884eee4128
13 changed files with 128 additions and 26 deletions

View File

@ -50,6 +50,8 @@ func updateTimeStampForCreateCallback(scope *Scope) {
// createCallback the callback used to insert data into database
func createCallback(scope *Scope) {
if !scope.HasError() {
scope.db.reconnectGuard.Wait()
defer scope.trace(NowFunc())
var (

View File

@ -28,6 +28,8 @@ func beforeDeleteCallback(scope *Scope) {
// deleteCallback used to delete data from database or set deleted_at to current time (when using with soft delete)
func deleteCallback(scope *Scope) {
if !scope.HasError() {
scope.db.reconnectGuard.Wait()
var extraOption string
if str, ok := scope.Get("gorm:delete_option"); ok {
extraOption = fmt.Sprint(str)

View File

@ -16,6 +16,7 @@ func init() {
// queryCallback used to query data from database
func queryCallback(scope *Scope) {
defer scope.trace(NowFunc())
scope.db.reconnectGuard.Wait()
var (
isSlice, isPtr bool

62
callback_reconnect.go Normal file
View File

@ -0,0 +1,62 @@
package gorm
import (
"time"
)
//define callbacks to reconnect in case of failure
func init() {
DefaultCallback.Create().After("gorm:create").Register("gorm:begin_reconnect", performReconnect)
DefaultCallback.Update().After("gorm:update").Register("gorm:begin_reconnect", performReconnect)
DefaultCallback.Delete().After("gorm:delete").Register("gorm:begin_reconnect", performReconnect)
DefaultCallback.Query().After("gorm:query").Register("gorm:begin_reconnect", performReconnect)
}
//May be do some kind of settings?
const reconnectAttempts = 5
const reconnectInterval = 5 * time.Second
//performReconnect the callback used to peform some reconnect attempts in case of disconnect
func performReconnect(scope *Scope) {
if scope.HasError() {
scope.db.reconnectGuard.Add(1)
defer scope.db.reconnectGuard.Done()
err := scope.db.Error
if scope.db.dialect.IsDisconnectError(err) {
for i := 0; i < reconnectAttempts; i++ {
newDb, openErr := Open(scope.db.dialectName, scope.db.dialectArgs...)
if openErr == nil {
oldDb := scope.db
if oldDb.parent != oldDb {
//In case of cloned db try to fix parents
//It is thread safe as we share mutex between instances
fixParentDbs(oldDb, newDb)
}
*scope.db = *newDb
break
} else {
//wait for interval and try to reconnect again
<-time.After(reconnectInterval)
}
}
}
}
}
func fixParentDbs(current, newDb *DB) {
iterator := current
parent := current.parent
for {
oldParent := parent
*parent = *newDb
parent = oldParent.parent
iterator = oldParent
if iterator == parent {
break
}
}
}

View File

@ -18,6 +18,8 @@ type RowsQueryResult struct {
// queryCallback used to query data from database
func rowQueryCallback(scope *Scope) {
scope.db.reconnectGuard.Wait()
if result, ok := scope.InstanceGet("row_query_result"); ok {
scope.prepareQuerySQL()

View File

@ -56,6 +56,8 @@ func updateTimeStampForUpdateCallback(scope *Scope) {
// updateCallback the callback used to update data to database
func updateCallback(scope *Scope) {
if !scope.HasError() {
scope.db.reconnectGuard.Wait()
var sqls []string
if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {

View File

@ -46,6 +46,9 @@ type Dialect interface {
// CurrentDatabase return current database name
CurrentDatabase() string
//
IsDisconnectError(err error) bool
}
var dialectsMap = map[string]Dialect{}

View File

@ -1,6 +1,7 @@
package gorm
import (
"database/sql/driver"
"fmt"
"reflect"
"regexp"
@ -21,7 +22,9 @@ type commonDialect struct {
func init() {
RegisterDialect("common", &commonDialect{})
}
func (commonDialect) IsDisconnectError(err error) bool {
return err == driver.ErrBadConn
}
func (commonDialect) GetName() string {
return "common"
}

View File

@ -2,7 +2,9 @@ package gorm
import (
"crypto/sha1"
"database/sql/driver"
"fmt"
"github.com/go-sql-driver/mysql"
"reflect"
"regexp"
"strconv"
@ -11,24 +13,28 @@ import (
"unicode/utf8"
)
type mysql struct {
type mysqlDialect struct {
commonDialect
}
func init() {
RegisterDialect("mysql", &mysql{})
RegisterDialect("mysql", &mysqlDialect{})
}
func (mysql) GetName() string {
func (mysqlDialect) IsDisconnectError(err error) bool {
return err == mysql.ErrInvalidConn || err == driver.ErrBadConn
}
func (mysqlDialect) GetName() string {
return "mysql"
}
func (mysql) Quote(key string) string {
func (mysqlDialect) Quote(key string) string {
return fmt.Sprintf("`%s`", key)
}
// Get Data Type for MySQL Dialect
func (s *mysql) DataTypeOf(field *StructField) string {
func (s *mysqlDialect) DataTypeOf(field *StructField) string {
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)
// MySQL allows only one auto increment column per table, and it must
@ -122,12 +128,12 @@ func (s *mysql) DataTypeOf(field *StructField) string {
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
func (s mysql) RemoveIndex(tableName string, indexName string) error {
func (s mysqlDialect) RemoveIndex(tableName string, indexName string) error {
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
return err
}
func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
func (s mysqlDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
if limit != nil {
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
@ -142,22 +148,22 @@ func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
return
}
func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool {
func (s mysqlDialect) HasForeignKey(tableName string, foreignKeyName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", s.CurrentDatabase(), tableName, foreignKeyName).Scan(&count)
return count > 0
}
func (s mysql) CurrentDatabase() (name string) {
func (s mysqlDialect) CurrentDatabase() (name string) {
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
return
}
func (mysql) SelectFromDummyTable() string {
func (mysqlDialect) SelectFromDummyTable() string {
return "FROM DUAL"
}
func (s mysql) BuildForeignKeyName(tableName, field, dest string) string {
func (s mysqlDialect) BuildForeignKeyName(tableName, field, dest string) string {
keyName := s.commonDialect.BuildForeignKeyName(tableName, field, dest)
if utf8.RuneCountInString(keyName) <= 64 {
return keyName

View File

@ -1,6 +1,7 @@
package gorm
import (
"database/sql/driver"
"fmt"
"reflect"
"strings"
@ -19,6 +20,9 @@ func init() {
func (postgres) GetName() string {
return "postgres"
}
func (postgres) IsDisconnectError(err error) bool {
return err == driver.ErrBadConn
}
func (postgres) BindVar(i int) string {
return fmt.Sprintf("$%v", i)

View File

@ -14,7 +14,10 @@ type sqlite3 struct {
func init() {
RegisterDialect("sqlite3", &sqlite3{})
}
func (sqlite3) IsDisconnectError(err error) bool {
//sqlite is a file db so i think no reconnect needed or it is possible to treat file errors as disconnect
return false
}
func (sqlite3) GetName() string {
return "sqlite3"
}

View File

@ -4,11 +4,11 @@ import (
"database/sql"
"database/sql/driver"
_ "github.com/lib/pq"
"github.com/lib/pq/hstore"
"encoding/json"
"errors"
"fmt"
_ "github.com/lib/pq"
"github.com/lib/pq/hstore"
)
type Hstore map[string]*string

34
main.go
View File

@ -6,6 +6,7 @@ import (
"fmt"
"reflect"
"strings"
"sync"
"time"
)
@ -22,7 +23,11 @@ type DB struct {
logger logger
search *search
values map[string]interface{}
//Mutex to block all operations during reconnection
reconnectGuard *sync.WaitGroup
//connection string passed to Open saved for reconnect
dialectName string
dialectArgs []interface{}
// global db
parent *DB
callbacks *Callback
@ -64,11 +69,14 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
}
db = &DB{
db: dbSQL,
logger: defaultLogger,
values: map[string]interface{}{},
callbacks: DefaultCallback,
dialect: newDialect(dialect, dbSQL),
db: dbSQL,
logger: defaultLogger,
values: map[string]interface{}{},
callbacks: DefaultCallback,
dialect: newDialect(dialect, dbSQL),
reconnectGuard: &sync.WaitGroup{},
dialectArgs: args,
dialectName: dialect,
}
db.parent = db
if err != nil {
@ -712,11 +720,15 @@ func (s *DB) GetErrors() []error {
func (s *DB) clone() *DB {
db := &DB{
db: s.db,
parent: s.parent,
logger: s.logger,
logMode: s.logMode,
values: map[string]interface{}{},
db: s.db,
parent: s.parent,
logger: s.logger,
logMode: s.logMode,
values: map[string]interface{}{},
//share reconnection info with new copy
reconnectGuard: s.reconnectGuard,
dialectName: s.dialectName,
dialectArgs: s.dialectArgs,
Value: s.Value,
Error: s.Error,
blockGlobalUpdate: s.blockGlobalUpdate,