Merge 884eee41282ff2633c87661ffd49af8f009614d5 into ca46ec0770003aab3c0ed7d7b336643362221c21
This commit is contained in:
		
						commit
						35fbb1ba01
					
				@ -50,6 +50,8 @@ func updateTimeStampForCreateCallback(scope *Scope) {
 | 
			
		||||
// createCallback the callback used to insert data into database
 | 
			
		||||
func createCallback(scope *Scope) {
 | 
			
		||||
	if !scope.HasError() {
 | 
			
		||||
		scope.db.reconnectGuard.Wait()
 | 
			
		||||
 | 
			
		||||
		defer scope.trace(NowFunc())
 | 
			
		||||
 | 
			
		||||
		var (
 | 
			
		||||
 | 
			
		||||
@ -28,6 +28,8 @@ func beforeDeleteCallback(scope *Scope) {
 | 
			
		||||
// deleteCallback used to delete data from database or set deleted_at to current time (when using with soft delete)
 | 
			
		||||
func deleteCallback(scope *Scope) {
 | 
			
		||||
	if !scope.HasError() {
 | 
			
		||||
		scope.db.reconnectGuard.Wait()
 | 
			
		||||
 | 
			
		||||
		var extraOption string
 | 
			
		||||
		if str, ok := scope.Get("gorm:delete_option"); ok {
 | 
			
		||||
			extraOption = fmt.Sprint(str)
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,7 @@ func init() {
 | 
			
		||||
// queryCallback used to query data from database
 | 
			
		||||
func queryCallback(scope *Scope) {
 | 
			
		||||
	defer scope.trace(NowFunc())
 | 
			
		||||
	scope.db.reconnectGuard.Wait()
 | 
			
		||||
 | 
			
		||||
	var (
 | 
			
		||||
		isSlice, isPtr bool
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										62
									
								
								callback_reconnect.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								callback_reconnect.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,62 @@
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
//define callbacks to reconnect in case of failure
 | 
			
		||||
func init() {
 | 
			
		||||
	DefaultCallback.Create().After("gorm:create").Register("gorm:begin_reconnect", performReconnect)
 | 
			
		||||
	DefaultCallback.Update().After("gorm:update").Register("gorm:begin_reconnect", performReconnect)
 | 
			
		||||
	DefaultCallback.Delete().After("gorm:delete").Register("gorm:begin_reconnect", performReconnect)
 | 
			
		||||
	DefaultCallback.Query().After("gorm:query").Register("gorm:begin_reconnect", performReconnect)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//May be do some kind of settings?
 | 
			
		||||
const reconnectAttempts = 5
 | 
			
		||||
const reconnectInterval = 5 * time.Second
 | 
			
		||||
 | 
			
		||||
//performReconnect the callback used to peform some reconnect attempts in case of disconnect
 | 
			
		||||
func performReconnect(scope *Scope) {
 | 
			
		||||
	if scope.HasError() {
 | 
			
		||||
 | 
			
		||||
		scope.db.reconnectGuard.Add(1)
 | 
			
		||||
		defer scope.db.reconnectGuard.Done()
 | 
			
		||||
 | 
			
		||||
		err := scope.db.Error
 | 
			
		||||
 | 
			
		||||
		if scope.db.dialect.IsDisconnectError(err) {
 | 
			
		||||
			for i := 0; i < reconnectAttempts; i++ {
 | 
			
		||||
				newDb, openErr := Open(scope.db.dialectName, scope.db.dialectArgs...)
 | 
			
		||||
				if openErr == nil {
 | 
			
		||||
					oldDb := scope.db
 | 
			
		||||
					if oldDb.parent != oldDb {
 | 
			
		||||
						//In case of cloned db try to fix parents
 | 
			
		||||
						//It is thread safe as we share mutex between instances
 | 
			
		||||
						fixParentDbs(oldDb, newDb)
 | 
			
		||||
					}
 | 
			
		||||
					*scope.db = *newDb
 | 
			
		||||
					break
 | 
			
		||||
				} else {
 | 
			
		||||
					//wait for interval and try to reconnect again
 | 
			
		||||
					<-time.After(reconnectInterval)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func fixParentDbs(current, newDb *DB) {
 | 
			
		||||
	iterator := current
 | 
			
		||||
	parent := current.parent
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
		oldParent := parent
 | 
			
		||||
		*parent = *newDb
 | 
			
		||||
		parent = oldParent.parent
 | 
			
		||||
		iterator = oldParent
 | 
			
		||||
		if iterator == parent {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -18,6 +18,8 @@ type RowsQueryResult struct {
 | 
			
		||||
 | 
			
		||||
// queryCallback used to query data from database
 | 
			
		||||
func rowQueryCallback(scope *Scope) {
 | 
			
		||||
	scope.db.reconnectGuard.Wait()
 | 
			
		||||
 | 
			
		||||
	if result, ok := scope.InstanceGet("row_query_result"); ok {
 | 
			
		||||
		scope.prepareQuerySQL()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -56,6 +56,8 @@ func updateTimeStampForUpdateCallback(scope *Scope) {
 | 
			
		||||
// updateCallback the callback used to update data to database
 | 
			
		||||
func updateCallback(scope *Scope) {
 | 
			
		||||
	if !scope.HasError() {
 | 
			
		||||
		scope.db.reconnectGuard.Wait()
 | 
			
		||||
 | 
			
		||||
		var sqls []string
 | 
			
		||||
 | 
			
		||||
		if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
 | 
			
		||||
 | 
			
		||||
@ -46,6 +46,9 @@ type Dialect interface {
 | 
			
		||||
 | 
			
		||||
	// CurrentDatabase return current database name
 | 
			
		||||
	CurrentDatabase() string
 | 
			
		||||
 | 
			
		||||
	//
 | 
			
		||||
	IsDisconnectError(err error) bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var dialectsMap = map[string]Dialect{}
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,7 @@
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql/driver"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"regexp"
 | 
			
		||||
@ -21,7 +22,9 @@ type commonDialect struct {
 | 
			
		||||
func init() {
 | 
			
		||||
	RegisterDialect("common", &commonDialect{})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (commonDialect) IsDisconnectError(err error) bool {
 | 
			
		||||
	return err == driver.ErrBadConn
 | 
			
		||||
}
 | 
			
		||||
func (commonDialect) GetName() string {
 | 
			
		||||
	return "common"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -2,7 +2,9 @@ package gorm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto/sha1"
 | 
			
		||||
	"database/sql/driver"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/go-sql-driver/mysql"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"regexp"
 | 
			
		||||
	"strconv"
 | 
			
		||||
@ -11,24 +13,28 @@ import (
 | 
			
		||||
	"unicode/utf8"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type mysql struct {
 | 
			
		||||
type mysqlDialect struct {
 | 
			
		||||
	commonDialect
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	RegisterDialect("mysql", &mysql{})
 | 
			
		||||
	RegisterDialect("mysql", &mysqlDialect{})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (mysql) GetName() string {
 | 
			
		||||
func (mysqlDialect) IsDisconnectError(err error) bool {
 | 
			
		||||
	return err == mysql.ErrInvalidConn || err == driver.ErrBadConn
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (mysqlDialect) GetName() string {
 | 
			
		||||
	return "mysql"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (mysql) Quote(key string) string {
 | 
			
		||||
func (mysqlDialect) Quote(key string) string {
 | 
			
		||||
	return fmt.Sprintf("`%s`", key)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Get Data Type for MySQL Dialect
 | 
			
		||||
func (s *mysql) DataTypeOf(field *StructField) string {
 | 
			
		||||
func (s *mysqlDialect) DataTypeOf(field *StructField) string {
 | 
			
		||||
	var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)
 | 
			
		||||
 | 
			
		||||
	// MySQL allows only one auto increment column per table, and it must
 | 
			
		||||
@ -122,12 +128,12 @@ func (s *mysql) DataTypeOf(field *StructField) string {
 | 
			
		||||
	return fmt.Sprintf("%v %v", sqlType, additionalType)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s mysql) RemoveIndex(tableName string, indexName string) error {
 | 
			
		||||
func (s mysqlDialect) RemoveIndex(tableName string, indexName string) error {
 | 
			
		||||
	_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
 | 
			
		||||
func (s mysqlDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
 | 
			
		||||
	if limit != nil {
 | 
			
		||||
		if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
 | 
			
		||||
			sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
 | 
			
		||||
@ -142,22 +148,22 @@ func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool {
 | 
			
		||||
func (s mysqlDialect) HasForeignKey(tableName string, foreignKeyName string) bool {
 | 
			
		||||
	var count int
 | 
			
		||||
	s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", s.CurrentDatabase(), tableName, foreignKeyName).Scan(&count)
 | 
			
		||||
	return count > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s mysql) CurrentDatabase() (name string) {
 | 
			
		||||
func (s mysqlDialect) CurrentDatabase() (name string) {
 | 
			
		||||
	s.db.QueryRow("SELECT DATABASE()").Scan(&name)
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (mysql) SelectFromDummyTable() string {
 | 
			
		||||
func (mysqlDialect) SelectFromDummyTable() string {
 | 
			
		||||
	return "FROM DUAL"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s mysql) BuildForeignKeyName(tableName, field, dest string) string {
 | 
			
		||||
func (s mysqlDialect) BuildForeignKeyName(tableName, field, dest string) string {
 | 
			
		||||
	keyName := s.commonDialect.BuildForeignKeyName(tableName, field, dest)
 | 
			
		||||
	if utf8.RuneCountInString(keyName) <= 64 {
 | 
			
		||||
		return keyName
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,7 @@
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql/driver"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"strings"
 | 
			
		||||
@ -19,6 +20,9 @@ func init() {
 | 
			
		||||
func (postgres) GetName() string {
 | 
			
		||||
	return "postgres"
 | 
			
		||||
}
 | 
			
		||||
func (postgres) IsDisconnectError(err error) bool {
 | 
			
		||||
	return err == driver.ErrBadConn
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (postgres) BindVar(i int) string {
 | 
			
		||||
	return fmt.Sprintf("$%v", i)
 | 
			
		||||
 | 
			
		||||
@ -14,7 +14,10 @@ type sqlite3 struct {
 | 
			
		||||
func init() {
 | 
			
		||||
	RegisterDialect("sqlite3", &sqlite3{})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (sqlite3) IsDisconnectError(err error) bool {
 | 
			
		||||
	//sqlite is a file db so i think no reconnect needed or it is possible to treat file errors as disconnect
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
func (sqlite3) GetName() string {
 | 
			
		||||
	return "sqlite3"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -4,11 +4,11 @@ import (
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"database/sql/driver"
 | 
			
		||||
 | 
			
		||||
	_ "github.com/lib/pq"
 | 
			
		||||
	"github.com/lib/pq/hstore"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	_ "github.com/lib/pq"
 | 
			
		||||
	"github.com/lib/pq/hstore"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Hstore map[string]*string
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										34
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										34
									
								
								main.go
									
									
									
									
									
								
							@ -6,6 +6,7 @@ import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -22,7 +23,11 @@ type DB struct {
 | 
			
		||||
	logger            logger
 | 
			
		||||
	search            *search
 | 
			
		||||
	values            map[string]interface{}
 | 
			
		||||
 | 
			
		||||
	//Mutex to block all operations during reconnection
 | 
			
		||||
	reconnectGuard *sync.WaitGroup
 | 
			
		||||
	//connection string passed to Open saved for reconnect
 | 
			
		||||
	dialectName string
 | 
			
		||||
	dialectArgs []interface{}
 | 
			
		||||
	// global db
 | 
			
		||||
	parent        *DB
 | 
			
		||||
	callbacks     *Callback
 | 
			
		||||
@ -64,11 +69,14 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	db = &DB{
 | 
			
		||||
		db:        dbSQL,
 | 
			
		||||
		logger:    defaultLogger,
 | 
			
		||||
		values:    map[string]interface{}{},
 | 
			
		||||
		callbacks: DefaultCallback,
 | 
			
		||||
		dialect:   newDialect(dialect, dbSQL),
 | 
			
		||||
		db:             dbSQL,
 | 
			
		||||
		logger:         defaultLogger,
 | 
			
		||||
		values:         map[string]interface{}{},
 | 
			
		||||
		callbacks:      DefaultCallback,
 | 
			
		||||
		dialect:        newDialect(dialect, dbSQL),
 | 
			
		||||
		reconnectGuard: &sync.WaitGroup{},
 | 
			
		||||
		dialectArgs:    args,
 | 
			
		||||
		dialectName:    dialect,
 | 
			
		||||
	}
 | 
			
		||||
	db.parent = db
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@ -712,11 +720,15 @@ func (s *DB) GetErrors() []error {
 | 
			
		||||
 | 
			
		||||
func (s *DB) clone() *DB {
 | 
			
		||||
	db := &DB{
 | 
			
		||||
		db:                s.db,
 | 
			
		||||
		parent:            s.parent,
 | 
			
		||||
		logger:            s.logger,
 | 
			
		||||
		logMode:           s.logMode,
 | 
			
		||||
		values:            map[string]interface{}{},
 | 
			
		||||
		db:      s.db,
 | 
			
		||||
		parent:  s.parent,
 | 
			
		||||
		logger:  s.logger,
 | 
			
		||||
		logMode: s.logMode,
 | 
			
		||||
		values:  map[string]interface{}{},
 | 
			
		||||
		//share reconnection info with new copy
 | 
			
		||||
		reconnectGuard:    s.reconnectGuard,
 | 
			
		||||
		dialectName:       s.dialectName,
 | 
			
		||||
		dialectArgs:       s.dialectArgs,
 | 
			
		||||
		Value:             s.Value,
 | 
			
		||||
		Error:             s.Error,
 | 
			
		||||
		blockGlobalUpdate: s.blockGlobalUpdate,
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user