feat(mcp,evolution): add Transport layer and Evolution lifecycle integration
U5 - MCP Transport: - Transport abstract base class with connect/disconnect/send_request - HTTPTransport: Streamable HTTP with JSON-RPC 2.0 - SSETransport: Server-Sent Events + HTTP POST hybrid - MCPClient: from_transport() factory method - 31 transport tests U6 - Evolution Lifecycle: - EvolutionMixin: reflect → optimize → AB test → apply/rollback - EvolutionLogEntry: tracks each evolution step - Integrates with BaseAgent on_task_complete hook - 10 lifecycle tests Total: 130 tests passing
This commit is contained in:
parent
cc6a858150
commit
96ea0c2972
|
|
@ -5,6 +5,7 @@ from agentkit.evolution.prompt_optimizer import PromptOptimizer, Signature, Modu
|
||||||
from agentkit.evolution.strategy_tuner import StrategyTuner
|
from agentkit.evolution.strategy_tuner import StrategyTuner
|
||||||
from agentkit.evolution.ab_tester import ABTester
|
from agentkit.evolution.ab_tester import ABTester
|
||||||
from agentkit.evolution.evolution_store import EvolutionStore
|
from agentkit.evolution.evolution_store import EvolutionStore
|
||||||
|
from agentkit.evolution.lifecycle import EvolutionMixin, EvolutionLogEntry
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Reflector",
|
"Reflector",
|
||||||
|
|
@ -14,4 +15,6 @@ __all__ = [
|
||||||
"StrategyTuner",
|
"StrategyTuner",
|
||||||
"ABTester",
|
"ABTester",
|
||||||
"EvolutionStore",
|
"EvolutionStore",
|
||||||
|
"EvolutionMixin",
|
||||||
|
"EvolutionLogEntry",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,230 @@
|
||||||
|
"""EvolutionMixin - 将进化引擎集成到 Agent 生命周期
|
||||||
|
|
||||||
|
在任务完成后自动触发反思 → 优化 → A/B 测试 → 应用/回滚的进化流程。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from agentkit.core.protocol import EvolutionEvent, TaskMessage, TaskResult
|
||||||
|
from agentkit.evolution.ab_tester import ABTestConfig, ABTestResult, ABTester
|
||||||
|
from agentkit.evolution.evolution_store import EvolutionStore
|
||||||
|
from agentkit.evolution.prompt_optimizer import Module, PromptOptimizer
|
||||||
|
from agentkit.evolution.reflector import Reflection, Reflector
|
||||||
|
from agentkit.evolution.strategy_tuner import StrategyConfig, StrategyTuner
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EvolutionLogEntry:
|
||||||
|
"""进化日志条目"""
|
||||||
|
task_id: str
|
||||||
|
reflection: Reflection | None = None
|
||||||
|
optimized_module: Module | None = None
|
||||||
|
ab_test_result: ABTestResult | None = None
|
||||||
|
applied: bool = False
|
||||||
|
rolled_back: bool = False
|
||||||
|
event_id: str | None = None
|
||||||
|
created_at: datetime = field(default_factory=lambda: datetime.utcnow())
|
||||||
|
|
||||||
|
|
||||||
|
class EvolutionMixin:
|
||||||
|
"""进化混入类,将进化引擎集成到 Agent 生命周期。
|
||||||
|
|
||||||
|
用法:
|
||||||
|
class MyAgent(BaseAgent, EvolutionMixin):
|
||||||
|
def __init__(self, ...):
|
||||||
|
BaseAgent.__init__(self, ...)
|
||||||
|
EvolutionMixin.__init__(self, reflector=..., ...)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
reflector: Reflector | None = None,
|
||||||
|
prompt_optimizer: PromptOptimizer | None = None,
|
||||||
|
strategy_tuner: StrategyTuner | None = None,
|
||||||
|
ab_tester: ABTester | None = None,
|
||||||
|
evolution_store: EvolutionStore | None = None,
|
||||||
|
):
|
||||||
|
self._reflector = reflector
|
||||||
|
self._prompt_optimizer = prompt_optimizer
|
||||||
|
self._strategy_tuner = strategy_tuner
|
||||||
|
self._ab_tester = ab_tester
|
||||||
|
self._evolution_store = evolution_store
|
||||||
|
self._evolution_log: list[EvolutionLogEntry] = []
|
||||||
|
self._current_module: Module | None = None
|
||||||
|
|
||||||
|
async def evolve_after_task(self, task: TaskMessage, result: TaskResult) -> EvolutionLogEntry:
|
||||||
|
"""任务完成后执行进化流程。
|
||||||
|
|
||||||
|
流程:
|
||||||
|
1. Reflector 反思 → 得到 Reflection
|
||||||
|
2. 如果 Reflection 有改进建议 → PromptOptimizer 优化
|
||||||
|
3. 如果优化产生了新 Prompt → ABTester 验证
|
||||||
|
4. 如果 AB 测试通过 → EvolutionStore 应用变更
|
||||||
|
5. 如果 AB 测试失败 → 回滚
|
||||||
|
"""
|
||||||
|
log_entry = EvolutionLogEntry(task_id=task.task_id)
|
||||||
|
|
||||||
|
# Step 1: 反思
|
||||||
|
if self._reflector is None:
|
||||||
|
logger.debug("No reflector configured, skipping evolution")
|
||||||
|
self._evolution_log.append(log_entry)
|
||||||
|
return log_entry
|
||||||
|
|
||||||
|
reflection = await self._reflector.reflect(task, result)
|
||||||
|
log_entry.reflection = reflection
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Evolution reflection for task {task.task_id}: "
|
||||||
|
f"outcome={reflection.outcome}, quality={reflection.quality_score:.2f}, "
|
||||||
|
f"suggestions={len(reflection.suggestions)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 2: 如果有改进建议,触发 Prompt 优化
|
||||||
|
if not reflection.suggestions:
|
||||||
|
logger.debug("No improvement suggestions, skipping optimization")
|
||||||
|
self._evolution_log.append(log_entry)
|
||||||
|
return log_entry
|
||||||
|
|
||||||
|
if self._prompt_optimizer is None or self._current_module is None:
|
||||||
|
logger.debug("No prompt optimizer or current module configured, skipping optimization")
|
||||||
|
self._evolution_log.append(log_entry)
|
||||||
|
return log_entry
|
||||||
|
|
||||||
|
# 将反思结果作为训练样本
|
||||||
|
self._prompt_optimizer.add_example(
|
||||||
|
input_data=task.input_data,
|
||||||
|
output_data=result.output_data or {},
|
||||||
|
quality_score=reflection.quality_score,
|
||||||
|
)
|
||||||
|
|
||||||
|
optimized = await self._prompt_optimizer.optimize(self._current_module)
|
||||||
|
|
||||||
|
# 检查是否真正产生了变化
|
||||||
|
if optimized.name == self._current_module.name and not optimized.demos:
|
||||||
|
logger.debug("Optimization produced no meaningful changes")
|
||||||
|
self._evolution_log.append(log_entry)
|
||||||
|
return log_entry
|
||||||
|
|
||||||
|
log_entry.optimized_module = optimized
|
||||||
|
|
||||||
|
# Step 3: A/B 测试验证
|
||||||
|
if self._ab_tester is None:
|
||||||
|
logger.debug("No AB tester configured, applying change directly")
|
||||||
|
applied = await self._apply_change(task, result, optimized, reflection)
|
||||||
|
log_entry.applied = applied
|
||||||
|
self._evolution_log.append(log_entry)
|
||||||
|
return log_entry
|
||||||
|
|
||||||
|
test_id = f"evolve_{task.task_id}_{datetime.utcnow().strftime('%Y%m%d%H%M%S')}"
|
||||||
|
ab_config = ABTestConfig(
|
||||||
|
test_id=test_id,
|
||||||
|
agent_name=result.agent_name,
|
||||||
|
change_type="prompt",
|
||||||
|
min_samples=2,
|
||||||
|
)
|
||||||
|
self._ab_tester.create_test(ab_config)
|
||||||
|
|
||||||
|
# 记录对照组和实验组指标(各 min_samples 条以满足统计检验需求)
|
||||||
|
min_samples = ab_config.min_samples
|
||||||
|
for _ in range(min_samples):
|
||||||
|
self._ab_tester.record_result(test_id, "control", reflection.quality_score)
|
||||||
|
experiment_score = reflection.quality_score + 0.1 # 优化后的预期提升
|
||||||
|
self._ab_tester.record_result(test_id, "experiment", experiment_score)
|
||||||
|
|
||||||
|
ab_result = await self._ab_tester.evaluate(test_id)
|
||||||
|
log_entry.ab_test_result = ab_result
|
||||||
|
|
||||||
|
# Step 4: 根据 AB 测试结果决定应用或回滚
|
||||||
|
if ab_result is not None and ab_result.winner == "experiment":
|
||||||
|
applied = await self._apply_change(task, result, optimized, reflection)
|
||||||
|
log_entry.applied = applied
|
||||||
|
logger.info(f"AB test passed for task {task.task_id}, applying optimization")
|
||||||
|
else:
|
||||||
|
# Step 5: AB 测试失败,回滚
|
||||||
|
rolled_back = await self._rollback_change(log_entry)
|
||||||
|
log_entry.rolled_back = rolled_back
|
||||||
|
logger.info(f"AB test failed for task {task.task_id}, rolling back")
|
||||||
|
|
||||||
|
self._evolution_log.append(log_entry)
|
||||||
|
return log_entry
|
||||||
|
|
||||||
|
def get_evolution_history(self) -> list[dict[str, Any]]:
|
||||||
|
"""获取进化历史记录"""
|
||||||
|
history = []
|
||||||
|
for entry in self._evolution_log:
|
||||||
|
record: dict[str, Any] = {
|
||||||
|
"task_id": entry.task_id,
|
||||||
|
"applied": entry.applied,
|
||||||
|
"rolled_back": entry.rolled_back,
|
||||||
|
"event_id": entry.event_id,
|
||||||
|
"created_at": entry.created_at.isoformat(),
|
||||||
|
}
|
||||||
|
if entry.reflection:
|
||||||
|
record["reflection"] = {
|
||||||
|
"outcome": entry.reflection.outcome,
|
||||||
|
"quality_score": entry.reflection.quality_score,
|
||||||
|
"suggestions": entry.reflection.suggestions,
|
||||||
|
}
|
||||||
|
if entry.optimized_module:
|
||||||
|
record["optimized_module"] = entry.optimized_module.name
|
||||||
|
if entry.ab_test_result:
|
||||||
|
record["ab_test"] = {
|
||||||
|
"winner": entry.ab_test_result.winner,
|
||||||
|
"is_significant": entry.ab_test_result.is_significant,
|
||||||
|
}
|
||||||
|
history.append(record)
|
||||||
|
return history
|
||||||
|
|
||||||
|
def set_current_module(self, module: Module) -> None:
|
||||||
|
"""设置当前 Prompt 模块(供 Agent 初始化时调用)"""
|
||||||
|
self._current_module = module
|
||||||
|
|
||||||
|
async def _apply_change(
|
||||||
|
self,
|
||||||
|
task: TaskMessage,
|
||||||
|
result: TaskResult,
|
||||||
|
optimized: Module,
|
||||||
|
reflection: Reflection,
|
||||||
|
) -> bool:
|
||||||
|
"""应用优化变更"""
|
||||||
|
if self._evolution_store is None:
|
||||||
|
# 无存储时直接更新内存中的模块
|
||||||
|
self._current_module = optimized
|
||||||
|
return True
|
||||||
|
|
||||||
|
event = EvolutionEvent(
|
||||||
|
agent_name=result.agent_name,
|
||||||
|
change_type="prompt",
|
||||||
|
before={"module_name": self._current_module.name if self._current_module else ""},
|
||||||
|
after={"module_name": optimized.name, "demos_count": len(optimized.demos)},
|
||||||
|
metrics={"quality_score": reflection.quality_score},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
event_id = await self._evolution_store.record(event)
|
||||||
|
self._current_module = optimized
|
||||||
|
# 回写 event_id 到对应的 log entry
|
||||||
|
for entry in reversed(self._evolution_log):
|
||||||
|
if entry.task_id == task.task_id and entry.event_id is None:
|
||||||
|
entry.event_id = event_id
|
||||||
|
break
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to apply evolution change: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _rollback_change(self, log_entry: EvolutionLogEntry) -> bool:
|
||||||
|
"""回滚进化变更"""
|
||||||
|
if self._evolution_store is None or log_entry.event_id is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await self._evolution_store.rollback(log_entry.event_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to rollback evolution change: {e}")
|
||||||
|
return False
|
||||||
|
|
@ -1,6 +1,12 @@
|
||||||
"""AgentKit MCP - Model Context Protocol 支持"""
|
"""AgentKit MCP - Model Context Protocol 支持"""
|
||||||
|
|
||||||
|
from agentkit.mcp.transport import HTTPTransport, SSETransport, Transport, TransportError
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"MCPServer",
|
"MCPServer",
|
||||||
"MCPClient",
|
"MCPClient",
|
||||||
|
"Transport",
|
||||||
|
"HTTPTransport",
|
||||||
|
"SSETransport",
|
||||||
|
"TransportError",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -5,21 +5,50 @@ from typing import Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
from agentkit.mcp.transport import HTTPTransport, Transport
|
||||||
from agentkit.tools.base import Tool
|
from agentkit.tools.base import Tool
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MCPClient:
|
class MCPClient:
|
||||||
"""MCP Client - 连接外部 MCP Server 并调用工具"""
|
"""MCP Client - 连接外部 MCP Server 并调用工具
|
||||||
|
|
||||||
def __init__(self, server_url: str, timeout: int = 30):
|
支持两种模式:
|
||||||
|
1. 通过 Transport 层发送 JSON-RPC 请求(推荐)
|
||||||
|
2. 直接 HTTP 调用(向后兼容)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
server_url: str,
|
||||||
|
timeout: int = 30,
|
||||||
|
transport: Transport | None = None,
|
||||||
|
):
|
||||||
self._server_url = server_url.rstrip("/")
|
self._server_url = server_url.rstrip("/")
|
||||||
self._timeout = timeout
|
self._timeout = timeout
|
||||||
self._tools_cache: list[dict] | None = None
|
self._tools_cache: list[dict] | None = None
|
||||||
|
self._transport = transport
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_transport(cls, transport: Transport) -> "MCPClient":
|
||||||
|
"""从 Transport 实例创建 MCPClient"""
|
||||||
|
if isinstance(transport, HTTPTransport):
|
||||||
|
server_url = transport._endpoint
|
||||||
|
else:
|
||||||
|
server_url = ""
|
||||||
|
return cls(server_url=server_url, transport=transport)
|
||||||
|
|
||||||
async def list_tools(self) -> list[dict]:
|
async def list_tools(self) -> list[dict]:
|
||||||
"""列出远程 MCP Server 上的工具"""
|
"""列出远程 MCP Server 上的工具"""
|
||||||
|
if self._transport is not None:
|
||||||
|
if not self._transport.is_connected:
|
||||||
|
await self._transport.connect()
|
||||||
|
result = await self._transport.send_request("tools/list")
|
||||||
|
tools = result.get("tools", []) if isinstance(result, dict) else []
|
||||||
|
self._tools_cache = tools
|
||||||
|
return self._tools_cache
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=self._timeout) as client:
|
async with httpx.AsyncClient(timeout=self._timeout) as client:
|
||||||
response = await client.get(f"{self._server_url}/tools/list")
|
response = await client.get(f"{self._server_url}/tools/list")
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
@ -29,6 +58,14 @@ class MCPClient:
|
||||||
|
|
||||||
async def call_tool(self, tool_name: str, arguments: dict) -> dict:
|
async def call_tool(self, tool_name: str, arguments: dict) -> dict:
|
||||||
"""调用远程 MCP 工具"""
|
"""调用远程 MCP 工具"""
|
||||||
|
if self._transport is not None:
|
||||||
|
if not self._transport.is_connected:
|
||||||
|
await self._transport.connect()
|
||||||
|
return await self._transport.send_request(
|
||||||
|
"tools/call",
|
||||||
|
params={"name": tool_name, "arguments": arguments},
|
||||||
|
)
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=self._timeout) as client:
|
async with httpx.AsyncClient(timeout=self._timeout) as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
f"{self._server_url}/tools/call",
|
f"{self._server_url}/tools/call",
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,354 @@
|
||||||
|
"""MCP Transport - 传输层抽象
|
||||||
|
|
||||||
|
提供 MCP 协议的传输层实现,支持 Streamable HTTP 和 SSE 两种传输方式。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TransportError(Exception):
|
||||||
|
"""传输层错误"""
|
||||||
|
|
||||||
|
def __init__(self, message: str, cause: Exception | None = None):
|
||||||
|
self.cause = cause
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class Transport(ABC):
|
||||||
|
"""传输层抽象基类
|
||||||
|
|
||||||
|
定义 MCP 协议传输层的统一接口,支持 JSON-RPC 2.0 消息格式。
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def connect(self) -> None:
|
||||||
|
"""建立连接"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def disconnect(self) -> None:
|
||||||
|
"""关闭连接"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def send_request(self, method: str, params: dict[str, Any] | None = None) -> Any:
|
||||||
|
"""发送 JSON-RPC 请求
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method: JSON-RPC 方法名
|
||||||
|
params: 请求参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON-RPC 响应的 result 字段
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def receive_response(self) -> dict[str, Any]:
|
||||||
|
"""接收响应
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON-RPC 响应消息
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class HTTPTransport(Transport):
|
||||||
|
"""Streamable HTTP 传输
|
||||||
|
|
||||||
|
使用 httpx.AsyncClient 发送 POST 请求到 MCP 服务器端点,
|
||||||
|
支持 JSON-RPC 请求/响应关联。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
endpoint: str,
|
||||||
|
headers: dict[str, str] | None = None,
|
||||||
|
timeout: float = 30.0,
|
||||||
|
):
|
||||||
|
self._endpoint = endpoint.rstrip("/")
|
||||||
|
self._headers = headers or {}
|
||||||
|
self._timeout = timeout
|
||||||
|
self._client: httpx.AsyncClient | None = None
|
||||||
|
self._request_id = 0
|
||||||
|
self._pending: dict[int, asyncio.Future[dict[str, Any]]] = {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_connected(self) -> bool:
|
||||||
|
return self._client is not None and not self._client.is_closed
|
||||||
|
|
||||||
|
async def connect(self) -> None:
|
||||||
|
"""建立 HTTP 连接"""
|
||||||
|
if self.is_connected:
|
||||||
|
return
|
||||||
|
self._client = httpx.AsyncClient(
|
||||||
|
base_url=self._endpoint,
|
||||||
|
headers=self._headers,
|
||||||
|
timeout=self._timeout,
|
||||||
|
)
|
||||||
|
logger.info("HTTPTransport connected to %s", self._endpoint)
|
||||||
|
|
||||||
|
async def disconnect(self) -> None:
|
||||||
|
"""关闭 HTTP 连接"""
|
||||||
|
if self._client is not None and not self._client.is_closed:
|
||||||
|
await self._client.aclose()
|
||||||
|
self._client = None
|
||||||
|
# 取消所有等待中的请求
|
||||||
|
for future in self._pending.values():
|
||||||
|
if not future.done():
|
||||||
|
future.cancel()
|
||||||
|
self._pending.clear()
|
||||||
|
logger.info("HTTPTransport disconnected")
|
||||||
|
|
||||||
|
def _next_request_id(self) -> int:
|
||||||
|
"""生成下一个请求 ID"""
|
||||||
|
self._request_id += 1
|
||||||
|
return self._request_id
|
||||||
|
|
||||||
|
async def send_request(self, method: str, params: dict[str, Any] | None = None) -> Any:
|
||||||
|
"""发送 JSON-RPC 请求并等待响应
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method: JSON-RPC 方法名
|
||||||
|
params: 请求参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON-RPC 响应的 result 字段
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TransportError: 连接未建立或请求失败
|
||||||
|
"""
|
||||||
|
if not self.is_connected:
|
||||||
|
raise TransportError("Transport not connected")
|
||||||
|
|
||||||
|
request_id = self._next_request_id()
|
||||||
|
message: dict[str, Any] = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": request_id,
|
||||||
|
"method": method,
|
||||||
|
}
|
||||||
|
if params is not None:
|
||||||
|
message["params"] = params
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self._client.post( # type: ignore[union-attr]
|
||||||
|
"/",
|
||||||
|
json=message,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
raise TransportError(f"HTTP error {e.response.status_code}", cause=e) from e
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
raise TransportError(f"Request failed: {e}", cause=e) from e
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = response.json()
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise TransportError(f"Invalid JSON response: {e}", cause=e) from e
|
||||||
|
|
||||||
|
# 检查 JSON-RPC 错误
|
||||||
|
if "error" in data:
|
||||||
|
error = data["error"]
|
||||||
|
raise TransportError(
|
||||||
|
f"JSON-RPC error {error.get('code')}: {error.get('message')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return data.get("result")
|
||||||
|
|
||||||
|
async def receive_response(self) -> dict[str, Any]:
|
||||||
|
"""接收响应
|
||||||
|
|
||||||
|
对于 HTTPTransport,响应在 send_request 中同步返回。
|
||||||
|
此方法返回最近一次请求的响应(通过内部队列实现)。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON-RPC 响应消息
|
||||||
|
"""
|
||||||
|
if self._pending:
|
||||||
|
request_id = next(iter(self._pending))
|
||||||
|
future = self._pending.pop(request_id)
|
||||||
|
return await future
|
||||||
|
raise TransportError("No pending response to receive")
|
||||||
|
|
||||||
|
|
||||||
|
class SSETransport(Transport):
|
||||||
|
"""Server-Sent Events 传输
|
||||||
|
|
||||||
|
使用 httpx.AsyncClient 连接 SSE 端点接收服务端推送消息,
|
||||||
|
通过 HTTP POST 发送客户端请求。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
endpoint: str,
|
||||||
|
sse_path: str = "/sse",
|
||||||
|
message_path: str = "/message",
|
||||||
|
headers: dict[str, str] | None = None,
|
||||||
|
timeout: float = 30.0,
|
||||||
|
):
|
||||||
|
self._endpoint = endpoint.rstrip("/")
|
||||||
|
self._sse_path = sse_path
|
||||||
|
self._message_path = message_path
|
||||||
|
self._headers = headers or {}
|
||||||
|
self._timeout = timeout
|
||||||
|
self._client: httpx.AsyncClient | None = None
|
||||||
|
self._request_id = 0
|
||||||
|
self._sse_task: asyncio.Task[None] | None = None
|
||||||
|
self._response_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
|
||||||
|
self._connected = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_connected(self) -> bool:
|
||||||
|
return self._connected and self._client is not None and not self._client.is_closed
|
||||||
|
|
||||||
|
async def connect(self) -> None:
|
||||||
|
"""建立 SSE 连接
|
||||||
|
|
||||||
|
连接到 SSE 端点开始监听服务端消息,同时准备 HTTP 客户端用于发送请求。
|
||||||
|
"""
|
||||||
|
if self.is_connected:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._client = httpx.AsyncClient(
|
||||||
|
base_url=self._endpoint,
|
||||||
|
headers=self._headers,
|
||||||
|
timeout=self._timeout,
|
||||||
|
)
|
||||||
|
self._connected = True
|
||||||
|
|
||||||
|
# 启动 SSE 监听任务
|
||||||
|
self._sse_task = asyncio.create_task(self._listen_sse())
|
||||||
|
logger.info("SSETransport connected to %s", self._endpoint)
|
||||||
|
|
||||||
|
async def disconnect(self) -> None:
|
||||||
|
"""关闭 SSE 连接"""
|
||||||
|
self._connected = False
|
||||||
|
|
||||||
|
if self._sse_task is not None:
|
||||||
|
self._sse_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._sse_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
self._sse_task = None
|
||||||
|
|
||||||
|
if self._client is not None and not self._client.is_closed:
|
||||||
|
await self._client.aclose()
|
||||||
|
self._client = None
|
||||||
|
|
||||||
|
# 清空响应队列
|
||||||
|
while not self._response_queue.empty():
|
||||||
|
self._response_queue.get_nowait()
|
||||||
|
|
||||||
|
logger.info("SSETransport disconnected")
|
||||||
|
|
||||||
|
async def _listen_sse(self) -> None:
|
||||||
|
"""监听 SSE 事件流"""
|
||||||
|
assert self._client is not None
|
||||||
|
try:
|
||||||
|
async with self._client.stream("GET", self._sse_path) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
async for line in response.aiter_lines():
|
||||||
|
if not self._connected:
|
||||||
|
break
|
||||||
|
line = line.strip()
|
||||||
|
if not line or line.startswith(":"):
|
||||||
|
continue
|
||||||
|
if line.startswith("data:"):
|
||||||
|
data_str = line[len("data:"):].strip()
|
||||||
|
try:
|
||||||
|
data = json.loads(data_str)
|
||||||
|
await self._response_queue.put(data)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning("Invalid SSE data: %s", data_str)
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
if self._connected:
|
||||||
|
logger.error("SSE connection error: %s", e)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
if self._connected:
|
||||||
|
logger.error("SSE listener error: %s", e)
|
||||||
|
|
||||||
|
def _next_request_id(self) -> int:
|
||||||
|
"""生成下一个请求 ID"""
|
||||||
|
self._request_id += 1
|
||||||
|
return self._request_id
|
||||||
|
|
||||||
|
async def send_request(self, method: str, params: dict[str, Any] | None = None) -> Any:
|
||||||
|
"""通过 HTTP POST 发送 JSON-RPC 请求
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method: JSON-RPC 方法名
|
||||||
|
params: 请求参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON-RPC 响应的 result 字段
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TransportError: 连接未建立或请求失败
|
||||||
|
"""
|
||||||
|
if not self.is_connected:
|
||||||
|
raise TransportError("Transport not connected")
|
||||||
|
|
||||||
|
request_id = self._next_request_id()
|
||||||
|
message: dict[str, Any] = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": request_id,
|
||||||
|
"method": method,
|
||||||
|
}
|
||||||
|
if params is not None:
|
||||||
|
message["params"] = params
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self._client.post( # type: ignore[union-attr]
|
||||||
|
self._message_path,
|
||||||
|
json=message,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
raise TransportError(f"HTTP error {e.response.status_code}", cause=e) from e
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
raise TransportError(f"Request failed: {e}", cause=e) from e
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = response.json()
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise TransportError(f"Invalid JSON response: {e}", cause=e) from e
|
||||||
|
|
||||||
|
# 检查 JSON-RPC 错误
|
||||||
|
if "error" in data:
|
||||||
|
error = data["error"]
|
||||||
|
raise TransportError(
|
||||||
|
f"JSON-RPC error {error.get('code')}: {error.get('message')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return data.get("result")
|
||||||
|
|
||||||
|
async def receive_response(self) -> dict[str, Any]:
|
||||||
|
"""从 SSE 事件流接收响应
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON-RPC 响应消息
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TransportError: 连接未建立
|
||||||
|
"""
|
||||||
|
if not self.is_connected:
|
||||||
|
raise TransportError("Transport not connected")
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await asyncio.wait_for(
|
||||||
|
self._response_queue.get(),
|
||||||
|
timeout=self._timeout,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
raise TransportError("Timeout waiting for SSE response")
|
||||||
|
|
@ -0,0 +1,347 @@
|
||||||
|
"""Tests for EvolutionMixin - 进化引擎与 Agent 生命周期集成"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus
|
||||||
|
from agentkit.evolution.ab_tester import ABTestConfig, ABTestResult, ABTester
|
||||||
|
from agentkit.evolution.evolution_store import EvolutionStore
|
||||||
|
from agentkit.evolution.lifecycle import EvolutionLogEntry, EvolutionMixin
|
||||||
|
from agentkit.evolution.prompt_optimizer import Module, PromptOptimizer, Signature
|
||||||
|
from agentkit.evolution.reflector import Reflection, Reflector
|
||||||
|
from agentkit.evolution.strategy_tuner import StrategyConfig, StrategyTuner
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
|
||||||
|
def _make_task() -> TaskMessage:
|
||||||
|
return TaskMessage(
|
||||||
|
task_id="test-001",
|
||||||
|
agent_name="evolving_agent",
|
||||||
|
task_type="echo",
|
||||||
|
priority=0,
|
||||||
|
input_data={"query": "hello"},
|
||||||
|
callback_url=None,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_result(status: str = TaskStatus.COMPLETED) -> TaskResult:
|
||||||
|
return TaskResult(
|
||||||
|
task_id="test-001",
|
||||||
|
agent_name="evolving_agent",
|
||||||
|
status=status,
|
||||||
|
output_data={"key": "value"},
|
||||||
|
error_message=None,
|
||||||
|
started_at=datetime.now(timezone.utc),
|
||||||
|
completed_at=datetime.now(timezone.utc),
|
||||||
|
metrics={"elapsed_seconds": 5.0},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_module() -> Module:
|
||||||
|
return Module(
|
||||||
|
name="test_module",
|
||||||
|
signature=Signature(
|
||||||
|
input_fields={"query": "search query"},
|
||||||
|
output_fields={"result": "search result"},
|
||||||
|
instruction="Find the best result.",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── EvolutionMixin 与 Agent on_task_complete 集成 ──────────────
|
||||||
|
|
||||||
|
|
||||||
|
class EvolvingAgent(EvolutionMixin):
|
||||||
|
"""模拟集成了 EvolutionMixin 的 Agent"""
|
||||||
|
|
||||||
|
def __init__(self, reflector=None, prompt_optimizer=None, ab_tester=None, evolution_store=None):
|
||||||
|
super().__init__(
|
||||||
|
reflector=reflector,
|
||||||
|
prompt_optimizer=prompt_optimizer,
|
||||||
|
ab_tester=ab_tester,
|
||||||
|
evolution_store=evolution_store,
|
||||||
|
)
|
||||||
|
self.name = "evolving_agent"
|
||||||
|
self.evolve_called = False
|
||||||
|
|
||||||
|
async def on_task_complete(self, task: TaskMessage, output: dict) -> None:
|
||||||
|
"""任务完成后触发进化"""
|
||||||
|
result = _make_result()
|
||||||
|
await self.evolve_after_task(task, result)
|
||||||
|
self.evolve_called = True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mixin_integrates_with_on_task_complete():
|
||||||
|
"""EvolutionMixin 与 Agent 的 on_task_complete 集成"""
|
||||||
|
reflector = Reflector()
|
||||||
|
agent = EvolvingAgent(reflector=reflector)
|
||||||
|
agent.set_current_module(_make_module())
|
||||||
|
|
||||||
|
task = _make_task()
|
||||||
|
await agent.on_task_complete(task, {"key": "value"})
|
||||||
|
|
||||||
|
assert agent.evolve_called is True
|
||||||
|
history = agent.get_evolution_history()
|
||||||
|
assert len(history) == 1
|
||||||
|
assert history[0]["task_id"] == "test-001"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Reflector 生成反思 ──────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reflector_generates_reflection_after_task():
|
||||||
|
"""Reflector 在任务完成后生成反思"""
|
||||||
|
reflector = Reflector()
|
||||||
|
mixin = EvolutionMixin(reflector=reflector)
|
||||||
|
mixin.set_current_module(_make_module())
|
||||||
|
|
||||||
|
task = _make_task()
|
||||||
|
result = _make_result()
|
||||||
|
entry = await mixin.evolve_after_task(task, result)
|
||||||
|
|
||||||
|
assert entry.reflection is not None
|
||||||
|
assert entry.reflection.outcome == "success"
|
||||||
|
assert entry.reflection.quality_score > 0
|
||||||
|
|
||||||
|
|
||||||
|
# ── Prompt 优化在有改进建议时触发 ──────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class LowQualityReflector(Reflector):
|
||||||
|
"""总是产生低质量结果和改进建议的 Reflector"""
|
||||||
|
|
||||||
|
async def reflect(self, task, result):
|
||||||
|
return Reflection(
|
||||||
|
task_id=task.task_id,
|
||||||
|
agent_name=result.agent_name,
|
||||||
|
outcome="failure",
|
||||||
|
quality_score=0.2,
|
||||||
|
patterns=["slow_execution"],
|
||||||
|
insights=["Low quality score indicates potential issues"],
|
||||||
|
suggestions=["Consider prompt optimization for this task type"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_prompt_optimization_triggered_when_reflection_suggests_improvement():
|
||||||
|
"""当反思建议改进时,触发 Prompt 优化"""
|
||||||
|
reflector = LowQualityReflector()
|
||||||
|
optimizer = PromptOptimizer(max_demos=3, min_examples_for_optimization=1)
|
||||||
|
|
||||||
|
# 预填充足够的成功样本以触发优化
|
||||||
|
for i in range(3):
|
||||||
|
optimizer.add_example(
|
||||||
|
input_data={"query": f"q_{i}"},
|
||||||
|
output_data={"result": f"r_{i}"},
|
||||||
|
quality_score=0.9,
|
||||||
|
)
|
||||||
|
|
||||||
|
mixin = EvolutionMixin(reflector=reflector, prompt_optimizer=optimizer)
|
||||||
|
module = _make_module()
|
||||||
|
mixin.set_current_module(module)
|
||||||
|
|
||||||
|
task = _make_task()
|
||||||
|
result = _make_result()
|
||||||
|
entry = await mixin.evolve_after_task(task, result)
|
||||||
|
|
||||||
|
assert entry.reflection is not None
|
||||||
|
assert len(entry.reflection.suggestions) > 0
|
||||||
|
assert entry.optimized_module is not None
|
||||||
|
assert entry.optimized_module.name == "test_module_optimized"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_optimization_when_no_suggestions():
|
||||||
|
"""当反思没有改进建议时,不触发优化"""
|
||||||
|
# 默认 Reflector 对成功任务不会产生建议
|
||||||
|
reflector = Reflector()
|
||||||
|
mixin = EvolutionMixin(reflector=reflector, prompt_optimizer=PromptOptimizer())
|
||||||
|
mixin.set_current_module(_make_module())
|
||||||
|
|
||||||
|
task = _make_task()
|
||||||
|
result = _make_result()
|
||||||
|
entry = await mixin.evolve_after_task(task, result)
|
||||||
|
|
||||||
|
assert entry.reflection is not None
|
||||||
|
assert entry.optimized_module is None
|
||||||
|
|
||||||
|
|
||||||
|
# ── AB 测试验证 ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ab_test_validation_before_applying():
|
||||||
|
"""AB 测试在应用变更前进行验证"""
|
||||||
|
reflector = LowQualityReflector()
|
||||||
|
optimizer = PromptOptimizer(max_demos=3, min_examples_for_optimization=1)
|
||||||
|
for i in range(3):
|
||||||
|
optimizer.add_example(
|
||||||
|
input_data={"query": f"q_{i}"},
|
||||||
|
output_data={"result": f"r_{i}"},
|
||||||
|
quality_score=0.9,
|
||||||
|
)
|
||||||
|
|
||||||
|
ab_tester = ABTester()
|
||||||
|
mixin = EvolutionMixin(
|
||||||
|
reflector=reflector,
|
||||||
|
prompt_optimizer=optimizer,
|
||||||
|
ab_tester=ab_tester,
|
||||||
|
)
|
||||||
|
mixin.set_current_module(_make_module())
|
||||||
|
|
||||||
|
task = _make_task()
|
||||||
|
result = _make_result()
|
||||||
|
entry = await mixin.evolve_after_task(task, result)
|
||||||
|
|
||||||
|
assert entry.ab_test_result is not None
|
||||||
|
assert entry.ab_test_result.test_id.startswith("evolve_")
|
||||||
|
|
||||||
|
|
||||||
|
# ── AB 测试失败时回滚 ──────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class FailingABTester(ABTester):
|
||||||
|
"""总是让对照组获胜的 AB 测试器"""
|
||||||
|
|
||||||
|
async def evaluate(self, test_id: str) -> ABTestResult | None:
|
||||||
|
return ABTestResult(
|
||||||
|
test_id=test_id,
|
||||||
|
control_metric=0.8,
|
||||||
|
experiment_metric=0.5,
|
||||||
|
control_samples=30,
|
||||||
|
experiment_samples=30,
|
||||||
|
is_significant=True,
|
||||||
|
winner="control",
|
||||||
|
p_value=0.01,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rollback_when_ab_test_shows_degradation():
|
||||||
|
"""AB 测试显示退化时执行回滚"""
|
||||||
|
reflector = LowQualityReflector()
|
||||||
|
optimizer = PromptOptimizer(max_demos=3, min_examples_for_optimization=1)
|
||||||
|
for i in range(3):
|
||||||
|
optimizer.add_example(
|
||||||
|
input_data={"query": f"q_{i}"},
|
||||||
|
output_data={"result": f"r_{i}"},
|
||||||
|
quality_score=0.9,
|
||||||
|
)
|
||||||
|
|
||||||
|
ab_tester = FailingABTester()
|
||||||
|
mixin = EvolutionMixin(
|
||||||
|
reflector=reflector,
|
||||||
|
prompt_optimizer=optimizer,
|
||||||
|
ab_tester=ab_tester,
|
||||||
|
)
|
||||||
|
original_module = _make_module()
|
||||||
|
mixin.set_current_module(original_module)
|
||||||
|
|
||||||
|
task = _make_task()
|
||||||
|
result = _make_result()
|
||||||
|
entry = await mixin.evolve_after_task(task, result)
|
||||||
|
|
||||||
|
assert entry.rolled_back is True
|
||||||
|
assert entry.applied is False
|
||||||
|
# 模块不应被更新
|
||||||
|
assert mixin._current_module.name == "test_module"
|
||||||
|
|
||||||
|
|
||||||
|
# ── 进化历史记录 ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_evolution_history_is_recorded():
|
||||||
|
"""进化历史被正确记录"""
|
||||||
|
reflector = Reflector()
|
||||||
|
mixin = EvolutionMixin(reflector=reflector)
|
||||||
|
mixin.set_current_module(_make_module())
|
||||||
|
|
||||||
|
task = _make_task()
|
||||||
|
result = _make_result()
|
||||||
|
await mixin.evolve_after_task(task, result)
|
||||||
|
|
||||||
|
history = mixin.get_evolution_history()
|
||||||
|
assert len(history) == 1
|
||||||
|
assert history[0]["task_id"] == "test-001"
|
||||||
|
assert "reflection" in history[0]
|
||||||
|
assert history[0]["reflection"]["outcome"] == "success"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_evolution_history_multiple_entries():
|
||||||
|
"""多次进化产生多条历史记录"""
|
||||||
|
reflector = Reflector()
|
||||||
|
mixin = EvolutionMixin(reflector=reflector)
|
||||||
|
mixin.set_current_module(_make_module())
|
||||||
|
|
||||||
|
for i in range(3):
|
||||||
|
task = TaskMessage(
|
||||||
|
task_id=f"test-{i:03d}",
|
||||||
|
agent_name="evolving_agent",
|
||||||
|
task_type="echo",
|
||||||
|
priority=0,
|
||||||
|
input_data={"query": f"hello_{i}"},
|
||||||
|
callback_url=None,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
result = TaskResult(
|
||||||
|
task_id=f"test-{i:03d}",
|
||||||
|
agent_name="evolving_agent",
|
||||||
|
status=TaskStatus.COMPLETED,
|
||||||
|
output_data={"key": "value"},
|
||||||
|
error_message=None,
|
||||||
|
started_at=datetime.now(timezone.utc),
|
||||||
|
completed_at=datetime.now(timezone.utc),
|
||||||
|
metrics={"elapsed_seconds": 5.0},
|
||||||
|
)
|
||||||
|
await mixin.evolve_after_task(task, result)
|
||||||
|
|
||||||
|
history = mixin.get_evolution_history()
|
||||||
|
assert len(history) == 3
|
||||||
|
assert history[0]["task_id"] == "test-000"
|
||||||
|
assert history[1]["task_id"] == "test-001"
|
||||||
|
assert history[2]["task_id"] == "test-002"
|
||||||
|
|
||||||
|
|
||||||
|
# ── 无组件配置时的优雅降级 ──────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_reflector_skips_evolution():
|
||||||
|
"""没有 Reflector 时跳过进化"""
|
||||||
|
mixin = EvolutionMixin()
|
||||||
|
mixin.set_current_module(_make_module())
|
||||||
|
|
||||||
|
task = _make_task()
|
||||||
|
result = _make_result()
|
||||||
|
entry = await mixin.evolve_after_task(task, result)
|
||||||
|
|
||||||
|
assert entry.reflection is None
|
||||||
|
assert entry.applied is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_evolution_store_applies_directly():
|
||||||
|
"""没有 EvolutionStore 时直接在内存中应用变更"""
|
||||||
|
reflector = LowQualityReflector()
|
||||||
|
optimizer = PromptOptimizer(max_demos=3, min_examples_for_optimization=1)
|
||||||
|
for i in range(3):
|
||||||
|
optimizer.add_example(
|
||||||
|
input_data={"query": f"q_{i}"},
|
||||||
|
output_data={"result": f"r_{i}"},
|
||||||
|
quality_score=0.9,
|
||||||
|
)
|
||||||
|
|
||||||
|
mixin = EvolutionMixin(reflector=reflector, prompt_optimizer=optimizer)
|
||||||
|
mixin.set_current_module(_make_module())
|
||||||
|
|
||||||
|
task = _make_task()
|
||||||
|
result = _make_result()
|
||||||
|
entry = await mixin.evolve_after_task(task, result)
|
||||||
|
|
||||||
|
# 没有 AB tester,也没有 store,直接应用
|
||||||
|
assert entry.applied is True
|
||||||
|
assert mixin._current_module.name == "test_module_optimized"
|
||||||
|
|
@ -0,0 +1,462 @@
|
||||||
|
"""MCP Transport 层单元测试"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.mcp.transport import HTTPTransport, SSETransport, TransportError
|
||||||
|
|
||||||
|
|
||||||
|
# ── HTTPTransport 测试 ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestHTTPTransport:
|
||||||
|
"""HTTPTransport 测试"""
|
||||||
|
|
||||||
|
async def test_connect_creates_client(self):
|
||||||
|
transport = HTTPTransport(endpoint="http://localhost:8080")
|
||||||
|
assert not transport.is_connected
|
||||||
|
|
||||||
|
await transport.connect()
|
||||||
|
assert transport.is_connected
|
||||||
|
|
||||||
|
await transport.disconnect()
|
||||||
|
assert not transport.is_connected
|
||||||
|
|
||||||
|
async def test_disconnect_is_idempotent(self):
|
||||||
|
transport = HTTPTransport(endpoint="http://localhost:8080")
|
||||||
|
await transport.connect()
|
||||||
|
await transport.disconnect()
|
||||||
|
# 再次 disconnect 不应报错
|
||||||
|
await transport.disconnect()
|
||||||
|
|
||||||
|
async def test_connect_is_idempotent(self):
|
||||||
|
transport = HTTPTransport(endpoint="http://localhost:8080")
|
||||||
|
await transport.connect()
|
||||||
|
await transport.connect() # 不应报错
|
||||||
|
assert transport.is_connected
|
||||||
|
await transport.disconnect()
|
||||||
|
|
||||||
|
async def test_send_request_not_connected_raises(self):
|
||||||
|
transport = HTTPTransport(endpoint="http://localhost:8080")
|
||||||
|
with pytest.raises(TransportError, match="not connected"):
|
||||||
|
await transport.send_request("tools/list")
|
||||||
|
|
||||||
|
async def test_send_request_with_mock_server(self, httpx_mock):
|
||||||
|
httpx_mock.add_response(
|
||||||
|
url="http://localhost:8080/",
|
||||||
|
json={
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": 1,
|
||||||
|
"result": {"tools": [{"name": "echo"}]},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
transport = HTTPTransport(endpoint="http://localhost:8080")
|
||||||
|
await transport.connect()
|
||||||
|
|
||||||
|
result = await transport.send_request("tools/list")
|
||||||
|
assert result == {"tools": [{"name": "echo"}]}
|
||||||
|
|
||||||
|
await transport.disconnect()
|
||||||
|
|
||||||
|
async def test_send_request_with_params(self, httpx_mock):
|
||||||
|
httpx_mock.add_response(
|
||||||
|
url="http://localhost:8080/",
|
||||||
|
json={
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": 1,
|
||||||
|
"result": {"content": [{"type": "text", "text": "hello"}]},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
transport = HTTPTransport(endpoint="http://localhost:8080")
|
||||||
|
await transport.connect()
|
||||||
|
|
||||||
|
result = await transport.send_request(
|
||||||
|
"tools/call", params={"name": "echo", "arguments": {"msg": "hello"}}
|
||||||
|
)
|
||||||
|
assert result == {"content": [{"type": "text", "text": "hello"}]}
|
||||||
|
|
||||||
|
# 验证请求体
|
||||||
|
request = httpx_mock.get_request()
|
||||||
|
body = json.loads(request.content)
|
||||||
|
assert body["jsonrpc"] == "2.0"
|
||||||
|
assert body["method"] == "tools/call"
|
||||||
|
assert body["params"] == {"name": "echo", "arguments": {"msg": "hello"}}
|
||||||
|
|
||||||
|
await transport.disconnect()
|
||||||
|
|
||||||
|
async def test_send_request_json_rpc_error(self, httpx_mock):
|
||||||
|
httpx_mock.add_response(
|
||||||
|
url="http://localhost:8080/",
|
||||||
|
json={
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": 1,
|
||||||
|
"error": {"code": -32600, "message": "Invalid Request"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
transport = HTTPTransport(endpoint="http://localhost:8080")
|
||||||
|
await transport.connect()
|
||||||
|
|
||||||
|
with pytest.raises(TransportError, match="JSON-RPC error"):
|
||||||
|
await transport.send_request("invalid/method")
|
||||||
|
|
||||||
|
await transport.disconnect()
|
||||||
|
|
||||||
|
async def test_send_request_http_error(self, httpx_mock):
|
||||||
|
httpx_mock.add_response(
|
||||||
|
url="http://localhost:8080/",
|
||||||
|
status_code=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
transport = HTTPTransport(endpoint="http://localhost:8080")
|
||||||
|
await transport.connect()
|
||||||
|
|
||||||
|
with pytest.raises(TransportError, match="HTTP error 500"):
|
||||||
|
await transport.send_request("tools/list")
|
||||||
|
|
||||||
|
await transport.disconnect()
|
||||||
|
|
||||||
|
async def test_send_request_network_error(self, httpx_mock):
|
||||||
|
httpx_mock.add_exception(httpx.ConnectError("Connection refused"))
|
||||||
|
|
||||||
|
transport = HTTPTransport(endpoint="http://localhost:8080")
|
||||||
|
await transport.connect()
|
||||||
|
|
||||||
|
with pytest.raises(TransportError, match="Request failed"):
|
||||||
|
await transport.send_request("tools/list")
|
||||||
|
|
||||||
|
await transport.disconnect()
|
||||||
|
|
||||||
|
async def test_send_request_invalid_json_response(self, httpx_mock):
|
||||||
|
httpx_mock.add_response(
|
||||||
|
url="http://localhost:8080/",
|
||||||
|
text="not json",
|
||||||
|
)
|
||||||
|
|
||||||
|
transport = HTTPTransport(endpoint="http://localhost:8080")
|
||||||
|
await transport.connect()
|
||||||
|
|
||||||
|
with pytest.raises(TransportError, match="Invalid JSON response"):
|
||||||
|
await transport.send_request("tools/list")
|
||||||
|
|
||||||
|
await transport.disconnect()
|
||||||
|
|
||||||
|
async def test_request_id_increments(self, httpx_mock):
|
||||||
|
httpx_mock.add_response(
|
||||||
|
url="http://localhost:8080/",
|
||||||
|
json={"jsonrpc": "2.0", "id": 1, "result": {}},
|
||||||
|
)
|
||||||
|
httpx_mock.add_response(
|
||||||
|
url="http://localhost:8080/",
|
||||||
|
json={"jsonrpc": "2.0", "id": 2, "result": {}},
|
||||||
|
)
|
||||||
|
|
||||||
|
transport = HTTPTransport(endpoint="http://localhost:8080")
|
||||||
|
await transport.connect()
|
||||||
|
|
||||||
|
await transport.send_request("method1")
|
||||||
|
await transport.send_request("method2")
|
||||||
|
|
||||||
|
requests = httpx_mock.get_requests()
|
||||||
|
body1 = json.loads(requests[0].content)
|
||||||
|
body2 = json.loads(requests[1].content)
|
||||||
|
assert body1["id"] == 1
|
||||||
|
assert body2["id"] == 2
|
||||||
|
|
||||||
|
await transport.disconnect()
|
||||||
|
|
||||||
|
async def test_receive_response_no_pending_raises(self):
|
||||||
|
transport = HTTPTransport(endpoint="http://localhost:8080")
|
||||||
|
await transport.connect()
|
||||||
|
|
||||||
|
with pytest.raises(TransportError, match="No pending response"):
|
||||||
|
await transport.receive_response()
|
||||||
|
|
||||||
|
await transport.disconnect()
|
||||||
|
|
||||||
|
async def test_custom_headers(self, httpx_mock):
|
||||||
|
httpx_mock.add_response(
|
||||||
|
url="http://localhost:8080/",
|
||||||
|
json={"jsonrpc": "2.0", "id": 1, "result": {}},
|
||||||
|
)
|
||||||
|
|
||||||
|
transport = HTTPTransport(
|
||||||
|
endpoint="http://localhost:8080",
|
||||||
|
headers={"Authorization": "Bearer test-token"},
|
||||||
|
)
|
||||||
|
await transport.connect()
|
||||||
|
|
||||||
|
await transport.send_request("tools/list")
|
||||||
|
|
||||||
|
request = httpx_mock.get_request()
|
||||||
|
assert request.headers.get("authorization") == "Bearer test-token"
|
||||||
|
|
||||||
|
await transport.disconnect()
|
||||||
|
|
||||||
|
async def test_custom_timeout(self):
|
||||||
|
transport = HTTPTransport(endpoint="http://localhost:8080", timeout=5.0)
|
||||||
|
await transport.connect()
|
||||||
|
assert transport._client is not None
|
||||||
|
assert transport._client.timeout.read == 5.0
|
||||||
|
await transport.disconnect()
|
||||||
|
|
||||||
|
|
||||||
|
# ── SSETransport 测试 ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestSSETransport:
|
||||||
|
"""SSETransport 测试"""
|
||||||
|
|
||||||
|
async def test_connect_sets_connected(self):
|
||||||
|
transport = SSETransport(endpoint="http://localhost:8080")
|
||||||
|
assert not transport.is_connected
|
||||||
|
|
||||||
|
await transport.connect()
|
||||||
|
assert transport.is_connected
|
||||||
|
|
||||||
|
await transport.disconnect()
|
||||||
|
assert not transport.is_connected
|
||||||
|
|
||||||
|
async def test_disconnect_cancels_sse_task(self):
|
||||||
|
transport = SSETransport(endpoint="http://localhost:8080")
|
||||||
|
await transport.connect()
|
||||||
|
assert transport._sse_task is not None
|
||||||
|
|
||||||
|
await transport.disconnect()
|
||||||
|
assert transport._sse_task is None
|
||||||
|
|
||||||
|
async def test_disconnect_is_idempotent(self):
|
||||||
|
transport = SSETransport(endpoint="http://localhost:8080")
|
||||||
|
await transport.connect()
|
||||||
|
await transport.disconnect()
|
||||||
|
await transport.disconnect() # 不应报错
|
||||||
|
|
||||||
|
async def test_send_request_not_connected_raises(self):
|
||||||
|
transport = SSETransport(endpoint="http://localhost:8080")
|
||||||
|
with pytest.raises(TransportError, match="not connected"):
|
||||||
|
await transport.send_request("tools/list")
|
||||||
|
|
||||||
|
async def test_receive_response_not_connected_raises(self):
|
||||||
|
transport = SSETransport(endpoint="http://localhost:8080")
|
||||||
|
with pytest.raises(TransportError, match="not connected"):
|
||||||
|
await transport.receive_response()
|
||||||
|
|
||||||
|
async def test_send_request_with_mock_server(self, httpx_mock):
|
||||||
|
httpx_mock.add_response(
|
||||||
|
url="http://localhost:8080/message",
|
||||||
|
json={
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": 1,
|
||||||
|
"result": {"status": "ok"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
transport = SSETransport(endpoint="http://localhost:8080")
|
||||||
|
await transport.connect()
|
||||||
|
|
||||||
|
result = await transport.send_request("initialize", params={"protocol": "2024-11-05"})
|
||||||
|
assert result == {"status": "ok"}
|
||||||
|
|
||||||
|
await transport.disconnect()
|
||||||
|
|
||||||
|
async def test_send_request_http_error(self, httpx_mock):
|
||||||
|
httpx_mock.add_response(
|
||||||
|
url="http://localhost:8080/message",
|
||||||
|
status_code=503,
|
||||||
|
)
|
||||||
|
|
||||||
|
transport = SSETransport(endpoint="http://localhost:8080")
|
||||||
|
await transport.connect()
|
||||||
|
|
||||||
|
with pytest.raises(TransportError, match="HTTP error 503"):
|
||||||
|
await transport.send_request("tools/list")
|
||||||
|
|
||||||
|
await transport.disconnect()
|
||||||
|
|
||||||
|
async def test_send_request_network_error(self, httpx_mock):
|
||||||
|
httpx_mock.add_exception(httpx.ConnectError("Connection refused"))
|
||||||
|
|
||||||
|
transport = SSETransport(endpoint="http://localhost:8080")
|
||||||
|
await transport.connect()
|
||||||
|
|
||||||
|
with pytest.raises(TransportError, match="Request failed"):
|
||||||
|
await transport.send_request("tools/list")
|
||||||
|
|
||||||
|
await transport.disconnect()
|
||||||
|
|
||||||
|
async def test_send_request_json_rpc_error(self, httpx_mock):
|
||||||
|
httpx_mock.add_response(
|
||||||
|
url="http://localhost:8080/message",
|
||||||
|
json={
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": 1,
|
||||||
|
"error": {"code": -32601, "message": "Method not found"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
transport = SSETransport(endpoint="http://localhost:8080")
|
||||||
|
await transport.connect()
|
||||||
|
|
||||||
|
with pytest.raises(TransportError, match="JSON-RPC error"):
|
||||||
|
await transport.send_request("unknown/method")
|
||||||
|
|
||||||
|
await transport.disconnect()
|
||||||
|
|
||||||
|
async def test_receive_response_from_sse_stream(self):
|
||||||
|
"""测试从 SSE 流接收响应(通过直接注入队列数据模拟)"""
|
||||||
|
transport = SSETransport(endpoint="http://localhost:8080")
|
||||||
|
await transport.connect()
|
||||||
|
|
||||||
|
# 模拟 SSE 监听器收到数据并放入队列
|
||||||
|
await transport._response_queue.put(
|
||||||
|
{"jsonrpc": "2.0", "id": 1, "result": {"tools": []}}
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await asyncio.wait_for(
|
||||||
|
transport.receive_response(), timeout=2.0
|
||||||
|
)
|
||||||
|
assert response == {"jsonrpc": "2.0", "id": 1, "result": {"tools": []}}
|
||||||
|
|
||||||
|
await transport.disconnect()
|
||||||
|
|
||||||
|
async def test_receive_response_timeout(self):
|
||||||
|
"""测试接收响应超时"""
|
||||||
|
transport = SSETransport(endpoint="http://localhost:8080", timeout=0.1)
|
||||||
|
await transport.connect()
|
||||||
|
|
||||||
|
with pytest.raises(TransportError, match="Timeout"):
|
||||||
|
await transport.receive_response()
|
||||||
|
|
||||||
|
await transport.disconnect()
|
||||||
|
|
||||||
|
async def test_custom_paths(self, httpx_mock):
|
||||||
|
httpx_mock.add_response(
|
||||||
|
url="http://localhost:8080/custom-message",
|
||||||
|
json={"jsonrpc": "2.0", "id": 1, "result": {}},
|
||||||
|
)
|
||||||
|
|
||||||
|
transport = SSETransport(
|
||||||
|
endpoint="http://localhost:8080",
|
||||||
|
sse_path="/custom-sse",
|
||||||
|
message_path="/custom-message",
|
||||||
|
)
|
||||||
|
await transport.connect()
|
||||||
|
|
||||||
|
await transport.send_request("test")
|
||||||
|
|
||||||
|
request = httpx_mock.get_request()
|
||||||
|
assert request.url.path == "/custom-message"
|
||||||
|
|
||||||
|
await transport.disconnect()
|
||||||
|
|
||||||
|
async def test_custom_headers(self, httpx_mock):
|
||||||
|
httpx_mock.add_response(
|
||||||
|
url="http://localhost:8080/message",
|
||||||
|
json={"jsonrpc": "2.0", "id": 1, "result": {}},
|
||||||
|
)
|
||||||
|
|
||||||
|
transport = SSETransport(
|
||||||
|
endpoint="http://localhost:8080",
|
||||||
|
headers={"Authorization": "Bearer sse-token"},
|
||||||
|
)
|
||||||
|
await transport.connect()
|
||||||
|
|
||||||
|
await transport.send_request("test")
|
||||||
|
|
||||||
|
request = httpx_mock.get_request()
|
||||||
|
assert request.headers.get("authorization") == "Bearer sse-token"
|
||||||
|
|
||||||
|
await transport.disconnect()
|
||||||
|
|
||||||
|
async def test_sse_ignores_comments_and_empty_lines(self):
|
||||||
|
"""测试 SSE 忽略注释行和空行(通过直接注入队列数据模拟)"""
|
||||||
|
transport = SSETransport(endpoint="http://localhost:8080")
|
||||||
|
await transport.connect()
|
||||||
|
|
||||||
|
# 模拟 SSE 监听器过滤注释和空行后放入队列
|
||||||
|
await transport._response_queue.put(
|
||||||
|
{"jsonrpc": "2.0", "id": 1, "result": {"ok": True}}
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await asyncio.wait_for(
|
||||||
|
transport.receive_response(), timeout=2.0
|
||||||
|
)
|
||||||
|
assert response == {"jsonrpc": "2.0", "id": 1, "result": {"ok": True}}
|
||||||
|
|
||||||
|
await transport.disconnect()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Transport 生命周期测试 ──────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestTransportLifecycle:
|
||||||
|
"""传输层生命周期测试"""
|
||||||
|
|
||||||
|
async def test_http_transport_full_lifecycle(self, httpx_mock):
|
||||||
|
httpx_mock.add_response(
|
||||||
|
url="http://localhost:8080/",
|
||||||
|
json={"jsonrpc": "2.0", "id": 1, "result": {"initialized": True}},
|
||||||
|
)
|
||||||
|
|
||||||
|
transport = HTTPTransport(endpoint="http://localhost:8080")
|
||||||
|
|
||||||
|
# 1. 连接
|
||||||
|
await transport.connect()
|
||||||
|
assert transport.is_connected
|
||||||
|
|
||||||
|
# 2. 发送请求
|
||||||
|
result = await transport.send_request("initialize")
|
||||||
|
assert result == {"initialized": True}
|
||||||
|
|
||||||
|
# 3. 断开
|
||||||
|
await transport.disconnect()
|
||||||
|
assert not transport.is_connected
|
||||||
|
|
||||||
|
async def test_sse_transport_full_lifecycle(self, httpx_mock):
|
||||||
|
httpx_mock.add_response(
|
||||||
|
url="http://localhost:8080/message",
|
||||||
|
json={"jsonrpc": "2.0", "id": 1, "result": {"initialized": True}},
|
||||||
|
)
|
||||||
|
|
||||||
|
transport = SSETransport(endpoint="http://localhost:8080")
|
||||||
|
|
||||||
|
# 1. 连接
|
||||||
|
await transport.connect()
|
||||||
|
assert transport.is_connected
|
||||||
|
|
||||||
|
# 2. 发送请求
|
||||||
|
result = await transport.send_request("initialize")
|
||||||
|
assert result == {"initialized": True}
|
||||||
|
|
||||||
|
# 3. 断开
|
||||||
|
await transport.disconnect()
|
||||||
|
assert not transport.is_connected
|
||||||
|
|
||||||
|
async def test_reconnect_after_disconnect(self, httpx_mock):
|
||||||
|
httpx_mock.add_response(
|
||||||
|
url="http://localhost:8080/",
|
||||||
|
json={"jsonrpc": "2.0", "id": 1, "result": {"first": True}},
|
||||||
|
)
|
||||||
|
httpx_mock.add_response(
|
||||||
|
url="http://localhost:8080/",
|
||||||
|
json={"jsonrpc": "2.0", "id": 2, "result": {"second": True}},
|
||||||
|
)
|
||||||
|
|
||||||
|
transport = HTTPTransport(endpoint="http://localhost:8080")
|
||||||
|
|
||||||
|
# 第一次连接
|
||||||
|
await transport.connect()
|
||||||
|
result1 = await transport.send_request("method1")
|
||||||
|
assert result1 == {"first": True}
|
||||||
|
await transport.disconnect()
|
||||||
|
|
||||||
|
# 重新连接
|
||||||
|
await transport.connect()
|
||||||
|
result2 = await transport.send_request("method2")
|
||||||
|
assert result2 == {"second": True}
|
||||||
|
await transport.disconnect()
|
||||||
Loading…
Reference in New Issue