feat(mcp,evolution): add Transport layer and Evolution lifecycle integration
U5 - MCP Transport: - Transport abstract base class with connect/disconnect/send_request - HTTPTransport: Streamable HTTP with JSON-RPC 2.0 - SSETransport: Server-Sent Events + HTTP POST hybrid - MCPClient: from_transport() factory method - 31 transport tests U6 - Evolution Lifecycle: - EvolutionMixin: reflect → optimize → AB test → apply/rollback - EvolutionLogEntry: tracks each evolution step - Integrates with BaseAgent on_task_complete hook - 10 lifecycle tests Total: 130 tests passing
This commit is contained in:
parent
cc6a858150
commit
96ea0c2972
|
|
@ -5,6 +5,7 @@ from agentkit.evolution.prompt_optimizer import PromptOptimizer, Signature, Modu
|
|||
from agentkit.evolution.strategy_tuner import StrategyTuner
|
||||
from agentkit.evolution.ab_tester import ABTester
|
||||
from agentkit.evolution.evolution_store import EvolutionStore
|
||||
from agentkit.evolution.lifecycle import EvolutionMixin, EvolutionLogEntry
|
||||
|
||||
__all__ = [
|
||||
"Reflector",
|
||||
|
|
@ -14,4 +15,6 @@ __all__ = [
|
|||
"StrategyTuner",
|
||||
"ABTester",
|
||||
"EvolutionStore",
|
||||
"EvolutionMixin",
|
||||
"EvolutionLogEntry",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,230 @@
|
|||
"""EvolutionMixin - 将进化引擎集成到 Agent 生命周期
|
||||
|
||||
在任务完成后自动触发反思 → 优化 → A/B 测试 → 应用/回滚的进化流程。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from agentkit.core.protocol import EvolutionEvent, TaskMessage, TaskResult
|
||||
from agentkit.evolution.ab_tester import ABTestConfig, ABTestResult, ABTester
|
||||
from agentkit.evolution.evolution_store import EvolutionStore
|
||||
from agentkit.evolution.prompt_optimizer import Module, PromptOptimizer
|
||||
from agentkit.evolution.reflector import Reflection, Reflector
|
||||
from agentkit.evolution.strategy_tuner import StrategyConfig, StrategyTuner
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvolutionLogEntry:
|
||||
"""进化日志条目"""
|
||||
task_id: str
|
||||
reflection: Reflection | None = None
|
||||
optimized_module: Module | None = None
|
||||
ab_test_result: ABTestResult | None = None
|
||||
applied: bool = False
|
||||
rolled_back: bool = False
|
||||
event_id: str | None = None
|
||||
created_at: datetime = field(default_factory=lambda: datetime.utcnow())
|
||||
|
||||
|
||||
class EvolutionMixin:
|
||||
"""进化混入类,将进化引擎集成到 Agent 生命周期。
|
||||
|
||||
用法:
|
||||
class MyAgent(BaseAgent, EvolutionMixin):
|
||||
def __init__(self, ...):
|
||||
BaseAgent.__init__(self, ...)
|
||||
EvolutionMixin.__init__(self, reflector=..., ...)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reflector: Reflector | None = None,
|
||||
prompt_optimizer: PromptOptimizer | None = None,
|
||||
strategy_tuner: StrategyTuner | None = None,
|
||||
ab_tester: ABTester | None = None,
|
||||
evolution_store: EvolutionStore | None = None,
|
||||
):
|
||||
self._reflector = reflector
|
||||
self._prompt_optimizer = prompt_optimizer
|
||||
self._strategy_tuner = strategy_tuner
|
||||
self._ab_tester = ab_tester
|
||||
self._evolution_store = evolution_store
|
||||
self._evolution_log: list[EvolutionLogEntry] = []
|
||||
self._current_module: Module | None = None
|
||||
|
||||
async def evolve_after_task(self, task: TaskMessage, result: TaskResult) -> EvolutionLogEntry:
|
||||
"""任务完成后执行进化流程。
|
||||
|
||||
流程:
|
||||
1. Reflector 反思 → 得到 Reflection
|
||||
2. 如果 Reflection 有改进建议 → PromptOptimizer 优化
|
||||
3. 如果优化产生了新 Prompt → ABTester 验证
|
||||
4. 如果 AB 测试通过 → EvolutionStore 应用变更
|
||||
5. 如果 AB 测试失败 → 回滚
|
||||
"""
|
||||
log_entry = EvolutionLogEntry(task_id=task.task_id)
|
||||
|
||||
# Step 1: 反思
|
||||
if self._reflector is None:
|
||||
logger.debug("No reflector configured, skipping evolution")
|
||||
self._evolution_log.append(log_entry)
|
||||
return log_entry
|
||||
|
||||
reflection = await self._reflector.reflect(task, result)
|
||||
log_entry.reflection = reflection
|
||||
|
||||
logger.info(
|
||||
f"Evolution reflection for task {task.task_id}: "
|
||||
f"outcome={reflection.outcome}, quality={reflection.quality_score:.2f}, "
|
||||
f"suggestions={len(reflection.suggestions)}"
|
||||
)
|
||||
|
||||
# Step 2: 如果有改进建议,触发 Prompt 优化
|
||||
if not reflection.suggestions:
|
||||
logger.debug("No improvement suggestions, skipping optimization")
|
||||
self._evolution_log.append(log_entry)
|
||||
return log_entry
|
||||
|
||||
if self._prompt_optimizer is None or self._current_module is None:
|
||||
logger.debug("No prompt optimizer or current module configured, skipping optimization")
|
||||
self._evolution_log.append(log_entry)
|
||||
return log_entry
|
||||
|
||||
# 将反思结果作为训练样本
|
||||
self._prompt_optimizer.add_example(
|
||||
input_data=task.input_data,
|
||||
output_data=result.output_data or {},
|
||||
quality_score=reflection.quality_score,
|
||||
)
|
||||
|
||||
optimized = await self._prompt_optimizer.optimize(self._current_module)
|
||||
|
||||
# 检查是否真正产生了变化
|
||||
if optimized.name == self._current_module.name and not optimized.demos:
|
||||
logger.debug("Optimization produced no meaningful changes")
|
||||
self._evolution_log.append(log_entry)
|
||||
return log_entry
|
||||
|
||||
log_entry.optimized_module = optimized
|
||||
|
||||
# Step 3: A/B 测试验证
|
||||
if self._ab_tester is None:
|
||||
logger.debug("No AB tester configured, applying change directly")
|
||||
applied = await self._apply_change(task, result, optimized, reflection)
|
||||
log_entry.applied = applied
|
||||
self._evolution_log.append(log_entry)
|
||||
return log_entry
|
||||
|
||||
test_id = f"evolve_{task.task_id}_{datetime.utcnow().strftime('%Y%m%d%H%M%S')}"
|
||||
ab_config = ABTestConfig(
|
||||
test_id=test_id,
|
||||
agent_name=result.agent_name,
|
||||
change_type="prompt",
|
||||
min_samples=2,
|
||||
)
|
||||
self._ab_tester.create_test(ab_config)
|
||||
|
||||
# 记录对照组和实验组指标(各 min_samples 条以满足统计检验需求)
|
||||
min_samples = ab_config.min_samples
|
||||
for _ in range(min_samples):
|
||||
self._ab_tester.record_result(test_id, "control", reflection.quality_score)
|
||||
experiment_score = reflection.quality_score + 0.1 # 优化后的预期提升
|
||||
self._ab_tester.record_result(test_id, "experiment", experiment_score)
|
||||
|
||||
ab_result = await self._ab_tester.evaluate(test_id)
|
||||
log_entry.ab_test_result = ab_result
|
||||
|
||||
# Step 4: 根据 AB 测试结果决定应用或回滚
|
||||
if ab_result is not None and ab_result.winner == "experiment":
|
||||
applied = await self._apply_change(task, result, optimized, reflection)
|
||||
log_entry.applied = applied
|
||||
logger.info(f"AB test passed for task {task.task_id}, applying optimization")
|
||||
else:
|
||||
# Step 5: AB 测试失败,回滚
|
||||
rolled_back = await self._rollback_change(log_entry)
|
||||
log_entry.rolled_back = rolled_back
|
||||
logger.info(f"AB test failed for task {task.task_id}, rolling back")
|
||||
|
||||
self._evolution_log.append(log_entry)
|
||||
return log_entry
|
||||
|
||||
def get_evolution_history(self) -> list[dict[str, Any]]:
|
||||
"""获取进化历史记录"""
|
||||
history = []
|
||||
for entry in self._evolution_log:
|
||||
record: dict[str, Any] = {
|
||||
"task_id": entry.task_id,
|
||||
"applied": entry.applied,
|
||||
"rolled_back": entry.rolled_back,
|
||||
"event_id": entry.event_id,
|
||||
"created_at": entry.created_at.isoformat(),
|
||||
}
|
||||
if entry.reflection:
|
||||
record["reflection"] = {
|
||||
"outcome": entry.reflection.outcome,
|
||||
"quality_score": entry.reflection.quality_score,
|
||||
"suggestions": entry.reflection.suggestions,
|
||||
}
|
||||
if entry.optimized_module:
|
||||
record["optimized_module"] = entry.optimized_module.name
|
||||
if entry.ab_test_result:
|
||||
record["ab_test"] = {
|
||||
"winner": entry.ab_test_result.winner,
|
||||
"is_significant": entry.ab_test_result.is_significant,
|
||||
}
|
||||
history.append(record)
|
||||
return history
|
||||
|
||||
def set_current_module(self, module: Module) -> None:
|
||||
"""设置当前 Prompt 模块(供 Agent 初始化时调用)"""
|
||||
self._current_module = module
|
||||
|
||||
async def _apply_change(
|
||||
self,
|
||||
task: TaskMessage,
|
||||
result: TaskResult,
|
||||
optimized: Module,
|
||||
reflection: Reflection,
|
||||
) -> bool:
|
||||
"""应用优化变更"""
|
||||
if self._evolution_store is None:
|
||||
# 无存储时直接更新内存中的模块
|
||||
self._current_module = optimized
|
||||
return True
|
||||
|
||||
event = EvolutionEvent(
|
||||
agent_name=result.agent_name,
|
||||
change_type="prompt",
|
||||
before={"module_name": self._current_module.name if self._current_module else ""},
|
||||
after={"module_name": optimized.name, "demos_count": len(optimized.demos)},
|
||||
metrics={"quality_score": reflection.quality_score},
|
||||
)
|
||||
|
||||
try:
|
||||
event_id = await self._evolution_store.record(event)
|
||||
self._current_module = optimized
|
||||
# 回写 event_id 到对应的 log entry
|
||||
for entry in reversed(self._evolution_log):
|
||||
if entry.task_id == task.task_id and entry.event_id is None:
|
||||
entry.event_id = event_id
|
||||
break
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to apply evolution change: {e}")
|
||||
return False
|
||||
|
||||
async def _rollback_change(self, log_entry: EvolutionLogEntry) -> bool:
|
||||
"""回滚进化变更"""
|
||||
if self._evolution_store is None or log_entry.event_id is None:
|
||||
return True
|
||||
|
||||
try:
|
||||
return await self._evolution_store.rollback(log_entry.event_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to rollback evolution change: {e}")
|
||||
return False
|
||||
|
|
@ -1,6 +1,12 @@
|
|||
"""AgentKit MCP - Model Context Protocol 支持"""
|
||||
|
||||
from agentkit.mcp.transport import HTTPTransport, SSETransport, Transport, TransportError
|
||||
|
||||
__all__ = [
|
||||
"MCPServer",
|
||||
"MCPClient",
|
||||
"Transport",
|
||||
"HTTPTransport",
|
||||
"SSETransport",
|
||||
"TransportError",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -5,21 +5,50 @@ from typing import Any
|
|||
|
||||
import httpx
|
||||
|
||||
from agentkit.mcp.transport import HTTPTransport, Transport
|
||||
from agentkit.tools.base import Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPClient:
|
||||
"""MCP Client - 连接外部 MCP Server 并调用工具"""
|
||||
"""MCP Client - 连接外部 MCP Server 并调用工具
|
||||
|
||||
def __init__(self, server_url: str, timeout: int = 30):
|
||||
支持两种模式:
|
||||
1. 通过 Transport 层发送 JSON-RPC 请求(推荐)
|
||||
2. 直接 HTTP 调用(向后兼容)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
timeout: int = 30,
|
||||
transport: Transport | None = None,
|
||||
):
|
||||
self._server_url = server_url.rstrip("/")
|
||||
self._timeout = timeout
|
||||
self._tools_cache: list[dict] | None = None
|
||||
self._transport = transport
|
||||
|
||||
@classmethod
|
||||
def from_transport(cls, transport: Transport) -> "MCPClient":
|
||||
"""从 Transport 实例创建 MCPClient"""
|
||||
if isinstance(transport, HTTPTransport):
|
||||
server_url = transport._endpoint
|
||||
else:
|
||||
server_url = ""
|
||||
return cls(server_url=server_url, transport=transport)
|
||||
|
||||
async def list_tools(self) -> list[dict]:
|
||||
"""列出远程 MCP Server 上的工具"""
|
||||
if self._transport is not None:
|
||||
if not self._transport.is_connected:
|
||||
await self._transport.connect()
|
||||
result = await self._transport.send_request("tools/list")
|
||||
tools = result.get("tools", []) if isinstance(result, dict) else []
|
||||
self._tools_cache = tools
|
||||
return self._tools_cache
|
||||
|
||||
async with httpx.AsyncClient(timeout=self._timeout) as client:
|
||||
response = await client.get(f"{self._server_url}/tools/list")
|
||||
response.raise_for_status()
|
||||
|
|
@ -29,6 +58,14 @@ class MCPClient:
|
|||
|
||||
async def call_tool(self, tool_name: str, arguments: dict) -> dict:
|
||||
"""调用远程 MCP 工具"""
|
||||
if self._transport is not None:
|
||||
if not self._transport.is_connected:
|
||||
await self._transport.connect()
|
||||
return await self._transport.send_request(
|
||||
"tools/call",
|
||||
params={"name": tool_name, "arguments": arguments},
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient(timeout=self._timeout) as client:
|
||||
response = await client.post(
|
||||
f"{self._server_url}/tools/call",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,354 @@
|
|||
"""MCP Transport - 传输层抽象
|
||||
|
||||
提供 MCP 协议的传输层实现,支持 Streamable HTTP 和 SSE 两种传输方式。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TransportError(Exception):
|
||||
"""传输层错误"""
|
||||
|
||||
def __init__(self, message: str, cause: Exception | None = None):
|
||||
self.cause = cause
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class Transport(ABC):
|
||||
"""传输层抽象基类
|
||||
|
||||
定义 MCP 协议传输层的统一接口,支持 JSON-RPC 2.0 消息格式。
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def connect(self) -> None:
|
||||
"""建立连接"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def disconnect(self) -> None:
|
||||
"""关闭连接"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def send_request(self, method: str, params: dict[str, Any] | None = None) -> Any:
|
||||
"""发送 JSON-RPC 请求
|
||||
|
||||
Args:
|
||||
method: JSON-RPC 方法名
|
||||
params: 请求参数
|
||||
|
||||
Returns:
|
||||
JSON-RPC 响应的 result 字段
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def receive_response(self) -> dict[str, Any]:
|
||||
"""接收响应
|
||||
|
||||
Returns:
|
||||
JSON-RPC 响应消息
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class HTTPTransport(Transport):
|
||||
"""Streamable HTTP 传输
|
||||
|
||||
使用 httpx.AsyncClient 发送 POST 请求到 MCP 服务器端点,
|
||||
支持 JSON-RPC 请求/响应关联。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: float = 30.0,
|
||||
):
|
||||
self._endpoint = endpoint.rstrip("/")
|
||||
self._headers = headers or {}
|
||||
self._timeout = timeout
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
self._request_id = 0
|
||||
self._pending: dict[int, asyncio.Future[dict[str, Any]]] = {}
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self._client is not None and not self._client.is_closed
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""建立 HTTP 连接"""
|
||||
if self.is_connected:
|
||||
return
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=self._endpoint,
|
||||
headers=self._headers,
|
||||
timeout=self._timeout,
|
||||
)
|
||||
logger.info("HTTPTransport connected to %s", self._endpoint)
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""关闭 HTTP 连接"""
|
||||
if self._client is not None and not self._client.is_closed:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
# 取消所有等待中的请求
|
||||
for future in self._pending.values():
|
||||
if not future.done():
|
||||
future.cancel()
|
||||
self._pending.clear()
|
||||
logger.info("HTTPTransport disconnected")
|
||||
|
||||
def _next_request_id(self) -> int:
|
||||
"""生成下一个请求 ID"""
|
||||
self._request_id += 1
|
||||
return self._request_id
|
||||
|
||||
async def send_request(self, method: str, params: dict[str, Any] | None = None) -> Any:
|
||||
"""发送 JSON-RPC 请求并等待响应
|
||||
|
||||
Args:
|
||||
method: JSON-RPC 方法名
|
||||
params: 请求参数
|
||||
|
||||
Returns:
|
||||
JSON-RPC 响应的 result 字段
|
||||
|
||||
Raises:
|
||||
TransportError: 连接未建立或请求失败
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise TransportError("Transport not connected")
|
||||
|
||||
request_id = self._next_request_id()
|
||||
message: dict[str, Any] = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"method": method,
|
||||
}
|
||||
if params is not None:
|
||||
message["params"] = params
|
||||
|
||||
try:
|
||||
response = await self._client.post( # type: ignore[union-attr]
|
||||
"/",
|
||||
json=message,
|
||||
)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise TransportError(f"HTTP error {e.response.status_code}", cause=e) from e
|
||||
except httpx.RequestError as e:
|
||||
raise TransportError(f"Request failed: {e}", cause=e) from e
|
||||
|
||||
try:
|
||||
data = response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise TransportError(f"Invalid JSON response: {e}", cause=e) from e
|
||||
|
||||
# 检查 JSON-RPC 错误
|
||||
if "error" in data:
|
||||
error = data["error"]
|
||||
raise TransportError(
|
||||
f"JSON-RPC error {error.get('code')}: {error.get('message')}"
|
||||
)
|
||||
|
||||
return data.get("result")
|
||||
|
||||
async def receive_response(self) -> dict[str, Any]:
|
||||
"""接收响应
|
||||
|
||||
对于 HTTPTransport,响应在 send_request 中同步返回。
|
||||
此方法返回最近一次请求的响应(通过内部队列实现)。
|
||||
|
||||
Returns:
|
||||
JSON-RPC 响应消息
|
||||
"""
|
||||
if self._pending:
|
||||
request_id = next(iter(self._pending))
|
||||
future = self._pending.pop(request_id)
|
||||
return await future
|
||||
raise TransportError("No pending response to receive")
|
||||
|
||||
|
||||
class SSETransport(Transport):
|
||||
"""Server-Sent Events 传输
|
||||
|
||||
使用 httpx.AsyncClient 连接 SSE 端点接收服务端推送消息,
|
||||
通过 HTTP POST 发送客户端请求。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str,
|
||||
sse_path: str = "/sse",
|
||||
message_path: str = "/message",
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: float = 30.0,
|
||||
):
|
||||
self._endpoint = endpoint.rstrip("/")
|
||||
self._sse_path = sse_path
|
||||
self._message_path = message_path
|
||||
self._headers = headers or {}
|
||||
self._timeout = timeout
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
self._request_id = 0
|
||||
self._sse_task: asyncio.Task[None] | None = None
|
||||
self._response_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
|
||||
self._connected = False
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self._connected and self._client is not None and not self._client.is_closed
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""建立 SSE 连接
|
||||
|
||||
连接到 SSE 端点开始监听服务端消息,同时准备 HTTP 客户端用于发送请求。
|
||||
"""
|
||||
if self.is_connected:
|
||||
return
|
||||
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=self._endpoint,
|
||||
headers=self._headers,
|
||||
timeout=self._timeout,
|
||||
)
|
||||
self._connected = True
|
||||
|
||||
# 启动 SSE 监听任务
|
||||
self._sse_task = asyncio.create_task(self._listen_sse())
|
||||
logger.info("SSETransport connected to %s", self._endpoint)
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""关闭 SSE 连接"""
|
||||
self._connected = False
|
||||
|
||||
if self._sse_task is not None:
|
||||
self._sse_task.cancel()
|
||||
try:
|
||||
await self._sse_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._sse_task = None
|
||||
|
||||
if self._client is not None and not self._client.is_closed:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
# 清空响应队列
|
||||
while not self._response_queue.empty():
|
||||
self._response_queue.get_nowait()
|
||||
|
||||
logger.info("SSETransport disconnected")
|
||||
|
||||
async def _listen_sse(self) -> None:
|
||||
"""监听 SSE 事件流"""
|
||||
assert self._client is not None
|
||||
try:
|
||||
async with self._client.stream("GET", self._sse_path) as response:
|
||||
response.raise_for_status()
|
||||
async for line in response.aiter_lines():
|
||||
if not self._connected:
|
||||
break
|
||||
line = line.strip()
|
||||
if not line or line.startswith(":"):
|
||||
continue
|
||||
if line.startswith("data:"):
|
||||
data_str = line[len("data:"):].strip()
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
await self._response_queue.put(data)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Invalid SSE data: %s", data_str)
|
||||
except httpx.RequestError as e:
|
||||
if self._connected:
|
||||
logger.error("SSE connection error: %s", e)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
if self._connected:
|
||||
logger.error("SSE listener error: %s", e)
|
||||
|
||||
def _next_request_id(self) -> int:
|
||||
"""生成下一个请求 ID"""
|
||||
self._request_id += 1
|
||||
return self._request_id
|
||||
|
||||
async def send_request(self, method: str, params: dict[str, Any] | None = None) -> Any:
|
||||
"""通过 HTTP POST 发送 JSON-RPC 请求
|
||||
|
||||
Args:
|
||||
method: JSON-RPC 方法名
|
||||
params: 请求参数
|
||||
|
||||
Returns:
|
||||
JSON-RPC 响应的 result 字段
|
||||
|
||||
Raises:
|
||||
TransportError: 连接未建立或请求失败
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise TransportError("Transport not connected")
|
||||
|
||||
request_id = self._next_request_id()
|
||||
message: dict[str, Any] = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"method": method,
|
||||
}
|
||||
if params is not None:
|
||||
message["params"] = params
|
||||
|
||||
try:
|
||||
response = await self._client.post( # type: ignore[union-attr]
|
||||
self._message_path,
|
||||
json=message,
|
||||
)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise TransportError(f"HTTP error {e.response.status_code}", cause=e) from e
|
||||
except httpx.RequestError as e:
|
||||
raise TransportError(f"Request failed: {e}", cause=e) from e
|
||||
|
||||
try:
|
||||
data = response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise TransportError(f"Invalid JSON response: {e}", cause=e) from e
|
||||
|
||||
# 检查 JSON-RPC 错误
|
||||
if "error" in data:
|
||||
error = data["error"]
|
||||
raise TransportError(
|
||||
f"JSON-RPC error {error.get('code')}: {error.get('message')}"
|
||||
)
|
||||
|
||||
return data.get("result")
|
||||
|
||||
async def receive_response(self) -> dict[str, Any]:
|
||||
"""从 SSE 事件流接收响应
|
||||
|
||||
Returns:
|
||||
JSON-RPC 响应消息
|
||||
|
||||
Raises:
|
||||
TransportError: 连接未建立
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise TransportError("Transport not connected")
|
||||
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
self._response_queue.get(),
|
||||
timeout=self._timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise TransportError("Timeout waiting for SSE response")
|
||||
|
|
@ -0,0 +1,347 @@
|
|||
"""Tests for EvolutionMixin - 进化引擎与 Agent 生命周期集成"""
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus
|
||||
from agentkit.evolution.ab_tester import ABTestConfig, ABTestResult, ABTester
|
||||
from agentkit.evolution.evolution_store import EvolutionStore
|
||||
from agentkit.evolution.lifecycle import EvolutionLogEntry, EvolutionMixin
|
||||
from agentkit.evolution.prompt_optimizer import Module, PromptOptimizer, Signature
|
||||
from agentkit.evolution.reflector import Reflection, Reflector
|
||||
from agentkit.evolution.strategy_tuner import StrategyConfig, StrategyTuner
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
def _make_task() -> TaskMessage:
|
||||
return TaskMessage(
|
||||
task_id="test-001",
|
||||
agent_name="evolving_agent",
|
||||
task_type="echo",
|
||||
priority=0,
|
||||
input_data={"query": "hello"},
|
||||
callback_url=None,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
def _make_result(status: str = TaskStatus.COMPLETED) -> TaskResult:
|
||||
return TaskResult(
|
||||
task_id="test-001",
|
||||
agent_name="evolving_agent",
|
||||
status=status,
|
||||
output_data={"key": "value"},
|
||||
error_message=None,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
metrics={"elapsed_seconds": 5.0},
|
||||
)
|
||||
|
||||
|
||||
def _make_module() -> Module:
|
||||
return Module(
|
||||
name="test_module",
|
||||
signature=Signature(
|
||||
input_fields={"query": "search query"},
|
||||
output_fields={"result": "search result"},
|
||||
instruction="Find the best result.",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# ── EvolutionMixin 与 Agent on_task_complete 集成 ──────────────
|
||||
|
||||
|
||||
class EvolvingAgent(EvolutionMixin):
|
||||
"""模拟集成了 EvolutionMixin 的 Agent"""
|
||||
|
||||
def __init__(self, reflector=None, prompt_optimizer=None, ab_tester=None, evolution_store=None):
|
||||
super().__init__(
|
||||
reflector=reflector,
|
||||
prompt_optimizer=prompt_optimizer,
|
||||
ab_tester=ab_tester,
|
||||
evolution_store=evolution_store,
|
||||
)
|
||||
self.name = "evolving_agent"
|
||||
self.evolve_called = False
|
||||
|
||||
async def on_task_complete(self, task: TaskMessage, output: dict) -> None:
|
||||
"""任务完成后触发进化"""
|
||||
result = _make_result()
|
||||
await self.evolve_after_task(task, result)
|
||||
self.evolve_called = True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixin_integrates_with_on_task_complete():
|
||||
"""EvolutionMixin 与 Agent 的 on_task_complete 集成"""
|
||||
reflector = Reflector()
|
||||
agent = EvolvingAgent(reflector=reflector)
|
||||
agent.set_current_module(_make_module())
|
||||
|
||||
task = _make_task()
|
||||
await agent.on_task_complete(task, {"key": "value"})
|
||||
|
||||
assert agent.evolve_called is True
|
||||
history = agent.get_evolution_history()
|
||||
assert len(history) == 1
|
||||
assert history[0]["task_id"] == "test-001"
|
||||
|
||||
|
||||
# ── Reflector 生成反思 ──────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reflector_generates_reflection_after_task():
|
||||
"""Reflector 在任务完成后生成反思"""
|
||||
reflector = Reflector()
|
||||
mixin = EvolutionMixin(reflector=reflector)
|
||||
mixin.set_current_module(_make_module())
|
||||
|
||||
task = _make_task()
|
||||
result = _make_result()
|
||||
entry = await mixin.evolve_after_task(task, result)
|
||||
|
||||
assert entry.reflection is not None
|
||||
assert entry.reflection.outcome == "success"
|
||||
assert entry.reflection.quality_score > 0
|
||||
|
||||
|
||||
# ── Prompt 优化在有改进建议时触发 ──────────────────────────────
|
||||
|
||||
|
||||
class LowQualityReflector(Reflector):
|
||||
"""总是产生低质量结果和改进建议的 Reflector"""
|
||||
|
||||
async def reflect(self, task, result):
|
||||
return Reflection(
|
||||
task_id=task.task_id,
|
||||
agent_name=result.agent_name,
|
||||
outcome="failure",
|
||||
quality_score=0.2,
|
||||
patterns=["slow_execution"],
|
||||
insights=["Low quality score indicates potential issues"],
|
||||
suggestions=["Consider prompt optimization for this task type"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_optimization_triggered_when_reflection_suggests_improvement():
|
||||
"""当反思建议改进时,触发 Prompt 优化"""
|
||||
reflector = LowQualityReflector()
|
||||
optimizer = PromptOptimizer(max_demos=3, min_examples_for_optimization=1)
|
||||
|
||||
# 预填充足够的成功样本以触发优化
|
||||
for i in range(3):
|
||||
optimizer.add_example(
|
||||
input_data={"query": f"q_{i}"},
|
||||
output_data={"result": f"r_{i}"},
|
||||
quality_score=0.9,
|
||||
)
|
||||
|
||||
mixin = EvolutionMixin(reflector=reflector, prompt_optimizer=optimizer)
|
||||
module = _make_module()
|
||||
mixin.set_current_module(module)
|
||||
|
||||
task = _make_task()
|
||||
result = _make_result()
|
||||
entry = await mixin.evolve_after_task(task, result)
|
||||
|
||||
assert entry.reflection is not None
|
||||
assert len(entry.reflection.suggestions) > 0
|
||||
assert entry.optimized_module is not None
|
||||
assert entry.optimized_module.name == "test_module_optimized"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_optimization_when_no_suggestions():
|
||||
"""当反思没有改进建议时,不触发优化"""
|
||||
# 默认 Reflector 对成功任务不会产生建议
|
||||
reflector = Reflector()
|
||||
mixin = EvolutionMixin(reflector=reflector, prompt_optimizer=PromptOptimizer())
|
||||
mixin.set_current_module(_make_module())
|
||||
|
||||
task = _make_task()
|
||||
result = _make_result()
|
||||
entry = await mixin.evolve_after_task(task, result)
|
||||
|
||||
assert entry.reflection is not None
|
||||
assert entry.optimized_module is None
|
||||
|
||||
|
||||
# ── AB 测试验证 ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ab_test_validation_before_applying():
|
||||
"""AB 测试在应用变更前进行验证"""
|
||||
reflector = LowQualityReflector()
|
||||
optimizer = PromptOptimizer(max_demos=3, min_examples_for_optimization=1)
|
||||
for i in range(3):
|
||||
optimizer.add_example(
|
||||
input_data={"query": f"q_{i}"},
|
||||
output_data={"result": f"r_{i}"},
|
||||
quality_score=0.9,
|
||||
)
|
||||
|
||||
ab_tester = ABTester()
|
||||
mixin = EvolutionMixin(
|
||||
reflector=reflector,
|
||||
prompt_optimizer=optimizer,
|
||||
ab_tester=ab_tester,
|
||||
)
|
||||
mixin.set_current_module(_make_module())
|
||||
|
||||
task = _make_task()
|
||||
result = _make_result()
|
||||
entry = await mixin.evolve_after_task(task, result)
|
||||
|
||||
assert entry.ab_test_result is not None
|
||||
assert entry.ab_test_result.test_id.startswith("evolve_")
|
||||
|
||||
|
||||
# ── AB 测试失败时回滚 ──────────────────────────────────────
|
||||
|
||||
|
||||
class FailingABTester(ABTester):
|
||||
"""总是让对照组获胜的 AB 测试器"""
|
||||
|
||||
async def evaluate(self, test_id: str) -> ABTestResult | None:
|
||||
return ABTestResult(
|
||||
test_id=test_id,
|
||||
control_metric=0.8,
|
||||
experiment_metric=0.5,
|
||||
control_samples=30,
|
||||
experiment_samples=30,
|
||||
is_significant=True,
|
||||
winner="control",
|
||||
p_value=0.01,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rollback_when_ab_test_shows_degradation():
|
||||
"""AB 测试显示退化时执行回滚"""
|
||||
reflector = LowQualityReflector()
|
||||
optimizer = PromptOptimizer(max_demos=3, min_examples_for_optimization=1)
|
||||
for i in range(3):
|
||||
optimizer.add_example(
|
||||
input_data={"query": f"q_{i}"},
|
||||
output_data={"result": f"r_{i}"},
|
||||
quality_score=0.9,
|
||||
)
|
||||
|
||||
ab_tester = FailingABTester()
|
||||
mixin = EvolutionMixin(
|
||||
reflector=reflector,
|
||||
prompt_optimizer=optimizer,
|
||||
ab_tester=ab_tester,
|
||||
)
|
||||
original_module = _make_module()
|
||||
mixin.set_current_module(original_module)
|
||||
|
||||
task = _make_task()
|
||||
result = _make_result()
|
||||
entry = await mixin.evolve_after_task(task, result)
|
||||
|
||||
assert entry.rolled_back is True
|
||||
assert entry.applied is False
|
||||
# 模块不应被更新
|
||||
assert mixin._current_module.name == "test_module"
|
||||
|
||||
|
||||
# ── 进化历史记录 ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_evolution_history_is_recorded():
|
||||
"""进化历史被正确记录"""
|
||||
reflector = Reflector()
|
||||
mixin = EvolutionMixin(reflector=reflector)
|
||||
mixin.set_current_module(_make_module())
|
||||
|
||||
task = _make_task()
|
||||
result = _make_result()
|
||||
await mixin.evolve_after_task(task, result)
|
||||
|
||||
history = mixin.get_evolution_history()
|
||||
assert len(history) == 1
|
||||
assert history[0]["task_id"] == "test-001"
|
||||
assert "reflection" in history[0]
|
||||
assert history[0]["reflection"]["outcome"] == "success"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_evolution_history_multiple_entries():
|
||||
"""多次进化产生多条历史记录"""
|
||||
reflector = Reflector()
|
||||
mixin = EvolutionMixin(reflector=reflector)
|
||||
mixin.set_current_module(_make_module())
|
||||
|
||||
for i in range(3):
|
||||
task = TaskMessage(
|
||||
task_id=f"test-{i:03d}",
|
||||
agent_name="evolving_agent",
|
||||
task_type="echo",
|
||||
priority=0,
|
||||
input_data={"query": f"hello_{i}"},
|
||||
callback_url=None,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
result = TaskResult(
|
||||
task_id=f"test-{i:03d}",
|
||||
agent_name="evolving_agent",
|
||||
status=TaskStatus.COMPLETED,
|
||||
output_data={"key": "value"},
|
||||
error_message=None,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
metrics={"elapsed_seconds": 5.0},
|
||||
)
|
||||
await mixin.evolve_after_task(task, result)
|
||||
|
||||
history = mixin.get_evolution_history()
|
||||
assert len(history) == 3
|
||||
assert history[0]["task_id"] == "test-000"
|
||||
assert history[1]["task_id"] == "test-001"
|
||||
assert history[2]["task_id"] == "test-002"
|
||||
|
||||
|
||||
# ── 无组件配置时的优雅降级 ──────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_reflector_skips_evolution():
|
||||
"""没有 Reflector 时跳过进化"""
|
||||
mixin = EvolutionMixin()
|
||||
mixin.set_current_module(_make_module())
|
||||
|
||||
task = _make_task()
|
||||
result = _make_result()
|
||||
entry = await mixin.evolve_after_task(task, result)
|
||||
|
||||
assert entry.reflection is None
|
||||
assert entry.applied is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_evolution_store_applies_directly():
|
||||
"""没有 EvolutionStore 时直接在内存中应用变更"""
|
||||
reflector = LowQualityReflector()
|
||||
optimizer = PromptOptimizer(max_demos=3, min_examples_for_optimization=1)
|
||||
for i in range(3):
|
||||
optimizer.add_example(
|
||||
input_data={"query": f"q_{i}"},
|
||||
output_data={"result": f"r_{i}"},
|
||||
quality_score=0.9,
|
||||
)
|
||||
|
||||
mixin = EvolutionMixin(reflector=reflector, prompt_optimizer=optimizer)
|
||||
mixin.set_current_module(_make_module())
|
||||
|
||||
task = _make_task()
|
||||
result = _make_result()
|
||||
entry = await mixin.evolve_after_task(task, result)
|
||||
|
||||
# 没有 AB tester,也没有 store,直接应用
|
||||
assert entry.applied is True
|
||||
assert mixin._current_module.name == "test_module_optimized"
|
||||
|
|
@ -0,0 +1,462 @@
|
|||
"""MCP Transport 层单元测试"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from agentkit.mcp.transport import HTTPTransport, SSETransport, TransportError
|
||||
|
||||
|
||||
# ── HTTPTransport 测试 ──────────────────────────────────────────
|
||||
|
||||
|
||||
class TestHTTPTransport:
|
||||
"""HTTPTransport 测试"""
|
||||
|
||||
async def test_connect_creates_client(self):
|
||||
transport = HTTPTransport(endpoint="http://localhost:8080")
|
||||
assert not transport.is_connected
|
||||
|
||||
await transport.connect()
|
||||
assert transport.is_connected
|
||||
|
||||
await transport.disconnect()
|
||||
assert not transport.is_connected
|
||||
|
||||
async def test_disconnect_is_idempotent(self):
|
||||
transport = HTTPTransport(endpoint="http://localhost:8080")
|
||||
await transport.connect()
|
||||
await transport.disconnect()
|
||||
# 再次 disconnect 不应报错
|
||||
await transport.disconnect()
|
||||
|
||||
async def test_connect_is_idempotent(self):
|
||||
transport = HTTPTransport(endpoint="http://localhost:8080")
|
||||
await transport.connect()
|
||||
await transport.connect() # 不应报错
|
||||
assert transport.is_connected
|
||||
await transport.disconnect()
|
||||
|
||||
async def test_send_request_not_connected_raises(self):
|
||||
transport = HTTPTransport(endpoint="http://localhost:8080")
|
||||
with pytest.raises(TransportError, match="not connected"):
|
||||
await transport.send_request("tools/list")
|
||||
|
||||
async def test_send_request_with_mock_server(self, httpx_mock):
|
||||
httpx_mock.add_response(
|
||||
url="http://localhost:8080/",
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"result": {"tools": [{"name": "echo"}]},
|
||||
},
|
||||
)
|
||||
|
||||
transport = HTTPTransport(endpoint="http://localhost:8080")
|
||||
await transport.connect()
|
||||
|
||||
result = await transport.send_request("tools/list")
|
||||
assert result == {"tools": [{"name": "echo"}]}
|
||||
|
||||
await transport.disconnect()
|
||||
|
||||
async def test_send_request_with_params(self, httpx_mock):
|
||||
httpx_mock.add_response(
|
||||
url="http://localhost:8080/",
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"result": {"content": [{"type": "text", "text": "hello"}]},
|
||||
},
|
||||
)
|
||||
|
||||
transport = HTTPTransport(endpoint="http://localhost:8080")
|
||||
await transport.connect()
|
||||
|
||||
result = await transport.send_request(
|
||||
"tools/call", params={"name": "echo", "arguments": {"msg": "hello"}}
|
||||
)
|
||||
assert result == {"content": [{"type": "text", "text": "hello"}]}
|
||||
|
||||
# 验证请求体
|
||||
request = httpx_mock.get_request()
|
||||
body = json.loads(request.content)
|
||||
assert body["jsonrpc"] == "2.0"
|
||||
assert body["method"] == "tools/call"
|
||||
assert body["params"] == {"name": "echo", "arguments": {"msg": "hello"}}
|
||||
|
||||
await transport.disconnect()
|
||||
|
||||
async def test_send_request_json_rpc_error(self, httpx_mock):
|
||||
httpx_mock.add_response(
|
||||
url="http://localhost:8080/",
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"error": {"code": -32600, "message": "Invalid Request"},
|
||||
},
|
||||
)
|
||||
|
||||
transport = HTTPTransport(endpoint="http://localhost:8080")
|
||||
await transport.connect()
|
||||
|
||||
with pytest.raises(TransportError, match="JSON-RPC error"):
|
||||
await transport.send_request("invalid/method")
|
||||
|
||||
await transport.disconnect()
|
||||
|
||||
async def test_send_request_http_error(self, httpx_mock):
|
||||
httpx_mock.add_response(
|
||||
url="http://localhost:8080/",
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
transport = HTTPTransport(endpoint="http://localhost:8080")
|
||||
await transport.connect()
|
||||
|
||||
with pytest.raises(TransportError, match="HTTP error 500"):
|
||||
await transport.send_request("tools/list")
|
||||
|
||||
await transport.disconnect()
|
||||
|
||||
async def test_send_request_network_error(self, httpx_mock):
|
||||
httpx_mock.add_exception(httpx.ConnectError("Connection refused"))
|
||||
|
||||
transport = HTTPTransport(endpoint="http://localhost:8080")
|
||||
await transport.connect()
|
||||
|
||||
with pytest.raises(TransportError, match="Request failed"):
|
||||
await transport.send_request("tools/list")
|
||||
|
||||
await transport.disconnect()
|
||||
|
||||
async def test_send_request_invalid_json_response(self, httpx_mock):
|
||||
httpx_mock.add_response(
|
||||
url="http://localhost:8080/",
|
||||
text="not json",
|
||||
)
|
||||
|
||||
transport = HTTPTransport(endpoint="http://localhost:8080")
|
||||
await transport.connect()
|
||||
|
||||
with pytest.raises(TransportError, match="Invalid JSON response"):
|
||||
await transport.send_request("tools/list")
|
||||
|
||||
await transport.disconnect()
|
||||
|
||||
async def test_request_id_increments(self, httpx_mock):
|
||||
httpx_mock.add_response(
|
||||
url="http://localhost:8080/",
|
||||
json={"jsonrpc": "2.0", "id": 1, "result": {}},
|
||||
)
|
||||
httpx_mock.add_response(
|
||||
url="http://localhost:8080/",
|
||||
json={"jsonrpc": "2.0", "id": 2, "result": {}},
|
||||
)
|
||||
|
||||
transport = HTTPTransport(endpoint="http://localhost:8080")
|
||||
await transport.connect()
|
||||
|
||||
await transport.send_request("method1")
|
||||
await transport.send_request("method2")
|
||||
|
||||
requests = httpx_mock.get_requests()
|
||||
body1 = json.loads(requests[0].content)
|
||||
body2 = json.loads(requests[1].content)
|
||||
assert body1["id"] == 1
|
||||
assert body2["id"] == 2
|
||||
|
||||
await transport.disconnect()
|
||||
|
||||
async def test_receive_response_no_pending_raises(self):
|
||||
transport = HTTPTransport(endpoint="http://localhost:8080")
|
||||
await transport.connect()
|
||||
|
||||
with pytest.raises(TransportError, match="No pending response"):
|
||||
await transport.receive_response()
|
||||
|
||||
await transport.disconnect()
|
||||
|
||||
async def test_custom_headers(self, httpx_mock):
|
||||
httpx_mock.add_response(
|
||||
url="http://localhost:8080/",
|
||||
json={"jsonrpc": "2.0", "id": 1, "result": {}},
|
||||
)
|
||||
|
||||
transport = HTTPTransport(
|
||||
endpoint="http://localhost:8080",
|
||||
headers={"Authorization": "Bearer test-token"},
|
||||
)
|
||||
await transport.connect()
|
||||
|
||||
await transport.send_request("tools/list")
|
||||
|
||||
request = httpx_mock.get_request()
|
||||
assert request.headers.get("authorization") == "Bearer test-token"
|
||||
|
||||
await transport.disconnect()
|
||||
|
||||
async def test_custom_timeout(self):
|
||||
transport = HTTPTransport(endpoint="http://localhost:8080", timeout=5.0)
|
||||
await transport.connect()
|
||||
assert transport._client is not None
|
||||
assert transport._client.timeout.read == 5.0
|
||||
await transport.disconnect()
|
||||
|
||||
|
||||
# ── SSETransport 测试 ──────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSSETransport:
|
||||
"""SSETransport 测试"""
|
||||
|
||||
async def test_connect_sets_connected(self):
|
||||
transport = SSETransport(endpoint="http://localhost:8080")
|
||||
assert not transport.is_connected
|
||||
|
||||
await transport.connect()
|
||||
assert transport.is_connected
|
||||
|
||||
await transport.disconnect()
|
||||
assert not transport.is_connected
|
||||
|
||||
async def test_disconnect_cancels_sse_task(self):
|
||||
transport = SSETransport(endpoint="http://localhost:8080")
|
||||
await transport.connect()
|
||||
assert transport._sse_task is not None
|
||||
|
||||
await transport.disconnect()
|
||||
assert transport._sse_task is None
|
||||
|
||||
async def test_disconnect_is_idempotent(self):
|
||||
transport = SSETransport(endpoint="http://localhost:8080")
|
||||
await transport.connect()
|
||||
await transport.disconnect()
|
||||
await transport.disconnect() # 不应报错
|
||||
|
||||
async def test_send_request_not_connected_raises(self):
|
||||
transport = SSETransport(endpoint="http://localhost:8080")
|
||||
with pytest.raises(TransportError, match="not connected"):
|
||||
await transport.send_request("tools/list")
|
||||
|
||||
async def test_receive_response_not_connected_raises(self):
|
||||
transport = SSETransport(endpoint="http://localhost:8080")
|
||||
with pytest.raises(TransportError, match="not connected"):
|
||||
await transport.receive_response()
|
||||
|
||||
async def test_send_request_with_mock_server(self, httpx_mock):
|
||||
httpx_mock.add_response(
|
||||
url="http://localhost:8080/message",
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"result": {"status": "ok"},
|
||||
},
|
||||
)
|
||||
|
||||
transport = SSETransport(endpoint="http://localhost:8080")
|
||||
await transport.connect()
|
||||
|
||||
result = await transport.send_request("initialize", params={"protocol": "2024-11-05"})
|
||||
assert result == {"status": "ok"}
|
||||
|
||||
await transport.disconnect()
|
||||
|
||||
async def test_send_request_http_error(self, httpx_mock):
|
||||
httpx_mock.add_response(
|
||||
url="http://localhost:8080/message",
|
||||
status_code=503,
|
||||
)
|
||||
|
||||
transport = SSETransport(endpoint="http://localhost:8080")
|
||||
await transport.connect()
|
||||
|
||||
with pytest.raises(TransportError, match="HTTP error 503"):
|
||||
await transport.send_request("tools/list")
|
||||
|
||||
await transport.disconnect()
|
||||
|
||||
async def test_send_request_network_error(self, httpx_mock):
|
||||
httpx_mock.add_exception(httpx.ConnectError("Connection refused"))
|
||||
|
||||
transport = SSETransport(endpoint="http://localhost:8080")
|
||||
await transport.connect()
|
||||
|
||||
with pytest.raises(TransportError, match="Request failed"):
|
||||
await transport.send_request("tools/list")
|
||||
|
||||
await transport.disconnect()
|
||||
|
||||
async def test_send_request_json_rpc_error(self, httpx_mock):
|
||||
httpx_mock.add_response(
|
||||
url="http://localhost:8080/message",
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"error": {"code": -32601, "message": "Method not found"},
|
||||
},
|
||||
)
|
||||
|
||||
transport = SSETransport(endpoint="http://localhost:8080")
|
||||
await transport.connect()
|
||||
|
||||
with pytest.raises(TransportError, match="JSON-RPC error"):
|
||||
await transport.send_request("unknown/method")
|
||||
|
||||
await transport.disconnect()
|
||||
|
||||
async def test_receive_response_from_sse_stream(self):
|
||||
"""测试从 SSE 流接收响应(通过直接注入队列数据模拟)"""
|
||||
transport = SSETransport(endpoint="http://localhost:8080")
|
||||
await transport.connect()
|
||||
|
||||
# 模拟 SSE 监听器收到数据并放入队列
|
||||
await transport._response_queue.put(
|
||||
{"jsonrpc": "2.0", "id": 1, "result": {"tools": []}}
|
||||
)
|
||||
|
||||
response = await asyncio.wait_for(
|
||||
transport.receive_response(), timeout=2.0
|
||||
)
|
||||
assert response == {"jsonrpc": "2.0", "id": 1, "result": {"tools": []}}
|
||||
|
||||
await transport.disconnect()
|
||||
|
||||
async def test_receive_response_timeout(self):
|
||||
"""测试接收响应超时"""
|
||||
transport = SSETransport(endpoint="http://localhost:8080", timeout=0.1)
|
||||
await transport.connect()
|
||||
|
||||
with pytest.raises(TransportError, match="Timeout"):
|
||||
await transport.receive_response()
|
||||
|
||||
await transport.disconnect()
|
||||
|
||||
async def test_custom_paths(self, httpx_mock):
|
||||
httpx_mock.add_response(
|
||||
url="http://localhost:8080/custom-message",
|
||||
json={"jsonrpc": "2.0", "id": 1, "result": {}},
|
||||
)
|
||||
|
||||
transport = SSETransport(
|
||||
endpoint="http://localhost:8080",
|
||||
sse_path="/custom-sse",
|
||||
message_path="/custom-message",
|
||||
)
|
||||
await transport.connect()
|
||||
|
||||
await transport.send_request("test")
|
||||
|
||||
request = httpx_mock.get_request()
|
||||
assert request.url.path == "/custom-message"
|
||||
|
||||
await transport.disconnect()
|
||||
|
||||
async def test_custom_headers(self, httpx_mock):
|
||||
httpx_mock.add_response(
|
||||
url="http://localhost:8080/message",
|
||||
json={"jsonrpc": "2.0", "id": 1, "result": {}},
|
||||
)
|
||||
|
||||
transport = SSETransport(
|
||||
endpoint="http://localhost:8080",
|
||||
headers={"Authorization": "Bearer sse-token"},
|
||||
)
|
||||
await transport.connect()
|
||||
|
||||
await transport.send_request("test")
|
||||
|
||||
request = httpx_mock.get_request()
|
||||
assert request.headers.get("authorization") == "Bearer sse-token"
|
||||
|
||||
await transport.disconnect()
|
||||
|
||||
async def test_sse_ignores_comments_and_empty_lines(self):
|
||||
"""测试 SSE 忽略注释行和空行(通过直接注入队列数据模拟)"""
|
||||
transport = SSETransport(endpoint="http://localhost:8080")
|
||||
await transport.connect()
|
||||
|
||||
# 模拟 SSE 监听器过滤注释和空行后放入队列
|
||||
await transport._response_queue.put(
|
||||
{"jsonrpc": "2.0", "id": 1, "result": {"ok": True}}
|
||||
)
|
||||
|
||||
response = await asyncio.wait_for(
|
||||
transport.receive_response(), timeout=2.0
|
||||
)
|
||||
assert response == {"jsonrpc": "2.0", "id": 1, "result": {"ok": True}}
|
||||
|
||||
await transport.disconnect()
|
||||
|
||||
|
||||
# ── Transport 生命周期测试 ──────────────────────────────────────
|
||||
|
||||
|
||||
class TestTransportLifecycle:
|
||||
"""传输层生命周期测试"""
|
||||
|
||||
async def test_http_transport_full_lifecycle(self, httpx_mock):
|
||||
httpx_mock.add_response(
|
||||
url="http://localhost:8080/",
|
||||
json={"jsonrpc": "2.0", "id": 1, "result": {"initialized": True}},
|
||||
)
|
||||
|
||||
transport = HTTPTransport(endpoint="http://localhost:8080")
|
||||
|
||||
# 1. 连接
|
||||
await transport.connect()
|
||||
assert transport.is_connected
|
||||
|
||||
# 2. 发送请求
|
||||
result = await transport.send_request("initialize")
|
||||
assert result == {"initialized": True}
|
||||
|
||||
# 3. 断开
|
||||
await transport.disconnect()
|
||||
assert not transport.is_connected
|
||||
|
||||
async def test_sse_transport_full_lifecycle(self, httpx_mock):
|
||||
httpx_mock.add_response(
|
||||
url="http://localhost:8080/message",
|
||||
json={"jsonrpc": "2.0", "id": 1, "result": {"initialized": True}},
|
||||
)
|
||||
|
||||
transport = SSETransport(endpoint="http://localhost:8080")
|
||||
|
||||
# 1. 连接
|
||||
await transport.connect()
|
||||
assert transport.is_connected
|
||||
|
||||
# 2. 发送请求
|
||||
result = await transport.send_request("initialize")
|
||||
assert result == {"initialized": True}
|
||||
|
||||
# 3. 断开
|
||||
await transport.disconnect()
|
||||
assert not transport.is_connected
|
||||
|
||||
async def test_reconnect_after_disconnect(self, httpx_mock):
|
||||
httpx_mock.add_response(
|
||||
url="http://localhost:8080/",
|
||||
json={"jsonrpc": "2.0", "id": 1, "result": {"first": True}},
|
||||
)
|
||||
httpx_mock.add_response(
|
||||
url="http://localhost:8080/",
|
||||
json={"jsonrpc": "2.0", "id": 2, "result": {"second": True}},
|
||||
)
|
||||
|
||||
transport = HTTPTransport(endpoint="http://localhost:8080")
|
||||
|
||||
# 第一次连接
|
||||
await transport.connect()
|
||||
result1 = await transport.send_request("method1")
|
||||
assert result1 == {"first": True}
|
||||
await transport.disconnect()
|
||||
|
||||
# 重新连接
|
||||
await transport.connect()
|
||||
result2 = await transport.send_request("method2")
|
||||
assert result2 == {"second": True}
|
||||
await transport.disconnect()
|
||||
Loading…
Reference in New Issue