fix(review): address code review findings for speed optimization
- P0: Rename WAL buffer to pending buffer, add crash-loss warning - P1: Fix keyword substring false matches with word-boundary regex - P1: Pass connection pool params in _build_llm_config - P1: Change parallel_tools default to False (safer default) - P1: Add classifier value validation in CostAwareRouter - P2: Replace __import__ with proper datetime import - P2: Add max_buffer_size enforcement in AsyncWriteQueue
This commit is contained in:
parent
a36bc3d1c1
commit
2e55aae775
|
|
@ -290,34 +290,64 @@ class HeuristicClassifier:
|
|||
"""
|
||||
|
||||
# 高复杂度暗示词(需要工具或多步推理)
|
||||
_HIGH_COMPLEXITY_HINTS = {
|
||||
# 工具/执行类
|
||||
"执行", "运行", "命令", "终端", "shell", "bash", "script",
|
||||
"安装", "部署", "启动", "停止", "重启", "配置",
|
||||
"搜索", "查找", "联网", "search", "find", "query",
|
||||
"文件", "目录", "创建", "删除", "修改", "编辑",
|
||||
"run", "execute", "install", "deploy", "start", "stop",
|
||||
"restart", "file", "directory", "create", "delete", "modify",
|
||||
# 多步/分析类
|
||||
"分析", "比较", "对比", "评估", "调研", "研究",
|
||||
"设计", "规划", "方案", "架构", "实现", "开发",
|
||||
# 中文关键词使用子串匹配(中文无自然词边界)
|
||||
_HIGH_COMPLEXITY_HINTS_CN = {
|
||||
"执行", "运行", "命令", "终端", "安装", "部署", "启动", "停止", "重启",
|
||||
"配置", "搜索", "查找", "联网", "文件", "目录", "创建", "删除", "修改",
|
||||
"编辑", "分析", "比较", "对比", "评估", "调研", "研究", "设计", "规划",
|
||||
"方案", "架构", "实现", "开发", "代码", "编程", "函数", "接口", "调试",
|
||||
"重构",
|
||||
}
|
||||
|
||||
# 英文关键词使用词边界匹配(避免子串误匹配如 "profile" 匹配 "file")
|
||||
_HIGH_COMPLEXITY_HINTS_EN = {
|
||||
"shell", "bash", "script", "search", "query", "directory",
|
||||
"execute", "install", "deploy", "restart", "modify",
|
||||
"analyze", "compare", "evaluate", "research", "design",
|
||||
"plan", "implement", "develop", "build",
|
||||
# 代码类
|
||||
"代码", "编程", "函数", "类", "接口", "调试", "重构",
|
||||
"code", "program", "function", "class", "interface", "debug", "refactor",
|
||||
"python", "java", "javascript", "typescript", "sql", "api",
|
||||
"implement", "develop", "refactor", "debug",
|
||||
"python", "javascript", "typescript", "sql",
|
||||
}
|
||||
|
||||
# 英文短词需要精确匹配(避免子串误匹配)
|
||||
_HIGH_COMPLEXITY_EXACT_EN = {
|
||||
"run", "find", "start", "stop", "file", "create", "delete",
|
||||
"plan", "build", "code", "program", "function", "class",
|
||||
"interface", "api",
|
||||
}
|
||||
|
||||
# 中等复杂度暗示词(简单问题但需思考)
|
||||
_MEDIUM_COMPLEXITY_HINTS = {
|
||||
_MEDIUM_COMPLEXITY_HINTS_CN = {
|
||||
"如何", "怎么", "怎样", "为什么", "什么原因", "区别",
|
||||
"how", "why", "what", "difference", "explain",
|
||||
"能", "可以", "是否", "会不会",
|
||||
"推荐", "建议", "选择", "哪个",
|
||||
"recommend", "suggest", "choose", "which",
|
||||
}
|
||||
|
||||
_MEDIUM_COMPLEXITY_HINTS_EN = {
|
||||
"difference", "explain", "recommend", "suggest", "choose",
|
||||
}
|
||||
|
||||
# 英文短词精确匹配
|
||||
_MEDIUM_COMPLEXITY_EXACT_EN = {
|
||||
"how", "why", "what", "which",
|
||||
}
|
||||
|
||||
# 预编译英文词边界正则
|
||||
_HIGH_EN_RE = re.compile(
|
||||
r'\b(' + '|'.join(re.escape(w) for w in sorted(_HIGH_COMPLEXITY_HINTS_EN, key=len, reverse=True)) + r')\b',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
_HIGH_EXACT_RE = re.compile(
|
||||
r'\b(' + '|'.join(re.escape(w) for w in sorted(_HIGH_COMPLEXITY_EXACT_EN, key=len, reverse=True)) + r')\b',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
_MEDIUM_EN_RE = re.compile(
|
||||
r'\b(' + '|'.join(re.escape(w) for w in sorted(_MEDIUM_COMPLEXITY_HINTS_EN, key=len, reverse=True)) + r')\b',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
_MEDIUM_EXACT_RE = re.compile(
|
||||
r'\b(' + '|'.join(re.escape(w) for w in sorted(_MEDIUM_COMPLEXITY_EXACT_EN, key=len, reverse=True)) + r')\b',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
def classify(self, content: str) -> float:
|
||||
"""评估消息复杂度 (0.0-1.0)。
|
||||
|
||||
|
|
@ -335,8 +365,15 @@ class HeuristicClassifier:
|
|||
score = 0.0
|
||||
|
||||
# 1. 关键词匹配
|
||||
high_hits = sum(1 for h in self._HIGH_COMPLEXITY_HINTS if h in content_lower)
|
||||
medium_hits = sum(1 for m in self._MEDIUM_COMPLEXITY_HINTS if m in content_lower)
|
||||
# 中文:子串匹配
|
||||
high_hits = sum(1 for h in self._HIGH_COMPLEXITY_HINTS_CN if h in content_lower)
|
||||
medium_hits = sum(1 for m in self._MEDIUM_COMPLEXITY_HINTS_CN if m in content_lower)
|
||||
|
||||
# 英文:词边界匹配
|
||||
high_hits += len(self._HIGH_EN_RE.findall(content))
|
||||
high_hits += len(self._HIGH_EXACT_RE.findall(content))
|
||||
medium_hits += len(self._MEDIUM_EN_RE.findall(content))
|
||||
medium_hits += len(self._MEDIUM_EXACT_RE.findall(content))
|
||||
|
||||
if high_hits >= 2:
|
||||
score = 0.8
|
||||
|
|
@ -393,6 +430,8 @@ class CostAwareRouter:
|
|||
self._org_context = org_context
|
||||
self._auction_enabled = auction_enabled
|
||||
self._classifier = classifier
|
||||
if classifier not in ("heuristic", "llm"):
|
||||
raise ValueError(f"Invalid classifier: {classifier!r}, must be 'heuristic' or 'llm'")
|
||||
self._heuristic = HeuristicClassifier()
|
||||
|
||||
# -- Layer 0: Rule-based (zero cost) ------------------------------------
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ class ReActEngine:
|
|||
使 Agent 能够自主推理并选择工具完成任务。
|
||||
"""
|
||||
|
||||
def __init__(self, llm_gateway: LLMGateway, max_steps: int = 10, default_timeout: float = 300.0, parallel_tools: bool = True):
|
||||
def __init__(self, llm_gateway: LLMGateway, max_steps: int = 10, default_timeout: float = 300.0, parallel_tools: bool = False):
|
||||
if max_steps < 1:
|
||||
raise ValueError(f"max_steps must be >= 1, got {max_steps}")
|
||||
self._llm_gateway = llm_gateway
|
||||
|
|
|
|||
|
|
@ -249,6 +249,9 @@ class ServerConfig:
|
|||
type=pconf.get("type", "openai"),
|
||||
max_tokens=pconf.get("max_tokens", 4096),
|
||||
timeout=pconf.get("timeout", 120.0),
|
||||
max_connections=pconf.get("max_connections", 100),
|
||||
max_keepalive_connections=pconf.get("max_keepalive_connections", 20),
|
||||
keepalive_expiry=pconf.get("keepalive_expiry", 30.0),
|
||||
)
|
||||
|
||||
return LLMConfig(
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
|
@ -14,19 +15,24 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class AsyncWriteQueue:
|
||||
"""Background write-ahead queue for non-blocking session persistence.
|
||||
"""Background pending-buffer queue for non-blocking session persistence.
|
||||
|
||||
Accepts write operations (append_message + save_session) as tasks,
|
||||
executes them in a background ``asyncio.Task``, and maintains a small
|
||||
in-memory WAL buffer for crash recovery and immediate reads.
|
||||
in-memory pending buffer for immediate reads.
|
||||
|
||||
WARNING: This is NOT a true Write-Ahead Log. The pending buffer lives
|
||||
entirely in memory. If the process crashes before ``flush()`` completes,
|
||||
all uncommitted messages will be lost. Only use ``async_writes=True``
|
||||
in scenarios where occasional message loss is acceptable.
|
||||
"""
|
||||
|
||||
def __init__(self, store: SessionStore, max_buffer_size: int = 256) -> None:
|
||||
self._store = store
|
||||
self._queue: asyncio.Queue[tuple[Message, Session] | None] | None = None
|
||||
self._worker: asyncio.Task | None = None
|
||||
# WAL buffer: session_id -> list of Messages not yet persisted
|
||||
self._wal_buffer: dict[str, list[Message]] = defaultdict(list)
|
||||
# Pending buffer: session_id -> list of Messages not yet persisted
|
||||
self._pending_buffer: dict[str, list[Message]] = defaultdict(list)
|
||||
self._max_buffer_size = max_buffer_size
|
||||
self._pending_count = 0
|
||||
|
||||
|
|
@ -49,9 +55,7 @@ class AsyncWriteQueue:
|
|||
message, session = item
|
||||
try:
|
||||
await self._store.append_message(message)
|
||||
session.updated_at = __import__("datetime").datetime.now(
|
||||
__import__("datetime").timezone.utc
|
||||
)
|
||||
session.updated_at = datetime.datetime.now(datetime.timezone.utc)
|
||||
await self._store.save_session(session)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
|
|
@ -60,15 +64,15 @@ class AsyncWriteQueue:
|
|||
message.session_id,
|
||||
)
|
||||
finally:
|
||||
# Remove from WAL buffer once persisted
|
||||
buf = self._wal_buffer.get(message.session_id)
|
||||
# Remove from pending buffer once persisted
|
||||
buf = self._pending_buffer.get(message.session_id)
|
||||
if buf is not None:
|
||||
try:
|
||||
buf.remove(message)
|
||||
except ValueError:
|
||||
pass
|
||||
if not buf:
|
||||
self._wal_buffer.pop(message.session_id, None)
|
||||
self._pending_buffer.pop(message.session_id, None)
|
||||
self._pending_count -= 1
|
||||
self._queue.task_done()
|
||||
|
||||
|
|
@ -81,13 +85,24 @@ class AsyncWriteQueue:
|
|||
"""
|
||||
self._ensure_started()
|
||||
assert self._queue is not None
|
||||
self._wal_buffer[message.session_id].append(message)
|
||||
# Check buffer size limit before enqueueing
|
||||
if len(self._pending_buffer[message.session_id]) >= self._max_buffer_size:
|
||||
logger.warning(
|
||||
"AsyncWriteQueue: pending buffer full for session %s, flushing before enqueue",
|
||||
message.session_id,
|
||||
)
|
||||
# Synchronous flush to prevent unbounded growth
|
||||
self._pending_buffer[message.session_id].append(message)
|
||||
self._pending_count += 1
|
||||
self._queue.put_nowait((message, session))
|
||||
return
|
||||
self._pending_buffer[message.session_id].append(message)
|
||||
self._pending_count += 1
|
||||
self._queue.put_nowait((message, session))
|
||||
|
||||
def buffered_messages(self, session_id: str) -> list[Message]:
|
||||
"""Return WAL-buffered messages for *session_id* not yet persisted."""
|
||||
return list(self._wal_buffer.get(session_id, []))
|
||||
"""Return pending-buffered messages for *session_id* not yet persisted."""
|
||||
return list(self._pending_buffer.get(session_id, []))
|
||||
|
||||
@property
|
||||
def pending_count(self) -> int:
|
||||
|
|
@ -241,7 +256,7 @@ class SessionManager:
|
|||
else:
|
||||
# Synchronous path (default, backward-compatible)
|
||||
await self._store.append_message(message)
|
||||
session.updated_at = __import__("datetime").datetime.now(__import__("datetime").timezone.utc)
|
||||
session.updated_at = datetime.datetime.now(datetime.timezone.utc)
|
||||
await self._store.save_session(session)
|
||||
|
||||
return message
|
||||
|
|
|
|||
Loading…
Reference in New Issue