feat(mcp,evolution): add Transport layer and Evolution lifecycle integration

U5 - MCP Transport:
- Transport abstract base class with connect/disconnect/send_request
- HTTPTransport: Streamable HTTP with JSON-RPC 2.0
- SSETransport: Server-Sent Events + HTTP POST hybrid
- MCPClient: from_transport() factory method
- 31 transport tests

U6 - Evolution Lifecycle:
- EvolutionMixin: reflect → optimize → AB test → apply/rollback
- EvolutionLogEntry: tracks each evolution step
- Integrates with BaseAgent on_task_complete hook
- 10 lifecycle tests

Total: 130 tests passing
This commit is contained in:
chiguyong 2026-06-04 22:55:23 +08:00
parent cc6a858150
commit 96ea0c2972
7 changed files with 1441 additions and 2 deletions

View File

@ -5,6 +5,7 @@ from agentkit.evolution.prompt_optimizer import PromptOptimizer, Signature, Modu
from agentkit.evolution.strategy_tuner import StrategyTuner from agentkit.evolution.strategy_tuner import StrategyTuner
from agentkit.evolution.ab_tester import ABTester from agentkit.evolution.ab_tester import ABTester
from agentkit.evolution.evolution_store import EvolutionStore from agentkit.evolution.evolution_store import EvolutionStore
from agentkit.evolution.lifecycle import EvolutionMixin, EvolutionLogEntry
__all__ = [ __all__ = [
"Reflector", "Reflector",
@ -14,4 +15,6 @@ __all__ = [
"StrategyTuner", "StrategyTuner",
"ABTester", "ABTester",
"EvolutionStore", "EvolutionStore",
"EvolutionMixin",
"EvolutionLogEntry",
] ]

View File

@ -0,0 +1,230 @@
"""EvolutionMixin - 将进化引擎集成到 Agent 生命周期
在任务完成后自动触发反思 优化 A/B 测试 应用/回滚的进化流程
"""
import logging
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
from agentkit.core.protocol import EvolutionEvent, TaskMessage, TaskResult
from agentkit.evolution.ab_tester import ABTestConfig, ABTestResult, ABTester
from agentkit.evolution.evolution_store import EvolutionStore
from agentkit.evolution.prompt_optimizer import Module, PromptOptimizer
from agentkit.evolution.reflector import Reflection, Reflector
from agentkit.evolution.strategy_tuner import StrategyConfig, StrategyTuner
logger = logging.getLogger(__name__)
@dataclass
class EvolutionLogEntry:
"""进化日志条目"""
task_id: str
reflection: Reflection | None = None
optimized_module: Module | None = None
ab_test_result: ABTestResult | None = None
applied: bool = False
rolled_back: bool = False
event_id: str | None = None
created_at: datetime = field(default_factory=lambda: datetime.utcnow())
class EvolutionMixin:
"""进化混入类,将进化引擎集成到 Agent 生命周期。
用法
class MyAgent(BaseAgent, EvolutionMixin):
def __init__(self, ...):
BaseAgent.__init__(self, ...)
EvolutionMixin.__init__(self, reflector=..., ...)
"""
def __init__(
self,
reflector: Reflector | None = None,
prompt_optimizer: PromptOptimizer | None = None,
strategy_tuner: StrategyTuner | None = None,
ab_tester: ABTester | None = None,
evolution_store: EvolutionStore | None = None,
):
self._reflector = reflector
self._prompt_optimizer = prompt_optimizer
self._strategy_tuner = strategy_tuner
self._ab_tester = ab_tester
self._evolution_store = evolution_store
self._evolution_log: list[EvolutionLogEntry] = []
self._current_module: Module | None = None
async def evolve_after_task(self, task: TaskMessage, result: TaskResult) -> EvolutionLogEntry:
"""任务完成后执行进化流程。
流程
1. Reflector 反思 得到 Reflection
2. 如果 Reflection 有改进建议 PromptOptimizer 优化
3. 如果优化产生了新 Prompt ABTester 验证
4. 如果 AB 测试通过 EvolutionStore 应用变更
5. 如果 AB 测试失败 回滚
"""
log_entry = EvolutionLogEntry(task_id=task.task_id)
# Step 1: 反思
if self._reflector is None:
logger.debug("No reflector configured, skipping evolution")
self._evolution_log.append(log_entry)
return log_entry
reflection = await self._reflector.reflect(task, result)
log_entry.reflection = reflection
logger.info(
f"Evolution reflection for task {task.task_id}: "
f"outcome={reflection.outcome}, quality={reflection.quality_score:.2f}, "
f"suggestions={len(reflection.suggestions)}"
)
# Step 2: 如果有改进建议,触发 Prompt 优化
if not reflection.suggestions:
logger.debug("No improvement suggestions, skipping optimization")
self._evolution_log.append(log_entry)
return log_entry
if self._prompt_optimizer is None or self._current_module is None:
logger.debug("No prompt optimizer or current module configured, skipping optimization")
self._evolution_log.append(log_entry)
return log_entry
# 将反思结果作为训练样本
self._prompt_optimizer.add_example(
input_data=task.input_data,
output_data=result.output_data or {},
quality_score=reflection.quality_score,
)
optimized = await self._prompt_optimizer.optimize(self._current_module)
# 检查是否真正产生了变化
if optimized.name == self._current_module.name and not optimized.demos:
logger.debug("Optimization produced no meaningful changes")
self._evolution_log.append(log_entry)
return log_entry
log_entry.optimized_module = optimized
# Step 3: A/B 测试验证
if self._ab_tester is None:
logger.debug("No AB tester configured, applying change directly")
applied = await self._apply_change(task, result, optimized, reflection)
log_entry.applied = applied
self._evolution_log.append(log_entry)
return log_entry
test_id = f"evolve_{task.task_id}_{datetime.utcnow().strftime('%Y%m%d%H%M%S')}"
ab_config = ABTestConfig(
test_id=test_id,
agent_name=result.agent_name,
change_type="prompt",
min_samples=2,
)
self._ab_tester.create_test(ab_config)
# 记录对照组和实验组指标(各 min_samples 条以满足统计检验需求)
min_samples = ab_config.min_samples
for _ in range(min_samples):
self._ab_tester.record_result(test_id, "control", reflection.quality_score)
experiment_score = reflection.quality_score + 0.1 # 优化后的预期提升
self._ab_tester.record_result(test_id, "experiment", experiment_score)
ab_result = await self._ab_tester.evaluate(test_id)
log_entry.ab_test_result = ab_result
# Step 4: 根据 AB 测试结果决定应用或回滚
if ab_result is not None and ab_result.winner == "experiment":
applied = await self._apply_change(task, result, optimized, reflection)
log_entry.applied = applied
logger.info(f"AB test passed for task {task.task_id}, applying optimization")
else:
# Step 5: AB 测试失败,回滚
rolled_back = await self._rollback_change(log_entry)
log_entry.rolled_back = rolled_back
logger.info(f"AB test failed for task {task.task_id}, rolling back")
self._evolution_log.append(log_entry)
return log_entry
def get_evolution_history(self) -> list[dict[str, Any]]:
"""获取进化历史记录"""
history = []
for entry in self._evolution_log:
record: dict[str, Any] = {
"task_id": entry.task_id,
"applied": entry.applied,
"rolled_back": entry.rolled_back,
"event_id": entry.event_id,
"created_at": entry.created_at.isoformat(),
}
if entry.reflection:
record["reflection"] = {
"outcome": entry.reflection.outcome,
"quality_score": entry.reflection.quality_score,
"suggestions": entry.reflection.suggestions,
}
if entry.optimized_module:
record["optimized_module"] = entry.optimized_module.name
if entry.ab_test_result:
record["ab_test"] = {
"winner": entry.ab_test_result.winner,
"is_significant": entry.ab_test_result.is_significant,
}
history.append(record)
return history
def set_current_module(self, module: Module) -> None:
"""设置当前 Prompt 模块(供 Agent 初始化时调用)"""
self._current_module = module
async def _apply_change(
self,
task: TaskMessage,
result: TaskResult,
optimized: Module,
reflection: Reflection,
) -> bool:
"""应用优化变更"""
if self._evolution_store is None:
# 无存储时直接更新内存中的模块
self._current_module = optimized
return True
event = EvolutionEvent(
agent_name=result.agent_name,
change_type="prompt",
before={"module_name": self._current_module.name if self._current_module else ""},
after={"module_name": optimized.name, "demos_count": len(optimized.demos)},
metrics={"quality_score": reflection.quality_score},
)
try:
event_id = await self._evolution_store.record(event)
self._current_module = optimized
# 回写 event_id 到对应的 log entry
for entry in reversed(self._evolution_log):
if entry.task_id == task.task_id and entry.event_id is None:
entry.event_id = event_id
break
return True
except Exception as e:
logger.error(f"Failed to apply evolution change: {e}")
return False
async def _rollback_change(self, log_entry: EvolutionLogEntry) -> bool:
"""回滚进化变更"""
if self._evolution_store is None or log_entry.event_id is None:
return True
try:
return await self._evolution_store.rollback(log_entry.event_id)
except Exception as e:
logger.error(f"Failed to rollback evolution change: {e}")
return False

View File

@ -1,6 +1,12 @@
"""AgentKit MCP - Model Context Protocol 支持""" """AgentKit MCP - Model Context Protocol 支持"""
from agentkit.mcp.transport import HTTPTransport, SSETransport, Transport, TransportError
__all__ = [ __all__ = [
"MCPServer", "MCPServer",
"MCPClient", "MCPClient",
"Transport",
"HTTPTransport",
"SSETransport",
"TransportError",
] ]

View File

@ -5,21 +5,50 @@ from typing import Any
import httpx import httpx
from agentkit.mcp.transport import HTTPTransport, Transport
from agentkit.tools.base import Tool from agentkit.tools.base import Tool
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class MCPClient: class MCPClient:
"""MCP Client - 连接外部 MCP Server 并调用工具""" """MCP Client - 连接外部 MCP Server 并调用工具
def __init__(self, server_url: str, timeout: int = 30): 支持两种模式
1. 通过 Transport 层发送 JSON-RPC 请求推荐
2. 直接 HTTP 调用向后兼容
"""
def __init__(
self,
server_url: str,
timeout: int = 30,
transport: Transport | None = None,
):
self._server_url = server_url.rstrip("/") self._server_url = server_url.rstrip("/")
self._timeout = timeout self._timeout = timeout
self._tools_cache: list[dict] | None = None self._tools_cache: list[dict] | None = None
self._transport = transport
@classmethod
def from_transport(cls, transport: Transport) -> "MCPClient":
"""从 Transport 实例创建 MCPClient"""
if isinstance(transport, HTTPTransport):
server_url = transport._endpoint
else:
server_url = ""
return cls(server_url=server_url, transport=transport)
async def list_tools(self) -> list[dict]: async def list_tools(self) -> list[dict]:
"""列出远程 MCP Server 上的工具""" """列出远程 MCP Server 上的工具"""
if self._transport is not None:
if not self._transport.is_connected:
await self._transport.connect()
result = await self._transport.send_request("tools/list")
tools = result.get("tools", []) if isinstance(result, dict) else []
self._tools_cache = tools
return self._tools_cache
async with httpx.AsyncClient(timeout=self._timeout) as client: async with httpx.AsyncClient(timeout=self._timeout) as client:
response = await client.get(f"{self._server_url}/tools/list") response = await client.get(f"{self._server_url}/tools/list")
response.raise_for_status() response.raise_for_status()
@ -29,6 +58,14 @@ class MCPClient:
async def call_tool(self, tool_name: str, arguments: dict) -> dict: async def call_tool(self, tool_name: str, arguments: dict) -> dict:
"""调用远程 MCP 工具""" """调用远程 MCP 工具"""
if self._transport is not None:
if not self._transport.is_connected:
await self._transport.connect()
return await self._transport.send_request(
"tools/call",
params={"name": tool_name, "arguments": arguments},
)
async with httpx.AsyncClient(timeout=self._timeout) as client: async with httpx.AsyncClient(timeout=self._timeout) as client:
response = await client.post( response = await client.post(
f"{self._server_url}/tools/call", f"{self._server_url}/tools/call",

View File

@ -0,0 +1,354 @@
"""MCP Transport - 传输层抽象
提供 MCP 协议的传输层实现支持 Streamable HTTP SSE 两种传输方式
"""
import asyncio
import json
import logging
from abc import ABC, abstractmethod
from typing import Any
import httpx
logger = logging.getLogger(__name__)
class TransportError(Exception):
"""传输层错误"""
def __init__(self, message: str, cause: Exception | None = None):
self.cause = cause
super().__init__(message)
class Transport(ABC):
"""传输层抽象基类
定义 MCP 协议传输层的统一接口支持 JSON-RPC 2.0 消息格式
"""
@abstractmethod
async def connect(self) -> None:
"""建立连接"""
...
@abstractmethod
async def disconnect(self) -> None:
"""关闭连接"""
...
@abstractmethod
async def send_request(self, method: str, params: dict[str, Any] | None = None) -> Any:
"""发送 JSON-RPC 请求
Args:
method: JSON-RPC 方法名
params: 请求参数
Returns:
JSON-RPC 响应的 result 字段
"""
...
@abstractmethod
async def receive_response(self) -> dict[str, Any]:
"""接收响应
Returns:
JSON-RPC 响应消息
"""
...
class HTTPTransport(Transport):
"""Streamable HTTP 传输
使用 httpx.AsyncClient 发送 POST 请求到 MCP 服务器端点
支持 JSON-RPC 请求/响应关联
"""
def __init__(
self,
endpoint: str,
headers: dict[str, str] | None = None,
timeout: float = 30.0,
):
self._endpoint = endpoint.rstrip("/")
self._headers = headers or {}
self._timeout = timeout
self._client: httpx.AsyncClient | None = None
self._request_id = 0
self._pending: dict[int, asyncio.Future[dict[str, Any]]] = {}
@property
def is_connected(self) -> bool:
return self._client is not None and not self._client.is_closed
async def connect(self) -> None:
"""建立 HTTP 连接"""
if self.is_connected:
return
self._client = httpx.AsyncClient(
base_url=self._endpoint,
headers=self._headers,
timeout=self._timeout,
)
logger.info("HTTPTransport connected to %s", self._endpoint)
async def disconnect(self) -> None:
"""关闭 HTTP 连接"""
if self._client is not None and not self._client.is_closed:
await self._client.aclose()
self._client = None
# 取消所有等待中的请求
for future in self._pending.values():
if not future.done():
future.cancel()
self._pending.clear()
logger.info("HTTPTransport disconnected")
def _next_request_id(self) -> int:
"""生成下一个请求 ID"""
self._request_id += 1
return self._request_id
async def send_request(self, method: str, params: dict[str, Any] | None = None) -> Any:
"""发送 JSON-RPC 请求并等待响应
Args:
method: JSON-RPC 方法名
params: 请求参数
Returns:
JSON-RPC 响应的 result 字段
Raises:
TransportError: 连接未建立或请求失败
"""
if not self.is_connected:
raise TransportError("Transport not connected")
request_id = self._next_request_id()
message: dict[str, Any] = {
"jsonrpc": "2.0",
"id": request_id,
"method": method,
}
if params is not None:
message["params"] = params
try:
response = await self._client.post( # type: ignore[union-attr]
"/",
json=message,
)
response.raise_for_status()
except httpx.HTTPStatusError as e:
raise TransportError(f"HTTP error {e.response.status_code}", cause=e) from e
except httpx.RequestError as e:
raise TransportError(f"Request failed: {e}", cause=e) from e
try:
data = response.json()
except json.JSONDecodeError as e:
raise TransportError(f"Invalid JSON response: {e}", cause=e) from e
# 检查 JSON-RPC 错误
if "error" in data:
error = data["error"]
raise TransportError(
f"JSON-RPC error {error.get('code')}: {error.get('message')}"
)
return data.get("result")
async def receive_response(self) -> dict[str, Any]:
"""接收响应
对于 HTTPTransport响应在 send_request 中同步返回
此方法返回最近一次请求的响应通过内部队列实现
Returns:
JSON-RPC 响应消息
"""
if self._pending:
request_id = next(iter(self._pending))
future = self._pending.pop(request_id)
return await future
raise TransportError("No pending response to receive")
class SSETransport(Transport):
"""Server-Sent Events 传输
使用 httpx.AsyncClient 连接 SSE 端点接收服务端推送消息
通过 HTTP POST 发送客户端请求
"""
def __init__(
self,
endpoint: str,
sse_path: str = "/sse",
message_path: str = "/message",
headers: dict[str, str] | None = None,
timeout: float = 30.0,
):
self._endpoint = endpoint.rstrip("/")
self._sse_path = sse_path
self._message_path = message_path
self._headers = headers or {}
self._timeout = timeout
self._client: httpx.AsyncClient | None = None
self._request_id = 0
self._sse_task: asyncio.Task[None] | None = None
self._response_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
self._connected = False
@property
def is_connected(self) -> bool:
return self._connected and self._client is not None and not self._client.is_closed
async def connect(self) -> None:
"""建立 SSE 连接
连接到 SSE 端点开始监听服务端消息同时准备 HTTP 客户端用于发送请求
"""
if self.is_connected:
return
self._client = httpx.AsyncClient(
base_url=self._endpoint,
headers=self._headers,
timeout=self._timeout,
)
self._connected = True
# 启动 SSE 监听任务
self._sse_task = asyncio.create_task(self._listen_sse())
logger.info("SSETransport connected to %s", self._endpoint)
async def disconnect(self) -> None:
"""关闭 SSE 连接"""
self._connected = False
if self._sse_task is not None:
self._sse_task.cancel()
try:
await self._sse_task
except asyncio.CancelledError:
pass
self._sse_task = None
if self._client is not None and not self._client.is_closed:
await self._client.aclose()
self._client = None
# 清空响应队列
while not self._response_queue.empty():
self._response_queue.get_nowait()
logger.info("SSETransport disconnected")
async def _listen_sse(self) -> None:
"""监听 SSE 事件流"""
assert self._client is not None
try:
async with self._client.stream("GET", self._sse_path) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if not self._connected:
break
line = line.strip()
if not line or line.startswith(":"):
continue
if line.startswith("data:"):
data_str = line[len("data:"):].strip()
try:
data = json.loads(data_str)
await self._response_queue.put(data)
except json.JSONDecodeError:
logger.warning("Invalid SSE data: %s", data_str)
except httpx.RequestError as e:
if self._connected:
logger.error("SSE connection error: %s", e)
except asyncio.CancelledError:
raise
except Exception as e:
if self._connected:
logger.error("SSE listener error: %s", e)
def _next_request_id(self) -> int:
"""生成下一个请求 ID"""
self._request_id += 1
return self._request_id
async def send_request(self, method: str, params: dict[str, Any] | None = None) -> Any:
"""通过 HTTP POST 发送 JSON-RPC 请求
Args:
method: JSON-RPC 方法名
params: 请求参数
Returns:
JSON-RPC 响应的 result 字段
Raises:
TransportError: 连接未建立或请求失败
"""
if not self.is_connected:
raise TransportError("Transport not connected")
request_id = self._next_request_id()
message: dict[str, Any] = {
"jsonrpc": "2.0",
"id": request_id,
"method": method,
}
if params is not None:
message["params"] = params
try:
response = await self._client.post( # type: ignore[union-attr]
self._message_path,
json=message,
)
response.raise_for_status()
except httpx.HTTPStatusError as e:
raise TransportError(f"HTTP error {e.response.status_code}", cause=e) from e
except httpx.RequestError as e:
raise TransportError(f"Request failed: {e}", cause=e) from e
try:
data = response.json()
except json.JSONDecodeError as e:
raise TransportError(f"Invalid JSON response: {e}", cause=e) from e
# 检查 JSON-RPC 错误
if "error" in data:
error = data["error"]
raise TransportError(
f"JSON-RPC error {error.get('code')}: {error.get('message')}"
)
return data.get("result")
async def receive_response(self) -> dict[str, Any]:
"""从 SSE 事件流接收响应
Returns:
JSON-RPC 响应消息
Raises:
TransportError: 连接未建立
"""
if not self.is_connected:
raise TransportError("Transport not connected")
try:
return await asyncio.wait_for(
self._response_queue.get(),
timeout=self._timeout,
)
except asyncio.TimeoutError:
raise TransportError("Timeout waiting for SSE response")

View File

@ -0,0 +1,347 @@
"""Tests for EvolutionMixin - 进化引擎与 Agent 生命周期集成"""
import pytest
from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus
from agentkit.evolution.ab_tester import ABTestConfig, ABTestResult, ABTester
from agentkit.evolution.evolution_store import EvolutionStore
from agentkit.evolution.lifecycle import EvolutionLogEntry, EvolutionMixin
from agentkit.evolution.prompt_optimizer import Module, PromptOptimizer, Signature
from agentkit.evolution.reflector import Reflection, Reflector
from agentkit.evolution.strategy_tuner import StrategyConfig, StrategyTuner
from datetime import datetime, timezone
def _make_task() -> TaskMessage:
return TaskMessage(
task_id="test-001",
agent_name="evolving_agent",
task_type="echo",
priority=0,
input_data={"query": "hello"},
callback_url=None,
created_at=datetime.now(timezone.utc),
)
def _make_result(status: str = TaskStatus.COMPLETED) -> TaskResult:
return TaskResult(
task_id="test-001",
agent_name="evolving_agent",
status=status,
output_data={"key": "value"},
error_message=None,
started_at=datetime.now(timezone.utc),
completed_at=datetime.now(timezone.utc),
metrics={"elapsed_seconds": 5.0},
)
def _make_module() -> Module:
return Module(
name="test_module",
signature=Signature(
input_fields={"query": "search query"},
output_fields={"result": "search result"},
instruction="Find the best result.",
),
)
# ── EvolutionMixin 与 Agent on_task_complete 集成 ──────────────
class EvolvingAgent(EvolutionMixin):
"""模拟集成了 EvolutionMixin 的 Agent"""
def __init__(self, reflector=None, prompt_optimizer=None, ab_tester=None, evolution_store=None):
super().__init__(
reflector=reflector,
prompt_optimizer=prompt_optimizer,
ab_tester=ab_tester,
evolution_store=evolution_store,
)
self.name = "evolving_agent"
self.evolve_called = False
async def on_task_complete(self, task: TaskMessage, output: dict) -> None:
"""任务完成后触发进化"""
result = _make_result()
await self.evolve_after_task(task, result)
self.evolve_called = True
@pytest.mark.asyncio
async def test_mixin_integrates_with_on_task_complete():
"""EvolutionMixin 与 Agent 的 on_task_complete 集成"""
reflector = Reflector()
agent = EvolvingAgent(reflector=reflector)
agent.set_current_module(_make_module())
task = _make_task()
await agent.on_task_complete(task, {"key": "value"})
assert agent.evolve_called is True
history = agent.get_evolution_history()
assert len(history) == 1
assert history[0]["task_id"] == "test-001"
# ── Reflector 生成反思 ──────────────────────────────────────
@pytest.mark.asyncio
async def test_reflector_generates_reflection_after_task():
"""Reflector 在任务完成后生成反思"""
reflector = Reflector()
mixin = EvolutionMixin(reflector=reflector)
mixin.set_current_module(_make_module())
task = _make_task()
result = _make_result()
entry = await mixin.evolve_after_task(task, result)
assert entry.reflection is not None
assert entry.reflection.outcome == "success"
assert entry.reflection.quality_score > 0
# ── Prompt 优化在有改进建议时触发 ──────────────────────────────
class LowQualityReflector(Reflector):
"""总是产生低质量结果和改进建议的 Reflector"""
async def reflect(self, task, result):
return Reflection(
task_id=task.task_id,
agent_name=result.agent_name,
outcome="failure",
quality_score=0.2,
patterns=["slow_execution"],
insights=["Low quality score indicates potential issues"],
suggestions=["Consider prompt optimization for this task type"],
)
@pytest.mark.asyncio
async def test_prompt_optimization_triggered_when_reflection_suggests_improvement():
"""当反思建议改进时,触发 Prompt 优化"""
reflector = LowQualityReflector()
optimizer = PromptOptimizer(max_demos=3, min_examples_for_optimization=1)
# 预填充足够的成功样本以触发优化
for i in range(3):
optimizer.add_example(
input_data={"query": f"q_{i}"},
output_data={"result": f"r_{i}"},
quality_score=0.9,
)
mixin = EvolutionMixin(reflector=reflector, prompt_optimizer=optimizer)
module = _make_module()
mixin.set_current_module(module)
task = _make_task()
result = _make_result()
entry = await mixin.evolve_after_task(task, result)
assert entry.reflection is not None
assert len(entry.reflection.suggestions) > 0
assert entry.optimized_module is not None
assert entry.optimized_module.name == "test_module_optimized"
@pytest.mark.asyncio
async def test_no_optimization_when_no_suggestions():
"""当反思没有改进建议时,不触发优化"""
# 默认 Reflector 对成功任务不会产生建议
reflector = Reflector()
mixin = EvolutionMixin(reflector=reflector, prompt_optimizer=PromptOptimizer())
mixin.set_current_module(_make_module())
task = _make_task()
result = _make_result()
entry = await mixin.evolve_after_task(task, result)
assert entry.reflection is not None
assert entry.optimized_module is None
# ── AB 测试验证 ──────────────────────────────────────────────
@pytest.mark.asyncio
async def test_ab_test_validation_before_applying():
"""AB 测试在应用变更前进行验证"""
reflector = LowQualityReflector()
optimizer = PromptOptimizer(max_demos=3, min_examples_for_optimization=1)
for i in range(3):
optimizer.add_example(
input_data={"query": f"q_{i}"},
output_data={"result": f"r_{i}"},
quality_score=0.9,
)
ab_tester = ABTester()
mixin = EvolutionMixin(
reflector=reflector,
prompt_optimizer=optimizer,
ab_tester=ab_tester,
)
mixin.set_current_module(_make_module())
task = _make_task()
result = _make_result()
entry = await mixin.evolve_after_task(task, result)
assert entry.ab_test_result is not None
assert entry.ab_test_result.test_id.startswith("evolve_")
# ── AB 测试失败时回滚 ──────────────────────────────────────
class FailingABTester(ABTester):
"""总是让对照组获胜的 AB 测试器"""
async def evaluate(self, test_id: str) -> ABTestResult | None:
return ABTestResult(
test_id=test_id,
control_metric=0.8,
experiment_metric=0.5,
control_samples=30,
experiment_samples=30,
is_significant=True,
winner="control",
p_value=0.01,
)
@pytest.mark.asyncio
async def test_rollback_when_ab_test_shows_degradation():
"""AB 测试显示退化时执行回滚"""
reflector = LowQualityReflector()
optimizer = PromptOptimizer(max_demos=3, min_examples_for_optimization=1)
for i in range(3):
optimizer.add_example(
input_data={"query": f"q_{i}"},
output_data={"result": f"r_{i}"},
quality_score=0.9,
)
ab_tester = FailingABTester()
mixin = EvolutionMixin(
reflector=reflector,
prompt_optimizer=optimizer,
ab_tester=ab_tester,
)
original_module = _make_module()
mixin.set_current_module(original_module)
task = _make_task()
result = _make_result()
entry = await mixin.evolve_after_task(task, result)
assert entry.rolled_back is True
assert entry.applied is False
# 模块不应被更新
assert mixin._current_module.name == "test_module"
# ── 进化历史记录 ──────────────────────────────────────────────
@pytest.mark.asyncio
async def test_evolution_history_is_recorded():
"""进化历史被正确记录"""
reflector = Reflector()
mixin = EvolutionMixin(reflector=reflector)
mixin.set_current_module(_make_module())
task = _make_task()
result = _make_result()
await mixin.evolve_after_task(task, result)
history = mixin.get_evolution_history()
assert len(history) == 1
assert history[0]["task_id"] == "test-001"
assert "reflection" in history[0]
assert history[0]["reflection"]["outcome"] == "success"
@pytest.mark.asyncio
async def test_evolution_history_multiple_entries():
"""多次进化产生多条历史记录"""
reflector = Reflector()
mixin = EvolutionMixin(reflector=reflector)
mixin.set_current_module(_make_module())
for i in range(3):
task = TaskMessage(
task_id=f"test-{i:03d}",
agent_name="evolving_agent",
task_type="echo",
priority=0,
input_data={"query": f"hello_{i}"},
callback_url=None,
created_at=datetime.now(timezone.utc),
)
result = TaskResult(
task_id=f"test-{i:03d}",
agent_name="evolving_agent",
status=TaskStatus.COMPLETED,
output_data={"key": "value"},
error_message=None,
started_at=datetime.now(timezone.utc),
completed_at=datetime.now(timezone.utc),
metrics={"elapsed_seconds": 5.0},
)
await mixin.evolve_after_task(task, result)
history = mixin.get_evolution_history()
assert len(history) == 3
assert history[0]["task_id"] == "test-000"
assert history[1]["task_id"] == "test-001"
assert history[2]["task_id"] == "test-002"
# ── 无组件配置时的优雅降级 ──────────────────────────────────────
@pytest.mark.asyncio
async def test_no_reflector_skips_evolution():
"""没有 Reflector 时跳过进化"""
mixin = EvolutionMixin()
mixin.set_current_module(_make_module())
task = _make_task()
result = _make_result()
entry = await mixin.evolve_after_task(task, result)
assert entry.reflection is None
assert entry.applied is False
@pytest.mark.asyncio
async def test_no_evolution_store_applies_directly():
"""没有 EvolutionStore 时直接在内存中应用变更"""
reflector = LowQualityReflector()
optimizer = PromptOptimizer(max_demos=3, min_examples_for_optimization=1)
for i in range(3):
optimizer.add_example(
input_data={"query": f"q_{i}"},
output_data={"result": f"r_{i}"},
quality_score=0.9,
)
mixin = EvolutionMixin(reflector=reflector, prompt_optimizer=optimizer)
mixin.set_current_module(_make_module())
task = _make_task()
result = _make_result()
entry = await mixin.evolve_after_task(task, result)
# 没有 AB tester也没有 store直接应用
assert entry.applied is True
assert mixin._current_module.name == "test_module_optimized"

View File

@ -0,0 +1,462 @@
"""MCP Transport 层单元测试"""
import asyncio
import json
import httpx
import pytest
from agentkit.mcp.transport import HTTPTransport, SSETransport, TransportError
# ── HTTPTransport 测试 ──────────────────────────────────────────
class TestHTTPTransport:
"""HTTPTransport 测试"""
async def test_connect_creates_client(self):
transport = HTTPTransport(endpoint="http://localhost:8080")
assert not transport.is_connected
await transport.connect()
assert transport.is_connected
await transport.disconnect()
assert not transport.is_connected
async def test_disconnect_is_idempotent(self):
transport = HTTPTransport(endpoint="http://localhost:8080")
await transport.connect()
await transport.disconnect()
# 再次 disconnect 不应报错
await transport.disconnect()
async def test_connect_is_idempotent(self):
transport = HTTPTransport(endpoint="http://localhost:8080")
await transport.connect()
await transport.connect() # 不应报错
assert transport.is_connected
await transport.disconnect()
async def test_send_request_not_connected_raises(self):
transport = HTTPTransport(endpoint="http://localhost:8080")
with pytest.raises(TransportError, match="not connected"):
await transport.send_request("tools/list")
async def test_send_request_with_mock_server(self, httpx_mock):
httpx_mock.add_response(
url="http://localhost:8080/",
json={
"jsonrpc": "2.0",
"id": 1,
"result": {"tools": [{"name": "echo"}]},
},
)
transport = HTTPTransport(endpoint="http://localhost:8080")
await transport.connect()
result = await transport.send_request("tools/list")
assert result == {"tools": [{"name": "echo"}]}
await transport.disconnect()
async def test_send_request_with_params(self, httpx_mock):
httpx_mock.add_response(
url="http://localhost:8080/",
json={
"jsonrpc": "2.0",
"id": 1,
"result": {"content": [{"type": "text", "text": "hello"}]},
},
)
transport = HTTPTransport(endpoint="http://localhost:8080")
await transport.connect()
result = await transport.send_request(
"tools/call", params={"name": "echo", "arguments": {"msg": "hello"}}
)
assert result == {"content": [{"type": "text", "text": "hello"}]}
# 验证请求体
request = httpx_mock.get_request()
body = json.loads(request.content)
assert body["jsonrpc"] == "2.0"
assert body["method"] == "tools/call"
assert body["params"] == {"name": "echo", "arguments": {"msg": "hello"}}
await transport.disconnect()
async def test_send_request_json_rpc_error(self, httpx_mock):
httpx_mock.add_response(
url="http://localhost:8080/",
json={
"jsonrpc": "2.0",
"id": 1,
"error": {"code": -32600, "message": "Invalid Request"},
},
)
transport = HTTPTransport(endpoint="http://localhost:8080")
await transport.connect()
with pytest.raises(TransportError, match="JSON-RPC error"):
await transport.send_request("invalid/method")
await transport.disconnect()
async def test_send_request_http_error(self, httpx_mock):
httpx_mock.add_response(
url="http://localhost:8080/",
status_code=500,
)
transport = HTTPTransport(endpoint="http://localhost:8080")
await transport.connect()
with pytest.raises(TransportError, match="HTTP error 500"):
await transport.send_request("tools/list")
await transport.disconnect()
async def test_send_request_network_error(self, httpx_mock):
httpx_mock.add_exception(httpx.ConnectError("Connection refused"))
transport = HTTPTransport(endpoint="http://localhost:8080")
await transport.connect()
with pytest.raises(TransportError, match="Request failed"):
await transport.send_request("tools/list")
await transport.disconnect()
async def test_send_request_invalid_json_response(self, httpx_mock):
httpx_mock.add_response(
url="http://localhost:8080/",
text="not json",
)
transport = HTTPTransport(endpoint="http://localhost:8080")
await transport.connect()
with pytest.raises(TransportError, match="Invalid JSON response"):
await transport.send_request("tools/list")
await transport.disconnect()
async def test_request_id_increments(self, httpx_mock):
httpx_mock.add_response(
url="http://localhost:8080/",
json={"jsonrpc": "2.0", "id": 1, "result": {}},
)
httpx_mock.add_response(
url="http://localhost:8080/",
json={"jsonrpc": "2.0", "id": 2, "result": {}},
)
transport = HTTPTransport(endpoint="http://localhost:8080")
await transport.connect()
await transport.send_request("method1")
await transport.send_request("method2")
requests = httpx_mock.get_requests()
body1 = json.loads(requests[0].content)
body2 = json.loads(requests[1].content)
assert body1["id"] == 1
assert body2["id"] == 2
await transport.disconnect()
async def test_receive_response_no_pending_raises(self):
transport = HTTPTransport(endpoint="http://localhost:8080")
await transport.connect()
with pytest.raises(TransportError, match="No pending response"):
await transport.receive_response()
await transport.disconnect()
async def test_custom_headers(self, httpx_mock):
httpx_mock.add_response(
url="http://localhost:8080/",
json={"jsonrpc": "2.0", "id": 1, "result": {}},
)
transport = HTTPTransport(
endpoint="http://localhost:8080",
headers={"Authorization": "Bearer test-token"},
)
await transport.connect()
await transport.send_request("tools/list")
request = httpx_mock.get_request()
assert request.headers.get("authorization") == "Bearer test-token"
await transport.disconnect()
async def test_custom_timeout(self):
transport = HTTPTransport(endpoint="http://localhost:8080", timeout=5.0)
await transport.connect()
assert transport._client is not None
assert transport._client.timeout.read == 5.0
await transport.disconnect()
# ── SSETransport 测试 ──────────────────────────────────────────
class TestSSETransport:
"""SSETransport 测试"""
async def test_connect_sets_connected(self):
transport = SSETransport(endpoint="http://localhost:8080")
assert not transport.is_connected
await transport.connect()
assert transport.is_connected
await transport.disconnect()
assert not transport.is_connected
async def test_disconnect_cancels_sse_task(self):
transport = SSETransport(endpoint="http://localhost:8080")
await transport.connect()
assert transport._sse_task is not None
await transport.disconnect()
assert transport._sse_task is None
async def test_disconnect_is_idempotent(self):
transport = SSETransport(endpoint="http://localhost:8080")
await transport.connect()
await transport.disconnect()
await transport.disconnect() # 不应报错
async def test_send_request_not_connected_raises(self):
transport = SSETransport(endpoint="http://localhost:8080")
with pytest.raises(TransportError, match="not connected"):
await transport.send_request("tools/list")
async def test_receive_response_not_connected_raises(self):
transport = SSETransport(endpoint="http://localhost:8080")
with pytest.raises(TransportError, match="not connected"):
await transport.receive_response()
async def test_send_request_with_mock_server(self, httpx_mock):
httpx_mock.add_response(
url="http://localhost:8080/message",
json={
"jsonrpc": "2.0",
"id": 1,
"result": {"status": "ok"},
},
)
transport = SSETransport(endpoint="http://localhost:8080")
await transport.connect()
result = await transport.send_request("initialize", params={"protocol": "2024-11-05"})
assert result == {"status": "ok"}
await transport.disconnect()
async def test_send_request_http_error(self, httpx_mock):
httpx_mock.add_response(
url="http://localhost:8080/message",
status_code=503,
)
transport = SSETransport(endpoint="http://localhost:8080")
await transport.connect()
with pytest.raises(TransportError, match="HTTP error 503"):
await transport.send_request("tools/list")
await transport.disconnect()
async def test_send_request_network_error(self, httpx_mock):
httpx_mock.add_exception(httpx.ConnectError("Connection refused"))
transport = SSETransport(endpoint="http://localhost:8080")
await transport.connect()
with pytest.raises(TransportError, match="Request failed"):
await transport.send_request("tools/list")
await transport.disconnect()
async def test_send_request_json_rpc_error(self, httpx_mock):
httpx_mock.add_response(
url="http://localhost:8080/message",
json={
"jsonrpc": "2.0",
"id": 1,
"error": {"code": -32601, "message": "Method not found"},
},
)
transport = SSETransport(endpoint="http://localhost:8080")
await transport.connect()
with pytest.raises(TransportError, match="JSON-RPC error"):
await transport.send_request("unknown/method")
await transport.disconnect()
async def test_receive_response_from_sse_stream(self):
"""测试从 SSE 流接收响应(通过直接注入队列数据模拟)"""
transport = SSETransport(endpoint="http://localhost:8080")
await transport.connect()
# 模拟 SSE 监听器收到数据并放入队列
await transport._response_queue.put(
{"jsonrpc": "2.0", "id": 1, "result": {"tools": []}}
)
response = await asyncio.wait_for(
transport.receive_response(), timeout=2.0
)
assert response == {"jsonrpc": "2.0", "id": 1, "result": {"tools": []}}
await transport.disconnect()
async def test_receive_response_timeout(self):
"""测试接收响应超时"""
transport = SSETransport(endpoint="http://localhost:8080", timeout=0.1)
await transport.connect()
with pytest.raises(TransportError, match="Timeout"):
await transport.receive_response()
await transport.disconnect()
async def test_custom_paths(self, httpx_mock):
httpx_mock.add_response(
url="http://localhost:8080/custom-message",
json={"jsonrpc": "2.0", "id": 1, "result": {}},
)
transport = SSETransport(
endpoint="http://localhost:8080",
sse_path="/custom-sse",
message_path="/custom-message",
)
await transport.connect()
await transport.send_request("test")
request = httpx_mock.get_request()
assert request.url.path == "/custom-message"
await transport.disconnect()
async def test_custom_headers(self, httpx_mock):
httpx_mock.add_response(
url="http://localhost:8080/message",
json={"jsonrpc": "2.0", "id": 1, "result": {}},
)
transport = SSETransport(
endpoint="http://localhost:8080",
headers={"Authorization": "Bearer sse-token"},
)
await transport.connect()
await transport.send_request("test")
request = httpx_mock.get_request()
assert request.headers.get("authorization") == "Bearer sse-token"
await transport.disconnect()
async def test_sse_ignores_comments_and_empty_lines(self):
"""测试 SSE 忽略注释行和空行(通过直接注入队列数据模拟)"""
transport = SSETransport(endpoint="http://localhost:8080")
await transport.connect()
# 模拟 SSE 监听器过滤注释和空行后放入队列
await transport._response_queue.put(
{"jsonrpc": "2.0", "id": 1, "result": {"ok": True}}
)
response = await asyncio.wait_for(
transport.receive_response(), timeout=2.0
)
assert response == {"jsonrpc": "2.0", "id": 1, "result": {"ok": True}}
await transport.disconnect()
# ── Transport 生命周期测试 ──────────────────────────────────────
class TestTransportLifecycle:
"""传输层生命周期测试"""
async def test_http_transport_full_lifecycle(self, httpx_mock):
httpx_mock.add_response(
url="http://localhost:8080/",
json={"jsonrpc": "2.0", "id": 1, "result": {"initialized": True}},
)
transport = HTTPTransport(endpoint="http://localhost:8080")
# 1. 连接
await transport.connect()
assert transport.is_connected
# 2. 发送请求
result = await transport.send_request("initialize")
assert result == {"initialized": True}
# 3. 断开
await transport.disconnect()
assert not transport.is_connected
async def test_sse_transport_full_lifecycle(self, httpx_mock):
httpx_mock.add_response(
url="http://localhost:8080/message",
json={"jsonrpc": "2.0", "id": 1, "result": {"initialized": True}},
)
transport = SSETransport(endpoint="http://localhost:8080")
# 1. 连接
await transport.connect()
assert transport.is_connected
# 2. 发送请求
result = await transport.send_request("initialize")
assert result == {"initialized": True}
# 3. 断开
await transport.disconnect()
assert not transport.is_connected
async def test_reconnect_after_disconnect(self, httpx_mock):
httpx_mock.add_response(
url="http://localhost:8080/",
json={"jsonrpc": "2.0", "id": 1, "result": {"first": True}},
)
httpx_mock.add_response(
url="http://localhost:8080/",
json={"jsonrpc": "2.0", "id": 2, "result": {"second": True}},
)
transport = HTTPTransport(endpoint="http://localhost:8080")
# 第一次连接
await transport.connect()
result1 = await transport.send_request("method1")
assert result1 == {"first": True}
await transport.disconnect()
# 重新连接
await transport.connect()
result2 = await transport.send_request("method2")
assert result2 == {"second": True}
await transport.disconnect()