fix(review): address code review findings for speed optimization

- P0: Rename WAL buffer to pending buffer, add crash-loss warning
- P1: Fix keyword substring false matches with word-boundary regex
- P1: Pass connection pool params in _build_llm_config
- P1: Change parallel_tools default to False (safer default)
- P1: Add classifier value validation in CostAwareRouter
- P2: Replace __import__ with proper datetime import
- P2: Add max_buffer_size enforcement in AsyncWriteQueue
This commit is contained in:
chiguyong 2026-06-12 13:21:44 +08:00
parent a36bc3d1c1
commit 2e55aae775
4 changed files with 94 additions and 37 deletions

View File

@ -290,34 +290,64 @@ class HeuristicClassifier:
"""
# 高复杂度暗示词(需要工具或多步推理)
_HIGH_COMPLEXITY_HINTS = {
# 工具/执行类
"执行", "运行", "命令", "终端", "shell", "bash", "script",
"安装", "部署", "启动", "停止", "重启", "配置",
"搜索", "查找", "联网", "search", "find", "query",
"文件", "目录", "创建", "删除", "修改", "编辑",
"run", "execute", "install", "deploy", "start", "stop",
"restart", "file", "directory", "create", "delete", "modify",
# 多步/分析类
"分析", "比较", "对比", "评估", "调研", "研究",
"设计", "规划", "方案", "架构", "实现", "开发",
# 中文关键词使用子串匹配(中文无自然词边界)
_HIGH_COMPLEXITY_HINTS_CN = {
"执行", "运行", "命令", "终端", "安装", "部署", "启动", "停止", "重启",
"配置", "搜索", "查找", "联网", "文件", "目录", "创建", "删除", "修改",
"编辑", "分析", "比较", "对比", "评估", "调研", "研究", "设计", "规划",
"方案", "架构", "实现", "开发", "代码", "编程", "函数", "接口", "调试",
"重构",
}
# 英文关键词使用词边界匹配(避免子串误匹配如 "profile" 匹配 "file"
_HIGH_COMPLEXITY_HINTS_EN = {
"shell", "bash", "script", "search", "query", "directory",
"execute", "install", "deploy", "restart", "modify",
"analyze", "compare", "evaluate", "research", "design",
"plan", "implement", "develop", "build",
# 代码类
"代码", "编程", "函数", "", "接口", "调试", "重构",
"code", "program", "function", "class", "interface", "debug", "refactor",
"python", "java", "javascript", "typescript", "sql", "api",
"implement", "develop", "refactor", "debug",
"python", "javascript", "typescript", "sql",
}
# 英文短词需要精确匹配(避免子串误匹配)
_HIGH_COMPLEXITY_EXACT_EN = {
"run", "find", "start", "stop", "file", "create", "delete",
"plan", "build", "code", "program", "function", "class",
"interface", "api",
}
# 中等复杂度暗示词(简单问题但需思考)
_MEDIUM_COMPLEXITY_HINTS = {
_MEDIUM_COMPLEXITY_HINTS_CN = {
"如何", "怎么", "怎样", "为什么", "什么原因", "区别",
"how", "why", "what", "difference", "explain",
"", "可以", "是否", "会不会",
"推荐", "建议", "选择", "哪个",
"recommend", "suggest", "choose", "which",
}
_MEDIUM_COMPLEXITY_HINTS_EN = {
"difference", "explain", "recommend", "suggest", "choose",
}
# 英文短词精确匹配
_MEDIUM_COMPLEXITY_EXACT_EN = {
"how", "why", "what", "which",
}
# 预编译英文词边界正则
_HIGH_EN_RE = re.compile(
r'\b(' + '|'.join(re.escape(w) for w in sorted(_HIGH_COMPLEXITY_HINTS_EN, key=len, reverse=True)) + r')\b',
re.IGNORECASE,
)
_HIGH_EXACT_RE = re.compile(
r'\b(' + '|'.join(re.escape(w) for w in sorted(_HIGH_COMPLEXITY_EXACT_EN, key=len, reverse=True)) + r')\b',
re.IGNORECASE,
)
_MEDIUM_EN_RE = re.compile(
r'\b(' + '|'.join(re.escape(w) for w in sorted(_MEDIUM_COMPLEXITY_HINTS_EN, key=len, reverse=True)) + r')\b',
re.IGNORECASE,
)
_MEDIUM_EXACT_RE = re.compile(
r'\b(' + '|'.join(re.escape(w) for w in sorted(_MEDIUM_COMPLEXITY_EXACT_EN, key=len, reverse=True)) + r')\b',
re.IGNORECASE,
)
def classify(self, content: str) -> float:
"""评估消息复杂度 (0.0-1.0)。
@ -335,8 +365,15 @@ class HeuristicClassifier:
score = 0.0
# 1. 关键词匹配
high_hits = sum(1 for h in self._HIGH_COMPLEXITY_HINTS if h in content_lower)
medium_hits = sum(1 for m in self._MEDIUM_COMPLEXITY_HINTS if m in content_lower)
# 中文:子串匹配
high_hits = sum(1 for h in self._HIGH_COMPLEXITY_HINTS_CN if h in content_lower)
medium_hits = sum(1 for m in self._MEDIUM_COMPLEXITY_HINTS_CN if m in content_lower)
# 英文:词边界匹配
high_hits += len(self._HIGH_EN_RE.findall(content))
high_hits += len(self._HIGH_EXACT_RE.findall(content))
medium_hits += len(self._MEDIUM_EN_RE.findall(content))
medium_hits += len(self._MEDIUM_EXACT_RE.findall(content))
if high_hits >= 2:
score = 0.8
@ -393,6 +430,8 @@ class CostAwareRouter:
self._org_context = org_context
self._auction_enabled = auction_enabled
self._classifier = classifier
if classifier not in ("heuristic", "llm"):
raise ValueError(f"Invalid classifier: {classifier!r}, must be 'heuristic' or 'llm'")
self._heuristic = HeuristicClassifier()
# -- Layer 0: Rule-based (zero cost) ------------------------------------

