feat: ajust PreparedStmtDB unlock location and BuildCondition if logic (#4681)
This commit is contained in:
		
							parent
							
								
									c13f3011f9
								
							
						
					
					
						commit
						e3fc49a694
					
				| @ -32,14 +32,14 @@ func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) { | ||||
| 
 | ||||
| func (db *PreparedStmtDB) Close() { | ||||
| 	db.Mux.Lock() | ||||
| 	defer db.Mux.Unlock() | ||||
| 
 | ||||
| 	for _, query := range db.PreparedSQL { | ||||
| 		if stmt, ok := db.Stmts[query]; ok { | ||||
| 			delete(db.Stmts, query) | ||||
| 			go stmt.Close() | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	db.Mux.Unlock() | ||||
| } | ||||
| 
 | ||||
| func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) { | ||||
| @ -51,9 +51,10 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact | ||||
| 	db.Mux.RUnlock() | ||||
| 
 | ||||
| 	db.Mux.Lock() | ||||
| 	defer db.Mux.Unlock() | ||||
| 
 | ||||
| 	// double check
 | ||||
| 	if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { | ||||
| 		db.Mux.Unlock() | ||||
| 		return stmt, nil | ||||
| 	} else if ok { | ||||
| 		go stmt.Close() | ||||
| @ -64,7 +65,6 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact | ||||
| 		db.Stmts[query] = Stmt{Stmt: stmt, Transaction: isTransaction} | ||||
| 		db.PreparedSQL = append(db.PreparedSQL, query) | ||||
| 	} | ||||
| 	defer db.Mux.Unlock() | ||||
| 
 | ||||
| 	return db.Stmts[query], err | ||||
| } | ||||
| @ -83,9 +83,9 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args .. | ||||
| 		result, err = stmt.ExecContext(ctx, args...) | ||||
| 		if err != nil { | ||||
| 			db.Mux.Lock() | ||||
| 			defer db.Mux.Unlock() | ||||
| 			go stmt.Close() | ||||
| 			delete(db.Stmts, query) | ||||
| 			db.Mux.Unlock() | ||||
| 		} | ||||
| 	} | ||||
| 	return result, err | ||||
| @ -97,9 +97,10 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args . | ||||
| 		rows, err = stmt.QueryContext(ctx, args...) | ||||
| 		if err != nil { | ||||
| 			db.Mux.Lock() | ||||
| 			defer db.Mux.Unlock() | ||||
| 
 | ||||
| 			go stmt.Close() | ||||
| 			delete(db.Stmts, query) | ||||
| 			db.Mux.Unlock() | ||||
| 		} | ||||
| 	} | ||||
| 	return rows, err | ||||
| @ -138,9 +139,10 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args .. | ||||
| 		result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...) | ||||
| 		if err != nil { | ||||
| 			tx.PreparedStmtDB.Mux.Lock() | ||||
| 			defer tx.PreparedStmtDB.Mux.Unlock() | ||||
| 
 | ||||
| 			go stmt.Close() | ||||
| 			delete(tx.PreparedStmtDB.Stmts, query) | ||||
| 			tx.PreparedStmtDB.Mux.Unlock() | ||||
| 		} | ||||
| 	} | ||||
| 	return result, err | ||||
| @ -152,9 +154,10 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args . | ||||
| 		rows, err = tx.Tx.Stmt(stmt.Stmt).QueryContext(ctx, args...) | ||||
| 		if err != nil { | ||||
| 			tx.PreparedStmtDB.Mux.Lock() | ||||
| 			defer tx.PreparedStmtDB.Mux.Unlock() | ||||
| 
 | ||||
| 			go stmt.Close() | ||||
| 			delete(tx.PreparedStmtDB.Stmts, query) | ||||
| 			tx.PreparedStmtDB.Mux.Unlock() | ||||
| 		} | ||||
| 	} | ||||
| 	return rows, err | ||||
|  | ||||
							
								
								
									
										12
									
								
								statement.go
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								statement.go
									
									
									
									
									
								
							| @ -271,13 +271,19 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] | ||||
| 		if _, err := strconv.Atoi(s); err != nil { | ||||
| 			if s == "" && len(args) == 0 { | ||||
| 				return nil | ||||
| 			} else if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) { | ||||
| 			} | ||||
| 
 | ||||
| 			if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) { | ||||
| 				// looks like a where condition
 | ||||
| 				return []clause.Expression{clause.Expr{SQL: s, Vars: args}} | ||||
| 			} else if len(args) > 0 && strings.Contains(s, "@") { | ||||
| 			} | ||||
| 
 | ||||
| 			if len(args) > 0 && strings.Contains(s, "@") { | ||||
| 				// looks like a named query
 | ||||
| 				return []clause.Expression{clause.NamedExpr{SQL: s, Vars: args}} | ||||
| 			} else if len(args) == 1 { | ||||
| 			} | ||||
| 
 | ||||
| 			if len(args) == 1 { | ||||
| 				return []clause.Expression{clause.Eq{Column: s, Value: args[0]}} | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 heige
						heige