diff --git a/prepare_stmt.go b/prepare_stmt.go index 26599b81..fef7e234 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -24,21 +24,26 @@ type PreparedStmtDB struct { ConnPool } +func newPrepareStmtCache(prepareStmtLruConfig *PrepareStmtLruConfig) *StmtStore { + var stmts StmtStore + if prepareStmtLruConfig != nil && prepareStmtLruConfig.Open { + if prepareStmtLruConfig.Size <= 0 { + panic("LRU prepareStmtLruConfig.Size must > 0") + } + lru := &LruStmtStore{} + lru.NewLru(prepareStmtLruConfig.Size, prepareStmtLruConfig.TTL) + stmts = lru + } else { + defaultStmtStore := &DefaultStmtStore{} + stmts = defaultStmtStore.init() + } + return &stmts +} func NewPreparedStmtDB(connPool ConnPool, prepareStmtLruConfig *PrepareStmtLruConfig) *PreparedStmtDB { return &PreparedStmtDB{ ConnPool: connPool, - Stmts: func() StmtStore { - var stmts StmtStore - if prepareStmtLruConfig != nil && prepareStmtLruConfig.Open { - lru := &LruStmtStore{} - lru.NewLru(prepareStmtLruConfig.Size, prepareStmtLruConfig.TTL) - stmts = lru - } else { - stmts = &DefaultStmtStore{} - } - return stmts - }(), - Mux: &sync.RWMutex{}, + Stmts: *newPrepareStmtCache(prepareStmtLruConfig), + Mux: &sync.RWMutex{}, } } @@ -57,6 +62,9 @@ func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) { func (db *PreparedStmtDB) Close() { db.Mux.Lock() defer db.Mux.Unlock() + if db.Stmts == nil { + return + } for _, stmt := range db.Stmts.AllMap() { go func(s *Stmt) { @@ -74,7 +82,9 @@ func (db *PreparedStmtDB) Close() { func (sdb *PreparedStmtDB) Reset() { sdb.Mux.Lock() defer sdb.Mux.Unlock() - + if sdb.Stmts == nil { + return + } for _, stmt := range sdb.Stmts.AllMap() { go func(s *Stmt) { // make sure the stmt must finish preparation first @@ -84,34 +94,40 @@ func (sdb *PreparedStmtDB) Reset() { } }(stmt) } - sdb.Stmts = &DefaultStmtStore{} + defaultStmt := &DefaultStmtStore{} + defaultStmt.init() + sdb.Stmts = defaultStmt } func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) { db.Mux.RLock() - if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) { - db.Mux.RUnlock() - // wait for other goroutines prepared - <-stmt.prepared - if stmt.prepareErr != nil { - return Stmt{}, stmt.prepareErr - } + if db.Stmts != nil { + if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) { + db.Mux.RUnlock() + // wait for other goroutines prepared + <-stmt.prepared + if stmt.prepareErr != nil { + return Stmt{}, stmt.prepareErr + } - return *stmt, nil + return *stmt, nil + } } db.Mux.RUnlock() db.Mux.Lock() - // double check - if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) { - db.Mux.Unlock() - // wait for other goroutines prepared - <-stmt.prepared - if stmt.prepareErr != nil { - return Stmt{}, stmt.prepareErr - } + if db.Stmts != nil { + // double check + if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) { + db.Mux.Unlock() + // wait for other goroutines prepared + <-stmt.prepared + if stmt.prepareErr != nil { + return Stmt{}, stmt.prepareErr + } - return *stmt, nil + return *stmt, nil + } } // check db.Stmts first to avoid Segmentation Fault(setting value to nil map) // which cause by calling Close and executing SQL concurrently @@ -295,11 +311,15 @@ type StmtStore interface { AllMap() map[string]*Stmt } -// 默认的 map 实现 type DefaultStmtStore struct { defaultStmt map[string]*Stmt } +func (s *DefaultStmtStore) init() *DefaultStmtStore { + s.defaultStmt = make(map[string]*Stmt) + return s +} + func (s *DefaultStmtStore) AllMap() map[string]*Stmt { return s.defaultStmt } diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go index 20a4f730..566378e0 100644 --- a/tests/prepared_stmt_test.go +++ b/tests/prepared_stmt_test.go @@ -92,6 +92,48 @@ func TestPreparedStmtFromTransaction(t *testing.T) { tx2.Commit() } +func TestPreparedStmtLruFromTransaction(t *testing.T) { + db, _ := OpenTestConnection(&gorm.Config{PrepareStmt: true, PrepareStmtLruConfig: &gorm.PrepareStmtLruConfig{10, 20 * time.Second, true}}) + + tx := db.Begin() + defer func() { + if r := recover(); r != nil { + tx.Rollback() + } + }() + if err := tx.Error; err != nil { + t.Errorf("Failed to start transaction, got error %v\n", err) + } + + if err := tx.Where("name=?", "zzjin").Delete(&User{}).Error; err != nil { + tx.Rollback() + t.Errorf("Failed to run one transaction, got error %v\n", err) + } + + if err := tx.Create(&User{Name: "zzjin"}).Error; err != nil { + tx.Rollback() + t.Errorf("Failed to run one transaction, got error %v\n", err) + } + + if err := tx.Commit().Error; err != nil { + t.Errorf("Failed to commit transaction, got error %v\n", err) + } + + if result := db.Where("name=?", "zzjin").Delete(&User{}); result.Error != nil || result.RowsAffected != 1 { + t.Fatalf("Failed, got error: %v, rows affected: %v", result.Error, result.RowsAffected) + } + + tx2 := db.Begin() + if result := tx2.Where("name=?", "zzjin").Delete(&User{}); result.Error != nil || result.RowsAffected != 0 { + t.Fatalf("Failed, got error: %v, rows affected: %v", result.Error, result.RowsAffected) + } + tx2.Commit() + time.Sleep(time.Second * 40) + conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB) + AssertEqual(t, ok, true) + AssertEqual(t, len(conn.Stmts.AllMap()), 0) +} + func TestPreparedStmtDeadlock(t *testing.T) { tx, err := OpenTestConnection(&gorm.Config{}) AssertEqual(t, err, nil) @@ -117,8 +159,8 @@ func TestPreparedStmtDeadlock(t *testing.T) { conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB) AssertEqual(t, ok, true) - AssertEqual(t, len(conn.Stmts), 2) - for _, stmt := range conn.Stmts { + AssertEqual(t, len(conn.Stmts.AllMap()), 2) + for _, stmt := range conn.Stmts.AllMap() { if stmt == nil { t.Fatalf("stmt cannot bee nil") } @@ -155,7 +197,7 @@ func TestPreparedStmtReset(t *testing.T) { } pdb.Mux.Lock() - if len(pdb.Stmts) == 0 { + if len(pdb.Stmts.AllMap()) == 0 { pdb.Mux.Unlock() t.Fatalf("prepared stmt can not be empty") } @@ -164,7 +206,7 @@ func TestPreparedStmtReset(t *testing.T) { pdb.Reset() pdb.Mux.Lock() defer pdb.Mux.Unlock() - if len(pdb.Stmts) != 0 { + if len(pdb.Stmts.AllMap()) != 0 { t.Fatalf("prepared stmt should be empty") } }