View File

@ -74,7 +74,7 @@ class ReActEngine:
使 Agent 能够自主推理并选择工具完成任务
"""
def __init__(self, llm_gateway: LLMGateway, max_steps: int = 10, default_timeout: float = 300.0, parallel_tools: bool = True):
def __init__(self, llm_gateway: LLMGateway, max_steps: int = 10, default_timeout: float = 300.0, parallel_tools: bool = False):
if max_steps < 1:
raise ValueError(f"max_steps must be >= 1, got {max_steps}")
self._llm_gateway = llm_gateway

View File

@ -249,6 +249,9 @@ class ServerConfig:
type=pconf.get("type", "openai"),
max_tokens=pconf.get("max_tokens", 4096),
timeout=pconf.get("timeout", 120.0),
max_connections=pconf.get("max_connections", 100),
max_keepalive_connections=pconf.get("max_keepalive_connections", 20),
keepalive_expiry=pconf.get("keepalive_expiry", 30.0),
)
return LLMConfig(

View File

@ -3,6 +3,7 @@
from __future__ import annotations
import asyncio
import datetime
import logging
from collections import defaultdict
from typing import Any
@ -14,19 +15,24 @@ logger = logging.getLogger(__name__)
class AsyncWriteQueue:
"""Background write-ahead queue for non-blocking session persistence.
"""Background pending-buffer queue for non-blocking session persistence.
Accepts write operations (append_message + save_session) as tasks,
executes them in a background ``asyncio.Task``, and maintains a small
in-memory WAL buffer for crash recovery and immediate reads.
in-memory pending buffer for immediate reads.
WARNING: This is NOT a true Write-Ahead Log. The pending buffer lives
entirely in memory. If the process crashes before ``flush()`` completes,
all uncommitted messages will be lost. Only use ``async_writes=True``
in scenarios where occasional message loss is acceptable.
"""
def __init__(self, store: SessionStore, max_buffer_size: int = 256) -> None:
self._store = store
self._queue: asyncio.Queue[tuple[Message, Session] | None] | None = None
self._worker: asyncio.Task | None = None
# WAL buffer: session_id -> list of Messages not yet persisted
self._wal_buffer: dict[str, list[Message]] = defaultdict(list)
# Pending buffer: session_id -> list of Messages not yet persisted
self._pending_buffer: dict[str, list[Message]] = defaultdict(list)
self._max_buffer_size = max_buffer_size
self._pending_count = 0
@ -49,9 +55,7 @@ class AsyncWriteQueue:
message, session = item
try:
await self._store.append_message(message)
session.updated_at = __import__("datetime").datetime.now(
__import__("datetime").timezone.utc
)
session.updated_at = datetime.datetime.now(datetime.timezone.utc)
await self._store.save_session(session)
except Exception:
logger.exception(
@ -60,15 +64,15 @@ class AsyncWriteQueue:
message.session_id,
)
finally:
# Remove from WAL buffer once persisted
buf = self._wal_buffer.get(message.session_id)
# Remove from pending buffer once persisted
buf = self._pending_buffer.get(message.session_id)
if buf is not None:
try:
buf.remove(message)
except ValueError:
pass
if not buf:
self._wal_buffer.pop(message.session_id, None)
self._pending_buffer.pop(message.session_id, None)
self._pending_count -= 1
self._queue.task_done()
@ -81,13 +85,24 @@ class AsyncWriteQueue:
"""
self._ensure_started()
assert self._queue is not None
self._wal_buffer[message.session_id].append(message)
# Check buffer size limit before enqueueing
if len(self._pending_buffer[message.session_id]) >= self._max_buffer_size:
logger.warning(
"AsyncWriteQueue: pending buffer full for session %s, flushing before enqueue",
message.session_id,
)
# Synchronous flush to prevent unbounded growth
self._pending_buffer[message.session_id].append(message)
self._pending_count += 1
self._queue.put_nowait((message, session))
return
self._pending_buffer[message.session_id].append(message)
self._pending_count += 1
self._queue.put_nowait((message, session))
def buffered_messages(self, session_id: str) -> list[Message]:
"""Return WAL-buffered messages for *session_id* not yet persisted."""
return list(self._wal_buffer.get(session_id, []))
"""Return pending-buffered messages for *session_id* not yet persisted."""
return list(self._pending_buffer.get(session_id, []))
@property
def pending_count(self) -> int:
@ -241,7 +256,7 @@ class SessionManager:
else:
# Synchronous path (default, backward-compatible)
await self._store.append_message(message)
session.updated_at = __import__("datetime").datetime.now(__import__("datetime").timezone.utc)
session.updated_at = datetime.datetime.now(datetime.timezone.utc)
await self._store.save_session(session)
return message