fix: resolve code review issues from deferred improvements
1. InMemoryMessageBus.request(): fix param name (timeout→timeout_seconds) to match ABC 2. InMemoryMessageBus: track consumer tasks, cancel on unsubscribe 3. InMemoryMessageBus: _try_resolve_pending() in queue consumer path 4. evolve_soul(): use "default" category when patterns is empty 5. quick_classify(): use delimiter-based prompt to mitigate injection risk 6. Use asyncio.get_running_loop() instead of deprecated get_event_loop()
This commit is contained in:
parent
ec51dbb259
commit
d47f279887
|
|
@ -0,0 +1,244 @@
|
|||
---
|
||||
title: "feat: AgentKit 劣势项改进"
|
||||
status: active
|
||||
created: 2026-06-10
|
||||
plan_type: feat
|
||||
---
|
||||
|
||||
# feat: AgentKit 劣势项改进
|
||||
|
||||
## Summary
|
||||
|
||||
针对代码审查后识别的 6 个延后劣势项进行改进:Soul演变多维度触发、中文分词增强、端到端可观测性、测试覆盖补充、ReWOO渐进回退、Agent间通信协议。
|
||||
|
||||
## Problem Frame
|
||||
|
||||
当前 AgentKit 的多Agent架构已基本成型,但存在 6 个影响生产就绪度的劣势项。这些项不阻塞核心功能,但会随使用规模扩大而暴露问题。
|
||||
|
||||
## Scope Boundaries
|
||||
|
||||
### In Scope
|
||||
- Soul演变多维度触发(时间衰减+质量梯度+场景权重)
|
||||
- 中文分词增强(2-gram+停用词,零依赖)
|
||||
- 端到端可观测性(OpenTelemetry埋点)
|
||||
- 集成测试补充(多模块交互测试)
|
||||
- ReWOO渐进回退链
|
||||
- Agent间通信协议(消息总线)
|
||||
|
||||
### Out of Scope
|
||||
- jieba分词集成(可选依赖,后续按需)
|
||||
- 分布式消息总线(Redis实现,后续按需)
|
||||
- Live测试(需要真实API Key,CI中标记跳过)
|
||||
- UI可视化trace(Jaeger/Zipkin前端,运维配置)
|
||||
|
||||
---
|
||||
|
||||
## Key Technical Decisions
|
||||
|
||||
### KTD1: 中文分词采用2-gram+停用词方案(零依赖)
|
||||
|
||||
jieba虽精准但是额外依赖,2-gram方案对"数据分析""代码生成"等复合词已有足够捕获率。后续可按需引入jieba作为可选依赖。
|
||||
|
||||
### KTD2: 可观测性基于OpenTelemetry标准
|
||||
|
||||
各引擎已有 `trace_recorder` 参数,只需实现 OTel 导出器。配置通过 `telemetry.otlp_endpoint` 启用。
|
||||
|
||||
### KTD3: Agent间通信采用请求-响应+广播+协商三模式
|
||||
|
||||
基于 `AgentMessage` 数据类和 `MessageBus` 接口。先实现 `InMemoryMessageBus`,后续扩展 `RedisMessageBus`。
|
||||
|
||||
### KTD4: ReWOO回退链为 ReWOO简化→PlanExec→ReAct→Direct
|
||||
|
||||
渐进式降级,每一步保留更多结构化能力。
|
||||
|
||||
---
|
||||
|
||||
## Implementation Units
|
||||
|
||||
### U1. Soul演变多维度触发
|
||||
|
||||
**Goal:** 扩展Soul演变触发条件,支持时间衰减、质量梯度、场景权重
|
||||
|
||||
**Dependencies:** None
|
||||
|
||||
**Files:**
|
||||
- Modify: `src/agentkit/evolution/lifecycle.py`
|
||||
- Modify: `src/agentkit/evolution/config.py` (新增 SoulEvolutionConfig)
|
||||
- Test: `tests/unit/test_soul_evolution.py`
|
||||
|
||||
**Approach:**
|
||||
1. 新增 `SoulEvolutionConfig` dataclass,包含 `reflection_window_seconds`、`time_decay_factor`、`task_type_weights`、`quality_gradient_threshold`
|
||||
2. 修改 `evolve_soul` 方法:时间衰减计算 `effective_count = sum(factor^age_hours)` 替代 `len(reflections)`
|
||||
3. 新增质量梯度检测:追踪最近N次评分趋势,连续下降超过阈值时提前触发
|
||||
4. 场景权重:高风险任务降低触发阈值
|
||||
|
||||
**Test scenarios:**
|
||||
- 时间衰减:1小时内的3次反思触发,超过窗口的不计入
|
||||
- 质量梯度:连续3次评分下降0.15触发,即使总反思次数<3
|
||||
- 场景权重:code_generation任务2次即触发(权重1.5),chat任务需4次(权重0.5)
|
||||
- 兼容性:无config时行为与现有一致
|
||||
|
||||
**Verification:** 新增测试全部通过,现有Soul演变测试不受影响
|
||||
|
||||
---
|
||||
|
||||
### U2. 中文分词增强
|
||||
|
||||
**Goal:** 替换空格分词为2-gram+停用词方案,提升中文能力关键词提取
|
||||
|
||||
**Dependencies:** None
|
||||
|
||||
**Files:**
|
||||
- Modify: `src/agentkit/chat/skill_routing.py`
|
||||
- Test: `tests/unit/test_cost_aware_router.py`
|
||||
|
||||
**Approach:**
|
||||
1. 提取 `_tokenize_content()` 为独立方法
|
||||
2. 标点分割 → 对长中文段补充2-gram → 停用词过滤
|
||||
3. Layer 2 路由调用新方法
|
||||
|
||||
**Test scenarios:**
|
||||
- 中文:"帮我做数据分析" → 提取 ["数据分析", "帮我"]
|
||||
- 英文:"help with code generation" → 提取 ["code generation", "help"]
|
||||
- 混合:"用python做data analysis" → 提取 ["python", "data analysis"]
|
||||
- 停用词过滤:"的了一个" → 过滤后为空
|
||||
|
||||
**Verification:** 中文分词测试通过,现有路由测试不受影响
|
||||
|
||||
---
|
||||
|
||||
### U3. 端到端可观测性(OpenTelemetry埋点)
|
||||
|
||||
**Goal:** 为核心组件添加OTel trace/metrics埋点
|
||||
|
||||
**Dependencies:** None
|
||||
|
||||
**Files:**
|
||||
- Create: `src/agentkit/telemetry/__init__.py`
|
||||
- Create: `src/agentkit/telemetry/tracer.py` (OTel导出器)
|
||||
- Modify: `src/agentkit/chat/skill_routing.py` (路由埋点)
|
||||
- Modify: `src/agentkit/core/react.py` (ReAct埋点)
|
||||
- Modify: `src/agentkit/core/rewoo.py` (ReWOO埋点)
|
||||
- Modify: `src/agentkit/core/reflexion.py` (Reflexion埋点)
|
||||
- Modify: `src/agentkit/quality/alignment.py` (Guard埋点)
|
||||
- Test: `tests/unit/test_telemetry.py`
|
||||
|
||||
**Approach:**
|
||||
1. 定义 `TelemetryConfig` 和 `get_tracer()` 工厂
|
||||
2. 实现 `OTelTraceRecorder` 适配器,桥接现有 `TraceRecorder` 接口
|
||||
3. 各组件在关键节点创建 Span:路由决策、引擎执行、约束检查
|
||||
4. 配置通过 `telemetry.otlp_endpoint` 启用,未配置时为 no-op
|
||||
|
||||
**Test scenarios:**
|
||||
- 未配置时:所有埋点为 no-op,不影响功能
|
||||
- 配置endpoint时:Span正确创建,属性包含 layer/complexity/status 等
|
||||
- 路由埋点:Layer 0/1/2 各产生带属性的 Span
|
||||
- 引擎埋点:执行完成后 Span 记录 steps_count/total_tokens/status
|
||||
|
||||
**Verification:** 新增测试通过,无OTel依赖时功能正常
|
||||
|
||||
---
|
||||
|
||||
### U4. 集成测试补充
|
||||
|
||||
**Goal:** 补充多模块交互集成测试,覆盖核心链路
|
||||
|
||||
**Dependencies:** U1, U2
|
||||
|
||||
**Files:**
|
||||
- Create: `tests/integration/test_router_engine_chain.py`
|
||||
- Create: `tests/integration/test_rewoo_fallback.py`
|
||||
- Create: `tests/integration/test_reflexion_loop.py`
|
||||
- Create: `tests/integration/test_soul_evolution_trigger.py`
|
||||
|
||||
**Approach:**
|
||||
1. 路由→执行→审计全链路测试(仅mock LLM)
|
||||
2. ReWOO规划失败→回退测试
|
||||
3. Reflexion多轮循环测试
|
||||
4. Soul演变触发条件测试
|
||||
|
||||
**Test scenarios:**
|
||||
- 全链路:CostAwareRouter路由到ReActEngine → AlignmentGuard检查通过
|
||||
- 全链路违规:AlignmentGuard检查失败,返回violations
|
||||
- ReWOO回退:规划异常时正确回退到ReAct
|
||||
- Reflexion:3轮循环后返回最佳结果
|
||||
- Soul触发:3次低质量反思后触发update_soul
|
||||
|
||||
**Verification:** 集成测试全部通过
|
||||
|
||||
---
|
||||
|
||||
### U5. ReWOO渐进回退链
|
||||
|
||||
**Goal:** ReWOO规划失败时渐进降级,而非直接回退ReAct
|
||||
|
||||
**Dependencies:** None
|
||||
|
||||
**Files:**
|
||||
- Modify: `src/agentkit/core/rewoo.py`
|
||||
- Test: `tests/unit/test_rewoo_engine.py`
|
||||
|
||||
**Approach:**
|
||||
1. 新增 `FALLBACK_CHAIN` 配置:simplified_rewoo → plan_exec → react → direct
|
||||
2. 规划失败时先尝试简化规划(max_plan_steps=3)
|
||||
3. 简化仍失败则按链降级
|
||||
4. 回退信息记录在 ReWOOResult 中
|
||||
|
||||
**Test scenarios:**
|
||||
- 正常规划成功:无回退
|
||||
- 规划失败→简化规划成功:结果标记 fallback=simplified
|
||||
- 规划失败→简化失败→ReAct回退:结果标记 fallback=react
|
||||
- 所有回退都失败:返回错误结果
|
||||
|
||||
**Verification:** 回退链测试通过
|
||||
|
||||
---
|
||||
|
||||
### U6. Agent间通信协议
|
||||
|
||||
**Goal:** 实现轻量级消息总线,支持Agent间请求-响应、广播、协商
|
||||
|
||||
**Dependencies:** None
|
||||
|
||||
**Files:**
|
||||
- Create: `src/agentkit/bus/__init__.py`
|
||||
- Create: `src/agentkit/bus/message.py` (AgentMessage数据类)
|
||||
- Create: `src/agentkit/bus/interface.py` (MessageBus接口)
|
||||
- Create: `src/agentkit/bus/memory_bus.py` (InMemoryMessageBus实现)
|
||||
- Modify: `src/agentkit/server/app.py` (初始化MessageBus)
|
||||
- Test: `tests/unit/test_agent_bus.py`
|
||||
|
||||
**Approach:**
|
||||
1. 定义 `AgentMessage` dataclass:sender, recipient, msg_type, content, correlation_id, timestamp, ttl
|
||||
2. 定义 `MessageBus` 抽象接口:publish, subscribe, request
|
||||
3. 实现 `InMemoryMessageBus`:基于 asyncio.Queue
|
||||
4. 与 AlignmentGuard 集成:消息经过约束检查
|
||||
5. 与 CascadeDetector 集成:消息传递计入级联检测
|
||||
|
||||
**Test scenarios:**
|
||||
- 请求-响应:Agent A发送请求,Agent B响应,correlation_id匹配
|
||||
- 广播:Agent A广播,所有订阅者收到
|
||||
- 协商:Agent A发起协商,Agent B回复建议
|
||||
- TTL过期:超时消息被丢弃
|
||||
- 级联检测:超过阈值的消息链触发 CascadeAlert
|
||||
|
||||
**Verification:** 消息总线测试通过,与Guard/Cascade集成正确
|
||||
|
||||
---
|
||||
|
||||
## Risks & Mitigations
|
||||
|
||||
| Risk | Impact | Mitigation |
|
||||
|------|--------|------------|
|
||||
| OTel依赖增加包体积 | Low | OTel为可选依赖,未配置时no-op |
|
||||
| 消息总线内存泄漏 | Medium | TTL过期清理 + 定期GC |
|
||||
| ReWOO回退链增加延迟 | Low | 每步设置timeout,总回退时间上限5s |
|
||||
| 2-gram分词噪声 | Low | 停用词过滤 + 长度阈值 |
|
||||
|
||||
## Deferred to Follow-Up Work
|
||||
|
||||
- jieba分词集成(可选依赖)
|
||||
- RedisMessageBus分布式实现
|
||||
- Live测试(需真实API Key)
|
||||
- Jaeger/Zipkin前端配置
|
||||
- 消息总线持久化
|
||||
|
|
@ -1,13 +1,15 @@
|
|||
"""AgentKit Bus - Agent 间通信基础设施"""
|
||||
|
||||
from agentkit.bus.message import AgentMessage
|
||||
from agentkit.bus.protocol import MessageBus
|
||||
from agentkit.bus.interface import MessageBus
|
||||
from agentkit.bus.protocol import MessageBus as MessageBusProtocol
|
||||
from agentkit.bus.memory_bus import InMemoryMessageBus
|
||||
from agentkit.bus.redis_bus import RedisMessageBus, create_message_bus
|
||||
|
||||
__all__ = [
|
||||
"AgentMessage",
|
||||
"MessageBus",
|
||||
"MessageBusProtocol",
|
||||
"InMemoryMessageBus",
|
||||
"RedisMessageBus",
|
||||
"create_message_bus",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,49 @@
|
|||
"""MessageBus ABC — Agent 间通信抽象基类。
|
||||
|
||||
与 protocol.py 的 Protocol 定义并存,提供 ABC 版本用于显式继承。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Awaitable
|
||||
|
||||
from agentkit.bus.message import AgentMessage
|
||||
|
||||
|
||||
class MessageBus(ABC):
|
||||
"""Agent 间消息总线抽象基类。
|
||||
|
||||
支持三种通信模式:
|
||||
- 点对点:publish() 指定 recipient
|
||||
- 广播:publish() 不指定 recipient
|
||||
- 请求-响应:request() 等待对方通过 correlation_id 回复
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def publish(self, message: AgentMessage) -> bool:
|
||||
"""发布消息,返回是否成功。"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def subscribe(
|
||||
self,
|
||||
agent_name: str,
|
||||
handler: Callable[[AgentMessage], Awaitable[None]],
|
||||
) -> None:
|
||||
"""订阅消息。"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def unsubscribe(self, agent_name: str) -> None:
|
||||
"""取消订阅。"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def request(
|
||||
self,
|
||||
message: AgentMessage,
|
||||
timeout_seconds: float = 30.0,
|
||||
) -> AgentMessage | None:
|
||||
"""发送请求并等待响应。超时返回 None。"""
|
||||
...
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
"""InMemoryMessageBus — 基于 asyncio.Queue 的内存消息总线。
|
||||
|
||||
用于开发和测试,行为与 Redis 实现一致。
|
||||
集成 CascadeDetector 和 AlignmentGuard 进行消息质量管控。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -17,16 +18,48 @@ logger = logging.getLogger(__name__)
|
|||
class InMemoryMessageBus:
|
||||
"""基于 asyncio.Queue 的内存消息总线。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
cascade_detector: Any = None,
|
||||
alignment_guard: Any = None,
|
||||
) -> None:
|
||||
self._subscribers: dict[str, list[Callable[[AgentMessage], Awaitable[None]]]] = {}
|
||||
self._pending_requests: dict[str, asyncio.Future[AgentMessage]] = {}
|
||||
self._queues: dict[str, asyncio.Queue[AgentMessage]] = {}
|
||||
self._consumer_tasks: dict[str, list[asyncio.Task]] = {}
|
||||
self._cascade_detector = cascade_detector
|
||||
self._alignment_guard = alignment_guard
|
||||
|
||||
async def publish(self, message: AgentMessage) -> bool:
|
||||
"""发布消息,返回是否成功。"""
|
||||
# TTL 过期检查
|
||||
if message.is_expired():
|
||||
logger.warning(f"Message {message.message_id} expired, dropping")
|
||||
return False
|
||||
|
||||
# Cascade detection — 级联故障检测
|
||||
if self._cascade_detector and message.sender:
|
||||
alert = self._cascade_detector.check_interaction(
|
||||
session_id=f"bus-{message.sender}-{message.recipient or 'broadcast'}"
|
||||
)
|
||||
if alert:
|
||||
logger.warning(f"Cascade alert: {alert}")
|
||||
return False
|
||||
|
||||
# Alignment check — 对齐守卫检查(仅对 request / negotiate 类型)
|
||||
if self._alignment_guard and message.msg_type in ("request", "negotiate"):
|
||||
check = await self._alignment_guard.check_output(
|
||||
output={"content": str(message.content), **message.payload},
|
||||
)
|
||||
if not check.passed:
|
||||
logger.warning(
|
||||
f"Message blocked by alignment guard: {check.violations}"
|
||||
)
|
||||
return False
|
||||
|
||||
async def publish(self, message: AgentMessage) -> None:
|
||||
"""发布消息。"""
|
||||
if message.is_broadcast:
|
||||
await self.broadcast(message)
|
||||
return
|
||||
return True
|
||||
|
||||
# Point-to-point: deliver to recipient's queue
|
||||
recipient = message.recipient
|
||||
|
|
@ -41,16 +74,9 @@ class InMemoryMessageBus:
|
|||
logger.warning(f"Handler error for {recipient}: {e}")
|
||||
|
||||
# Check if this is a response to a pending request
|
||||
# Only resolve if this is a reply (message_id != correlation_id),
|
||||
# not the original request itself
|
||||
if (
|
||||
message.correlation_id
|
||||
and message.correlation_id in self._pending_requests
|
||||
and message.message_id != message.correlation_id
|
||||
):
|
||||
future = self._pending_requests[message.correlation_id]
|
||||
if not future.done():
|
||||
future.set_result(message)
|
||||
self._try_resolve_pending(message)
|
||||
|
||||
return True
|
||||
|
||||
async def subscribe(
|
||||
self,
|
||||
|
|
@ -61,10 +87,12 @@ class InMemoryMessageBus:
|
|||
if agent_name not in self._subscribers:
|
||||
self._subscribers[agent_name] = []
|
||||
self._queues[agent_name] = asyncio.Queue()
|
||||
self._consumer_tasks[agent_name] = []
|
||||
self._subscribers[agent_name].append(handler)
|
||||
|
||||
# Start consumer task
|
||||
asyncio.create_task(self._consume_queue(agent_name, handler))
|
||||
# Start consumer task and track it
|
||||
task = asyncio.create_task(self._consume_queue(agent_name, handler))
|
||||
self._consumer_tasks[agent_name].append(task)
|
||||
|
||||
async def _consume_queue(
|
||||
self,
|
||||
|
|
@ -82,34 +110,59 @@ class InMemoryMessageBus:
|
|||
await handler(message)
|
||||
except Exception as e:
|
||||
logger.warning(f"Handler error for {agent_name}: {e}")
|
||||
|
||||
# Check pending requests after handler processes the message
|
||||
# (e.g., handler may publish a response that resolves a future)
|
||||
self._try_resolve_pending(message)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
def _try_resolve_pending(self, message: AgentMessage) -> None:
|
||||
"""Try to resolve a pending request future if this message is a response."""
|
||||
if (
|
||||
message.correlation_id
|
||||
and message.correlation_id in self._pending_requests
|
||||
and message.message_id != message.correlation_id
|
||||
):
|
||||
future = self._pending_requests[message.correlation_id]
|
||||
if not future.done():
|
||||
future.set_result(message)
|
||||
|
||||
async def unsubscribe(self, agent_name: str) -> None:
|
||||
"""取消订阅。"""
|
||||
self._subscribers.pop(agent_name, None)
|
||||
self._queues.pop(agent_name, None)
|
||||
# Cancel tracked consumer tasks
|
||||
tasks = self._consumer_tasks.pop(agent_name, [])
|
||||
for task in tasks:
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
|
||||
async def request(
|
||||
self,
|
||||
message: AgentMessage,
|
||||
timeout: float = 30.0,
|
||||
) -> AgentMessage:
|
||||
"""请求-响应模式。"""
|
||||
timeout_seconds: float = 30.0,
|
||||
) -> AgentMessage | None:
|
||||
"""请求-响应模式。超时返回 None。"""
|
||||
message.msg_type = "request"
|
||||
if not message.correlation_id:
|
||||
message.correlation_id = message.message_id
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.get_running_loop()
|
||||
future: asyncio.Future[AgentMessage] = loop.create_future()
|
||||
self._pending_requests[message.correlation_id] = future
|
||||
|
||||
published = await self.publish(message)
|
||||
if not published:
|
||||
self._pending_requests.pop(message.correlation_id, None)
|
||||
return None
|
||||
|
||||
try:
|
||||
await self.publish(message)
|
||||
return await asyncio.wait_for(future, timeout=timeout)
|
||||
return await asyncio.wait_for(future, timeout=timeout_seconds)
|
||||
except asyncio.TimeoutError:
|
||||
raise TimeoutError(
|
||||
f"Request {message.correlation_id} timed out after {timeout}s"
|
||||
)
|
||||
self._pending_requests.pop(message.correlation_id, None)
|
||||
logger.warning(f"Request {message.correlation_id} timed out")
|
||||
return None
|
||||
finally:
|
||||
self._pending_requests.pop(message.correlation_id, None)
|
||||
|
||||
|
|
|
|||
|
|
@ -21,34 +21,58 @@ class AgentMessage:
|
|||
recipient: str | None = None # None = broadcast
|
||||
topic: str = ""
|
||||
payload: dict[str, Any] = field(default_factory=dict)
|
||||
timestamp: str = field(
|
||||
default_factory=lambda: datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
correlation_id: str | None = None # 请求-响应关联
|
||||
# --- 新增字段 ---
|
||||
content: Any = None # 消息内容(与 payload 互补,payload 为 dict,content 可为任意类型)
|
||||
msg_type: str = "notify" # "request" | "response" | "notify" | "negotiate"
|
||||
ttl_seconds: int = 300 # 消息存活时间(秒)
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""检查消息是否已过期。"""
|
||||
if isinstance(self.timestamp, datetime):
|
||||
age = (datetime.now(timezone.utc) - self.timestamp).total_seconds()
|
||||
return age > self.ttl_seconds
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_broadcast(self) -> bool:
|
||||
return self.recipient is None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
ts = self.timestamp
|
||||
if isinstance(ts, datetime):
|
||||
ts = ts.isoformat()
|
||||
return {
|
||||
"message_id": self.message_id,
|
||||
"sender": self.sender,
|
||||
"recipient": self.recipient,
|
||||
"topic": self.topic,
|
||||
"payload": self.payload,
|
||||
"timestamp": self.timestamp,
|
||||
"timestamp": ts,
|
||||
"correlation_id": self.correlation_id,
|
||||
"content": self.content,
|
||||
"msg_type": self.msg_type,
|
||||
"ttl_seconds": self.ttl_seconds,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> AgentMessage:
|
||||
ts = data.get("timestamp", "")
|
||||
if isinstance(ts, str) and ts:
|
||||
try:
|
||||
ts = datetime.fromisoformat(ts)
|
||||
except (ValueError, TypeError):
|
||||
pass # keep as string if parse fails
|
||||
return cls(
|
||||
message_id=data.get("message_id", ""),
|
||||
sender=data.get("sender", ""),
|
||||
recipient=data.get("recipient"),
|
||||
topic=data.get("topic", ""),
|
||||
payload=data.get("payload", {}),
|
||||
timestamp=data.get("timestamp", ""),
|
||||
timestamp=ts,
|
||||
correlation_id=data.get("correlation_id"),
|
||||
content=data.get("content"),
|
||||
msg_type=data.get("msg_type", "notify"),
|
||||
ttl_seconds=data.get("ttl_seconds", 300),
|
||||
)
|
||||
|
||||
@property
|
||||
def is_broadcast(self) -> bool:
|
||||
return self.recipient is None
|
||||
|
|
|
|||
|
|
@ -12,6 +12,8 @@ import re
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from agentkit.telemetry.tracer import get_tracer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Strict validation: only lowercase alphanumeric, hyphens, underscores
|
||||
|
|
@ -187,6 +189,31 @@ _CHAT_MODE_RE = re.compile(
|
|||
)
|
||||
|
||||
|
||||
def _tokenize_content(content: str) -> list[str]:
|
||||
"""Tokenize content for capability matching. Supports Chinese and English."""
|
||||
# 1. Split by punctuation and whitespace
|
||||
segments = re.split(r'[\s,,。!?、;:\n]+', content)
|
||||
|
||||
# 2. For long Chinese segments, add 2-gram supplements
|
||||
tokens = []
|
||||
for seg in segments:
|
||||
if len(seg) <= 4:
|
||||
tokens.append(seg)
|
||||
else:
|
||||
tokens.append(seg)
|
||||
# Add 2-grams for Chinese compound words
|
||||
for i in range(len(seg) - 1):
|
||||
bigram = seg[i:i+2]
|
||||
if all('\u4e00' <= c <= '\u9fff' for c in bigram):
|
||||
tokens.append(bigram)
|
||||
|
||||
# 3. Filter stopwords
|
||||
stopwords = {"的", "了", "是", "在", "和", "与", "也", "都", "就", "要", "会", "我", "你", "他", "这", "那", "有", "没", "不"}
|
||||
tokens = [t for t in tokens if t not in stopwords and len(t) > 1][:10]
|
||||
|
||||
return tokens
|
||||
|
||||
|
||||
class CostAwareRouter:
|
||||
"""三层成本感知路由器。
|
||||
|
||||
|
|
@ -245,7 +272,9 @@ class CostAwareRouter:
|
|||
'You are a complexity classifier. Rate the complexity of the user request on a scale of 0.0 to 1.0.\n'
|
||||
'0.0 = trivial greeting, 0.3 = simple question, 0.5 = moderate task, '
|
||||
'0.7 = complex multi-step task, 1.0 = very complex research task.\n\n'
|
||||
f'User request: "{content}"\n\n'
|
||||
'---BEGIN USER REQUEST---\n'
|
||||
f'{content}\n'
|
||||
'---END USER REQUEST---\n\n'
|
||||
'Respond ONLY with a JSON object: {"complexity": <float>}'
|
||||
)
|
||||
try:
|
||||
|
|
@ -281,10 +310,7 @@ class CostAwareRouter:
|
|||
try:
|
||||
# Extract capability-like keywords from content for matching
|
||||
# find_best_agent expects list[str] of required capabilities
|
||||
# Support both space-separated (English) and punctuation-separated (Chinese) content
|
||||
import re
|
||||
tokens = re.split(r'[\s,,。!?、;:\n]+', content)
|
||||
content_words = [t for t in tokens if len(t) > 1][:5]
|
||||
content_words = _tokenize_content(content)
|
||||
best_agent = self._org_context.find_best_agent(required_capabilities=content_words)
|
||||
if best_agent is not None:
|
||||
agent_name = best_agent if isinstance(best_agent, str) else getattr(best_agent, "name", str(best_agent))
|
||||
|
|
@ -365,11 +391,130 @@ class CostAwareRouter:
|
|||
"""
|
||||
trace: list[dict] = []
|
||||
|
||||
# ---- Layer 0: Rule-based (zero cost) ----
|
||||
match_type, clean_content = self._match_layer0(content)
|
||||
tracer = get_tracer()
|
||||
with tracer.start_span("router.route") as span:
|
||||
span.set_attribute("input.length", len(content))
|
||||
|
||||
if match_type == "explicit_skill":
|
||||
result = await resolve_skill_routing(
|
||||
# ---- Layer 0: Rule-based (zero cost) ----
|
||||
match_type, clean_content = self._match_layer0(content)
|
||||
|
||||
if match_type == "explicit_skill":
|
||||
result = await resolve_skill_routing(
|
||||
content=content,
|
||||
skill_registry=skill_registry,
|
||||
intent_router=intent_router,
|
||||
default_tools=default_tools,
|
||||
default_system_prompt=default_system_prompt,
|
||||
default_model=default_model,
|
||||
default_agent_name=default_agent_name,
|
||||
agent_tool_registry=agent_tool_registry,
|
||||
session_id=session_id,
|
||||
)
|
||||
result.match_method = result.match_method or "explicit_skill"
|
||||
result.complexity = 0.0
|
||||
trace.append({
|
||||
"layer": 0,
|
||||
"method": "explicit_skill",
|
||||
"matched": result.matched,
|
||||
"cost": "zero",
|
||||
})
|
||||
result.execution_trace = trace if transparency != "SILENT" else []
|
||||
result.transparency_level = transparency
|
||||
span.set_attribute("route.layer", result.match_method or "explicit_skill")
|
||||
span.set_attribute("route.target", result.skill_name or "default")
|
||||
return result
|
||||
|
||||
if match_type in ("greeting", "chat_mode"):
|
||||
result = SkillRoutingResult(
|
||||
clean_content=clean_content,
|
||||
system_prompt=default_system_prompt,
|
||||
tools=default_tools,
|
||||
model=default_model,
|
||||
agent_name=default_agent_name,
|
||||
matched=False,
|
||||
match_method=match_type,
|
||||
match_confidence=1.0,
|
||||
complexity=0.0,
|
||||
)
|
||||
trace.append({
|
||||
"layer": 0,
|
||||
"method": match_type,
|
||||
"matched": False,
|
||||
"cost": "zero",
|
||||
})
|
||||
result.execution_trace = trace if transparency != "SILENT" else []
|
||||
result.transparency_level = transparency
|
||||
span.set_attribute("route.layer", match_type)
|
||||
span.set_attribute("route.target", "default")
|
||||
return result
|
||||
|
||||
# ---- Layer 1: LLM quick classify (~100 tokens) ----
|
||||
complexity = await self.quick_classify(clean_content)
|
||||
trace.append({
|
||||
"layer": 1,
|
||||
"method": "quick_classify",
|
||||
"complexity": complexity,
|
||||
})
|
||||
|
||||
# Low complexity → default agent
|
||||
if complexity < 0.3:
|
||||
result = SkillRoutingResult(
|
||||
clean_content=clean_content,
|
||||
system_prompt=default_system_prompt,
|
||||
tools=default_tools,
|
||||
model=default_model,
|
||||
agent_name=default_agent_name,
|
||||
matched=False,
|
||||
match_method="low_complexity",
|
||||
match_confidence=1.0 - complexity,
|
||||
complexity=complexity,
|
||||
)
|
||||
trace.append({
|
||||
"layer": 1,
|
||||
"method": "low_complexity",
|
||||
"complexity": complexity,
|
||||
"routed_to": "default",
|
||||
})
|
||||
result.execution_trace = trace if transparency != "SILENT" else []
|
||||
result.transparency_level = transparency
|
||||
span.set_attribute("route.layer", "low_complexity")
|
||||
span.set_attribute("route.target", "default")
|
||||
return result
|
||||
|
||||
# Medium complexity → IntentRouter via resolve_skill_routing
|
||||
if complexity <= 0.7:
|
||||
result = await resolve_skill_routing(
|
||||
content=content,
|
||||
skill_registry=skill_registry,
|
||||
intent_router=intent_router,
|
||||
default_tools=default_tools,
|
||||
default_system_prompt=default_system_prompt,
|
||||
default_model=default_model,
|
||||
default_agent_name=default_agent_name,
|
||||
agent_tool_registry=agent_tool_registry,
|
||||
session_id=session_id,
|
||||
)
|
||||
result.complexity = complexity
|
||||
trace.append({
|
||||
"layer": 1,
|
||||
"method": "intent_router",
|
||||
"complexity": complexity,
|
||||
"matched": result.matched,
|
||||
})
|
||||
result.execution_trace = trace if transparency != "SILENT" else []
|
||||
result.transparency_level = transparency
|
||||
span.set_attribute("route.layer", result.match_method or "intent_router")
|
||||
span.set_attribute("route.target", result.skill_name or "default")
|
||||
return result
|
||||
|
||||
# ---- Layer 2: Capability matching / Auction (high complexity) ----
|
||||
trace.append({
|
||||
"layer": 2,
|
||||
"method": "capability_or_auction",
|
||||
"complexity": complexity,
|
||||
"auction_enabled": self._auction_enabled,
|
||||
})
|
||||
result = await self._route_layer2(
|
||||
content=content,
|
||||
skill_registry=skill_registry,
|
||||
intent_router=intent_router,
|
||||
|
|
@ -379,116 +524,11 @@ class CostAwareRouter:
|
|||
default_agent_name=default_agent_name,
|
||||
agent_tool_registry=agent_tool_registry,
|
||||
session_id=session_id,
|
||||
)
|
||||
result.match_method = result.match_method or "explicit_skill"
|
||||
result.complexity = 0.0
|
||||
trace.append({
|
||||
"layer": 0,
|
||||
"method": "explicit_skill",
|
||||
"matched": result.matched,
|
||||
"cost": "zero",
|
||||
})
|
||||
result.execution_trace = trace if transparency != "SILENT" else []
|
||||
result.transparency_level = transparency
|
||||
return result
|
||||
|
||||
if match_type in ("greeting", "chat_mode"):
|
||||
result = SkillRoutingResult(
|
||||
clean_content=clean_content,
|
||||
system_prompt=default_system_prompt,
|
||||
tools=default_tools,
|
||||
model=default_model,
|
||||
agent_name=default_agent_name,
|
||||
matched=False,
|
||||
match_method=match_type,
|
||||
match_confidence=1.0,
|
||||
complexity=0.0,
|
||||
)
|
||||
trace.append({
|
||||
"layer": 0,
|
||||
"method": match_type,
|
||||
"matched": False,
|
||||
"cost": "zero",
|
||||
})
|
||||
result.execution_trace = trace if transparency != "SILENT" else []
|
||||
result.transparency_level = transparency
|
||||
return result
|
||||
|
||||
# ---- Layer 1: LLM quick classify (~100 tokens) ----
|
||||
complexity = await self.quick_classify(clean_content)
|
||||
trace.append({
|
||||
"layer": 1,
|
||||
"method": "quick_classify",
|
||||
"complexity": complexity,
|
||||
})
|
||||
|
||||
# Low complexity → default agent
|
||||
if complexity < 0.3:
|
||||
result = SkillRoutingResult(
|
||||
clean_content=clean_content,
|
||||
system_prompt=default_system_prompt,
|
||||
tools=default_tools,
|
||||
model=default_model,
|
||||
agent_name=default_agent_name,
|
||||
matched=False,
|
||||
match_method="low_complexity",
|
||||
match_confidence=1.0 - complexity,
|
||||
complexity=complexity,
|
||||
trace=trace,
|
||||
)
|
||||
trace.append({
|
||||
"layer": 1,
|
||||
"method": "low_complexity",
|
||||
"complexity": complexity,
|
||||
"routed_to": "default",
|
||||
})
|
||||
result.execution_trace = trace if transparency != "SILENT" else []
|
||||
result.transparency_level = transparency
|
||||
span.set_attribute("route.layer", result.match_method or "capability")
|
||||
span.set_attribute("route.target", result.skill_name or result.agent_name or "default")
|
||||
return result
|
||||
|
||||
# Medium complexity → IntentRouter via resolve_skill_routing
|
||||
if complexity <= 0.7:
|
||||
result = await resolve_skill_routing(
|
||||
content=content,
|
||||
skill_registry=skill_registry,
|
||||
intent_router=intent_router,
|
||||
default_tools=default_tools,
|
||||
default_system_prompt=default_system_prompt,
|
||||
default_model=default_model,
|
||||
default_agent_name=default_agent_name,
|
||||
agent_tool_registry=agent_tool_registry,
|
||||
session_id=session_id,
|
||||
)
|
||||
result.complexity = complexity
|
||||
trace.append({
|
||||
"layer": 1,
|
||||
"method": "intent_router",
|
||||
"complexity": complexity,
|
||||
"matched": result.matched,
|
||||
})
|
||||
result.execution_trace = trace if transparency != "SILENT" else []
|
||||
result.transparency_level = transparency
|
||||
return result
|
||||
|
||||
# ---- Layer 2: Capability matching / Auction (high complexity) ----
|
||||
trace.append({
|
||||
"layer": 2,
|
||||
"method": "capability_or_auction",
|
||||
"complexity": complexity,
|
||||
"auction_enabled": self._auction_enabled,
|
||||
})
|
||||
result = await self._route_layer2(
|
||||
content=content,
|
||||
skill_registry=skill_registry,
|
||||
intent_router=intent_router,
|
||||
default_tools=default_tools,
|
||||
default_system_prompt=default_system_prompt,
|
||||
default_model=default_model,
|
||||
default_agent_name=default_agent_name,
|
||||
agent_tool_registry=agent_tool_registry,
|
||||
session_id=session_id,
|
||||
complexity=complexity,
|
||||
trace=trace,
|
||||
)
|
||||
result.execution_trace = trace if transparency != "SILENT" else []
|
||||
result.transparency_level = transparency
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -54,6 +54,7 @@ class ReActResult:
|
|||
total_steps: int
|
||||
total_tokens: int
|
||||
status: str = "success" # "success" | "timeout" | "cancelled" | "partial"
|
||||
fallback_strategy: str | None = None # e.g. "simplified_rewoo", "react", "direct"
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
|||
|
|
@ -115,6 +115,8 @@ class ReWOOEngine:
|
|||
3. Synthesis Phase: 综合所有工具结果生成最终输出
|
||||
"""
|
||||
|
||||
FALLBACK_STRATEGIES = ["simplified_rewoo", "react", "direct"]
|
||||
|
||||
def __init__(self, llm_gateway: LLMGateway, max_plan_steps: int = 10, default_timeout: float = 300.0):
|
||||
if max_plan_steps < 1:
|
||||
raise ValueError(f"max_plan_steps must be >= 1, got {max_plan_steps}")
|
||||
|
|
@ -283,6 +285,8 @@ class ReWOOEngine:
|
|||
)
|
||||
total_tokens += planning_tokens
|
||||
|
||||
fallback_strategy: str | None = None
|
||||
|
||||
# 记录规划步骤
|
||||
if trace_recorder is not None:
|
||||
trace_recorder.record_step(
|
||||
|
|
@ -292,26 +296,109 @@ class ReWOOEngine:
|
|||
tokens_used=planning_tokens,
|
||||
)
|
||||
|
||||
# 如果规划失败,回退到 ReAct
|
||||
# 如果规划失败,尝试渐进式回退
|
||||
if plan is None:
|
||||
# 尝试简化规划(max_steps=3)
|
||||
logger.warning("ReWOO planning failed, trying simplified planning with max_steps=3")
|
||||
try:
|
||||
plan, simplified_tokens = await self._plan_phase(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_schemas=tool_schemas,
|
||||
model=model,
|
||||
agent_name=agent_name,
|
||||
task_type=task_type,
|
||||
system_prompt=effective_system_prompt,
|
||||
compressor=compressor,
|
||||
cancellation_token=cancellation_token,
|
||||
max_steps=3,
|
||||
)
|
||||
total_tokens += simplified_tokens
|
||||
if plan is not None and plan.steps:
|
||||
fallback_strategy = "simplified_rewoo"
|
||||
logger.info("Simplified ReWOO planning succeeded")
|
||||
except Exception as e2:
|
||||
logger.warning(f"Simplified ReWOO planning also failed: {e2}")
|
||||
|
||||
if plan is None:
|
||||
# 回退到 ReAct
|
||||
fallback_strategy = "react"
|
||||
logger.warning("ReWOO planning failed, falling back to ReActEngine")
|
||||
if trace_recorder is not None:
|
||||
trace_recorder.end_trace(outcome="fallback")
|
||||
return await self._react_engine.execute(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model,
|
||||
agent_name=agent_name,
|
||||
task_type=task_type,
|
||||
system_prompt=system_prompt,
|
||||
trace_recorder=trace_recorder,
|
||||
memory_retriever=memory_retriever,
|
||||
task_id=task_id,
|
||||
compressor=compressor,
|
||||
retrieval_config=retrieval_config,
|
||||
cancellation_token=cancellation_token,
|
||||
timeout_seconds=0, # timeout already handled by outer wrapper
|
||||
)
|
||||
try:
|
||||
react_result = await self._react_engine.execute(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model,
|
||||
agent_name=agent_name,
|
||||
task_type=task_type,
|
||||
system_prompt=system_prompt,
|
||||
trace_recorder=trace_recorder,
|
||||
memory_retriever=memory_retriever,
|
||||
task_id=task_id,
|
||||
compressor=compressor,
|
||||
retrieval_config=retrieval_config,
|
||||
cancellation_token=cancellation_token,
|
||||
timeout_seconds=0, # timeout already handled by outer wrapper
|
||||
)
|
||||
react_result.fallback_strategy = fallback_strategy
|
||||
return react_result
|
||||
except Exception as react_err:
|
||||
# ReAct 也失败,回退到 Direct(简单 LLM 调用)
|
||||
fallback_strategy = "direct"
|
||||
logger.warning(f"ReAct fallback also failed: {react_err}, falling back to direct LLM call")
|
||||
try:
|
||||
direct_messages: list[dict[str, Any]] = []
|
||||
if effective_system_prompt:
|
||||
direct_messages.append({"role": "system", "content": effective_system_prompt})
|
||||
direct_messages.extend(messages)
|
||||
|
||||
if compressor:
|
||||
try:
|
||||
direct_messages = await compressor.compress(direct_messages)
|
||||
except Exception as e:
|
||||
logger.warning(f"Context compression failed in direct fallback: {e}")
|
||||
|
||||
direct_response = await self._llm_gateway.chat(
|
||||
messages=direct_messages,
|
||||
model=model,
|
||||
agent_name=agent_name,
|
||||
task_type=task_type,
|
||||
)
|
||||
total_tokens += direct_response.usage.total_tokens
|
||||
|
||||
direct_step = ReWOOStep(
|
||||
step=1,
|
||||
action="final_answer",
|
||||
content=direct_response.content,
|
||||
tokens=direct_response.usage.total_tokens,
|
||||
plan_step_id=None,
|
||||
)
|
||||
trajectory.append(direct_step)
|
||||
|
||||
if trace_recorder is not None:
|
||||
trace_recorder.record_step(
|
||||
step=1,
|
||||
action="final_answer",
|
||||
output_data={"content": direct_response.content},
|
||||
tokens_used=direct_response.usage.total_tokens,
|
||||
)
|
||||
|
||||
trace_outcome = "success"
|
||||
if trace_recorder is not None:
|
||||
trace_recorder.end_trace(outcome=trace_outcome)
|
||||
|
||||
return ReActResult(
|
||||
output=direct_response.content or "",
|
||||
trajectory=trajectory,
|
||||
total_steps=len(trajectory),
|
||||
total_tokens=total_tokens,
|
||||
fallback_strategy=fallback_strategy,
|
||||
)
|
||||
except Exception as direct_err:
|
||||
logger.error(f"Direct LLM fallback also failed: {direct_err}")
|
||||
raise
|
||||
|
||||
# 如果计划为空(无需工具),直接让 LLM 回答
|
||||
if not plan.steps:
|
||||
|
|
@ -360,6 +447,7 @@ class ReWOOEngine:
|
|||
trajectory=trajectory,
|
||||
total_steps=len(trajectory),
|
||||
total_tokens=total_tokens,
|
||||
fallback_strategy=fallback_strategy,
|
||||
)
|
||||
|
||||
# ── Phase 2: Execution ──
|
||||
|
|
@ -462,6 +550,7 @@ class ReWOOEngine:
|
|||
trajectory=trajectory,
|
||||
total_steps=len(trajectory),
|
||||
total_tokens=total_tokens,
|
||||
fallback_strategy=fallback_strategy,
|
||||
)
|
||||
finally:
|
||||
# Telemetry: end span and record duration
|
||||
|
|
@ -558,26 +647,84 @@ class ReWOOEngine:
|
|||
)
|
||||
total_tokens += planning_tokens
|
||||
|
||||
if plan is None:
|
||||
# Try simplified planning
|
||||
logger.warning("ReWOO planning failed in stream mode, trying simplified planning with max_steps=3")
|
||||
try:
|
||||
plan, simplified_tokens = await self._plan_phase(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_schemas=tool_schemas,
|
||||
model=model,
|
||||
agent_name=agent_name,
|
||||
task_type=task_type,
|
||||
system_prompt=effective_system_prompt,
|
||||
compressor=compressor,
|
||||
cancellation_token=cancellation_token,
|
||||
max_steps=3,
|
||||
)
|
||||
total_tokens += simplified_tokens
|
||||
except Exception as e2:
|
||||
logger.warning(f"Simplified ReWOO planning also failed in stream mode: {e2}")
|
||||
|
||||
if plan is None:
|
||||
# Planning failed, fall back to ReAct streaming
|
||||
logger.warning("ReWOO planning failed in stream mode, falling back to ReActEngine")
|
||||
async for event in self._react_engine.execute_stream(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model,
|
||||
agent_name=agent_name,
|
||||
task_type=task_type,
|
||||
system_prompt=system_prompt,
|
||||
trace_recorder=trace_recorder,
|
||||
memory_retriever=memory_retriever,
|
||||
task_id=task_id,
|
||||
compressor=compressor,
|
||||
retrieval_config=retrieval_config,
|
||||
cancellation_token=cancellation_token,
|
||||
timeout_seconds=0,
|
||||
):
|
||||
yield event
|
||||
return
|
||||
try:
|
||||
async for event in self._react_engine.execute_stream(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model,
|
||||
agent_name=agent_name,
|
||||
task_type=task_type,
|
||||
system_prompt=system_prompt,
|
||||
trace_recorder=trace_recorder,
|
||||
memory_retriever=memory_retriever,
|
||||
task_id=task_id,
|
||||
compressor=compressor,
|
||||
retrieval_config=retrieval_config,
|
||||
cancellation_token=cancellation_token,
|
||||
timeout_seconds=0,
|
||||
):
|
||||
yield event
|
||||
return
|
||||
except Exception as react_err:
|
||||
# ReAct also failed, fall back to direct LLM call
|
||||
logger.warning(f"ReAct fallback also failed in stream mode: {react_err}, falling back to direct LLM call")
|
||||
try:
|
||||
direct_messages: list[dict[str, Any]] = []
|
||||
if effective_system_prompt:
|
||||
direct_messages.append({"role": "system", "content": effective_system_prompt})
|
||||
direct_messages.extend(messages)
|
||||
|
||||
if compressor:
|
||||
try:
|
||||
direct_messages = await compressor.compress(direct_messages)
|
||||
except Exception as e:
|
||||
logger.warning(f"Context compression failed in direct fallback: {e}")
|
||||
|
||||
direct_response = await self._llm_gateway.chat(
|
||||
messages=direct_messages,
|
||||
model=model,
|
||||
agent_name=agent_name,
|
||||
task_type=task_type,
|
||||
)
|
||||
total_tokens += direct_response.usage.total_tokens
|
||||
output = direct_response.content or ""
|
||||
|
||||
yield ReActEvent(
|
||||
event_type="final_answer",
|
||||
step=1,
|
||||
data={
|
||||
"output": output,
|
||||
"total_steps": 1,
|
||||
"total_tokens": total_tokens,
|
||||
},
|
||||
)
|
||||
return
|
||||
except Exception as direct_err:
|
||||
logger.error(f"Direct LLM fallback also failed in stream mode: {direct_err}")
|
||||
raise
|
||||
|
||||
yield ReActEvent(
|
||||
event_type="plan_generated",
|
||||
|
|
@ -758,9 +905,13 @@ class ReWOOEngine:
|
|||
system_prompt: str | None,
|
||||
compressor: "CompressionStrategy | None",
|
||||
cancellation_token: CancellationToken | None,
|
||||
max_steps: int | None = None,
|
||||
) -> tuple[ReWOOPlan | None, int]:
|
||||
"""Planning Phase: 调用 LLM 生成完整执行计划
|
||||
|
||||
Args:
|
||||
max_steps: 限制计划最大步数,None 则使用 self._max_plan_steps
|
||||
|
||||
Returns:
|
||||
(plan, tokens_used) - plan 为 None 表示规划失败
|
||||
"""
|
||||
|
|
@ -817,8 +968,9 @@ class ReWOOEngine:
|
|||
return None, tokens_used
|
||||
|
||||
# 限制计划步数
|
||||
if len(plan.steps) > self._max_plan_steps:
|
||||
plan.steps = plan.steps[:self._max_plan_steps]
|
||||
effective_max_steps = max_steps if max_steps is not None else self._max_plan_steps
|
||||
if len(plan.steps) > effective_max_steps:
|
||||
plan.steps = plan.steps[:effective_max_steps]
|
||||
|
||||
return plan, tokens_used
|
||||
|
||||
|
|
|
|||
|
|
@ -23,6 +23,17 @@ from agentkit.memory.profile import MemoryStore
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SoulEvolutionConfig:
|
||||
"""Soul 进化多维触发配置"""
|
||||
|
||||
min_reflections: int = 3
|
||||
reflection_window_seconds: int = 3600
|
||||
time_decay_factor: float = 0.5
|
||||
task_type_weights: dict[str, float] = field(default_factory=dict)
|
||||
quality_gradient_threshold: float = -0.15
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvolutionLogEntry:
|
||||
"""进化日志条目"""
|
||||
|
|
@ -59,6 +70,7 @@ class EvolutionMixin:
|
|||
llm_gateway: Any | None = None,
|
||||
auxiliary_model: str | None = None,
|
||||
strategy_tuning_enabled: bool = False,
|
||||
evolution_config: SoulEvolutionConfig | None = None,
|
||||
):
|
||||
if reflector is not EvolutionMixin._UNSET:
|
||||
# 显式传入了 reflector 参数(包括 None)
|
||||
|
|
@ -78,6 +90,7 @@ class EvolutionMixin:
|
|||
self._evolution_log: list[EvolutionLogEntry] = []
|
||||
self._current_module: Module | None = None
|
||||
self._strategy_tuning_enabled = strategy_tuning_enabled
|
||||
self._evolution_config = evolution_config
|
||||
self.pending_soul_updates: dict[str, list] = {}
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -373,19 +386,41 @@ class EvolutionMixin:
|
|||
logger.error(f"Failed to rollback evolution change: {e}")
|
||||
return False
|
||||
|
||||
def record_reflection(
|
||||
self,
|
||||
pattern: str,
|
||||
reflection: Reflection,
|
||||
task_type: str = "",
|
||||
score: float | None = None,
|
||||
) -> None:
|
||||
"""记录反思到待处理列表,附带时间戳、分数和任务类型。"""
|
||||
if pattern not in self.pending_soul_updates:
|
||||
self.pending_soul_updates[pattern] = []
|
||||
self.pending_soul_updates[pattern].append(
|
||||
{
|
||||
"reflection": reflection,
|
||||
"timestamp": datetime.now(timezone.utc),
|
||||
"score": score if score is not None else reflection.quality_score,
|
||||
"task_type": task_type,
|
||||
}
|
||||
)
|
||||
|
||||
async def evolve_soul(
|
||||
self,
|
||||
task: TaskMessage,
|
||||
result: TaskResult,
|
||||
memory_store: MemoryStore | None = None,
|
||||
reflection: Reflection | None = None,
|
||||
task_type: str = "",
|
||||
score: float | None = None,
|
||||
) -> bool:
|
||||
"""Check if soul should be updated based on accumulated reflections.
|
||||
|
||||
Conditions for soul update:
|
||||
- Same category reflection appears >= 3 times
|
||||
- Reflection quality_score < 0.5 (indicating consistent issues)
|
||||
- Reflection has actionable suggestions
|
||||
Multi-dimensional triggers:
|
||||
- Time decay: older reflections contribute less
|
||||
- Quality gradient: declining scores trigger early
|
||||
- Task type weight: different task types have different trigger thresholds
|
||||
- Trigger threshold: effective_count * weight >= min_reflections
|
||||
"""
|
||||
if memory_store is None:
|
||||
return False
|
||||
|
|
@ -402,22 +437,53 @@ class EvolutionMixin:
|
|||
if not reflection.suggestions:
|
||||
return False
|
||||
|
||||
# 按 pattern 分类累积反思
|
||||
for pattern in reflection.patterns:
|
||||
if pattern not in self.pending_soul_updates:
|
||||
self.pending_soul_updates[pattern] = []
|
||||
self.pending_soul_updates[pattern].append(reflection)
|
||||
config = self._evolution_config or SoulEvolutionConfig()
|
||||
|
||||
# 检查是否有同一类别累积 >= 3 次反思
|
||||
# 按 pattern 分类累积反思(patterns为空时使用默认category)
|
||||
categories = reflection.patterns if reflection.patterns else ["default"]
|
||||
for pattern in categories:
|
||||
self.record_reflection(
|
||||
pattern, reflection, task_type=task_type, score=score
|
||||
)
|
||||
|
||||
# 检查是否有类别满足触发条件
|
||||
for category, reflections in list(self.pending_soul_updates.items()):
|
||||
if len(reflections) >= 3:
|
||||
# --- Quality gradient: 3+ declining scores trigger early ---
|
||||
scores = [r["score"] for r in reflections if r["score"] is not None]
|
||||
quality_gradient_triggered = False
|
||||
if len(scores) >= 3:
|
||||
last_3 = scores[-3:]
|
||||
declines = [
|
||||
last_3[i] - last_3[i - 1] for i in range(1, len(last_3))
|
||||
]
|
||||
if all(d <= config.quality_gradient_threshold for d in declines):
|
||||
quality_gradient_triggered = True
|
||||
|
||||
# --- Time decay: compute effective_count ---
|
||||
now = datetime.now(timezone.utc)
|
||||
effective_count = 0.0
|
||||
for r in reflections:
|
||||
age_seconds = (now - r["timestamp"]).total_seconds()
|
||||
age_hours = age_seconds / 3600.0
|
||||
effective_count += config.time_decay_factor ** age_hours
|
||||
# Round to avoid floating-point precision issues
|
||||
# (e.g. 3 recent reflections should yield exactly 3.0)
|
||||
effective_count = round(effective_count, 6)
|
||||
|
||||
# --- Task type weight ---
|
||||
weight = 1.0
|
||||
if task_type and task_type in config.task_type_weights:
|
||||
weight = config.task_type_weights[task_type]
|
||||
|
||||
# --- Trigger threshold: effective_count * weight >= min_reflections ---
|
||||
weighted_count = effective_count * weight
|
||||
if weighted_count >= config.min_reflections or quality_gradient_triggered:
|
||||
# 触发 soul 更新
|
||||
from agentkit.tools.memory_tool import MemoryTool
|
||||
|
||||
tool = MemoryTool(memory_store)
|
||||
# 使用第一个建议作为更新内容
|
||||
section = category
|
||||
content = "; ".join(reflections[-1].suggestions[:2])
|
||||
content = "; ".join(reflections[-1]["reflection"].suggestions[:2])
|
||||
reason = f"连续{len(reflections)}次低质量反思 (category: {category})"
|
||||
|
||||
update_result = await tool.execute(
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ import logging
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from agentkit.telemetry.tracer import get_tracer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -92,24 +94,37 @@ class AlignmentGuard:
|
|||
if not effective_constraints:
|
||||
return AlignmentCheckResult(passed=True, checked_by="rule")
|
||||
|
||||
# 1. 基于规则的检查:关键词/子串匹配
|
||||
violations = self._rule_check(output, effective_constraints)
|
||||
if violations:
|
||||
return AlignmentCheckResult(
|
||||
passed=False,
|
||||
violations=violations,
|
||||
checked_by="rule",
|
||||
)
|
||||
tracer = get_tracer()
|
||||
with tracer.start_span("guard.check") as span:
|
||||
span.set_attribute("guard.constraints_count", len(effective_constraints))
|
||||
|
||||
# 2. LLM 语义检查(仅当 audit_enabled=True 且有 llm_gateway,按采样率执行)
|
||||
if self._config.audit_enabled and self._llm_gateway is not None:
|
||||
import random
|
||||
if random.random() < self._config.audit_sample_rate:
|
||||
return await self._llm_check(output, effective_constraints)
|
||||
# 采样未命中,信任规则检查结果
|
||||
logger.debug("LLM audit skipped (sample rate=%.2f)", self._config.audit_sample_rate)
|
||||
# 1. 基于规则的检查:关键词/子串匹配
|
||||
violations = self._rule_check(output, effective_constraints)
|
||||
if violations:
|
||||
result = AlignmentCheckResult(
|
||||
passed=False,
|
||||
violations=violations,
|
||||
checked_by="rule",
|
||||
)
|
||||
span.set_attribute("guard.passed", result.passed)
|
||||
span.set_attribute("guard.checked_by", result.checked_by)
|
||||
return result
|
||||
|
||||
return AlignmentCheckResult(passed=True, checked_by="rule")
|
||||
# 2. LLM 语义检查(仅当 audit_enabled=True 且有 llm_gateway,按采样率执行)
|
||||
if self._config.audit_enabled and self._llm_gateway is not None:
|
||||
import random
|
||||
if random.random() < self._config.audit_sample_rate:
|
||||
result = await self._llm_check(output, effective_constraints)
|
||||
span.set_attribute("guard.passed", result.passed)
|
||||
span.set_attribute("guard.checked_by", result.checked_by)
|
||||
return result
|
||||
# 采样未命中,信任规则检查结果
|
||||
logger.debug("LLM audit skipped (sample rate=%.2f)", self._config.audit_sample_rate)
|
||||
|
||||
result = AlignmentCheckResult(passed=True, checked_by="rule")
|
||||
span.set_attribute("guard.passed", result.passed)
|
||||
span.set_attribute("guard.checked_by", result.checked_by)
|
||||
return result
|
||||
|
||||
def _rule_check(
|
||||
self, output: dict[str, Any], constraints: list[str]
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from agentkit.telemetry.metrics import (
|
|||
pipeline_step_histogram,
|
||||
)
|
||||
from agentkit.telemetry.setup import setup_telemetry
|
||||
from agentkit.telemetry.tracer import TelemetryConfig, NoOpTracer
|
||||
|
||||
__all__ = [
|
||||
"get_tracer",
|
||||
|
|
@ -35,4 +36,6 @@ __all__ = [
|
|||
"pipeline_step_histogram",
|
||||
"setup_telemetry",
|
||||
"_OTEL_AVAILABLE",
|
||||
"TelemetryConfig",
|
||||
"NoOpTracer",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,134 @@
|
|||
"""OpenTelemetry tracer integration with no-op fallback."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, ContextManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TelemetryConfig:
|
||||
"""Telemetry configuration."""
|
||||
|
||||
enabled: bool = False
|
||||
otlp_endpoint: str = ""
|
||||
service_name: str = "agentkit"
|
||||
sample_rate: float = 1.0
|
||||
|
||||
|
||||
class NoOpSpan:
|
||||
"""No-op span when telemetry is disabled."""
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
pass
|
||||
|
||||
def set_attribute(self, key: str, value: Any) -> None:
|
||||
pass
|
||||
|
||||
def add_event(self, name: str, attributes: dict[str, Any] | None = None) -> None:
|
||||
pass
|
||||
|
||||
def record_exception(self, exception: Exception) -> None:
|
||||
pass
|
||||
|
||||
def is_recording(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class NoOpTracer:
|
||||
"""No-op tracer when telemetry is disabled."""
|
||||
|
||||
def start_span(self, name: str, attributes: dict[str, Any] | None = None) -> ContextManager[NoOpSpan]:
|
||||
return NoOpSpan()
|
||||
|
||||
def start_as_current_span(self, name: str, attributes: dict[str, Any] | None = None) -> ContextManager[NoOpSpan]:
|
||||
return NoOpSpan()
|
||||
|
||||
|
||||
class OTelSpan:
|
||||
"""Wrapper around OpenTelemetry Span."""
|
||||
|
||||
def __init__(self, span: Any):
|
||||
self._span = span
|
||||
|
||||
def __enter__(self):
|
||||
self._span.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self._span.__exit__(*args)
|
||||
|
||||
def set_attribute(self, key: str, value: Any) -> None:
|
||||
self._span.set_attribute(key, value)
|
||||
|
||||
def add_event(self, name: str, attributes: dict[str, Any] | None = None) -> None:
|
||||
self._span.add_event(name, attributes or {})
|
||||
|
||||
def record_exception(self, exception: Exception) -> None:
|
||||
self._span.record_exception(exception)
|
||||
|
||||
def is_recording(self) -> bool:
|
||||
return self._span.is_recording()
|
||||
|
||||
|
||||
class OTelTracer:
|
||||
"""Wrapper around OpenTelemetry Tracer."""
|
||||
|
||||
def __init__(self, tracer: Any):
|
||||
self._tracer = tracer
|
||||
|
||||
def start_span(self, name: str, attributes: dict[str, Any] | None = None) -> OTelSpan:
|
||||
span = self._tracer.start_span(name, attributes=attributes)
|
||||
return OTelSpan(span)
|
||||
|
||||
def start_as_current_span(self, name: str, attributes: dict[str, Any] | None = None) -> OTelSpan:
|
||||
span = self._tracer.start_as_current_span(name, attributes=attributes)
|
||||
return OTelSpan(span)
|
||||
|
||||
|
||||
# Global tracer instance
|
||||
_tracer: NoOpTracer | OTelTracer = NoOpTracer()
|
||||
|
||||
|
||||
def get_tracer() -> NoOpTracer | OTelTracer:
|
||||
"""Get the global tracer instance."""
|
||||
return _tracer
|
||||
|
||||
|
||||
def init_telemetry(config: TelemetryConfig) -> None:
|
||||
"""Initialize telemetry with the given configuration."""
|
||||
global _tracer
|
||||
|
||||
if not config.enabled:
|
||||
_tracer = NoOpTracer()
|
||||
logger.info("Telemetry disabled, using no-op tracer")
|
||||
return
|
||||
|
||||
try:
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
|
||||
|
||||
resource = Resource.create({"service.name": config.service_name})
|
||||
provider = TracerProvider(resource=resource)
|
||||
|
||||
if config.otlp_endpoint:
|
||||
exporter = OTLPSpanExporter(endpoint=config.otlp_endpoint)
|
||||
provider.add_span_processor(BatchSpanProcessor(exporter))
|
||||
|
||||
trace.set_tracer_provider(provider)
|
||||
_tracer = OTelTracer(trace.get_tracer(config.service_name))
|
||||
logger.info(f"Telemetry initialized with endpoint: {config.otlp_endpoint}")
|
||||
except ImportError:
|
||||
logger.warning("opentelemetry packages not installed, using no-op tracer")
|
||||
_tracer = NoOpTracer()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize telemetry: {e}, using no-op tracer")
|
||||
_tracer = NoOpTracer()
|
||||
|
|
@ -0,0 +1,292 @@
|
|||
"""集成测试 - Reflexion 多轮循环
|
||||
|
||||
测试 ReflexionEngine 的 Evaluate→Reflect→Retry 循环。
|
||||
仅 mock LLMGateway(外部 API),使用真实 ReflexionEngine 实例。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.core.react import ReActEngine, ReActStep
|
||||
from agentkit.core.reflexion import ReflexionEngine, ReflexionReflection, ReflexionResult
|
||||
from agentkit.llm.gateway import LLMGateway
|
||||
from agentkit.llm.protocol import LLMResponse, TokenUsage
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def make_response(
|
||||
content: str = "",
|
||||
prompt_tokens: int = 10,
|
||||
completion_tokens: int = 20,
|
||||
) -> LLMResponse:
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model="test-model",
|
||||
usage=TokenUsage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def make_mock_gateway(responses: list[LLMResponse]) -> MagicMock:
|
||||
gateway = MagicMock(spec=LLMGateway)
|
||||
gateway.chat = AsyncMock(side_effect=responses)
|
||||
return gateway
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 1: First attempt scores high → no retry, returns result
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReflexionFirstAttemptPasses:
|
||||
"""首次尝试分数高于阈值,无需重试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_high_score_no_retry(self):
|
||||
gateway = make_mock_gateway([
|
||||
# ReAct call
|
||||
make_response(content="The answer is 42"),
|
||||
# Evaluation call - high score
|
||||
make_response(content='```json\n{"score": 0.9, "reasoning": "Excellent"}\n```'),
|
||||
])
|
||||
engine = ReflexionEngine(llm_gateway=gateway, quality_threshold=0.7)
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "What is the answer?"}],
|
||||
)
|
||||
|
||||
assert isinstance(result, ReflexionResult)
|
||||
assert result.output == "The answer is 42"
|
||||
assert result.evaluation_score == 0.9
|
||||
assert result.reflection_count == 0
|
||||
assert len(result.reflections) == 0
|
||||
assert result.status == "success"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_exactly_at_threshold(self):
|
||||
"""分数恰好等于阈值,无需重试"""
|
||||
gateway = make_mock_gateway([
|
||||
make_response(content="Answer"),
|
||||
make_response(content='```json\n{"score": 0.7, "reasoning": "OK"}\n```'),
|
||||
])
|
||||
engine = ReflexionEngine(llm_gateway=gateway, quality_threshold=0.7)
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "Task"}],
|
||||
)
|
||||
|
||||
assert result.evaluation_score == 0.7
|
||||
assert result.reflection_count == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 2: First attempt scores low, second scores high → returns best result
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReflexionRetryImprovesScore:
|
||||
"""首次低分,反思后重试高分"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reflection_and_retry_on_low_score(self):
|
||||
gateway = make_mock_gateway([
|
||||
# 1st ReAct call
|
||||
make_response(content="Initial poor answer"),
|
||||
# 1st Evaluation - low score
|
||||
make_response(content='```json\n{"score": 0.3, "reasoning": "Incomplete"}\n```'),
|
||||
# 1st Reflection call
|
||||
make_response(content="You need to be more specific and provide detailed analysis."),
|
||||
# 2nd ReAct call
|
||||
make_response(content="Improved detailed answer"),
|
||||
# 2nd Evaluation - high score
|
||||
make_response(content='```json\n{"score": 0.85, "reasoning": "Good improvement"}\n```'),
|
||||
])
|
||||
engine = ReflexionEngine(llm_gateway=gateway, quality_threshold=0.7)
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "Analyze this"}],
|
||||
)
|
||||
|
||||
assert result.output == "Improved detailed answer"
|
||||
assert result.evaluation_score == 0.85
|
||||
assert result.reflection_count == 1
|
||||
assert len(result.reflections) == 1
|
||||
assert result.reflections[0].score_before == 0.3
|
||||
assert result.reflections[0].retry_number == 1
|
||||
assert "specific" in result.reflections[0].reflection_text.lower() or "detailed" in result.reflections[0].reflection_text.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_retries_improve_score(self):
|
||||
"""多次重试后分数逐步提升"""
|
||||
gateway = make_mock_gateway([
|
||||
# Attempt 1
|
||||
make_response(content="Bad answer"),
|
||||
make_response(content='```json\n{"score": 0.2}\n```'),
|
||||
make_response(content="Need more depth"),
|
||||
# Attempt 2
|
||||
make_response(content="Better answer"),
|
||||
make_response(content='```json\n{"score": 0.5}\n```'),
|
||||
make_response(content="Still needs improvement"),
|
||||
# Attempt 3
|
||||
make_response(content="Great answer"),
|
||||
make_response(content='```json\n{"score": 0.9}\n```'),
|
||||
])
|
||||
engine = ReflexionEngine(llm_gateway=gateway, quality_threshold=0.7, max_reflections=3)
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "Complex task"}],
|
||||
)
|
||||
|
||||
assert result.output == "Great answer"
|
||||
assert result.evaluation_score == 0.9
|
||||
assert result.reflection_count == 2
|
||||
assert len(result.reflections) == 2
|
||||
assert result.reflections[0].retry_number == 1
|
||||
assert result.reflections[1].retry_number == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 3: Max reflections reached → returns best result found
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReflexionMaxReflectionsReached:
|
||||
"""达到最大反思次数后返回最佳结果"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_best_result_when_max_reflections_reached(self):
|
||||
gateway = make_mock_gateway([
|
||||
# Attempt 1
|
||||
make_response(content="Poor answer"),
|
||||
make_response(content='```json\n{"score": 0.3}\n```'),
|
||||
make_response(content="Try harder"),
|
||||
# Attempt 2
|
||||
make_response(content="Slightly better answer"),
|
||||
make_response(content='```json\n{"score": 0.5}\n```'),
|
||||
make_response(content="Still not good enough"),
|
||||
# Attempt 3 (max)
|
||||
make_response(content="Another answer"),
|
||||
make_response(content='```json\n{"score": 0.6}\n```'),
|
||||
])
|
||||
engine = ReflexionEngine(llm_gateway=gateway, quality_threshold=0.7, max_reflections=3)
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "Hard task"}],
|
||||
)
|
||||
|
||||
# Should return the best result (score 0.6 from last attempt)
|
||||
assert result.evaluation_score == 0.6
|
||||
assert result.reflection_count == 2
|
||||
assert result.output == "Another answer"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_earlier_best_when_later_worse(self):
|
||||
"""后续尝试分数更低时,返回之前最佳结果"""
|
||||
gateway = make_mock_gateway([
|
||||
# Attempt 1: score 0.5
|
||||
make_response(content="Decent answer"),
|
||||
make_response(content='```json\n{"score": 0.5}\n```'),
|
||||
make_response(content="Try to improve"),
|
||||
# Attempt 2: score 0.4 (worse)
|
||||
make_response(content="Worse answer"),
|
||||
make_response(content='```json\n{"score": 0.4}\n```'),
|
||||
make_response(content="Still trying"),
|
||||
# Attempt 3: score 0.45 (still worse than attempt 1)
|
||||
make_response(content="Another answer"),
|
||||
make_response(content='```json\n{"score": 0.45}\n```'),
|
||||
# Reflection for attempt 3 (consumed but loop ends)
|
||||
make_response(content="Final reflection"),
|
||||
])
|
||||
engine = ReflexionEngine(llm_gateway=gateway, quality_threshold=0.7, max_reflections=3)
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "Task"}],
|
||||
)
|
||||
|
||||
# Best score was 0.5 from attempt 1
|
||||
assert result.evaluation_score == 0.5
|
||||
assert result.output == "Decent answer"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 4: Reflection text improves subsequent attempts
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReflexionReflectionImprovesAttempts:
|
||||
"""反思文本改善后续尝试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reflection_injected_into_system_prompt(self):
|
||||
"""反思文本被注入到下一次 ReAct 的 system prompt 中"""
|
||||
gateway = make_mock_gateway([
|
||||
make_response(content="Poor answer"),
|
||||
make_response(content='```json\n{"score": 0.3}\n```'),
|
||||
make_response(content="You need to provide more specific details."),
|
||||
make_response(content="Better answer with details"),
|
||||
make_response(content='```json\n{"score": 0.9}\n```'),
|
||||
])
|
||||
engine = ReflexionEngine(llm_gateway=gateway, quality_threshold=0.7)
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "Task"}],
|
||||
system_prompt="You are a helpful assistant",
|
||||
)
|
||||
|
||||
assert result.reflection_count == 1
|
||||
assert result.evaluation_score == 0.9
|
||||
assert result.output == "Better answer with details"
|
||||
|
||||
# Verify the reflection was recorded with correct metadata
|
||||
assert result.reflections[0].reflection_text == "You need to provide more specific details."
|
||||
assert result.reflections[0].score_before == 0.3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reflexion_composes_react_engine(self):
|
||||
"""ReflexionEngine 组合(而非继承)ReActEngine"""
|
||||
gateway = make_mock_gateway([
|
||||
make_response(content="Answer"),
|
||||
make_response(content='```json\n{"score": 0.9}\n```'),
|
||||
])
|
||||
engine = ReflexionEngine(llm_gateway=gateway)
|
||||
|
||||
assert hasattr(engine, "_react_engine")
|
||||
assert isinstance(engine._react_engine, ReActEngine)
|
||||
assert not isinstance(engine, ReActEngine)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reflexion_result_has_all_fields(self):
|
||||
"""ReflexionResult 包含所有必要字段"""
|
||||
gateway = make_mock_gateway([
|
||||
make_response(content="Answer"),
|
||||
make_response(content='```json\n{"score": 0.85}\n```'),
|
||||
])
|
||||
engine = ReflexionEngine(llm_gateway=gateway)
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "Task"}],
|
||||
)
|
||||
|
||||
# ReActResult fields
|
||||
assert hasattr(result, "output")
|
||||
assert hasattr(result, "trajectory")
|
||||
assert hasattr(result, "total_steps")
|
||||
assert hasattr(result, "total_tokens")
|
||||
assert hasattr(result, "status")
|
||||
|
||||
# ReflexionResult additional fields
|
||||
assert hasattr(result, "evaluation_score")
|
||||
assert hasattr(result, "reflection_count")
|
||||
assert hasattr(result, "reflections")
|
||||
|
||||
# All trajectory steps are ReActStep
|
||||
assert all(isinstance(step, ReActStep) for step in result.trajectory)
|
||||
|
|
@ -0,0 +1,293 @@
|
|||
"""集成测试 - ReWOO 渐进式回退链
|
||||
|
||||
测试 ReWOOEngine 的 planning → simplified_rewoo → react → direct 回退策略。
|
||||
仅 mock LLMGateway(外部 API),使用真实 ReWOOEngine 实例。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.core.rewoo import ReWOOEngine, ReWOOStep
|
||||
from agentkit.core.react import ReActResult, ReActStep
|
||||
from agentkit.llm.gateway import LLMGateway
|
||||
from agentkit.llm.protocol import LLMResponse, TokenUsage, ToolCall
|
||||
from agentkit.tools.base import Tool
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class FakeTool(Tool):
|
||||
"""用于测试的 Fake Tool"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str = "fake_tool",
|
||||
description: str = "A fake tool for testing",
|
||||
result: dict | None = None,
|
||||
should_fail: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
name=name,
|
||||
description=description,
|
||||
)
|
||||
self._result = result or {"status": "ok"}
|
||||
self._should_fail = should_fail
|
||||
self.call_count = 0
|
||||
|
||||
async def execute(self, **kwargs) -> dict:
|
||||
self.call_count += 1
|
||||
if self._should_fail:
|
||||
raise RuntimeError(f"Tool '{self.name}' execution failed")
|
||||
return self._result
|
||||
|
||||
|
||||
def make_response(
|
||||
content: str = "",
|
||||
tool_calls: list[ToolCall] | None = None,
|
||||
prompt_tokens: int = 10,
|
||||
completion_tokens: int = 20,
|
||||
) -> LLMResponse:
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model="test-model",
|
||||
usage=TokenUsage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
),
|
||||
tool_calls=tool_calls or [],
|
||||
)
|
||||
|
||||
|
||||
def make_plan_response(
|
||||
steps: list[dict],
|
||||
reasoning: str = "Plan reasoning",
|
||||
) -> LLMResponse:
|
||||
plan_json = json.dumps({
|
||||
"reasoning": reasoning,
|
||||
"steps": steps,
|
||||
})
|
||||
return make_response(content=plan_json)
|
||||
|
||||
|
||||
def make_mock_gateway(responses: list[LLMResponse]) -> MagicMock:
|
||||
gateway = MagicMock(spec=LLMGateway)
|
||||
gateway.chat = AsyncMock(side_effect=responses)
|
||||
return gateway
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 1: Planning succeeds → no fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReWOOPlanningSucceeds:
|
||||
"""规划成功,不使用回退"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_planning_succeeds_no_fallback(self):
|
||||
tool = FakeTool(name="calculator", result={"value": 42})
|
||||
|
||||
plan_response = make_plan_response([
|
||||
{"step_id": 1, "tool_name": "calculator", "arguments": {"expr": "6*7"}, "reasoning": "Calculate"},
|
||||
])
|
||||
synthesis_response = make_response(content="The result is 42")
|
||||
|
||||
gateway = make_mock_gateway([plan_response, synthesis_response])
|
||||
engine = ReWOOEngine(llm_gateway=gateway)
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "Calculate 6*7"}],
|
||||
tools=[tool],
|
||||
)
|
||||
|
||||
assert isinstance(result, ReActResult)
|
||||
assert result.output == "The result is 42"
|
||||
assert result.fallback_strategy is None
|
||||
assert result.status == "success"
|
||||
# 1 tool_call + 1 final_answer = 2 steps
|
||||
assert result.total_steps == 2
|
||||
assert tool.call_count == 1
|
||||
|
||||
# Verify ReWOOStep has plan_step_id
|
||||
tool_steps = [s for s in result.trajectory if s.action == "tool_call"]
|
||||
assert len(tool_steps) == 1
|
||||
assert isinstance(tool_steps[0], ReWOOStep)
|
||||
assert tool_steps[0].plan_step_id == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 2: Planning fails → simplified planning succeeds → fallback_strategy="simplified_rewoo"
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReWOOSimplifiedFallback:
|
||||
"""规划失败,简化规划成功"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_planning_fails_simplified_succeeds(self):
|
||||
tool = FakeTool(name="search", result={"results": ["found"]})
|
||||
|
||||
# First plan fails (invalid JSON), second (simplified) succeeds
|
||||
invalid_plan_response = make_response(content="I cannot create a plan for this task.")
|
||||
simplified_plan_response = make_plan_response([
|
||||
{"step_id": 1, "tool_name": "search", "arguments": {"query": "test"}, "reasoning": "Simplified search"},
|
||||
])
|
||||
synthesis_response = make_response(content="Simplified result")
|
||||
|
||||
gateway = make_mock_gateway([invalid_plan_response, simplified_plan_response, synthesis_response])
|
||||
engine = ReWOOEngine(llm_gateway=gateway)
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "Complex task"}],
|
||||
tools=[tool],
|
||||
)
|
||||
|
||||
assert result.output == "Simplified result"
|
||||
assert result.fallback_strategy == "simplified_rewoo"
|
||||
assert tool.call_count == 1
|
||||
|
||||
# Verify trajectory still has proper structure
|
||||
tool_steps = [s for s in result.trajectory if s.action == "tool_call"]
|
||||
assert len(tool_steps) == 1
|
||||
assert tool_steps[0].tool_name == "search"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 3: All planning fails → ReAct fallback → fallback_strategy="react"
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReWOOReActFallback:
|
||||
"""所有规划失败,回退到 ReAct"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_planning_and_simplified_fail_react_succeeds(self):
|
||||
# Both plan attempts fail (invalid JSON), ReAct succeeds
|
||||
invalid_plan1 = make_response(content="Not a plan")
|
||||
invalid_plan2 = make_response(content="Still not a plan")
|
||||
react_response = make_response(content="ReAct fallback answer")
|
||||
|
||||
gateway = make_mock_gateway([invalid_plan1, invalid_plan2, react_response])
|
||||
engine = ReWOOEngine(llm_gateway=gateway)
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "Complex task"}],
|
||||
)
|
||||
|
||||
assert result.output == "ReAct fallback answer"
|
||||
assert result.fallback_strategy == "react"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_react_fallback_with_tool_calls(self):
|
||||
"""ReAct 回退时带工具调用"""
|
||||
tool = FakeTool(name="search", result={"results": ["found"]})
|
||||
|
||||
invalid_plan1 = make_response(content="Cannot plan")
|
||||
invalid_plan2 = make_response(content="Still cannot plan")
|
||||
react_tool_response = make_response(
|
||||
content="",
|
||||
tool_calls=[ToolCall(id="tc_1", name="search", arguments={"query": "test"})],
|
||||
)
|
||||
react_final_response = make_response(content="ReAct answer with tool")
|
||||
|
||||
gateway = make_mock_gateway([
|
||||
invalid_plan1,
|
||||
invalid_plan2,
|
||||
react_tool_response,
|
||||
react_final_response,
|
||||
])
|
||||
engine = ReWOOEngine(llm_gateway=gateway)
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "Search task"}],
|
||||
tools=[tool],
|
||||
)
|
||||
|
||||
assert result.output == "ReAct answer with tool"
|
||||
assert result.fallback_strategy == "react"
|
||||
assert tool.call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_malformed_json_triggers_react_fallback(self):
|
||||
"""格式错误的 JSON 触发 ReAct 回退"""
|
||||
malformed_response = make_response(content='{"reasoning": "plan", "steps": [invalid json')
|
||||
simplified_fail_response = make_response(content="Also not a plan")
|
||||
react_response = make_response(content="ReAct answer")
|
||||
|
||||
gateway = make_mock_gateway([malformed_response, simplified_fail_response, react_response])
|
||||
engine = ReWOOEngine(llm_gateway=gateway)
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "Task"}],
|
||||
)
|
||||
|
||||
assert result.output == "ReAct answer"
|
||||
assert result.fallback_strategy == "react"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_steps_key_triggers_react_fallback(self):
|
||||
"""缺少 steps 键触发 ReAct 回退"""
|
||||
no_steps_response = make_response(content='{"reasoning": "no steps here"}')
|
||||
simplified_fail_response = make_response(content="Also no steps")
|
||||
react_response = make_response(content="ReAct fallback")
|
||||
|
||||
gateway = make_mock_gateway([no_steps_response, simplified_fail_response, react_response])
|
||||
engine = ReWOOEngine(llm_gateway=gateway)
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "Task"}],
|
||||
)
|
||||
|
||||
assert result.output == "ReAct fallback"
|
||||
assert result.fallback_strategy == "react"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: Multi-step plan with fallback chain integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReWOOMultiStepWithFallbackChain:
|
||||
"""多步计划与回退链的集成测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_three_step_plan_no_fallback(self):
|
||||
"""三步计划成功,无回退"""
|
||||
search_tool = FakeTool(name="search", result={"results": ["Python is great"]})
|
||||
calc_tool = FakeTool(name="calculator", result={"value": 100})
|
||||
weather_tool = FakeTool(name="weather", result={"temp": 25, "city": "Shanghai"})
|
||||
|
||||
plan_response = make_plan_response([
|
||||
{"step_id": 1, "tool_name": "search", "arguments": {"query": "Python"}, "reasoning": "Search first"},
|
||||
{"step_id": 2, "tool_name": "calculator", "arguments": {"expr": "10*10"}, "reasoning": "Calculate"},
|
||||
{"step_id": 3, "tool_name": "weather", "arguments": {"city": "Shanghai"}, "reasoning": "Check weather"},
|
||||
])
|
||||
synthesis_response = make_response(content="Based on search, calculation (100), and weather (25°C)")
|
||||
|
||||
gateway = make_mock_gateway([plan_response, synthesis_response])
|
||||
engine = ReWOOEngine(llm_gateway=gateway)
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "Search, calculate and check weather"}],
|
||||
tools=[search_tool, calc_tool, weather_tool],
|
||||
)
|
||||
|
||||
assert result.fallback_strategy is None
|
||||
assert result.total_steps == 4 # 3 tool_calls + 1 final_answer
|
||||
assert search_tool.call_count == 1
|
||||
assert calc_tool.call_count == 1
|
||||
assert weather_tool.call_count == 1
|
||||
assert "100" in result.output
|
||||
assert "25" in result.output
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_strategies_constant(self):
|
||||
"""验证 FALLBACK_STRATEGIES 常量"""
|
||||
assert ReWOOEngine.FALLBACK_STRATEGIES == ["simplified_rewoo", "react", "direct"]
|
||||
|
|
@ -0,0 +1,315 @@
|
|||
"""集成测试 - CostAwareRouter → Engine → AlignmentGuard 全链路"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.chat.skill_routing import CostAwareRouter, SkillRoutingResult
|
||||
from agentkit.core.react import ReActEngine, ReActResult, ReActStep
|
||||
from agentkit.llm.protocol import LLMResponse, TokenUsage
|
||||
from agentkit.org.context import AgentProfile, OrganizationContext
|
||||
from agentkit.quality.alignment import AlignmentConfig, AlignmentGuard
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_llm_response(content: str) -> LLMResponse:
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model="test-model",
|
||||
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
|
||||
)
|
||||
|
||||
|
||||
def _make_mock_gateway(responses: list[LLMResponse]) -> MagicMock:
|
||||
gateway = MagicMock()
|
||||
gateway.chat = AsyncMock(side_effect=responses)
|
||||
return gateway
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 1: Router routes to ReAct engine, output passes alignment check
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRouterToEnginePassesAlignment:
|
||||
"""路由到 ReAct 引擎,输出通过对齐检查"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_react_output_passes_alignment(self):
|
||||
# --- Setup: LLM returns low complexity → default agent (ReAct) ---
|
||||
gateway = _make_mock_gateway([
|
||||
_make_llm_response('{"complexity": 0.2}'), # quick_classify
|
||||
_make_llm_response("你好!有什么可以帮你的?"), # ReAct final answer
|
||||
])
|
||||
|
||||
org_context = OrganizationContext()
|
||||
alignment_config = AlignmentConfig(
|
||||
constraints=["password", "secret_key"],
|
||||
)
|
||||
guard = AlignmentGuard(config=alignment_config)
|
||||
|
||||
router = CostAwareRouter(llm_gateway=gateway, org_context=org_context)
|
||||
|
||||
mock_skill_registry = MagicMock()
|
||||
mock_skill_registry.list_skills.return_value = []
|
||||
mock_intent_router = AsyncMock()
|
||||
|
||||
# Step 1: Route
|
||||
route_result = await router.route(
|
||||
content="随便聊聊",
|
||||
skill_registry=mock_skill_registry,
|
||||
intent_router=mock_intent_router,
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful",
|
||||
default_model="default",
|
||||
default_agent_name="default",
|
||||
)
|
||||
assert route_result.complexity < 0.3
|
||||
assert route_result.agent_name == "default"
|
||||
|
||||
# Step 2: Inject constraints
|
||||
input_data = {"content": route_result.clean_content}
|
||||
injected = guard.inject_constraints(input_data)
|
||||
assert "alignment_constraints" in injected
|
||||
|
||||
# Step 3: Simulate engine execution (use real ReActEngine with mock gateway)
|
||||
react_engine = ReActEngine(llm_gateway=gateway)
|
||||
engine_result = await react_engine.execute(
|
||||
messages=[{"role": "user", "content": injected["content"]}],
|
||||
)
|
||||
|
||||
# Step 4: Alignment check
|
||||
output = {"result": engine_result.output}
|
||||
check_result = await guard.check_output(output)
|
||||
assert check_result.passed is True
|
||||
assert check_result.violations == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 2: Router routes to ReAct engine, output fails alignment check
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRouterToEngineFailsAlignment:
|
||||
"""路由到 ReAct 引擎,输出未通过对齐检查"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_react_output_fails_alignment(self):
|
||||
gateway = _make_mock_gateway([
|
||||
_make_llm_response('{"complexity": 0.2}'),
|
||||
_make_llm_response("Your password is 123456"),
|
||||
])
|
||||
|
||||
org_context = OrganizationContext()
|
||||
alignment_config = AlignmentConfig(
|
||||
constraints=["password", "secret_key"],
|
||||
)
|
||||
guard = AlignmentGuard(config=alignment_config)
|
||||
|
||||
router = CostAwareRouter(llm_gateway=gateway, org_context=org_context)
|
||||
|
||||
mock_skill_registry = MagicMock()
|
||||
mock_skill_registry.list_skills.return_value = []
|
||||
mock_intent_router = AsyncMock()
|
||||
|
||||
route_result = await router.route(
|
||||
content="随便聊聊",
|
||||
skill_registry=mock_skill_registry,
|
||||
intent_router=mock_intent_router,
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful",
|
||||
default_model="default",
|
||||
default_agent_name="default",
|
||||
)
|
||||
|
||||
input_data = {"content": route_result.clean_content}
|
||||
injected = guard.inject_constraints(input_data)
|
||||
|
||||
react_engine = ReActEngine(llm_gateway=gateway)
|
||||
engine_result = await react_engine.execute(
|
||||
messages=[{"role": "user", "content": injected["content"]}],
|
||||
)
|
||||
|
||||
output = {"result": engine_result.output}
|
||||
check_result = await guard.check_output(output)
|
||||
assert check_result.passed is False
|
||||
assert len(check_result.violations) > 0
|
||||
assert "password" in check_result.violations
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 3: Router routes based on complexity (low→default, high→org_context)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRouterComplexityBasedRouting:
|
||||
"""基于复杂度的路由:低复杂度→默认,高复杂度→org_context 能力匹配"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_low_complexity_routes_to_default(self):
|
||||
gateway = _make_mock_gateway([
|
||||
_make_llm_response('{"complexity": 0.15}'),
|
||||
])
|
||||
|
||||
org_context = OrganizationContext()
|
||||
org_context.register_agent(AgentProfile(
|
||||
name="analyst",
|
||||
agent_type="react",
|
||||
capabilities=["analysis"],
|
||||
skills=["analysis"],
|
||||
))
|
||||
|
||||
router = CostAwareRouter(llm_gateway=gateway, org_context=org_context)
|
||||
|
||||
mock_skill_registry = MagicMock()
|
||||
mock_skill_registry.list_skills.return_value = []
|
||||
mock_intent_router = AsyncMock()
|
||||
|
||||
result = await router.route(
|
||||
content="简单问题",
|
||||
skill_registry=mock_skill_registry,
|
||||
intent_router=mock_intent_router,
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful",
|
||||
default_model="default",
|
||||
default_agent_name="default",
|
||||
)
|
||||
|
||||
assert result.complexity < 0.3
|
||||
assert result.agent_name == "default"
|
||||
assert result.match_method == "low_complexity"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_high_complexity_routes_via_org_context(self):
|
||||
gateway = _make_mock_gateway([
|
||||
_make_llm_response('{"complexity": 0.85}'),
|
||||
])
|
||||
|
||||
org_context = OrganizationContext()
|
||||
org_context.register_agent(AgentProfile(
|
||||
name="analyst",
|
||||
agent_type="react",
|
||||
capabilities=["分析", "市场", "调研"],
|
||||
skills=["market_analysis"],
|
||||
current_load=0,
|
||||
))
|
||||
|
||||
# find_best_agent returns real AgentProfile
|
||||
org_context.find_best_agent = MagicMock(
|
||||
return_value=org_context.get_agent_profile("analyst")
|
||||
)
|
||||
|
||||
router = CostAwareRouter(llm_gateway=gateway, org_context=org_context)
|
||||
|
||||
mock_skill_registry = MagicMock()
|
||||
mock_skill_registry.list_skills.return_value = []
|
||||
mock_intent_router = AsyncMock()
|
||||
|
||||
result = await router.route(
|
||||
content="请对市场趋势进行深度分析并给出投资建议",
|
||||
skill_registry=mock_skill_registry,
|
||||
intent_router=mock_intent_router,
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful",
|
||||
default_model="default",
|
||||
default_agent_name="default",
|
||||
)
|
||||
|
||||
assert result.complexity >= 0.7
|
||||
assert result.match_method == "capability"
|
||||
assert result.agent_name == "analyst"
|
||||
assert result.matched is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 4: AlignmentGuard injects constraints into input before engine execution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAlignmentGuardConstraintInjection:
|
||||
"""AlignmentGuard 在引擎执行前将约束注入到输入中"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_constraints_injected_before_engine_execution(self):
|
||||
gateway = _make_mock_gateway([
|
||||
_make_llm_response('{"complexity": 0.5}'),
|
||||
_make_llm_response("Safe answer"),
|
||||
])
|
||||
|
||||
alignment_config = AlignmentConfig(
|
||||
constraints=["不得泄露用户隐私", "禁止生成有害内容"],
|
||||
)
|
||||
guard = AlignmentGuard(config=alignment_config)
|
||||
|
||||
org_context = OrganizationContext()
|
||||
router = CostAwareRouter(llm_gateway=gateway, org_context=org_context)
|
||||
|
||||
mock_skill_registry = MagicMock()
|
||||
mock_skill_registry.list_skills.return_value = []
|
||||
mock_intent_router = AsyncMock()
|
||||
|
||||
# Step 1: Route
|
||||
route_result = await router.route(
|
||||
content="请帮我写一篇文章",
|
||||
skill_registry=mock_skill_registry,
|
||||
intent_router=mock_intent_router,
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful",
|
||||
default_model="default",
|
||||
default_agent_name="default",
|
||||
)
|
||||
|
||||
# Step 2: Inject constraints
|
||||
input_data = {"content": route_result.clean_content}
|
||||
injected = guard.inject_constraints(input_data)
|
||||
|
||||
# Verify constraints are present
|
||||
assert "alignment_constraints" in injected
|
||||
assert "不得泄露用户隐私" in injected["alignment_constraints"]
|
||||
assert "禁止生成有害内容" in injected["alignment_constraints"]
|
||||
# Original data preserved
|
||||
assert injected["content"] == route_result.clean_content
|
||||
# Original dict not mutated
|
||||
assert "alignment_constraints" not in input_data
|
||||
|
||||
# Step 3: Engine executes with injected input
|
||||
react_engine = ReActEngine(llm_gateway=gateway)
|
||||
engine_result = await react_engine.execute(
|
||||
messages=[{"role": "user", "content": injected["content"]}],
|
||||
)
|
||||
|
||||
# Step 4: Output passes alignment
|
||||
output = {"result": engine_result.output}
|
||||
check_result = await guard.check_output(output)
|
||||
assert check_result.passed is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_constraint_injection_with_cascade_monitoring(self):
|
||||
"""约束注入 + 级联故障监控的完整链路"""
|
||||
alignment_config = AlignmentConfig(
|
||||
constraints=["password"],
|
||||
cascade_max_interactions=5,
|
||||
)
|
||||
guard = AlignmentGuard(config=alignment_config)
|
||||
|
||||
# Inject constraints
|
||||
input_data = {"content": "请帮我重置密码"}
|
||||
injected = guard.inject_constraints(input_data)
|
||||
assert "alignment_constraints" in injected
|
||||
|
||||
# Simulate safe output
|
||||
output = {"result": "密码重置链接已发送到您的邮箱。"}
|
||||
check_result = await guard.check_output(output)
|
||||
assert check_result.passed is True
|
||||
|
||||
# Record interactions — no cascade alert
|
||||
alert = guard.record_interaction("session-chain-1")
|
||||
assert alert is None
|
||||
assert guard.get_interaction_count("session-chain-1") == 1
|
||||
|
|
@ -0,0 +1,415 @@
|
|||
"""集成测试 - Soul 进化触发条件
|
||||
|
||||
测试 EvolutionMixin.evolve_soul 的多维触发逻辑:
|
||||
- 时间窗口内反思计数
|
||||
- 质量梯度(下降分数)触发早期进化
|
||||
- 任务类型权重调整触发阈值
|
||||
- 时间衰减降低旧反思的有效计数
|
||||
|
||||
仅 mock MemoryTool(文件 I/O),使用真实 EvolutionMixin + SoulEvolutionConfig 实例。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus
|
||||
from agentkit.evolution.lifecycle import EvolutionMixin, SoulEvolutionConfig
|
||||
from agentkit.evolution.reflector import Reflection
|
||||
from agentkit.memory.profile import MemoryStore
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def make_task(task_id: str = "task-1") -> TaskMessage:
|
||||
return TaskMessage(
|
||||
task_id=task_id,
|
||||
agent_name="test_agent",
|
||||
task_type="analysis",
|
||||
priority=1,
|
||||
input_data={"content": "test task"},
|
||||
callback_url=None,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
def make_result(task_id: str = "task-1") -> TaskResult:
|
||||
return TaskResult(
|
||||
task_id=task_id,
|
||||
agent_name="test_agent",
|
||||
status=TaskStatus.COMPLETED,
|
||||
output_data={"result": "done"},
|
||||
error_message=None,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
def make_reflection(
|
||||
quality_score: float = 0.3,
|
||||
patterns: list[str] | None = None,
|
||||
suggestions: list[str] | None = None,
|
||||
) -> Reflection:
|
||||
# Use explicit None check to allow empty list for suggestions
|
||||
if suggestions is None:
|
||||
suggestions = ["Add more detail", "Be more specific"]
|
||||
return Reflection(
|
||||
task_id="task-1",
|
||||
agent_name="test_agent",
|
||||
outcome="partial",
|
||||
quality_score=quality_score,
|
||||
patterns=patterns or ["reasoning"],
|
||||
insights=["Needs improvement"],
|
||||
suggestions=suggestions,
|
||||
)
|
||||
|
||||
|
||||
def make_mock_memory_store() -> MagicMock:
|
||||
"""创建 mock MemoryStore,模拟 get_file 返回可操作的 MemoryFile"""
|
||||
store = MagicMock(spec=MemoryStore)
|
||||
mock_file = MagicMock()
|
||||
mock_file.read_section.return_value = "版本: 1\n更新时间: 2025-01-01T00:00:00"
|
||||
mock_file.list_sections.return_value = ["身份", "版本"]
|
||||
store.get_file.return_value = mock_file
|
||||
return store
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 1: 3 reflections within window trigger evolution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReflectionCountTrigger:
|
||||
"""时间窗口内 3 次反思触发进化"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_three_reflections_trigger_evolution(self):
|
||||
config = SoulEvolutionConfig(
|
||||
min_reflections=3,
|
||||
reflection_window_seconds=3600,
|
||||
time_decay_factor=1.0, # No decay for this test
|
||||
)
|
||||
|
||||
mixin = EvolutionMixin(evolution_config=config)
|
||||
memory_store = make_mock_memory_store()
|
||||
|
||||
# Record 3 reflections within the window
|
||||
task = make_task()
|
||||
result = make_result()
|
||||
|
||||
with patch("agentkit.tools.memory_tool.MemoryTool.execute", new_callable=AsyncMock) as mock_execute:
|
||||
mock_execute.return_value = {"success": True, "version": 2}
|
||||
|
||||
# First reflection — should not trigger
|
||||
reflection1 = make_reflection(quality_score=0.3, patterns=["reasoning"])
|
||||
evolved = await mixin.evolve_soul(task, result, memory_store, reflection=reflection1)
|
||||
assert evolved is False
|
||||
|
||||
# Second reflection — should not trigger
|
||||
reflection2 = make_reflection(quality_score=0.25, patterns=["reasoning"])
|
||||
evolved = await mixin.evolve_soul(task, result, memory_store, reflection=reflection2)
|
||||
assert evolved is False
|
||||
|
||||
# Third reflection — should trigger (3 >= min_reflections=3)
|
||||
reflection3 = make_reflection(quality_score=0.2, patterns=["reasoning"])
|
||||
evolved = await mixin.evolve_soul(task, result, memory_store, reflection=reflection3)
|
||||
assert evolved is True
|
||||
|
||||
# MemoryTool.execute should have been called for the soul update
|
||||
mock_execute.assert_called_once()
|
||||
call_kwargs = mock_execute.call_args[1]
|
||||
assert call_kwargs["action"] == "update_soul"
|
||||
assert call_kwargs["file"] == "soul"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_two_reflections_do_not_trigger(self):
|
||||
config = SoulEvolutionConfig(
|
||||
min_reflections=3,
|
||||
time_decay_factor=1.0,
|
||||
)
|
||||
|
||||
mixin = EvolutionMixin(evolution_config=config)
|
||||
memory_store = make_mock_memory_store()
|
||||
|
||||
task = make_task()
|
||||
result = make_result()
|
||||
|
||||
# First reflection
|
||||
reflection1 = make_reflection(quality_score=0.3, patterns=["reasoning"])
|
||||
evolved = await mixin.evolve_soul(task, result, memory_store, reflection=reflection1)
|
||||
assert evolved is False
|
||||
|
||||
# Second reflection — still not enough
|
||||
reflection2 = make_reflection(quality_score=0.25, patterns=["reasoning"])
|
||||
evolved = await mixin.evolve_soul(task, result, memory_store, reflection=reflection2)
|
||||
assert evolved is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 2: Quality gradient (declining scores) triggers early evolution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestQualityGradientTrigger:
|
||||
"""质量梯度(持续下降分数)触发早期进化"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_declining_scores_trigger_early_evolution(self):
|
||||
config = SoulEvolutionConfig(
|
||||
min_reflections=10, # High threshold — won't trigger by count
|
||||
quality_gradient_threshold=-0.15,
|
||||
time_decay_factor=1.0,
|
||||
)
|
||||
|
||||
mixin = EvolutionMixin(evolution_config=config)
|
||||
memory_store = make_mock_memory_store()
|
||||
|
||||
task = make_task()
|
||||
result = make_result()
|
||||
|
||||
with patch("agentkit.tools.memory_tool.MemoryTool.execute", new_callable=AsyncMock) as mock_execute:
|
||||
mock_execute.return_value = {"success": True, "version": 2}
|
||||
|
||||
# Record 3 reflections with declining scores (all < 0.5 to pass the quality check)
|
||||
# Score drops: 0.45 → 0.25 → 0.05 (each drop > 0.15)
|
||||
reflection1 = make_reflection(quality_score=0.45, patterns=["reasoning"])
|
||||
await mixin.evolve_soul(task, result, memory_store, reflection=reflection1)
|
||||
|
||||
reflection2 = make_reflection(quality_score=0.25, patterns=["reasoning"])
|
||||
await mixin.evolve_soul(task, result, memory_store, reflection=reflection2)
|
||||
|
||||
# Third reflection with continued decline should trigger quality gradient
|
||||
reflection3 = make_reflection(quality_score=0.05, patterns=["reasoning"])
|
||||
evolved = await mixin.evolve_soul(task, result, memory_store, reflection=reflection3)
|
||||
|
||||
# Quality gradient: 0.45→0.25 (drop=-0.2), 0.25→0.05 (drop=-0.2)
|
||||
# Both drops <= -0.15, so quality_gradient_triggered = True
|
||||
assert evolved is True
|
||||
mock_execute.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stable_scores_do_not_trigger_gradient(self):
|
||||
config = SoulEvolutionConfig(
|
||||
min_reflections=10, # High threshold
|
||||
quality_gradient_threshold=-0.15,
|
||||
time_decay_factor=1.0,
|
||||
)
|
||||
|
||||
mixin = EvolutionMixin(evolution_config=config)
|
||||
memory_store = make_mock_memory_store()
|
||||
|
||||
task = make_task()
|
||||
result = make_result()
|
||||
|
||||
# Record 3 reflections with stable/improving scores
|
||||
reflection1 = make_reflection(quality_score=0.3, patterns=["reasoning"])
|
||||
await mixin.evolve_soul(task, result, memory_store, reflection=reflection1)
|
||||
|
||||
reflection2 = make_reflection(quality_score=0.35, patterns=["reasoning"])
|
||||
await mixin.evolve_soul(task, result, memory_store, reflection=reflection2)
|
||||
|
||||
reflection3 = make_reflection(quality_score=0.4, patterns=["reasoning"])
|
||||
evolved = await mixin.evolve_soul(task, result, memory_store, reflection=reflection3)
|
||||
|
||||
# Scores are improving, no quality gradient trigger
|
||||
assert evolved is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 3: Task type weight adjusts trigger threshold
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTaskTypeWeightTrigger:
|
||||
"""任务类型权重调整触发阈值"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_high_weight_reduces_effective_threshold(self):
|
||||
"""高权重降低有效触发阈值:2 次反思 × 权重 2.0 = 有效 4.0 >= min_reflections 3"""
|
||||
config = SoulEvolutionConfig(
|
||||
min_reflections=3,
|
||||
time_decay_factor=1.0,
|
||||
task_type_weights={"critical": 2.0},
|
||||
)
|
||||
|
||||
mixin = EvolutionMixin(evolution_config=config)
|
||||
memory_store = make_mock_memory_store()
|
||||
|
||||
task = make_task()
|
||||
result = make_result()
|
||||
|
||||
with patch("agentkit.tools.memory_tool.MemoryTool.execute", new_callable=AsyncMock) as mock_execute:
|
||||
mock_execute.return_value = {"success": True, "version": 2}
|
||||
|
||||
# First reflection with critical task type
|
||||
reflection1 = make_reflection(quality_score=0.3, patterns=["reasoning"])
|
||||
evolved = await mixin.evolve_soul(
|
||||
task, result, memory_store,
|
||||
reflection=reflection1,
|
||||
task_type="critical",
|
||||
)
|
||||
# 1 reflection × weight 2.0 = 2.0 < 3
|
||||
assert evolved is False
|
||||
|
||||
# Second reflection with critical task type
|
||||
reflection2 = make_reflection(quality_score=0.25, patterns=["reasoning"])
|
||||
evolved = await mixin.evolve_soul(
|
||||
task, result, memory_store,
|
||||
reflection=reflection2,
|
||||
task_type="critical",
|
||||
)
|
||||
# 2 reflections × weight 2.0 = 4.0 >= 3
|
||||
assert evolved is True
|
||||
mock_execute.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_low_weight_increases_effective_threshold(self):
|
||||
"""低权重增加有效触发阈值:3 次反思 × 权重 0.5 = 有效 1.5 < min_reflections 3"""
|
||||
config = SoulEvolutionConfig(
|
||||
min_reflections=3,
|
||||
time_decay_factor=1.0,
|
||||
task_type_weights={"low_priority": 0.5},
|
||||
)
|
||||
|
||||
mixin = EvolutionMixin(evolution_config=config)
|
||||
memory_store = make_mock_memory_store()
|
||||
|
||||
task = make_task()
|
||||
result = make_result()
|
||||
|
||||
# 3 reflections with low_priority task type
|
||||
for i in range(3):
|
||||
reflection = make_reflection(quality_score=0.3, patterns=["reasoning"])
|
||||
evolved = await mixin.evolve_soul(
|
||||
task, result, memory_store,
|
||||
reflection=reflection,
|
||||
task_type="low_priority",
|
||||
)
|
||||
# 3 × 0.5 = 1.5 < 3 → should not trigger
|
||||
assert evolved is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 4: Time decay reduces effective count for old reflections
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTimeDecayReducesEffectiveCount:
|
||||
"""时间衰减降低旧反思的有效计数"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_old_reflections_decay_below_threshold(self):
|
||||
"""旧反思因时间衰减导致有效计数不足"""
|
||||
config = SoulEvolutionConfig(
|
||||
min_reflections=3,
|
||||
reflection_window_seconds=3600,
|
||||
time_decay_factor=0.5, # Half-life of 1 hour
|
||||
)
|
||||
|
||||
mixin = EvolutionMixin(evolution_config=config)
|
||||
memory_store = make_mock_memory_store()
|
||||
|
||||
task = make_task()
|
||||
result = make_result()
|
||||
|
||||
# Manually add old reflections to pending_soul_updates
|
||||
now = datetime.now(timezone.utc)
|
||||
old_timestamp = now - timedelta(hours=3) # 3 hours ago
|
||||
|
||||
# Add 2 old reflections manually
|
||||
mixin.pending_soul_updates["reasoning"] = [
|
||||
{
|
||||
"reflection": make_reflection(quality_score=0.3),
|
||||
"timestamp": old_timestamp,
|
||||
"score": 0.3,
|
||||
"task_type": "",
|
||||
},
|
||||
{
|
||||
"reflection": make_reflection(quality_score=0.25),
|
||||
"timestamp": old_timestamp,
|
||||
"score": 0.25,
|
||||
"task_type": "",
|
||||
},
|
||||
]
|
||||
|
||||
# Add a recent reflection via evolve_soul
|
||||
# Time decay: 0.5^3 = 0.125 per old reflection → 2 × 0.125 = 0.25
|
||||
# Plus 1 new reflection → total effective ≈ 1.25 < 3
|
||||
recent_reflection = make_reflection(quality_score=0.2, patterns=["reasoning"])
|
||||
evolved = await mixin.evolve_soul(task, result, memory_store, reflection=recent_reflection)
|
||||
|
||||
# Effective count should be well below 3 due to decay
|
||||
assert evolved is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recent_reflections_no_decay(self):
|
||||
"""近期反思不受时间衰减影响"""
|
||||
config = SoulEvolutionConfig(
|
||||
min_reflections=3,
|
||||
time_decay_factor=0.5,
|
||||
)
|
||||
|
||||
mixin = EvolutionMixin(evolution_config=config)
|
||||
memory_store = make_mock_memory_store()
|
||||
|
||||
task = make_task()
|
||||
result = make_result()
|
||||
|
||||
with patch("agentkit.tools.memory_tool.MemoryTool.execute", new_callable=AsyncMock) as mock_execute:
|
||||
mock_execute.return_value = {"success": True, "version": 2}
|
||||
|
||||
# 3 recent reflections should trigger (no significant decay)
|
||||
for i in range(3):
|
||||
reflection = make_reflection(quality_score=0.3, patterns=["reasoning"])
|
||||
evolved = await mixin.evolve_soul(task, result, memory_store, reflection=reflection)
|
||||
|
||||
assert evolved is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_memory_store_returns_false(self):
|
||||
"""无 MemoryStore 时不触发进化"""
|
||||
config = SoulEvolutionConfig(min_reflections=1)
|
||||
mixin = EvolutionMixin(evolution_config=config)
|
||||
|
||||
task = make_task()
|
||||
result = make_result()
|
||||
reflection = make_reflection(quality_score=0.3, patterns=["reasoning"])
|
||||
|
||||
evolved = await mixin.evolve_soul(task, result, memory_store=None, reflection=reflection)
|
||||
assert evolved is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_high_quality_reflection_does_not_trigger(self):
|
||||
"""高质量反思不触发进化(quality_score >= 0.5)"""
|
||||
config = SoulEvolutionConfig(min_reflections=1)
|
||||
mixin = EvolutionMixin(evolution_config=config)
|
||||
memory_store = make_mock_memory_store()
|
||||
|
||||
task = make_task()
|
||||
result = make_result()
|
||||
|
||||
# High quality reflection — should not even be recorded
|
||||
reflection = make_reflection(quality_score=0.8, patterns=["reasoning"])
|
||||
evolved = await mixin.evolve_soul(task, result, memory_store, reflection=reflection)
|
||||
assert evolved is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_suggestions_does_not_trigger(self):
|
||||
"""无建议的反思不触发进化"""
|
||||
config = SoulEvolutionConfig(min_reflections=1)
|
||||
mixin = EvolutionMixin(evolution_config=config)
|
||||
memory_store = make_mock_memory_store()
|
||||
|
||||
task = make_task()
|
||||
result = make_result()
|
||||
|
||||
# Low quality but no suggestions
|
||||
reflection = make_reflection(quality_score=0.3, patterns=["reasoning"], suggestions=[])
|
||||
evolved = await mixin.evolve_soul(task, result, memory_store, reflection=reflection)
|
||||
assert evolved is False
|
||||
|
|
@ -0,0 +1,331 @@
|
|||
"""Tests for Agent inter-communication bus — new features.
|
||||
|
||||
Covers: msg_type, content, TTL, CascadeDetector integration,
|
||||
AlignmentGuard integration, negotiate, request-timeout-returns-None.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.bus.message import AgentMessage
|
||||
from agentkit.bus.memory_bus import InMemoryMessageBus
|
||||
from agentkit.quality.cascade_detector import CascadeDetector
|
||||
from agentkit.quality.alignment import AlignmentGuard, AlignmentConfig
|
||||
|
||||
|
||||
# ── AgentMessage new fields ───────────────────────────────
|
||||
|
||||
|
||||
class TestAgentMessageNewFields:
|
||||
def test_msg_type_default(self):
|
||||
msg = AgentMessage(sender="a")
|
||||
assert msg.msg_type == "notify"
|
||||
|
||||
def test_content_field(self):
|
||||
msg = AgentMessage(sender="a", content="hello")
|
||||
assert msg.content == "hello"
|
||||
|
||||
def test_ttl_default(self):
|
||||
msg = AgentMessage(sender="a")
|
||||
assert msg.ttl_seconds == 300
|
||||
|
||||
def test_is_expired_fresh(self):
|
||||
msg = AgentMessage(sender="a")
|
||||
assert msg.is_expired() is False
|
||||
|
||||
def test_is_expired_old(self):
|
||||
old_ts = datetime.now(timezone.utc) - timedelta(seconds=600)
|
||||
msg = AgentMessage(sender="a", ttl_seconds=300, timestamp=old_ts)
|
||||
assert msg.is_expired() is True
|
||||
|
||||
def test_to_dict_roundtrip_with_new_fields(self):
|
||||
msg = AgentMessage(
|
||||
sender="a",
|
||||
content="data",
|
||||
msg_type="negotiate",
|
||||
ttl_seconds=60,
|
||||
)
|
||||
d = msg.to_dict()
|
||||
restored = AgentMessage.from_dict(d)
|
||||
assert restored.content == "data"
|
||||
assert restored.msg_type == "negotiate"
|
||||
assert restored.ttl_seconds == 60
|
||||
|
||||
|
||||
# ── Request-Response with correlation_id ──────────────────
|
||||
|
||||
|
||||
class TestRequestResponse:
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_response_correlation(self):
|
||||
"""Agent A sends request, Agent B responds, correlation_id matches."""
|
||||
bus = InMemoryMessageBus()
|
||||
|
||||
async def handler_b(msg: AgentMessage):
|
||||
reply = AgentMessage(
|
||||
sender="agent_b",
|
||||
recipient=msg.sender,
|
||||
content="answer",
|
||||
msg_type="response",
|
||||
correlation_id=msg.correlation_id,
|
||||
)
|
||||
await bus.publish(reply)
|
||||
|
||||
await bus.subscribe("agent_b", handler_b)
|
||||
|
||||
request = AgentMessage(
|
||||
sender="agent_a",
|
||||
recipient="agent_b",
|
||||
content="question",
|
||||
)
|
||||
response = await bus.request(request, timeout_seconds=5.0)
|
||||
assert response is not None
|
||||
assert response.content == "answer"
|
||||
assert response.correlation_id == request.correlation_id
|
||||
|
||||
|
||||
# ── Broadcast ─────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestBroadcast:
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_all_subscribers_receive(self):
|
||||
"""Agent A broadcasts, all subscribers receive."""
|
||||
bus = InMemoryMessageBus()
|
||||
received: dict[str, list[AgentMessage]] = {"b": [], "c": []}
|
||||
|
||||
async def handler_b(msg: AgentMessage):
|
||||
received["b"].append(msg)
|
||||
|
||||
async def handler_c(msg: AgentMessage):
|
||||
received["c"].append(msg)
|
||||
|
||||
await bus.subscribe("agent_b", handler_b)
|
||||
await bus.subscribe("agent_c", handler_c)
|
||||
|
||||
result = await bus.publish(AgentMessage(
|
||||
sender="agent_a",
|
||||
content="hello everyone",
|
||||
msg_type="notify",
|
||||
))
|
||||
assert result is True
|
||||
assert len(received["b"]) == 1
|
||||
assert len(received["c"]) == 1
|
||||
assert received["b"][0].content == "hello everyone"
|
||||
|
||||
|
||||
# ── Negotiate ─────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestNegotiate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_negotiate_response(self):
|
||||
"""Agent A sends negotiate, Agent B responds."""
|
||||
bus = InMemoryMessageBus()
|
||||
|
||||
async def handler_b(msg: AgentMessage):
|
||||
reply = AgentMessage(
|
||||
sender="agent_b",
|
||||
recipient=msg.sender,
|
||||
content="deal accepted",
|
||||
msg_type="response",
|
||||
correlation_id=msg.correlation_id,
|
||||
)
|
||||
await bus.publish(reply)
|
||||
|
||||
await bus.subscribe("agent_b", handler_b)
|
||||
|
||||
request = AgentMessage(
|
||||
sender="agent_a",
|
||||
recipient="agent_b",
|
||||
content="propose deal",
|
||||
msg_type="negotiate",
|
||||
)
|
||||
response = await bus.request(request, timeout_seconds=5.0)
|
||||
assert response is not None
|
||||
assert response.content == "deal accepted"
|
||||
assert response.msg_type == "response"
|
||||
|
||||
|
||||
# ── TTL Expired ───────────────────────────────────────────
|
||||
|
||||
|
||||
class TestTTLExpired:
|
||||
@pytest.mark.asyncio
|
||||
async def test_expired_message_dropped(self):
|
||||
"""Expired message is dropped by publish()."""
|
||||
bus = InMemoryMessageBus()
|
||||
received: list[AgentMessage] = []
|
||||
|
||||
async def handler(msg: AgentMessage):
|
||||
received.append(msg)
|
||||
|
||||
await bus.subscribe("agent_b", handler)
|
||||
|
||||
old_ts = datetime.now(timezone.utc) - timedelta(seconds=600)
|
||||
msg = AgentMessage(
|
||||
sender="agent_a",
|
||||
recipient="agent_b",
|
||||
content="old news",
|
||||
ttl_seconds=300,
|
||||
timestamp=old_ts,
|
||||
)
|
||||
result = await bus.publish(msg)
|
||||
assert result is False
|
||||
await asyncio.sleep(0.05)
|
||||
assert len(received) == 0
|
||||
|
||||
|
||||
# ── Unsubscribe ───────────────────────────────────────────
|
||||
|
||||
|
||||
class TestUnsubscribe:
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsubscribed_agent_no_receive(self):
|
||||
"""Unsubscribed agent doesn't receive messages."""
|
||||
bus = InMemoryMessageBus()
|
||||
received: list[AgentMessage] = []
|
||||
|
||||
async def handler(msg: AgentMessage):
|
||||
received.append(msg)
|
||||
|
||||
await bus.subscribe("agent_b", handler)
|
||||
await bus.unsubscribe("agent_b")
|
||||
|
||||
await bus.publish(AgentMessage(
|
||||
sender="agent_a",
|
||||
content="still here?",
|
||||
))
|
||||
await asyncio.sleep(0.05)
|
||||
assert len(received) == 0
|
||||
|
||||
|
||||
# ── Request Timeout ───────────────────────────────────────
|
||||
|
||||
|
||||
class TestRequestTimeout:
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_timeout_returns_none(self):
|
||||
"""No response within timeout returns None."""
|
||||
bus = InMemoryMessageBus()
|
||||
|
||||
# No subscriber for agent_b
|
||||
request = AgentMessage(
|
||||
sender="agent_a",
|
||||
recipient="agent_b",
|
||||
content="anyone there?",
|
||||
)
|
||||
response = await bus.request(request, timeout_seconds=0.1)
|
||||
assert response is None
|
||||
|
||||
|
||||
# ── CascadeDetector Integration ───────────────────────────
|
||||
|
||||
|
||||
class TestCascadeDetectorIntegration:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cascade_alert_blocks_message(self):
|
||||
"""Too many messages trigger cascade alert and block publishing."""
|
||||
detector = CascadeDetector(max_interactions=3)
|
||||
bus = InMemoryMessageBus(cascade_detector=detector)
|
||||
|
||||
received: list[AgentMessage] = []
|
||||
|
||||
async def handler(msg: AgentMessage):
|
||||
received.append(msg)
|
||||
|
||||
await bus.subscribe("agent_b", handler)
|
||||
|
||||
# First 3 interactions should succeed (check_interaction increments before check)
|
||||
# max_interactions=3 means count > 3 triggers alert, so 3 succeed, 4th fails
|
||||
for i in range(3):
|
||||
result = await bus.publish(AgentMessage(
|
||||
sender="agent_a",
|
||||
recipient="agent_b",
|
||||
content=f"msg {i}",
|
||||
))
|
||||
assert result is True
|
||||
|
||||
# 4th interaction should be blocked
|
||||
result = await bus.publish(AgentMessage(
|
||||
sender="agent_a",
|
||||
recipient="agent_b",
|
||||
content="msg 4",
|
||||
))
|
||||
assert result is False
|
||||
|
||||
|
||||
# ── AlignmentGuard Integration ────────────────────────────
|
||||
|
||||
|
||||
class TestAlignmentGuardIntegration:
|
||||
@pytest.mark.asyncio
|
||||
async def test_violating_message_blocked(self):
|
||||
"""Message violating alignment constraints is blocked."""
|
||||
config = AlignmentConfig(constraints=["禁止暴力"])
|
||||
guard = AlignmentGuard(config=config)
|
||||
bus = InMemoryMessageBus(alignment_guard=guard)
|
||||
|
||||
received: list[AgentMessage] = []
|
||||
|
||||
async def handler(msg: AgentMessage):
|
||||
received.append(msg)
|
||||
|
||||
await bus.subscribe("agent_b", handler)
|
||||
|
||||
# request with violating content should be blocked
|
||||
result = await bus.publish(AgentMessage(
|
||||
sender="agent_a",
|
||||
recipient="agent_b",
|
||||
content="执行暴力行为",
|
||||
msg_type="request",
|
||||
))
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_violating_request_passes(self):
|
||||
"""Non-violating request message passes alignment check."""
|
||||
config = AlignmentConfig(constraints=["禁止暴力"])
|
||||
guard = AlignmentGuard(config=config)
|
||||
bus = InMemoryMessageBus(alignment_guard=guard)
|
||||
|
||||
received: list[AgentMessage] = []
|
||||
|
||||
async def handler(msg: AgentMessage):
|
||||
received.append(msg)
|
||||
|
||||
await bus.subscribe("agent_b", handler)
|
||||
|
||||
result = await bus.publish(AgentMessage(
|
||||
sender="agent_a",
|
||||
recipient="agent_b",
|
||||
content="和平交流",
|
||||
msg_type="request",
|
||||
))
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_notify_not_checked_by_alignment(self):
|
||||
"""notify type messages are not checked by alignment guard."""
|
||||
config = AlignmentConfig(constraints=["禁止暴力"])
|
||||
guard = AlignmentGuard(config=config)
|
||||
bus = InMemoryMessageBus(alignment_guard=guard)
|
||||
|
||||
received: list[AgentMessage] = []
|
||||
|
||||
async def handler(msg: AgentMessage):
|
||||
received.append(msg)
|
||||
|
||||
await bus.subscribe("agent_b", handler)
|
||||
|
||||
# notify with violating content should pass (not checked)
|
||||
result = await bus.publish(AgentMessage(
|
||||
sender="agent_a",
|
||||
recipient="agent_b",
|
||||
content="执行暴力行为",
|
||||
msg_type="notify",
|
||||
))
|
||||
assert result is True
|
||||
|
|
@ -126,7 +126,7 @@ class TestInMemoryMessageBus:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_timeout(self):
|
||||
"""请求超时后抛出异常。"""
|
||||
"""请求超时后返回 None。"""
|
||||
bus = InMemoryMessageBus()
|
||||
|
||||
# No one is subscribed to handle the request
|
||||
|
|
@ -136,8 +136,8 @@ class TestInMemoryMessageBus:
|
|||
topic="question",
|
||||
)
|
||||
|
||||
with pytest.raises(TimeoutError):
|
||||
await bus.request(request, timeout=0.1)
|
||||
result = await bus.request(request, timeout=0.1)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsubscribe_stops_delivery(self):
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from agentkit.chat.skill_routing import CostAwareRouter, SkillRoutingResult
|
||||
from agentkit.chat.skill_routing import CostAwareRouter, SkillRoutingResult, _tokenize_content
|
||||
from agentkit.llm.protocol import LLMResponse, TokenUsage
|
||||
from agentkit.router.intent import IntentRouter, RoutingResult
|
||||
from agentkit.skills.base import IntentConfig, Skill, SkillConfig
|
||||
|
|
@ -466,3 +466,47 @@ class TestSkillRoutingResultNewFields:
|
|||
assert result.transparency_level == "SILENT"
|
||||
assert result.execution_trace == []
|
||||
assert result.complexity == 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _tokenize_content: 中文分词增强
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTokenizeContent:
|
||||
"""_tokenize_content 中文分词增强测试"""
|
||||
|
||||
def test_chinese_content(self):
|
||||
"""中文内容:'帮我做数据分析' 应包含 '数据分析' 相关 2-gram"""
|
||||
tokens = _tokenize_content("帮我做数据分析")
|
||||
# 整段无标点分隔,生成 2-gram:帮我、我做、做数、数据、据分、分析
|
||||
assert "数据" in tokens or "数据分析" in tokens
|
||||
|
||||
def test_english_content(self):
|
||||
"""英文内容:'help with code generation' 应包含 'code', 'generation' 或 'code generation'"""
|
||||
tokens = _tokenize_content("help with code generation")
|
||||
assert "code" in tokens or "generation" in tokens or "code generation" in tokens
|
||||
|
||||
def test_mixed_content(self):
|
||||
"""中英混合:'用python做data analysis' 应包含 'python' 相关 token 和 'data analysis'"""
|
||||
tokens = _tokenize_content("用python做data analysis")
|
||||
# 按空格分割后 "用python做data" 作为一个 segment,生成 2-gram
|
||||
# "analysis" 作为独立 segment
|
||||
assert "analysis" in tokens
|
||||
# "用python做data" 长度 > 4,会生成 2-gram,其中包含 python 相关片段
|
||||
has_python_related = any("python" in t for t in tokens)
|
||||
assert has_python_related or "data analysis" in tokens
|
||||
|
||||
def test_stopwords_filtered(self):
|
||||
"""停用词过滤:纯停用词短句过滤后应为空或极少 token"""
|
||||
tokens = _tokenize_content("我的一个")
|
||||
# "我的一个" 长度 4,作为整体保留(不在停用词集合中)
|
||||
# 但停用词 "我的" "的一" "一个" 等 2-gram 会被过滤
|
||||
assert len(tokens) <= 1
|
||||
|
||||
def test_bigram_generation(self):
|
||||
"""2-gram 生成:'机器学习模型训练' 应包含各 2-gram"""
|
||||
tokens = _tokenize_content("机器学习模型训练")
|
||||
expected_bigrams = ["机器", "器学", "学习", "习模", "模型", "型训", "训练"]
|
||||
for bigram in expected_bigrams:
|
||||
assert bigram in tokens, f"缺少 2-gram: {bigram}"
|
||||
|
|
|
|||
|
|
@ -267,8 +267,9 @@ class TestReWOOPlanningFailureFallback:
|
|||
async def test_invalid_json_falls_back_to_react(self):
|
||||
from agentkit.core.rewoo import ReWOOEngine
|
||||
|
||||
# Planning returns invalid JSON
|
||||
# Planning returns invalid JSON, simplified planning also fails
|
||||
invalid_plan_response = make_response(content="I cannot create a plan for this task.")
|
||||
simplified_fail_response = make_response(content="Still cannot create a plan")
|
||||
# ReAct fallback responses
|
||||
react_tool_response = make_response(
|
||||
content="",
|
||||
|
|
@ -278,6 +279,7 @@ class TestReWOOPlanningFailureFallback:
|
|||
|
||||
gateway = make_mock_gateway([
|
||||
invalid_plan_response,
|
||||
simplified_fail_response,
|
||||
react_tool_response,
|
||||
react_final_response,
|
||||
])
|
||||
|
|
@ -292,15 +294,17 @@ class TestReWOOPlanningFailureFallback:
|
|||
# Should have fallen back to ReAct and produced a result
|
||||
assert result.output == "ReAct fallback answer"
|
||||
assert result.total_steps >= 1
|
||||
assert result.fallback_strategy == "react"
|
||||
|
||||
async def test_malformed_json_falls_back_to_react(self):
|
||||
from agentkit.core.rewoo import ReWOOEngine
|
||||
|
||||
# Planning returns malformed JSON
|
||||
# Planning returns malformed JSON, simplified planning also fails, ReAct succeeds
|
||||
malformed_response = make_response(content='{"reasoning": "plan", "steps": [invalid json')
|
||||
simplified_fail_response = make_response(content='Also not a plan')
|
||||
react_response = make_response(content="ReAct answer")
|
||||
|
||||
gateway = make_mock_gateway([malformed_response, react_response])
|
||||
gateway = make_mock_gateway([malformed_response, simplified_fail_response, react_response])
|
||||
engine = ReWOOEngine(llm_gateway=gateway)
|
||||
|
||||
result = await engine.execute(
|
||||
|
|
@ -308,15 +312,17 @@ class TestReWOOPlanningFailureFallback:
|
|||
)
|
||||
|
||||
assert result.output == "ReAct answer"
|
||||
assert result.fallback_strategy == "react"
|
||||
|
||||
async def test_missing_steps_key_falls_back_to_react(self):
|
||||
from agentkit.core.rewoo import ReWOOEngine
|
||||
|
||||
# JSON without "steps" key
|
||||
# JSON without "steps" key, simplified planning also fails, ReAct succeeds
|
||||
no_steps_response = make_response(content='{"reasoning": "no steps here"}')
|
||||
simplified_fail_response = make_response(content='Also no steps')
|
||||
react_response = make_response(content="ReAct fallback")
|
||||
|
||||
gateway = make_mock_gateway([no_steps_response, react_response])
|
||||
gateway = make_mock_gateway([no_steps_response, simplified_fail_response, react_response])
|
||||
engine = ReWOOEngine(llm_gateway=gateway)
|
||||
|
||||
result = await engine.execute(
|
||||
|
|
@ -324,6 +330,7 @@ class TestReWOOPlanningFailureFallback:
|
|||
)
|
||||
|
||||
assert result.output == "ReAct fallback"
|
||||
assert result.fallback_strategy == "react"
|
||||
|
||||
|
||||
# ── Test: Cancellation Token ──────────────────────────────
|
||||
|
|
@ -706,11 +713,12 @@ class TestReWOOStreaming:
|
|||
async def test_stream_planning_failure_falls_back(self):
|
||||
from agentkit.core.rewoo import ReWOOEngine
|
||||
|
||||
# Invalid plan, then ReAct fallback
|
||||
# Invalid plan, simplified also fails, then ReAct fallback
|
||||
invalid_plan = make_response(content="Not a plan")
|
||||
simplified_fail = make_response(content="Still not a plan")
|
||||
react_response = make_response(content="ReAct answer")
|
||||
|
||||
gateway = make_mock_gateway([invalid_plan, react_response])
|
||||
gateway = make_mock_gateway([invalid_plan, simplified_fail, react_response])
|
||||
engine = ReWOOEngine(llm_gateway=gateway)
|
||||
|
||||
events = []
|
||||
|
|
@ -842,3 +850,148 @@ class TestReWOOMaxPlanSteps:
|
|||
|
||||
# Need to import ReActResult for type checking in tests
|
||||
from agentkit.core.react import ReActResult
|
||||
|
||||
|
||||
# ── Test: Progressive Fallback Chain ──────────────────────
|
||||
|
||||
|
||||
class TestReWOOProgressiveFallback:
|
||||
"""渐进式回退链:planning → simplified_rewoo → react → direct"""
|
||||
|
||||
async def test_normal_planning_succeeds_no_fallback(self):
|
||||
"""正常规划成功,不使用回退"""
|
||||
from agentkit.core.rewoo import ReWOOEngine
|
||||
|
||||
tool = FakeTool(name="calculator", result={"value": 42})
|
||||
|
||||
plan_response = make_plan_response([
|
||||
{"step_id": 1, "tool_name": "calculator", "arguments": {"expr": "6*7"}, "reasoning": "Calculate"},
|
||||
])
|
||||
synthesis_response = make_response(content="The result is 42")
|
||||
|
||||
gateway = make_mock_gateway([plan_response, synthesis_response])
|
||||
engine = ReWOOEngine(llm_gateway=gateway)
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "Calculate 6*7"}],
|
||||
tools=[tool],
|
||||
)
|
||||
|
||||
assert result.output == "The result is 42"
|
||||
assert result.fallback_strategy is None
|
||||
|
||||
async def test_planning_fails_simplified_succeeds(self):
|
||||
"""规划失败,简化规划成功 → fallback_strategy="simplified_rewoo" """
|
||||
from agentkit.core.rewoo import ReWOOEngine
|
||||
|
||||
tool = FakeTool(name="search", result={"results": ["found"]})
|
||||
|
||||
# First plan fails (invalid JSON), second (simplified) succeeds
|
||||
invalid_plan_response = make_response(content="I cannot create a plan for this task.")
|
||||
simplified_plan_response = make_plan_response([
|
||||
{"step_id": 1, "tool_name": "search", "arguments": {"query": "test"}, "reasoning": "Simplified search"},
|
||||
])
|
||||
synthesis_response = make_response(content="Simplified result")
|
||||
|
||||
gateway = make_mock_gateway([invalid_plan_response, simplified_plan_response, synthesis_response])
|
||||
engine = ReWOOEngine(llm_gateway=gateway)
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "Complex task"}],
|
||||
tools=[tool],
|
||||
)
|
||||
|
||||
assert result.output == "Simplified result"
|
||||
assert result.fallback_strategy == "simplified_rewoo"
|
||||
|
||||
async def test_planning_and_simplified_fail_react_succeeds(self):
|
||||
"""规划和简化规划都失败,ReAct 回退成功 → fallback_strategy="react" """
|
||||
from agentkit.core.rewoo import ReWOOEngine
|
||||
|
||||
# Both plan attempts fail (invalid JSON), ReAct succeeds
|
||||
invalid_plan1 = make_response(content="Not a plan")
|
||||
invalid_plan2 = make_response(content="Still not a plan")
|
||||
react_response = make_response(content="ReAct fallback answer")
|
||||
|
||||
gateway = make_mock_gateway([invalid_plan1, invalid_plan2, react_response])
|
||||
engine = ReWOOEngine(llm_gateway=gateway)
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "Complex task"}],
|
||||
)
|
||||
|
||||
assert result.output == "ReAct fallback answer"
|
||||
assert result.fallback_strategy == "react"
|
||||
|
||||
async def test_all_fail_direct_fallback(self):
|
||||
"""规划、简化规划、ReAct 全部失败 → fallback_strategy="direct" """
|
||||
from agentkit.core.rewoo import ReWOOEngine
|
||||
|
||||
# Both plan attempts fail
|
||||
invalid_plan1 = make_response(content="Not a plan")
|
||||
invalid_plan2 = make_response(content="Still not a plan")
|
||||
|
||||
# Make ReAct engine fail by having its LLM call raise an exception
|
||||
call_count = 0
|
||||
|
||||
async def chat_side_effect(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count <= 2:
|
||||
# First two calls are for planning (both fail to parse)
|
||||
return make_response(content="Not a plan")
|
||||
if call_count == 3:
|
||||
# ReAct engine call - raise exception
|
||||
raise RuntimeError("ReAct engine failed")
|
||||
# Direct fallback call
|
||||
return make_response(content="Direct fallback answer")
|
||||
|
||||
gateway = MagicMock(spec=LLMGateway)
|
||||
gateway.chat = AsyncMock(side_effect=chat_side_effect)
|
||||
engine = ReWOOEngine(llm_gateway=gateway)
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "Impossible task"}],
|
||||
)
|
||||
|
||||
assert result.output == "Direct fallback answer"
|
||||
assert result.fallback_strategy == "direct"
|
||||
|
||||
async def test_fallback_strategies_constant_exists(self):
|
||||
"""验证 FALLBACK_STRATEGIES 常量存在"""
|
||||
from agentkit.core.rewoo import ReWOOEngine
|
||||
|
||||
assert hasattr(ReWOOEngine, "FALLBACK_STRATEGIES")
|
||||
assert ReWOOEngine.FALLBACK_STRATEGIES == ["simplified_rewoo", "react", "direct"]
|
||||
|
||||
async def test_react_fallback_with_tools(self):
|
||||
"""规划失败后 ReAct 回退,带工具调用"""
|
||||
from agentkit.core.rewoo import ReWOOEngine
|
||||
|
||||
tool = FakeTool(name="search", result={"results": ["found"]})
|
||||
|
||||
# Both plan attempts fail
|
||||
invalid_plan1 = make_response(content="Cannot plan")
|
||||
invalid_plan2 = make_response(content="Still cannot plan")
|
||||
# ReAct: tool call then final answer
|
||||
react_tool_response = make_response(
|
||||
content="",
|
||||
tool_calls=[ToolCall(id="tc_1", name="search", arguments={"query": "test"})],
|
||||
)
|
||||
react_final_response = make_response(content="ReAct answer with tool")
|
||||
|
||||
gateway = make_mock_gateway([
|
||||
invalid_plan1,
|
||||
invalid_plan2,
|
||||
react_tool_response,
|
||||
react_final_response,
|
||||
])
|
||||
engine = ReWOOEngine(llm_gateway=gateway)
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "Search task"}],
|
||||
tools=[tool],
|
||||
)
|
||||
|
||||
assert result.output == "ReAct answer with tool"
|
||||
assert result.fallback_strategy == "react"
|
||||
|
|
|
|||
|
|
@ -2,14 +2,14 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus
|
||||
from agentkit.evolution.lifecycle import EvolutionMixin
|
||||
from agentkit.evolution.lifecycle import EvolutionMixin, SoulEvolutionConfig
|
||||
from agentkit.evolution.reflector import Reflection, Reflector
|
||||
from agentkit.memory.profile import MemoryStore
|
||||
from agentkit.tools.memory_tool import MemoryTool
|
||||
|
|
@ -265,3 +265,288 @@ class TestEvolveSoul:
|
|||
|
||||
updated = await mixin.evolve_soul(task, result, memory_store=store)
|
||||
assert updated is False
|
||||
|
||||
|
||||
# ── Multi-dimensional trigger tests ──────────────────────────
|
||||
|
||||
|
||||
class TestTimeDecay:
|
||||
"""时间衰减触发测试."""
|
||||
|
||||
async def test_recent_reflections_count_fully(self, store: MemoryStore):
|
||||
"""窗口内的反思完全计入有效数量."""
|
||||
config = SoulEvolutionConfig(
|
||||
min_reflections=3,
|
||||
reflection_window_seconds=3600,
|
||||
time_decay_factor=0.5,
|
||||
)
|
||||
reflector = LowQualityReflector()
|
||||
mixin = EvolutionMixin(reflector=reflector, evolution_config=config)
|
||||
|
||||
task = _make_task()
|
||||
result = _make_result()
|
||||
|
||||
# 3 次近期反思应触发
|
||||
for _ in range(2):
|
||||
updated = await mixin.evolve_soul(task, result, memory_store=store)
|
||||
assert updated is False
|
||||
|
||||
updated = await mixin.evolve_soul(task, result, memory_store=store)
|
||||
assert updated is True
|
||||
|
||||
async def test_old_reflections_decay(self, store: MemoryStore):
|
||||
"""旧反思因时间衰减导致有效数量不足,不触发."""
|
||||
config = SoulEvolutionConfig(
|
||||
min_reflections=3,
|
||||
reflection_window_seconds=3600,
|
||||
time_decay_factor=0.5,
|
||||
)
|
||||
reflector = LowQualityReflector()
|
||||
mixin = EvolutionMixin(reflector=reflector, evolution_config=config)
|
||||
|
||||
task = _make_task()
|
||||
result = _make_result()
|
||||
|
||||
# 手动插入 2 个旧反思(10 小时前)
|
||||
old_time = datetime.now(timezone.utc) - timedelta(hours=10)
|
||||
for _ in range(2):
|
||||
mixin.pending_soul_updates.setdefault("slow_execution", []).append(
|
||||
{
|
||||
"reflection": Reflection(
|
||||
task_id="old",
|
||||
agent_name="evolving_agent",
|
||||
outcome="failure",
|
||||
quality_score=0.2,
|
||||
patterns=["slow_execution"],
|
||||
insights=[],
|
||||
suggestions=["Improve speed"],
|
||||
),
|
||||
"timestamp": old_time,
|
||||
"score": 0.2,
|
||||
"task_type": "",
|
||||
}
|
||||
)
|
||||
|
||||
# 1 个新反思:2*0.5^10 + 1 ≈ 1.002 < 3,不触发
|
||||
updated = await mixin.evolve_soul(task, result, memory_store=store)
|
||||
assert updated is False
|
||||
|
||||
|
||||
class TestQualityGradient:
|
||||
"""质量梯度触发测试."""
|
||||
|
||||
async def test_declining_scores_trigger_early(self, store: MemoryStore):
|
||||
"""连续 3 次分数下降超过阈值时提前触发."""
|
||||
config = SoulEvolutionConfig(
|
||||
min_reflections=3,
|
||||
quality_gradient_threshold=-0.15,
|
||||
)
|
||||
reflector = LowQualityReflector()
|
||||
mixin = EvolutionMixin(reflector=reflector, evolution_config=config)
|
||||
|
||||
task = _make_task()
|
||||
result = _make_result()
|
||||
|
||||
# 手动插入 2 个反思,分数递减
|
||||
mixin.pending_soul_updates.setdefault("slow_execution", []).append(
|
||||
{
|
||||
"reflection": Reflection(
|
||||
task_id="g1",
|
||||
agent_name="evolving_agent",
|
||||
outcome="failure",
|
||||
quality_score=0.4,
|
||||
patterns=["slow_execution"],
|
||||
insights=[],
|
||||
suggestions=["Improve"],
|
||||
),
|
||||
"timestamp": datetime.now(timezone.utc),
|
||||
"score": 0.4,
|
||||
"task_type": "",
|
||||
}
|
||||
)
|
||||
mixin.pending_soul_updates["slow_execution"].append(
|
||||
{
|
||||
"reflection": Reflection(
|
||||
task_id="g2",
|
||||
agent_name="evolving_agent",
|
||||
outcome="failure",
|
||||
quality_score=0.2,
|
||||
patterns=["slow_execution"],
|
||||
insights=[],
|
||||
suggestions=["Improve more"],
|
||||
),
|
||||
"timestamp": datetime.now(timezone.utc),
|
||||
"score": 0.2,
|
||||
"task_type": "",
|
||||
}
|
||||
)
|
||||
|
||||
# 第 3 个反思:score=0.0,下降 0.2 > 0.15 阈值,触发
|
||||
updated = await mixin.evolve_soul(
|
||||
task, result, memory_store=store, score=0.0
|
||||
)
|
||||
assert updated is True
|
||||
|
||||
async def test_stable_scores_do_not_trigger_early(self, store: MemoryStore):
|
||||
"""分数稳定时不提前触发."""
|
||||
config = SoulEvolutionConfig(
|
||||
min_reflections=3,
|
||||
quality_gradient_threshold=-0.15,
|
||||
)
|
||||
reflector = LowQualityReflector()
|
||||
mixin = EvolutionMixin(reflector=reflector, evolution_config=config)
|
||||
|
||||
task = _make_task()
|
||||
result = _make_result()
|
||||
|
||||
# 2 个分数稳定的反思
|
||||
mixin.pending_soul_updates.setdefault("slow_execution", []).append(
|
||||
{
|
||||
"reflection": Reflection(
|
||||
task_id="s1",
|
||||
agent_name="evolving_agent",
|
||||
outcome="failure",
|
||||
quality_score=0.2,
|
||||
patterns=["slow_execution"],
|
||||
insights=[],
|
||||
suggestions=["Improve"],
|
||||
),
|
||||
"timestamp": datetime.now(timezone.utc),
|
||||
"score": 0.2,
|
||||
"task_type": "",
|
||||
}
|
||||
)
|
||||
mixin.pending_soul_updates["slow_execution"].append(
|
||||
{
|
||||
"reflection": Reflection(
|
||||
task_id="s2",
|
||||
agent_name="evolving_agent",
|
||||
outcome="failure",
|
||||
quality_score=0.19,
|
||||
patterns=["slow_execution"],
|
||||
insights=[],
|
||||
suggestions=["Improve more"],
|
||||
),
|
||||
"timestamp": datetime.now(timezone.utc),
|
||||
"score": 0.19,
|
||||
"task_type": "",
|
||||
}
|
||||
)
|
||||
|
||||
# 第 3 个反思:下降 0.01 < 0.15 阈值,不提前触发
|
||||
# 但 effective_count=3 >= min_reflections=3,所以仍会触发
|
||||
# 需要改为只有 2 个反思来测试"不提前触发"
|
||||
mixin2 = EvolutionMixin(reflector=reflector, evolution_config=config)
|
||||
mixin2.pending_soul_updates.setdefault("slow_execution", []).append(
|
||||
{
|
||||
"reflection": Reflection(
|
||||
task_id="s1",
|
||||
agent_name="evolving_agent",
|
||||
outcome="failure",
|
||||
quality_score=0.2,
|
||||
patterns=["slow_execution"],
|
||||
insights=[],
|
||||
suggestions=["Improve"],
|
||||
),
|
||||
"timestamp": datetime.now(timezone.utc),
|
||||
"score": 0.2,
|
||||
"task_type": "",
|
||||
}
|
||||
)
|
||||
# 只有 2 个反思,分数稳定,不应触发
|
||||
updated = await mixin2.evolve_soul(
|
||||
task, result, memory_store=store, score=0.19
|
||||
)
|
||||
assert updated is False
|
||||
|
||||
|
||||
class TestTaskTypeWeight:
|
||||
"""任务类型权重触发测试."""
|
||||
|
||||
async def test_code_generation_triggers_at_2(self, store: MemoryStore):
|
||||
"""code_generation 类型权重 1.5,2 次反思即可触发 (2*1.5=3 >= 3)."""
|
||||
config = SoulEvolutionConfig(
|
||||
min_reflections=3,
|
||||
task_type_weights={"code_generation": 1.5, "chat": 0.5},
|
||||
)
|
||||
reflector = LowQualityReflector()
|
||||
mixin = EvolutionMixin(reflector=reflector, evolution_config=config)
|
||||
|
||||
task = _make_task()
|
||||
result = _make_result()
|
||||
|
||||
# 第 1 次
|
||||
updated = await mixin.evolve_soul(
|
||||
task, result, memory_store=store, task_type="code_generation"
|
||||
)
|
||||
assert updated is False
|
||||
|
||||
# 第 2 次:effective_count=2, weight=1.5, weighted=3.0 >= 3,触发
|
||||
updated = await mixin.evolve_soul(
|
||||
task, result, memory_store=store, task_type="code_generation"
|
||||
)
|
||||
assert updated is True
|
||||
|
||||
async def test_chat_needs_more_reflections(self, store: MemoryStore):
|
||||
"""chat 类型权重 0.5,需要更多反思才能触发."""
|
||||
config = SoulEvolutionConfig(
|
||||
min_reflections=3,
|
||||
task_type_weights={"code_generation": 1.5, "chat": 0.5},
|
||||
)
|
||||
reflector = LowQualityReflector()
|
||||
mixin = EvolutionMixin(reflector=reflector, evolution_config=config)
|
||||
|
||||
task = _make_task()
|
||||
result = _make_result()
|
||||
|
||||
# 4 次 chat 反思:effective_count=4, weight=0.5, weighted=2.0 < 3,不触发
|
||||
for _ in range(4):
|
||||
updated = await mixin.evolve_soul(
|
||||
task, result, memory_store=store, task_type="chat"
|
||||
)
|
||||
assert updated is False
|
||||
|
||||
# 第 5 次触发:5 * 0.5 = 2.5 < 3,仍不触发
|
||||
updated = await mixin.evolve_soul(
|
||||
task, result, memory_store=store, task_type="chat"
|
||||
)
|
||||
assert updated is False
|
||||
|
||||
# 第 6 次:6 * 0.5 = 3.0 >= 3,触发
|
||||
updated = await mixin.evolve_soul(
|
||||
task, result, memory_store=store, task_type="chat"
|
||||
)
|
||||
assert updated is True
|
||||
|
||||
|
||||
class TestBackwardCompatibility:
|
||||
"""向后兼容性测试:无 config 时行为与之前一致."""
|
||||
|
||||
async def test_no_config_3_reflections_trigger(self, store: MemoryStore):
|
||||
"""无 evolution_config 时,3 次反思触发(与原行为一致)."""
|
||||
reflector = LowQualityReflector()
|
||||
mixin = EvolutionMixin(reflector=reflector)
|
||||
|
||||
task = _make_task()
|
||||
result = _make_result()
|
||||
|
||||
# 前 2 次不触发
|
||||
for _ in range(2):
|
||||
updated = await mixin.evolve_soul(task, result, memory_store=store)
|
||||
assert updated is False
|
||||
|
||||
# 第 3 次触发
|
||||
updated = await mixin.evolve_soul(task, result, memory_store=store)
|
||||
assert updated is True
|
||||
|
||||
async def test_no_config_fewer_than_3_no_trigger(self, store: MemoryStore):
|
||||
"""无 evolution_config 时,少于 3 次不触发."""
|
||||
reflector = LowQualityReflector()
|
||||
mixin = EvolutionMixin(reflector=reflector)
|
||||
|
||||
task = _make_task()
|
||||
result = _make_result()
|
||||
|
||||
for _ in range(2):
|
||||
updated = await mixin.evolve_soul(task, result, memory_store=store)
|
||||
assert updated is False
|
||||
|
|
|
|||
|
|
@ -1,472 +1,254 @@
|
|||
"""Unit tests for telemetry module — OpenTelemetry integration"""
|
||||
"""Telemetry module unit tests."""
|
||||
|
||||
import asyncio
|
||||
import importlib
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ── No-op behavior when OTel not installed ──────────────────────────
|
||||
from agentkit.telemetry.tracer import (
|
||||
NoOpSpan,
|
||||
NoOpTracer,
|
||||
OTelSpan,
|
||||
OTelTracer,
|
||||
TelemetryConfig,
|
||||
get_tracer,
|
||||
init_telemetry,
|
||||
)
|
||||
from agentkit.chat.skill_routing import CostAwareRouter, SkillRoutingResult
|
||||
from agentkit.quality.alignment import AlignmentGuard, AlignmentConfig
|
||||
|
||||
|
||||
class TestNoOpWhenOTelNotInstalled:
|
||||
"""All operations are no-op when opentelemetry is not installed."""
|
||||
# ── NoOpSpan 测试 ──────────────────────────────────────────
|
||||
|
||||
def test_tracing_noop_span_context_manager(self):
|
||||
"""_NoOpSpan works as context manager without errors."""
|
||||
from agentkit.telemetry.tracing import _NoOpSpan
|
||||
|
||||
span = _NoOpSpan()
|
||||
class TestNoOpSpan:
|
||||
"""NoOpSpan no-op 行为测试"""
|
||||
|
||||
def test_is_recording_returns_false(self):
|
||||
span = NoOpSpan()
|
||||
assert span.is_recording() is False
|
||||
|
||||
def test_set_attribute_does_nothing(self):
|
||||
span = NoOpSpan()
|
||||
# 不应抛出异常
|
||||
span.set_attribute("key", "value")
|
||||
|
||||
def test_add_event_does_nothing(self):
|
||||
span = NoOpSpan()
|
||||
span.add_event("event_name", {"attr": "val"})
|
||||
|
||||
def test_record_exception_does_nothing(self):
|
||||
span = NoOpSpan()
|
||||
span.record_exception(ValueError("test"))
|
||||
|
||||
def test_context_manager(self):
|
||||
span = NoOpSpan()
|
||||
with span as s:
|
||||
s.set_attribute("key", "value")
|
||||
s.add_event("event")
|
||||
s.set_status("ok")
|
||||
s.record_exception(Exception("test"))
|
||||
assert s is span
|
||||
|
||||
def test_get_tracer_returns_none_without_otel(self):
|
||||
"""get_tracer returns None when OTel is not installed."""
|
||||
from agentkit.telemetry.tracing import _OTEL_AVAILABLE, get_tracer
|
||||
|
||||
if _OTEL_AVAILABLE:
|
||||
pytest.skip("OTel is installed, skipping no-op test")
|
||||
assert get_tracer() is None
|
||||
# ── NoOpTracer 测试 ────────────────────────────────────────
|
||||
|
||||
def test_start_span_returns_noop_without_otel(self):
|
||||
"""start_span returns no-op span when OTel is not installed."""
|
||||
from agentkit.telemetry.tracing import _OTEL_AVAILABLE, start_span, _NoOpSpan
|
||||
|
||||
if _OTEL_AVAILABLE:
|
||||
pytest.skip("OTel is installed, skipping no-op test")
|
||||
span_cm = start_span("test.span")
|
||||
assert isinstance(span_cm, _NoOpSpan)
|
||||
class TestNoOpTracer:
|
||||
"""NoOpTracer no-op 行为测试"""
|
||||
|
||||
def test_metrics_noop_counter(self):
|
||||
"""No-op counter add() does not raise."""
|
||||
from agentkit.telemetry.metrics import _NoOpCounter
|
||||
def test_start_span_returns_noop_span(self):
|
||||
tracer = NoOpTracer()
|
||||
span = tracer.start_span("test.span")
|
||||
assert isinstance(span, NoOpSpan)
|
||||
|
||||
counter = _NoOpCounter()
|
||||
counter.add(1, {"key": "value"}) # Should not raise
|
||||
|
||||
def test_metrics_noop_histogram(self):
|
||||
"""No-op histogram record() does not raise."""
|
||||
from agentkit.telemetry.metrics import _NoOpHistogram
|
||||
|
||||
hist = _NoOpHistogram()
|
||||
hist.record(100, {"key": "value"}) # Should not raise
|
||||
|
||||
def test_metrics_get_meter_returns_none_without_otel(self):
|
||||
"""get_meter returns None when OTel is not installed."""
|
||||
from agentkit.telemetry.metrics import _OTEL_AVAILABLE, get_meter
|
||||
|
||||
if _OTEL_AVAILABLE:
|
||||
pytest.skip("OTel is installed, skipping no-op test")
|
||||
assert get_meter() is None
|
||||
|
||||
def test_metric_helpers_return_noop_without_otel(self):
|
||||
"""Metric helper functions return no-op instruments when OTel not installed."""
|
||||
from agentkit.telemetry.metrics import (
|
||||
_OTEL_AVAILABLE,
|
||||
_NoOpCounter,
|
||||
_NoOpHistogram,
|
||||
agent_request_counter,
|
||||
agent_duration_histogram,
|
||||
llm_token_histogram,
|
||||
tool_duration_histogram,
|
||||
pipeline_step_histogram,
|
||||
)
|
||||
|
||||
if _OTEL_AVAILABLE:
|
||||
pytest.skip("OTel is installed, skipping no-op test")
|
||||
|
||||
# Reset lazy singletons to force re-creation
|
||||
import agentkit.telemetry.metrics as m
|
||||
m._agent_request_counter = None
|
||||
m._agent_duration_histogram = None
|
||||
m._llm_token_histogram = None
|
||||
m._tool_duration_histogram = None
|
||||
m._pipeline_step_histogram = None
|
||||
|
||||
assert isinstance(agent_request_counter(), _NoOpCounter)
|
||||
assert isinstance(agent_duration_histogram(), _NoOpHistogram)
|
||||
assert isinstance(llm_token_histogram(), _NoOpHistogram)
|
||||
assert isinstance(tool_duration_histogram(), _NoOpHistogram)
|
||||
assert isinstance(pipeline_step_histogram(), _NoOpHistogram)
|
||||
|
||||
|
||||
# ── Tracing decorator tests ─────────────────────────────────────────
|
||||
|
||||
|
||||
class TestTraceAgentDecorator:
|
||||
"""trace_agent decorator works with and without OTel."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_works_without_otel(self):
|
||||
"""trace_agent decorator passes through when OTel not installed."""
|
||||
from agentkit.telemetry.tracing import _OTEL_AVAILABLE, trace_agent
|
||||
|
||||
if _OTEL_AVAILABLE:
|
||||
pytest.skip("OTel is installed, skipping no-op test")
|
||||
|
||||
@trace_agent("test_agent", "react")
|
||||
async def my_func():
|
||||
return "result"
|
||||
|
||||
result = await my_func()
|
||||
assert result == "result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_propagates_exception_without_otel(self):
|
||||
"""trace_agent propagates exceptions when OTel not installed."""
|
||||
from agentkit.telemetry.tracing import _OTEL_AVAILABLE, trace_agent
|
||||
|
||||
if _OTEL_AVAILABLE:
|
||||
pytest.skip("OTel is installed, skipping no-op test")
|
||||
|
||||
@trace_agent("test_agent")
|
||||
async def my_func():
|
||||
raise ValueError("test error")
|
||||
|
||||
with pytest.raises(ValueError, match="test error"):
|
||||
await my_func()
|
||||
|
||||
|
||||
class TestTraceToolDecorator:
|
||||
"""trace_tool decorator tests."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_works_without_otel(self):
|
||||
"""trace_tool decorator passes through when OTel not installed."""
|
||||
from agentkit.telemetry.tracing import _OTEL_AVAILABLE, trace_tool
|
||||
|
||||
if _OTEL_AVAILABLE:
|
||||
pytest.skip("OTel is installed, skipping no-op test")
|
||||
|
||||
@trace_tool("my_tool")
|
||||
async def my_func():
|
||||
return {"result": "ok"}
|
||||
|
||||
result = await my_func()
|
||||
assert result == {"result": "ok"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_propagates_exception_without_otel(self):
|
||||
"""trace_tool propagates exceptions when OTel not installed."""
|
||||
from agentkit.telemetry.tracing import _OTEL_AVAILABLE, trace_tool
|
||||
|
||||
if _OTEL_AVAILABLE:
|
||||
pytest.skip("OTel is installed, skipping no-op test")
|
||||
|
||||
@trace_tool("my_tool")
|
||||
async def my_func():
|
||||
raise RuntimeError("tool error")
|
||||
|
||||
with pytest.raises(RuntimeError, match="tool error"):
|
||||
await my_func()
|
||||
|
||||
|
||||
class TestTraceLLMDecorator:
|
||||
"""trace_llm decorator tests."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_works_without_otel(self):
|
||||
"""trace_llm decorator passes through when OTel not installed."""
|
||||
from agentkit.telemetry.tracing import _OTEL_AVAILABLE, trace_llm
|
||||
|
||||
if _OTEL_AVAILABLE:
|
||||
pytest.skip("OTel is installed, skipping no-op test")
|
||||
|
||||
@trace_llm("openai", "gpt-4")
|
||||
async def my_func():
|
||||
return MagicMock(usage=MagicMock(prompt_tokens=10, completion_tokens=20))
|
||||
|
||||
result = await my_func()
|
||||
assert result is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_propagates_exception_without_otel(self):
|
||||
"""trace_llm propagates exceptions when OTel not installed."""
|
||||
from agentkit.telemetry.tracing import _OTEL_AVAILABLE, trace_llm
|
||||
|
||||
if _OTEL_AVAILABLE:
|
||||
pytest.skip("OTel is installed, skipping no-op test")
|
||||
|
||||
@trace_llm("openai", "gpt-4")
|
||||
async def my_func():
|
||||
raise ConnectionError("LLM error")
|
||||
|
||||
with pytest.raises(ConnectionError, match="LLM error"):
|
||||
await my_func()
|
||||
|
||||
|
||||
class TestTracePipelineStepDecorator:
|
||||
"""trace_pipeline_step decorator tests."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_works_without_otel(self):
|
||||
"""trace_pipeline_step decorator passes through when OTel not installed."""
|
||||
from agentkit.telemetry.tracing import _OTEL_AVAILABLE, trace_pipeline_step
|
||||
|
||||
if _OTEL_AVAILABLE:
|
||||
pytest.skip("OTel is installed, skipping no-op test")
|
||||
|
||||
@trace_pipeline_step("my_pipeline", "step_1")
|
||||
async def my_func():
|
||||
return "step_result"
|
||||
|
||||
result = await my_func()
|
||||
assert result == "step_result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_propagates_exception_without_otel(self):
|
||||
"""trace_pipeline_step propagates exceptions when OTel not installed."""
|
||||
from agentkit.telemetry.tracing import _OTEL_AVAILABLE, trace_pipeline_step
|
||||
|
||||
if _OTEL_AVAILABLE:
|
||||
pytest.skip("OTel is installed, skipping no-op test")
|
||||
|
||||
@trace_pipeline_step("my_pipeline", "step_1")
|
||||
async def my_func():
|
||||
raise RuntimeError("step failed")
|
||||
|
||||
with pytest.raises(RuntimeError, match="step failed"):
|
||||
await my_func()
|
||||
|
||||
|
||||
# ── OTel installed (mocked) tests ───────────────────────────────────
|
||||
|
||||
|
||||
class TestTracingWithMockedOTel:
|
||||
"""Test tracing with mocked OTel imports."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trace_agent_with_mocked_otel(self):
|
||||
"""trace_agent creates span with correct attributes when OTel is available."""
|
||||
mock_span = MagicMock()
|
||||
mock_span_cm = MagicMock()
|
||||
mock_span_cm.__enter__ = MagicMock(return_value=mock_span)
|
||||
mock_span_cm.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mock_tracer = MagicMock()
|
||||
mock_tracer.start_as_current_span.return_value = mock_span_cm
|
||||
|
||||
with patch("agentkit.telemetry.tracing._OTEL_AVAILABLE", True), \
|
||||
patch("agentkit.telemetry.tracing.get_tracer", return_value=mock_tracer), \
|
||||
patch("agentkit.telemetry.tracing.SpanKind"), \
|
||||
patch("agentkit.telemetry.tracing.Status"), \
|
||||
patch("agentkit.telemetry.tracing.StatusCode"):
|
||||
|
||||
from agentkit.telemetry.tracing import trace_agent
|
||||
|
||||
@trace_agent("test_agent", "react")
|
||||
async def my_func():
|
||||
return "result"
|
||||
|
||||
result = await my_func()
|
||||
assert result == "result"
|
||||
mock_tracer.start_as_current_span.assert_called_once()
|
||||
call_kwargs = mock_tracer.start_as_current_span.call_args
|
||||
assert call_kwargs[1]["attributes"]["agent.name"] == "test_agent"
|
||||
assert call_kwargs[1]["attributes"]["agent.type"] == "react"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trace_tool_with_mocked_otel(self):
|
||||
"""trace_tool creates span with tool.name attribute."""
|
||||
mock_span = MagicMock()
|
||||
mock_span_cm = MagicMock()
|
||||
mock_span_cm.__enter__ = MagicMock(return_value=mock_span)
|
||||
mock_span_cm.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mock_tracer = MagicMock()
|
||||
mock_tracer.start_as_current_span.return_value = mock_span_cm
|
||||
|
||||
with patch("agentkit.telemetry.tracing._OTEL_AVAILABLE", True), \
|
||||
patch("agentkit.telemetry.tracing.get_tracer", return_value=mock_tracer), \
|
||||
patch("agentkit.telemetry.tracing.SpanKind"), \
|
||||
patch("agentkit.telemetry.tracing.Status"), \
|
||||
patch("agentkit.telemetry.tracing.StatusCode"):
|
||||
|
||||
from agentkit.telemetry.tracing import trace_tool
|
||||
|
||||
@trace_tool("search_tool")
|
||||
async def my_func():
|
||||
return {"found": True}
|
||||
|
||||
result = await my_func()
|
||||
assert result == {"found": True}
|
||||
call_kwargs = mock_tracer.start_as_current_span.call_args
|
||||
assert call_kwargs[1]["attributes"]["tool.name"] == "search_tool"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trace_llm_with_mocked_otel(self):
|
||||
"""trace_llm creates span with gen_ai semantic conventions."""
|
||||
mock_span = MagicMock()
|
||||
mock_span_cm = MagicMock()
|
||||
mock_span_cm.__enter__ = MagicMock(return_value=mock_span)
|
||||
mock_span_cm.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mock_tracer = MagicMock()
|
||||
mock_tracer.start_as_current_span.return_value = mock_span_cm
|
||||
|
||||
mock_usage = MagicMock()
|
||||
mock_usage.prompt_tokens = 50
|
||||
mock_usage.completion_tokens = 100
|
||||
mock_response = MagicMock()
|
||||
mock_response.usage = mock_usage
|
||||
|
||||
with patch("agentkit.telemetry.tracing._OTEL_AVAILABLE", True), \
|
||||
patch("agentkit.telemetry.tracing.get_tracer", return_value=mock_tracer), \
|
||||
patch("agentkit.telemetry.tracing.SpanKind"), \
|
||||
patch("agentkit.telemetry.tracing.Status"), \
|
||||
patch("agentkit.telemetry.tracing.StatusCode"):
|
||||
|
||||
from agentkit.telemetry.tracing import trace_llm
|
||||
|
||||
@trace_llm("openai", "gpt-4")
|
||||
async def my_func():
|
||||
return mock_response
|
||||
|
||||
result = await my_func()
|
||||
assert result is mock_response
|
||||
call_kwargs = mock_tracer.start_as_current_span.call_args
|
||||
attrs = call_kwargs[1]["attributes"]
|
||||
assert attrs["gen_ai.system"] == "openai"
|
||||
assert attrs["gen_ai.operation.name"] == "chat"
|
||||
assert attrs["gen_ai.request.model"] == "gpt-4"
|
||||
# Token usage should be recorded on span
|
||||
mock_span.set_attribute.assert_any_call("gen_ai.usage.input_tokens", 50)
|
||||
mock_span.set_attribute.assert_any_call("gen_ai.usage.output_tokens", 100)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trace_pipeline_step_with_mocked_otel(self):
|
||||
"""trace_pipeline_step creates span with pipeline and step attributes."""
|
||||
mock_span = MagicMock()
|
||||
mock_span_cm = MagicMock()
|
||||
mock_span_cm.__enter__ = MagicMock(return_value=mock_span)
|
||||
mock_span_cm.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mock_tracer = MagicMock()
|
||||
mock_tracer.start_as_current_span.return_value = mock_span_cm
|
||||
|
||||
with patch("agentkit.telemetry.tracing._OTEL_AVAILABLE", True), \
|
||||
patch("agentkit.telemetry.tracing.get_tracer", return_value=mock_tracer), \
|
||||
patch("agentkit.telemetry.tracing.SpanKind"), \
|
||||
patch("agentkit.telemetry.tracing.Status"), \
|
||||
patch("agentkit.telemetry.tracing.StatusCode"):
|
||||
|
||||
from agentkit.telemetry.tracing import trace_pipeline_step
|
||||
|
||||
@trace_pipeline_step("geo_pipeline", "analyze")
|
||||
async def my_func():
|
||||
return "done"
|
||||
|
||||
result = await my_func()
|
||||
assert result == "done"
|
||||
call_kwargs = mock_tracer.start_as_current_span.call_args
|
||||
attrs = call_kwargs[1]["attributes"]
|
||||
assert attrs["pipeline.name"] == "geo_pipeline"
|
||||
assert attrs["step.name"] == "analyze"
|
||||
|
||||
|
||||
# ── setup_telemetry tests ───────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSetupTelemetry:
|
||||
"""setup_telemetry initialization tests."""
|
||||
|
||||
def test_no_config_is_noop(self):
|
||||
"""setup_telemetry with no config is a no-op."""
|
||||
from agentkit.telemetry.setup import setup_telemetry
|
||||
|
||||
mock_app = MagicMock()
|
||||
setup_telemetry(mock_app, None) # Should not raise
|
||||
# No auto-instrumentation should happen
|
||||
mock_app.state = MagicMock() # Just ensure no crash
|
||||
|
||||
def test_disabled_config_is_noop(self):
|
||||
"""setup_telemetry with enabled=False is a no-op."""
|
||||
from agentkit.telemetry.setup import setup_telemetry
|
||||
|
||||
mock_app = MagicMock()
|
||||
setup_telemetry(mock_app, {"enabled": False}) # Should not raise
|
||||
|
||||
def test_config_without_otel_logs_warning(self):
|
||||
"""setup_telemetry with config but OTel not installed logs warning."""
|
||||
from agentkit.telemetry.setup import setup_telemetry
|
||||
|
||||
mock_app = MagicMock()
|
||||
# This should not raise even if OTel is not installed
|
||||
# It will log a warning internally
|
||||
config = {"enabled": True, "service_name": "test"}
|
||||
# If OTel is installed, this will try to set up providers
|
||||
# If not, it will log a warning and return
|
||||
setup_telemetry(mock_app, config) # Should not raise
|
||||
|
||||
def test_empty_config_is_noop(self):
|
||||
"""setup_telemetry with empty dict is a no-op (enabled defaults to False)."""
|
||||
from agentkit.telemetry.setup import setup_telemetry
|
||||
|
||||
mock_app = MagicMock()
|
||||
setup_telemetry(mock_app, {}) # Should not raise
|
||||
|
||||
|
||||
# ── Integration: Tool safe_execute with telemetry ───────────────────
|
||||
|
||||
|
||||
class TestToolTelemetryIntegration:
|
||||
"""Test that Tool.safe_execute records telemetry."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_safe_execute_records_noop_telemetry(self):
|
||||
"""safe_execute works with no-op telemetry (OTel not installed)."""
|
||||
from agentkit.tools.base import Tool
|
||||
|
||||
class DummyTool(Tool):
|
||||
async def execute(self, **kwargs):
|
||||
return {"result": "ok"}
|
||||
|
||||
tool = DummyTool(name="test_tool", description="A test tool")
|
||||
result = await tool.safe_execute(query="hello")
|
||||
assert result == {"result": "ok"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_safe_execute_error_records_telemetry(self):
|
||||
"""safe_execute records error telemetry on exception."""
|
||||
from agentkit.tools.base import Tool
|
||||
|
||||
class FailingTool(Tool):
|
||||
async def execute(self, **kwargs):
|
||||
raise ValueError("tool failed")
|
||||
|
||||
tool = FailingTool(name="failing_tool", description="A failing tool")
|
||||
with pytest.raises(ValueError, match="tool failed"):
|
||||
await tool.safe_execute(query="hello")
|
||||
|
||||
|
||||
# ── start_span helper tests ─────────────────────────────────────────
|
||||
|
||||
|
||||
class TestStartSpan:
|
||||
"""Test start_span helper function."""
|
||||
|
||||
def test_start_span_noop_without_otel(self):
|
||||
"""start_span returns no-op span context manager without OTel."""
|
||||
from agentkit.telemetry.tracing import _OTEL_AVAILABLE, start_span, _NoOpSpan
|
||||
|
||||
if _OTEL_AVAILABLE:
|
||||
pytest.skip("OTel is installed, skipping no-op test")
|
||||
|
||||
cm = start_span("test.span", attributes={"key": "value"})
|
||||
assert isinstance(cm, _NoOpSpan)
|
||||
# Should work as context manager
|
||||
with cm:
|
||||
pass # No error
|
||||
def test_start_as_current_span_returns_noop_span(self):
|
||||
tracer = NoOpTracer()
|
||||
span = tracer.start_as_current_span("test.span")
|
||||
assert isinstance(span, NoOpSpan)
|
||||
|
||||
def test_start_span_with_attributes(self):
|
||||
"""start_span accepts attributes parameter without error."""
|
||||
from agentkit.telemetry.tracing import start_span
|
||||
tracer = NoOpTracer()
|
||||
span = tracer.start_span("test.span", attributes={"key": "value"})
|
||||
assert isinstance(span, NoOpSpan)
|
||||
|
||||
cm = start_span("test.span", attributes={"key": "value", "count": 42})
|
||||
with cm:
|
||||
pass # No error regardless of OTel availability
|
||||
def test_start_span_as_context_manager(self):
|
||||
tracer = NoOpTracer()
|
||||
with tracer.start_span("test.span") as span:
|
||||
assert isinstance(span, NoOpSpan)
|
||||
|
||||
|
||||
# ── TelemetryConfig 测试 ───────────────────────────────────
|
||||
|
||||
|
||||
class TestTelemetryConfig:
|
||||
"""TelemetryConfig 默认值测试"""
|
||||
|
||||
def test_default_values(self):
|
||||
config = TelemetryConfig()
|
||||
assert config.enabled is False
|
||||
assert config.otlp_endpoint == ""
|
||||
assert config.service_name == "agentkit"
|
||||
assert config.sample_rate == 1.0
|
||||
|
||||
def test_custom_values(self):
|
||||
config = TelemetryConfig(
|
||||
enabled=True,
|
||||
otlp_endpoint="http://localhost:4317",
|
||||
service_name="my-service",
|
||||
sample_rate=0.5,
|
||||
)
|
||||
assert config.enabled is True
|
||||
assert config.otlp_endpoint == "http://localhost:4317"
|
||||
assert config.service_name == "my-service"
|
||||
assert config.sample_rate == 0.5
|
||||
|
||||
|
||||
# ── init_telemetry 测试 ────────────────────────────────────
|
||||
|
||||
|
||||
class TestInitTelemetry:
|
||||
"""init_telemetry 初始化测试"""
|
||||
|
||||
def test_disabled_returns_noop_tracer(self):
|
||||
config = TelemetryConfig(enabled=False)
|
||||
init_telemetry(config)
|
||||
tracer = get_tracer()
|
||||
assert isinstance(tracer, NoOpTracer)
|
||||
|
||||
def test_missing_opentelemetry_falls_back_to_noop(self):
|
||||
"""当 opentelemetry 包未安装时,优雅降级为 NoOpTracer"""
|
||||
config = TelemetryConfig(enabled=True, otlp_endpoint="http://localhost:4317")
|
||||
# 即使 opentelemetry 未安装,也不应抛出异常
|
||||
init_telemetry(config)
|
||||
tracer = get_tracer()
|
||||
# 在没有 opentelemetry 的环境中应为 NoOpTracer
|
||||
assert isinstance(tracer, (NoOpTracer, OTelTracer))
|
||||
|
||||
def test_init_with_exception_falls_back_to_noop(self):
|
||||
"""初始化过程中出现异常时,降级为 NoOpTracer"""
|
||||
config = TelemetryConfig(enabled=True, otlp_endpoint="http://localhost:4317")
|
||||
with patch(
|
||||
"agentkit.telemetry.tracer.init_telemetry",
|
||||
side_effect=Exception("init error"),
|
||||
) as mock_init:
|
||||
# init_telemetry 被模拟为抛异常,但实际调用的是原始函数
|
||||
# 这里验证的是 init_telemetry 本身不会崩溃
|
||||
pass
|
||||
# 直接调用,不应崩溃
|
||||
init_telemetry(config)
|
||||
tracer = get_tracer()
|
||||
assert isinstance(tracer, (NoOpTracer, OTelTracer))
|
||||
|
||||
|
||||
# ── get_tracer 测试 ────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGetTracer:
|
||||
"""get_tracer 全局实例测试"""
|
||||
|
||||
def test_returns_global_tracer(self):
|
||||
# 确保先重置为 NoOpTracer
|
||||
init_telemetry(TelemetryConfig(enabled=False))
|
||||
tracer = get_tracer()
|
||||
assert isinstance(tracer, NoOpTracer)
|
||||
|
||||
def test_get_tracer_returns_same_instance(self):
|
||||
init_telemetry(TelemetryConfig(enabled=False))
|
||||
tracer1 = get_tracer()
|
||||
tracer2 = get_tracer()
|
||||
assert tracer1 is tracer2
|
||||
|
||||
|
||||
# ── CostAwareRouter span 测试 ──────────────────────────────
|
||||
|
||||
|
||||
class TestCostAwareRouterSpan:
|
||||
"""CostAwareRouter 创建 span 并设置属性"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_creates_span_with_attributes(self):
|
||||
"""路由时创建 span 并设置 route.layer 和 route.target 属性"""
|
||||
init_telemetry(TelemetryConfig(enabled=False))
|
||||
tracer = get_tracer()
|
||||
|
||||
# 用 mock 替换 start_span 以验证调用
|
||||
mock_span = MagicMock()
|
||||
mock_span.__enter__ = MagicMock(return_value=mock_span)
|
||||
mock_span.__exit__ = MagicMock(return_value=False)
|
||||
mock_span.set_attribute = MagicMock()
|
||||
|
||||
with patch.object(tracer, "start_span", return_value=mock_span):
|
||||
router = CostAwareRouter()
|
||||
result = await router.route(
|
||||
content="你好",
|
||||
skill_registry=MagicMock(),
|
||||
intent_router=MagicMock(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
|
||||
tracer.start_span.assert_called_once_with("router.route")
|
||||
# 验证 span 设置了 input.length 属性
|
||||
mock_span.set_attribute.assert_any_call("input.length", len("你好"))
|
||||
# 验证 span 设置了 route.layer 和 route.target
|
||||
call_args_list = [
|
||||
(call.args[0], call.args[1])
|
||||
for call in mock_span.set_attribute.call_args_list
|
||||
]
|
||||
assert ("route.layer", "greeting") in call_args_list
|
||||
assert ("route.target", "default") in call_args_list
|
||||
|
||||
|
||||
# ── AlignmentGuard span 测试 ──────────────────────────────
|
||||
|
||||
|
||||
class TestAlignmentGuardSpan:
|
||||
"""AlignmentGuard 创建 span 并设置属性"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_guard_creates_span_with_attributes(self):
|
||||
"""对齐检查时创建 span 并设置 guard.passed 和 guard.checked_by 属性"""
|
||||
init_telemetry(TelemetryConfig(enabled=False))
|
||||
tracer = get_tracer()
|
||||
|
||||
mock_span = MagicMock()
|
||||
mock_span.__enter__ = MagicMock(return_value=mock_span)
|
||||
mock_span.__exit__ = MagicMock(return_value=False)
|
||||
mock_span.set_attribute = MagicMock()
|
||||
|
||||
with patch.object(tracer, "start_span", return_value=mock_span):
|
||||
config = AlignmentConfig(constraints=["forbidden_word"])
|
||||
guard = AlignmentGuard(config)
|
||||
result = await guard.check_output(
|
||||
{"content": "This contains forbidden_word"}
|
||||
)
|
||||
|
||||
tracer.start_span.assert_called_once_with("guard.check")
|
||||
call_args_list = [
|
||||
(call.args[0], call.args[1])
|
||||
for call in mock_span.set_attribute.call_args_list
|
||||
]
|
||||
assert ("guard.constraints_count", 1) in call_args_list
|
||||
assert ("guard.passed", False) in call_args_list
|
||||
assert ("guard.checked_by", "rule") in call_args_list
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_guard_span_on_pass(self):
|
||||
"""对齐检查通过时 span 设置 guard.passed=True"""
|
||||
init_telemetry(TelemetryConfig(enabled=False))
|
||||
tracer = get_tracer()
|
||||
|
||||
mock_span = MagicMock()
|
||||
mock_span.__enter__ = MagicMock(return_value=mock_span)
|
||||
mock_span.__exit__ = MagicMock(return_value=False)
|
||||
mock_span.set_attribute = MagicMock()
|
||||
|
||||
with patch.object(tracer, "start_span", return_value=mock_span):
|
||||
config = AlignmentConfig(constraints=["forbidden_word"])
|
||||
guard = AlignmentGuard(config)
|
||||
result = await guard.check_output(
|
||||
{"content": "This is clean text"}
|
||||
)
|
||||
|
||||
call_args_list = [
|
||||
(call.args[0], call.args[1])
|
||||
for call in mock_span.set_attribute.call_args_list
|
||||
]
|
||||
assert ("guard.passed", True) in call_args_list
|
||||
assert ("guard.checked_by", "rule") in call_args_list
|
||||
|
|
|
|||
Loading…
Reference in New Issue