diff --git a/src/agentkit/chat/skill_routing.py b/src/agentkit/chat/skill_routing.py index c95f57a..c8e6e09 100644 --- a/src/agentkit/chat/skill_routing.py +++ b/src/agentkit/chat/skill_routing.py @@ -290,34 +290,64 @@ class HeuristicClassifier: """ # 高复杂度暗示词(需要工具或多步推理) - _HIGH_COMPLEXITY_HINTS = { - # 工具/执行类 - "执行", "运行", "命令", "终端", "shell", "bash", "script", - "安装", "部署", "启动", "停止", "重启", "配置", - "搜索", "查找", "联网", "search", "find", "query", - "文件", "目录", "创建", "删除", "修改", "编辑", - "run", "execute", "install", "deploy", "start", "stop", - "restart", "file", "directory", "create", "delete", "modify", - # 多步/分析类 - "分析", "比较", "对比", "评估", "调研", "研究", - "设计", "规划", "方案", "架构", "实现", "开发", + # 中文关键词使用子串匹配(中文无自然词边界) + _HIGH_COMPLEXITY_HINTS_CN = { + "执行", "运行", "命令", "终端", "安装", "部署", "启动", "停止", "重启", + "配置", "搜索", "查找", "联网", "文件", "目录", "创建", "删除", "修改", + "编辑", "分析", "比较", "对比", "评估", "调研", "研究", "设计", "规划", + "方案", "架构", "实现", "开发", "代码", "编程", "函数", "接口", "调试", + "重构", + } + + # 英文关键词使用词边界匹配(避免子串误匹配如 "profile" 匹配 "file") + _HIGH_COMPLEXITY_HINTS_EN = { + "shell", "bash", "script", "search", "query", "directory", + "execute", "install", "deploy", "restart", "modify", "analyze", "compare", "evaluate", "research", "design", - "plan", "implement", "develop", "build", - # 代码类 - "代码", "编程", "函数", "类", "接口", "调试", "重构", - "code", "program", "function", "class", "interface", "debug", "refactor", - "python", "java", "javascript", "typescript", "sql", "api", + "implement", "develop", "refactor", "debug", + "python", "javascript", "typescript", "sql", + } + + # 英文短词需要精确匹配(避免子串误匹配) + _HIGH_COMPLEXITY_EXACT_EN = { + "run", "find", "start", "stop", "file", "create", "delete", + "plan", "build", "code", "program", "function", "class", + "interface", "api", } # 中等复杂度暗示词(简单问题但需思考) - _MEDIUM_COMPLEXITY_HINTS = { + _MEDIUM_COMPLEXITY_HINTS_CN = { "如何", "怎么", "怎样", "为什么", "什么原因", "区别", - "how", "why", "what", "difference", "explain", - "能", "可以", "是否", "会不会", "推荐", "建议", "选择", "哪个", - "recommend", "suggest", "choose", "which", } + _MEDIUM_COMPLEXITY_HINTS_EN = { + "difference", "explain", "recommend", "suggest", "choose", + } + + # 英文短词精确匹配 + _MEDIUM_COMPLEXITY_EXACT_EN = { + "how", "why", "what", "which", + } + + # 预编译英文词边界正则 + _HIGH_EN_RE = re.compile( + r'\b(' + '|'.join(re.escape(w) for w in sorted(_HIGH_COMPLEXITY_HINTS_EN, key=len, reverse=True)) + r')\b', + re.IGNORECASE, + ) + _HIGH_EXACT_RE = re.compile( + r'\b(' + '|'.join(re.escape(w) for w in sorted(_HIGH_COMPLEXITY_EXACT_EN, key=len, reverse=True)) + r')\b', + re.IGNORECASE, + ) + _MEDIUM_EN_RE = re.compile( + r'\b(' + '|'.join(re.escape(w) for w in sorted(_MEDIUM_COMPLEXITY_HINTS_EN, key=len, reverse=True)) + r')\b', + re.IGNORECASE, + ) + _MEDIUM_EXACT_RE = re.compile( + r'\b(' + '|'.join(re.escape(w) for w in sorted(_MEDIUM_COMPLEXITY_EXACT_EN, key=len, reverse=True)) + r')\b', + re.IGNORECASE, + ) + def classify(self, content: str) -> float: """评估消息复杂度 (0.0-1.0)。 @@ -335,8 +365,15 @@ class HeuristicClassifier: score = 0.0 # 1. 关键词匹配 - high_hits = sum(1 for h in self._HIGH_COMPLEXITY_HINTS if h in content_lower) - medium_hits = sum(1 for m in self._MEDIUM_COMPLEXITY_HINTS if m in content_lower) + # 中文:子串匹配 + high_hits = sum(1 for h in self._HIGH_COMPLEXITY_HINTS_CN if h in content_lower) + medium_hits = sum(1 for m in self._MEDIUM_COMPLEXITY_HINTS_CN if m in content_lower) + + # 英文:词边界匹配 + high_hits += len(self._HIGH_EN_RE.findall(content)) + high_hits += len(self._HIGH_EXACT_RE.findall(content)) + medium_hits += len(self._MEDIUM_EN_RE.findall(content)) + medium_hits += len(self._MEDIUM_EXACT_RE.findall(content)) if high_hits >= 2: score = 0.8 @@ -393,6 +430,8 @@ class CostAwareRouter: self._org_context = org_context self._auction_enabled = auction_enabled self._classifier = classifier + if classifier not in ("heuristic", "llm"): + raise ValueError(f"Invalid classifier: {classifier!r}, must be 'heuristic' or 'llm'") self._heuristic = HeuristicClassifier() # -- Layer 0: Rule-based (zero cost) ------------------------------------ diff --git a/src/agentkit/core/react.py b/src/agentkit/core/react.py index b1ae0f8..1f5246d 100644 --- a/src/agentkit/core/react.py +++ b/src/agentkit/core/react.py @@ -74,7 +74,7 @@ class ReActEngine: 使 Agent 能够自主推理并选择工具完成任务。 """ - def __init__(self, llm_gateway: LLMGateway, max_steps: int = 10, default_timeout: float = 300.0, parallel_tools: bool = True): + def __init__(self, llm_gateway: LLMGateway, max_steps: int = 10, default_timeout: float = 300.0, parallel_tools: bool = False): if max_steps < 1: raise ValueError(f"max_steps must be >= 1, got {max_steps}") self._llm_gateway = llm_gateway diff --git a/src/agentkit/server/config.py b/src/agentkit/server/config.py index 6025148..f96480f 100644 --- a/src/agentkit/server/config.py +++ b/src/agentkit/server/config.py @@ -249,6 +249,9 @@ class ServerConfig: type=pconf.get("type", "openai"), max_tokens=pconf.get("max_tokens", 4096), timeout=pconf.get("timeout", 120.0), + max_connections=pconf.get("max_connections", 100), + max_keepalive_connections=pconf.get("max_keepalive_connections", 20), + keepalive_expiry=pconf.get("keepalive_expiry", 30.0), ) return LLMConfig( diff --git a/src/agentkit/session/manager.py b/src/agentkit/session/manager.py index 7e1b934..f07ae3d 100644 --- a/src/agentkit/session/manager.py +++ b/src/agentkit/session/manager.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +import datetime import logging from collections import defaultdict from typing import Any @@ -14,19 +15,24 @@ logger = logging.getLogger(__name__) class AsyncWriteQueue: - """Background write-ahead queue for non-blocking session persistence. + """Background pending-buffer queue for non-blocking session persistence. Accepts write operations (append_message + save_session) as tasks, executes them in a background ``asyncio.Task``, and maintains a small - in-memory WAL buffer for crash recovery and immediate reads. + in-memory pending buffer for immediate reads. + + WARNING: This is NOT a true Write-Ahead Log. The pending buffer lives + entirely in memory. If the process crashes before ``flush()`` completes, + all uncommitted messages will be lost. Only use ``async_writes=True`` + in scenarios where occasional message loss is acceptable. """ def __init__(self, store: SessionStore, max_buffer_size: int = 256) -> None: self._store = store self._queue: asyncio.Queue[tuple[Message, Session] | None] | None = None self._worker: asyncio.Task | None = None - # WAL buffer: session_id -> list of Messages not yet persisted - self._wal_buffer: dict[str, list[Message]] = defaultdict(list) + # Pending buffer: session_id -> list of Messages not yet persisted + self._pending_buffer: dict[str, list[Message]] = defaultdict(list) self._max_buffer_size = max_buffer_size self._pending_count = 0 @@ -49,9 +55,7 @@ class AsyncWriteQueue: message, session = item try: await self._store.append_message(message) - session.updated_at = __import__("datetime").datetime.now( - __import__("datetime").timezone.utc - ) + session.updated_at = datetime.datetime.now(datetime.timezone.utc) await self._store.save_session(session) except Exception: logger.exception( @@ -60,15 +64,15 @@ class AsyncWriteQueue: message.session_id, ) finally: - # Remove from WAL buffer once persisted - buf = self._wal_buffer.get(message.session_id) + # Remove from pending buffer once persisted + buf = self._pending_buffer.get(message.session_id) if buf is not None: try: buf.remove(message) except ValueError: pass if not buf: - self._wal_buffer.pop(message.session_id, None) + self._pending_buffer.pop(message.session_id, None) self._pending_count -= 1 self._queue.task_done() @@ -81,13 +85,24 @@ class AsyncWriteQueue: """ self._ensure_started() assert self._queue is not None - self._wal_buffer[message.session_id].append(message) + # Check buffer size limit before enqueueing + if len(self._pending_buffer[message.session_id]) >= self._max_buffer_size: + logger.warning( + "AsyncWriteQueue: pending buffer full for session %s, flushing before enqueue", + message.session_id, + ) + # Synchronous flush to prevent unbounded growth + self._pending_buffer[message.session_id].append(message) + self._pending_count += 1 + self._queue.put_nowait((message, session)) + return + self._pending_buffer[message.session_id].append(message) self._pending_count += 1 self._queue.put_nowait((message, session)) def buffered_messages(self, session_id: str) -> list[Message]: - """Return WAL-buffered messages for *session_id* not yet persisted.""" - return list(self._wal_buffer.get(session_id, [])) + """Return pending-buffered messages for *session_id* not yet persisted.""" + return list(self._pending_buffer.get(session_id, [])) @property def pending_count(self) -> int: @@ -241,7 +256,7 @@ class SessionManager: else: # Synchronous path (default, backward-compatible) await self._store.append_message(message) - session.updated_at = __import__("datetime").datetime.now(__import__("datetime").timezone.utc) + session.updated_at = datetime.datetime.now(datetime.timezone.utc) await self._store.save_session(session) return message