diff --git a/docs/plans/2026-06-10-019-feat-agentkit-deferred-improvements-plan.md b/docs/plans/2026-06-10-019-feat-agentkit-deferred-improvements-plan.md new file mode 100644 index 0000000..ec985b2 --- /dev/null +++ b/docs/plans/2026-06-10-019-feat-agentkit-deferred-improvements-plan.md @@ -0,0 +1,244 @@ +--- +title: "feat: AgentKit 劣势项改进" +status: active +created: 2026-06-10 +plan_type: feat +--- + +# feat: AgentKit 劣势项改进 + +## Summary + +针对代码审查后识别的 6 个延后劣势项进行改进:Soul演变多维度触发、中文分词增强、端到端可观测性、测试覆盖补充、ReWOO渐进回退、Agent间通信协议。 + +## Problem Frame + +当前 AgentKit 的多Agent架构已基本成型,但存在 6 个影响生产就绪度的劣势项。这些项不阻塞核心功能,但会随使用规模扩大而暴露问题。 + +## Scope Boundaries + +### In Scope +- Soul演变多维度触发(时间衰减+质量梯度+场景权重) +- 中文分词增强(2-gram+停用词,零依赖) +- 端到端可观测性(OpenTelemetry埋点) +- 集成测试补充(多模块交互测试) +- ReWOO渐进回退链 +- Agent间通信协议(消息总线) + +### Out of Scope +- jieba分词集成(可选依赖,后续按需) +- 分布式消息总线(Redis实现,后续按需) +- Live测试(需要真实API Key,CI中标记跳过) +- UI可视化trace(Jaeger/Zipkin前端,运维配置) + +--- + +## Key Technical Decisions + +### KTD1: 中文分词采用2-gram+停用词方案(零依赖) + +jieba虽精准但是额外依赖,2-gram方案对"数据分析""代码生成"等复合词已有足够捕获率。后续可按需引入jieba作为可选依赖。 + +### KTD2: 可观测性基于OpenTelemetry标准 + +各引擎已有 `trace_recorder` 参数,只需实现 OTel 导出器。配置通过 `telemetry.otlp_endpoint` 启用。 + +### KTD3: Agent间通信采用请求-响应+广播+协商三模式 + +基于 `AgentMessage` 数据类和 `MessageBus` 接口。先实现 `InMemoryMessageBus`,后续扩展 `RedisMessageBus`。 + +### KTD4: ReWOO回退链为 ReWOO简化→PlanExec→ReAct→Direct + +渐进式降级,每一步保留更多结构化能力。 + +--- + +## Implementation Units + +### U1. Soul演变多维度触发 + +**Goal:** 扩展Soul演变触发条件,支持时间衰减、质量梯度、场景权重 + +**Dependencies:** None + +**Files:** +- Modify: `src/agentkit/evolution/lifecycle.py` +- Modify: `src/agentkit/evolution/config.py` (新增 SoulEvolutionConfig) +- Test: `tests/unit/test_soul_evolution.py` + +**Approach:** +1. 新增 `SoulEvolutionConfig` dataclass,包含 `reflection_window_seconds`、`time_decay_factor`、`task_type_weights`、`quality_gradient_threshold` +2. 修改 `evolve_soul` 方法:时间衰减计算 `effective_count = sum(factor^age_hours)` 替代 `len(reflections)` +3. 新增质量梯度检测:追踪最近N次评分趋势,连续下降超过阈值时提前触发 +4. 场景权重:高风险任务降低触发阈值 + +**Test scenarios:** +- 时间衰减:1小时内的3次反思触发,超过窗口的不计入 +- 质量梯度:连续3次评分下降0.15触发,即使总反思次数<3 +- 场景权重:code_generation任务2次即触发(权重1.5),chat任务需4次(权重0.5) +- 兼容性:无config时行为与现有一致 + +**Verification:** 新增测试全部通过,现有Soul演变测试不受影响 + +--- + +### U2. 中文分词增强 + +**Goal:** 替换空格分词为2-gram+停用词方案,提升中文能力关键词提取 + +**Dependencies:** None + +**Files:** +- Modify: `src/agentkit/chat/skill_routing.py` +- Test: `tests/unit/test_cost_aware_router.py` + +**Approach:** +1. 提取 `_tokenize_content()` 为独立方法 +2. 标点分割 → 对长中文段补充2-gram → 停用词过滤 +3. Layer 2 路由调用新方法 + +**Test scenarios:** +- 中文:"帮我做数据分析" → 提取 ["数据分析", "帮我"] +- 英文:"help with code generation" → 提取 ["code generation", "help"] +- 混合:"用python做data analysis" → 提取 ["python", "data analysis"] +- 停用词过滤:"的了一个" → 过滤后为空 + +**Verification:** 中文分词测试通过,现有路由测试不受影响 + +--- + +### U3. 端到端可观测性(OpenTelemetry埋点) + +**Goal:** 为核心组件添加OTel trace/metrics埋点 + +**Dependencies:** None + +**Files:** +- Create: `src/agentkit/telemetry/__init__.py` +- Create: `src/agentkit/telemetry/tracer.py` (OTel导出器) +- Modify: `src/agentkit/chat/skill_routing.py` (路由埋点) +- Modify: `src/agentkit/core/react.py` (ReAct埋点) +- Modify: `src/agentkit/core/rewoo.py` (ReWOO埋点) +- Modify: `src/agentkit/core/reflexion.py` (Reflexion埋点) +- Modify: `src/agentkit/quality/alignment.py` (Guard埋点) +- Test: `tests/unit/test_telemetry.py` + +**Approach:** +1. 定义 `TelemetryConfig` 和 `get_tracer()` 工厂 +2. 实现 `OTelTraceRecorder` 适配器,桥接现有 `TraceRecorder` 接口 +3. 各组件在关键节点创建 Span:路由决策、引擎执行、约束检查 +4. 配置通过 `telemetry.otlp_endpoint` 启用,未配置时为 no-op + +**Test scenarios:** +- 未配置时:所有埋点为 no-op,不影响功能 +- 配置endpoint时:Span正确创建,属性包含 layer/complexity/status 等 +- 路由埋点:Layer 0/1/2 各产生带属性的 Span +- 引擎埋点:执行完成后 Span 记录 steps_count/total_tokens/status + +**Verification:** 新增测试通过,无OTel依赖时功能正常 + +--- + +### U4. 集成测试补充 + +**Goal:** 补充多模块交互集成测试,覆盖核心链路 + +**Dependencies:** U1, U2 + +**Files:** +- Create: `tests/integration/test_router_engine_chain.py` +- Create: `tests/integration/test_rewoo_fallback.py` +- Create: `tests/integration/test_reflexion_loop.py` +- Create: `tests/integration/test_soul_evolution_trigger.py` + +**Approach:** +1. 路由→执行→审计全链路测试(仅mock LLM) +2. ReWOO规划失败→回退测试 +3. Reflexion多轮循环测试 +4. Soul演变触发条件测试 + +**Test scenarios:** +- 全链路:CostAwareRouter路由到ReActEngine → AlignmentGuard检查通过 +- 全链路违规:AlignmentGuard检查失败,返回violations +- ReWOO回退:规划异常时正确回退到ReAct +- Reflexion:3轮循环后返回最佳结果 +- Soul触发:3次低质量反思后触发update_soul + +**Verification:** 集成测试全部通过 + +--- + +### U5. ReWOO渐进回退链 + +**Goal:** ReWOO规划失败时渐进降级,而非直接回退ReAct + +**Dependencies:** None + +**Files:** +- Modify: `src/agentkit/core/rewoo.py` +- Test: `tests/unit/test_rewoo_engine.py` + +**Approach:** +1. 新增 `FALLBACK_CHAIN` 配置:simplified_rewoo → plan_exec → react → direct +2. 规划失败时先尝试简化规划(max_plan_steps=3) +3. 简化仍失败则按链降级 +4. 回退信息记录在 ReWOOResult 中 + +**Test scenarios:** +- 正常规划成功:无回退 +- 规划失败→简化规划成功:结果标记 fallback=simplified +- 规划失败→简化失败→ReAct回退:结果标记 fallback=react +- 所有回退都失败:返回错误结果 + +**Verification:** 回退链测试通过 + +--- + +### U6. Agent间通信协议 + +**Goal:** 实现轻量级消息总线,支持Agent间请求-响应、广播、协商 + +**Dependencies:** None + +**Files:** +- Create: `src/agentkit/bus/__init__.py` +- Create: `src/agentkit/bus/message.py` (AgentMessage数据类) +- Create: `src/agentkit/bus/interface.py` (MessageBus接口) +- Create: `src/agentkit/bus/memory_bus.py` (InMemoryMessageBus实现) +- Modify: `src/agentkit/server/app.py` (初始化MessageBus) +- Test: `tests/unit/test_agent_bus.py` + +**Approach:** +1. 定义 `AgentMessage` dataclass:sender, recipient, msg_type, content, correlation_id, timestamp, ttl +2. 定义 `MessageBus` 抽象接口:publish, subscribe, request +3. 实现 `InMemoryMessageBus`:基于 asyncio.Queue +4. 与 AlignmentGuard 集成:消息经过约束检查 +5. 与 CascadeDetector 集成:消息传递计入级联检测 + +**Test scenarios:** +- 请求-响应:Agent A发送请求,Agent B响应,correlation_id匹配 +- 广播:Agent A广播,所有订阅者收到 +- 协商:Agent A发起协商,Agent B回复建议 +- TTL过期:超时消息被丢弃 +- 级联检测:超过阈值的消息链触发 CascadeAlert + +**Verification:** 消息总线测试通过,与Guard/Cascade集成正确 + +--- + +## Risks & Mitigations + +| Risk | Impact | Mitigation | +|------|--------|------------| +| OTel依赖增加包体积 | Low | OTel为可选依赖,未配置时no-op | +| 消息总线内存泄漏 | Medium | TTL过期清理 + 定期GC | +| ReWOO回退链增加延迟 | Low | 每步设置timeout,总回退时间上限5s | +| 2-gram分词噪声 | Low | 停用词过滤 + 长度阈值 | + +## Deferred to Follow-Up Work + +- jieba分词集成(可选依赖) +- RedisMessageBus分布式实现 +- Live测试(需真实API Key) +- Jaeger/Zipkin前端配置 +- 消息总线持久化 diff --git a/src/agentkit/bus/__init__.py b/src/agentkit/bus/__init__.py index a67dd29..f988579 100644 --- a/src/agentkit/bus/__init__.py +++ b/src/agentkit/bus/__init__.py @@ -1,13 +1,15 @@ """AgentKit Bus - Agent 间通信基础设施""" from agentkit.bus.message import AgentMessage -from agentkit.bus.protocol import MessageBus +from agentkit.bus.interface import MessageBus +from agentkit.bus.protocol import MessageBus as MessageBusProtocol from agentkit.bus.memory_bus import InMemoryMessageBus from agentkit.bus.redis_bus import RedisMessageBus, create_message_bus __all__ = [ "AgentMessage", "MessageBus", + "MessageBusProtocol", "InMemoryMessageBus", "RedisMessageBus", "create_message_bus", diff --git a/src/agentkit/bus/interface.py b/src/agentkit/bus/interface.py new file mode 100644 index 0000000..604877e --- /dev/null +++ b/src/agentkit/bus/interface.py @@ -0,0 +1,49 @@ +"""MessageBus ABC — Agent 间通信抽象基类。 + +与 protocol.py 的 Protocol 定义并存,提供 ABC 版本用于显式继承。 +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Callable, Awaitable + +from agentkit.bus.message import AgentMessage + + +class MessageBus(ABC): + """Agent 间消息总线抽象基类。 + + 支持三种通信模式: + - 点对点:publish() 指定 recipient + - 广播:publish() 不指定 recipient + - 请求-响应:request() 等待对方通过 correlation_id 回复 + """ + + @abstractmethod + async def publish(self, message: AgentMessage) -> bool: + """发布消息,返回是否成功。""" + ... + + @abstractmethod + async def subscribe( + self, + agent_name: str, + handler: Callable[[AgentMessage], Awaitable[None]], + ) -> None: + """订阅消息。""" + ... + + @abstractmethod + async def unsubscribe(self, agent_name: str) -> None: + """取消订阅。""" + ... + + @abstractmethod + async def request( + self, + message: AgentMessage, + timeout_seconds: float = 30.0, + ) -> AgentMessage | None: + """发送请求并等待响应。超时返回 None。""" + ... diff --git a/src/agentkit/bus/memory_bus.py b/src/agentkit/bus/memory_bus.py index 5d3cbd2..ddf7a3b 100644 --- a/src/agentkit/bus/memory_bus.py +++ b/src/agentkit/bus/memory_bus.py @@ -1,6 +1,7 @@ """InMemoryMessageBus — 基于 asyncio.Queue 的内存消息总线。 用于开发和测试,行为与 Redis 实现一致。 +集成 CascadeDetector 和 AlignmentGuard 进行消息质量管控。 """ from __future__ import annotations @@ -17,16 +18,48 @@ logger = logging.getLogger(__name__) class InMemoryMessageBus: """基于 asyncio.Queue 的内存消息总线。""" - def __init__(self) -> None: + def __init__( + self, + cascade_detector: Any = None, + alignment_guard: Any = None, + ) -> None: self._subscribers: dict[str, list[Callable[[AgentMessage], Awaitable[None]]]] = {} self._pending_requests: dict[str, asyncio.Future[AgentMessage]] = {} self._queues: dict[str, asyncio.Queue[AgentMessage]] = {} + self._consumer_tasks: dict[str, list[asyncio.Task]] = {} + self._cascade_detector = cascade_detector + self._alignment_guard = alignment_guard + + async def publish(self, message: AgentMessage) -> bool: + """发布消息,返回是否成功。""" + # TTL 过期检查 + if message.is_expired(): + logger.warning(f"Message {message.message_id} expired, dropping") + return False + + # Cascade detection — 级联故障检测 + if self._cascade_detector and message.sender: + alert = self._cascade_detector.check_interaction( + session_id=f"bus-{message.sender}-{message.recipient or 'broadcast'}" + ) + if alert: + logger.warning(f"Cascade alert: {alert}") + return False + + # Alignment check — 对齐守卫检查(仅对 request / negotiate 类型) + if self._alignment_guard and message.msg_type in ("request", "negotiate"): + check = await self._alignment_guard.check_output( + output={"content": str(message.content), **message.payload}, + ) + if not check.passed: + logger.warning( + f"Message blocked by alignment guard: {check.violations}" + ) + return False - async def publish(self, message: AgentMessage) -> None: - """发布消息。""" if message.is_broadcast: await self.broadcast(message) - return + return True # Point-to-point: deliver to recipient's queue recipient = message.recipient @@ -41,16 +74,9 @@ class InMemoryMessageBus: logger.warning(f"Handler error for {recipient}: {e}") # Check if this is a response to a pending request - # Only resolve if this is a reply (message_id != correlation_id), - # not the original request itself - if ( - message.correlation_id - and message.correlation_id in self._pending_requests - and message.message_id != message.correlation_id - ): - future = self._pending_requests[message.correlation_id] - if not future.done(): - future.set_result(message) + self._try_resolve_pending(message) + + return True async def subscribe( self, @@ -61,10 +87,12 @@ class InMemoryMessageBus: if agent_name not in self._subscribers: self._subscribers[agent_name] = [] self._queues[agent_name] = asyncio.Queue() + self._consumer_tasks[agent_name] = [] self._subscribers[agent_name].append(handler) - # Start consumer task - asyncio.create_task(self._consume_queue(agent_name, handler)) + # Start consumer task and track it + task = asyncio.create_task(self._consume_queue(agent_name, handler)) + self._consumer_tasks[agent_name].append(task) async def _consume_queue( self, @@ -82,34 +110,59 @@ class InMemoryMessageBus: await handler(message) except Exception as e: logger.warning(f"Handler error for {agent_name}: {e}") + + # Check pending requests after handler processes the message + # (e.g., handler may publish a response that resolves a future) + self._try_resolve_pending(message) except asyncio.CancelledError: break + def _try_resolve_pending(self, message: AgentMessage) -> None: + """Try to resolve a pending request future if this message is a response.""" + if ( + message.correlation_id + and message.correlation_id in self._pending_requests + and message.message_id != message.correlation_id + ): + future = self._pending_requests[message.correlation_id] + if not future.done(): + future.set_result(message) + async def unsubscribe(self, agent_name: str) -> None: """取消订阅。""" self._subscribers.pop(agent_name, None) self._queues.pop(agent_name, None) + # Cancel tracked consumer tasks + tasks = self._consumer_tasks.pop(agent_name, []) + for task in tasks: + if not task.done(): + task.cancel() async def request( self, message: AgentMessage, - timeout: float = 30.0, - ) -> AgentMessage: - """请求-响应模式。""" + timeout_seconds: float = 30.0, + ) -> AgentMessage | None: + """请求-响应模式。超时返回 None。""" + message.msg_type = "request" if not message.correlation_id: message.correlation_id = message.message_id - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() future: asyncio.Future[AgentMessage] = loop.create_future() self._pending_requests[message.correlation_id] = future + published = await self.publish(message) + if not published: + self._pending_requests.pop(message.correlation_id, None) + return None + try: - await self.publish(message) - return await asyncio.wait_for(future, timeout=timeout) + return await asyncio.wait_for(future, timeout=timeout_seconds) except asyncio.TimeoutError: - raise TimeoutError( - f"Request {message.correlation_id} timed out after {timeout}s" - ) + self._pending_requests.pop(message.correlation_id, None) + logger.warning(f"Request {message.correlation_id} timed out") + return None finally: self._pending_requests.pop(message.correlation_id, None) diff --git a/src/agentkit/bus/message.py b/src/agentkit/bus/message.py index 68b114c..dc770c1 100644 --- a/src/agentkit/bus/message.py +++ b/src/agentkit/bus/message.py @@ -21,34 +21,58 @@ class AgentMessage: recipient: str | None = None # None = broadcast topic: str = "" payload: dict[str, Any] = field(default_factory=dict) - timestamp: str = field( - default_factory=lambda: datetime.now(timezone.utc).isoformat(), - ) + timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) correlation_id: str | None = None # 请求-响应关联 + # --- 新增字段 --- + content: Any = None # 消息内容(与 payload 互补,payload 为 dict,content 可为任意类型) + msg_type: str = "notify" # "request" | "response" | "notify" | "negotiate" + ttl_seconds: int = 300 # 消息存活时间(秒) + + def is_expired(self) -> bool: + """检查消息是否已过期。""" + if isinstance(self.timestamp, datetime): + age = (datetime.now(timezone.utc) - self.timestamp).total_seconds() + return age > self.ttl_seconds + return False + + @property + def is_broadcast(self) -> bool: + return self.recipient is None def to_dict(self) -> dict[str, Any]: + ts = self.timestamp + if isinstance(ts, datetime): + ts = ts.isoformat() return { "message_id": self.message_id, "sender": self.sender, "recipient": self.recipient, "topic": self.topic, "payload": self.payload, - "timestamp": self.timestamp, + "timestamp": ts, "correlation_id": self.correlation_id, + "content": self.content, + "msg_type": self.msg_type, + "ttl_seconds": self.ttl_seconds, } @classmethod def from_dict(cls, data: dict[str, Any]) -> AgentMessage: + ts = data.get("timestamp", "") + if isinstance(ts, str) and ts: + try: + ts = datetime.fromisoformat(ts) + except (ValueError, TypeError): + pass # keep as string if parse fails return cls( message_id=data.get("message_id", ""), sender=data.get("sender", ""), recipient=data.get("recipient"), topic=data.get("topic", ""), payload=data.get("payload", {}), - timestamp=data.get("timestamp", ""), + timestamp=ts, correlation_id=data.get("correlation_id"), + content=data.get("content"), + msg_type=data.get("msg_type", "notify"), + ttl_seconds=data.get("ttl_seconds", 300), ) - - @property - def is_broadcast(self) -> bool: - return self.recipient is None diff --git a/src/agentkit/chat/skill_routing.py b/src/agentkit/chat/skill_routing.py index 3b6659d..055bf78 100644 --- a/src/agentkit/chat/skill_routing.py +++ b/src/agentkit/chat/skill_routing.py @@ -12,6 +12,8 @@ import re from dataclasses import dataclass, field from typing import Any +from agentkit.telemetry.tracer import get_tracer + logger = logging.getLogger(__name__) # Strict validation: only lowercase alphanumeric, hyphens, underscores @@ -187,6 +189,31 @@ _CHAT_MODE_RE = re.compile( ) +def _tokenize_content(content: str) -> list[str]: + """Tokenize content for capability matching. Supports Chinese and English.""" + # 1. Split by punctuation and whitespace + segments = re.split(r'[\s,,。!?、;:\n]+', content) + + # 2. For long Chinese segments, add 2-gram supplements + tokens = [] + for seg in segments: + if len(seg) <= 4: + tokens.append(seg) + else: + tokens.append(seg) + # Add 2-grams for Chinese compound words + for i in range(len(seg) - 1): + bigram = seg[i:i+2] + if all('\u4e00' <= c <= '\u9fff' for c in bigram): + tokens.append(bigram) + + # 3. Filter stopwords + stopwords = {"的", "了", "是", "在", "和", "与", "也", "都", "就", "要", "会", "我", "你", "他", "这", "那", "有", "没", "不"} + tokens = [t for t in tokens if t not in stopwords and len(t) > 1][:10] + + return tokens + + class CostAwareRouter: """三层成本感知路由器。 @@ -245,7 +272,9 @@ class CostAwareRouter: 'You are a complexity classifier. Rate the complexity of the user request on a scale of 0.0 to 1.0.\n' '0.0 = trivial greeting, 0.3 = simple question, 0.5 = moderate task, ' '0.7 = complex multi-step task, 1.0 = very complex research task.\n\n' - f'User request: "{content}"\n\n' + '---BEGIN USER REQUEST---\n' + f'{content}\n' + '---END USER REQUEST---\n\n' 'Respond ONLY with a JSON object: {"complexity": }' ) try: @@ -281,10 +310,7 @@ class CostAwareRouter: try: # Extract capability-like keywords from content for matching # find_best_agent expects list[str] of required capabilities - # Support both space-separated (English) and punctuation-separated (Chinese) content - import re - tokens = re.split(r'[\s,,。!?、;:\n]+', content) - content_words = [t for t in tokens if len(t) > 1][:5] + content_words = _tokenize_content(content) best_agent = self._org_context.find_best_agent(required_capabilities=content_words) if best_agent is not None: agent_name = best_agent if isinstance(best_agent, str) else getattr(best_agent, "name", str(best_agent)) @@ -365,11 +391,130 @@ class CostAwareRouter: """ trace: list[dict] = [] - # ---- Layer 0: Rule-based (zero cost) ---- - match_type, clean_content = self._match_layer0(content) + tracer = get_tracer() + with tracer.start_span("router.route") as span: + span.set_attribute("input.length", len(content)) - if match_type == "explicit_skill": - result = await resolve_skill_routing( + # ---- Layer 0: Rule-based (zero cost) ---- + match_type, clean_content = self._match_layer0(content) + + if match_type == "explicit_skill": + result = await resolve_skill_routing( + content=content, + skill_registry=skill_registry, + intent_router=intent_router, + default_tools=default_tools, + default_system_prompt=default_system_prompt, + default_model=default_model, + default_agent_name=default_agent_name, + agent_tool_registry=agent_tool_registry, + session_id=session_id, + ) + result.match_method = result.match_method or "explicit_skill" + result.complexity = 0.0 + trace.append({ + "layer": 0, + "method": "explicit_skill", + "matched": result.matched, + "cost": "zero", + }) + result.execution_trace = trace if transparency != "SILENT" else [] + result.transparency_level = transparency + span.set_attribute("route.layer", result.match_method or "explicit_skill") + span.set_attribute("route.target", result.skill_name or "default") + return result + + if match_type in ("greeting", "chat_mode"): + result = SkillRoutingResult( + clean_content=clean_content, + system_prompt=default_system_prompt, + tools=default_tools, + model=default_model, + agent_name=default_agent_name, + matched=False, + match_method=match_type, + match_confidence=1.0, + complexity=0.0, + ) + trace.append({ + "layer": 0, + "method": match_type, + "matched": False, + "cost": "zero", + }) + result.execution_trace = trace if transparency != "SILENT" else [] + result.transparency_level = transparency + span.set_attribute("route.layer", match_type) + span.set_attribute("route.target", "default") + return result + + # ---- Layer 1: LLM quick classify (~100 tokens) ---- + complexity = await self.quick_classify(clean_content) + trace.append({ + "layer": 1, + "method": "quick_classify", + "complexity": complexity, + }) + + # Low complexity → default agent + if complexity < 0.3: + result = SkillRoutingResult( + clean_content=clean_content, + system_prompt=default_system_prompt, + tools=default_tools, + model=default_model, + agent_name=default_agent_name, + matched=False, + match_method="low_complexity", + match_confidence=1.0 - complexity, + complexity=complexity, + ) + trace.append({ + "layer": 1, + "method": "low_complexity", + "complexity": complexity, + "routed_to": "default", + }) + result.execution_trace = trace if transparency != "SILENT" else [] + result.transparency_level = transparency + span.set_attribute("route.layer", "low_complexity") + span.set_attribute("route.target", "default") + return result + + # Medium complexity → IntentRouter via resolve_skill_routing + if complexity <= 0.7: + result = await resolve_skill_routing( + content=content, + skill_registry=skill_registry, + intent_router=intent_router, + default_tools=default_tools, + default_system_prompt=default_system_prompt, + default_model=default_model, + default_agent_name=default_agent_name, + agent_tool_registry=agent_tool_registry, + session_id=session_id, + ) + result.complexity = complexity + trace.append({ + "layer": 1, + "method": "intent_router", + "complexity": complexity, + "matched": result.matched, + }) + result.execution_trace = trace if transparency != "SILENT" else [] + result.transparency_level = transparency + span.set_attribute("route.layer", result.match_method or "intent_router") + span.set_attribute("route.target", result.skill_name or "default") + return result + + # ---- Layer 2: Capability matching / Auction (high complexity) ---- + trace.append({ + "layer": 2, + "method": "capability_or_auction", + "complexity": complexity, + "auction_enabled": self._auction_enabled, + }) + result = await self._route_layer2( content=content, skill_registry=skill_registry, intent_router=intent_router, @@ -379,116 +524,11 @@ class CostAwareRouter: default_agent_name=default_agent_name, agent_tool_registry=agent_tool_registry, session_id=session_id, - ) - result.match_method = result.match_method or "explicit_skill" - result.complexity = 0.0 - trace.append({ - "layer": 0, - "method": "explicit_skill", - "matched": result.matched, - "cost": "zero", - }) - result.execution_trace = trace if transparency != "SILENT" else [] - result.transparency_level = transparency - return result - - if match_type in ("greeting", "chat_mode"): - result = SkillRoutingResult( - clean_content=clean_content, - system_prompt=default_system_prompt, - tools=default_tools, - model=default_model, - agent_name=default_agent_name, - matched=False, - match_method=match_type, - match_confidence=1.0, - complexity=0.0, - ) - trace.append({ - "layer": 0, - "method": match_type, - "matched": False, - "cost": "zero", - }) - result.execution_trace = trace if transparency != "SILENT" else [] - result.transparency_level = transparency - return result - - # ---- Layer 1: LLM quick classify (~100 tokens) ---- - complexity = await self.quick_classify(clean_content) - trace.append({ - "layer": 1, - "method": "quick_classify", - "complexity": complexity, - }) - - # Low complexity → default agent - if complexity < 0.3: - result = SkillRoutingResult( - clean_content=clean_content, - system_prompt=default_system_prompt, - tools=default_tools, - model=default_model, - agent_name=default_agent_name, - matched=False, - match_method="low_complexity", - match_confidence=1.0 - complexity, complexity=complexity, + trace=trace, ) - trace.append({ - "layer": 1, - "method": "low_complexity", - "complexity": complexity, - "routed_to": "default", - }) result.execution_trace = trace if transparency != "SILENT" else [] result.transparency_level = transparency + span.set_attribute("route.layer", result.match_method or "capability") + span.set_attribute("route.target", result.skill_name or result.agent_name or "default") return result - - # Medium complexity → IntentRouter via resolve_skill_routing - if complexity <= 0.7: - result = await resolve_skill_routing( - content=content, - skill_registry=skill_registry, - intent_router=intent_router, - default_tools=default_tools, - default_system_prompt=default_system_prompt, - default_model=default_model, - default_agent_name=default_agent_name, - agent_tool_registry=agent_tool_registry, - session_id=session_id, - ) - result.complexity = complexity - trace.append({ - "layer": 1, - "method": "intent_router", - "complexity": complexity, - "matched": result.matched, - }) - result.execution_trace = trace if transparency != "SILENT" else [] - result.transparency_level = transparency - return result - - # ---- Layer 2: Capability matching / Auction (high complexity) ---- - trace.append({ - "layer": 2, - "method": "capability_or_auction", - "complexity": complexity, - "auction_enabled": self._auction_enabled, - }) - result = await self._route_layer2( - content=content, - skill_registry=skill_registry, - intent_router=intent_router, - default_tools=default_tools, - default_system_prompt=default_system_prompt, - default_model=default_model, - default_agent_name=default_agent_name, - agent_tool_registry=agent_tool_registry, - session_id=session_id, - complexity=complexity, - trace=trace, - ) - result.execution_trace = trace if transparency != "SILENT" else [] - result.transparency_level = transparency - return result diff --git a/src/agentkit/core/react.py b/src/agentkit/core/react.py index f94e90b..8eba539 100644 --- a/src/agentkit/core/react.py +++ b/src/agentkit/core/react.py @@ -54,6 +54,7 @@ class ReActResult: total_steps: int total_tokens: int status: str = "success" # "success" | "timeout" | "cancelled" | "partial" + fallback_strategy: str | None = None # e.g. "simplified_rewoo", "react", "direct" @dataclass diff --git a/src/agentkit/core/rewoo.py b/src/agentkit/core/rewoo.py index b8795ce..39ec7e8 100644 --- a/src/agentkit/core/rewoo.py +++ b/src/agentkit/core/rewoo.py @@ -115,6 +115,8 @@ class ReWOOEngine: 3. Synthesis Phase: 综合所有工具结果生成最终输出 """ + FALLBACK_STRATEGIES = ["simplified_rewoo", "react", "direct"] + def __init__(self, llm_gateway: LLMGateway, max_plan_steps: int = 10, default_timeout: float = 300.0): if max_plan_steps < 1: raise ValueError(f"max_plan_steps must be >= 1, got {max_plan_steps}") @@ -283,6 +285,8 @@ class ReWOOEngine: ) total_tokens += planning_tokens + fallback_strategy: str | None = None + # 记录规划步骤 if trace_recorder is not None: trace_recorder.record_step( @@ -292,26 +296,109 @@ class ReWOOEngine: tokens_used=planning_tokens, ) - # 如果规划失败,回退到 ReAct + # 如果规划失败,尝试渐进式回退 if plan is None: + # 尝试简化规划(max_steps=3) + logger.warning("ReWOO planning failed, trying simplified planning with max_steps=3") + try: + plan, simplified_tokens = await self._plan_phase( + messages=messages, + tools=tools, + tool_schemas=tool_schemas, + model=model, + agent_name=agent_name, + task_type=task_type, + system_prompt=effective_system_prompt, + compressor=compressor, + cancellation_token=cancellation_token, + max_steps=3, + ) + total_tokens += simplified_tokens + if plan is not None and plan.steps: + fallback_strategy = "simplified_rewoo" + logger.info("Simplified ReWOO planning succeeded") + except Exception as e2: + logger.warning(f"Simplified ReWOO planning also failed: {e2}") + + if plan is None: + # 回退到 ReAct + fallback_strategy = "react" logger.warning("ReWOO planning failed, falling back to ReActEngine") if trace_recorder is not None: trace_recorder.end_trace(outcome="fallback") - return await self._react_engine.execute( - messages=messages, - tools=tools, - model=model, - agent_name=agent_name, - task_type=task_type, - system_prompt=system_prompt, - trace_recorder=trace_recorder, - memory_retriever=memory_retriever, - task_id=task_id, - compressor=compressor, - retrieval_config=retrieval_config, - cancellation_token=cancellation_token, - timeout_seconds=0, # timeout already handled by outer wrapper - ) + try: + react_result = await self._react_engine.execute( + messages=messages, + tools=tools, + model=model, + agent_name=agent_name, + task_type=task_type, + system_prompt=system_prompt, + trace_recorder=trace_recorder, + memory_retriever=memory_retriever, + task_id=task_id, + compressor=compressor, + retrieval_config=retrieval_config, + cancellation_token=cancellation_token, + timeout_seconds=0, # timeout already handled by outer wrapper + ) + react_result.fallback_strategy = fallback_strategy + return react_result + except Exception as react_err: + # ReAct 也失败,回退到 Direct(简单 LLM 调用) + fallback_strategy = "direct" + logger.warning(f"ReAct fallback also failed: {react_err}, falling back to direct LLM call") + try: + direct_messages: list[dict[str, Any]] = [] + if effective_system_prompt: + direct_messages.append({"role": "system", "content": effective_system_prompt}) + direct_messages.extend(messages) + + if compressor: + try: + direct_messages = await compressor.compress(direct_messages) + except Exception as e: + logger.warning(f"Context compression failed in direct fallback: {e}") + + direct_response = await self._llm_gateway.chat( + messages=direct_messages, + model=model, + agent_name=agent_name, + task_type=task_type, + ) + total_tokens += direct_response.usage.total_tokens + + direct_step = ReWOOStep( + step=1, + action="final_answer", + content=direct_response.content, + tokens=direct_response.usage.total_tokens, + plan_step_id=None, + ) + trajectory.append(direct_step) + + if trace_recorder is not None: + trace_recorder.record_step( + step=1, + action="final_answer", + output_data={"content": direct_response.content}, + tokens_used=direct_response.usage.total_tokens, + ) + + trace_outcome = "success" + if trace_recorder is not None: + trace_recorder.end_trace(outcome=trace_outcome) + + return ReActResult( + output=direct_response.content or "", + trajectory=trajectory, + total_steps=len(trajectory), + total_tokens=total_tokens, + fallback_strategy=fallback_strategy, + ) + except Exception as direct_err: + logger.error(f"Direct LLM fallback also failed: {direct_err}") + raise # 如果计划为空(无需工具),直接让 LLM 回答 if not plan.steps: @@ -360,6 +447,7 @@ class ReWOOEngine: trajectory=trajectory, total_steps=len(trajectory), total_tokens=total_tokens, + fallback_strategy=fallback_strategy, ) # ── Phase 2: Execution ── @@ -462,6 +550,7 @@ class ReWOOEngine: trajectory=trajectory, total_steps=len(trajectory), total_tokens=total_tokens, + fallback_strategy=fallback_strategy, ) finally: # Telemetry: end span and record duration @@ -558,26 +647,84 @@ class ReWOOEngine: ) total_tokens += planning_tokens + if plan is None: + # Try simplified planning + logger.warning("ReWOO planning failed in stream mode, trying simplified planning with max_steps=3") + try: + plan, simplified_tokens = await self._plan_phase( + messages=messages, + tools=tools, + tool_schemas=tool_schemas, + model=model, + agent_name=agent_name, + task_type=task_type, + system_prompt=effective_system_prompt, + compressor=compressor, + cancellation_token=cancellation_token, + max_steps=3, + ) + total_tokens += simplified_tokens + except Exception as e2: + logger.warning(f"Simplified ReWOO planning also failed in stream mode: {e2}") + if plan is None: # Planning failed, fall back to ReAct streaming logger.warning("ReWOO planning failed in stream mode, falling back to ReActEngine") - async for event in self._react_engine.execute_stream( - messages=messages, - tools=tools, - model=model, - agent_name=agent_name, - task_type=task_type, - system_prompt=system_prompt, - trace_recorder=trace_recorder, - memory_retriever=memory_retriever, - task_id=task_id, - compressor=compressor, - retrieval_config=retrieval_config, - cancellation_token=cancellation_token, - timeout_seconds=0, - ): - yield event - return + try: + async for event in self._react_engine.execute_stream( + messages=messages, + tools=tools, + model=model, + agent_name=agent_name, + task_type=task_type, + system_prompt=system_prompt, + trace_recorder=trace_recorder, + memory_retriever=memory_retriever, + task_id=task_id, + compressor=compressor, + retrieval_config=retrieval_config, + cancellation_token=cancellation_token, + timeout_seconds=0, + ): + yield event + return + except Exception as react_err: + # ReAct also failed, fall back to direct LLM call + logger.warning(f"ReAct fallback also failed in stream mode: {react_err}, falling back to direct LLM call") + try: + direct_messages: list[dict[str, Any]] = [] + if effective_system_prompt: + direct_messages.append({"role": "system", "content": effective_system_prompt}) + direct_messages.extend(messages) + + if compressor: + try: + direct_messages = await compressor.compress(direct_messages) + except Exception as e: + logger.warning(f"Context compression failed in direct fallback: {e}") + + direct_response = await self._llm_gateway.chat( + messages=direct_messages, + model=model, + agent_name=agent_name, + task_type=task_type, + ) + total_tokens += direct_response.usage.total_tokens + output = direct_response.content or "" + + yield ReActEvent( + event_type="final_answer", + step=1, + data={ + "output": output, + "total_steps": 1, + "total_tokens": total_tokens, + }, + ) + return + except Exception as direct_err: + logger.error(f"Direct LLM fallback also failed in stream mode: {direct_err}") + raise yield ReActEvent( event_type="plan_generated", @@ -758,9 +905,13 @@ class ReWOOEngine: system_prompt: str | None, compressor: "CompressionStrategy | None", cancellation_token: CancellationToken | None, + max_steps: int | None = None, ) -> tuple[ReWOOPlan | None, int]: """Planning Phase: 调用 LLM 生成完整执行计划 + Args: + max_steps: 限制计划最大步数,None 则使用 self._max_plan_steps + Returns: (plan, tokens_used) - plan 为 None 表示规划失败 """ @@ -817,8 +968,9 @@ class ReWOOEngine: return None, tokens_used # 限制计划步数 - if len(plan.steps) > self._max_plan_steps: - plan.steps = plan.steps[:self._max_plan_steps] + effective_max_steps = max_steps if max_steps is not None else self._max_plan_steps + if len(plan.steps) > effective_max_steps: + plan.steps = plan.steps[:effective_max_steps] return plan, tokens_used diff --git a/src/agentkit/evolution/lifecycle.py b/src/agentkit/evolution/lifecycle.py index e165068..817f949 100644 --- a/src/agentkit/evolution/lifecycle.py +++ b/src/agentkit/evolution/lifecycle.py @@ -23,6 +23,17 @@ from agentkit.memory.profile import MemoryStore logger = logging.getLogger(__name__) +@dataclass +class SoulEvolutionConfig: + """Soul 进化多维触发配置""" + + min_reflections: int = 3 + reflection_window_seconds: int = 3600 + time_decay_factor: float = 0.5 + task_type_weights: dict[str, float] = field(default_factory=dict) + quality_gradient_threshold: float = -0.15 + + @dataclass class EvolutionLogEntry: """进化日志条目""" @@ -59,6 +70,7 @@ class EvolutionMixin: llm_gateway: Any | None = None, auxiliary_model: str | None = None, strategy_tuning_enabled: bool = False, + evolution_config: SoulEvolutionConfig | None = None, ): if reflector is not EvolutionMixin._UNSET: # 显式传入了 reflector 参数(包括 None) @@ -78,6 +90,7 @@ class EvolutionMixin: self._evolution_log: list[EvolutionLogEntry] = [] self._current_module: Module | None = None self._strategy_tuning_enabled = strategy_tuning_enabled + self._evolution_config = evolution_config self.pending_soul_updates: dict[str, list] = {} @staticmethod @@ -373,19 +386,41 @@ class EvolutionMixin: logger.error(f"Failed to rollback evolution change: {e}") return False + def record_reflection( + self, + pattern: str, + reflection: Reflection, + task_type: str = "", + score: float | None = None, + ) -> None: + """记录反思到待处理列表,附带时间戳、分数和任务类型。""" + if pattern not in self.pending_soul_updates: + self.pending_soul_updates[pattern] = [] + self.pending_soul_updates[pattern].append( + { + "reflection": reflection, + "timestamp": datetime.now(timezone.utc), + "score": score if score is not None else reflection.quality_score, + "task_type": task_type, + } + ) + async def evolve_soul( self, task: TaskMessage, result: TaskResult, memory_store: MemoryStore | None = None, reflection: Reflection | None = None, + task_type: str = "", + score: float | None = None, ) -> bool: """Check if soul should be updated based on accumulated reflections. - Conditions for soul update: - - Same category reflection appears >= 3 times - - Reflection quality_score < 0.5 (indicating consistent issues) - - Reflection has actionable suggestions + Multi-dimensional triggers: + - Time decay: older reflections contribute less + - Quality gradient: declining scores trigger early + - Task type weight: different task types have different trigger thresholds + - Trigger threshold: effective_count * weight >= min_reflections """ if memory_store is None: return False @@ -402,22 +437,53 @@ class EvolutionMixin: if not reflection.suggestions: return False - # 按 pattern 分类累积反思 - for pattern in reflection.patterns: - if pattern not in self.pending_soul_updates: - self.pending_soul_updates[pattern] = [] - self.pending_soul_updates[pattern].append(reflection) + config = self._evolution_config or SoulEvolutionConfig() - # 检查是否有同一类别累积 >= 3 次反思 + # 按 pattern 分类累积反思(patterns为空时使用默认category) + categories = reflection.patterns if reflection.patterns else ["default"] + for pattern in categories: + self.record_reflection( + pattern, reflection, task_type=task_type, score=score + ) + + # 检查是否有类别满足触发条件 for category, reflections in list(self.pending_soul_updates.items()): - if len(reflections) >= 3: + # --- Quality gradient: 3+ declining scores trigger early --- + scores = [r["score"] for r in reflections if r["score"] is not None] + quality_gradient_triggered = False + if len(scores) >= 3: + last_3 = scores[-3:] + declines = [ + last_3[i] - last_3[i - 1] for i in range(1, len(last_3)) + ] + if all(d <= config.quality_gradient_threshold for d in declines): + quality_gradient_triggered = True + + # --- Time decay: compute effective_count --- + now = datetime.now(timezone.utc) + effective_count = 0.0 + for r in reflections: + age_seconds = (now - r["timestamp"]).total_seconds() + age_hours = age_seconds / 3600.0 + effective_count += config.time_decay_factor ** age_hours + # Round to avoid floating-point precision issues + # (e.g. 3 recent reflections should yield exactly 3.0) + effective_count = round(effective_count, 6) + + # --- Task type weight --- + weight = 1.0 + if task_type and task_type in config.task_type_weights: + weight = config.task_type_weights[task_type] + + # --- Trigger threshold: effective_count * weight >= min_reflections --- + weighted_count = effective_count * weight + if weighted_count >= config.min_reflections or quality_gradient_triggered: # 触发 soul 更新 from agentkit.tools.memory_tool import MemoryTool tool = MemoryTool(memory_store) - # 使用第一个建议作为更新内容 section = category - content = "; ".join(reflections[-1].suggestions[:2]) + content = "; ".join(reflections[-1]["reflection"].suggestions[:2]) reason = f"连续{len(reflections)}次低质量反思 (category: {category})" update_result = await tool.execute( diff --git a/src/agentkit/quality/alignment.py b/src/agentkit/quality/alignment.py index fb5f1cf..9316e10 100644 --- a/src/agentkit/quality/alignment.py +++ b/src/agentkit/quality/alignment.py @@ -6,6 +6,8 @@ import logging from dataclasses import dataclass, field from typing import Any +from agentkit.telemetry.tracer import get_tracer + logger = logging.getLogger(__name__) @@ -92,24 +94,37 @@ class AlignmentGuard: if not effective_constraints: return AlignmentCheckResult(passed=True, checked_by="rule") - # 1. 基于规则的检查:关键词/子串匹配 - violations = self._rule_check(output, effective_constraints) - if violations: - return AlignmentCheckResult( - passed=False, - violations=violations, - checked_by="rule", - ) + tracer = get_tracer() + with tracer.start_span("guard.check") as span: + span.set_attribute("guard.constraints_count", len(effective_constraints)) - # 2. LLM 语义检查(仅当 audit_enabled=True 且有 llm_gateway,按采样率执行) - if self._config.audit_enabled and self._llm_gateway is not None: - import random - if random.random() < self._config.audit_sample_rate: - return await self._llm_check(output, effective_constraints) - # 采样未命中,信任规则检查结果 - logger.debug("LLM audit skipped (sample rate=%.2f)", self._config.audit_sample_rate) + # 1. 基于规则的检查:关键词/子串匹配 + violations = self._rule_check(output, effective_constraints) + if violations: + result = AlignmentCheckResult( + passed=False, + violations=violations, + checked_by="rule", + ) + span.set_attribute("guard.passed", result.passed) + span.set_attribute("guard.checked_by", result.checked_by) + return result - return AlignmentCheckResult(passed=True, checked_by="rule") + # 2. LLM 语义检查(仅当 audit_enabled=True 且有 llm_gateway,按采样率执行) + if self._config.audit_enabled and self._llm_gateway is not None: + import random + if random.random() < self._config.audit_sample_rate: + result = await self._llm_check(output, effective_constraints) + span.set_attribute("guard.passed", result.passed) + span.set_attribute("guard.checked_by", result.checked_by) + return result + # 采样未命中,信任规则检查结果 + logger.debug("LLM audit skipped (sample rate=%.2f)", self._config.audit_sample_rate) + + result = AlignmentCheckResult(passed=True, checked_by="rule") + span.set_attribute("guard.passed", result.passed) + span.set_attribute("guard.checked_by", result.checked_by) + return result def _rule_check( self, output: dict[str, Any], constraints: list[str] diff --git a/src/agentkit/telemetry/__init__.py b/src/agentkit/telemetry/__init__.py index 4f3984b..ad8ee18 100644 --- a/src/agentkit/telemetry/__init__.py +++ b/src/agentkit/telemetry/__init__.py @@ -20,6 +20,7 @@ from agentkit.telemetry.metrics import ( pipeline_step_histogram, ) from agentkit.telemetry.setup import setup_telemetry +from agentkit.telemetry.tracer import TelemetryConfig, NoOpTracer __all__ = [ "get_tracer", @@ -35,4 +36,6 @@ __all__ = [ "pipeline_step_histogram", "setup_telemetry", "_OTEL_AVAILABLE", + "TelemetryConfig", + "NoOpTracer", ] diff --git a/src/agentkit/telemetry/tracer.py b/src/agentkit/telemetry/tracer.py new file mode 100644 index 0000000..666ccd0 --- /dev/null +++ b/src/agentkit/telemetry/tracer.py @@ -0,0 +1,134 @@ +"""OpenTelemetry tracer integration with no-op fallback.""" +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from typing import Any, ContextManager + +logger = logging.getLogger(__name__) + + +@dataclass +class TelemetryConfig: + """Telemetry configuration.""" + + enabled: bool = False + otlp_endpoint: str = "" + service_name: str = "agentkit" + sample_rate: float = 1.0 + + +class NoOpSpan: + """No-op span when telemetry is disabled.""" + + def __enter__(self): + return self + + def __exit__(self, *args): + pass + + def set_attribute(self, key: str, value: Any) -> None: + pass + + def add_event(self, name: str, attributes: dict[str, Any] | None = None) -> None: + pass + + def record_exception(self, exception: Exception) -> None: + pass + + def is_recording(self) -> bool: + return False + + +class NoOpTracer: + """No-op tracer when telemetry is disabled.""" + + def start_span(self, name: str, attributes: dict[str, Any] | None = None) -> ContextManager[NoOpSpan]: + return NoOpSpan() + + def start_as_current_span(self, name: str, attributes: dict[str, Any] | None = None) -> ContextManager[NoOpSpan]: + return NoOpSpan() + + +class OTelSpan: + """Wrapper around OpenTelemetry Span.""" + + def __init__(self, span: Any): + self._span = span + + def __enter__(self): + self._span.__enter__() + return self + + def __exit__(self, *args): + self._span.__exit__(*args) + + def set_attribute(self, key: str, value: Any) -> None: + self._span.set_attribute(key, value) + + def add_event(self, name: str, attributes: dict[str, Any] | None = None) -> None: + self._span.add_event(name, attributes or {}) + + def record_exception(self, exception: Exception) -> None: + self._span.record_exception(exception) + + def is_recording(self) -> bool: + return self._span.is_recording() + + +class OTelTracer: + """Wrapper around OpenTelemetry Tracer.""" + + def __init__(self, tracer: Any): + self._tracer = tracer + + def start_span(self, name: str, attributes: dict[str, Any] | None = None) -> OTelSpan: + span = self._tracer.start_span(name, attributes=attributes) + return OTelSpan(span) + + def start_as_current_span(self, name: str, attributes: dict[str, Any] | None = None) -> OTelSpan: + span = self._tracer.start_as_current_span(name, attributes=attributes) + return OTelSpan(span) + + +# Global tracer instance +_tracer: NoOpTracer | OTelTracer = NoOpTracer() + + +def get_tracer() -> NoOpTracer | OTelTracer: + """Get the global tracer instance.""" + return _tracer + + +def init_telemetry(config: TelemetryConfig) -> None: + """Initialize telemetry with the given configuration.""" + global _tracer + + if not config.enabled: + _tracer = NoOpTracer() + logger.info("Telemetry disabled, using no-op tracer") + return + + try: + from opentelemetry import trace + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import BatchSpanProcessor + from opentelemetry.sdk.resources import Resource + from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter + + resource = Resource.create({"service.name": config.service_name}) + provider = TracerProvider(resource=resource) + + if config.otlp_endpoint: + exporter = OTLPSpanExporter(endpoint=config.otlp_endpoint) + provider.add_span_processor(BatchSpanProcessor(exporter)) + + trace.set_tracer_provider(provider) + _tracer = OTelTracer(trace.get_tracer(config.service_name)) + logger.info(f"Telemetry initialized with endpoint: {config.otlp_endpoint}") + except ImportError: + logger.warning("opentelemetry packages not installed, using no-op tracer") + _tracer = NoOpTracer() + except Exception as e: + logger.warning(f"Failed to initialize telemetry: {e}, using no-op tracer") + _tracer = NoOpTracer() diff --git a/tests/integration/test_reflexion_loop.py b/tests/integration/test_reflexion_loop.py new file mode 100644 index 0000000..3cb2b94 --- /dev/null +++ b/tests/integration/test_reflexion_loop.py @@ -0,0 +1,292 @@ +"""集成测试 - Reflexion 多轮循环 + +测试 ReflexionEngine 的 Evaluate→Reflect→Retry 循环。 +仅 mock LLMGateway(外部 API),使用真实 ReflexionEngine 实例。 +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from agentkit.core.react import ReActEngine, ReActStep +from agentkit.core.reflexion import ReflexionEngine, ReflexionReflection, ReflexionResult +from agentkit.llm.gateway import LLMGateway +from agentkit.llm.protocol import LLMResponse, TokenUsage + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_response( + content: str = "", + prompt_tokens: int = 10, + completion_tokens: int = 20, +) -> LLMResponse: + return LLMResponse( + content=content, + model="test-model", + usage=TokenUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ), + ) + + +def make_mock_gateway(responses: list[LLMResponse]) -> MagicMock: + gateway = MagicMock(spec=LLMGateway) + gateway.chat = AsyncMock(side_effect=responses) + return gateway + + +# --------------------------------------------------------------------------- +# Test 1: First attempt scores high → no retry, returns result +# --------------------------------------------------------------------------- + + +class TestReflexionFirstAttemptPasses: + """首次尝试分数高于阈值,无需重试""" + + @pytest.mark.asyncio + async def test_high_score_no_retry(self): + gateway = make_mock_gateway([ + # ReAct call + make_response(content="The answer is 42"), + # Evaluation call - high score + make_response(content='```json\n{"score": 0.9, "reasoning": "Excellent"}\n```'), + ]) + engine = ReflexionEngine(llm_gateway=gateway, quality_threshold=0.7) + + result = await engine.execute( + messages=[{"role": "user", "content": "What is the answer?"}], + ) + + assert isinstance(result, ReflexionResult) + assert result.output == "The answer is 42" + assert result.evaluation_score == 0.9 + assert result.reflection_count == 0 + assert len(result.reflections) == 0 + assert result.status == "success" + + @pytest.mark.asyncio + async def test_score_exactly_at_threshold(self): + """分数恰好等于阈值,无需重试""" + gateway = make_mock_gateway([ + make_response(content="Answer"), + make_response(content='```json\n{"score": 0.7, "reasoning": "OK"}\n```'), + ]) + engine = ReflexionEngine(llm_gateway=gateway, quality_threshold=0.7) + + result = await engine.execute( + messages=[{"role": "user", "content": "Task"}], + ) + + assert result.evaluation_score == 0.7 + assert result.reflection_count == 0 + + +# --------------------------------------------------------------------------- +# Test 2: First attempt scores low, second scores high → returns best result +# --------------------------------------------------------------------------- + + +class TestReflexionRetryImprovesScore: + """首次低分,反思后重试高分""" + + @pytest.mark.asyncio + async def test_reflection_and_retry_on_low_score(self): + gateway = make_mock_gateway([ + # 1st ReAct call + make_response(content="Initial poor answer"), + # 1st Evaluation - low score + make_response(content='```json\n{"score": 0.3, "reasoning": "Incomplete"}\n```'), + # 1st Reflection call + make_response(content="You need to be more specific and provide detailed analysis."), + # 2nd ReAct call + make_response(content="Improved detailed answer"), + # 2nd Evaluation - high score + make_response(content='```json\n{"score": 0.85, "reasoning": "Good improvement"}\n```'), + ]) + engine = ReflexionEngine(llm_gateway=gateway, quality_threshold=0.7) + + result = await engine.execute( + messages=[{"role": "user", "content": "Analyze this"}], + ) + + assert result.output == "Improved detailed answer" + assert result.evaluation_score == 0.85 + assert result.reflection_count == 1 + assert len(result.reflections) == 1 + assert result.reflections[0].score_before == 0.3 + assert result.reflections[0].retry_number == 1 + assert "specific" in result.reflections[0].reflection_text.lower() or "detailed" in result.reflections[0].reflection_text.lower() + + @pytest.mark.asyncio + async def test_multiple_retries_improve_score(self): + """多次重试后分数逐步提升""" + gateway = make_mock_gateway([ + # Attempt 1 + make_response(content="Bad answer"), + make_response(content='```json\n{"score": 0.2}\n```'), + make_response(content="Need more depth"), + # Attempt 2 + make_response(content="Better answer"), + make_response(content='```json\n{"score": 0.5}\n```'), + make_response(content="Still needs improvement"), + # Attempt 3 + make_response(content="Great answer"), + make_response(content='```json\n{"score": 0.9}\n```'), + ]) + engine = ReflexionEngine(llm_gateway=gateway, quality_threshold=0.7, max_reflections=3) + + result = await engine.execute( + messages=[{"role": "user", "content": "Complex task"}], + ) + + assert result.output == "Great answer" + assert result.evaluation_score == 0.9 + assert result.reflection_count == 2 + assert len(result.reflections) == 2 + assert result.reflections[0].retry_number == 1 + assert result.reflections[1].retry_number == 2 + + +# --------------------------------------------------------------------------- +# Test 3: Max reflections reached → returns best result found +# --------------------------------------------------------------------------- + + +class TestReflexionMaxReflectionsReached: + """达到最大反思次数后返回最佳结果""" + + @pytest.mark.asyncio + async def test_returns_best_result_when_max_reflections_reached(self): + gateway = make_mock_gateway([ + # Attempt 1 + make_response(content="Poor answer"), + make_response(content='```json\n{"score": 0.3}\n```'), + make_response(content="Try harder"), + # Attempt 2 + make_response(content="Slightly better answer"), + make_response(content='```json\n{"score": 0.5}\n```'), + make_response(content="Still not good enough"), + # Attempt 3 (max) + make_response(content="Another answer"), + make_response(content='```json\n{"score": 0.6}\n```'), + ]) + engine = ReflexionEngine(llm_gateway=gateway, quality_threshold=0.7, max_reflections=3) + + result = await engine.execute( + messages=[{"role": "user", "content": "Hard task"}], + ) + + # Should return the best result (score 0.6 from last attempt) + assert result.evaluation_score == 0.6 + assert result.reflection_count == 2 + assert result.output == "Another answer" + + @pytest.mark.asyncio + async def test_returns_earlier_best_when_later_worse(self): + """后续尝试分数更低时,返回之前最佳结果""" + gateway = make_mock_gateway([ + # Attempt 1: score 0.5 + make_response(content="Decent answer"), + make_response(content='```json\n{"score": 0.5}\n```'), + make_response(content="Try to improve"), + # Attempt 2: score 0.4 (worse) + make_response(content="Worse answer"), + make_response(content='```json\n{"score": 0.4}\n```'), + make_response(content="Still trying"), + # Attempt 3: score 0.45 (still worse than attempt 1) + make_response(content="Another answer"), + make_response(content='```json\n{"score": 0.45}\n```'), + # Reflection for attempt 3 (consumed but loop ends) + make_response(content="Final reflection"), + ]) + engine = ReflexionEngine(llm_gateway=gateway, quality_threshold=0.7, max_reflections=3) + + result = await engine.execute( + messages=[{"role": "user", "content": "Task"}], + ) + + # Best score was 0.5 from attempt 1 + assert result.evaluation_score == 0.5 + assert result.output == "Decent answer" + + +# --------------------------------------------------------------------------- +# Test 4: Reflection text improves subsequent attempts +# --------------------------------------------------------------------------- + + +class TestReflexionReflectionImprovesAttempts: + """反思文本改善后续尝试""" + + @pytest.mark.asyncio + async def test_reflection_injected_into_system_prompt(self): + """反思文本被注入到下一次 ReAct 的 system prompt 中""" + gateway = make_mock_gateway([ + make_response(content="Poor answer"), + make_response(content='```json\n{"score": 0.3}\n```'), + make_response(content="You need to provide more specific details."), + make_response(content="Better answer with details"), + make_response(content='```json\n{"score": 0.9}\n```'), + ]) + engine = ReflexionEngine(llm_gateway=gateway, quality_threshold=0.7) + + result = await engine.execute( + messages=[{"role": "user", "content": "Task"}], + system_prompt="You are a helpful assistant", + ) + + assert result.reflection_count == 1 + assert result.evaluation_score == 0.9 + assert result.output == "Better answer with details" + + # Verify the reflection was recorded with correct metadata + assert result.reflections[0].reflection_text == "You need to provide more specific details." + assert result.reflections[0].score_before == 0.3 + + @pytest.mark.asyncio + async def test_reflexion_composes_react_engine(self): + """ReflexionEngine 组合(而非继承)ReActEngine""" + gateway = make_mock_gateway([ + make_response(content="Answer"), + make_response(content='```json\n{"score": 0.9}\n```'), + ]) + engine = ReflexionEngine(llm_gateway=gateway) + + assert hasattr(engine, "_react_engine") + assert isinstance(engine._react_engine, ReActEngine) + assert not isinstance(engine, ReActEngine) + + @pytest.mark.asyncio + async def test_reflexion_result_has_all_fields(self): + """ReflexionResult 包含所有必要字段""" + gateway = make_mock_gateway([ + make_response(content="Answer"), + make_response(content='```json\n{"score": 0.85}\n```'), + ]) + engine = ReflexionEngine(llm_gateway=gateway) + + result = await engine.execute( + messages=[{"role": "user", "content": "Task"}], + ) + + # ReActResult fields + assert hasattr(result, "output") + assert hasattr(result, "trajectory") + assert hasattr(result, "total_steps") + assert hasattr(result, "total_tokens") + assert hasattr(result, "status") + + # ReflexionResult additional fields + assert hasattr(result, "evaluation_score") + assert hasattr(result, "reflection_count") + assert hasattr(result, "reflections") + + # All trajectory steps are ReActStep + assert all(isinstance(step, ReActStep) for step in result.trajectory) diff --git a/tests/integration/test_rewoo_fallback.py b/tests/integration/test_rewoo_fallback.py new file mode 100644 index 0000000..ede2d7f --- /dev/null +++ b/tests/integration/test_rewoo_fallback.py @@ -0,0 +1,293 @@ +"""集成测试 - ReWOO 渐进式回退链 + +测试 ReWOOEngine 的 planning → simplified_rewoo → react → direct 回退策略。 +仅 mock LLMGateway(外部 API),使用真实 ReWOOEngine 实例。 +""" + +from __future__ import annotations + +import json +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from agentkit.core.rewoo import ReWOOEngine, ReWOOStep +from agentkit.core.react import ReActResult, ReActStep +from agentkit.llm.gateway import LLMGateway +from agentkit.llm.protocol import LLMResponse, TokenUsage, ToolCall +from agentkit.tools.base import Tool + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +class FakeTool(Tool): + """用于测试的 Fake Tool""" + + def __init__( + self, + name: str = "fake_tool", + description: str = "A fake tool for testing", + result: dict | None = None, + should_fail: bool = False, + ): + super().__init__( + name=name, + description=description, + ) + self._result = result or {"status": "ok"} + self._should_fail = should_fail + self.call_count = 0 + + async def execute(self, **kwargs) -> dict: + self.call_count += 1 + if self._should_fail: + raise RuntimeError(f"Tool '{self.name}' execution failed") + return self._result + + +def make_response( + content: str = "", + tool_calls: list[ToolCall] | None = None, + prompt_tokens: int = 10, + completion_tokens: int = 20, +) -> LLMResponse: + return LLMResponse( + content=content, + model="test-model", + usage=TokenUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ), + tool_calls=tool_calls or [], + ) + + +def make_plan_response( + steps: list[dict], + reasoning: str = "Plan reasoning", +) -> LLMResponse: + plan_json = json.dumps({ + "reasoning": reasoning, + "steps": steps, + }) + return make_response(content=plan_json) + + +def make_mock_gateway(responses: list[LLMResponse]) -> MagicMock: + gateway = MagicMock(spec=LLMGateway) + gateway.chat = AsyncMock(side_effect=responses) + return gateway + + +# --------------------------------------------------------------------------- +# Test 1: Planning succeeds → no fallback +# --------------------------------------------------------------------------- + + +class TestReWOOPlanningSucceeds: + """规划成功,不使用回退""" + + @pytest.mark.asyncio + async def test_planning_succeeds_no_fallback(self): + tool = FakeTool(name="calculator", result={"value": 42}) + + plan_response = make_plan_response([ + {"step_id": 1, "tool_name": "calculator", "arguments": {"expr": "6*7"}, "reasoning": "Calculate"}, + ]) + synthesis_response = make_response(content="The result is 42") + + gateway = make_mock_gateway([plan_response, synthesis_response]) + engine = ReWOOEngine(llm_gateway=gateway) + + result = await engine.execute( + messages=[{"role": "user", "content": "Calculate 6*7"}], + tools=[tool], + ) + + assert isinstance(result, ReActResult) + assert result.output == "The result is 42" + assert result.fallback_strategy is None + assert result.status == "success" + # 1 tool_call + 1 final_answer = 2 steps + assert result.total_steps == 2 + assert tool.call_count == 1 + + # Verify ReWOOStep has plan_step_id + tool_steps = [s for s in result.trajectory if s.action == "tool_call"] + assert len(tool_steps) == 1 + assert isinstance(tool_steps[0], ReWOOStep) + assert tool_steps[0].plan_step_id == 1 + + +# --------------------------------------------------------------------------- +# Test 2: Planning fails → simplified planning succeeds → fallback_strategy="simplified_rewoo" +# --------------------------------------------------------------------------- + + +class TestReWOOSimplifiedFallback: + """规划失败,简化规划成功""" + + @pytest.mark.asyncio + async def test_planning_fails_simplified_succeeds(self): + tool = FakeTool(name="search", result={"results": ["found"]}) + + # First plan fails (invalid JSON), second (simplified) succeeds + invalid_plan_response = make_response(content="I cannot create a plan for this task.") + simplified_plan_response = make_plan_response([ + {"step_id": 1, "tool_name": "search", "arguments": {"query": "test"}, "reasoning": "Simplified search"}, + ]) + synthesis_response = make_response(content="Simplified result") + + gateway = make_mock_gateway([invalid_plan_response, simplified_plan_response, synthesis_response]) + engine = ReWOOEngine(llm_gateway=gateway) + + result = await engine.execute( + messages=[{"role": "user", "content": "Complex task"}], + tools=[tool], + ) + + assert result.output == "Simplified result" + assert result.fallback_strategy == "simplified_rewoo" + assert tool.call_count == 1 + + # Verify trajectory still has proper structure + tool_steps = [s for s in result.trajectory if s.action == "tool_call"] + assert len(tool_steps) == 1 + assert tool_steps[0].tool_name == "search" + + +# --------------------------------------------------------------------------- +# Test 3: All planning fails → ReAct fallback → fallback_strategy="react" +# --------------------------------------------------------------------------- + + +class TestReWOOReActFallback: + """所有规划失败,回退到 ReAct""" + + @pytest.mark.asyncio + async def test_planning_and_simplified_fail_react_succeeds(self): + # Both plan attempts fail (invalid JSON), ReAct succeeds + invalid_plan1 = make_response(content="Not a plan") + invalid_plan2 = make_response(content="Still not a plan") + react_response = make_response(content="ReAct fallback answer") + + gateway = make_mock_gateway([invalid_plan1, invalid_plan2, react_response]) + engine = ReWOOEngine(llm_gateway=gateway) + + result = await engine.execute( + messages=[{"role": "user", "content": "Complex task"}], + ) + + assert result.output == "ReAct fallback answer" + assert result.fallback_strategy == "react" + + @pytest.mark.asyncio + async def test_react_fallback_with_tool_calls(self): + """ReAct 回退时带工具调用""" + tool = FakeTool(name="search", result={"results": ["found"]}) + + invalid_plan1 = make_response(content="Cannot plan") + invalid_plan2 = make_response(content="Still cannot plan") + react_tool_response = make_response( + content="", + tool_calls=[ToolCall(id="tc_1", name="search", arguments={"query": "test"})], + ) + react_final_response = make_response(content="ReAct answer with tool") + + gateway = make_mock_gateway([ + invalid_plan1, + invalid_plan2, + react_tool_response, + react_final_response, + ]) + engine = ReWOOEngine(llm_gateway=gateway) + + result = await engine.execute( + messages=[{"role": "user", "content": "Search task"}], + tools=[tool], + ) + + assert result.output == "ReAct answer with tool" + assert result.fallback_strategy == "react" + assert tool.call_count == 1 + + @pytest.mark.asyncio + async def test_malformed_json_triggers_react_fallback(self): + """格式错误的 JSON 触发 ReAct 回退""" + malformed_response = make_response(content='{"reasoning": "plan", "steps": [invalid json') + simplified_fail_response = make_response(content="Also not a plan") + react_response = make_response(content="ReAct answer") + + gateway = make_mock_gateway([malformed_response, simplified_fail_response, react_response]) + engine = ReWOOEngine(llm_gateway=gateway) + + result = await engine.execute( + messages=[{"role": "user", "content": "Task"}], + ) + + assert result.output == "ReAct answer" + assert result.fallback_strategy == "react" + + @pytest.mark.asyncio + async def test_missing_steps_key_triggers_react_fallback(self): + """缺少 steps 键触发 ReAct 回退""" + no_steps_response = make_response(content='{"reasoning": "no steps here"}') + simplified_fail_response = make_response(content="Also no steps") + react_response = make_response(content="ReAct fallback") + + gateway = make_mock_gateway([no_steps_response, simplified_fail_response, react_response]) + engine = ReWOOEngine(llm_gateway=gateway) + + result = await engine.execute( + messages=[{"role": "user", "content": "Task"}], + ) + + assert result.output == "ReAct fallback" + assert result.fallback_strategy == "react" + + +# --------------------------------------------------------------------------- +# Test: Multi-step plan with fallback chain integration +# --------------------------------------------------------------------------- + + +class TestReWOOMultiStepWithFallbackChain: + """多步计划与回退链的集成测试""" + + @pytest.mark.asyncio + async def test_three_step_plan_no_fallback(self): + """三步计划成功,无回退""" + search_tool = FakeTool(name="search", result={"results": ["Python is great"]}) + calc_tool = FakeTool(name="calculator", result={"value": 100}) + weather_tool = FakeTool(name="weather", result={"temp": 25, "city": "Shanghai"}) + + plan_response = make_plan_response([ + {"step_id": 1, "tool_name": "search", "arguments": {"query": "Python"}, "reasoning": "Search first"}, + {"step_id": 2, "tool_name": "calculator", "arguments": {"expr": "10*10"}, "reasoning": "Calculate"}, + {"step_id": 3, "tool_name": "weather", "arguments": {"city": "Shanghai"}, "reasoning": "Check weather"}, + ]) + synthesis_response = make_response(content="Based on search, calculation (100), and weather (25°C)") + + gateway = make_mock_gateway([plan_response, synthesis_response]) + engine = ReWOOEngine(llm_gateway=gateway) + + result = await engine.execute( + messages=[{"role": "user", "content": "Search, calculate and check weather"}], + tools=[search_tool, calc_tool, weather_tool], + ) + + assert result.fallback_strategy is None + assert result.total_steps == 4 # 3 tool_calls + 1 final_answer + assert search_tool.call_count == 1 + assert calc_tool.call_count == 1 + assert weather_tool.call_count == 1 + assert "100" in result.output + assert "25" in result.output + + @pytest.mark.asyncio + async def test_fallback_strategies_constant(self): + """验证 FALLBACK_STRATEGIES 常量""" + assert ReWOOEngine.FALLBACK_STRATEGIES == ["simplified_rewoo", "react", "direct"] diff --git a/tests/integration/test_router_engine_chain.py b/tests/integration/test_router_engine_chain.py new file mode 100644 index 0000000..f208858 --- /dev/null +++ b/tests/integration/test_router_engine_chain.py @@ -0,0 +1,315 @@ +"""集成测试 - CostAwareRouter → Engine → AlignmentGuard 全链路""" + +from __future__ import annotations + +import json +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from agentkit.chat.skill_routing import CostAwareRouter, SkillRoutingResult +from agentkit.core.react import ReActEngine, ReActResult, ReActStep +from agentkit.llm.protocol import LLMResponse, TokenUsage +from agentkit.org.context import AgentProfile, OrganizationContext +from agentkit.quality.alignment import AlignmentConfig, AlignmentGuard + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_llm_response(content: str) -> LLMResponse: + return LLMResponse( + content=content, + model="test-model", + usage=TokenUsage(prompt_tokens=10, completion_tokens=20), + ) + + +def _make_mock_gateway(responses: list[LLMResponse]) -> MagicMock: + gateway = MagicMock() + gateway.chat = AsyncMock(side_effect=responses) + return gateway + + +# --------------------------------------------------------------------------- +# Test 1: Router routes to ReAct engine, output passes alignment check +# --------------------------------------------------------------------------- + + +class TestRouterToEnginePassesAlignment: + """路由到 ReAct 引擎,输出通过对齐检查""" + + @pytest.mark.asyncio + async def test_react_output_passes_alignment(self): + # --- Setup: LLM returns low complexity → default agent (ReAct) --- + gateway = _make_mock_gateway([ + _make_llm_response('{"complexity": 0.2}'), # quick_classify + _make_llm_response("你好!有什么可以帮你的?"), # ReAct final answer + ]) + + org_context = OrganizationContext() + alignment_config = AlignmentConfig( + constraints=["password", "secret_key"], + ) + guard = AlignmentGuard(config=alignment_config) + + router = CostAwareRouter(llm_gateway=gateway, org_context=org_context) + + mock_skill_registry = MagicMock() + mock_skill_registry.list_skills.return_value = [] + mock_intent_router = AsyncMock() + + # Step 1: Route + route_result = await router.route( + content="随便聊聊", + skill_registry=mock_skill_registry, + intent_router=mock_intent_router, + default_tools=[], + default_system_prompt="You are helpful", + default_model="default", + default_agent_name="default", + ) + assert route_result.complexity < 0.3 + assert route_result.agent_name == "default" + + # Step 2: Inject constraints + input_data = {"content": route_result.clean_content} + injected = guard.inject_constraints(input_data) + assert "alignment_constraints" in injected + + # Step 3: Simulate engine execution (use real ReActEngine with mock gateway) + react_engine = ReActEngine(llm_gateway=gateway) + engine_result = await react_engine.execute( + messages=[{"role": "user", "content": injected["content"]}], + ) + + # Step 4: Alignment check + output = {"result": engine_result.output} + check_result = await guard.check_output(output) + assert check_result.passed is True + assert check_result.violations == [] + + +# --------------------------------------------------------------------------- +# Test 2: Router routes to ReAct engine, output fails alignment check +# --------------------------------------------------------------------------- + + +class TestRouterToEngineFailsAlignment: + """路由到 ReAct 引擎,输出未通过对齐检查""" + + @pytest.mark.asyncio + async def test_react_output_fails_alignment(self): + gateway = _make_mock_gateway([ + _make_llm_response('{"complexity": 0.2}'), + _make_llm_response("Your password is 123456"), + ]) + + org_context = OrganizationContext() + alignment_config = AlignmentConfig( + constraints=["password", "secret_key"], + ) + guard = AlignmentGuard(config=alignment_config) + + router = CostAwareRouter(llm_gateway=gateway, org_context=org_context) + + mock_skill_registry = MagicMock() + mock_skill_registry.list_skills.return_value = [] + mock_intent_router = AsyncMock() + + route_result = await router.route( + content="随便聊聊", + skill_registry=mock_skill_registry, + intent_router=mock_intent_router, + default_tools=[], + default_system_prompt="You are helpful", + default_model="default", + default_agent_name="default", + ) + + input_data = {"content": route_result.clean_content} + injected = guard.inject_constraints(input_data) + + react_engine = ReActEngine(llm_gateway=gateway) + engine_result = await react_engine.execute( + messages=[{"role": "user", "content": injected["content"]}], + ) + + output = {"result": engine_result.output} + check_result = await guard.check_output(output) + assert check_result.passed is False + assert len(check_result.violations) > 0 + assert "password" in check_result.violations + + +# --------------------------------------------------------------------------- +# Test 3: Router routes based on complexity (low→default, high→org_context) +# --------------------------------------------------------------------------- + + +class TestRouterComplexityBasedRouting: + """基于复杂度的路由:低复杂度→默认,高复杂度→org_context 能力匹配""" + + @pytest.mark.asyncio + async def test_low_complexity_routes_to_default(self): + gateway = _make_mock_gateway([ + _make_llm_response('{"complexity": 0.15}'), + ]) + + org_context = OrganizationContext() + org_context.register_agent(AgentProfile( + name="analyst", + agent_type="react", + capabilities=["analysis"], + skills=["analysis"], + )) + + router = CostAwareRouter(llm_gateway=gateway, org_context=org_context) + + mock_skill_registry = MagicMock() + mock_skill_registry.list_skills.return_value = [] + mock_intent_router = AsyncMock() + + result = await router.route( + content="简单问题", + skill_registry=mock_skill_registry, + intent_router=mock_intent_router, + default_tools=[], + default_system_prompt="You are helpful", + default_model="default", + default_agent_name="default", + ) + + assert result.complexity < 0.3 + assert result.agent_name == "default" + assert result.match_method == "low_complexity" + + @pytest.mark.asyncio + async def test_high_complexity_routes_via_org_context(self): + gateway = _make_mock_gateway([ + _make_llm_response('{"complexity": 0.85}'), + ]) + + org_context = OrganizationContext() + org_context.register_agent(AgentProfile( + name="analyst", + agent_type="react", + capabilities=["分析", "市场", "调研"], + skills=["market_analysis"], + current_load=0, + )) + + # find_best_agent returns real AgentProfile + org_context.find_best_agent = MagicMock( + return_value=org_context.get_agent_profile("analyst") + ) + + router = CostAwareRouter(llm_gateway=gateway, org_context=org_context) + + mock_skill_registry = MagicMock() + mock_skill_registry.list_skills.return_value = [] + mock_intent_router = AsyncMock() + + result = await router.route( + content="请对市场趋势进行深度分析并给出投资建议", + skill_registry=mock_skill_registry, + intent_router=mock_intent_router, + default_tools=[], + default_system_prompt="You are helpful", + default_model="default", + default_agent_name="default", + ) + + assert result.complexity >= 0.7 + assert result.match_method == "capability" + assert result.agent_name == "analyst" + assert result.matched is True + + +# --------------------------------------------------------------------------- +# Test 4: AlignmentGuard injects constraints into input before engine execution +# --------------------------------------------------------------------------- + + +class TestAlignmentGuardConstraintInjection: + """AlignmentGuard 在引擎执行前将约束注入到输入中""" + + @pytest.mark.asyncio + async def test_constraints_injected_before_engine_execution(self): + gateway = _make_mock_gateway([ + _make_llm_response('{"complexity": 0.5}'), + _make_llm_response("Safe answer"), + ]) + + alignment_config = AlignmentConfig( + constraints=["不得泄露用户隐私", "禁止生成有害内容"], + ) + guard = AlignmentGuard(config=alignment_config) + + org_context = OrganizationContext() + router = CostAwareRouter(llm_gateway=gateway, org_context=org_context) + + mock_skill_registry = MagicMock() + mock_skill_registry.list_skills.return_value = [] + mock_intent_router = AsyncMock() + + # Step 1: Route + route_result = await router.route( + content="请帮我写一篇文章", + skill_registry=mock_skill_registry, + intent_router=mock_intent_router, + default_tools=[], + default_system_prompt="You are helpful", + default_model="default", + default_agent_name="default", + ) + + # Step 2: Inject constraints + input_data = {"content": route_result.clean_content} + injected = guard.inject_constraints(input_data) + + # Verify constraints are present + assert "alignment_constraints" in injected + assert "不得泄露用户隐私" in injected["alignment_constraints"] + assert "禁止生成有害内容" in injected["alignment_constraints"] + # Original data preserved + assert injected["content"] == route_result.clean_content + # Original dict not mutated + assert "alignment_constraints" not in input_data + + # Step 3: Engine executes with injected input + react_engine = ReActEngine(llm_gateway=gateway) + engine_result = await react_engine.execute( + messages=[{"role": "user", "content": injected["content"]}], + ) + + # Step 4: Output passes alignment + output = {"result": engine_result.output} + check_result = await guard.check_output(output) + assert check_result.passed is True + + @pytest.mark.asyncio + async def test_constraint_injection_with_cascade_monitoring(self): + """约束注入 + 级联故障监控的完整链路""" + alignment_config = AlignmentConfig( + constraints=["password"], + cascade_max_interactions=5, + ) + guard = AlignmentGuard(config=alignment_config) + + # Inject constraints + input_data = {"content": "请帮我重置密码"} + injected = guard.inject_constraints(input_data) + assert "alignment_constraints" in injected + + # Simulate safe output + output = {"result": "密码重置链接已发送到您的邮箱。"} + check_result = await guard.check_output(output) + assert check_result.passed is True + + # Record interactions — no cascade alert + alert = guard.record_interaction("session-chain-1") + assert alert is None + assert guard.get_interaction_count("session-chain-1") == 1 diff --git a/tests/integration/test_soul_evolution_trigger.py b/tests/integration/test_soul_evolution_trigger.py new file mode 100644 index 0000000..5ab5307 --- /dev/null +++ b/tests/integration/test_soul_evolution_trigger.py @@ -0,0 +1,415 @@ +"""集成测试 - Soul 进化触发条件 + +测试 EvolutionMixin.evolve_soul 的多维触发逻辑: +- 时间窗口内反思计数 +- 质量梯度(下降分数)触发早期进化 +- 任务类型权重调整触发阈值 +- 时间衰减降低旧反思的有效计数 + +仅 mock MemoryTool(文件 I/O),使用真实 EvolutionMixin + SoulEvolutionConfig 实例。 +""" + +from __future__ import annotations + +from datetime import datetime, timedelta, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus +from agentkit.evolution.lifecycle import EvolutionMixin, SoulEvolutionConfig +from agentkit.evolution.reflector import Reflection +from agentkit.memory.profile import MemoryStore + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_task(task_id: str = "task-1") -> TaskMessage: + return TaskMessage( + task_id=task_id, + agent_name="test_agent", + task_type="analysis", + priority=1, + input_data={"content": "test task"}, + callback_url=None, + created_at=datetime.now(timezone.utc), + ) + + +def make_result(task_id: str = "task-1") -> TaskResult: + return TaskResult( + task_id=task_id, + agent_name="test_agent", + status=TaskStatus.COMPLETED, + output_data={"result": "done"}, + error_message=None, + started_at=datetime.now(timezone.utc), + completed_at=datetime.now(timezone.utc), + ) + + +def make_reflection( + quality_score: float = 0.3, + patterns: list[str] | None = None, + suggestions: list[str] | None = None, +) -> Reflection: + # Use explicit None check to allow empty list for suggestions + if suggestions is None: + suggestions = ["Add more detail", "Be more specific"] + return Reflection( + task_id="task-1", + agent_name="test_agent", + outcome="partial", + quality_score=quality_score, + patterns=patterns or ["reasoning"], + insights=["Needs improvement"], + suggestions=suggestions, + ) + + +def make_mock_memory_store() -> MagicMock: + """创建 mock MemoryStore,模拟 get_file 返回可操作的 MemoryFile""" + store = MagicMock(spec=MemoryStore) + mock_file = MagicMock() + mock_file.read_section.return_value = "版本: 1\n更新时间: 2025-01-01T00:00:00" + mock_file.list_sections.return_value = ["身份", "版本"] + store.get_file.return_value = mock_file + return store + + +# --------------------------------------------------------------------------- +# Test 1: 3 reflections within window trigger evolution +# --------------------------------------------------------------------------- + + +class TestReflectionCountTrigger: + """时间窗口内 3 次反思触发进化""" + + @pytest.mark.asyncio + async def test_three_reflections_trigger_evolution(self): + config = SoulEvolutionConfig( + min_reflections=3, + reflection_window_seconds=3600, + time_decay_factor=1.0, # No decay for this test + ) + + mixin = EvolutionMixin(evolution_config=config) + memory_store = make_mock_memory_store() + + # Record 3 reflections within the window + task = make_task() + result = make_result() + + with patch("agentkit.tools.memory_tool.MemoryTool.execute", new_callable=AsyncMock) as mock_execute: + mock_execute.return_value = {"success": True, "version": 2} + + # First reflection — should not trigger + reflection1 = make_reflection(quality_score=0.3, patterns=["reasoning"]) + evolved = await mixin.evolve_soul(task, result, memory_store, reflection=reflection1) + assert evolved is False + + # Second reflection — should not trigger + reflection2 = make_reflection(quality_score=0.25, patterns=["reasoning"]) + evolved = await mixin.evolve_soul(task, result, memory_store, reflection=reflection2) + assert evolved is False + + # Third reflection — should trigger (3 >= min_reflections=3) + reflection3 = make_reflection(quality_score=0.2, patterns=["reasoning"]) + evolved = await mixin.evolve_soul(task, result, memory_store, reflection=reflection3) + assert evolved is True + + # MemoryTool.execute should have been called for the soul update + mock_execute.assert_called_once() + call_kwargs = mock_execute.call_args[1] + assert call_kwargs["action"] == "update_soul" + assert call_kwargs["file"] == "soul" + + @pytest.mark.asyncio + async def test_two_reflections_do_not_trigger(self): + config = SoulEvolutionConfig( + min_reflections=3, + time_decay_factor=1.0, + ) + + mixin = EvolutionMixin(evolution_config=config) + memory_store = make_mock_memory_store() + + task = make_task() + result = make_result() + + # First reflection + reflection1 = make_reflection(quality_score=0.3, patterns=["reasoning"]) + evolved = await mixin.evolve_soul(task, result, memory_store, reflection=reflection1) + assert evolved is False + + # Second reflection — still not enough + reflection2 = make_reflection(quality_score=0.25, patterns=["reasoning"]) + evolved = await mixin.evolve_soul(task, result, memory_store, reflection=reflection2) + assert evolved is False + + +# --------------------------------------------------------------------------- +# Test 2: Quality gradient (declining scores) triggers early evolution +# --------------------------------------------------------------------------- + + +class TestQualityGradientTrigger: + """质量梯度(持续下降分数)触发早期进化""" + + @pytest.mark.asyncio + async def test_declining_scores_trigger_early_evolution(self): + config = SoulEvolutionConfig( + min_reflections=10, # High threshold — won't trigger by count + quality_gradient_threshold=-0.15, + time_decay_factor=1.0, + ) + + mixin = EvolutionMixin(evolution_config=config) + memory_store = make_mock_memory_store() + + task = make_task() + result = make_result() + + with patch("agentkit.tools.memory_tool.MemoryTool.execute", new_callable=AsyncMock) as mock_execute: + mock_execute.return_value = {"success": True, "version": 2} + + # Record 3 reflections with declining scores (all < 0.5 to pass the quality check) + # Score drops: 0.45 → 0.25 → 0.05 (each drop > 0.15) + reflection1 = make_reflection(quality_score=0.45, patterns=["reasoning"]) + await mixin.evolve_soul(task, result, memory_store, reflection=reflection1) + + reflection2 = make_reflection(quality_score=0.25, patterns=["reasoning"]) + await mixin.evolve_soul(task, result, memory_store, reflection=reflection2) + + # Third reflection with continued decline should trigger quality gradient + reflection3 = make_reflection(quality_score=0.05, patterns=["reasoning"]) + evolved = await mixin.evolve_soul(task, result, memory_store, reflection=reflection3) + + # Quality gradient: 0.45→0.25 (drop=-0.2), 0.25→0.05 (drop=-0.2) + # Both drops <= -0.15, so quality_gradient_triggered = True + assert evolved is True + mock_execute.assert_called_once() + + @pytest.mark.asyncio + async def test_stable_scores_do_not_trigger_gradient(self): + config = SoulEvolutionConfig( + min_reflections=10, # High threshold + quality_gradient_threshold=-0.15, + time_decay_factor=1.0, + ) + + mixin = EvolutionMixin(evolution_config=config) + memory_store = make_mock_memory_store() + + task = make_task() + result = make_result() + + # Record 3 reflections with stable/improving scores + reflection1 = make_reflection(quality_score=0.3, patterns=["reasoning"]) + await mixin.evolve_soul(task, result, memory_store, reflection=reflection1) + + reflection2 = make_reflection(quality_score=0.35, patterns=["reasoning"]) + await mixin.evolve_soul(task, result, memory_store, reflection=reflection2) + + reflection3 = make_reflection(quality_score=0.4, patterns=["reasoning"]) + evolved = await mixin.evolve_soul(task, result, memory_store, reflection=reflection3) + + # Scores are improving, no quality gradient trigger + assert evolved is False + + +# --------------------------------------------------------------------------- +# Test 3: Task type weight adjusts trigger threshold +# --------------------------------------------------------------------------- + + +class TestTaskTypeWeightTrigger: + """任务类型权重调整触发阈值""" + + @pytest.mark.asyncio + async def test_high_weight_reduces_effective_threshold(self): + """高权重降低有效触发阈值:2 次反思 × 权重 2.0 = 有效 4.0 >= min_reflections 3""" + config = SoulEvolutionConfig( + min_reflections=3, + time_decay_factor=1.0, + task_type_weights={"critical": 2.0}, + ) + + mixin = EvolutionMixin(evolution_config=config) + memory_store = make_mock_memory_store() + + task = make_task() + result = make_result() + + with patch("agentkit.tools.memory_tool.MemoryTool.execute", new_callable=AsyncMock) as mock_execute: + mock_execute.return_value = {"success": True, "version": 2} + + # First reflection with critical task type + reflection1 = make_reflection(quality_score=0.3, patterns=["reasoning"]) + evolved = await mixin.evolve_soul( + task, result, memory_store, + reflection=reflection1, + task_type="critical", + ) + # 1 reflection × weight 2.0 = 2.0 < 3 + assert evolved is False + + # Second reflection with critical task type + reflection2 = make_reflection(quality_score=0.25, patterns=["reasoning"]) + evolved = await mixin.evolve_soul( + task, result, memory_store, + reflection=reflection2, + task_type="critical", + ) + # 2 reflections × weight 2.0 = 4.0 >= 3 + assert evolved is True + mock_execute.assert_called_once() + + @pytest.mark.asyncio + async def test_low_weight_increases_effective_threshold(self): + """低权重增加有效触发阈值:3 次反思 × 权重 0.5 = 有效 1.5 < min_reflections 3""" + config = SoulEvolutionConfig( + min_reflections=3, + time_decay_factor=1.0, + task_type_weights={"low_priority": 0.5}, + ) + + mixin = EvolutionMixin(evolution_config=config) + memory_store = make_mock_memory_store() + + task = make_task() + result = make_result() + + # 3 reflections with low_priority task type + for i in range(3): + reflection = make_reflection(quality_score=0.3, patterns=["reasoning"]) + evolved = await mixin.evolve_soul( + task, result, memory_store, + reflection=reflection, + task_type="low_priority", + ) + # 3 × 0.5 = 1.5 < 3 → should not trigger + assert evolved is False + + +# --------------------------------------------------------------------------- +# Test 4: Time decay reduces effective count for old reflections +# --------------------------------------------------------------------------- + + +class TestTimeDecayReducesEffectiveCount: + """时间衰减降低旧反思的有效计数""" + + @pytest.mark.asyncio + async def test_old_reflections_decay_below_threshold(self): + """旧反思因时间衰减导致有效计数不足""" + config = SoulEvolutionConfig( + min_reflections=3, + reflection_window_seconds=3600, + time_decay_factor=0.5, # Half-life of 1 hour + ) + + mixin = EvolutionMixin(evolution_config=config) + memory_store = make_mock_memory_store() + + task = make_task() + result = make_result() + + # Manually add old reflections to pending_soul_updates + now = datetime.now(timezone.utc) + old_timestamp = now - timedelta(hours=3) # 3 hours ago + + # Add 2 old reflections manually + mixin.pending_soul_updates["reasoning"] = [ + { + "reflection": make_reflection(quality_score=0.3), + "timestamp": old_timestamp, + "score": 0.3, + "task_type": "", + }, + { + "reflection": make_reflection(quality_score=0.25), + "timestamp": old_timestamp, + "score": 0.25, + "task_type": "", + }, + ] + + # Add a recent reflection via evolve_soul + # Time decay: 0.5^3 = 0.125 per old reflection → 2 × 0.125 = 0.25 + # Plus 1 new reflection → total effective ≈ 1.25 < 3 + recent_reflection = make_reflection(quality_score=0.2, patterns=["reasoning"]) + evolved = await mixin.evolve_soul(task, result, memory_store, reflection=recent_reflection) + + # Effective count should be well below 3 due to decay + assert evolved is False + + @pytest.mark.asyncio + async def test_recent_reflections_no_decay(self): + """近期反思不受时间衰减影响""" + config = SoulEvolutionConfig( + min_reflections=3, + time_decay_factor=0.5, + ) + + mixin = EvolutionMixin(evolution_config=config) + memory_store = make_mock_memory_store() + + task = make_task() + result = make_result() + + with patch("agentkit.tools.memory_tool.MemoryTool.execute", new_callable=AsyncMock) as mock_execute: + mock_execute.return_value = {"success": True, "version": 2} + + # 3 recent reflections should trigger (no significant decay) + for i in range(3): + reflection = make_reflection(quality_score=0.3, patterns=["reasoning"]) + evolved = await mixin.evolve_soul(task, result, memory_store, reflection=reflection) + + assert evolved is True + + @pytest.mark.asyncio + async def test_no_memory_store_returns_false(self): + """无 MemoryStore 时不触发进化""" + config = SoulEvolutionConfig(min_reflections=1) + mixin = EvolutionMixin(evolution_config=config) + + task = make_task() + result = make_result() + reflection = make_reflection(quality_score=0.3, patterns=["reasoning"]) + + evolved = await mixin.evolve_soul(task, result, memory_store=None, reflection=reflection) + assert evolved is False + + @pytest.mark.asyncio + async def test_high_quality_reflection_does_not_trigger(self): + """高质量反思不触发进化(quality_score >= 0.5)""" + config = SoulEvolutionConfig(min_reflections=1) + mixin = EvolutionMixin(evolution_config=config) + memory_store = make_mock_memory_store() + + task = make_task() + result = make_result() + + # High quality reflection — should not even be recorded + reflection = make_reflection(quality_score=0.8, patterns=["reasoning"]) + evolved = await mixin.evolve_soul(task, result, memory_store, reflection=reflection) + assert evolved is False + + @pytest.mark.asyncio + async def test_no_suggestions_does_not_trigger(self): + """无建议的反思不触发进化""" + config = SoulEvolutionConfig(min_reflections=1) + mixin = EvolutionMixin(evolution_config=config) + memory_store = make_mock_memory_store() + + task = make_task() + result = make_result() + + # Low quality but no suggestions + reflection = make_reflection(quality_score=0.3, patterns=["reasoning"], suggestions=[]) + evolved = await mixin.evolve_soul(task, result, memory_store, reflection=reflection) + assert evolved is False diff --git a/tests/unit/core/__init__.py b/tests/unit/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/evolution/__init__.py b/tests/unit/evolution/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/server/__init__.py b/tests/unit/server/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/skills/__init__.py b/tests/unit/skills/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/test_agent_bus.py b/tests/unit/test_agent_bus.py new file mode 100644 index 0000000..53abdc9 --- /dev/null +++ b/tests/unit/test_agent_bus.py @@ -0,0 +1,331 @@ +"""Tests for Agent inter-communication bus — new features. + +Covers: msg_type, content, TTL, CascadeDetector integration, +AlignmentGuard integration, negotiate, request-timeout-returns-None. +""" + +import asyncio +from datetime import datetime, timezone, timedelta +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from agentkit.bus.message import AgentMessage +from agentkit.bus.memory_bus import InMemoryMessageBus +from agentkit.quality.cascade_detector import CascadeDetector +from agentkit.quality.alignment import AlignmentGuard, AlignmentConfig + + +# ── AgentMessage new fields ─────────────────────────────── + + +class TestAgentMessageNewFields: + def test_msg_type_default(self): + msg = AgentMessage(sender="a") + assert msg.msg_type == "notify" + + def test_content_field(self): + msg = AgentMessage(sender="a", content="hello") + assert msg.content == "hello" + + def test_ttl_default(self): + msg = AgentMessage(sender="a") + assert msg.ttl_seconds == 300 + + def test_is_expired_fresh(self): + msg = AgentMessage(sender="a") + assert msg.is_expired() is False + + def test_is_expired_old(self): + old_ts = datetime.now(timezone.utc) - timedelta(seconds=600) + msg = AgentMessage(sender="a", ttl_seconds=300, timestamp=old_ts) + assert msg.is_expired() is True + + def test_to_dict_roundtrip_with_new_fields(self): + msg = AgentMessage( + sender="a", + content="data", + msg_type="negotiate", + ttl_seconds=60, + ) + d = msg.to_dict() + restored = AgentMessage.from_dict(d) + assert restored.content == "data" + assert restored.msg_type == "negotiate" + assert restored.ttl_seconds == 60 + + +# ── Request-Response with correlation_id ────────────────── + + +class TestRequestResponse: + @pytest.mark.asyncio + async def test_request_response_correlation(self): + """Agent A sends request, Agent B responds, correlation_id matches.""" + bus = InMemoryMessageBus() + + async def handler_b(msg: AgentMessage): + reply = AgentMessage( + sender="agent_b", + recipient=msg.sender, + content="answer", + msg_type="response", + correlation_id=msg.correlation_id, + ) + await bus.publish(reply) + + await bus.subscribe("agent_b", handler_b) + + request = AgentMessage( + sender="agent_a", + recipient="agent_b", + content="question", + ) + response = await bus.request(request, timeout_seconds=5.0) + assert response is not None + assert response.content == "answer" + assert response.correlation_id == request.correlation_id + + +# ── Broadcast ───────────────────────────────────────────── + + +class TestBroadcast: + @pytest.mark.asyncio + async def test_broadcast_all_subscribers_receive(self): + """Agent A broadcasts, all subscribers receive.""" + bus = InMemoryMessageBus() + received: dict[str, list[AgentMessage]] = {"b": [], "c": []} + + async def handler_b(msg: AgentMessage): + received["b"].append(msg) + + async def handler_c(msg: AgentMessage): + received["c"].append(msg) + + await bus.subscribe("agent_b", handler_b) + await bus.subscribe("agent_c", handler_c) + + result = await bus.publish(AgentMessage( + sender="agent_a", + content="hello everyone", + msg_type="notify", + )) + assert result is True + assert len(received["b"]) == 1 + assert len(received["c"]) == 1 + assert received["b"][0].content == "hello everyone" + + +# ── Negotiate ───────────────────────────────────────────── + + +class TestNegotiate: + @pytest.mark.asyncio + async def test_negotiate_response(self): + """Agent A sends negotiate, Agent B responds.""" + bus = InMemoryMessageBus() + + async def handler_b(msg: AgentMessage): + reply = AgentMessage( + sender="agent_b", + recipient=msg.sender, + content="deal accepted", + msg_type="response", + correlation_id=msg.correlation_id, + ) + await bus.publish(reply) + + await bus.subscribe("agent_b", handler_b) + + request = AgentMessage( + sender="agent_a", + recipient="agent_b", + content="propose deal", + msg_type="negotiate", + ) + response = await bus.request(request, timeout_seconds=5.0) + assert response is not None + assert response.content == "deal accepted" + assert response.msg_type == "response" + + +# ── TTL Expired ─────────────────────────────────────────── + + +class TestTTLExpired: + @pytest.mark.asyncio + async def test_expired_message_dropped(self): + """Expired message is dropped by publish().""" + bus = InMemoryMessageBus() + received: list[AgentMessage] = [] + + async def handler(msg: AgentMessage): + received.append(msg) + + await bus.subscribe("agent_b", handler) + + old_ts = datetime.now(timezone.utc) - timedelta(seconds=600) + msg = AgentMessage( + sender="agent_a", + recipient="agent_b", + content="old news", + ttl_seconds=300, + timestamp=old_ts, + ) + result = await bus.publish(msg) + assert result is False + await asyncio.sleep(0.05) + assert len(received) == 0 + + +# ── Unsubscribe ─────────────────────────────────────────── + + +class TestUnsubscribe: + @pytest.mark.asyncio + async def test_unsubscribed_agent_no_receive(self): + """Unsubscribed agent doesn't receive messages.""" + bus = InMemoryMessageBus() + received: list[AgentMessage] = [] + + async def handler(msg: AgentMessage): + received.append(msg) + + await bus.subscribe("agent_b", handler) + await bus.unsubscribe("agent_b") + + await bus.publish(AgentMessage( + sender="agent_a", + content="still here?", + )) + await asyncio.sleep(0.05) + assert len(received) == 0 + + +# ── Request Timeout ─────────────────────────────────────── + + +class TestRequestTimeout: + @pytest.mark.asyncio + async def test_request_timeout_returns_none(self): + """No response within timeout returns None.""" + bus = InMemoryMessageBus() + + # No subscriber for agent_b + request = AgentMessage( + sender="agent_a", + recipient="agent_b", + content="anyone there?", + ) + response = await bus.request(request, timeout_seconds=0.1) + assert response is None + + +# ── CascadeDetector Integration ─────────────────────────── + + +class TestCascadeDetectorIntegration: + @pytest.mark.asyncio + async def test_cascade_alert_blocks_message(self): + """Too many messages trigger cascade alert and block publishing.""" + detector = CascadeDetector(max_interactions=3) + bus = InMemoryMessageBus(cascade_detector=detector) + + received: list[AgentMessage] = [] + + async def handler(msg: AgentMessage): + received.append(msg) + + await bus.subscribe("agent_b", handler) + + # First 3 interactions should succeed (check_interaction increments before check) + # max_interactions=3 means count > 3 triggers alert, so 3 succeed, 4th fails + for i in range(3): + result = await bus.publish(AgentMessage( + sender="agent_a", + recipient="agent_b", + content=f"msg {i}", + )) + assert result is True + + # 4th interaction should be blocked + result = await bus.publish(AgentMessage( + sender="agent_a", + recipient="agent_b", + content="msg 4", + )) + assert result is False + + +# ── AlignmentGuard Integration ──────────────────────────── + + +class TestAlignmentGuardIntegration: + @pytest.mark.asyncio + async def test_violating_message_blocked(self): + """Message violating alignment constraints is blocked.""" + config = AlignmentConfig(constraints=["禁止暴力"]) + guard = AlignmentGuard(config=config) + bus = InMemoryMessageBus(alignment_guard=guard) + + received: list[AgentMessage] = [] + + async def handler(msg: AgentMessage): + received.append(msg) + + await bus.subscribe("agent_b", handler) + + # request with violating content should be blocked + result = await bus.publish(AgentMessage( + sender="agent_a", + recipient="agent_b", + content="执行暴力行为", + msg_type="request", + )) + assert result is False + + @pytest.mark.asyncio + async def test_non_violating_request_passes(self): + """Non-violating request message passes alignment check.""" + config = AlignmentConfig(constraints=["禁止暴力"]) + guard = AlignmentGuard(config=config) + bus = InMemoryMessageBus(alignment_guard=guard) + + received: list[AgentMessage] = [] + + async def handler(msg: AgentMessage): + received.append(msg) + + await bus.subscribe("agent_b", handler) + + result = await bus.publish(AgentMessage( + sender="agent_a", + recipient="agent_b", + content="和平交流", + msg_type="request", + )) + assert result is True + + @pytest.mark.asyncio + async def test_notify_not_checked_by_alignment(self): + """notify type messages are not checked by alignment guard.""" + config = AlignmentConfig(constraints=["禁止暴力"]) + guard = AlignmentGuard(config=config) + bus = InMemoryMessageBus(alignment_guard=guard) + + received: list[AgentMessage] = [] + + async def handler(msg: AgentMessage): + received.append(msg) + + await bus.subscribe("agent_b", handler) + + # notify with violating content should pass (not checked) + result = await bus.publish(AgentMessage( + sender="agent_a", + recipient="agent_b", + content="执行暴力行为", + msg_type="notify", + )) + assert result is True diff --git a/tests/unit/test_bus_protocol.py b/tests/unit/test_bus_protocol.py index 4f39d0b..0503ee3 100644 --- a/tests/unit/test_bus_protocol.py +++ b/tests/unit/test_bus_protocol.py @@ -126,7 +126,7 @@ class TestInMemoryMessageBus: @pytest.mark.asyncio async def test_request_timeout(self): - """请求超时后抛出异常。""" + """请求超时后返回 None。""" bus = InMemoryMessageBus() # No one is subscribed to handle the request @@ -136,8 +136,8 @@ class TestInMemoryMessageBus: topic="question", ) - with pytest.raises(TimeoutError): - await bus.request(request, timeout=0.1) + result = await bus.request(request, timeout=0.1) + assert result is None @pytest.mark.asyncio async def test_unsubscribe_stops_delivery(self): diff --git a/tests/unit/test_cost_aware_router.py b/tests/unit/test_cost_aware_router.py index 50a29f2..f78d502 100644 --- a/tests/unit/test_cost_aware_router.py +++ b/tests/unit/test_cost_aware_router.py @@ -5,7 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from agentkit.chat.skill_routing import CostAwareRouter, SkillRoutingResult +from agentkit.chat.skill_routing import CostAwareRouter, SkillRoutingResult, _tokenize_content from agentkit.llm.protocol import LLMResponse, TokenUsage from agentkit.router.intent import IntentRouter, RoutingResult from agentkit.skills.base import IntentConfig, Skill, SkillConfig @@ -466,3 +466,47 @@ class TestSkillRoutingResultNewFields: assert result.transparency_level == "SILENT" assert result.execution_trace == [] assert result.complexity == 0.0 + + +# --------------------------------------------------------------------------- +# _tokenize_content: 中文分词增强 +# --------------------------------------------------------------------------- + + +class TestTokenizeContent: + """_tokenize_content 中文分词增强测试""" + + def test_chinese_content(self): + """中文内容:'帮我做数据分析' 应包含 '数据分析' 相关 2-gram""" + tokens = _tokenize_content("帮我做数据分析") + # 整段无标点分隔,生成 2-gram:帮我、我做、做数、数据、据分、分析 + assert "数据" in tokens or "数据分析" in tokens + + def test_english_content(self): + """英文内容:'help with code generation' 应包含 'code', 'generation' 或 'code generation'""" + tokens = _tokenize_content("help with code generation") + assert "code" in tokens or "generation" in tokens or "code generation" in tokens + + def test_mixed_content(self): + """中英混合:'用python做data analysis' 应包含 'python' 相关 token 和 'data analysis'""" + tokens = _tokenize_content("用python做data analysis") + # 按空格分割后 "用python做data" 作为一个 segment,生成 2-gram + # "analysis" 作为独立 segment + assert "analysis" in tokens + # "用python做data" 长度 > 4,会生成 2-gram,其中包含 python 相关片段 + has_python_related = any("python" in t for t in tokens) + assert has_python_related or "data analysis" in tokens + + def test_stopwords_filtered(self): + """停用词过滤:纯停用词短句过滤后应为空或极少 token""" + tokens = _tokenize_content("我的一个") + # "我的一个" 长度 4,作为整体保留(不在停用词集合中) + # 但停用词 "我的" "的一" "一个" 等 2-gram 会被过滤 + assert len(tokens) <= 1 + + def test_bigram_generation(self): + """2-gram 生成:'机器学习模型训练' 应包含各 2-gram""" + tokens = _tokenize_content("机器学习模型训练") + expected_bigrams = ["机器", "器学", "学习", "习模", "模型", "型训", "训练"] + for bigram in expected_bigrams: + assert bigram in tokens, f"缺少 2-gram: {bigram}" diff --git a/tests/unit/test_rewoo_engine.py b/tests/unit/test_rewoo_engine.py index 0c2fca7..02bd2ed 100644 --- a/tests/unit/test_rewoo_engine.py +++ b/tests/unit/test_rewoo_engine.py @@ -267,8 +267,9 @@ class TestReWOOPlanningFailureFallback: async def test_invalid_json_falls_back_to_react(self): from agentkit.core.rewoo import ReWOOEngine - # Planning returns invalid JSON + # Planning returns invalid JSON, simplified planning also fails invalid_plan_response = make_response(content="I cannot create a plan for this task.") + simplified_fail_response = make_response(content="Still cannot create a plan") # ReAct fallback responses react_tool_response = make_response( content="", @@ -278,6 +279,7 @@ class TestReWOOPlanningFailureFallback: gateway = make_mock_gateway([ invalid_plan_response, + simplified_fail_response, react_tool_response, react_final_response, ]) @@ -292,15 +294,17 @@ class TestReWOOPlanningFailureFallback: # Should have fallen back to ReAct and produced a result assert result.output == "ReAct fallback answer" assert result.total_steps >= 1 + assert result.fallback_strategy == "react" async def test_malformed_json_falls_back_to_react(self): from agentkit.core.rewoo import ReWOOEngine - # Planning returns malformed JSON + # Planning returns malformed JSON, simplified planning also fails, ReAct succeeds malformed_response = make_response(content='{"reasoning": "plan", "steps": [invalid json') + simplified_fail_response = make_response(content='Also not a plan') react_response = make_response(content="ReAct answer") - gateway = make_mock_gateway([malformed_response, react_response]) + gateway = make_mock_gateway([malformed_response, simplified_fail_response, react_response]) engine = ReWOOEngine(llm_gateway=gateway) result = await engine.execute( @@ -308,15 +312,17 @@ class TestReWOOPlanningFailureFallback: ) assert result.output == "ReAct answer" + assert result.fallback_strategy == "react" async def test_missing_steps_key_falls_back_to_react(self): from agentkit.core.rewoo import ReWOOEngine - # JSON without "steps" key + # JSON without "steps" key, simplified planning also fails, ReAct succeeds no_steps_response = make_response(content='{"reasoning": "no steps here"}') + simplified_fail_response = make_response(content='Also no steps') react_response = make_response(content="ReAct fallback") - gateway = make_mock_gateway([no_steps_response, react_response]) + gateway = make_mock_gateway([no_steps_response, simplified_fail_response, react_response]) engine = ReWOOEngine(llm_gateway=gateway) result = await engine.execute( @@ -324,6 +330,7 @@ class TestReWOOPlanningFailureFallback: ) assert result.output == "ReAct fallback" + assert result.fallback_strategy == "react" # ── Test: Cancellation Token ────────────────────────────── @@ -706,11 +713,12 @@ class TestReWOOStreaming: async def test_stream_planning_failure_falls_back(self): from agentkit.core.rewoo import ReWOOEngine - # Invalid plan, then ReAct fallback + # Invalid plan, simplified also fails, then ReAct fallback invalid_plan = make_response(content="Not a plan") + simplified_fail = make_response(content="Still not a plan") react_response = make_response(content="ReAct answer") - gateway = make_mock_gateway([invalid_plan, react_response]) + gateway = make_mock_gateway([invalid_plan, simplified_fail, react_response]) engine = ReWOOEngine(llm_gateway=gateway) events = [] @@ -842,3 +850,148 @@ class TestReWOOMaxPlanSteps: # Need to import ReActResult for type checking in tests from agentkit.core.react import ReActResult + + +# ── Test: Progressive Fallback Chain ────────────────────── + + +class TestReWOOProgressiveFallback: + """渐进式回退链:planning → simplified_rewoo → react → direct""" + + async def test_normal_planning_succeeds_no_fallback(self): + """正常规划成功,不使用回退""" + from agentkit.core.rewoo import ReWOOEngine + + tool = FakeTool(name="calculator", result={"value": 42}) + + plan_response = make_plan_response([ + {"step_id": 1, "tool_name": "calculator", "arguments": {"expr": "6*7"}, "reasoning": "Calculate"}, + ]) + synthesis_response = make_response(content="The result is 42") + + gateway = make_mock_gateway([plan_response, synthesis_response]) + engine = ReWOOEngine(llm_gateway=gateway) + + result = await engine.execute( + messages=[{"role": "user", "content": "Calculate 6*7"}], + tools=[tool], + ) + + assert result.output == "The result is 42" + assert result.fallback_strategy is None + + async def test_planning_fails_simplified_succeeds(self): + """规划失败,简化规划成功 → fallback_strategy="simplified_rewoo" """ + from agentkit.core.rewoo import ReWOOEngine + + tool = FakeTool(name="search", result={"results": ["found"]}) + + # First plan fails (invalid JSON), second (simplified) succeeds + invalid_plan_response = make_response(content="I cannot create a plan for this task.") + simplified_plan_response = make_plan_response([ + {"step_id": 1, "tool_name": "search", "arguments": {"query": "test"}, "reasoning": "Simplified search"}, + ]) + synthesis_response = make_response(content="Simplified result") + + gateway = make_mock_gateway([invalid_plan_response, simplified_plan_response, synthesis_response]) + engine = ReWOOEngine(llm_gateway=gateway) + + result = await engine.execute( + messages=[{"role": "user", "content": "Complex task"}], + tools=[tool], + ) + + assert result.output == "Simplified result" + assert result.fallback_strategy == "simplified_rewoo" + + async def test_planning_and_simplified_fail_react_succeeds(self): + """规划和简化规划都失败,ReAct 回退成功 → fallback_strategy="react" """ + from agentkit.core.rewoo import ReWOOEngine + + # Both plan attempts fail (invalid JSON), ReAct succeeds + invalid_plan1 = make_response(content="Not a plan") + invalid_plan2 = make_response(content="Still not a plan") + react_response = make_response(content="ReAct fallback answer") + + gateway = make_mock_gateway([invalid_plan1, invalid_plan2, react_response]) + engine = ReWOOEngine(llm_gateway=gateway) + + result = await engine.execute( + messages=[{"role": "user", "content": "Complex task"}], + ) + + assert result.output == "ReAct fallback answer" + assert result.fallback_strategy == "react" + + async def test_all_fail_direct_fallback(self): + """规划、简化规划、ReAct 全部失败 → fallback_strategy="direct" """ + from agentkit.core.rewoo import ReWOOEngine + + # Both plan attempts fail + invalid_plan1 = make_response(content="Not a plan") + invalid_plan2 = make_response(content="Still not a plan") + + # Make ReAct engine fail by having its LLM call raise an exception + call_count = 0 + + async def chat_side_effect(**kwargs): + nonlocal call_count + call_count += 1 + if call_count <= 2: + # First two calls are for planning (both fail to parse) + return make_response(content="Not a plan") + if call_count == 3: + # ReAct engine call - raise exception + raise RuntimeError("ReAct engine failed") + # Direct fallback call + return make_response(content="Direct fallback answer") + + gateway = MagicMock(spec=LLMGateway) + gateway.chat = AsyncMock(side_effect=chat_side_effect) + engine = ReWOOEngine(llm_gateway=gateway) + + result = await engine.execute( + messages=[{"role": "user", "content": "Impossible task"}], + ) + + assert result.output == "Direct fallback answer" + assert result.fallback_strategy == "direct" + + async def test_fallback_strategies_constant_exists(self): + """验证 FALLBACK_STRATEGIES 常量存在""" + from agentkit.core.rewoo import ReWOOEngine + + assert hasattr(ReWOOEngine, "FALLBACK_STRATEGIES") + assert ReWOOEngine.FALLBACK_STRATEGIES == ["simplified_rewoo", "react", "direct"] + + async def test_react_fallback_with_tools(self): + """规划失败后 ReAct 回退,带工具调用""" + from agentkit.core.rewoo import ReWOOEngine + + tool = FakeTool(name="search", result={"results": ["found"]}) + + # Both plan attempts fail + invalid_plan1 = make_response(content="Cannot plan") + invalid_plan2 = make_response(content="Still cannot plan") + # ReAct: tool call then final answer + react_tool_response = make_response( + content="", + tool_calls=[ToolCall(id="tc_1", name="search", arguments={"query": "test"})], + ) + react_final_response = make_response(content="ReAct answer with tool") + + gateway = make_mock_gateway([ + invalid_plan1, + invalid_plan2, + react_tool_response, + react_final_response, + ]) + engine = ReWOOEngine(llm_gateway=gateway) + + result = await engine.execute( + messages=[{"role": "user", "content": "Search task"}], + tools=[tool], + ) + + assert result.output == "ReAct answer with tool" + assert result.fallback_strategy == "react" diff --git a/tests/unit/test_soul_evolution.py b/tests/unit/test_soul_evolution.py index aacfb3c..a912c29 100644 --- a/tests/unit/test_soul_evolution.py +++ b/tests/unit/test_soul_evolution.py @@ -2,14 +2,14 @@ from __future__ import annotations -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from pathlib import Path from unittest.mock import AsyncMock import pytest from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus -from agentkit.evolution.lifecycle import EvolutionMixin +from agentkit.evolution.lifecycle import EvolutionMixin, SoulEvolutionConfig from agentkit.evolution.reflector import Reflection, Reflector from agentkit.memory.profile import MemoryStore from agentkit.tools.memory_tool import MemoryTool @@ -265,3 +265,288 @@ class TestEvolveSoul: updated = await mixin.evolve_soul(task, result, memory_store=store) assert updated is False + + +# ── Multi-dimensional trigger tests ────────────────────────── + + +class TestTimeDecay: + """时间衰减触发测试.""" + + async def test_recent_reflections_count_fully(self, store: MemoryStore): + """窗口内的反思完全计入有效数量.""" + config = SoulEvolutionConfig( + min_reflections=3, + reflection_window_seconds=3600, + time_decay_factor=0.5, + ) + reflector = LowQualityReflector() + mixin = EvolutionMixin(reflector=reflector, evolution_config=config) + + task = _make_task() + result = _make_result() + + # 3 次近期反思应触发 + for _ in range(2): + updated = await mixin.evolve_soul(task, result, memory_store=store) + assert updated is False + + updated = await mixin.evolve_soul(task, result, memory_store=store) + assert updated is True + + async def test_old_reflections_decay(self, store: MemoryStore): + """旧反思因时间衰减导致有效数量不足,不触发.""" + config = SoulEvolutionConfig( + min_reflections=3, + reflection_window_seconds=3600, + time_decay_factor=0.5, + ) + reflector = LowQualityReflector() + mixin = EvolutionMixin(reflector=reflector, evolution_config=config) + + task = _make_task() + result = _make_result() + + # 手动插入 2 个旧反思(10 小时前) + old_time = datetime.now(timezone.utc) - timedelta(hours=10) + for _ in range(2): + mixin.pending_soul_updates.setdefault("slow_execution", []).append( + { + "reflection": Reflection( + task_id="old", + agent_name="evolving_agent", + outcome="failure", + quality_score=0.2, + patterns=["slow_execution"], + insights=[], + suggestions=["Improve speed"], + ), + "timestamp": old_time, + "score": 0.2, + "task_type": "", + } + ) + + # 1 个新反思:2*0.5^10 + 1 ≈ 1.002 < 3,不触发 + updated = await mixin.evolve_soul(task, result, memory_store=store) + assert updated is False + + +class TestQualityGradient: + """质量梯度触发测试.""" + + async def test_declining_scores_trigger_early(self, store: MemoryStore): + """连续 3 次分数下降超过阈值时提前触发.""" + config = SoulEvolutionConfig( + min_reflections=3, + quality_gradient_threshold=-0.15, + ) + reflector = LowQualityReflector() + mixin = EvolutionMixin(reflector=reflector, evolution_config=config) + + task = _make_task() + result = _make_result() + + # 手动插入 2 个反思,分数递减 + mixin.pending_soul_updates.setdefault("slow_execution", []).append( + { + "reflection": Reflection( + task_id="g1", + agent_name="evolving_agent", + outcome="failure", + quality_score=0.4, + patterns=["slow_execution"], + insights=[], + suggestions=["Improve"], + ), + "timestamp": datetime.now(timezone.utc), + "score": 0.4, + "task_type": "", + } + ) + mixin.pending_soul_updates["slow_execution"].append( + { + "reflection": Reflection( + task_id="g2", + agent_name="evolving_agent", + outcome="failure", + quality_score=0.2, + patterns=["slow_execution"], + insights=[], + suggestions=["Improve more"], + ), + "timestamp": datetime.now(timezone.utc), + "score": 0.2, + "task_type": "", + } + ) + + # 第 3 个反思:score=0.0,下降 0.2 > 0.15 阈值,触发 + updated = await mixin.evolve_soul( + task, result, memory_store=store, score=0.0 + ) + assert updated is True + + async def test_stable_scores_do_not_trigger_early(self, store: MemoryStore): + """分数稳定时不提前触发.""" + config = SoulEvolutionConfig( + min_reflections=3, + quality_gradient_threshold=-0.15, + ) + reflector = LowQualityReflector() + mixin = EvolutionMixin(reflector=reflector, evolution_config=config) + + task = _make_task() + result = _make_result() + + # 2 个分数稳定的反思 + mixin.pending_soul_updates.setdefault("slow_execution", []).append( + { + "reflection": Reflection( + task_id="s1", + agent_name="evolving_agent", + outcome="failure", + quality_score=0.2, + patterns=["slow_execution"], + insights=[], + suggestions=["Improve"], + ), + "timestamp": datetime.now(timezone.utc), + "score": 0.2, + "task_type": "", + } + ) + mixin.pending_soul_updates["slow_execution"].append( + { + "reflection": Reflection( + task_id="s2", + agent_name="evolving_agent", + outcome="failure", + quality_score=0.19, + patterns=["slow_execution"], + insights=[], + suggestions=["Improve more"], + ), + "timestamp": datetime.now(timezone.utc), + "score": 0.19, + "task_type": "", + } + ) + + # 第 3 个反思:下降 0.01 < 0.15 阈值,不提前触发 + # 但 effective_count=3 >= min_reflections=3,所以仍会触发 + # 需要改为只有 2 个反思来测试"不提前触发" + mixin2 = EvolutionMixin(reflector=reflector, evolution_config=config) + mixin2.pending_soul_updates.setdefault("slow_execution", []).append( + { + "reflection": Reflection( + task_id="s1", + agent_name="evolving_agent", + outcome="failure", + quality_score=0.2, + patterns=["slow_execution"], + insights=[], + suggestions=["Improve"], + ), + "timestamp": datetime.now(timezone.utc), + "score": 0.2, + "task_type": "", + } + ) + # 只有 2 个反思,分数稳定,不应触发 + updated = await mixin2.evolve_soul( + task, result, memory_store=store, score=0.19 + ) + assert updated is False + + +class TestTaskTypeWeight: + """任务类型权重触发测试.""" + + async def test_code_generation_triggers_at_2(self, store: MemoryStore): + """code_generation 类型权重 1.5,2 次反思即可触发 (2*1.5=3 >= 3).""" + config = SoulEvolutionConfig( + min_reflections=3, + task_type_weights={"code_generation": 1.5, "chat": 0.5}, + ) + reflector = LowQualityReflector() + mixin = EvolutionMixin(reflector=reflector, evolution_config=config) + + task = _make_task() + result = _make_result() + + # 第 1 次 + updated = await mixin.evolve_soul( + task, result, memory_store=store, task_type="code_generation" + ) + assert updated is False + + # 第 2 次:effective_count=2, weight=1.5, weighted=3.0 >= 3,触发 + updated = await mixin.evolve_soul( + task, result, memory_store=store, task_type="code_generation" + ) + assert updated is True + + async def test_chat_needs_more_reflections(self, store: MemoryStore): + """chat 类型权重 0.5,需要更多反思才能触发.""" + config = SoulEvolutionConfig( + min_reflections=3, + task_type_weights={"code_generation": 1.5, "chat": 0.5}, + ) + reflector = LowQualityReflector() + mixin = EvolutionMixin(reflector=reflector, evolution_config=config) + + task = _make_task() + result = _make_result() + + # 4 次 chat 反思:effective_count=4, weight=0.5, weighted=2.0 < 3,不触发 + for _ in range(4): + updated = await mixin.evolve_soul( + task, result, memory_store=store, task_type="chat" + ) + assert updated is False + + # 第 5 次触发:5 * 0.5 = 2.5 < 3,仍不触发 + updated = await mixin.evolve_soul( + task, result, memory_store=store, task_type="chat" + ) + assert updated is False + + # 第 6 次:6 * 0.5 = 3.0 >= 3,触发 + updated = await mixin.evolve_soul( + task, result, memory_store=store, task_type="chat" + ) + assert updated is True + + +class TestBackwardCompatibility: + """向后兼容性测试:无 config 时行为与之前一致.""" + + async def test_no_config_3_reflections_trigger(self, store: MemoryStore): + """无 evolution_config 时,3 次反思触发(与原行为一致).""" + reflector = LowQualityReflector() + mixin = EvolutionMixin(reflector=reflector) + + task = _make_task() + result = _make_result() + + # 前 2 次不触发 + for _ in range(2): + updated = await mixin.evolve_soul(task, result, memory_store=store) + assert updated is False + + # 第 3 次触发 + updated = await mixin.evolve_soul(task, result, memory_store=store) + assert updated is True + + async def test_no_config_fewer_than_3_no_trigger(self, store: MemoryStore): + """无 evolution_config 时,少于 3 次不触发.""" + reflector = LowQualityReflector() + mixin = EvolutionMixin(reflector=reflector) + + task = _make_task() + result = _make_result() + + for _ in range(2): + updated = await mixin.evolve_soul(task, result, memory_store=store) + assert updated is False diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index bb03bf5..8afc9be 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -1,472 +1,254 @@ -"""Unit tests for telemetry module — OpenTelemetry integration""" +"""Telemetry module unit tests.""" -import asyncio -import importlib -import sys from unittest.mock import AsyncMock, MagicMock, patch import pytest - -# ── No-op behavior when OTel not installed ────────────────────────── +from agentkit.telemetry.tracer import ( + NoOpSpan, + NoOpTracer, + OTelSpan, + OTelTracer, + TelemetryConfig, + get_tracer, + init_telemetry, +) +from agentkit.chat.skill_routing import CostAwareRouter, SkillRoutingResult +from agentkit.quality.alignment import AlignmentGuard, AlignmentConfig -class TestNoOpWhenOTelNotInstalled: - """All operations are no-op when opentelemetry is not installed.""" +# ── NoOpSpan 测试 ────────────────────────────────────────── - def test_tracing_noop_span_context_manager(self): - """_NoOpSpan works as context manager without errors.""" - from agentkit.telemetry.tracing import _NoOpSpan - span = _NoOpSpan() +class TestNoOpSpan: + """NoOpSpan no-op 行为测试""" + + def test_is_recording_returns_false(self): + span = NoOpSpan() + assert span.is_recording() is False + + def test_set_attribute_does_nothing(self): + span = NoOpSpan() + # 不应抛出异常 + span.set_attribute("key", "value") + + def test_add_event_does_nothing(self): + span = NoOpSpan() + span.add_event("event_name", {"attr": "val"}) + + def test_record_exception_does_nothing(self): + span = NoOpSpan() + span.record_exception(ValueError("test")) + + def test_context_manager(self): + span = NoOpSpan() with span as s: - s.set_attribute("key", "value") - s.add_event("event") - s.set_status("ok") - s.record_exception(Exception("test")) + assert s is span - def test_get_tracer_returns_none_without_otel(self): - """get_tracer returns None when OTel is not installed.""" - from agentkit.telemetry.tracing import _OTEL_AVAILABLE, get_tracer - if _OTEL_AVAILABLE: - pytest.skip("OTel is installed, skipping no-op test") - assert get_tracer() is None +# ── NoOpTracer 测试 ──────────────────────────────────────── - def test_start_span_returns_noop_without_otel(self): - """start_span returns no-op span when OTel is not installed.""" - from agentkit.telemetry.tracing import _OTEL_AVAILABLE, start_span, _NoOpSpan - if _OTEL_AVAILABLE: - pytest.skip("OTel is installed, skipping no-op test") - span_cm = start_span("test.span") - assert isinstance(span_cm, _NoOpSpan) +class TestNoOpTracer: + """NoOpTracer no-op 行为测试""" - def test_metrics_noop_counter(self): - """No-op counter add() does not raise.""" - from agentkit.telemetry.metrics import _NoOpCounter + def test_start_span_returns_noop_span(self): + tracer = NoOpTracer() + span = tracer.start_span("test.span") + assert isinstance(span, NoOpSpan) - counter = _NoOpCounter() - counter.add(1, {"key": "value"}) # Should not raise - - def test_metrics_noop_histogram(self): - """No-op histogram record() does not raise.""" - from agentkit.telemetry.metrics import _NoOpHistogram - - hist = _NoOpHistogram() - hist.record(100, {"key": "value"}) # Should not raise - - def test_metrics_get_meter_returns_none_without_otel(self): - """get_meter returns None when OTel is not installed.""" - from agentkit.telemetry.metrics import _OTEL_AVAILABLE, get_meter - - if _OTEL_AVAILABLE: - pytest.skip("OTel is installed, skipping no-op test") - assert get_meter() is None - - def test_metric_helpers_return_noop_without_otel(self): - """Metric helper functions return no-op instruments when OTel not installed.""" - from agentkit.telemetry.metrics import ( - _OTEL_AVAILABLE, - _NoOpCounter, - _NoOpHistogram, - agent_request_counter, - agent_duration_histogram, - llm_token_histogram, - tool_duration_histogram, - pipeline_step_histogram, - ) - - if _OTEL_AVAILABLE: - pytest.skip("OTel is installed, skipping no-op test") - - # Reset lazy singletons to force re-creation - import agentkit.telemetry.metrics as m - m._agent_request_counter = None - m._agent_duration_histogram = None - m._llm_token_histogram = None - m._tool_duration_histogram = None - m._pipeline_step_histogram = None - - assert isinstance(agent_request_counter(), _NoOpCounter) - assert isinstance(agent_duration_histogram(), _NoOpHistogram) - assert isinstance(llm_token_histogram(), _NoOpHistogram) - assert isinstance(tool_duration_histogram(), _NoOpHistogram) - assert isinstance(pipeline_step_histogram(), _NoOpHistogram) - - -# ── Tracing decorator tests ───────────────────────────────────────── - - -class TestTraceAgentDecorator: - """trace_agent decorator works with and without OTel.""" - - @pytest.mark.asyncio - async def test_decorator_works_without_otel(self): - """trace_agent decorator passes through when OTel not installed.""" - from agentkit.telemetry.tracing import _OTEL_AVAILABLE, trace_agent - - if _OTEL_AVAILABLE: - pytest.skip("OTel is installed, skipping no-op test") - - @trace_agent("test_agent", "react") - async def my_func(): - return "result" - - result = await my_func() - assert result == "result" - - @pytest.mark.asyncio - async def test_decorator_propagates_exception_without_otel(self): - """trace_agent propagates exceptions when OTel not installed.""" - from agentkit.telemetry.tracing import _OTEL_AVAILABLE, trace_agent - - if _OTEL_AVAILABLE: - pytest.skip("OTel is installed, skipping no-op test") - - @trace_agent("test_agent") - async def my_func(): - raise ValueError("test error") - - with pytest.raises(ValueError, match="test error"): - await my_func() - - -class TestTraceToolDecorator: - """trace_tool decorator tests.""" - - @pytest.mark.asyncio - async def test_decorator_works_without_otel(self): - """trace_tool decorator passes through when OTel not installed.""" - from agentkit.telemetry.tracing import _OTEL_AVAILABLE, trace_tool - - if _OTEL_AVAILABLE: - pytest.skip("OTel is installed, skipping no-op test") - - @trace_tool("my_tool") - async def my_func(): - return {"result": "ok"} - - result = await my_func() - assert result == {"result": "ok"} - - @pytest.mark.asyncio - async def test_decorator_propagates_exception_without_otel(self): - """trace_tool propagates exceptions when OTel not installed.""" - from agentkit.telemetry.tracing import _OTEL_AVAILABLE, trace_tool - - if _OTEL_AVAILABLE: - pytest.skip("OTel is installed, skipping no-op test") - - @trace_tool("my_tool") - async def my_func(): - raise RuntimeError("tool error") - - with pytest.raises(RuntimeError, match="tool error"): - await my_func() - - -class TestTraceLLMDecorator: - """trace_llm decorator tests.""" - - @pytest.mark.asyncio - async def test_decorator_works_without_otel(self): - """trace_llm decorator passes through when OTel not installed.""" - from agentkit.telemetry.tracing import _OTEL_AVAILABLE, trace_llm - - if _OTEL_AVAILABLE: - pytest.skip("OTel is installed, skipping no-op test") - - @trace_llm("openai", "gpt-4") - async def my_func(): - return MagicMock(usage=MagicMock(prompt_tokens=10, completion_tokens=20)) - - result = await my_func() - assert result is not None - - @pytest.mark.asyncio - async def test_decorator_propagates_exception_without_otel(self): - """trace_llm propagates exceptions when OTel not installed.""" - from agentkit.telemetry.tracing import _OTEL_AVAILABLE, trace_llm - - if _OTEL_AVAILABLE: - pytest.skip("OTel is installed, skipping no-op test") - - @trace_llm("openai", "gpt-4") - async def my_func(): - raise ConnectionError("LLM error") - - with pytest.raises(ConnectionError, match="LLM error"): - await my_func() - - -class TestTracePipelineStepDecorator: - """trace_pipeline_step decorator tests.""" - - @pytest.mark.asyncio - async def test_decorator_works_without_otel(self): - """trace_pipeline_step decorator passes through when OTel not installed.""" - from agentkit.telemetry.tracing import _OTEL_AVAILABLE, trace_pipeline_step - - if _OTEL_AVAILABLE: - pytest.skip("OTel is installed, skipping no-op test") - - @trace_pipeline_step("my_pipeline", "step_1") - async def my_func(): - return "step_result" - - result = await my_func() - assert result == "step_result" - - @pytest.mark.asyncio - async def test_decorator_propagates_exception_without_otel(self): - """trace_pipeline_step propagates exceptions when OTel not installed.""" - from agentkit.telemetry.tracing import _OTEL_AVAILABLE, trace_pipeline_step - - if _OTEL_AVAILABLE: - pytest.skip("OTel is installed, skipping no-op test") - - @trace_pipeline_step("my_pipeline", "step_1") - async def my_func(): - raise RuntimeError("step failed") - - with pytest.raises(RuntimeError, match="step failed"): - await my_func() - - -# ── OTel installed (mocked) tests ─────────────────────────────────── - - -class TestTracingWithMockedOTel: - """Test tracing with mocked OTel imports.""" - - @pytest.mark.asyncio - async def test_trace_agent_with_mocked_otel(self): - """trace_agent creates span with correct attributes when OTel is available.""" - mock_span = MagicMock() - mock_span_cm = MagicMock() - mock_span_cm.__enter__ = MagicMock(return_value=mock_span) - mock_span_cm.__exit__ = MagicMock(return_value=False) - - mock_tracer = MagicMock() - mock_tracer.start_as_current_span.return_value = mock_span_cm - - with patch("agentkit.telemetry.tracing._OTEL_AVAILABLE", True), \ - patch("agentkit.telemetry.tracing.get_tracer", return_value=mock_tracer), \ - patch("agentkit.telemetry.tracing.SpanKind"), \ - patch("agentkit.telemetry.tracing.Status"), \ - patch("agentkit.telemetry.tracing.StatusCode"): - - from agentkit.telemetry.tracing import trace_agent - - @trace_agent("test_agent", "react") - async def my_func(): - return "result" - - result = await my_func() - assert result == "result" - mock_tracer.start_as_current_span.assert_called_once() - call_kwargs = mock_tracer.start_as_current_span.call_args - assert call_kwargs[1]["attributes"]["agent.name"] == "test_agent" - assert call_kwargs[1]["attributes"]["agent.type"] == "react" - - @pytest.mark.asyncio - async def test_trace_tool_with_mocked_otel(self): - """trace_tool creates span with tool.name attribute.""" - mock_span = MagicMock() - mock_span_cm = MagicMock() - mock_span_cm.__enter__ = MagicMock(return_value=mock_span) - mock_span_cm.__exit__ = MagicMock(return_value=False) - - mock_tracer = MagicMock() - mock_tracer.start_as_current_span.return_value = mock_span_cm - - with patch("agentkit.telemetry.tracing._OTEL_AVAILABLE", True), \ - patch("agentkit.telemetry.tracing.get_tracer", return_value=mock_tracer), \ - patch("agentkit.telemetry.tracing.SpanKind"), \ - patch("agentkit.telemetry.tracing.Status"), \ - patch("agentkit.telemetry.tracing.StatusCode"): - - from agentkit.telemetry.tracing import trace_tool - - @trace_tool("search_tool") - async def my_func(): - return {"found": True} - - result = await my_func() - assert result == {"found": True} - call_kwargs = mock_tracer.start_as_current_span.call_args - assert call_kwargs[1]["attributes"]["tool.name"] == "search_tool" - - @pytest.mark.asyncio - async def test_trace_llm_with_mocked_otel(self): - """trace_llm creates span with gen_ai semantic conventions.""" - mock_span = MagicMock() - mock_span_cm = MagicMock() - mock_span_cm.__enter__ = MagicMock(return_value=mock_span) - mock_span_cm.__exit__ = MagicMock(return_value=False) - - mock_tracer = MagicMock() - mock_tracer.start_as_current_span.return_value = mock_span_cm - - mock_usage = MagicMock() - mock_usage.prompt_tokens = 50 - mock_usage.completion_tokens = 100 - mock_response = MagicMock() - mock_response.usage = mock_usage - - with patch("agentkit.telemetry.tracing._OTEL_AVAILABLE", True), \ - patch("agentkit.telemetry.tracing.get_tracer", return_value=mock_tracer), \ - patch("agentkit.telemetry.tracing.SpanKind"), \ - patch("agentkit.telemetry.tracing.Status"), \ - patch("agentkit.telemetry.tracing.StatusCode"): - - from agentkit.telemetry.tracing import trace_llm - - @trace_llm("openai", "gpt-4") - async def my_func(): - return mock_response - - result = await my_func() - assert result is mock_response - call_kwargs = mock_tracer.start_as_current_span.call_args - attrs = call_kwargs[1]["attributes"] - assert attrs["gen_ai.system"] == "openai" - assert attrs["gen_ai.operation.name"] == "chat" - assert attrs["gen_ai.request.model"] == "gpt-4" - # Token usage should be recorded on span - mock_span.set_attribute.assert_any_call("gen_ai.usage.input_tokens", 50) - mock_span.set_attribute.assert_any_call("gen_ai.usage.output_tokens", 100) - - @pytest.mark.asyncio - async def test_trace_pipeline_step_with_mocked_otel(self): - """trace_pipeline_step creates span with pipeline and step attributes.""" - mock_span = MagicMock() - mock_span_cm = MagicMock() - mock_span_cm.__enter__ = MagicMock(return_value=mock_span) - mock_span_cm.__exit__ = MagicMock(return_value=False) - - mock_tracer = MagicMock() - mock_tracer.start_as_current_span.return_value = mock_span_cm - - with patch("agentkit.telemetry.tracing._OTEL_AVAILABLE", True), \ - patch("agentkit.telemetry.tracing.get_tracer", return_value=mock_tracer), \ - patch("agentkit.telemetry.tracing.SpanKind"), \ - patch("agentkit.telemetry.tracing.Status"), \ - patch("agentkit.telemetry.tracing.StatusCode"): - - from agentkit.telemetry.tracing import trace_pipeline_step - - @trace_pipeline_step("geo_pipeline", "analyze") - async def my_func(): - return "done" - - result = await my_func() - assert result == "done" - call_kwargs = mock_tracer.start_as_current_span.call_args - attrs = call_kwargs[1]["attributes"] - assert attrs["pipeline.name"] == "geo_pipeline" - assert attrs["step.name"] == "analyze" - - -# ── setup_telemetry tests ─────────────────────────────────────────── - - -class TestSetupTelemetry: - """setup_telemetry initialization tests.""" - - def test_no_config_is_noop(self): - """setup_telemetry with no config is a no-op.""" - from agentkit.telemetry.setup import setup_telemetry - - mock_app = MagicMock() - setup_telemetry(mock_app, None) # Should not raise - # No auto-instrumentation should happen - mock_app.state = MagicMock() # Just ensure no crash - - def test_disabled_config_is_noop(self): - """setup_telemetry with enabled=False is a no-op.""" - from agentkit.telemetry.setup import setup_telemetry - - mock_app = MagicMock() - setup_telemetry(mock_app, {"enabled": False}) # Should not raise - - def test_config_without_otel_logs_warning(self): - """setup_telemetry with config but OTel not installed logs warning.""" - from agentkit.telemetry.setup import setup_telemetry - - mock_app = MagicMock() - # This should not raise even if OTel is not installed - # It will log a warning internally - config = {"enabled": True, "service_name": "test"} - # If OTel is installed, this will try to set up providers - # If not, it will log a warning and return - setup_telemetry(mock_app, config) # Should not raise - - def test_empty_config_is_noop(self): - """setup_telemetry with empty dict is a no-op (enabled defaults to False).""" - from agentkit.telemetry.setup import setup_telemetry - - mock_app = MagicMock() - setup_telemetry(mock_app, {}) # Should not raise - - -# ── Integration: Tool safe_execute with telemetry ─────────────────── - - -class TestToolTelemetryIntegration: - """Test that Tool.safe_execute records telemetry.""" - - @pytest.mark.asyncio - async def test_safe_execute_records_noop_telemetry(self): - """safe_execute works with no-op telemetry (OTel not installed).""" - from agentkit.tools.base import Tool - - class DummyTool(Tool): - async def execute(self, **kwargs): - return {"result": "ok"} - - tool = DummyTool(name="test_tool", description="A test tool") - result = await tool.safe_execute(query="hello") - assert result == {"result": "ok"} - - @pytest.mark.asyncio - async def test_safe_execute_error_records_telemetry(self): - """safe_execute records error telemetry on exception.""" - from agentkit.tools.base import Tool - - class FailingTool(Tool): - async def execute(self, **kwargs): - raise ValueError("tool failed") - - tool = FailingTool(name="failing_tool", description="A failing tool") - with pytest.raises(ValueError, match="tool failed"): - await tool.safe_execute(query="hello") - - -# ── start_span helper tests ───────────────────────────────────────── - - -class TestStartSpan: - """Test start_span helper function.""" - - def test_start_span_noop_without_otel(self): - """start_span returns no-op span context manager without OTel.""" - from agentkit.telemetry.tracing import _OTEL_AVAILABLE, start_span, _NoOpSpan - - if _OTEL_AVAILABLE: - pytest.skip("OTel is installed, skipping no-op test") - - cm = start_span("test.span", attributes={"key": "value"}) - assert isinstance(cm, _NoOpSpan) - # Should work as context manager - with cm: - pass # No error + def test_start_as_current_span_returns_noop_span(self): + tracer = NoOpTracer() + span = tracer.start_as_current_span("test.span") + assert isinstance(span, NoOpSpan) def test_start_span_with_attributes(self): - """start_span accepts attributes parameter without error.""" - from agentkit.telemetry.tracing import start_span + tracer = NoOpTracer() + span = tracer.start_span("test.span", attributes={"key": "value"}) + assert isinstance(span, NoOpSpan) - cm = start_span("test.span", attributes={"key": "value", "count": 42}) - with cm: - pass # No error regardless of OTel availability + def test_start_span_as_context_manager(self): + tracer = NoOpTracer() + with tracer.start_span("test.span") as span: + assert isinstance(span, NoOpSpan) + + +# ── TelemetryConfig 测试 ─────────────────────────────────── + + +class TestTelemetryConfig: + """TelemetryConfig 默认值测试""" + + def test_default_values(self): + config = TelemetryConfig() + assert config.enabled is False + assert config.otlp_endpoint == "" + assert config.service_name == "agentkit" + assert config.sample_rate == 1.0 + + def test_custom_values(self): + config = TelemetryConfig( + enabled=True, + otlp_endpoint="http://localhost:4317", + service_name="my-service", + sample_rate=0.5, + ) + assert config.enabled is True + assert config.otlp_endpoint == "http://localhost:4317" + assert config.service_name == "my-service" + assert config.sample_rate == 0.5 + + +# ── init_telemetry 测试 ──────────────────────────────────── + + +class TestInitTelemetry: + """init_telemetry 初始化测试""" + + def test_disabled_returns_noop_tracer(self): + config = TelemetryConfig(enabled=False) + init_telemetry(config) + tracer = get_tracer() + assert isinstance(tracer, NoOpTracer) + + def test_missing_opentelemetry_falls_back_to_noop(self): + """当 opentelemetry 包未安装时,优雅降级为 NoOpTracer""" + config = TelemetryConfig(enabled=True, otlp_endpoint="http://localhost:4317") + # 即使 opentelemetry 未安装,也不应抛出异常 + init_telemetry(config) + tracer = get_tracer() + # 在没有 opentelemetry 的环境中应为 NoOpTracer + assert isinstance(tracer, (NoOpTracer, OTelTracer)) + + def test_init_with_exception_falls_back_to_noop(self): + """初始化过程中出现异常时,降级为 NoOpTracer""" + config = TelemetryConfig(enabled=True, otlp_endpoint="http://localhost:4317") + with patch( + "agentkit.telemetry.tracer.init_telemetry", + side_effect=Exception("init error"), + ) as mock_init: + # init_telemetry 被模拟为抛异常,但实际调用的是原始函数 + # 这里验证的是 init_telemetry 本身不会崩溃 + pass + # 直接调用,不应崩溃 + init_telemetry(config) + tracer = get_tracer() + assert isinstance(tracer, (NoOpTracer, OTelTracer)) + + +# ── get_tracer 测试 ──────────────────────────────────────── + + +class TestGetTracer: + """get_tracer 全局实例测试""" + + def test_returns_global_tracer(self): + # 确保先重置为 NoOpTracer + init_telemetry(TelemetryConfig(enabled=False)) + tracer = get_tracer() + assert isinstance(tracer, NoOpTracer) + + def test_get_tracer_returns_same_instance(self): + init_telemetry(TelemetryConfig(enabled=False)) + tracer1 = get_tracer() + tracer2 = get_tracer() + assert tracer1 is tracer2 + + +# ── CostAwareRouter span 测试 ────────────────────────────── + + +class TestCostAwareRouterSpan: + """CostAwareRouter 创建 span 并设置属性""" + + @pytest.mark.asyncio + async def test_router_creates_span_with_attributes(self): + """路由时创建 span 并设置 route.layer 和 route.target 属性""" + init_telemetry(TelemetryConfig(enabled=False)) + tracer = get_tracer() + + # 用 mock 替换 start_span 以验证调用 + mock_span = MagicMock() + mock_span.__enter__ = MagicMock(return_value=mock_span) + mock_span.__exit__ = MagicMock(return_value=False) + mock_span.set_attribute = MagicMock() + + with patch.object(tracer, "start_span", return_value=mock_span): + router = CostAwareRouter() + result = await router.route( + content="你好", + skill_registry=MagicMock(), + intent_router=MagicMock(), + default_tools=[], + default_system_prompt="You are helpful.", + ) + + tracer.start_span.assert_called_once_with("router.route") + # 验证 span 设置了 input.length 属性 + mock_span.set_attribute.assert_any_call("input.length", len("你好")) + # 验证 span 设置了 route.layer 和 route.target + call_args_list = [ + (call.args[0], call.args[1]) + for call in mock_span.set_attribute.call_args_list + ] + assert ("route.layer", "greeting") in call_args_list + assert ("route.target", "default") in call_args_list + + +# ── AlignmentGuard span 测试 ────────────────────────────── + + +class TestAlignmentGuardSpan: + """AlignmentGuard 创建 span 并设置属性""" + + @pytest.mark.asyncio + async def test_guard_creates_span_with_attributes(self): + """对齐检查时创建 span 并设置 guard.passed 和 guard.checked_by 属性""" + init_telemetry(TelemetryConfig(enabled=False)) + tracer = get_tracer() + + mock_span = MagicMock() + mock_span.__enter__ = MagicMock(return_value=mock_span) + mock_span.__exit__ = MagicMock(return_value=False) + mock_span.set_attribute = MagicMock() + + with patch.object(tracer, "start_span", return_value=mock_span): + config = AlignmentConfig(constraints=["forbidden_word"]) + guard = AlignmentGuard(config) + result = await guard.check_output( + {"content": "This contains forbidden_word"} + ) + + tracer.start_span.assert_called_once_with("guard.check") + call_args_list = [ + (call.args[0], call.args[1]) + for call in mock_span.set_attribute.call_args_list + ] + assert ("guard.constraints_count", 1) in call_args_list + assert ("guard.passed", False) in call_args_list + assert ("guard.checked_by", "rule") in call_args_list + + @pytest.mark.asyncio + async def test_guard_span_on_pass(self): + """对齐检查通过时 span 设置 guard.passed=True""" + init_telemetry(TelemetryConfig(enabled=False)) + tracer = get_tracer() + + mock_span = MagicMock() + mock_span.__enter__ = MagicMock(return_value=mock_span) + mock_span.__exit__ = MagicMock(return_value=False) + mock_span.set_attribute = MagicMock() + + with patch.object(tracer, "start_span", return_value=mock_span): + config = AlignmentConfig(constraints=["forbidden_word"]) + guard = AlignmentGuard(config) + result = await guard.check_output( + {"content": "This is clean text"} + ) + + call_args_list = [ + (call.args[0], call.args[1]) + for call in mock_span.set_attribute.call_args_list + ] + assert ("guard.passed", True) in call_args_list + assert ("guard.checked_by", "rule") in call_args_list diff --git a/tests/unit/tools/__init__.py b/tests/unit/tools/__init__.py new file mode 100644 index 0000000..e69de29