Implementatoin of possibility to reconnect

This commit is contained in:
Ruslan Fedoseenko 2018-01-13 15:18:25 +03:00
parent 0a51f6cdc5
commit 884eee4128
13 changed files with 128 additions and 26 deletions

View File

@ -50,6 +50,8 @@ func updateTimeStampForCreateCallback(scope *Scope) {
// createCallback the callback used to insert data into database // createCallback the callback used to insert data into database
func createCallback(scope *Scope) { func createCallback(scope *Scope) {
if !scope.HasError() { if !scope.HasError() {
scope.db.reconnectGuard.Wait()
defer scope.trace(NowFunc()) defer scope.trace(NowFunc())
var ( var (

View File

@ -28,6 +28,8 @@ func beforeDeleteCallback(scope *Scope) {
// deleteCallback used to delete data from database or set deleted_at to current time (when using with soft delete) // deleteCallback used to delete data from database or set deleted_at to current time (when using with soft delete)
func deleteCallback(scope *Scope) { func deleteCallback(scope *Scope) {
if !scope.HasError() { if !scope.HasError() {
scope.db.reconnectGuard.Wait()
var extraOption string var extraOption string
if str, ok := scope.Get("gorm:delete_option"); ok { if str, ok := scope.Get("gorm:delete_option"); ok {
extraOption = fmt.Sprint(str) extraOption = fmt.Sprint(str)

View File

@ -16,6 +16,7 @@ func init() {
// queryCallback used to query data from database // queryCallback used to query data from database
func queryCallback(scope *Scope) { func queryCallback(scope *Scope) {
defer scope.trace(NowFunc()) defer scope.trace(NowFunc())
scope.db.reconnectGuard.Wait()
var ( var (
isSlice, isPtr bool isSlice, isPtr bool

62
callback_reconnect.go Normal file
View File

@ -0,0 +1,62 @@
package gorm
import (
"time"
)
//define callbacks to reconnect in case of failure
func init() {
DefaultCallback.Create().After("gorm:create").Register("gorm:begin_reconnect", performReconnect)
DefaultCallback.Update().After("gorm:update").Register("gorm:begin_reconnect", performReconnect)
DefaultCallback.Delete().After("gorm:delete").Register("gorm:begin_reconnect", performReconnect)
DefaultCallback.Query().After("gorm:query").Register("gorm:begin_reconnect", performReconnect)
}
//May be do some kind of settings?
const reconnectAttempts = 5
const reconnectInterval = 5 * time.Second
//performReconnect the callback used to peform some reconnect attempts in case of disconnect
func performReconnect(scope *Scope) {
if scope.HasError() {
scope.db.reconnectGuard.Add(1)
defer scope.db.reconnectGuard.Done()
err := scope.db.Error
if scope.db.dialect.IsDisconnectError(err) {
for i := 0; i < reconnectAttempts; i++ {
newDb, openErr := Open(scope.db.dialectName, scope.db.dialectArgs...)
if openErr == nil {
oldDb := scope.db
if oldDb.parent != oldDb {
//In case of cloned db try to fix parents
//It is thread safe as we share mutex between instances
fixParentDbs(oldDb, newDb)
}
*scope.db = *newDb
break
} else {
//wait for interval and try to reconnect again
<-time.After(reconnectInterval)
}
}
}
}
}
func fixParentDbs(current, newDb *DB) {
iterator := current
parent := current.parent
for {
oldParent := parent
*parent = *newDb
parent = oldParent.parent
iterator = oldParent
if iterator == parent {
break
}
}
}

View File

@ -18,6 +18,8 @@ type RowsQueryResult struct {
// queryCallback used to query data from database // queryCallback used to query data from database
func rowQueryCallback(scope *Scope) { func rowQueryCallback(scope *Scope) {
scope.db.reconnectGuard.Wait()
if result, ok := scope.InstanceGet("row_query_result"); ok { if result, ok := scope.InstanceGet("row_query_result"); ok {
scope.prepareQuerySQL() scope.prepareQuerySQL()

View File

@ -56,6 +56,8 @@ func updateTimeStampForUpdateCallback(scope *Scope) {
// updateCallback the callback used to update data to database // updateCallback the callback used to update data to database
func updateCallback(scope *Scope) { func updateCallback(scope *Scope) {
if !scope.HasError() { if !scope.HasError() {
scope.db.reconnectGuard.Wait()
var sqls []string var sqls []string
if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {

View File

@ -46,6 +46,9 @@ type Dialect interface {
// CurrentDatabase return current database name // CurrentDatabase return current database name
CurrentDatabase() string CurrentDatabase() string
//
IsDisconnectError(err error) bool
} }
var dialectsMap = map[string]Dialect{} var dialectsMap = map[string]Dialect{}

View File

@ -1,6 +1,7 @@
package gorm package gorm
import ( import (
"database/sql/driver"
"fmt" "fmt"
"reflect" "reflect"
"regexp" "regexp"
@ -21,7 +22,9 @@ type commonDialect struct {
func init() { func init() {
RegisterDialect("common", &commonDialect{}) RegisterDialect("common", &commonDialect{})
} }
func (commonDialect) IsDisconnectError(err error) bool {
return err == driver.ErrBadConn
}
func (commonDialect) GetName() string { func (commonDialect) GetName() string {
return "common" return "common"
} }

View File

@ -2,7 +2,9 @@ package gorm
import ( import (
"crypto/sha1" "crypto/sha1"
"database/sql/driver"
"fmt" "fmt"
"github.com/go-sql-driver/mysql"
"reflect" "reflect"
"regexp" "regexp"
"strconv" "strconv"
@ -11,24 +13,28 @@ import (
"unicode/utf8" "unicode/utf8"
) )
type mysql struct { type mysqlDialect struct {
commonDialect commonDialect
} }
func init() { func init() {
RegisterDialect("mysql", &mysql{}) RegisterDialect("mysql", &mysqlDialect{})
} }
func (mysql) GetName() string { func (mysqlDialect) IsDisconnectError(err error) bool {
return err == mysql.ErrInvalidConn || err == driver.ErrBadConn
}
func (mysqlDialect) GetName() string {
return "mysql" return "mysql"
} }
func (mysql) Quote(key string) string { func (mysqlDialect) Quote(key string) string {
return fmt.Sprintf("`%s`", key) return fmt.Sprintf("`%s`", key)
} }
// Get Data Type for MySQL Dialect // Get Data Type for MySQL Dialect
func (s *mysql) DataTypeOf(field *StructField) string { func (s *mysqlDialect) DataTypeOf(field *StructField) string {
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)
// MySQL allows only one auto increment column per table, and it must // MySQL allows only one auto increment column per table, and it must
@ -122,12 +128,12 @@ func (s *mysql) DataTypeOf(field *StructField) string {
return fmt.Sprintf("%v %v", sqlType, additionalType) return fmt.Sprintf("%v %v", sqlType, additionalType)
} }
func (s mysql) RemoveIndex(tableName string, indexName string) error { func (s mysqlDialect) RemoveIndex(tableName string, indexName string) error {
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName))) _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
return err return err
} }
func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { func (s mysqlDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
if limit != nil { if limit != nil {
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 { if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
sql += fmt.Sprintf(" LIMIT %d", parsedLimit) sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
@ -142,22 +148,22 @@ func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
return return
} }
func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool { func (s mysqlDialect) HasForeignKey(tableName string, foreignKeyName string) bool {
var count int var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", s.CurrentDatabase(), tableName, foreignKeyName).Scan(&count) s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", s.CurrentDatabase(), tableName, foreignKeyName).Scan(&count)
return count > 0 return count > 0
} }
func (s mysql) CurrentDatabase() (name string) { func (s mysqlDialect) CurrentDatabase() (name string) {
s.db.QueryRow("SELECT DATABASE()").Scan(&name) s.db.QueryRow("SELECT DATABASE()").Scan(&name)
return return
} }
func (mysql) SelectFromDummyTable() string { func (mysqlDialect) SelectFromDummyTable() string {
return "FROM DUAL" return "FROM DUAL"
} }
func (s mysql) BuildForeignKeyName(tableName, field, dest string) string { func (s mysqlDialect) BuildForeignKeyName(tableName, field, dest string) string {
keyName := s.commonDialect.BuildForeignKeyName(tableName, field, dest) keyName := s.commonDialect.BuildForeignKeyName(tableName, field, dest)
if utf8.RuneCountInString(keyName) <= 64 { if utf8.RuneCountInString(keyName) <= 64 {
return keyName return keyName

View File

@ -1,6 +1,7 @@
package gorm package gorm
import ( import (
"database/sql/driver"
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
@ -19,6 +20,9 @@ func init() {
func (postgres) GetName() string { func (postgres) GetName() string {
return "postgres" return "postgres"
} }
func (postgres) IsDisconnectError(err error) bool {
return err == driver.ErrBadConn
}
func (postgres) BindVar(i int) string { func (postgres) BindVar(i int) string {
return fmt.Sprintf("$%v", i) return fmt.Sprintf("$%v", i)

View File

@ -14,7 +14,10 @@ type sqlite3 struct {
func init() { func init() {
RegisterDialect("sqlite3", &sqlite3{}) RegisterDialect("sqlite3", &sqlite3{})
} }
func (sqlite3) IsDisconnectError(err error) bool {
//sqlite is a file db so i think no reconnect needed or it is possible to treat file errors as disconnect
return false
}
func (sqlite3) GetName() string { func (sqlite3) GetName() string {
return "sqlite3" return "sqlite3"
} }

View File

@ -4,11 +4,11 @@ import (
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
_ "github.com/lib/pq"
"github.com/lib/pq/hstore"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
_ "github.com/lib/pq"
"github.com/lib/pq/hstore"
) )
type Hstore map[string]*string type Hstore map[string]*string

34
main.go
View File

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
"sync"
"time" "time"
) )
@ -22,7 +23,11 @@ type DB struct {
logger logger logger logger
search *search search *search
values map[string]interface{} values map[string]interface{}
//Mutex to block all operations during reconnection
reconnectGuard *sync.WaitGroup
//connection string passed to Open saved for reconnect
dialectName string
dialectArgs []interface{}
// global db // global db
parent *DB parent *DB
callbacks *Callback callbacks *Callback
@ -64,11 +69,14 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
} }
db = &DB{ db = &DB{
db: dbSQL, db: dbSQL,
logger: defaultLogger, logger: defaultLogger,
values: map[string]interface{}{}, values: map[string]interface{}{},
callbacks: DefaultCallback, callbacks: DefaultCallback,
dialect: newDialect(dialect, dbSQL), dialect: newDialect(dialect, dbSQL),
reconnectGuard: &sync.WaitGroup{},
dialectArgs: args,
dialectName: dialect,
} }
db.parent = db db.parent = db
if err != nil { if err != nil {
@ -712,11 +720,15 @@ func (s *DB) GetErrors() []error {
func (s *DB) clone() *DB { func (s *DB) clone() *DB {
db := &DB{ db := &DB{
db: s.db, db: s.db,
parent: s.parent, parent: s.parent,
logger: s.logger, logger: s.logger,
logMode: s.logMode, logMode: s.logMode,
values: map[string]interface{}{}, values: map[string]interface{}{},
//share reconnection info with new copy
reconnectGuard: s.reconnectGuard,
dialectName: s.dialectName,
dialectArgs: s.dialectArgs,
Value: s.Value, Value: s.Value,
Error: s.Error, Error: s.Error,
blockGlobalUpdate: s.blockGlobalUpdate, blockGlobalUpdate: s.blockGlobalUpdate,