diff --git a/callback_create.go b/callback_create.go index a4da39e8..456ebff2 100644 --- a/callback_create.go +++ b/callback_create.go @@ -50,6 +50,8 @@ func updateTimeStampForCreateCallback(scope *Scope) { // createCallback the callback used to insert data into database func createCallback(scope *Scope) { if !scope.HasError() { + scope.db.reconnectGuard.Wait() + defer scope.trace(NowFunc()) var ( diff --git a/callback_delete.go b/callback_delete.go index 73d90880..a8126cbb 100644 --- a/callback_delete.go +++ b/callback_delete.go @@ -28,6 +28,8 @@ func beforeDeleteCallback(scope *Scope) { // deleteCallback used to delete data from database or set deleted_at to current time (when using with soft delete) func deleteCallback(scope *Scope) { if !scope.HasError() { + scope.db.reconnectGuard.Wait() + var extraOption string if str, ok := scope.Get("gorm:delete_option"); ok { extraOption = fmt.Sprint(str) diff --git a/callback_query.go b/callback_query.go index 20e88161..a6d98f8c 100644 --- a/callback_query.go +++ b/callback_query.go @@ -16,6 +16,7 @@ func init() { // queryCallback used to query data from database func queryCallback(scope *Scope) { defer scope.trace(NowFunc()) + scope.db.reconnectGuard.Wait() var ( isSlice, isPtr bool diff --git a/callback_reconnect.go b/callback_reconnect.go new file mode 100644 index 00000000..c3cb6fe3 --- /dev/null +++ b/callback_reconnect.go @@ -0,0 +1,62 @@ +package gorm + +import ( + "time" +) + +//define callbacks to reconnect in case of failure +func init() { + DefaultCallback.Create().After("gorm:create").Register("gorm:begin_reconnect", performReconnect) + DefaultCallback.Update().After("gorm:update").Register("gorm:begin_reconnect", performReconnect) + DefaultCallback.Delete().After("gorm:delete").Register("gorm:begin_reconnect", performReconnect) + DefaultCallback.Query().After("gorm:query").Register("gorm:begin_reconnect", performReconnect) +} + +//May be do some kind of settings? +const reconnectAttempts = 5 +const reconnectInterval = 5 * time.Second + +//performReconnect the callback used to peform some reconnect attempts in case of disconnect +func performReconnect(scope *Scope) { + if scope.HasError() { + + scope.db.reconnectGuard.Add(1) + defer scope.db.reconnectGuard.Done() + + err := scope.db.Error + + if scope.db.dialect.IsDisconnectError(err) { + for i := 0; i < reconnectAttempts; i++ { + newDb, openErr := Open(scope.db.dialectName, scope.db.dialectArgs...) + if openErr == nil { + oldDb := scope.db + if oldDb.parent != oldDb { + //In case of cloned db try to fix parents + //It is thread safe as we share mutex between instances + fixParentDbs(oldDb, newDb) + } + *scope.db = *newDb + break + } else { + //wait for interval and try to reconnect again + <-time.After(reconnectInterval) + } + } + } + } +} + +func fixParentDbs(current, newDb *DB) { + iterator := current + parent := current.parent + + for { + oldParent := parent + *parent = *newDb + parent = oldParent.parent + iterator = oldParent + if iterator == parent { + break + } + } +} diff --git a/callback_row_query.go b/callback_row_query.go index c2ff4a08..8c6390af 100644 --- a/callback_row_query.go +++ b/callback_row_query.go @@ -18,6 +18,8 @@ type RowsQueryResult struct { // queryCallback used to query data from database func rowQueryCallback(scope *Scope) { + scope.db.reconnectGuard.Wait() + if result, ok := scope.InstanceGet("row_query_result"); ok { scope.prepareQuerySQL() diff --git a/callback_update.go b/callback_update.go index 6948439f..4080195c 100644 --- a/callback_update.go +++ b/callback_update.go @@ -56,6 +56,8 @@ func updateTimeStampForUpdateCallback(scope *Scope) { // updateCallback the callback used to update data to database func updateCallback(scope *Scope) { if !scope.HasError() { + scope.db.reconnectGuard.Wait() + var sqls []string if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { diff --git a/dialect.go b/dialect.go index e879588b..3d7309c4 100644 --- a/dialect.go +++ b/dialect.go @@ -46,6 +46,9 @@ type Dialect interface { // CurrentDatabase return current database name CurrentDatabase() string + + // + IsDisconnectError(err error) bool } var dialectsMap = map[string]Dialect{} diff --git a/dialect_common.go b/dialect_common.go index 7d0c3ce7..fd7c663b 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -1,6 +1,7 @@ package gorm import ( + "database/sql/driver" "fmt" "reflect" "regexp" @@ -21,7 +22,9 @@ type commonDialect struct { func init() { RegisterDialect("common", &commonDialect{}) } - +func (commonDialect) IsDisconnectError(err error) bool { + return err == driver.ErrBadConn +} func (commonDialect) GetName() string { return "common" } diff --git a/dialect_mysql.go b/dialect_mysql.go index 686ad1ee..c9858fd4 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -2,7 +2,9 @@ package gorm import ( "crypto/sha1" + "database/sql/driver" "fmt" + "github.com/go-sql-driver/mysql" "reflect" "regexp" "strconv" @@ -11,24 +13,28 @@ import ( "unicode/utf8" ) -type mysql struct { +type mysqlDialect struct { commonDialect } func init() { - RegisterDialect("mysql", &mysql{}) + RegisterDialect("mysql", &mysqlDialect{}) } -func (mysql) GetName() string { +func (mysqlDialect) IsDisconnectError(err error) bool { + return err == mysql.ErrInvalidConn || err == driver.ErrBadConn +} + +func (mysqlDialect) GetName() string { return "mysql" } -func (mysql) Quote(key string) string { +func (mysqlDialect) Quote(key string) string { return fmt.Sprintf("`%s`", key) } // Get Data Type for MySQL Dialect -func (s *mysql) DataTypeOf(field *StructField) string { +func (s *mysqlDialect) DataTypeOf(field *StructField) string { var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) // MySQL allows only one auto increment column per table, and it must @@ -122,12 +128,12 @@ func (s *mysql) DataTypeOf(field *StructField) string { return fmt.Sprintf("%v %v", sqlType, additionalType) } -func (s mysql) RemoveIndex(tableName string, indexName string) error { +func (s mysqlDialect) RemoveIndex(tableName string, indexName string) error { _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName))) return err } -func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { +func (s mysqlDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { if limit != nil { if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 { sql += fmt.Sprintf(" LIMIT %d", parsedLimit) @@ -142,22 +148,22 @@ func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { return } -func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool { +func (s mysqlDialect) HasForeignKey(tableName string, foreignKeyName string) bool { var count int s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", s.CurrentDatabase(), tableName, foreignKeyName).Scan(&count) return count > 0 } -func (s mysql) CurrentDatabase() (name string) { +func (s mysqlDialect) CurrentDatabase() (name string) { s.db.QueryRow("SELECT DATABASE()").Scan(&name) return } -func (mysql) SelectFromDummyTable() string { +func (mysqlDialect) SelectFromDummyTable() string { return "FROM DUAL" } -func (s mysql) BuildForeignKeyName(tableName, field, dest string) string { +func (s mysqlDialect) BuildForeignKeyName(tableName, field, dest string) string { keyName := s.commonDialect.BuildForeignKeyName(tableName, field, dest) if utf8.RuneCountInString(keyName) <= 64 { return keyName diff --git a/dialect_postgres.go b/dialect_postgres.go index 6fdf4df1..9b68dc3b 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -1,6 +1,7 @@ package gorm import ( + "database/sql/driver" "fmt" "reflect" "strings" @@ -19,6 +20,9 @@ func init() { func (postgres) GetName() string { return "postgres" } +func (postgres) IsDisconnectError(err error) bool { + return err == driver.ErrBadConn +} func (postgres) BindVar(i int) string { return fmt.Sprintf("$%v", i) diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go index de9c05cb..2672b3d5 100644 --- a/dialect_sqlite3.go +++ b/dialect_sqlite3.go @@ -14,7 +14,10 @@ type sqlite3 struct { func init() { RegisterDialect("sqlite3", &sqlite3{}) } - +func (sqlite3) IsDisconnectError(err error) bool { + //sqlite is a file db so i think no reconnect needed or it is possible to treat file errors as disconnect + return false +} func (sqlite3) GetName() string { return "sqlite3" } diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go index b8e76891..19583ae2 100644 --- a/dialects/postgres/postgres.go +++ b/dialects/postgres/postgres.go @@ -4,11 +4,11 @@ import ( "database/sql" "database/sql/driver" - _ "github.com/lib/pq" - "github.com/lib/pq/hstore" "encoding/json" "errors" "fmt" + _ "github.com/lib/pq" + "github.com/lib/pq/hstore" ) type Hstore map[string]*string diff --git a/main.go b/main.go index 16fa0b79..fb3e722f 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,7 @@ import ( "fmt" "reflect" "strings" + "sync" "time" ) @@ -22,7 +23,11 @@ type DB struct { logger logger search *search values map[string]interface{} - + //Mutex to block all operations during reconnection + reconnectGuard *sync.WaitGroup + //connection string passed to Open saved for reconnect + dialectName string + dialectArgs []interface{} // global db parent *DB callbacks *Callback @@ -64,11 +69,14 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) { } db = &DB{ - db: dbSQL, - logger: defaultLogger, - values: map[string]interface{}{}, - callbacks: DefaultCallback, - dialect: newDialect(dialect, dbSQL), + db: dbSQL, + logger: defaultLogger, + values: map[string]interface{}{}, + callbacks: DefaultCallback, + dialect: newDialect(dialect, dbSQL), + reconnectGuard: &sync.WaitGroup{}, + dialectArgs: args, + dialectName: dialect, } db.parent = db if err != nil { @@ -712,11 +720,15 @@ func (s *DB) GetErrors() []error { func (s *DB) clone() *DB { db := &DB{ - db: s.db, - parent: s.parent, - logger: s.logger, - logMode: s.logMode, - values: map[string]interface{}{}, + db: s.db, + parent: s.parent, + logger: s.logger, + logMode: s.logMode, + values: map[string]interface{}{}, + //share reconnection info with new copy + reconnectGuard: s.reconnectGuard, + dialectName: s.dialectName, + dialectArgs: s.dialectArgs, Value: s.Value, Error: s.Error, blockGlobalUpdate: s.blockGlobalUpdate,