Implementatoin of possibility to reconnect
This commit is contained in:
parent
0a51f6cdc5
commit
884eee4128
@ -50,6 +50,8 @@ func updateTimeStampForCreateCallback(scope *Scope) {
|
|||||||
// createCallback the callback used to insert data into database
|
// createCallback the callback used to insert data into database
|
||||||
func createCallback(scope *Scope) {
|
func createCallback(scope *Scope) {
|
||||||
if !scope.HasError() {
|
if !scope.HasError() {
|
||||||
|
scope.db.reconnectGuard.Wait()
|
||||||
|
|
||||||
defer scope.trace(NowFunc())
|
defer scope.trace(NowFunc())
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -28,6 +28,8 @@ func beforeDeleteCallback(scope *Scope) {
|
|||||||
// deleteCallback used to delete data from database or set deleted_at to current time (when using with soft delete)
|
// deleteCallback used to delete data from database or set deleted_at to current time (when using with soft delete)
|
||||||
func deleteCallback(scope *Scope) {
|
func deleteCallback(scope *Scope) {
|
||||||
if !scope.HasError() {
|
if !scope.HasError() {
|
||||||
|
scope.db.reconnectGuard.Wait()
|
||||||
|
|
||||||
var extraOption string
|
var extraOption string
|
||||||
if str, ok := scope.Get("gorm:delete_option"); ok {
|
if str, ok := scope.Get("gorm:delete_option"); ok {
|
||||||
extraOption = fmt.Sprint(str)
|
extraOption = fmt.Sprint(str)
|
||||||
|
@ -16,6 +16,7 @@ func init() {
|
|||||||
// queryCallback used to query data from database
|
// queryCallback used to query data from database
|
||||||
func queryCallback(scope *Scope) {
|
func queryCallback(scope *Scope) {
|
||||||
defer scope.trace(NowFunc())
|
defer scope.trace(NowFunc())
|
||||||
|
scope.db.reconnectGuard.Wait()
|
||||||
|
|
||||||
var (
|
var (
|
||||||
isSlice, isPtr bool
|
isSlice, isPtr bool
|
||||||
|
62
callback_reconnect.go
Normal file
62
callback_reconnect.go
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
package gorm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
//define callbacks to reconnect in case of failure
|
||||||
|
func init() {
|
||||||
|
DefaultCallback.Create().After("gorm:create").Register("gorm:begin_reconnect", performReconnect)
|
||||||
|
DefaultCallback.Update().After("gorm:update").Register("gorm:begin_reconnect", performReconnect)
|
||||||
|
DefaultCallback.Delete().After("gorm:delete").Register("gorm:begin_reconnect", performReconnect)
|
||||||
|
DefaultCallback.Query().After("gorm:query").Register("gorm:begin_reconnect", performReconnect)
|
||||||
|
}
|
||||||
|
|
||||||
|
//May be do some kind of settings?
|
||||||
|
const reconnectAttempts = 5
|
||||||
|
const reconnectInterval = 5 * time.Second
|
||||||
|
|
||||||
|
//performReconnect the callback used to peform some reconnect attempts in case of disconnect
|
||||||
|
func performReconnect(scope *Scope) {
|
||||||
|
if scope.HasError() {
|
||||||
|
|
||||||
|
scope.db.reconnectGuard.Add(1)
|
||||||
|
defer scope.db.reconnectGuard.Done()
|
||||||
|
|
||||||
|
err := scope.db.Error
|
||||||
|
|
||||||
|
if scope.db.dialect.IsDisconnectError(err) {
|
||||||
|
for i := 0; i < reconnectAttempts; i++ {
|
||||||
|
newDb, openErr := Open(scope.db.dialectName, scope.db.dialectArgs...)
|
||||||
|
if openErr == nil {
|
||||||
|
oldDb := scope.db
|
||||||
|
if oldDb.parent != oldDb {
|
||||||
|
//In case of cloned db try to fix parents
|
||||||
|
//It is thread safe as we share mutex between instances
|
||||||
|
fixParentDbs(oldDb, newDb)
|
||||||
|
}
|
||||||
|
*scope.db = *newDb
|
||||||
|
break
|
||||||
|
} else {
|
||||||
|
//wait for interval and try to reconnect again
|
||||||
|
<-time.After(reconnectInterval)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func fixParentDbs(current, newDb *DB) {
|
||||||
|
iterator := current
|
||||||
|
parent := current.parent
|
||||||
|
|
||||||
|
for {
|
||||||
|
oldParent := parent
|
||||||
|
*parent = *newDb
|
||||||
|
parent = oldParent.parent
|
||||||
|
iterator = oldParent
|
||||||
|
if iterator == parent {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -18,6 +18,8 @@ type RowsQueryResult struct {
|
|||||||
|
|
||||||
// queryCallback used to query data from database
|
// queryCallback used to query data from database
|
||||||
func rowQueryCallback(scope *Scope) {
|
func rowQueryCallback(scope *Scope) {
|
||||||
|
scope.db.reconnectGuard.Wait()
|
||||||
|
|
||||||
if result, ok := scope.InstanceGet("row_query_result"); ok {
|
if result, ok := scope.InstanceGet("row_query_result"); ok {
|
||||||
scope.prepareQuerySQL()
|
scope.prepareQuerySQL()
|
||||||
|
|
||||||
|
@ -56,6 +56,8 @@ func updateTimeStampForUpdateCallback(scope *Scope) {
|
|||||||
// updateCallback the callback used to update data to database
|
// updateCallback the callback used to update data to database
|
||||||
func updateCallback(scope *Scope) {
|
func updateCallback(scope *Scope) {
|
||||||
if !scope.HasError() {
|
if !scope.HasError() {
|
||||||
|
scope.db.reconnectGuard.Wait()
|
||||||
|
|
||||||
var sqls []string
|
var sqls []string
|
||||||
|
|
||||||
if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
|
if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
|
||||||
|
@ -46,6 +46,9 @@ type Dialect interface {
|
|||||||
|
|
||||||
// CurrentDatabase return current database name
|
// CurrentDatabase return current database name
|
||||||
CurrentDatabase() string
|
CurrentDatabase() string
|
||||||
|
|
||||||
|
//
|
||||||
|
IsDisconnectError(err error) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
var dialectsMap = map[string]Dialect{}
|
var dialectsMap = map[string]Dialect{}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
@ -21,7 +22,9 @@ type commonDialect struct {
|
|||||||
func init() {
|
func init() {
|
||||||
RegisterDialect("common", &commonDialect{})
|
RegisterDialect("common", &commonDialect{})
|
||||||
}
|
}
|
||||||
|
func (commonDialect) IsDisconnectError(err error) bool {
|
||||||
|
return err == driver.ErrBadConn
|
||||||
|
}
|
||||||
func (commonDialect) GetName() string {
|
func (commonDialect) GetName() string {
|
||||||
return "common"
|
return "common"
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,9 @@ package gorm
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/sha1"
|
"crypto/sha1"
|
||||||
|
"database/sql/driver"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/go-sql-driver/mysql"
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -11,24 +13,28 @@ import (
|
|||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
)
|
)
|
||||||
|
|
||||||
type mysql struct {
|
type mysqlDialect struct {
|
||||||
commonDialect
|
commonDialect
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
RegisterDialect("mysql", &mysql{})
|
RegisterDialect("mysql", &mysqlDialect{})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mysql) GetName() string {
|
func (mysqlDialect) IsDisconnectError(err error) bool {
|
||||||
|
return err == mysql.ErrInvalidConn || err == driver.ErrBadConn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mysqlDialect) GetName() string {
|
||||||
return "mysql"
|
return "mysql"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mysql) Quote(key string) string {
|
func (mysqlDialect) Quote(key string) string {
|
||||||
return fmt.Sprintf("`%s`", key)
|
return fmt.Sprintf("`%s`", key)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get Data Type for MySQL Dialect
|
// Get Data Type for MySQL Dialect
|
||||||
func (s *mysql) DataTypeOf(field *StructField) string {
|
func (s *mysqlDialect) DataTypeOf(field *StructField) string {
|
||||||
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)
|
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)
|
||||||
|
|
||||||
// MySQL allows only one auto increment column per table, and it must
|
// MySQL allows only one auto increment column per table, and it must
|
||||||
@ -122,12 +128,12 @@ func (s *mysql) DataTypeOf(field *StructField) string {
|
|||||||
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mysql) RemoveIndex(tableName string, indexName string) error {
|
func (s mysqlDialect) RemoveIndex(tableName string, indexName string) error {
|
||||||
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
|
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
|
func (s mysqlDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
|
||||||
if limit != nil {
|
if limit != nil {
|
||||||
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
|
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
|
||||||
sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
|
sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
|
||||||
@ -142,22 +148,22 @@ func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool {
|
func (s mysqlDialect) HasForeignKey(tableName string, foreignKeyName string) bool {
|
||||||
var count int
|
var count int
|
||||||
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", s.CurrentDatabase(), tableName, foreignKeyName).Scan(&count)
|
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", s.CurrentDatabase(), tableName, foreignKeyName).Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mysql) CurrentDatabase() (name string) {
|
func (s mysqlDialect) CurrentDatabase() (name string) {
|
||||||
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
|
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mysql) SelectFromDummyTable() string {
|
func (mysqlDialect) SelectFromDummyTable() string {
|
||||||
return "FROM DUAL"
|
return "FROM DUAL"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s mysql) BuildForeignKeyName(tableName, field, dest string) string {
|
func (s mysqlDialect) BuildForeignKeyName(tableName, field, dest string) string {
|
||||||
keyName := s.commonDialect.BuildForeignKeyName(tableName, field, dest)
|
keyName := s.commonDialect.BuildForeignKeyName(tableName, field, dest)
|
||||||
if utf8.RuneCountInString(keyName) <= 64 {
|
if utf8.RuneCountInString(keyName) <= 64 {
|
||||||
return keyName
|
return keyName
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
@ -19,6 +20,9 @@ func init() {
|
|||||||
func (postgres) GetName() string {
|
func (postgres) GetName() string {
|
||||||
return "postgres"
|
return "postgres"
|
||||||
}
|
}
|
||||||
|
func (postgres) IsDisconnectError(err error) bool {
|
||||||
|
return err == driver.ErrBadConn
|
||||||
|
}
|
||||||
|
|
||||||
func (postgres) BindVar(i int) string {
|
func (postgres) BindVar(i int) string {
|
||||||
return fmt.Sprintf("$%v", i)
|
return fmt.Sprintf("$%v", i)
|
||||||
|
@ -14,7 +14,10 @@ type sqlite3 struct {
|
|||||||
func init() {
|
func init() {
|
||||||
RegisterDialect("sqlite3", &sqlite3{})
|
RegisterDialect("sqlite3", &sqlite3{})
|
||||||
}
|
}
|
||||||
|
func (sqlite3) IsDisconnectError(err error) bool {
|
||||||
|
//sqlite is a file db so i think no reconnect needed or it is possible to treat file errors as disconnect
|
||||||
|
return false
|
||||||
|
}
|
||||||
func (sqlite3) GetName() string {
|
func (sqlite3) GetName() string {
|
||||||
return "sqlite3"
|
return "sqlite3"
|
||||||
}
|
}
|
||||||
|
@ -4,11 +4,11 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
|
|
||||||
_ "github.com/lib/pq"
|
|
||||||
"github.com/lib/pq/hstore"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
_ "github.com/lib/pq"
|
||||||
|
"github.com/lib/pq/hstore"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Hstore map[string]*string
|
type Hstore map[string]*string
|
||||||
|
34
main.go
34
main.go
@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -22,7 +23,11 @@ type DB struct {
|
|||||||
logger logger
|
logger logger
|
||||||
search *search
|
search *search
|
||||||
values map[string]interface{}
|
values map[string]interface{}
|
||||||
|
//Mutex to block all operations during reconnection
|
||||||
|
reconnectGuard *sync.WaitGroup
|
||||||
|
//connection string passed to Open saved for reconnect
|
||||||
|
dialectName string
|
||||||
|
dialectArgs []interface{}
|
||||||
// global db
|
// global db
|
||||||
parent *DB
|
parent *DB
|
||||||
callbacks *Callback
|
callbacks *Callback
|
||||||
@ -64,11 +69,14 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
db = &DB{
|
db = &DB{
|
||||||
db: dbSQL,
|
db: dbSQL,
|
||||||
logger: defaultLogger,
|
logger: defaultLogger,
|
||||||
values: map[string]interface{}{},
|
values: map[string]interface{}{},
|
||||||
callbacks: DefaultCallback,
|
callbacks: DefaultCallback,
|
||||||
dialect: newDialect(dialect, dbSQL),
|
dialect: newDialect(dialect, dbSQL),
|
||||||
|
reconnectGuard: &sync.WaitGroup{},
|
||||||
|
dialectArgs: args,
|
||||||
|
dialectName: dialect,
|
||||||
}
|
}
|
||||||
db.parent = db
|
db.parent = db
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -712,11 +720,15 @@ func (s *DB) GetErrors() []error {
|
|||||||
|
|
||||||
func (s *DB) clone() *DB {
|
func (s *DB) clone() *DB {
|
||||||
db := &DB{
|
db := &DB{
|
||||||
db: s.db,
|
db: s.db,
|
||||||
parent: s.parent,
|
parent: s.parent,
|
||||||
logger: s.logger,
|
logger: s.logger,
|
||||||
logMode: s.logMode,
|
logMode: s.logMode,
|
||||||
values: map[string]interface{}{},
|
values: map[string]interface{}{},
|
||||||
|
//share reconnection info with new copy
|
||||||
|
reconnectGuard: s.reconnectGuard,
|
||||||
|
dialectName: s.dialectName,
|
||||||
|
dialectArgs: s.dialectArgs,
|
||||||
Value: s.Value,
|
Value: s.Value,
|
||||||
Error: s.Error,
|
Error: s.Error,
|
||||||
blockGlobalUpdate: s.blockGlobalUpdate,
|
blockGlobalUpdate: s.blockGlobalUpdate,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user