feat: ajust PreparedStmtDB unlock location and BuildCondition if logic
This commit is contained in:
		
							parent
							
								
									04f049c1da
								
							
						
					
					
						commit
						f266765612
					
				@ -32,14 +32,14 @@ func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
 | 
			
		||||
 | 
			
		||||
func (db *PreparedStmtDB) Close() {
 | 
			
		||||
	db.Mux.Lock()
 | 
			
		||||
	defer db.Mux.Unlock()
 | 
			
		||||
 | 
			
		||||
	for _, query := range db.PreparedSQL {
 | 
			
		||||
		if stmt, ok := db.Stmts[query]; ok {
 | 
			
		||||
			delete(db.Stmts, query)
 | 
			
		||||
			go stmt.Close()
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	db.Mux.Unlock()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
 | 
			
		||||
@ -51,9 +51,10 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact
 | 
			
		||||
	db.Mux.RUnlock()
 | 
			
		||||
 | 
			
		||||
	db.Mux.Lock()
 | 
			
		||||
	defer db.Mux.Unlock()
 | 
			
		||||
 | 
			
		||||
	// double check
 | 
			
		||||
	if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
 | 
			
		||||
		db.Mux.Unlock()
 | 
			
		||||
		return stmt, nil
 | 
			
		||||
	} else if ok {
 | 
			
		||||
		go stmt.Close()
 | 
			
		||||
@ -64,7 +65,6 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact
 | 
			
		||||
		db.Stmts[query] = Stmt{Stmt: stmt, Transaction: isTransaction}
 | 
			
		||||
		db.PreparedSQL = append(db.PreparedSQL, query)
 | 
			
		||||
	}
 | 
			
		||||
	defer db.Mux.Unlock()
 | 
			
		||||
 | 
			
		||||
	return db.Stmts[query], err
 | 
			
		||||
}
 | 
			
		||||
@ -83,9 +83,9 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ..
 | 
			
		||||
		result, err = stmt.ExecContext(ctx, args...)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			db.Mux.Lock()
 | 
			
		||||
			defer db.Mux.Unlock()
 | 
			
		||||
			go stmt.Close()
 | 
			
		||||
			delete(db.Stmts, query)
 | 
			
		||||
			db.Mux.Unlock()
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return result, err
 | 
			
		||||
@ -97,9 +97,10 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args .
 | 
			
		||||
		rows, err = stmt.QueryContext(ctx, args...)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			db.Mux.Lock()
 | 
			
		||||
			defer db.Mux.Unlock()
 | 
			
		||||
 | 
			
		||||
			go stmt.Close()
 | 
			
		||||
			delete(db.Stmts, query)
 | 
			
		||||
			db.Mux.Unlock()
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return rows, err
 | 
			
		||||
@ -138,9 +139,10 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
 | 
			
		||||
		result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			tx.PreparedStmtDB.Mux.Lock()
 | 
			
		||||
			defer tx.PreparedStmtDB.Mux.Unlock()
 | 
			
		||||
 | 
			
		||||
			go stmt.Close()
 | 
			
		||||
			delete(tx.PreparedStmtDB.Stmts, query)
 | 
			
		||||
			tx.PreparedStmtDB.Mux.Unlock()
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return result, err
 | 
			
		||||
@ -152,9 +154,10 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args .
 | 
			
		||||
		rows, err = tx.Tx.Stmt(stmt.Stmt).QueryContext(ctx, args...)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			tx.PreparedStmtDB.Mux.Lock()
 | 
			
		||||
			defer tx.PreparedStmtDB.Mux.Unlock()
 | 
			
		||||
 | 
			
		||||
			go stmt.Close()
 | 
			
		||||
			delete(tx.PreparedStmtDB.Stmts, query)
 | 
			
		||||
			tx.PreparedStmtDB.Mux.Unlock()
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return rows, err
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										12
									
								
								statement.go
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								statement.go
									
									
									
									
									
								
							@ -267,13 +267,19 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
 | 
			
		||||
		if _, err := strconv.Atoi(s); err != nil {
 | 
			
		||||
			if s == "" && len(args) == 0 {
 | 
			
		||||
				return nil
 | 
			
		||||
			} else if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) {
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) {
 | 
			
		||||
				// looks like a where condition
 | 
			
		||||
				return []clause.Expression{clause.Expr{SQL: s, Vars: args}}
 | 
			
		||||
			} else if len(args) > 0 && strings.Contains(s, "@") {
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if len(args) > 0 && strings.Contains(s, "@") {
 | 
			
		||||
				// looks like a named query
 | 
			
		||||
				return []clause.Expression{clause.NamedExpr{SQL: s, Vars: args}}
 | 
			
		||||
			} else if len(args) == 1 {
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if len(args) == 1 {
 | 
			
		||||
				return []clause.Expression{clause.Eq{Column: s, Value: args[0]}}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user