fix(agentkit): resolve all P0/P1/P2/P3 issues from code review
This commit is contained in:
parent
3645c7a080
commit
b34b06724d
|
|
@ -0,0 +1,201 @@
|
|||
---
|
||||
title: "fix: AgentKit P0 Code Review Fixes"
|
||||
status: completed
|
||||
created: 2026-06-07
|
||||
plan_type: fix
|
||||
execution_posture: TDD
|
||||
---
|
||||
|
||||
## Summary
|
||||
|
||||
Fix 4 P0 issues and 1 import defect identified in the Phase 6+7 code review, unblocking merge to main. All units follow TDD: write failing tests first, then implement the fix.
|
||||
|
||||
## Problem Frame
|
||||
|
||||
Code review of the `feat/agentkit-phase7-headroom` branch revealed 4 P0 defects that must be fixed before merge:
|
||||
|
||||
1. **CCR cache unbounded growth** — `_ccr_cache: dict[str, str]` grows without limit; `ccr_ttl` config is declared but never enforced
|
||||
2. **CCR hash collision** — `sha256(...).hexdigest()[:16]` truncates to 64 bits; collisions silently overwrite cached originals
|
||||
3. **OTel span leak** — `_span_cm.__enter__()` without `try/finally`; exception between enter and exit leaks the span
|
||||
4. **StdioTransport notification queue** — `receive_response()` raises `TransportError` when queue is empty, inconsistent with `SSETransport` which awaits
|
||||
|
||||
Plus 1 import defect: `mcp/__init__.py` lists `MCPServer` and `MCPClient` in `__all__` but never imports them.
|
||||
|
||||
## Requirements
|
||||
|
||||
- R1: CCR cache must enforce capacity limit and TTL eviction
|
||||
- R2: CCR hash must detect collisions and reject silent overwrites
|
||||
- R3: OTel span lifecycle must use `try/finally` to guarantee cleanup
|
||||
- R4: `StdioTransport.receive_response()` must await empty queue (consistent with SSETransport)
|
||||
- R5: `mcp/__init__.py` must import and export `MCPServer` and `MCPClient`
|
||||
|
||||
## Key Technical Decisions
|
||||
|
||||
### KTD-1: CCR cache eviction strategy
|
||||
|
||||
**Decision:** Use `collections.OrderedDict` as an LRU with a configurable `max_entries` (default 1000). On insert, move to end (most-recent). When capacity exceeded, evict oldest (least-recent). TTL enforced by storing `(content, timestamp)` tuples and evicting expired entries on access.
|
||||
|
||||
**Rationale:** `OrderedDict` is stdlib, zero-dependency, and provides O(1) move-to-end/pop-oldest. No need for `functools.lru_cache` (wrong abstraction — we need per-instance, not per-function caching) or external deps like `cachetools`.
|
||||
|
||||
### KTD-2: Hash collision handling
|
||||
|
||||
**Decision:** Use full SHA-256 hex digest (64 chars) instead of truncated 16-char prefix. On `_store_ccr`, if hash already exists and content differs, log a warning and skip caching (return `None`).
|
||||
|
||||
**Rationale:** Full SHA-256 makes collisions astronomically improbable (~2^-256). The collision check is a safety net for the extremely unlikely case. Truncating to 64 bits (16 hex chars) was the root cause — birthday paradox gives ~50% collision at ~2^32 entries.
|
||||
|
||||
### KTD-3: OTel span lifecycle pattern
|
||||
|
||||
**Decision:** Replace `__enter__`/`__exit__` manual calls with `with start_span(...) as span:` context manager. Guard with `if _OTEL_AVAILABLE` to avoid no-op span overhead.
|
||||
|
||||
**Rationale:** Context manager guarantees `__exit__` on exception. The current pattern leaks on any exception between `__enter__` and `__exit__`.
|
||||
|
||||
### KTD-4: StdioTransport receive_response await behavior
|
||||
|
||||
**Decision:** When `_notifications` queue is empty, `await` the queue with the transport's configured timeout (same pattern as `SSETransport`). Raise `TransportError` only on timeout or disconnect.
|
||||
|
||||
**Rationale:** Consistency with `SSETransport.receive_response()`, which awaits `_response_queue.get()` with timeout. The current behavior of raising immediately breaks polling consumers that expect to wait.
|
||||
|
||||
---
|
||||
|
||||
## Implementation Units
|
||||
|
||||
### U1. CCR Cache: LRU + TTL + Collision Detection
|
||||
|
||||
**Goal:** Fix unbounded growth and hash collision in `HeadroomCompressor._ccr_cache`.
|
||||
|
||||
**Requirements:** R1, R2
|
||||
|
||||
**Dependencies:** None
|
||||
|
||||
**Files:**
|
||||
- `src/agentkit/core/headroom_compressor.py` — modify
|
||||
- `tests/unit/test_headroom_compressor.py` — modify
|
||||
|
||||
**Approach:**
|
||||
1. Replace `_ccr_cache: dict[str, str]` with `_ccr_cache: OrderedDict[str, tuple[str, float]]` storing `(content, insert_time)`
|
||||
2. Add `_max_entries` config (default 1000); on insert, if at capacity, pop oldest item
|
||||
3. On `_store_ccr`, use full SHA-256 hex digest; if hash exists and content differs, log warning and return `None`
|
||||
4. On `retrieve`, check TTL before returning; evict expired entries
|
||||
5. Add `_evict_expired()` helper called on each store/retrieve
|
||||
|
||||
**Execution note:** TDD — write failing tests for each behavior first.
|
||||
|
||||
**Test scenarios:**
|
||||
- **Happy path:** Store and retrieve content by full hash
|
||||
- **LRU eviction:** Store `max_entries + 1` items; verify oldest evicted
|
||||
- **TTL expiry:** Store with `ccr_ttl=1`, wait >1s, retrieve returns not-found
|
||||
- **Collision detection:** Manually inject a hash with different content; `_store_ccr` returns `None` and logs warning
|
||||
- **No collision on same content:** Store identical content twice; second store returns same hash (idempotent)
|
||||
- **Evict expired on access:** Store with short TTL, wait, then store another item; expired entry cleaned during eviction sweep
|
||||
- **Default max_entries:** Verify default is 1000
|
||||
- **Custom max_entries:** Verify custom config respected
|
||||
|
||||
**Verification:** All new tests pass; existing CCR tests still pass with updated hash length.
|
||||
|
||||
---
|
||||
|
||||
### U2. OTel Span Lifecycle Fix
|
||||
|
||||
**Goal:** Ensure OTel span is always properly closed, even on exceptions.
|
||||
|
||||
**Requirements:** R3
|
||||
|
||||
**Dependencies:** None
|
||||
|
||||
**Files:**
|
||||
- `src/agentkit/core/react.py` — modify
|
||||
- `tests/unit/test_react_compression.py` — modify
|
||||
|
||||
**Approach:**
|
||||
1. Replace `_span_cm = start_span(...); _span_cm.__enter__(); ...; _span_cm.__exit__(...)` with `with start_span(...) as _span:` wrapped around the entire `_execute_loop` body
|
||||
2. Move `_exec_start` and span attribute setting inside the `with` block
|
||||
3. Guard with `if _OTEL_AVAILABLE` to skip span creation when OTel is not installed
|
||||
4. Ensure `agent_duration_histogram` recording happens inside the `with` block
|
||||
|
||||
**Execution note:** TDD — write a failing test that verifies span cleanup on exception first.
|
||||
|
||||
**Test scenarios:**
|
||||
- **Happy path:** Span attributes set and span closed on successful execution
|
||||
- **Exception path:** LLM gateway raises exception; span is still properly closed (attributes set, `__exit__` called)
|
||||
- **Cancellation path:** `TaskCancelledError` raised; span closed with outcome="cancelled"
|
||||
- **No OTel available:** When `_OTEL_AVAILABLE=False`, execution proceeds without span overhead
|
||||
- **Span attribute values:** Verify `agent.total_steps`, `agent.total_tokens`, `agent.outcome`, `agent.duration_ms` are set correctly
|
||||
|
||||
**Verification:** All new tests pass; existing ReAct tests still pass.
|
||||
|
||||
---
|
||||
|
||||
### U3. StdioTransport receive_response Await Fix
|
||||
|
||||
**Goal:** Make `StdioTransport.receive_response()` await empty notification queue, consistent with `SSETransport`.
|
||||
|
||||
**Requirements:** R4
|
||||
|
||||
**Dependencies:** None
|
||||
|
||||
**Files:**
|
||||
- `src/agentkit/mcp/transport.py` — modify
|
||||
- `tests/unit/test_mcp_transport.py` — modify
|
||||
|
||||
**Approach:**
|
||||
1. Replace `if not self._notifications.empty(): return self._notifications.get_nowait()` + `raise TransportError(...)` with `await asyncio.wait_for(self._notifications.get(), timeout=self._timeout)`
|
||||
2. Catch `asyncio.TimeoutError` and raise `TransportError("Timeout waiting for notification")` (matching SSETransport pattern)
|
||||
3. Keep the `is_connected` guard at the top
|
||||
|
||||
**Execution note:** TDD — write failing test for await behavior first.
|
||||
|
||||
**Test scenarios:**
|
||||
- **Happy path:** Notification available immediately; returned without waiting
|
||||
- **Await path:** Queue empty; `receive_response()` awaits until notification arrives
|
||||
- **Timeout path:** Queue empty; timeout expires; raises `TransportError` with "Timeout" message
|
||||
- **Not connected:** Raises `TransportError` with "not connected" message
|
||||
- **Consistency with SSE:** Same await+timeout pattern as `SSETransport.receive_response()`
|
||||
|
||||
**Verification:** All new tests pass; existing transport tests still pass.
|
||||
|
||||
---
|
||||
|
||||
### U4. MCP __init__.py Import Fix
|
||||
|
||||
**Goal:** Add missing `MCPServer` and `MCPClient` imports to `mcp/__init__.py`.
|
||||
|
||||
**Requirements:** R5
|
||||
|
||||
**Dependencies:** None
|
||||
|
||||
**Files:**
|
||||
- `src/agentkit/mcp/__init__.py` — modify
|
||||
|
||||
**Approach:**
|
||||
1. Add `from agentkit.mcp.server import MCPServer` and `from agentkit.mcp.client import MCPClient` imports
|
||||
2. Verify `__all__` already lists both names (it does)
|
||||
|
||||
**Test scenarios:**
|
||||
- **Import test:** `from agentkit.mcp import MCPServer, MCPClient` succeeds
|
||||
- **All exports test:** All names in `__all__` are importable
|
||||
|
||||
**Verification:** `python -c "from agentkit.mcp import MCPServer, MCPClient"` succeeds.
|
||||
|
||||
---
|
||||
|
||||
## Scope Boundaries
|
||||
|
||||
### In Scope
|
||||
- 4 P0 fixes + 1 import fix as described above
|
||||
- Test coverage for all fixes
|
||||
|
||||
### Deferred to Follow-Up Work
|
||||
- P1: Redis degradation recovery in `pipeline_state.py`
|
||||
- P1: Sync `urllib.request` → async in `baidu_search.py` and `schema_tools.py`
|
||||
- P1: Type annotation mismatch (`ContextCompressor` → `CompressionStrategy`) in `react.py`
|
||||
- P1: Config hot-reload race condition in `app.py`
|
||||
- P2: `_request_id` non-atomic increment in transport classes
|
||||
- P3: `_should_compress` hardcoded 8000 token threshold
|
||||
|
||||
## Risks & Mitigations
|
||||
|
||||
| Risk | Mitigation |
|
||||
|------|-----------|
|
||||
| Full SHA-256 hash increases CCR marker length in compressed output | Acceptable: 64 chars vs 16 chars is negligible in tool output context |
|
||||
| `OrderedDict` LRU is not thread-safe | HeadroomCompressor is used within async single-threaded context; no concurrent access |
|
||||
| `with start_span()` changes span scoping in `_execute_loop` | Span now covers the entire loop body including error paths — strictly better |
|
||||
|
|
@ -5,9 +5,12 @@
|
|||
CCR 可逆压缩保证原始数据不丢失。
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from typing import Any
|
||||
|
||||
from agentkit.core.compressor import CompressionStrategy
|
||||
|
|
@ -65,7 +68,8 @@ class HeadroomCompressor:
|
|||
配置项:
|
||||
enabled: bool — 开关
|
||||
compressors: list[str] — 启用的压缩器 ["smart_crusher", "code_compressor"]
|
||||
ccr_ttl: int — CCR 缓存 TTL(秒),默认 300
|
||||
ccr_ttl: int — CCR 缓存 TTL(秒),默认 300;0 表示永不过期
|
||||
max_entries: int — CCR 缓存最大条目数,默认 1000
|
||||
min_length: int — 最小压缩长度(字符),默认 500
|
||||
model: str — 传给 headroom 的模型名
|
||||
"""
|
||||
|
|
@ -74,10 +78,11 @@ class HeadroomCompressor:
|
|||
self._config = config
|
||||
self._compressors = config.get("compressors", ["smart_crusher", "code_compressor"])
|
||||
self._ccr_ttl = config.get("ccr_ttl", 300)
|
||||
self._max_entries = config.get("max_entries", 1000)
|
||||
self._min_length = config.get("min_length", 500)
|
||||
self._model = config.get("model", "default")
|
||||
# CCR cache: hash -> original content
|
||||
self._ccr_cache: dict[str, str] = {}
|
||||
# CCR cache: hash -> (content, insert_timestamp) with LRU ordering
|
||||
self._ccr_cache: OrderedDict[str, tuple[str, float]] = OrderedDict()
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""检查 headroom-ai 是否已安装"""
|
||||
|
|
@ -172,17 +177,66 @@ class HeadroomCompressor:
|
|||
return None
|
||||
|
||||
def _store_ccr(self, original: str) -> str | None:
|
||||
"""存储原始内容到 CCR 缓存,返回哈希"""
|
||||
import hashlib
|
||||
ccr_hash = hashlib.sha256(original.encode()).hexdigest()[:16]
|
||||
self._ccr_cache[ccr_hash] = original
|
||||
"""存储原始内容到 CCR 缓存,返回哈希
|
||||
|
||||
使用完整 SHA-256 防止碰撞。碰撞时拒绝覆盖并返回 None。
|
||||
超过 max_entries 时淘汰最久未访问的条目(LRU)。
|
||||
"""
|
||||
ccr_hash = hashlib.sha256(original.encode()).hexdigest()
|
||||
|
||||
# Collision detection: if hash exists with different content, reject
|
||||
if ccr_hash in self._ccr_cache:
|
||||
cached_content, _ = self._ccr_cache[ccr_hash]
|
||||
if cached_content != original:
|
||||
logger.warning(
|
||||
"CCR hash collision detected for hash=%s... "
|
||||
"Rejecting overwrite to prevent data loss.",
|
||||
ccr_hash[:16],
|
||||
)
|
||||
return None
|
||||
# Same content: idempotent update (renew timestamp + LRU position)
|
||||
self._ccr_cache.move_to_end(ccr_hash)
|
||||
self._ccr_cache[ccr_hash] = (original, time.monotonic())
|
||||
return ccr_hash
|
||||
|
||||
# Evict expired entries before inserting
|
||||
self._evict_expired()
|
||||
|
||||
# LRU eviction: if at capacity, remove oldest entry
|
||||
while len(self._ccr_cache) >= self._max_entries:
|
||||
self._ccr_cache.popitem(last=False)
|
||||
|
||||
self._ccr_cache[ccr_hash] = (original, time.monotonic())
|
||||
return ccr_hash
|
||||
|
||||
def _evict_expired(self) -> None:
|
||||
"""清理过期的 CCR 缓存条目"""
|
||||
if self._ccr_ttl <= 0:
|
||||
return # TTL=0 means no expiry
|
||||
now = time.monotonic()
|
||||
expired_keys = [
|
||||
k for k, (_, ts) in self._ccr_cache.items()
|
||||
if now - ts > self._ccr_ttl
|
||||
]
|
||||
for k in expired_keys:
|
||||
del self._ccr_cache[k]
|
||||
|
||||
def retrieve(self, ccr_hash: str | None = None, query: str | None = None) -> dict:
|
||||
"""从 CCR 缓存检索原始数据"""
|
||||
if ccr_hash and ccr_hash in self._ccr_cache:
|
||||
content, ts = self._ccr_cache[ccr_hash]
|
||||
# Check TTL
|
||||
if self._ccr_ttl > 0:
|
||||
if time.monotonic() - ts > self._ccr_ttl:
|
||||
del self._ccr_cache[ccr_hash]
|
||||
return {
|
||||
"error": f"CCR hash '{ccr_hash}' expired",
|
||||
"success": False,
|
||||
}
|
||||
# Renew LRU position on access
|
||||
self._ccr_cache.move_to_end(ccr_hash)
|
||||
return {
|
||||
"content": self._ccr_cache[ccr_hash],
|
||||
"content": content,
|
||||
"ccr_hash": ccr_hash,
|
||||
"success": True,
|
||||
}
|
||||
|
|
@ -190,7 +244,7 @@ class HeadroomCompressor:
|
|||
if query:
|
||||
# Simple keyword search in cached content
|
||||
results = []
|
||||
for h, content in self._ccr_cache.items():
|
||||
for h, (content, _) in self._ccr_cache.items():
|
||||
if query.lower() in content.lower():
|
||||
results.append({"ccr_hash": h, "content": content[:500]})
|
||||
if results:
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ class ReActEngine:
|
|||
trace_recorder: "TraceRecorder | None" = None,
|
||||
memory_retriever: "MemoryRetriever | None" = None,
|
||||
task_id: str | None = None,
|
||||
compressor: "ContextCompressor | None" = None,
|
||||
compressor: "CompressionStrategy | None" = None,
|
||||
retrieval_config: dict[str, Any] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
timeout_seconds: float | None = None,
|
||||
|
|
@ -163,7 +163,7 @@ class ReActEngine:
|
|||
trace_recorder: "TraceRecorder | None" = None,
|
||||
memory_retriever: "MemoryRetriever | None" = None,
|
||||
task_id: str | None = None,
|
||||
compressor: "ContextCompressor | None" = None,
|
||||
compressor: "CompressionStrategy | None" = None,
|
||||
retrieval_config: dict[str, Any] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> ReActResult:
|
||||
|
|
@ -174,157 +174,90 @@ class ReActEngine:
|
|||
agent_request_counter().add(1, {"agent.name": agent_name, "agent.type": task_type or "react"})
|
||||
|
||||
# Start telemetry span for the entire agent execution
|
||||
_span_cm = start_span(
|
||||
"agent.execute",
|
||||
attributes={"agent.name": agent_name, "agent.type": task_type or "react"},
|
||||
)
|
||||
_span = _span_cm.__enter__()
|
||||
_span_cm = None
|
||||
_span = None
|
||||
_exec_start = time.monotonic()
|
||||
|
||||
# 启动轨迹记录
|
||||
if trace_recorder is not None:
|
||||
trace_recorder.start_trace(
|
||||
task_id="",
|
||||
agent_name=agent_name,
|
||||
skill_name=task_type or None,
|
||||
if _OTEL_AVAILABLE:
|
||||
_span_cm = start_span(
|
||||
"agent.execute",
|
||||
attributes={"agent.name": agent_name, "agent.type": task_type or "react"},
|
||||
)
|
||||
_span = _span_cm.__enter__()
|
||||
|
||||
# Memory retrieval: 执行前检索相关上下文注入 system_prompt
|
||||
if memory_retriever:
|
||||
try:
|
||||
query = str(messages[-1].get("content", "")) if messages else ""
|
||||
top_k = (retrieval_config or {}).get("top_k", 5)
|
||||
token_budget = (retrieval_config or {}).get("token_budget", 2000)
|
||||
memory_context = await memory_retriever.get_context_string(
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
token_budget=token_budget,
|
||||
)
|
||||
if memory_context:
|
||||
if system_prompt:
|
||||
system_prompt += f"\n\n## 参考信息\n{memory_context}"
|
||||
else:
|
||||
system_prompt = f"## 参考信息\n{memory_context}"
|
||||
except Exception as e:
|
||||
logger.warning(f"Memory retrieval failed, continuing without context: {e}")
|
||||
|
||||
# 构建初始消息
|
||||
conversation: list[dict[str, Any]] = []
|
||||
if system_prompt:
|
||||
conversation.append({"role": "system", "content": system_prompt})
|
||||
conversation.extend(messages)
|
||||
|
||||
# Context compression: 压缩超长对话历史
|
||||
if compressor:
|
||||
try:
|
||||
conversation = await compressor.compress(conversation)
|
||||
except Exception as e:
|
||||
logger.warning(f"Context compression failed, continuing with original messages: {e}")
|
||||
|
||||
# Initialize before try so finally can access them
|
||||
trajectory: list[ReActStep] = []
|
||||
total_tokens = 0
|
||||
step = 0
|
||||
output = ""
|
||||
trace_outcome = "success"
|
||||
trace_outcome = "error"
|
||||
|
||||
while step < self._max_steps:
|
||||
step += 1
|
||||
try:
|
||||
# 启动轨迹记录
|
||||
if trace_recorder is not None:
|
||||
trace_recorder.start_trace(
|
||||
task_id="",
|
||||
agent_name=agent_name,
|
||||
skill_name=task_type or None,
|
||||
)
|
||||
|
||||
# 协作式取消检查
|
||||
if cancellation_token is not None:
|
||||
cancellation_token.check()
|
||||
|
||||
# Think: 调用 LLM
|
||||
llm_start = time.monotonic()
|
||||
response = await self._llm_gateway.chat(
|
||||
messages=conversation,
|
||||
model=model,
|
||||
agent_name=agent_name,
|
||||
task_type=task_type,
|
||||
tools=tool_schemas,
|
||||
)
|
||||
llm_duration_ms = int((time.monotonic() - llm_start) * 1000)
|
||||
|
||||
step_tokens = response.usage.total_tokens
|
||||
total_tokens += step_tokens
|
||||
|
||||
# 检查是否有 Function Calling 的 tool_calls
|
||||
if response.has_tool_calls:
|
||||
# 记录 LLM 调用步骤
|
||||
if trace_recorder is not None:
|
||||
trace_recorder.record_step(
|
||||
step=step,
|
||||
action="llm_call",
|
||||
duration_ms=llm_duration_ms,
|
||||
tokens_used=step_tokens,
|
||||
# Memory retrieval: 执行前检索相关上下文注入 system_prompt
|
||||
if memory_retriever:
|
||||
try:
|
||||
query = str(messages[-1].get("content", "")) if messages else ""
|
||||
top_k = (retrieval_config or {}).get("top_k", 5)
|
||||
token_budget = (retrieval_config or {}).get("token_budget", 2000)
|
||||
memory_context = await memory_retriever.get_context_string(
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
token_budget=token_budget,
|
||||
)
|
||||
if memory_context:
|
||||
if system_prompt:
|
||||
system_prompt += f"\n\n## 参考信息\n{memory_context}"
|
||||
else:
|
||||
system_prompt = f"## 参考信息\n{memory_context}"
|
||||
except Exception as e:
|
||||
logger.warning(f"Memory retrieval failed, continuing without context: {e}")
|
||||
|
||||
# Act: 执行工具调用
|
||||
# 先记录 assistant 消息(含 tool_calls)到对话历史
|
||||
assistant_msg: dict[str, Any] = {
|
||||
"role": "assistant",
|
||||
"content": response.content or "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.name,
|
||||
"arguments": json.dumps(tc.arguments),
|
||||
},
|
||||
}
|
||||
for tc in response.tool_calls
|
||||
],
|
||||
}
|
||||
conversation.append(assistant_msg)
|
||||
# 构建初始消息
|
||||
conversation: list[dict[str, Any]] = []
|
||||
if system_prompt:
|
||||
conversation.append({"role": "system", "content": system_prompt})
|
||||
conversation.extend(messages)
|
||||
|
||||
# 执行每个工具调用
|
||||
for tc in response.tool_calls:
|
||||
tool_start = time.monotonic()
|
||||
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
|
||||
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
|
||||
# Context compression: 压缩超长对话历史
|
||||
if compressor:
|
||||
try:
|
||||
conversation = await compressor.compress(conversation)
|
||||
except Exception as e:
|
||||
logger.warning(f"Context compression failed, continuing with original messages: {e}")
|
||||
|
||||
react_step = ReActStep(
|
||||
step=step,
|
||||
action="tool_call",
|
||||
tool_name=tc.name,
|
||||
arguments=tc.arguments,
|
||||
result=tool_result,
|
||||
tokens=step_tokens,
|
||||
)
|
||||
trajectory.append(react_step)
|
||||
trace_outcome = "success"
|
||||
step = 0
|
||||
output = ""
|
||||
|
||||
# 记录工具调用步骤
|
||||
if trace_recorder is not None:
|
||||
tool_error = None
|
||||
if isinstance(tool_result, dict) and "error" in tool_result:
|
||||
tool_error = tool_result["error"]
|
||||
trace_recorder.record_step(
|
||||
step=step,
|
||||
action="tool_call",
|
||||
tool_name=tc.name,
|
||||
input_data=tc.arguments,
|
||||
output_data=tool_result,
|
||||
duration_ms=tool_duration_ms,
|
||||
tokens_used=0,
|
||||
error=tool_error,
|
||||
)
|
||||
while step < self._max_steps:
|
||||
step += 1
|
||||
|
||||
# Observe: 将工具结果添加到对话历史
|
||||
tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name)
|
||||
conversation.append(tool_msg)
|
||||
# 协作式取消检查
|
||||
if cancellation_token is not None:
|
||||
cancellation_token.check()
|
||||
|
||||
# Incremental compression: compress conversation if it's getting long
|
||||
if self._should_compress(conversation, compressor):
|
||||
try:
|
||||
conversation = await compressor.compress(conversation)
|
||||
except Exception as e:
|
||||
logger.warning(f"Incremental compression failed: {e}")
|
||||
# Think: 调用 LLM
|
||||
llm_start = time.monotonic()
|
||||
response = await self._llm_gateway.chat(
|
||||
messages=conversation,
|
||||
model=model,
|
||||
agent_name=agent_name,
|
||||
task_type=task_type,
|
||||
tools=tool_schemas,
|
||||
)
|
||||
llm_duration_ms = int((time.monotonic() - llm_start) * 1000)
|
||||
|
||||
else:
|
||||
# 检查文本解析模式
|
||||
parsed_calls = self._parse_text_tool_calls(response.content or "")
|
||||
if parsed_calls and tools:
|
||||
step_tokens = response.usage.total_tokens
|
||||
total_tokens += step_tokens
|
||||
|
||||
# 检查是否有 Function Calling 的 tool_calls
|
||||
if response.has_tool_calls:
|
||||
# 记录 LLM 调用步骤
|
||||
if trace_recorder is not None:
|
||||
trace_recorder.record_step(
|
||||
|
|
@ -334,19 +267,36 @@ class ReActEngine:
|
|||
tokens_used=step_tokens,
|
||||
)
|
||||
|
||||
# 文本解析模式执行工具
|
||||
conversation.append({"role": "assistant", "content": response.content})
|
||||
# Act: 执行工具调用
|
||||
# 先记录 assistant 消息(含 tool_calls)到对话历史
|
||||
assistant_msg: dict[str, Any] = {
|
||||
"role": "assistant",
|
||||
"content": response.content or "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.name,
|
||||
"arguments": json.dumps(tc.arguments),
|
||||
},
|
||||
}
|
||||
for tc in response.tool_calls
|
||||
],
|
||||
}
|
||||
conversation.append(assistant_msg)
|
||||
|
||||
for pc in parsed_calls:
|
||||
# 执行每个工具调用
|
||||
for tc in response.tool_calls:
|
||||
tool_start = time.monotonic()
|
||||
tool_result = await self._execute_tool(pc["name"], pc["arguments"], tools)
|
||||
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
|
||||
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
|
||||
|
||||
react_step = ReActStep(
|
||||
step=step,
|
||||
action="tool_call",
|
||||
tool_name=pc["name"],
|
||||
arguments=pc["arguments"],
|
||||
tool_name=tc.name,
|
||||
arguments=tc.arguments,
|
||||
result=tool_result,
|
||||
tokens=step_tokens,
|
||||
)
|
||||
|
|
@ -360,16 +310,16 @@ class ReActEngine:
|
|||
trace_recorder.record_step(
|
||||
step=step,
|
||||
action="tool_call",
|
||||
tool_name=pc["name"],
|
||||
input_data=pc["arguments"],
|
||||
tool_name=tc.name,
|
||||
input_data=tc.arguments,
|
||||
output_data=tool_result,
|
||||
duration_ms=tool_duration_ms,
|
||||
tokens_used=0,
|
||||
error=tool_error,
|
||||
)
|
||||
|
||||
# 将工具结果添加到对话历史
|
||||
tool_msg = await self._build_tool_result_message(pc.get("id", f"text_tc_{step}"), tool_result, compressor, pc["name"])
|
||||
# Observe: 将工具结果添加到对话历史
|
||||
tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name)
|
||||
conversation.append(tool_msg)
|
||||
|
||||
# Incremental compression: compress conversation if it's getting long
|
||||
|
|
@ -378,70 +328,130 @@ class ReActEngine:
|
|||
conversation = await compressor.compress(conversation)
|
||||
except Exception as e:
|
||||
logger.warning(f"Incremental compression failed: {e}")
|
||||
else:
|
||||
# Final answer: LLM 没有调用工具,返回最终答案
|
||||
react_step = ReActStep(
|
||||
step=step,
|
||||
action="final_answer",
|
||||
content=response.content,
|
||||
tokens=step_tokens,
|
||||
)
|
||||
trajectory.append(react_step)
|
||||
output = response.content or ""
|
||||
|
||||
# 记录最终答案步骤
|
||||
if trace_recorder is not None:
|
||||
trace_recorder.record_step(
|
||||
else:
|
||||
# 检查文本解析模式
|
||||
parsed_calls = self._parse_text_tool_calls(response.content or "")
|
||||
if parsed_calls and tools:
|
||||
# 记录 LLM 调用步骤
|
||||
if trace_recorder is not None:
|
||||
trace_recorder.record_step(
|
||||
step=step,
|
||||
action="llm_call",
|
||||
duration_ms=llm_duration_ms,
|
||||
tokens_used=step_tokens,
|
||||
)
|
||||
|
||||
# 文本解析模式执行工具
|
||||
conversation.append({"role": "assistant", "content": response.content})
|
||||
|
||||
for pc in parsed_calls:
|
||||
tool_start = time.monotonic()
|
||||
tool_result = await self._execute_tool(pc["name"], pc["arguments"], tools)
|
||||
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
|
||||
|
||||
react_step = ReActStep(
|
||||
step=step,
|
||||
action="tool_call",
|
||||
tool_name=pc["name"],
|
||||
arguments=pc["arguments"],
|
||||
result=tool_result,
|
||||
tokens=step_tokens,
|
||||
)
|
||||
trajectory.append(react_step)
|
||||
|
||||
# 记录工具调用步骤
|
||||
if trace_recorder is not None:
|
||||
tool_error = None
|
||||
if isinstance(tool_result, dict) and "error" in tool_result:
|
||||
tool_error = tool_result["error"]
|
||||
trace_recorder.record_step(
|
||||
step=step,
|
||||
action="tool_call",
|
||||
tool_name=pc["name"],
|
||||
input_data=pc["arguments"],
|
||||
output_data=tool_result,
|
||||
duration_ms=tool_duration_ms,
|
||||
tokens_used=0,
|
||||
error=tool_error,
|
||||
)
|
||||
|
||||
# 将工具结果添加到对话历史
|
||||
tool_msg = await self._build_tool_result_message(pc.get("id", f"text_tc_{step}"), tool_result, compressor, pc["name"])
|
||||
conversation.append(tool_msg)
|
||||
|
||||
# Incremental compression: compress conversation if it's getting long
|
||||
if self._should_compress(conversation, compressor):
|
||||
try:
|
||||
conversation = await compressor.compress(conversation)
|
||||
except Exception as e:
|
||||
logger.warning(f"Incremental compression failed: {e}")
|
||||
else:
|
||||
# Final answer: LLM 没有调用工具,返回最终答案
|
||||
react_step = ReActStep(
|
||||
step=step,
|
||||
action="final_answer",
|
||||
output_data={"content": response.content},
|
||||
duration_ms=llm_duration_ms,
|
||||
tokens_used=step_tokens,
|
||||
content=response.content,
|
||||
tokens=step_tokens,
|
||||
)
|
||||
break
|
||||
trajectory.append(react_step)
|
||||
output = response.content or ""
|
||||
|
||||
# 达到 max_steps 时,返回当前最佳输出
|
||||
if step >= self._max_steps and not output:
|
||||
trace_outcome = "partial"
|
||||
# 使用最后一步的内容作为输出
|
||||
if trajectory and trajectory[-1].content:
|
||||
output = trajectory[-1].content
|
||||
elif trajectory and trajectory[-1].result is not None:
|
||||
output = str(trajectory[-1].result)
|
||||
else:
|
||||
output = response.content or ""
|
||||
# 记录最终答案步骤
|
||||
if trace_recorder is not None:
|
||||
trace_recorder.record_step(
|
||||
step=step,
|
||||
action="final_answer",
|
||||
output_data={"content": response.content},
|
||||
duration_ms=llm_duration_ms,
|
||||
tokens_used=step_tokens,
|
||||
)
|
||||
break
|
||||
|
||||
# 结束轨迹记录
|
||||
if trace_recorder is not None:
|
||||
trace_recorder.end_trace(outcome=trace_outcome)
|
||||
# 达到 max_steps 时,返回当前最佳输出
|
||||
if step >= self._max_steps and not output:
|
||||
trace_outcome = "partial"
|
||||
# 使用最后一步的内容作为输出
|
||||
if trajectory and trajectory[-1].content:
|
||||
output = trajectory[-1].content
|
||||
elif trajectory and trajectory[-1].result is not None:
|
||||
output = str(trajectory[-1].result)
|
||||
else:
|
||||
output = response.content or ""
|
||||
|
||||
# Memory storage: 执行后写入轨迹摘要到 EpisodicMemory
|
||||
if memory_retriever and hasattr(memory_retriever, "store_episode"):
|
||||
try:
|
||||
summary = output[:500] if output else ""
|
||||
await memory_retriever.store_episode(
|
||||
key=f"task:{task_id or 'unknown'}",
|
||||
value={"output_summary": summary, "agent_name": agent_name},
|
||||
metadata={"task_type": task_type, "outcome": trace_outcome},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to store task result in episodic memory: {e}")
|
||||
# 结束轨迹记录
|
||||
if trace_recorder is not None:
|
||||
trace_recorder.end_trace(outcome=trace_outcome)
|
||||
|
||||
# Telemetry: end span and record duration
|
||||
_duration_ms = int((time.monotonic() - _exec_start) * 1000)
|
||||
_span.set_attribute("agent.total_steps", len(trajectory))
|
||||
_span.set_attribute("agent.total_tokens", total_tokens)
|
||||
_span.set_attribute("agent.outcome", trace_outcome)
|
||||
_span.set_attribute("agent.duration_ms", _duration_ms)
|
||||
_span_cm.__exit__(None, None, None)
|
||||
agent_duration_histogram().record(_duration_ms, {"agent.name": agent_name})
|
||||
# Memory storage: 执行后写入轨迹摘要到 EpisodicMemory
|
||||
if memory_retriever and hasattr(memory_retriever, "store_episode"):
|
||||
try:
|
||||
summary = output[:500] if output else ""
|
||||
await memory_retriever.store_episode(
|
||||
key=f"task:{task_id or 'unknown'}",
|
||||
value={"output_summary": summary, "agent_name": agent_name},
|
||||
metadata={"task_type": task_type, "outcome": trace_outcome},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to store task result in episodic memory: {e}")
|
||||
|
||||
return ReActResult(
|
||||
output=output,
|
||||
trajectory=trajectory,
|
||||
total_steps=len(trajectory),
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
return ReActResult(
|
||||
output=output,
|
||||
trajectory=trajectory,
|
||||
total_steps=len(trajectory),
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
finally:
|
||||
# Telemetry: end span and record duration — always runs
|
||||
_duration_ms = int((time.monotonic() - _exec_start) * 1000)
|
||||
if _span is not None:
|
||||
_span.set_attribute("agent.total_steps", len(trajectory))
|
||||
_span.set_attribute("agent.total_tokens", total_tokens)
|
||||
_span.set_attribute("agent.outcome", trace_outcome)
|
||||
_span.set_attribute("agent.duration_ms", _duration_ms)
|
||||
if _span_cm is not None:
|
||||
_span_cm.__exit__(None, None, None)
|
||||
agent_duration_histogram().record(_duration_ms, {"agent.name": agent_name})
|
||||
|
||||
async def execute_stream(
|
||||
self,
|
||||
|
|
@ -454,7 +464,7 @@ class ReActEngine:
|
|||
trace_recorder: "TraceRecorder | None" = None,
|
||||
memory_retriever: "MemoryRetriever | None" = None,
|
||||
task_id: str | None = None,
|
||||
compressor: "ContextCompressor | None" = None,
|
||||
compressor: "CompressionStrategy | None" = None,
|
||||
retrieval_config: dict[str, Any] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
timeout_seconds: float | None = None,
|
||||
|
|
@ -773,14 +783,17 @@ class ReActEngine:
|
|||
return tool
|
||||
return None
|
||||
|
||||
# Default token threshold for incremental compression
|
||||
_DEFAULT_COMPRESS_THRESHOLD = 8000
|
||||
|
||||
def _should_compress(self, conversation: list[dict], compressor: "CompressionStrategy | None") -> bool:
|
||||
"""检查是否需要增量压缩"""
|
||||
if not compressor:
|
||||
return False
|
||||
# Estimate tokens in conversation
|
||||
# Estimate tokens in conversation (rough: 4 chars ≈ 1 token)
|
||||
total_chars = sum(len(str(m.get("content", ""))) for m in conversation)
|
||||
estimated_tokens = total_chars // 4
|
||||
return estimated_tokens > 8000 # Threshold: ~8000 tokens
|
||||
return estimated_tokens > self._DEFAULT_COMPRESS_THRESHOLD
|
||||
|
||||
async def _build_tool_result_message(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
"""AgentKit MCP - Model Context Protocol 支持"""
|
||||
|
||||
from agentkit.mcp.client import MCPClient
|
||||
from agentkit.mcp.manager import MCPManager
|
||||
from agentkit.mcp.server import MCPServer
|
||||
from agentkit.mcp.transport import HTTPTransport, SSETransport, StdioTransport, Transport, TransportError
|
||||
|
||||
__all__ = [
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
|
|
@ -34,13 +35,23 @@ class MCPManager:
|
|||
self._server_tools: dict[str, list[str]] = {} # server_name -> [tool_names]
|
||||
|
||||
async def start_all(self) -> None:
|
||||
"""启动所有配置的 MCP Server,发现并注册工具"""
|
||||
for name, config in self._configs.items():
|
||||
try:
|
||||
await self._start_server(name, config)
|
||||
except Exception as e:
|
||||
logger.error("Failed to start MCP server '%s': %s", name, e)
|
||||
self._available[name] = False
|
||||
"""启动所有配置的 MCP Server,并发发现并注册工具
|
||||
|
||||
使用 asyncio.gather 并发启动,单个服务器失败不影响其他服务器。
|
||||
"""
|
||||
tasks = [
|
||||
self._start_server_safe(name, config)
|
||||
for name, config in self._configs.items()
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def _start_server_safe(self, name: str, config: MCPServerConfig) -> None:
|
||||
"""启动单个 MCP Server,失败时标记为不可用"""
|
||||
try:
|
||||
await self._start_server(name, config)
|
||||
except Exception as e:
|
||||
logger.error("Failed to start MCP server '%s': %s", name, e)
|
||||
self._available[name] = False
|
||||
|
||||
async def _start_server(self, name: str, config: MCPServerConfig) -> None:
|
||||
"""启动单个 MCP Server"""
|
||||
|
|
@ -97,9 +108,10 @@ class MCPManager:
|
|||
await transport.disconnect()
|
||||
except Exception as e:
|
||||
logger.error("Error stopping MCP server '%s': %s", name, e)
|
||||
self._available[name] = False
|
||||
self._transports.clear()
|
||||
self._clients.clear()
|
||||
self._available.clear()
|
||||
self._server_tools.clear()
|
||||
|
||||
def is_available(self, server_name: str) -> bool:
|
||||
"""检查指定 MCP Server 是否可用"""
|
||||
|
|
|
|||
|
|
@ -567,20 +567,24 @@ class StdioTransport(Transport):
|
|||
|
||||
对于 StdioTransport,请求响应通过 _pending Future 异步返回。
|
||||
此方法仅用于获取服务端推送的通知消息。
|
||||
空队列时 await 等待(与 SSETransport 行为一致)。
|
||||
|
||||
Returns:
|
||||
JSON-RPC 通知消息
|
||||
|
||||
Raises:
|
||||
TransportError: 连接未建立或无通知
|
||||
TransportError: 连接未建立或超时
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise TransportError("Transport not connected")
|
||||
|
||||
if not self._notifications.empty():
|
||||
return self._notifications.get_nowait()
|
||||
|
||||
raise TransportError("No notification to receive")
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
self._notifications.get(),
|
||||
timeout=self._timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise TransportError("Timeout waiting for notification")
|
||||
|
||||
async def _read_stdout(self) -> None:
|
||||
"""持续从子进程 stdout 读取 JSON-RPC 消息"""
|
||||
|
|
|
|||
|
|
@ -148,19 +148,27 @@ class PipelineStateMemory:
|
|||
async def get_step_history(self, execution_id: str) -> list[dict[str, Any]]:
|
||||
return self._step_history.get(execution_id, [])
|
||||
|
||||
def get_execution_sync(self, execution_id: str) -> dict[str, Any] | None:
|
||||
"""Synchronous accessor for execution state (used by Redis dual-write)."""
|
||||
return self._executions.get(execution_id)
|
||||
|
||||
|
||||
class PipelineStateRedis:
|
||||
"""Redis-backed pipeline state storage (hot state).
|
||||
|
||||
Uses Redis Hash for execution state and Sorted Set for indexing.
|
||||
Falls back to PipelineStateMemory if Redis is unavailable.
|
||||
Automatically retries Redis after a cooldown period.
|
||||
"""
|
||||
|
||||
_RECOVERY_COOLDOWN_SECONDS = 30
|
||||
|
||||
def __init__(self, redis_url: str = "redis://localhost:6379/0") -> None:
|
||||
self._redis_url = redis_url
|
||||
self._redis: Any = None
|
||||
self._fallback = PipelineStateMemory()
|
||||
self._use_fallback = False
|
||||
self._fallback_since: float | None = None
|
||||
|
||||
async def _get_redis(self):
|
||||
if self._redis is None:
|
||||
|
|
@ -175,15 +183,42 @@ class PipelineStateRedis:
|
|||
async def _safe_redis_call(
|
||||
self, fn: Callable[..., Coroutine[Any, Any, Any]], *args: Any, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Execute a Redis call, falling back to memory on failure."""
|
||||
"""Execute a Redis call, falling back to memory on failure.
|
||||
|
||||
After falling back, periodically retries Redis to enable recovery.
|
||||
On successful recovery, the original operation is executed immediately.
|
||||
"""
|
||||
if self._use_fallback:
|
||||
return None
|
||||
# Check if enough time has passed to attempt recovery
|
||||
if self._fallback_since is not None:
|
||||
import time as _time
|
||||
elapsed = _time.monotonic() - self._fallback_since
|
||||
if elapsed >= self._RECOVERY_COOLDOWN_SECONDS:
|
||||
try:
|
||||
self._redis = None
|
||||
redis = await self._get_redis()
|
||||
await redis.ping()
|
||||
# Recovery successful — continue to execute the operation
|
||||
self._use_fallback = False
|
||||
self._fallback_since = None
|
||||
logger.info("Redis connection recovered, switching back from fallback")
|
||||
# Fall through to execute the actual operation on Redis
|
||||
except Exception:
|
||||
# Still down, reset cooldown timer
|
||||
self._fallback_since = _time.monotonic()
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
return await fn(redis, *args, **kwargs)
|
||||
except Exception as exc:
|
||||
logger.warning(f"Redis operation failed, switching to memory fallback: {exc}")
|
||||
self._use_fallback = True
|
||||
import time as _time
|
||||
self._fallback_since = _time.monotonic()
|
||||
self._redis = None
|
||||
return None
|
||||
|
||||
|
|
@ -204,7 +239,7 @@ class PipelineStateRedis:
|
|||
|
||||
# Try Redis
|
||||
async def _redis_create(redis: Any) -> None:
|
||||
state = self._fallback._executions[execution_id]
|
||||
state = self._fallback.get_execution_sync(execution_id)
|
||||
score = datetime.now(timezone.utc).timestamp()
|
||||
pipe = redis.pipeline()
|
||||
pipe.set(self._key(execution_id), json.dumps(state), ex=_TTL_SECONDS)
|
||||
|
|
@ -226,7 +261,7 @@ class PipelineStateRedis:
|
|||
await self._fallback.update_step(execution_id, step_name, status, output, error, duration_ms)
|
||||
|
||||
async def _redis_update(redis: Any) -> None:
|
||||
state = self._fallback._executions.get(execution_id)
|
||||
state = self._fallback.get_execution_sync(execution_id)
|
||||
if state is None:
|
||||
return
|
||||
await redis.set(self._key(execution_id), json.dumps(state), ex=_TTL_SECONDS)
|
||||
|
|
@ -241,7 +276,7 @@ class PipelineStateRedis:
|
|||
await self._fallback.complete_execution(execution_id, final_output)
|
||||
|
||||
async def _redis_complete(redis: Any) -> None:
|
||||
state = self._fallback._executions.get(execution_id)
|
||||
state = self._fallback.get_execution_sync(execution_id)
|
||||
if state is None:
|
||||
return
|
||||
await redis.set(self._key(execution_id), json.dumps(state), ex=_TTL_SECONDS)
|
||||
|
|
@ -257,7 +292,7 @@ class PipelineStateRedis:
|
|||
await self._fallback.fail_execution(execution_id, step_name, error)
|
||||
|
||||
async def _redis_fail(redis: Any) -> None:
|
||||
state = self._fallback._executions.get(execution_id)
|
||||
state = self._fallback.get_execution_sync(execution_id)
|
||||
if state is None:
|
||||
return
|
||||
await redis.set(self._key(execution_id), json.dumps(state), ex=_TTL_SECONDS)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""FastAPI Application Factory"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
|
|
@ -10,6 +11,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||
from agentkit.core.agent_pool import AgentPool
|
||||
from agentkit.llm.gateway import LLMGateway
|
||||
from agentkit.llm.providers.anthropic import AnthropicProvider
|
||||
from agentkit.llm.providers.gemini import GeminiProvider
|
||||
from agentkit.llm.providers.openai import OpenAICompatibleProvider
|
||||
from agentkit.mcp.manager import MCPManager
|
||||
from agentkit.quality.gate import QualityGate
|
||||
|
|
@ -114,42 +116,62 @@ def _on_config_change(app: FastAPI, config: ServerConfig) -> None:
|
|||
- New tasks use the new configuration
|
||||
- In-progress tasks continue with their original configuration
|
||||
- Config version is incremented for audit tracking
|
||||
|
||||
Uses a lock to prevent concurrent config reloads from racing.
|
||||
"""
|
||||
# Increment config version for audit
|
||||
current_version = getattr(app.state, "config_version", 0) + 1
|
||||
app.state.config_version = current_version
|
||||
logger.info(f"Config change detected (v{current_version}), reloading...")
|
||||
lock: asyncio.Lock = getattr(app.state, "_config_reload_lock", None)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
app.state._config_reload_lock = lock
|
||||
|
||||
# Rebuild LLMGateway if llm config changed
|
||||
if lock.locked():
|
||||
logger.warning("Config reload already in progress, skipping")
|
||||
return
|
||||
|
||||
async def _reload():
|
||||
async with lock:
|
||||
# Increment config version for audit
|
||||
current_version = getattr(app.state, "config_version", 0) + 1
|
||||
app.state.config_version = current_version
|
||||
logger.info(f"Config change detected (v{current_version}), reloading...")
|
||||
|
||||
# Rebuild LLMGateway if llm config changed
|
||||
try:
|
||||
new_gateway = _build_llm_gateway(config)
|
||||
app.state.llm_gateway = new_gateway
|
||||
# Also update the agent pool's gateway reference
|
||||
if hasattr(app.state, "agent_pool") and app.state.agent_pool is not None:
|
||||
app.state.agent_pool._llm_gateway = new_gateway
|
||||
if hasattr(app.state, "intent_router") and app.state.intent_router is not None:
|
||||
app.state.intent_router._llm_gateway = new_gateway
|
||||
logger.info(f"LLM Gateway reloaded (config v{current_version})")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reload LLM Gateway: {e}")
|
||||
|
||||
# Reload skills if skill paths changed
|
||||
try:
|
||||
new_skill_registry = _build_skill_registry(config)
|
||||
app.state.skill_registry = new_skill_registry
|
||||
if hasattr(app.state, "agent_pool") and app.state.agent_pool is not None:
|
||||
app.state.agent_pool._skill_registry = new_skill_registry
|
||||
logger.info(f"Skills reloaded (config v{current_version})")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reload skills: {e}")
|
||||
|
||||
# Update config version on all agents
|
||||
if hasattr(app.state, "agent_pool") and app.state.agent_pool is not None:
|
||||
for agent in app.state.agent_pool._agents.values():
|
||||
if hasattr(agent, "_config_version"):
|
||||
agent._config_version = current_version
|
||||
|
||||
logger.info(f"Config reload complete (v{current_version})")
|
||||
|
||||
# Schedule the reload as a task (non-blocking for the watcher thread)
|
||||
try:
|
||||
new_gateway = _build_llm_gateway(config)
|
||||
app.state.llm_gateway = new_gateway
|
||||
# Also update the agent pool's gateway reference
|
||||
if hasattr(app.state, "agent_pool") and app.state.agent_pool is not None:
|
||||
app.state.agent_pool._llm_gateway = new_gateway
|
||||
if hasattr(app.state, "intent_router") and app.state.intent_router is not None:
|
||||
app.state.intent_router._llm_gateway = new_gateway
|
||||
logger.info(f"LLM Gateway reloaded (config v{current_version})")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reload LLM Gateway: {e}")
|
||||
|
||||
# Reload skills if skill paths changed
|
||||
try:
|
||||
new_skill_registry = _build_skill_registry(config)
|
||||
app.state.skill_registry = new_skill_registry
|
||||
if hasattr(app.state, "agent_pool") and app.state.agent_pool is not None:
|
||||
app.state.agent_pool._skill_registry = new_skill_registry
|
||||
logger.info(f"Skills reloaded (config v{current_version})")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reload skills: {e}")
|
||||
|
||||
# Update config version on all agents
|
||||
if hasattr(app.state, "agent_pool") and app.state.agent_pool is not None:
|
||||
for agent in app.state.agent_pool._agents.values():
|
||||
if hasattr(agent, "_config_version"):
|
||||
agent._config_version = current_version
|
||||
|
||||
logger.info(f"Config reload complete (v{current_version})")
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.create_task(_reload())
|
||||
except RuntimeError:
|
||||
logger.warning("No running event loop, config reload deferred")
|
||||
|
||||
|
||||
def create_app(
|
||||
|
|
|
|||
|
|
@ -7,9 +7,10 @@
|
|||
import json
|
||||
import logging
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from agentkit.tools.base import Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -119,15 +120,16 @@ class BaiduSearchTool(Tool):
|
|||
"num": max_results,
|
||||
}
|
||||
url = f"{self._api_url}?{urllib.parse.urlencode(params)}"
|
||||
req = urllib.request.Request(
|
||||
url,
|
||||
headers={
|
||||
"User-Agent": "AgentKit/1.0",
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
},
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=30) as resp:
|
||||
data = json.loads(resp.read().decode("utf-8"))
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
resp = await client.get(
|
||||
url,
|
||||
headers={
|
||||
"User-Agent": "AgentKit/1.0",
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
results = []
|
||||
for item in data.get("results", [])[:max_results]:
|
||||
|
|
@ -149,18 +151,18 @@ class BaiduSearchTool(Tool):
|
|||
try:
|
||||
encoded_query = urllib.parse.quote(query)
|
||||
url = f"https://www.baidu.com/s?wd={encoded_query}&rn={max_results}"
|
||||
req = urllib.request.Request(
|
||||
url,
|
||||
headers={
|
||||
"User-Agent": (
|
||||
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/120.0.0.0 Safari/537.36"
|
||||
),
|
||||
},
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=30) as resp:
|
||||
html = resp.read().decode("utf-8", errors="replace")
|
||||
async with httpx.AsyncClient(timeout=30, follow_redirects=True) as client:
|
||||
resp = await client.get(
|
||||
url,
|
||||
headers={
|
||||
"User-Agent": (
|
||||
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/120.0.0.0 Safari/537.36"
|
||||
),
|
||||
},
|
||||
)
|
||||
html = resp.text
|
||||
|
||||
# 简单解析搜索结果(基于百度搜索结果页 HTML 结构)
|
||||
results = self._parse_baidu_html(html, max_results)
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ import json
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from agentkit.tools.base import Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -144,11 +146,9 @@ class SchemaExtractTool(Tool):
|
|||
if self._is_url(url_or_html):
|
||||
url = url_or_html
|
||||
try:
|
||||
import urllib.request
|
||||
|
||||
req = urllib.request.Request(url, headers={"User-Agent": "AgentKit/1.0"})
|
||||
with urllib.request.urlopen(req, timeout=30) as resp:
|
||||
html = resp.read().decode("utf-8", errors="replace")
|
||||
async with httpx.AsyncClient(timeout=30, follow_redirects=True) as client:
|
||||
resp = await client.get(url, headers={"User-Agent": "AgentKit/1.0"})
|
||||
html = resp.text
|
||||
except Exception as e:
|
||||
return {
|
||||
"error": f"获取 URL 内容失败: {e}",
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ class MockHeadroomCompressor:
|
|||
|
||||
def _store_ccr(self, original):
|
||||
import hashlib
|
||||
ccr_hash = hashlib.sha256(original.encode()).hexdigest()[:16]
|
||||
ccr_hash = hashlib.sha256(original.encode()).hexdigest()
|
||||
self._ccr_cache[ccr_hash] = original
|
||||
return ccr_hash
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@
|
|||
所有测试使用 mock headroom 模块,无需安装 headroom-ai。
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
|
@ -302,7 +304,7 @@ class TestCCRRetrieve:
|
|||
def test_retrieve_not_found(self):
|
||||
"""无效 hash 返回错误"""
|
||||
compressor = HeadroomCompressor({})
|
||||
result = compressor.retrieve(ccr_hash="nonexistent_hash")
|
||||
result = compressor.retrieve(ccr_hash="a" * 64) # Full SHA-256 length
|
||||
assert result["success"] is False
|
||||
assert "error" in result
|
||||
|
||||
|
|
@ -347,7 +349,6 @@ class TestHeadroomCompressorConfig:
|
|||
assert compressor._model == "default"
|
||||
|
||||
def test_custom_config(self):
|
||||
"""自定义配置值"""
|
||||
config = {
|
||||
"compressors": ["smart_crusher"],
|
||||
"ccr_ttl": 600,
|
||||
|
|
@ -359,3 +360,146 @@ class TestHeadroomCompressorConfig:
|
|||
assert compressor._ccr_ttl == 600
|
||||
assert compressor._min_length == 1000
|
||||
assert compressor._model == "gpt-4"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestCCRCacheLRU (P0 fix: unbounded growth)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCCRCacheLRU:
|
||||
"""测试 CCR 缓存 LRU 淘汰策略"""
|
||||
|
||||
def test_lru_evicts_oldest_when_full(self):
|
||||
"""超过 max_entries 时淘汰最久未访问的条目"""
|
||||
compressor = HeadroomCompressor({"max_entries": 3})
|
||||
h1 = compressor._store_ccr("content_1")
|
||||
h2 = compressor._store_ccr("content_2")
|
||||
h3 = compressor._store_ccr("content_3")
|
||||
# 第 4 个条目应该触发淘汰第 1 个
|
||||
h4 = compressor._store_ccr("content_4")
|
||||
assert h1 is not None
|
||||
assert h4 is not None
|
||||
# h1 应该已被淘汰
|
||||
result = compressor.retrieve(ccr_hash=h1)
|
||||
assert result["success"] is False
|
||||
# h2, h3, h4 应该还在
|
||||
assert compressor.retrieve(ccr_hash=h2)["success"] is True
|
||||
assert compressor.retrieve(ccr_hash=h3)["success"] is True
|
||||
assert compressor.retrieve(ccr_hash=h4)["success"] is True
|
||||
|
||||
def test_lru_access_renews_entry(self):
|
||||
"""retrieve 使条目变为最近访问,不被淘汰"""
|
||||
compressor = HeadroomCompressor({"max_entries": 3})
|
||||
h1 = compressor._store_ccr("content_1")
|
||||
h2 = compressor._store_ccr("content_2")
|
||||
h3 = compressor._store_ccr("content_3")
|
||||
# 访问 h1,使其变为最近
|
||||
compressor.retrieve(ccr_hash=h1)
|
||||
# 插入新条目,应该淘汰 h2(最久未访问)
|
||||
h4 = compressor._store_ccr("content_4")
|
||||
assert compressor.retrieve(ccr_hash=h1)["success"] is True
|
||||
assert compressor.retrieve(ccr_hash=h2)["success"] is False
|
||||
|
||||
def test_default_max_entries(self):
|
||||
"""默认 max_entries 为 1000"""
|
||||
compressor = HeadroomCompressor({})
|
||||
assert compressor._max_entries == 1000
|
||||
|
||||
def test_custom_max_entries(self):
|
||||
"""自定义 max_entries 配置"""
|
||||
compressor = HeadroomCompressor({"max_entries": 50})
|
||||
assert compressor._max_entries == 50
|
||||
|
||||
def test_cache_uses_ordered_dict(self):
|
||||
"""CCR 缓存使用 OrderedDict"""
|
||||
compressor = HeadroomCompressor({})
|
||||
assert isinstance(compressor._ccr_cache, OrderedDict)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestCCRCacheTTL (P0 fix: TTL enforcement)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCCRCacheTTL:
|
||||
"""测试 CCR 缓存 TTL 过期淘汰"""
|
||||
|
||||
def test_expired_entry_not_retrieved(self):
|
||||
"""过期的条目无法被 retrieve"""
|
||||
compressor = HeadroomCompressor({"ccr_ttl": 1})
|
||||
h = compressor._store_ccr("content")
|
||||
time.sleep(1.1)
|
||||
result = compressor.retrieve(ccr_hash=h)
|
||||
assert result["success"] is False
|
||||
|
||||
def test_fresh_entry_retrieved(self):
|
||||
"""未过期的条目可以正常 retrieve"""
|
||||
compressor = HeadroomCompressor({"ccr_ttl": 300})
|
||||
h = compressor._store_ccr("content")
|
||||
result = compressor.retrieve(ccr_hash=h)
|
||||
assert result["success"] is True
|
||||
assert result["content"] == "content"
|
||||
|
||||
def test_ttl_zero_means_no_expiry(self):
|
||||
"""ccr_ttl=0 表示永不过期"""
|
||||
compressor = HeadroomCompressor({"ccr_ttl": 0})
|
||||
h = compressor._store_ccr("content")
|
||||
result = compressor.retrieve(ccr_hash=h)
|
||||
assert result["success"] is True
|
||||
|
||||
def test_evict_expired_on_store(self):
|
||||
"""_store_ccr 时清理过期条目"""
|
||||
compressor = HeadroomCompressor({"ccr_ttl": 1, "max_entries": 100})
|
||||
h1 = compressor._store_ccr("old_content")
|
||||
time.sleep(1.1)
|
||||
# 存储新条目时应触发过期清理
|
||||
h2 = compressor._store_ccr("new_content")
|
||||
# h1 应该已被清理
|
||||
result = compressor.retrieve(ccr_hash=h1)
|
||||
assert result["success"] is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestCCRCacheCollision (P0 fix: hash collision detection)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCCRCacheCollision:
|
||||
"""测试 CCR 缓存哈希碰撞检测"""
|
||||
|
||||
def test_full_sha256_hash_length(self):
|
||||
"""_store_ccr 使用完整 SHA-256(64 字符 hex)"""
|
||||
compressor = HeadroomCompressor({})
|
||||
h = compressor._store_ccr("some content")
|
||||
assert h is not None
|
||||
assert len(h) == 64 # Full SHA-256 hex digest
|
||||
|
||||
def test_same_content_returns_same_hash(self):
|
||||
"""相同内容返回相同 hash(幂等)"""
|
||||
compressor = HeadroomCompressor({})
|
||||
h1 = compressor._store_ccr("identical content")
|
||||
h2 = compressor._store_ccr("identical content")
|
||||
assert h1 == h2
|
||||
|
||||
def test_collision_detected_returns_none(self):
|
||||
"""碰撞检测:手动注入不同内容到相同 hash 时返回 None"""
|
||||
compressor = HeadroomCompressor({})
|
||||
# 正常存储
|
||||
h1 = compressor._store_ccr("original content")
|
||||
assert h1 is not None
|
||||
# 手动修改缓存中的内容为不同值(模拟碰撞)
|
||||
# 获取内部存储的 key
|
||||
import hashlib
|
||||
collision_hash = hashlib.sha256("collision content".encode()).hexdigest()
|
||||
# 手动注入一个不同内容到同一个 hash
|
||||
compressor._ccr_cache[collision_hash] = ("different content", time.time())
|
||||
# 尝试存储 "collision content" 到已有不同内容的 hash
|
||||
result = compressor._store_ccr("collision content")
|
||||
assert result is None
|
||||
|
||||
def test_no_collision_same_content_overwrite(self):
|
||||
"""相同内容重复存储不触发碰撞(幂等更新)"""
|
||||
compressor = HeadroomCompressor({})
|
||||
h1 = compressor._store_ccr("same content")
|
||||
h2 = compressor._store_ccr("same content")
|
||||
assert h1 is not None
|
||||
assert h2 is not None
|
||||
assert h1 == h2
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
|
@ -460,3 +461,66 @@ class TestTransportLifecycle:
|
|||
result2 = await transport.send_request("method2")
|
||||
assert result2 == {"second": True}
|
||||
await transport.disconnect()
|
||||
|
||||
|
||||
# ── StdioTransport receive_response 测试 (P0 fix) ──────────────────
|
||||
|
||||
|
||||
class TestStdioTransportReceiveResponse:
|
||||
"""测试 StdioTransport.receive_response() await 行为"""
|
||||
|
||||
async def test_awaits_empty_notification_queue(self):
|
||||
"""空队列时 receive_response 应 await 而非立即抛异常"""
|
||||
from agentkit.mcp.transport import StdioTransport
|
||||
|
||||
transport = StdioTransport(command="echo", timeout=2.0)
|
||||
# 手动设置连接状态(不实际启动子进程)
|
||||
transport._connected = True
|
||||
transport._process = MagicMock()
|
||||
transport._process.returncode = None
|
||||
|
||||
# 在后台放入一个通知来解除 await
|
||||
notification = {"jsonrpc": "2.0", "method": "notifications/progress", "params": {"progress": 50}}
|
||||
asyncio.get_event_loop().call_later(0.1, lambda: asyncio.ensure_future(
|
||||
transport._notifications.put(notification)
|
||||
))
|
||||
|
||||
result = await asyncio.wait_for(
|
||||
transport.receive_response(), timeout=1.0
|
||||
)
|
||||
assert result == notification
|
||||
|
||||
async def test_immediate_return_when_notification_available(self):
|
||||
"""队列中已有通知时立即返回"""
|
||||
from agentkit.mcp.transport import StdioTransport
|
||||
|
||||
transport = StdioTransport(command="echo", timeout=2.0)
|
||||
transport._connected = True
|
||||
transport._process = MagicMock()
|
||||
transport._process.returncode = None
|
||||
|
||||
notification = {"jsonrpc": "2.0", "method": "test"}
|
||||
await transport._notifications.put(notification)
|
||||
|
||||
result = await transport.receive_response()
|
||||
assert result == notification
|
||||
|
||||
async def test_timeout_raises_transport_error(self):
|
||||
"""超时时抛出 TransportError"""
|
||||
from agentkit.mcp.transport import StdioTransport, TransportError
|
||||
|
||||
transport = StdioTransport(command="echo", timeout=0.1)
|
||||
transport._connected = True
|
||||
transport._process = MagicMock()
|
||||
transport._process.returncode = None
|
||||
|
||||
with pytest.raises(TransportError, match="Timeout"):
|
||||
await transport.receive_response()
|
||||
|
||||
async def test_not_connected_raises_transport_error(self):
|
||||
"""未连接时抛出 TransportError"""
|
||||
from agentkit.mcp.transport import StdioTransport, TransportError
|
||||
|
||||
transport = StdioTransport(command="echo")
|
||||
with pytest.raises(TransportError, match="not connected"):
|
||||
await transport.receive_response()
|
||||
|
|
|
|||
|
|
@ -349,3 +349,77 @@ class TestReActLoopCompression:
|
|||
compressor = ContextCompressor()
|
||||
result = await compressor.compress_tool_result("search", {"key": "value"})
|
||||
assert result == "{'key': 'value'}"
|
||||
|
||||
|
||||
# ── TestOTelSpanLifecycle (P0 fix: span leak) ──────────────────
|
||||
|
||||
|
||||
class TestOTelSpanLifecycle:
|
||||
"""测试 OTel span 生命周期 — 异常时 span 必须正确关闭"""
|
||||
|
||||
async def test_span_closed_on_success(self):
|
||||
"""正常执行时 span 被正确关闭"""
|
||||
gateway = make_mock_gateway()
|
||||
engine = ReActEngine(llm_gateway=gateway)
|
||||
|
||||
mock_span = MagicMock()
|
||||
mock_span_cm = MagicMock()
|
||||
mock_span_cm.__enter__ = MagicMock(return_value=mock_span)
|
||||
mock_span_cm.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with patch("agentkit.core.react.start_span", return_value=mock_span_cm), \
|
||||
patch("agentkit.core.react._OTEL_AVAILABLE", True):
|
||||
await engine.execute(messages=[{"role": "user", "content": "hello"}])
|
||||
|
||||
# __exit__ should have been called
|
||||
mock_span_cm.__exit__.assert_called_once()
|
||||
|
||||
async def test_span_closed_on_exception(self):
|
||||
"""LLM 抛出异常时 span 仍被正确关闭"""
|
||||
gateway = make_mock_gateway()
|
||||
gateway.chat = AsyncMock(side_effect=RuntimeError("LLM error"))
|
||||
engine = ReActEngine(llm_gateway=gateway)
|
||||
|
||||
mock_span = MagicMock()
|
||||
mock_span_cm = MagicMock()
|
||||
mock_span_cm.__enter__ = MagicMock(return_value=mock_span)
|
||||
mock_span_cm.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with patch("agentkit.core.react.start_span", return_value=mock_span_cm), \
|
||||
patch("agentkit.core.react._OTEL_AVAILABLE", True):
|
||||
with pytest.raises(RuntimeError, match="LLM error"):
|
||||
await engine.execute(messages=[{"role": "user", "content": "hello"}])
|
||||
|
||||
# __exit__ must have been called even though exception was raised
|
||||
mock_span_cm.__exit__.assert_called_once()
|
||||
|
||||
async def test_span_attributes_set_on_success(self):
|
||||
"""正常执行时 span 属性被设置"""
|
||||
gateway = make_mock_gateway()
|
||||
engine = ReActEngine(llm_gateway=gateway)
|
||||
|
||||
mock_span = MagicMock()
|
||||
mock_span_cm = MagicMock()
|
||||
mock_span_cm.__enter__ = MagicMock(return_value=mock_span)
|
||||
mock_span_cm.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with patch("agentkit.core.react.start_span", return_value=mock_span_cm), \
|
||||
patch("agentkit.core.react._OTEL_AVAILABLE", True):
|
||||
await engine.execute(messages=[{"role": "user", "content": "hello"}])
|
||||
|
||||
# Verify span attributes were set
|
||||
mock_span.set_attribute.assert_any_call("agent.total_steps", 1)
|
||||
mock_span.set_attribute.assert_any_call("agent.total_tokens", 20)
|
||||
mock_span.set_attribute.assert_any_call("agent.outcome", "success")
|
||||
|
||||
async def test_no_span_when_otel_unavailable(self):
|
||||
"""_OTEL_AVAILABLE=False 时不创建 span"""
|
||||
gateway = make_mock_gateway()
|
||||
engine = ReActEngine(llm_gateway=gateway)
|
||||
|
||||
with patch("agentkit.core.react._OTEL_AVAILABLE", False), \
|
||||
patch("agentkit.core.react.start_span") as mock_start_span:
|
||||
await engine.execute(messages=[{"role": "user", "content": "hello"}])
|
||||
|
||||
# start_span should not be called when OTel is unavailable
|
||||
mock_start_span.assert_not_called()
|
||||
|
|
|
|||
|
|
@ -516,10 +516,13 @@ class TestStdioTransportNotifications:
|
|||
await transport.disconnect()
|
||||
|
||||
async def test_receive_response_no_notification_raises(self):
|
||||
"""空通知队列时 receive_response 超时抛出 TransportError"""
|
||||
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||||
try:
|
||||
await transport.connect()
|
||||
with pytest.raises(TransportError, match="No notification"):
|
||||
# 临时缩短 receive_response 超时
|
||||
transport._timeout = 0.1
|
||||
with pytest.raises(TransportError, match="Timeout"):
|
||||
await transport.receive_response()
|
||||
finally:
|
||||
await transport.disconnect()
|
||||
|
|
|
|||
Loading…
Reference in New Issue