feat(marketplace): add Phase B/C - CostAwareRouter, OrganizationContext, AlignmentGuard, Soul Evolution, Auction, Server Integration
Phase B: - U1: CostAwareRouter with 3-layer routing (rule/LLM/capability matching) - U6: OrganizationContext with agent profiles and capability-based discovery - U7: AlignmentGuard with constraint injection and cascade detection Phase C: - U8: Soul dynamic evolution with version tracking and reflection-triggered updates - U9: Auction mechanism as optional advanced routing mode - U10: Server integration + end-to-end integration tests 250 new tests passing across all units.
This commit is contained in:
parent
5b42487d8a
commit
8713636d50
|
|
@ -6,6 +6,7 @@ and prompt assembly into a single module used by both chat routes.
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
|
|
@ -42,6 +43,9 @@ class SkillRoutingResult:
|
|||
matched: bool = False
|
||||
match_method: str | None = None
|
||||
match_confidence: float = 0.0
|
||||
transparency_level: str = "SILENT"
|
||||
execution_trace: list[dict] = field(default_factory=list)
|
||||
complexity: float = 0.0
|
||||
|
||||
|
||||
def parse_skill_prefix(content: str) -> tuple[str | None, str]:
|
||||
|
|
@ -166,3 +170,323 @@ async def resolve_skill_routing(
|
|||
result.agent_name = default_agent_name
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CostAwareRouter - 三层成本感知路由
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_GREETING_RE = re.compile(
|
||||
r"^(你好|hi|hello|hey|嗨|哈喽|早上好|下午好|晚上好|good morning|good afternoon|good evening)\s*[!!.。??]*$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
_CHAT_MODE_RE = re.compile(
|
||||
r"^(谢谢|感谢|thanks|thank you|ok|好的|嗯|对|是|不是|没关系|再见|bye|goodbye)\s*[!!.。??]*$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
_COMPLEXITY_CLASSIFY_PROMPT = (
|
||||
"Assess the complexity of the following user request on a scale of 0.0 to 1.0.\n"
|
||||
"0.0 = trivial greeting / simple chat\n"
|
||||
"0.3 = single-skill task (e.g. search, translate)\n"
|
||||
"0.7 = multi-step or cross-domain task (e.g. market research + competitor analysis)\n"
|
||||
"1.0 = highly complex, multi-agent collaboration needed\n\n"
|
||||
'User request: "{content}"\n\n'
|
||||
'Respond ONLY with a JSON object: {{"complexity": <float>}}'
|
||||
)
|
||||
|
||||
|
||||
class CostAwareRouter:
|
||||
"""三层成本感知路由器。
|
||||
|
||||
Layer 0: 规则匹配(零成本)— @skill: 前缀 / 问候 / 简单对话
|
||||
Layer 1: LLM 快速分类(~100 tokens)— 复杂度评估 + IntentRouter
|
||||
Layer 2: 能力匹配 / 拍卖(可选)— 高复杂度任务委派给最佳 Agent
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_gateway: Any = None,
|
||||
model: str = "default",
|
||||
org_context: Any = None,
|
||||
auction_enabled: bool = False,
|
||||
):
|
||||
self._llm_gateway = llm_gateway
|
||||
self._model = model
|
||||
self._org_context = org_context
|
||||
self._auction_enabled = auction_enabled
|
||||
|
||||
# -- Layer 0: Rule-based (zero cost) ------------------------------------
|
||||
|
||||
def _match_layer0(self, content: str) -> tuple[str | None, str]:
|
||||
"""Layer 0 规则匹配。
|
||||
|
||||
Returns:
|
||||
(match_type, clean_content) — match_type 为 None 表示未命中。
|
||||
"""
|
||||
# @skill: 显式前缀
|
||||
explicit_skill, clean = parse_skill_prefix(content)
|
||||
if explicit_skill:
|
||||
return "explicit_skill", clean
|
||||
|
||||
# 问候模式
|
||||
stripped = content.strip()
|
||||
if _GREETING_RE.match(stripped):
|
||||
return "greeting", stripped
|
||||
|
||||
# 简单对话模式
|
||||
if _CHAT_MODE_RE.match(stripped):
|
||||
return "chat_mode", stripped
|
||||
|
||||
return None, stripped
|
||||
|
||||
# -- Layer 1: LLM quick classify (~100 tokens) -------------------------
|
||||
|
||||
async def quick_classify(self, content: str) -> float:
|
||||
"""使用 LLM 快速评估用户请求的复杂度 (0.0-1.0)。
|
||||
|
||||
当 LLM Gateway 不可用或解析失败时,返回默认中等复杂度 0.5。
|
||||
"""
|
||||
if self._llm_gateway is None:
|
||||
return 0.5
|
||||
|
||||
prompt = _COMPLEXITY_CLASSIFY_PROMPT.format(content=content)
|
||||
try:
|
||||
response = await self._llm_gateway.chat(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
model=self._model,
|
||||
)
|
||||
data = json.loads(response.content.strip())
|
||||
complexity = float(data.get("complexity", 0.5))
|
||||
return max(0.0, min(1.0, complexity))
|
||||
except Exception as e:
|
||||
logger.warning(f"CostAwareRouter quick_classify failed: {e}")
|
||||
return 0.5
|
||||
|
||||
# -- Layer 2: Capability matching / Auction (optional) -----------------
|
||||
|
||||
async def _route_layer2(
|
||||
self,
|
||||
content: str,
|
||||
skill_registry: Any,
|
||||
intent_router: Any,
|
||||
default_tools: list,
|
||||
default_system_prompt: str | None,
|
||||
default_model: str,
|
||||
default_agent_name: str,
|
||||
agent_tool_registry: Any = None,
|
||||
session_id: str = "",
|
||||
complexity: float = 0.0,
|
||||
trace: list[dict] | None = None,
|
||||
) -> SkillRoutingResult:
|
||||
"""Layer 2: 高复杂度任务通过 org_context.find_best_agent 路由。"""
|
||||
if self._org_context is not None and hasattr(self._org_context, "find_best_agent"):
|
||||
try:
|
||||
best_agent = await self._org_context.find_best_agent(content)
|
||||
if best_agent is not None:
|
||||
agent_name = best_agent if isinstance(best_agent, str) else getattr(best_agent, "name", str(best_agent))
|
||||
result = SkillRoutingResult(
|
||||
clean_content=content,
|
||||
matched=True,
|
||||
match_method="capability",
|
||||
match_confidence=0.8,
|
||||
agent_name=agent_name,
|
||||
model=default_model,
|
||||
system_prompt=default_system_prompt,
|
||||
tools=default_tools,
|
||||
complexity=complexity,
|
||||
)
|
||||
if trace is not None:
|
||||
trace.append({
|
||||
"layer": 2,
|
||||
"method": "capability",
|
||||
"agent_name": agent_name,
|
||||
"complexity": complexity,
|
||||
})
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.warning(f"CostAwareRouter Layer 2 org_context.find_best_agent failed: {e}")
|
||||
|
||||
# Fallback: 使用 IntentRouter
|
||||
result = await resolve_skill_routing(
|
||||
content=content,
|
||||
skill_registry=skill_registry,
|
||||
intent_router=intent_router,
|
||||
default_tools=default_tools,
|
||||
default_system_prompt=default_system_prompt,
|
||||
default_model=default_model,
|
||||
default_agent_name=default_agent_name,
|
||||
agent_tool_registry=agent_tool_registry,
|
||||
session_id=session_id,
|
||||
)
|
||||
result.complexity = complexity
|
||||
if trace is not None:
|
||||
trace.append({
|
||||
"layer": 2,
|
||||
"method": "intent_router_fallback",
|
||||
"complexity": complexity,
|
||||
})
|
||||
return result
|
||||
|
||||
# -- Main entry point ---------------------------------------------------
|
||||
|
||||
async def route(
|
||||
self,
|
||||
content: str,
|
||||
skill_registry: Any,
|
||||
intent_router: Any,
|
||||
default_tools: list,
|
||||
default_system_prompt: str | None,
|
||||
default_model: str = "default",
|
||||
default_agent_name: str = "default",
|
||||
agent_tool_registry: Any = None,
|
||||
session_id: str = "",
|
||||
transparency: str = "SILENT",
|
||||
) -> SkillRoutingResult:
|
||||
"""三层成本感知路由主入口。
|
||||
|
||||
Args:
|
||||
content: 用户输入内容
|
||||
skill_registry: Skill 注册表
|
||||
intent_router: IntentRouter 实例
|
||||
default_tools: 默认工具列表
|
||||
default_system_prompt: 默认系统提示词
|
||||
default_model: 默认模型
|
||||
default_agent_name: 默认 Agent 名称
|
||||
agent_tool_registry: Agent 工具注册表
|
||||
session_id: 会话 ID
|
||||
transparency: 透明度级别 (SILENT / VERBOSE / TRACE)
|
||||
|
||||
Returns:
|
||||
SkillRoutingResult 包含路由结果和追踪信息
|
||||
"""
|
||||
trace: list[dict] = []
|
||||
|
||||
# ---- Layer 0: Rule-based (zero cost) ----
|
||||
match_type, clean_content = self._match_layer0(content)
|
||||
|
||||
if match_type == "explicit_skill":
|
||||
result = await resolve_skill_routing(
|
||||
content=content,
|
||||
skill_registry=skill_registry,
|
||||
intent_router=intent_router,
|
||||
default_tools=default_tools,
|
||||
default_system_prompt=default_system_prompt,
|
||||
default_model=default_model,
|
||||
default_agent_name=default_agent_name,
|
||||
agent_tool_registry=agent_tool_registry,
|
||||
session_id=session_id,
|
||||
)
|
||||
result.match_method = result.match_method or "explicit_skill"
|
||||
result.complexity = 0.0
|
||||
trace.append({
|
||||
"layer": 0,
|
||||
"method": "explicit_skill",
|
||||
"matched": result.matched,
|
||||
"cost": "zero",
|
||||
})
|
||||
result.execution_trace = trace if transparency != "SILENT" else []
|
||||
result.transparency_level = transparency
|
||||
return result
|
||||
|
||||
if match_type in ("greeting", "chat_mode"):
|
||||
result = SkillRoutingResult(
|
||||
clean_content=clean_content,
|
||||
system_prompt=default_system_prompt,
|
||||
tools=default_tools,
|
||||
model=default_model,
|
||||
agent_name=default_agent_name,
|
||||
matched=False,
|
||||
match_method=match_type,
|
||||
match_confidence=1.0,
|
||||
complexity=0.0,
|
||||
)
|
||||
trace.append({
|
||||
"layer": 0,
|
||||
"method": match_type,
|
||||
"matched": False,
|
||||
"cost": "zero",
|
||||
})
|
||||
result.execution_trace = trace if transparency != "SILENT" else []
|
||||
result.transparency_level = transparency
|
||||
return result
|
||||
|
||||
# ---- Layer 1: LLM quick classify (~100 tokens) ----
|
||||
complexity = await self.quick_classify(clean_content)
|
||||
trace.append({
|
||||
"layer": 1,
|
||||
"method": "quick_classify",
|
||||
"complexity": complexity,
|
||||
})
|
||||
|
||||
# Low complexity → default agent
|
||||
if complexity < 0.3:
|
||||
result = SkillRoutingResult(
|
||||
clean_content=clean_content,
|
||||
system_prompt=default_system_prompt,
|
||||
tools=default_tools,
|
||||
model=default_model,
|
||||
agent_name=default_agent_name,
|
||||
matched=False,
|
||||
match_method="low_complexity",
|
||||
match_confidence=1.0 - complexity,
|
||||
complexity=complexity,
|
||||
)
|
||||
trace.append({
|
||||
"layer": 1,
|
||||
"method": "low_complexity",
|
||||
"complexity": complexity,
|
||||
"routed_to": "default",
|
||||
})
|
||||
result.execution_trace = trace if transparency != "SILENT" else []
|
||||
result.transparency_level = transparency
|
||||
return result
|
||||
|
||||
# Medium complexity → IntentRouter via resolve_skill_routing
|
||||
if complexity <= 0.7:
|
||||
result = await resolve_skill_routing(
|
||||
content=content,
|
||||
skill_registry=skill_registry,
|
||||
intent_router=intent_router,
|
||||
default_tools=default_tools,
|
||||
default_system_prompt=default_system_prompt,
|
||||
default_model=default_model,
|
||||
default_agent_name=default_agent_name,
|
||||
agent_tool_registry=agent_tool_registry,
|
||||
session_id=session_id,
|
||||
)
|
||||
result.complexity = complexity
|
||||
trace.append({
|
||||
"layer": 1,
|
||||
"method": "intent_router",
|
||||
"complexity": complexity,
|
||||
"matched": result.matched,
|
||||
})
|
||||
result.execution_trace = trace if transparency != "SILENT" else []
|
||||
result.transparency_level = transparency
|
||||
return result
|
||||
|
||||
# ---- Layer 2: Capability matching / Auction (high complexity) ----
|
||||
trace.append({
|
||||
"layer": 2,
|
||||
"method": "capability_or_auction",
|
||||
"complexity": complexity,
|
||||
"auction_enabled": self._auction_enabled,
|
||||
})
|
||||
result = await self._route_layer2(
|
||||
content=content,
|
||||
skill_registry=skill_registry,
|
||||
intent_router=intent_router,
|
||||
default_tools=default_tools,
|
||||
default_system_prompt=default_system_prompt,
|
||||
default_model=default_model,
|
||||
default_agent_name=default_agent_name,
|
||||
agent_tool_registry=agent_tool_registry,
|
||||
session_id=session_id,
|
||||
complexity=complexity,
|
||||
trace=trace,
|
||||
)
|
||||
result.execution_trace = trace if transparency != "SILENT" else []
|
||||
result.transparency_level = transparency
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from agentkit.evolution.prompt_optimizer import (
|
|||
)
|
||||
from agentkit.evolution.reflector import Reflection, Reflector, RuleBasedReflector
|
||||
from agentkit.evolution.strategy_tuner import StrategyConfig, StrategyTuner
|
||||
from agentkit.memory.profile import MemoryStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -77,6 +78,7 @@ class EvolutionMixin:
|
|||
self._evolution_log: list[EvolutionLogEntry] = []
|
||||
self._current_module: Module | None = None
|
||||
self._strategy_tuning_enabled = strategy_tuning_enabled
|
||||
self.pending_soul_updates: dict[str, list] = {}
|
||||
|
||||
@staticmethod
|
||||
def _create_reflector(
|
||||
|
|
@ -111,16 +113,22 @@ class EvolutionMixin:
|
|||
|
||||
return RuleBasedReflector()
|
||||
|
||||
async def evolve_after_task(self, task: TaskMessage, result: TaskResult) -> EvolutionLogEntry:
|
||||
async def evolve_after_task(
|
||||
self,
|
||||
task: TaskMessage,
|
||||
result: TaskResult,
|
||||
memory_store: MemoryStore | None = None,
|
||||
) -> EvolutionLogEntry:
|
||||
"""任务完成后执行进化流程。
|
||||
|
||||
流程:
|
||||
1. Reflector 反思 → 得到 Reflection
|
||||
2. 如果 Reflection 有改进建议 → PromptOptimizer 优化
|
||||
3. 如果优化产生了新 Prompt → ABTester 验证
|
||||
4. 如果 AB 测试通过 → EvolutionStore 应用变更
|
||||
5. 如果 AB 测试失败 → 回滚
|
||||
6. 如果策略调优启用 → StrategyTuner 调优
|
||||
2. Soul 进化检查(如果 memory_store 可用)
|
||||
3. 如果 Reflection 有改进建议 → PromptOptimizer 优化
|
||||
4. 如果优化产生了新 Prompt → ABTester 验证
|
||||
5. 如果 AB 测试通过 → EvolutionStore 应用变更
|
||||
6. 如果 AB 测试失败 → 回滚
|
||||
7. 如果策略调优启用 → StrategyTuner 调优
|
||||
"""
|
||||
log_entry = EvolutionLogEntry(task_id=task.task_id)
|
||||
|
||||
|
|
@ -139,7 +147,11 @@ class EvolutionMixin:
|
|||
f"suggestions={len(reflection.suggestions)}"
|
||||
)
|
||||
|
||||
# Step 2: 如果有改进建议,触发 Prompt 优化
|
||||
# Step 2: Soul 进化检查
|
||||
if memory_store is not None:
|
||||
await self.evolve_soul(task, result, memory_store)
|
||||
|
||||
# Step 3: 如果有改进建议,触发 Prompt 优化
|
||||
if not reflection.suggestions:
|
||||
logger.debug("No improvement suggestions, skipping optimization")
|
||||
self._evolution_log.append(log_entry)
|
||||
|
|
@ -360,3 +372,68 @@ class EvolutionMixin:
|
|||
except Exception as e:
|
||||
logger.error(f"Failed to rollback evolution change: {e}")
|
||||
return False
|
||||
|
||||
async def evolve_soul(
|
||||
self,
|
||||
task: TaskMessage,
|
||||
result: TaskResult,
|
||||
memory_store: MemoryStore | None = None,
|
||||
) -> bool:
|
||||
"""Check if soul should be updated based on accumulated reflections.
|
||||
|
||||
Conditions for soul update:
|
||||
- Same category reflection appears >= 3 times
|
||||
- Reflection quality_score < 0.5 (indicating consistent issues)
|
||||
- Reflection has actionable suggestions
|
||||
"""
|
||||
if memory_store is None:
|
||||
return False
|
||||
|
||||
if self._reflector is None:
|
||||
return False
|
||||
|
||||
reflection = await self._reflector.reflect(task, result)
|
||||
|
||||
# 只关注低质量且有建议的反思
|
||||
if reflection.quality_score >= 0.5:
|
||||
return False
|
||||
|
||||
if not reflection.suggestions:
|
||||
return False
|
||||
|
||||
# 按 pattern 分类累积反思
|
||||
for pattern in reflection.patterns:
|
||||
if pattern not in self.pending_soul_updates:
|
||||
self.pending_soul_updates[pattern] = []
|
||||
self.pending_soul_updates[pattern].append(reflection)
|
||||
|
||||
# 检查是否有同一类别累积 >= 3 次反思
|
||||
for category, reflections in self.pending_soul_updates.items():
|
||||
if len(reflections) >= 3:
|
||||
# 触发 soul 更新
|
||||
from agentkit.tools.memory_tool import MemoryTool
|
||||
|
||||
tool = MemoryTool(memory_store)
|
||||
# 使用第一个建议作为更新内容
|
||||
section = category
|
||||
content = "; ".join(reflections[-1].suggestions[:2])
|
||||
reason = f"连续{len(reflections)}次低质量反思 (category: {category})"
|
||||
|
||||
update_result = await tool.execute(
|
||||
action="update_soul",
|
||||
file="soul",
|
||||
section=section,
|
||||
content=content,
|
||||
reason=reason,
|
||||
)
|
||||
|
||||
if update_result.get("success"):
|
||||
logger.info(
|
||||
f"Soul evolved: category={category}, "
|
||||
f"version={update_result.get('version')}"
|
||||
)
|
||||
# 清除已处理的类别
|
||||
del self.pending_soul_updates[category]
|
||||
return True
|
||||
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -0,0 +1,13 @@
|
|||
"""AgentKit Marketplace - 拍卖机制与财富追踪"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from agentkit.marketplace.auction import AuctionHouse, AuctionResult, Bid
|
||||
from agentkit.marketplace.wealth import WealthTracker
|
||||
|
||||
__all__ = [
|
||||
"Bid",
|
||||
"AuctionResult",
|
||||
"AuctionHouse",
|
||||
"WealthTracker",
|
||||
]
|
||||
|
|
@ -0,0 +1,100 @@
|
|||
"""AuctionHouse - 拍卖机制,基于竞价选择 Agent"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from agentkit.marketplace.wealth import WealthTracker
|
||||
|
||||
|
||||
@dataclass
|
||||
class Bid:
|
||||
"""Agent 竞价信息"""
|
||||
|
||||
agent_name: str
|
||||
architecture: str # "react", "rewoo", "plan_exec", "reflexion", "direct"
|
||||
estimated_steps: int
|
||||
estimated_cost: float # estimated token cost
|
||||
confidence: float # 0.0-1.0 confidence in completing the task
|
||||
payment_offer: float # how much the agent "charges"
|
||||
capabilities: list[str] = field(default_factory=list)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuctionResult:
|
||||
"""拍卖结果"""
|
||||
|
||||
winner: Bid | None
|
||||
all_bids: list[Bid]
|
||||
selection_reason: str
|
||||
total_bidders: int
|
||||
|
||||
|
||||
class AuctionHouse:
|
||||
"""Auction-based agent selection mechanism.
|
||||
|
||||
Default disabled. Enable via marketplace.auction_enabled: true in config.
|
||||
When enabled, Layer 2 routing uses auction instead of capability matching.
|
||||
"""
|
||||
|
||||
def __init__(self, wealth_tracker: WealthTracker | None = None) -> None:
|
||||
self._wealth_tracker = wealth_tracker or WealthTracker()
|
||||
|
||||
async def run_auction(self, task_description: str, bidders: list[Bid]) -> AuctionResult:
|
||||
"""Run auction among bidders, select winner.
|
||||
|
||||
Scoring formula:
|
||||
score = (confidence / max(estimated_cost, 0.001)) * wealth_factor
|
||||
|
||||
wealth_factor = 1.0 + (wealth / 1000.0) # wealth bonus, diminishing returns
|
||||
"""
|
||||
if not bidders:
|
||||
return AuctionResult(
|
||||
winner=None,
|
||||
all_bids=[],
|
||||
selection_reason="No bidders participated",
|
||||
total_bidders=0,
|
||||
)
|
||||
|
||||
# Filter out bankrupt agents
|
||||
eligible = [
|
||||
b for b in bidders
|
||||
if not self._wealth_tracker.is_bankrupt(b.agent_name)
|
||||
]
|
||||
|
||||
if not eligible:
|
||||
return AuctionResult(
|
||||
winner=None,
|
||||
all_bids=bidders,
|
||||
selection_reason="All bidders are bankrupt",
|
||||
total_bidders=len(bidders),
|
||||
)
|
||||
|
||||
# Score each bid
|
||||
scored: list[tuple[Bid, float]] = []
|
||||
for bid in eligible:
|
||||
score = self.score_bid(bid)
|
||||
scored.append((bid, score))
|
||||
|
||||
# Select highest score
|
||||
scored.sort(key=lambda x: x[1], reverse=True)
|
||||
winner, winner_score = scored[0]
|
||||
|
||||
return AuctionResult(
|
||||
winner=winner,
|
||||
all_bids=bidders,
|
||||
selection_reason=(
|
||||
f"Agent '{winner.agent_name}' won with score {winner_score:.4f} "
|
||||
f"(confidence={winner.confidence}, cost={winner.estimated_cost}, "
|
||||
f"wealth_factor={self._wealth_tracker.get_wealth_factor(winner.agent_name):.4f})"
|
||||
),
|
||||
total_bidders=len(bidders),
|
||||
)
|
||||
|
||||
def score_bid(self, bid: Bid) -> float:
|
||||
"""Calculate bid score without running full auction"""
|
||||
wealth_factor = self._wealth_tracker.get_wealth_factor(bid.agent_name)
|
||||
score = (bid.confidence / max(bid.estimated_cost, 0.001)) * wealth_factor
|
||||
return score
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
"""WealthTracker - Agent 财富追踪,用于拍卖机制"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class WealthTracker:
|
||||
"""Track agent wealth for auction mechanism.
|
||||
|
||||
Agents earn wealth by completing tasks successfully.
|
||||
Agents lose wealth by failing tasks.
|
||||
Bankrupt agents (wealth <= -100) are excluded from auctions.
|
||||
"""
|
||||
|
||||
def __init__(self, initial_wealth: float = 100.0) -> None:
|
||||
self._balances: dict[str, float] = {}
|
||||
self._initial_wealth = initial_wealth
|
||||
|
||||
def get_wealth(self, agent_name: str) -> float:
|
||||
"""Get agent's current wealth, defaulting to initial_wealth"""
|
||||
return self._balances.get(agent_name, self._initial_wealth)
|
||||
|
||||
def reward(self, agent_name: str, amount: float) -> None:
|
||||
"""Reward agent for successful task completion"""
|
||||
current = self.get_wealth(agent_name)
|
||||
self._balances[agent_name] = current + amount
|
||||
|
||||
def penalize(self, agent_name: str, amount: float) -> None:
|
||||
"""Penalize agent for task failure"""
|
||||
current = self.get_wealth(agent_name)
|
||||
self._balances[agent_name] = current - amount
|
||||
|
||||
def is_bankrupt(self, agent_name: str) -> bool:
|
||||
"""Check if agent is bankrupt (wealth <= -100)"""
|
||||
return self.get_wealth(agent_name) <= -100
|
||||
|
||||
def reset(self, agent_name: str) -> None:
|
||||
"""Reset agent's wealth to initial value"""
|
||||
self._balances[agent_name] = self._initial_wealth
|
||||
|
||||
def get_rankings(self) -> list[tuple[str, float]]:
|
||||
"""Get wealth rankings sorted by wealth descending"""
|
||||
all_agents = [
|
||||
(name, wealth) for name, wealth in self._balances.items()
|
||||
]
|
||||
all_agents.sort(key=lambda x: x[1], reverse=True)
|
||||
return all_agents
|
||||
|
||||
def get_wealth_factor(self, agent_name: str) -> float:
|
||||
"""Get wealth factor for scoring: 1.0 + (wealth / 1000.0)"""
|
||||
return 1.0 + (self.get_wealth(agent_name) / 1000.0)
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
"""OrganizationContext - 组织上下文与 Agent 发现"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from agentkit.org.context import AgentProfile, OrganizationContext
|
||||
from agentkit.org.discovery import AgentDiscovery
|
||||
|
||||
__all__ = [
|
||||
"AgentProfile",
|
||||
"OrganizationContext",
|
||||
"AgentDiscovery",
|
||||
]
|
||||
|
|
@ -0,0 +1,173 @@
|
|||
"""OrganizationContext - 组织上下文,管理 AgentProfile 与能力矩阵"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentProfile:
|
||||
"""Agent 档案 - 描述组织中一个 Agent 的能力与状态"""
|
||||
|
||||
name: str
|
||||
agent_type: str # "react", "rewoo", "plan_exec", "reflexion", "direct"
|
||||
capabilities: list[str] # capability tag strings
|
||||
skills: list[str] # skill names
|
||||
current_load: int = 0 # number of active tasks
|
||||
max_concurrency: int = 1
|
||||
availability: bool = True
|
||||
specializations: list[str] = field(default_factory=list)
|
||||
model: str = "default"
|
||||
execution_mode: str = "react"
|
||||
|
||||
|
||||
class OrganizationContext:
|
||||
"""组织上下文 - 管理 Agent 档案与能力矩阵,支持基于能力的 Agent 发现"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._agents: dict[str, AgentProfile] = {}
|
||||
self._capability_matrix: dict[str, list[str]] = {} # capability -> [agent_names]
|
||||
|
||||
def register_agent(self, profile: AgentProfile) -> None:
|
||||
"""注册 Agent 档案"""
|
||||
self._agents[profile.name] = profile
|
||||
# 更新能力矩阵
|
||||
for cap in profile.capabilities:
|
||||
cap_lower = cap.lower()
|
||||
if cap_lower not in self._capability_matrix:
|
||||
self._capability_matrix[cap_lower] = []
|
||||
if profile.name not in self._capability_matrix[cap_lower]:
|
||||
self._capability_matrix[cap_lower].append(profile.name)
|
||||
logger.info(f"Agent profile '{profile.name}' registered")
|
||||
|
||||
def unregister_agent(self, name: str) -> None:
|
||||
"""注销 Agent 档案"""
|
||||
profile = self._agents.pop(name, None)
|
||||
if profile is None:
|
||||
return
|
||||
# 清理能力矩阵
|
||||
for cap in profile.capabilities:
|
||||
cap_lower = cap.lower()
|
||||
if cap_lower in self._capability_matrix:
|
||||
self._capability_matrix[cap_lower] = [
|
||||
n for n in self._capability_matrix[cap_lower] if n != name
|
||||
]
|
||||
if not self._capability_matrix[cap_lower]:
|
||||
del self._capability_matrix[cap_lower]
|
||||
logger.info(f"Agent profile '{name}' unregistered")
|
||||
|
||||
def get_agent_profile(self, name: str) -> AgentProfile | None:
|
||||
"""获取 Agent 档案"""
|
||||
return self._agents.get(name)
|
||||
|
||||
def list_agents(self) -> list[AgentProfile]:
|
||||
"""列出所有 Agent 档案"""
|
||||
return list(self._agents.values())
|
||||
|
||||
def find_best_agent(
|
||||
self,
|
||||
required_capabilities: list[str],
|
||||
exclude: list[str] | None = None,
|
||||
) -> AgentProfile | None:
|
||||
"""根据能力需求找到最佳 Agent
|
||||
|
||||
逻辑:
|
||||
1. 找到拥有所有所需能力的 Agent
|
||||
2. 在匹配的 Agent 中,优先选择 current_load 较低的
|
||||
3. 排除 exclude 列表中的 Agent
|
||||
4. 排除不可用的 Agent
|
||||
5. 没有匹配则返回 None
|
||||
"""
|
||||
exclude_set = set(exclude or [])
|
||||
|
||||
# 对每个所需能力,查找拥有该能力的 Agent 名称集合
|
||||
candidate_names: set[str] | None = None
|
||||
for cap in required_capabilities:
|
||||
cap_lower = cap.lower()
|
||||
agents_with_cap = set(self._capability_matrix.get(cap_lower, []))
|
||||
if candidate_names is None:
|
||||
candidate_names = agents_with_cap
|
||||
else:
|
||||
candidate_names &= agents_with_cap
|
||||
|
||||
if not candidate_names:
|
||||
return None
|
||||
|
||||
# 过滤排除和不可用的 Agent,按 load 排序
|
||||
candidates = [
|
||||
self._agents[name]
|
||||
for name in candidate_names
|
||||
if name not in exclude_set
|
||||
and name in self._agents
|
||||
and self._agents[name].availability
|
||||
]
|
||||
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
candidates.sort(key=lambda p: p.current_load)
|
||||
return candidates[0]
|
||||
|
||||
def update_load(self, name: str, delta: int) -> None:
|
||||
"""更新 Agent 负载"""
|
||||
profile = self._agents.get(name)
|
||||
if profile is not None:
|
||||
profile.current_load = max(0, profile.current_load + delta)
|
||||
|
||||
def set_availability(self, name: str, available: bool) -> None:
|
||||
"""设置 Agent 可用性"""
|
||||
profile = self._agents.get(name)
|
||||
if profile is not None:
|
||||
profile.availability = available
|
||||
|
||||
@classmethod
|
||||
def from_agent_pool(cls, agent_pool, skill_registry) -> OrganizationContext:
|
||||
"""从 AgentPool 和 SkillRegistry 构建 OrganizationContext
|
||||
|
||||
Args:
|
||||
agent_pool: AgentPool 实例,提供运行时 Agent 列表
|
||||
skill_registry: SkillRegistry 实例,提供 Skill 配置查询
|
||||
"""
|
||||
ctx = cls()
|
||||
|
||||
if agent_pool is None or skill_registry is None:
|
||||
return ctx
|
||||
|
||||
for agent_info in agent_pool.list_agents():
|
||||
agent_name = agent_info["name"]
|
||||
agent_type = agent_info.get("agent_type", "react")
|
||||
|
||||
# 尝试从 skill_registry 获取 SkillConfig
|
||||
capabilities: list[str] = []
|
||||
skills: list[str] = []
|
||||
execution_mode = "react"
|
||||
model = "default"
|
||||
max_concurrency = 1
|
||||
|
||||
try:
|
||||
skill = skill_registry.get(agent_name)
|
||||
config = skill.config
|
||||
capabilities = [cap.tag for cap in config.capabilities]
|
||||
execution_mode = config.execution_mode
|
||||
model = config.llm.get("model", "default") if config.llm else "default"
|
||||
max_concurrency = config.max_concurrency
|
||||
skills = [agent_name]
|
||||
except Exception:
|
||||
# Agent 不在 skill_registry 中,使用默认值
|
||||
skills = [agent_name]
|
||||
|
||||
profile = AgentProfile(
|
||||
name=agent_name,
|
||||
agent_type=agent_type,
|
||||
capabilities=capabilities,
|
||||
skills=skills,
|
||||
execution_mode=execution_mode,
|
||||
model=model,
|
||||
max_concurrency=max_concurrency,
|
||||
)
|
||||
ctx.register_agent(profile)
|
||||
|
||||
return ctx
|
||||
|
|
@ -0,0 +1,74 @@
|
|||
"""AgentDiscovery - 基于 OrganizationContext 的 Agent 发现与推荐"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from agentkit.org.context import AgentProfile, OrganizationContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentDiscovery:
|
||||
"""Agent 发现 - 提供多种维度的 Agent 查询与推荐"""
|
||||
|
||||
def __init__(self, org_context: OrganizationContext) -> None:
|
||||
self._org = org_context
|
||||
|
||||
def discover_by_capability(self, required_capabilities: list[str]) -> list[AgentProfile]:
|
||||
"""按能力标签发现 Agent(需满足所有指定能力)"""
|
||||
result: list[AgentProfile] = []
|
||||
for profile in self._org.list_agents():
|
||||
profile_caps_lower = {c.lower() for c in profile.capabilities}
|
||||
if all(cap.lower() in profile_caps_lower for cap in required_capabilities):
|
||||
result.append(profile)
|
||||
return result
|
||||
|
||||
def discover_by_execution_mode(self, mode: str) -> list[AgentProfile]:
|
||||
"""按执行模式发现 Agent"""
|
||||
return [
|
||||
p for p in self._org.list_agents()
|
||||
if p.execution_mode == mode
|
||||
]
|
||||
|
||||
def discover_available(self) -> list[AgentProfile]:
|
||||
"""发现所有可用的 Agent"""
|
||||
return [p for p in self._org.list_agents() if p.availability]
|
||||
|
||||
def recommend_agent(
|
||||
self,
|
||||
required_capabilities: list[str],
|
||||
preferred_mode: str | None = None,
|
||||
) -> AgentProfile | None:
|
||||
"""推荐最佳 Agent
|
||||
|
||||
逻辑:
|
||||
1. 如果指定了 preferred_mode,先按 execution_mode 过滤
|
||||
2. 然后按能力匹配 + 负载均衡找到最佳 Agent
|
||||
3. 如果没有能力匹配的,回退到任何可用 Agent
|
||||
"""
|
||||
# 按能力发现候选
|
||||
candidates = self.discover_by_capability(required_capabilities)
|
||||
|
||||
# 过滤不可用
|
||||
candidates = [c for c in candidates if c.availability]
|
||||
|
||||
# 如果指定了 preferred_mode,优先匹配
|
||||
if preferred_mode is not None:
|
||||
mode_matched = [c for c in candidates if c.execution_mode == preferred_mode]
|
||||
if mode_matched:
|
||||
mode_matched.sort(key=lambda p: p.current_load)
|
||||
return mode_matched[0]
|
||||
|
||||
# 按负载排序返回最佳
|
||||
if candidates:
|
||||
candidates.sort(key=lambda p: p.current_load)
|
||||
return candidates[0]
|
||||
|
||||
# 回退:返回任何可用 Agent
|
||||
available = self.discover_available()
|
||||
if available:
|
||||
available.sort(key=lambda p: p.current_load)
|
||||
return available[0]
|
||||
|
||||
return None
|
||||
|
|
@ -1,5 +1,13 @@
|
|||
"""Quality Gate & Output Standardizer"""
|
||||
|
||||
from agentkit.quality.alignment import (
|
||||
AlignmentCheckResult,
|
||||
AlignmentConfig,
|
||||
AlignmentGuard,
|
||||
CascadeAlert,
|
||||
ConstraintInjector,
|
||||
)
|
||||
from agentkit.quality.cascade_detector import CascadeDetector
|
||||
from agentkit.quality.gate import QualityCheck, QualityGate, QualityResult
|
||||
from agentkit.quality.output import OutputMetadata, OutputStandardizer, StandardOutput
|
||||
|
||||
|
|
@ -10,4 +18,10 @@ __all__ = [
|
|||
"OutputStandardizer",
|
||||
"StandardOutput",
|
||||
"OutputMetadata",
|
||||
"AlignmentConfig",
|
||||
"AlignmentGuard",
|
||||
"AlignmentCheckResult",
|
||||
"CascadeAlert",
|
||||
"ConstraintInjector",
|
||||
"CascadeDetector",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,206 @@
|
|||
"""AlignmentGuard - 对齐守卫:约束注入 + 级联故障检测"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AlignmentConfig:
|
||||
"""对齐守卫配置"""
|
||||
|
||||
constraints: list[str] = field(default_factory=list)
|
||||
cascade_max_interactions: int = 10
|
||||
cascade_max_depth: int = 3
|
||||
audit_enabled: bool = False
|
||||
audit_model: str = "default"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AlignmentCheckResult:
|
||||
"""对齐检查结果"""
|
||||
|
||||
passed: bool
|
||||
violations: list[str] = field(default_factory=list)
|
||||
checked_by: str = "" # "rule" or "llm"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CascadeAlert:
|
||||
"""级联故障告警"""
|
||||
|
||||
session_id: str
|
||||
alert_type: str # "interaction_limit" or "loop_depth"
|
||||
current_value: int
|
||||
threshold: int
|
||||
message: str
|
||||
|
||||
|
||||
class ConstraintInjector:
|
||||
"""将全局约束注入到任务 input_data 中"""
|
||||
|
||||
def __init__(self, config: AlignmentConfig):
|
||||
self._config = config
|
||||
|
||||
def inject(self, input_data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""注入约束指令到 input_data
|
||||
|
||||
在 input_data 中添加 'alignment_constraints' 键,值为约束列表。
|
||||
不修改原始 dict,返回新 dict。
|
||||
"""
|
||||
result = {**input_data, "alignment_constraints": list(self._config.constraints)}
|
||||
return result
|
||||
|
||||
|
||||
class AlignmentGuard:
|
||||
"""对齐守卫 — 扩展 QualityGate,增加约束注入和级联检测"""
|
||||
|
||||
def __init__(self, config: AlignmentConfig, llm_gateway=None):
|
||||
self._config = config
|
||||
self._injector = ConstraintInjector(config)
|
||||
self._llm_gateway = llm_gateway
|
||||
self._interaction_counts: dict[str, int] = {}
|
||||
self._loop_depths: dict[str, int] = {}
|
||||
|
||||
def inject_constraints(self, input_data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""委托给 ConstraintInjector"""
|
||||
return self._injector.inject(input_data)
|
||||
|
||||
async def check_output(
|
||||
self,
|
||||
output: dict[str, Any],
|
||||
constraints: list[str] | None = None,
|
||||
) -> AlignmentCheckResult:
|
||||
"""检查输出是否符合约束
|
||||
|
||||
- 系统级约束:基于规则的检查(关键词 + 正则匹配)
|
||||
- 组织级约束:LLM 语义检查(仅当 audit_enabled=True)
|
||||
"""
|
||||
effective_constraints = constraints if constraints is not None else self._config.constraints
|
||||
if not effective_constraints:
|
||||
return AlignmentCheckResult(passed=True, checked_by="rule")
|
||||
|
||||
# 1. 基于规则的检查:关键词/子串匹配
|
||||
violations = self._rule_check(output, effective_constraints)
|
||||
if violations:
|
||||
return AlignmentCheckResult(
|
||||
passed=False,
|
||||
violations=violations,
|
||||
checked_by="rule",
|
||||
)
|
||||
|
||||
# 2. LLM 语义检查(仅当 audit_enabled=True 且有 llm_gateway)
|
||||
if self._config.audit_enabled and self._llm_gateway is not None:
|
||||
return await self._llm_check(output, effective_constraints)
|
||||
|
||||
return AlignmentCheckResult(passed=True, checked_by="rule")
|
||||
|
||||
def _rule_check(
|
||||
self, output: dict[str, Any], constraints: list[str]
|
||||
) -> list[str]:
|
||||
"""基于规则的约束检查:将 output 内容拼接后做关键词/子串匹配"""
|
||||
content = self._extract_text(output)
|
||||
violations: list[str] = []
|
||||
for constraint in constraints:
|
||||
# 简单子串匹配:约束关键词出现在输出中即视为违规
|
||||
if constraint.lower() in content.lower():
|
||||
violations.append(constraint)
|
||||
return violations
|
||||
|
||||
@staticmethod
|
||||
def _extract_text(output: dict[str, Any]) -> str:
|
||||
"""从 output dict 中提取所有文本内容"""
|
||||
parts: list[str] = []
|
||||
for value in output.values():
|
||||
if isinstance(value, str):
|
||||
parts.append(value)
|
||||
else:
|
||||
parts.append(str(value))
|
||||
return " ".join(parts)
|
||||
|
||||
async def _llm_check(
|
||||
self, output: dict[str, Any], constraints: list[str]
|
||||
) -> AlignmentCheckResult:
|
||||
"""LLM 语义检查"""
|
||||
content = self._extract_text(output)
|
||||
constraint_text = "\n".join(f"- {c}" for c in constraints)
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are an alignment auditor. Check if the following output "
|
||||
"violates any of the listed constraints. "
|
||||
"Reply with 'PASS' if no violations, or list the violated constraints."
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"Constraints:\n{constraint_text}\n\nOutput:\n{content}"
|
||||
),
|
||||
},
|
||||
]
|
||||
try:
|
||||
response = await self._llm_gateway.chat(
|
||||
messages=messages, model=self._config.audit_model
|
||||
)
|
||||
reply = response.content.strip()
|
||||
if reply.upper().startswith("PASS"):
|
||||
return AlignmentCheckResult(passed=True, checked_by="llm")
|
||||
else:
|
||||
return AlignmentCheckResult(
|
||||
passed=False,
|
||||
violations=[reply],
|
||||
checked_by="llm",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM audit failed: {e}")
|
||||
return AlignmentCheckResult(passed=True, checked_by="rule")
|
||||
|
||||
def record_interaction(self, session_id: str) -> CascadeAlert | None:
|
||||
"""记录一次 agent 间交互,超过阈值返回 CascadeAlert"""
|
||||
self._interaction_counts[session_id] = (
|
||||
self._interaction_counts.get(session_id, 0) + 1
|
||||
)
|
||||
count = self._interaction_counts[session_id]
|
||||
if count > self._config.cascade_max_interactions:
|
||||
return CascadeAlert(
|
||||
session_id=session_id,
|
||||
alert_type="interaction_limit",
|
||||
current_value=count,
|
||||
threshold=self._config.cascade_max_interactions,
|
||||
message=(
|
||||
f"Session {session_id} exceeded max interactions: "
|
||||
f"{count} > {self._config.cascade_max_interactions}"
|
||||
),
|
||||
)
|
||||
return None
|
||||
|
||||
def record_loop_depth(self, session_id: str, depth: int) -> CascadeAlert | None:
|
||||
"""记录循环深度,超过阈值返回 CascadeAlert"""
|
||||
self._loop_depths[session_id] = depth
|
||||
if depth > self._config.cascade_max_depth:
|
||||
return CascadeAlert(
|
||||
session_id=session_id,
|
||||
alert_type="loop_depth",
|
||||
current_value=depth,
|
||||
threshold=self._config.cascade_max_depth,
|
||||
message=(
|
||||
f"Session {session_id} exceeded max loop depth: "
|
||||
f"{depth} > {self._config.cascade_max_depth}"
|
||||
),
|
||||
)
|
||||
return None
|
||||
|
||||
def reset_session(self, session_id: str) -> None:
|
||||
"""重置某个 session 的交互计数"""
|
||||
self._interaction_counts.pop(session_id, None)
|
||||
self._loop_depths.pop(session_id, None)
|
||||
|
||||
def get_interaction_count(self, session_id: str) -> int:
|
||||
"""获取某个 session 的当前交互计数"""
|
||||
return self._interaction_counts.get(session_id, 0)
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
"""CascadeDetector - 独立的级联故障检测工具"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class CascadeAlert:
|
||||
"""级联故障告警"""
|
||||
|
||||
session_id: str
|
||||
alert_type: str # "interaction_limit" or "loop_depth"
|
||||
current_value: int
|
||||
threshold: int
|
||||
message: str
|
||||
|
||||
|
||||
class CascadeDetector:
|
||||
"""检测多 agent 交互中的级联故障"""
|
||||
|
||||
def __init__(self, max_interactions: int = 10, max_depth: int = 3):
|
||||
self._max_interactions = max_interactions
|
||||
self._max_depth = max_depth
|
||||
self._interaction_counts: dict[str, int] = {}
|
||||
self._loop_depths: dict[str, int] = {}
|
||||
|
||||
def check_interaction(self, session_id: str) -> CascadeAlert | None:
|
||||
"""递增并检查交互计数"""
|
||||
self._interaction_counts[session_id] = (
|
||||
self._interaction_counts.get(session_id, 0) + 1
|
||||
)
|
||||
count = self._interaction_counts[session_id]
|
||||
if count > self._max_interactions:
|
||||
return CascadeAlert(
|
||||
session_id=session_id,
|
||||
alert_type="interaction_limit",
|
||||
current_value=count,
|
||||
threshold=self._max_interactions,
|
||||
message=(
|
||||
f"Session {session_id} exceeded max interactions: "
|
||||
f"{count} > {self._max_interactions}"
|
||||
),
|
||||
)
|
||||
return None
|
||||
|
||||
def check_depth(self, session_id: str, depth: int) -> CascadeAlert | None:
|
||||
"""检查循环深度"""
|
||||
self._loop_depths[session_id] = depth
|
||||
if depth > self._max_depth:
|
||||
return CascadeAlert(
|
||||
session_id=session_id,
|
||||
alert_type="loop_depth",
|
||||
current_value=depth,
|
||||
threshold=self._max_depth,
|
||||
message=(
|
||||
f"Session {session_id} exceeded max loop depth: "
|
||||
f"{depth} > {self._max_depth}"
|
||||
),
|
||||
)
|
||||
return None
|
||||
|
||||
def reset(self, session_id: str) -> None:
|
||||
"""重置某个 session 的计数器"""
|
||||
self._interaction_counts.pop(session_id, None)
|
||||
self._loop_depths.pop(session_id, None)
|
||||
|
||||
def get_stats(self, session_id: str) -> dict[str, int]:
|
||||
"""获取某个 session 的当前统计"""
|
||||
return {
|
||||
"interactions": self._interaction_counts.get(session_id, 0),
|
||||
"depth": self._loop_depths.get(session_id, 0),
|
||||
}
|
||||
|
|
@ -438,6 +438,35 @@ def create_app(
|
|||
app.state.intent_router = IntentRouter(llm_gateway=app.state.llm_gateway)
|
||||
app.state.quality_gate = QualityGate()
|
||||
app.state.output_standardizer = OutputStandardizer()
|
||||
|
||||
# Initialize OrganizationContext from AgentPool + SkillRegistry
|
||||
from agentkit.org.context import OrganizationContext
|
||||
org_context = OrganizationContext.from_agent_pool(
|
||||
agent_pool=app.state.agent_pool,
|
||||
skill_registry=app.state.skill_registry,
|
||||
)
|
||||
app.state.org_context = org_context
|
||||
|
||||
# Initialize AlignmentGuard from config
|
||||
from agentkit.quality.alignment import AlignmentGuard, AlignmentConfig
|
||||
alignment_config_data = {}
|
||||
if server_config and hasattr(server_config, "alignment") and server_config.alignment:
|
||||
alignment_config_data = server_config.alignment
|
||||
alignment_config = AlignmentConfig(**alignment_config_data)
|
||||
alignment_guard = AlignmentGuard(config=alignment_config, llm_gateway=app.state.llm_gateway)
|
||||
app.state.alignment_guard = alignment_guard
|
||||
|
||||
# Initialize CostAwareRouter
|
||||
from agentkit.chat.skill_routing import CostAwareRouter
|
||||
auction_enabled = False
|
||||
if server_config and hasattr(server_config, "marketplace") and server_config.marketplace:
|
||||
auction_enabled = server_config.marketplace.get("auction_enabled", False)
|
||||
cost_aware_router = CostAwareRouter(
|
||||
llm_gateway=app.state.llm_gateway,
|
||||
org_context=org_context,
|
||||
auction_enabled=auction_enabled,
|
||||
)
|
||||
app.state.cost_aware_router = cost_aware_router
|
||||
# Initialize task store from config
|
||||
ts_config = server_config.task_store if server_config else {}
|
||||
# Merge CLI overrides from AGENTKIT_TASK_STORE env var
|
||||
|
|
|
|||
|
|
@ -108,6 +108,8 @@ class ServerConfig:
|
|||
compression: dict[str, Any] | None = None,
|
||||
session: dict[str, Any] | None = None,
|
||||
bus: dict[str, Any] | None = None,
|
||||
marketplace: dict[str, Any] | None = None,
|
||||
alignment: dict[str, Any] | None = None,
|
||||
on_change: Callable[["ServerConfig"], None] | None = None,
|
||||
):
|
||||
self.host = host
|
||||
|
|
@ -128,6 +130,8 @@ class ServerConfig:
|
|||
self.compression = compression or {}
|
||||
self.session = session or {}
|
||||
self.bus = bus or {}
|
||||
self.marketplace = marketplace or {}
|
||||
self.alignment = alignment or {}
|
||||
self.on_change = on_change
|
||||
|
||||
# Config watching state
|
||||
|
|
@ -186,6 +190,12 @@ class ServerConfig:
|
|||
# Session config
|
||||
session_data = data.get("session", {})
|
||||
|
||||
# Marketplace config
|
||||
marketplace_data = data.get("marketplace", {})
|
||||
|
||||
# Alignment config
|
||||
alignment_data = data.get("alignment", {})
|
||||
|
||||
return cls(
|
||||
host=server.get("host", "0.0.0.0"),
|
||||
port=server.get("port", 8001),
|
||||
|
|
@ -205,6 +215,8 @@ class ServerConfig:
|
|||
compression=compression_data,
|
||||
session=session_data,
|
||||
bus=server.get("bus"),
|
||||
marketplace=marketplace_data,
|
||||
alignment=alignment_data,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -397,6 +409,8 @@ class ServerConfig:
|
|||
self.telemetry = new_config.telemetry
|
||||
self.compression = new_config.compression
|
||||
self.session = new_config.session
|
||||
self.marketplace = new_config.marketplace
|
||||
self.alignment = new_config.alignment
|
||||
self._last_mtime = new_config._last_mtime
|
||||
|
||||
logger.info(f"Config reloaded from {path}")
|
||||
|
|
|
|||
|
|
@ -84,6 +84,8 @@ class SkillConfig(AgentConfig):
|
|||
# v4 新增字段:依赖声明、能力标签
|
||||
dependencies: list[dict[str, Any] | DependencyDecl] | None = None,
|
||||
capabilities: list[str | dict[str, Any] | CapabilityTag] | None = None,
|
||||
# v5 新增字段:对齐守卫
|
||||
alignment: dict[str, Any] | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
name=name,
|
||||
|
|
@ -111,6 +113,9 @@ class SkillConfig(AgentConfig):
|
|||
# v4: 解析依赖和能力标签
|
||||
self.dependencies = self._parse_dependencies(dependencies or [])
|
||||
self.capabilities = self._parse_capabilities(capabilities or [])
|
||||
# v5: 对齐守卫配置
|
||||
from agentkit.quality.alignment import AlignmentConfig
|
||||
self.alignment = AlignmentConfig(**(alignment or {}))
|
||||
self._validate_v2()
|
||||
|
||||
def _validate_v2(self) -> None:
|
||||
|
|
@ -184,6 +189,7 @@ class SkillConfig(AgentConfig):
|
|||
disclosure_level=data.get("disclosure_level", 0),
|
||||
dependencies=data.get("dependencies"),
|
||||
capabilities=data.get("capabilities"),
|
||||
alignment=data.get("alignment"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
|
@ -244,6 +250,14 @@ class SkillConfig(AgentConfig):
|
|||
{"tag": cap.tag, "description": cap.description}
|
||||
for cap in self.capabilities
|
||||
]
|
||||
# v5: 对齐守卫
|
||||
d["alignment"] = {
|
||||
"constraints": self.alignment.constraints,
|
||||
"cascade_max_interactions": self.alignment.cascade_max_interactions,
|
||||
"cascade_max_depth": self.alignment.cascade_max_depth,
|
||||
"audit_enabled": self.alignment.audit_enabled,
|
||||
"audit_model": self.alignment.audit_model,
|
||||
}
|
||||
return d
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,20 +5,23 @@
|
|||
- replace: 替换 section 内的文本
|
||||
- remove: 删除整个 section
|
||||
- read: 读取文件内容
|
||||
- update_soul: 动态更新 SOUL 文件(带版本追踪)
|
||||
|
||||
file 参数: soul | user | memory | daily
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from agentkit.memory.profile import MemoryStore
|
||||
from agentkit.memory.profile import MemoryFile, MemoryStore
|
||||
from agentkit.tools.base import Tool
|
||||
|
||||
|
||||
VALID_FILES = {"soul", "user", "memory", "daily"}
|
||||
VALID_ACTIONS = {"add", "replace", "remove", "read"}
|
||||
VALID_ACTIONS = {"add", "replace", "remove", "read", "update_soul"}
|
||||
|
||||
|
||||
class MemoryTool(Tool):
|
||||
|
|
@ -37,7 +40,7 @@ class MemoryTool(Tool):
|
|||
"action": {
|
||||
"type": "string",
|
||||
"enum": list(VALID_ACTIONS),
|
||||
"description": "Operation: add, replace, remove, read",
|
||||
"description": "Operation: add, replace, remove, read, update_soul",
|
||||
},
|
||||
"file": {
|
||||
"type": "string",
|
||||
|
|
@ -60,6 +63,10 @@ class MemoryTool(Tool):
|
|||
"type": "string",
|
||||
"description": "Replacement text for replace action",
|
||||
},
|
||||
"reason": {
|
||||
"type": "string",
|
||||
"description": "Reason for update_soul action (stored in version history)",
|
||||
},
|
||||
},
|
||||
"required": ["action", "file"],
|
||||
},
|
||||
|
|
@ -111,7 +118,68 @@ class MemoryTool(Tool):
|
|||
mf.remove_section(section)
|
||||
return {"success": True, "message": f"Removed {file_key}/{section}"}
|
||||
|
||||
elif action == "update_soul":
|
||||
section = kwargs.get("section", "")
|
||||
content = kwargs.get("content", "")
|
||||
reason = kwargs.get("reason", "")
|
||||
if not section:
|
||||
return {"success": False, "error": "section is required for update_soul action"}
|
||||
if not content:
|
||||
return {"success": False, "error": "content is required for update_soul action"}
|
||||
return await self._update_soul(mf, section, content, reason)
|
||||
|
||||
return {"success": False, "error": f"Unhandled action: {action}"}
|
||||
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def _update_soul(
|
||||
self, mf: MemoryFile, section: str, content: str, reason: str
|
||||
) -> dict[str, Any]:
|
||||
"""执行 SOUL 动态更新,带版本追踪和更新历史."""
|
||||
# 解析当前版本号
|
||||
version = 1
|
||||
version_content = mf.read_section("版本")
|
||||
if version_content:
|
||||
match = re.search(r"版本:\s*(\d+)", version_content)
|
||||
if match:
|
||||
version = int(match.group(1))
|
||||
|
||||
new_version = version + 1
|
||||
now = datetime.now(timezone.utc)
|
||||
timestamp = now.strftime("%Y-%m-%dT%H:%M:%S")
|
||||
date_str = now.strftime("%Y-%m-%d")
|
||||
|
||||
# 更新目标 section
|
||||
if section in mf.list_sections():
|
||||
mf.remove_section(section)
|
||||
mf.add_section(section, content)
|
||||
|
||||
# 更新版本 section
|
||||
version_text = f"版本: {new_version}\n更新时间: {timestamp}"
|
||||
if "版本" in mf.list_sections():
|
||||
mf.remove_section("版本")
|
||||
mf.add_section("版本", version_text)
|
||||
|
||||
# 更新更新历史 section
|
||||
history_entry = f"- v{new_version} ({date_str}): 更新了{section}" + (f" - {reason}" if reason else "")
|
||||
|
||||
history_lines: list[str] = []
|
||||
history_content = mf.read_section("更新历史")
|
||||
if history_content:
|
||||
history_lines = [line for line in history_content.strip().split("\n") if line.strip()]
|
||||
|
||||
history_lines.append(history_entry)
|
||||
# 最多保留 10 条
|
||||
if len(history_lines) > 10:
|
||||
history_lines = history_lines[-10:]
|
||||
|
||||
if "更新历史" in mf.list_sections():
|
||||
mf.remove_section("更新历史")
|
||||
mf.add_section("更新历史", "\n".join(history_lines))
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Updated soul/{section} to v{new_version}",
|
||||
"version": new_version,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,583 @@
|
|||
"""Marketplace E2E 集成测试 - 多 Agent 市场架构端到端流程"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from agentkit.chat.skill_routing import CostAwareRouter, SkillRoutingResult
|
||||
from agentkit.org.context import OrganizationContext, AgentProfile
|
||||
from agentkit.quality.alignment import AlignmentGuard, AlignmentConfig, CascadeAlert, ConstraintInjector
|
||||
from agentkit.marketplace.auction import AuctionHouse, Bid, AuctionResult
|
||||
from agentkit.marketplace.wealth import WealthTracker
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_gateway():
|
||||
"""Mock LLMGateway for CostAwareRouter Layer 1 classification."""
|
||||
gw = AsyncMock()
|
||||
response = MagicMock()
|
||||
response.content = '{"complexity": 0.5}'
|
||||
gw.chat = AsyncMock(return_value=response)
|
||||
return gw
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_skill_registry():
|
||||
"""Mock SkillRegistry with no skills by default."""
|
||||
registry = MagicMock()
|
||||
registry.list_skills.return_value = []
|
||||
registry.get.side_effect = KeyError("not found")
|
||||
return registry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_intent_router():
|
||||
"""Mock IntentRouter that returns no match by default."""
|
||||
router = AsyncMock()
|
||||
router.route = AsyncMock(return_value=None)
|
||||
return router
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 1: Simple chat routes to default agent (Layer 0)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSimpleChatRoutesToDefault:
|
||||
"""简单对话走 Layer 0 规则匹配,路由到默认 Agent"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_greeting_routes_to_default(self, mock_skill_registry, mock_intent_router):
|
||||
router = CostAwareRouter(llm_gateway=None, org_context=None)
|
||||
result = await router.route(
|
||||
content="你好",
|
||||
skill_registry=mock_skill_registry,
|
||||
intent_router=mock_intent_router,
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful",
|
||||
default_model="default",
|
||||
default_agent_name="default",
|
||||
)
|
||||
assert result.match_method == "greeting"
|
||||
assert result.agent_name == "default"
|
||||
assert result.complexity == 0.0
|
||||
assert result.matched is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_mode_routes_to_default(self, mock_skill_registry, mock_intent_router):
|
||||
router = CostAwareRouter(llm_gateway=None, org_context=None)
|
||||
result = await router.route(
|
||||
content="谢谢",
|
||||
skill_registry=mock_skill_registry,
|
||||
intent_router=mock_intent_router,
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful",
|
||||
default_model="default",
|
||||
default_agent_name="default",
|
||||
)
|
||||
assert result.match_method == "chat_mode"
|
||||
assert result.agent_name == "default"
|
||||
assert result.complexity == 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 2: Complex task routes via capability matching
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCapabilityMatching:
|
||||
"""高复杂度任务通过 OrganizationContext 能力匹配路由"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complex_task_routes_via_capability(self, mock_llm_gateway, mock_skill_registry, mock_intent_router):
|
||||
# Set up LLM to return high complexity
|
||||
high_response = MagicMock()
|
||||
high_response.content = '{"complexity": 0.9}'
|
||||
mock_llm_gateway.chat = AsyncMock(return_value=high_response)
|
||||
|
||||
# Set up org_context with a capable agent
|
||||
org_context = OrganizationContext()
|
||||
org_context.register_agent(AgentProfile(
|
||||
name="research_agent",
|
||||
agent_type="react",
|
||||
capabilities=["research", "analysis"],
|
||||
skills=["research"],
|
||||
))
|
||||
|
||||
# Mock find_best_agent to return the research agent
|
||||
org_context.find_best_agent = AsyncMock(
|
||||
return_value=org_context.get_agent_profile("research_agent")
|
||||
)
|
||||
|
||||
router = CostAwareRouter(
|
||||
llm_gateway=mock_llm_gateway,
|
||||
org_context=org_context,
|
||||
)
|
||||
result = await router.route(
|
||||
content="请对市场趋势进行深度分析并给出投资建议",
|
||||
skill_registry=mock_skill_registry,
|
||||
intent_router=mock_intent_router,
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful",
|
||||
default_model="default",
|
||||
default_agent_name="default",
|
||||
)
|
||||
assert result.matched is True
|
||||
assert result.match_method == "capability"
|
||||
assert result.agent_name == "research_agent"
|
||||
assert result.complexity >= 0.7
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 3: Alignment guard detects cascade risk
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAlignmentCascadeDetection:
|
||||
"""AlignmentGuard 检测级联故障风险"""
|
||||
|
||||
def test_cascade_alert_on_excessive_interactions(self):
|
||||
config = AlignmentConfig(cascade_max_interactions=3)
|
||||
guard = AlignmentGuard(config=config)
|
||||
|
||||
# Record interactions below threshold
|
||||
for _ in range(3):
|
||||
alert = guard.record_interaction("session-1")
|
||||
assert alert is None
|
||||
|
||||
# Next interaction should trigger alert
|
||||
alert = guard.record_interaction("session-1")
|
||||
assert alert is not None
|
||||
assert isinstance(alert, CascadeAlert)
|
||||
assert alert.alert_type == "interaction_limit"
|
||||
assert alert.current_value == 4
|
||||
assert alert.threshold == 3
|
||||
|
||||
def test_cascade_alert_on_loop_depth(self):
|
||||
config = AlignmentConfig(cascade_max_depth=2)
|
||||
guard = AlignmentGuard(config=config)
|
||||
|
||||
# Depth within threshold
|
||||
alert = guard.record_loop_depth("session-1", 2)
|
||||
assert alert is None
|
||||
|
||||
# Depth exceeds threshold
|
||||
alert = guard.record_loop_depth("session-1", 3)
|
||||
assert alert is not None
|
||||
assert alert.alert_type == "loop_depth"
|
||||
assert alert.current_value == 3
|
||||
assert alert.threshold == 2
|
||||
|
||||
def test_reset_session_clears_counts(self):
|
||||
config = AlignmentConfig(cascade_max_interactions=2)
|
||||
guard = AlignmentGuard(config=config)
|
||||
|
||||
guard.record_interaction("session-1")
|
||||
guard.record_interaction("session-1")
|
||||
guard.record_interaction("session-1") # triggers alert
|
||||
assert guard.get_interaction_count("session-1") == 3
|
||||
|
||||
guard.reset_session("session-1")
|
||||
assert guard.get_interaction_count("session-1") == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 4: Transparency TRACE mode returns execution trace
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTransparencyTraceMode:
|
||||
"""TRACE 透明度模式返回执行追踪"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trace_mode_populates_execution_trace(self, mock_skill_registry, mock_intent_router):
|
||||
router = CostAwareRouter(llm_gateway=None, org_context=None)
|
||||
result = await router.route(
|
||||
content="你好",
|
||||
skill_registry=mock_skill_registry,
|
||||
intent_router=mock_intent_router,
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful",
|
||||
default_model="default",
|
||||
default_agent_name="default",
|
||||
transparency="TRACE",
|
||||
)
|
||||
assert result.transparency_level == "TRACE"
|
||||
assert len(result.execution_trace) > 0
|
||||
assert result.execution_trace[0]["layer"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_silent_mode_no_trace(self, mock_skill_registry, mock_intent_router):
|
||||
router = CostAwareRouter(llm_gateway=None, org_context=None)
|
||||
result = await router.route(
|
||||
content="你好",
|
||||
skill_registry=mock_skill_registry,
|
||||
intent_router=mock_intent_router,
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful",
|
||||
default_model="default",
|
||||
default_agent_name="default",
|
||||
transparency="SILENT",
|
||||
)
|
||||
assert result.transparency_level == "SILENT"
|
||||
assert result.execution_trace == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 5: Auction mode routes via auction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAuctionMode:
|
||||
"""拍卖模式通过 AuctionHouse 选择 Agent"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auction_selects_best_bidder(self):
|
||||
wealth_tracker = WealthTracker(initial_wealth=100.0)
|
||||
wealth_tracker.reward("agent_a", 50.0) # agent_a is richer
|
||||
|
||||
auction_house = AuctionHouse(wealth_tracker=wealth_tracker)
|
||||
|
||||
bids = [
|
||||
Bid(
|
||||
agent_name="agent_a",
|
||||
architecture="react",
|
||||
estimated_steps=3,
|
||||
estimated_cost=0.5,
|
||||
confidence=0.9,
|
||||
payment_offer=1.0,
|
||||
capabilities=["research"],
|
||||
),
|
||||
Bid(
|
||||
agent_name="agent_b",
|
||||
architecture="rewoo",
|
||||
estimated_steps=5,
|
||||
estimated_cost=0.8,
|
||||
confidence=0.7,
|
||||
payment_offer=0.5,
|
||||
capabilities=["research"],
|
||||
),
|
||||
]
|
||||
|
||||
result = await auction_house.run_auction("research task", bids)
|
||||
assert result.winner is not None
|
||||
assert result.winner.agent_name == "agent_a"
|
||||
assert result.total_bidders == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auction_no_bidders(self):
|
||||
auction_house = AuctionHouse()
|
||||
result = await auction_house.run_auction("task", [])
|
||||
assert result.winner is None
|
||||
assert result.total_bidders == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bankrupt_agent_excluded(self):
|
||||
wealth_tracker = WealthTracker(initial_wealth=-150.0)
|
||||
auction_house = AuctionHouse(wealth_tracker=wealth_tracker)
|
||||
|
||||
bids = [
|
||||
Bid(
|
||||
agent_name="bankrupt_agent",
|
||||
architecture="react",
|
||||
estimated_steps=1,
|
||||
estimated_cost=0.1,
|
||||
confidence=0.9,
|
||||
payment_offer=1.0,
|
||||
),
|
||||
]
|
||||
|
||||
result = await auction_house.run_auction("task", bids)
|
||||
assert result.winner is None
|
||||
assert "bankrupt" in result.selection_reason.lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 6: Constraint injection works end-to-end
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConstraintInjection:
|
||||
"""约束注入端到端测试"""
|
||||
|
||||
def test_inject_constraints_into_input_data(self):
|
||||
config = AlignmentConfig(constraints=["不得泄露用户隐私", "禁止生成有害内容"])
|
||||
guard = AlignmentGuard(config=config)
|
||||
|
||||
input_data = {"content": "请帮我写一篇文章"}
|
||||
injected = guard.inject_constraints(input_data)
|
||||
|
||||
assert "alignment_constraints" in injected
|
||||
assert "不得泄露用户隐私" in injected["alignment_constraints"]
|
||||
assert "禁止生成有害内容" in injected["alignment_constraints"]
|
||||
# Original data preserved
|
||||
assert injected["content"] == "请帮我写一篇文章"
|
||||
|
||||
def test_inject_does_not_mutate_original(self):
|
||||
config = AlignmentConfig(constraints=["constraint_1"])
|
||||
guard = AlignmentGuard(config=config)
|
||||
|
||||
input_data = {"key": "value"}
|
||||
injected = guard.inject_constraints(input_data)
|
||||
|
||||
assert "alignment_constraints" not in input_data
|
||||
assert "alignment_constraints" in injected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 7: OrganizationContext builds from AgentPool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOrganizationContextFromAgentPool:
|
||||
"""OrganizationContext 从 AgentPool 构建"""
|
||||
|
||||
def test_build_from_agent_pool_with_skills(self):
|
||||
# Mock AgentPool
|
||||
agent_pool = MagicMock()
|
||||
agent_pool.list_agents.return_value = [
|
||||
{"name": "writer", "agent_type": "react"},
|
||||
{"name": "analyst", "agent_type": "plan_exec"},
|
||||
]
|
||||
|
||||
# Mock SkillRegistry — writer has a skill, analyst does not
|
||||
skill_registry = MagicMock()
|
||||
|
||||
writer_skill = MagicMock()
|
||||
writer_config = MagicMock()
|
||||
writer_config.capabilities = [MagicMock(tag="writing"), MagicMock(tag="creative")]
|
||||
writer_config.execution_mode = "react"
|
||||
writer_config.llm = {"model": "gpt-4"}
|
||||
writer_config.max_concurrency = 2
|
||||
writer_skill.config = writer_config
|
||||
|
||||
def get_skill(name):
|
||||
if name == "writer":
|
||||
return writer_skill
|
||||
raise KeyError(name)
|
||||
|
||||
skill_registry.get = MagicMock(side_effect=get_skill)
|
||||
|
||||
org_context = OrganizationContext.from_agent_pool(
|
||||
agent_pool=agent_pool,
|
||||
skill_registry=skill_registry,
|
||||
)
|
||||
|
||||
profiles = org_context.list_agents()
|
||||
assert len(profiles) == 2
|
||||
|
||||
writer_profile = org_context.get_agent_profile("writer")
|
||||
assert writer_profile is not None
|
||||
assert writer_profile.agent_type == "react"
|
||||
assert "writing" in writer_profile.capabilities
|
||||
assert "creative" in writer_profile.capabilities
|
||||
assert writer_profile.model == "gpt-4"
|
||||
assert writer_profile.max_concurrency == 2
|
||||
|
||||
analyst_profile = org_context.get_agent_profile("analyst")
|
||||
assert analyst_profile is not None
|
||||
assert analyst_profile.agent_type == "plan_exec"
|
||||
# No skill found → default values
|
||||
assert analyst_profile.capabilities == []
|
||||
assert analyst_profile.model == "default"
|
||||
|
||||
def test_build_from_empty_agent_pool(self):
|
||||
agent_pool = MagicMock()
|
||||
agent_pool.list_agents.return_value = []
|
||||
skill_registry = MagicMock()
|
||||
|
||||
org_context = OrganizationContext.from_agent_pool(
|
||||
agent_pool=agent_pool,
|
||||
skill_registry=skill_registry,
|
||||
)
|
||||
|
||||
assert org_context.list_agents() == []
|
||||
|
||||
def test_find_best_agent_by_capability(self):
|
||||
org_context = OrganizationContext()
|
||||
org_context.register_agent(AgentProfile(
|
||||
name="researcher",
|
||||
agent_type="react",
|
||||
capabilities=["research", "analysis"],
|
||||
skills=["research"],
|
||||
current_load=0,
|
||||
))
|
||||
org_context.register_agent(AgentProfile(
|
||||
name="writer",
|
||||
agent_type="react",
|
||||
capabilities=["writing", "creative"],
|
||||
skills=["writing"],
|
||||
current_load=2,
|
||||
))
|
||||
|
||||
# Find agent with research capability
|
||||
best = org_context.find_best_agent(["research"])
|
||||
assert best is not None
|
||||
assert best.name == "researcher"
|
||||
|
||||
# Find agent with both research and analysis
|
||||
best = org_context.find_best_agent(["research", "analysis"])
|
||||
assert best is not None
|
||||
assert best.name == "researcher"
|
||||
|
||||
# No agent with unknown capability
|
||||
best = org_context.find_best_agent(["coding"])
|
||||
assert best is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 8: Full pipeline: Chat → Router → Agent → AlignmentGuard
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFullPipeline:
|
||||
"""完整流水线: 用户消息 → CostAwareRouter → 技能匹配 → 约束注入 → 对齐检查"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_pipeline_greeting(self):
|
||||
"""简单问候走完整流水线"""
|
||||
# Setup
|
||||
org_context = OrganizationContext()
|
||||
alignment_config = AlignmentConfig(
|
||||
constraints=["不得包含敏感信息"],
|
||||
cascade_max_interactions=10,
|
||||
)
|
||||
guard = AlignmentGuard(config=alignment_config)
|
||||
router = CostAwareRouter(llm_gateway=None, org_context=org_context)
|
||||
|
||||
mock_skill_registry = MagicMock()
|
||||
mock_skill_registry.list_skills.return_value = []
|
||||
mock_intent_router = AsyncMock()
|
||||
|
||||
# Step 1: Route the message
|
||||
result = await router.route(
|
||||
content="你好",
|
||||
skill_registry=mock_skill_registry,
|
||||
intent_router=mock_intent_router,
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful",
|
||||
default_model="default",
|
||||
default_agent_name="default",
|
||||
)
|
||||
assert result.match_method == "greeting"
|
||||
assert result.agent_name == "default"
|
||||
|
||||
# Step 2: Inject constraints
|
||||
input_data = {"content": result.clean_content}
|
||||
injected = guard.inject_constraints(input_data)
|
||||
assert "alignment_constraints" in injected
|
||||
|
||||
# Step 3: Check alignment on simulated output
|
||||
output = {"result": "你好!有什么我可以帮助你的吗?"}
|
||||
check_result = await guard.check_output(output)
|
||||
assert check_result.passed is True
|
||||
|
||||
# Step 4: Record interaction (no cascade)
|
||||
alert = guard.record_interaction("session-1")
|
||||
assert alert is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_pipeline_with_constraint_violation(self):
|
||||
"""输出违反约束时被检测到"""
|
||||
alignment_config = AlignmentConfig(
|
||||
constraints=["password", "secret_key"],
|
||||
)
|
||||
guard = AlignmentGuard(config=alignment_config)
|
||||
|
||||
# Output containing a constraint keyword
|
||||
output = {"result": "Your password is 123456"}
|
||||
check_result = await guard.check_output(output)
|
||||
assert check_result.passed is False
|
||||
assert len(check_result.violations) > 0
|
||||
assert check_result.checked_by == "rule"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_pipeline_complex_task_with_alignment(self):
|
||||
"""复杂任务走完整流水线:路由 → 能力匹配 → 约束注入 → 对齐检查"""
|
||||
# Setup LLM gateway returning high complexity
|
||||
mock_llm = AsyncMock()
|
||||
high_response = MagicMock()
|
||||
high_response.content = '{"complexity": 0.85}'
|
||||
mock_llm.chat = AsyncMock(return_value=high_response)
|
||||
|
||||
# Setup org context with capable agent
|
||||
org_context = OrganizationContext()
|
||||
org_context.register_agent(AgentProfile(
|
||||
name="analyst",
|
||||
agent_type="react",
|
||||
capabilities=["analysis", "market_research"],
|
||||
skills=["market_analysis"],
|
||||
current_load=0,
|
||||
))
|
||||
org_context.find_best_agent = AsyncMock(
|
||||
return_value=org_context.get_agent_profile("analyst")
|
||||
)
|
||||
|
||||
alignment_config = AlignmentConfig(
|
||||
constraints=["不得提供具体投资建议"],
|
||||
cascade_max_interactions=5,
|
||||
)
|
||||
guard = AlignmentGuard(config=alignment_config, llm_gateway=mock_llm)
|
||||
|
||||
router = CostAwareRouter(
|
||||
llm_gateway=mock_llm,
|
||||
org_context=org_context,
|
||||
)
|
||||
|
||||
mock_skill_registry = MagicMock()
|
||||
mock_skill_registry.list_skills.return_value = []
|
||||
mock_intent_router = AsyncMock()
|
||||
|
||||
# Step 1: Route complex task
|
||||
result = await router.route(
|
||||
content="请分析当前AI行业的市场趋势",
|
||||
skill_registry=mock_skill_registry,
|
||||
intent_router=mock_intent_router,
|
||||
default_tools=[],
|
||||
default_system_prompt="You are a market analyst",
|
||||
default_model="default",
|
||||
default_agent_name="default",
|
||||
transparency="TRACE",
|
||||
)
|
||||
assert result.matched is True
|
||||
assert result.match_method == "capability"
|
||||
assert result.agent_name == "analyst"
|
||||
assert result.complexity >= 0.7
|
||||
assert len(result.execution_trace) > 0
|
||||
|
||||
# Step 2: Inject constraints
|
||||
input_data = {"content": result.clean_content}
|
||||
injected = guard.inject_constraints(input_data)
|
||||
assert "alignment_constraints" in injected
|
||||
|
||||
# Step 3: Simulate agent output and check alignment
|
||||
safe_output = {"result": "AI行业目前呈现稳步增长趋势,主要驱动力来自大模型技术的突破。"}
|
||||
check_result = await guard.check_output(safe_output)
|
||||
assert check_result.passed is True
|
||||
|
||||
# Step 4: Record interaction
|
||||
alert = guard.record_interaction("session-complex")
|
||||
assert alert is None # Under threshold
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_pipeline_cascade_alert(self):
|
||||
"""级联故障检测在完整流水线中触发"""
|
||||
alignment_config = AlignmentConfig(
|
||||
cascade_max_interactions=2,
|
||||
)
|
||||
guard = AlignmentGuard(config=alignment_config)
|
||||
|
||||
# Simulate multiple interactions
|
||||
guard.record_interaction("session-cascade")
|
||||
guard.record_interaction("session-cascade")
|
||||
alert = guard.record_interaction("session-cascade")
|
||||
|
||||
assert alert is not None
|
||||
assert alert.alert_type == "interaction_limit"
|
||||
assert alert.current_value == 3
|
||||
|
|
@ -0,0 +1,334 @@
|
|||
"""AlignmentGuard 单元测试"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.quality.alignment import (
|
||||
AlignmentCheckResult,
|
||||
AlignmentConfig,
|
||||
AlignmentGuard,
|
||||
CascadeAlert,
|
||||
ConstraintInjector,
|
||||
)
|
||||
from agentkit.quality.cascade_detector import CascadeDetector
|
||||
from agentkit.skills.base import SkillConfig
|
||||
|
||||
|
||||
# ── AlignmentConfig 测试 ──────────────────────────────────
|
||||
|
||||
|
||||
class TestAlignmentConfig:
|
||||
"""AlignmentConfig 默认值测试"""
|
||||
|
||||
def test_default_values(self):
|
||||
config = AlignmentConfig()
|
||||
assert config.constraints == []
|
||||
assert config.cascade_max_interactions == 10
|
||||
assert config.cascade_max_depth == 3
|
||||
assert config.audit_enabled is False
|
||||
assert config.audit_model == "default"
|
||||
|
||||
def test_custom_values(self):
|
||||
config = AlignmentConfig(
|
||||
constraints=["no_harm", "be_honest"],
|
||||
cascade_max_interactions=5,
|
||||
cascade_max_depth=2,
|
||||
audit_enabled=True,
|
||||
audit_model="gpt-4",
|
||||
)
|
||||
assert config.constraints == ["no_harm", "be_honest"]
|
||||
assert config.cascade_max_interactions == 5
|
||||
assert config.cascade_max_depth == 2
|
||||
assert config.audit_enabled is True
|
||||
assert config.audit_model == "gpt-4"
|
||||
|
||||
|
||||
# ── ConstraintInjector 测试 ───────────────────────────────
|
||||
|
||||
|
||||
class TestConstraintInjector:
|
||||
"""ConstraintInjector 约束注入测试"""
|
||||
|
||||
def test_inject_constraints_into_input_data(self):
|
||||
config = AlignmentConfig(constraints=["no_harm", "be_honest"])
|
||||
injector = ConstraintInjector(config)
|
||||
result = injector.inject({"task": "translate"})
|
||||
assert "alignment_constraints" in result
|
||||
assert result["alignment_constraints"] == ["no_harm", "be_honest"]
|
||||
assert result["task"] == "translate"
|
||||
|
||||
def test_does_not_modify_original_dict(self):
|
||||
config = AlignmentConfig(constraints=["no_harm"])
|
||||
injector = ConstraintInjector(config)
|
||||
original = {"task": "translate"}
|
||||
result = injector.inject(original)
|
||||
assert "alignment_constraints" not in original
|
||||
assert "alignment_constraints" in result
|
||||
|
||||
def test_empty_constraints(self):
|
||||
config = AlignmentConfig(constraints=[])
|
||||
injector = ConstraintInjector(config)
|
||||
result = injector.inject({"task": "translate"})
|
||||
assert result["alignment_constraints"] == []
|
||||
|
||||
|
||||
# ── AlignmentGuard.check_output 测试 ──────────────────────
|
||||
|
||||
|
||||
class TestAlignmentGuardCheckOutput:
|
||||
"""AlignmentGuard.check_output 对齐检查"""
|
||||
|
||||
async def test_rule_check_violation_keyword_match(self):
|
||||
config = AlignmentConfig(constraints=["forbidden_word"])
|
||||
guard = AlignmentGuard(config)
|
||||
output = {"content": "This contains forbidden_word in text"}
|
||||
result = await guard.check_output(output)
|
||||
assert result.passed is False
|
||||
assert "forbidden_word" in result.violations
|
||||
assert result.checked_by == "rule"
|
||||
|
||||
async def test_rule_check_passes_no_violations(self):
|
||||
config = AlignmentConfig(constraints=["forbidden_word"])
|
||||
guard = AlignmentGuard(config)
|
||||
output = {"content": "This is clean text"}
|
||||
result = await guard.check_output(output)
|
||||
assert result.passed is True
|
||||
assert result.violations == []
|
||||
assert result.checked_by == "rule"
|
||||
|
||||
async def test_no_constraints_passes(self):
|
||||
config = AlignmentConfig(constraints=[])
|
||||
guard = AlignmentGuard(config)
|
||||
result = await guard.check_output({"content": "anything"})
|
||||
assert result.passed is True
|
||||
assert result.checked_by == "rule"
|
||||
|
||||
async def test_audit_disabled_does_not_call_llm(self):
|
||||
config = AlignmentConfig(
|
||||
constraints=["no_harm"], audit_enabled=False
|
||||
)
|
||||
mock_gateway = AsyncMock()
|
||||
guard = AlignmentGuard(config, llm_gateway=mock_gateway)
|
||||
output = {"content": "This is safe"}
|
||||
result = await guard.check_output(output)
|
||||
assert result.checked_by == "rule"
|
||||
mock_gateway.chat.assert_not_called()
|
||||
|
||||
async def test_audit_enabled_calls_llm_for_semantic_check(self):
|
||||
config = AlignmentConfig(
|
||||
constraints=["be_respectful"], audit_enabled=True, audit_model="gpt-4"
|
||||
)
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = "PASS"
|
||||
mock_gateway = AsyncMock()
|
||||
mock_gateway.chat.return_value = mock_response
|
||||
guard = AlignmentGuard(config, llm_gateway=mock_gateway)
|
||||
output = {"content": "This is respectful text"}
|
||||
# Rule check passes first (no keyword match), then LLM audit
|
||||
result = await guard.check_output(output)
|
||||
assert result.checked_by == "llm"
|
||||
mock_gateway.chat.assert_called_once()
|
||||
|
||||
async def test_audit_enabled_llm_detects_violation(self):
|
||||
config = AlignmentConfig(
|
||||
constraints=["be_respectful"], audit_enabled=True
|
||||
)
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = "VIOLATION: Output is disrespectful"
|
||||
mock_gateway = AsyncMock()
|
||||
mock_gateway.chat.return_value = mock_response
|
||||
guard = AlignmentGuard(config, llm_gateway=mock_gateway)
|
||||
output = {"content": "This is borderline text"}
|
||||
result = await guard.check_output(output)
|
||||
assert result.passed is False
|
||||
assert result.checked_by == "llm"
|
||||
|
||||
async def test_audit_enabled_no_llm_gateway_skips_llm(self):
|
||||
config = AlignmentConfig(
|
||||
constraints=["be_respectful"], audit_enabled=True
|
||||
)
|
||||
guard = AlignmentGuard(config, llm_gateway=None)
|
||||
output = {"content": "This is safe"}
|
||||
result = await guard.check_output(output)
|
||||
assert result.checked_by == "rule"
|
||||
|
||||
async def test_custom_constraints_override_config(self):
|
||||
config = AlignmentConfig(constraints=["default_constraint"])
|
||||
guard = AlignmentGuard(config)
|
||||
output = {"content": "This has custom_violation in it"}
|
||||
result = await guard.check_output(output, constraints=["custom_violation"])
|
||||
assert result.passed is False
|
||||
assert "custom_violation" in result.violations
|
||||
|
||||
async def test_case_insensitive_matching(self):
|
||||
config = AlignmentConfig(constraints=["ForBiDdEn"])
|
||||
guard = AlignmentGuard(config)
|
||||
output = {"content": "This has forbidden in it"}
|
||||
result = await guard.check_output(output)
|
||||
assert result.passed is False
|
||||
|
||||
|
||||
# ── AlignmentGuard 级联检测测试 ───────────────────────────
|
||||
|
||||
|
||||
class TestAlignmentGuardCascade:
|
||||
"""AlignmentGuard 级联故障检测"""
|
||||
|
||||
def test_record_interaction_returns_alert_when_exceeded(self):
|
||||
config = AlignmentConfig(cascade_max_interactions=3)
|
||||
guard = AlignmentGuard(config)
|
||||
# 前 3 次不触发
|
||||
assert guard.record_interaction("s1") is None
|
||||
assert guard.record_interaction("s1") is None
|
||||
assert guard.record_interaction("s1") is None
|
||||
# 第 4 次触发
|
||||
alert = guard.record_interaction("s1")
|
||||
assert alert is not None
|
||||
assert alert.session_id == "s1"
|
||||
assert alert.alert_type == "interaction_limit"
|
||||
assert alert.current_value == 4
|
||||
assert alert.threshold == 3
|
||||
|
||||
def test_record_interaction_below_threshold_returns_none(self):
|
||||
config = AlignmentConfig(cascade_max_interactions=10)
|
||||
guard = AlignmentGuard(config)
|
||||
assert guard.record_interaction("s1") is None
|
||||
|
||||
def test_record_loop_depth_returns_alert_when_exceeded(self):
|
||||
config = AlignmentConfig(cascade_max_depth=2)
|
||||
guard = AlignmentGuard(config)
|
||||
assert guard.record_loop_depth("s1", 2) is None
|
||||
alert = guard.record_loop_depth("s1", 3)
|
||||
assert alert is not None
|
||||
assert alert.alert_type == "loop_depth"
|
||||
assert alert.current_value == 3
|
||||
|
||||
def test_reset_session_clears_counters(self):
|
||||
config = AlignmentConfig(cascade_max_interactions=5)
|
||||
guard = AlignmentGuard(config)
|
||||
guard.record_interaction("s1")
|
||||
guard.record_interaction("s1")
|
||||
assert guard.get_interaction_count("s1") == 2
|
||||
guard.reset_session("s1")
|
||||
assert guard.get_interaction_count("s1") == 0
|
||||
|
||||
def test_get_interaction_count_default_zero(self):
|
||||
config = AlignmentConfig()
|
||||
guard = AlignmentGuard(config)
|
||||
assert guard.get_interaction_count("unknown") == 0
|
||||
|
||||
def test_inject_constraints_delegates_to_injector(self):
|
||||
config = AlignmentConfig(constraints=["no_harm"])
|
||||
guard = AlignmentGuard(config)
|
||||
result = guard.inject_constraints({"task": "test"})
|
||||
assert result["alignment_constraints"] == ["no_harm"]
|
||||
|
||||
|
||||
# ── CascadeDetector 测试 ──────────────────────────────────
|
||||
|
||||
|
||||
class TestCascadeDetector:
|
||||
"""CascadeDetector 独立级联检测测试"""
|
||||
|
||||
def test_interaction_exceeds_threshold_triggers_alert(self):
|
||||
detector = CascadeDetector(max_interactions=3)
|
||||
assert detector.check_interaction("s1") is None
|
||||
assert detector.check_interaction("s1") is None
|
||||
assert detector.check_interaction("s1") is None
|
||||
alert = detector.check_interaction("s1")
|
||||
assert alert is not None
|
||||
assert alert.alert_type == "interaction_limit"
|
||||
assert alert.current_value == 4
|
||||
assert alert.threshold == 3
|
||||
|
||||
def test_interaction_below_threshold_returns_none(self):
|
||||
detector = CascadeDetector(max_interactions=10)
|
||||
assert detector.check_interaction("s1") is None
|
||||
|
||||
def test_loop_depth_exceeds_threshold_triggers_alert(self):
|
||||
detector = CascadeDetector(max_depth=3)
|
||||
assert detector.check_depth("s1", 3) is None
|
||||
alert = detector.check_depth("s1", 4)
|
||||
assert alert is not None
|
||||
assert alert.alert_type == "loop_depth"
|
||||
assert alert.current_value == 4
|
||||
|
||||
def test_reset_clears_counters(self):
|
||||
detector = CascadeDetector(max_interactions=2)
|
||||
detector.check_interaction("s1")
|
||||
detector.check_interaction("s1")
|
||||
detector.reset("s1")
|
||||
stats = detector.get_stats("s1")
|
||||
assert stats["interactions"] == 0
|
||||
assert stats["depth"] == 0
|
||||
|
||||
def test_get_stats_returns_current_values(self):
|
||||
detector = CascadeDetector()
|
||||
detector.check_interaction("s1")
|
||||
detector.check_interaction("s1")
|
||||
detector.check_depth("s1", 5)
|
||||
stats = detector.get_stats("s1")
|
||||
assert stats["interactions"] == 2
|
||||
assert stats["depth"] == 5
|
||||
|
||||
def test_get_stats_unknown_session(self):
|
||||
detector = CascadeDetector()
|
||||
stats = detector.get_stats("unknown")
|
||||
assert stats["interactions"] == 0
|
||||
assert stats["depth"] == 0
|
||||
|
||||
|
||||
# ── SkillConfig alignment 字段测试 ────────────────────────
|
||||
|
||||
|
||||
class TestSkillConfigAlignment:
|
||||
"""SkillConfig alignment 字段测试"""
|
||||
|
||||
def test_default_alignment(self):
|
||||
config = SkillConfig(name="test", agent_type="test", prompt={"identity": "test"})
|
||||
assert config.alignment.constraints == []
|
||||
assert config.alignment.cascade_max_interactions == 10
|
||||
assert config.alignment.cascade_max_depth == 3
|
||||
assert config.alignment.audit_enabled is False
|
||||
assert config.alignment.audit_model == "default"
|
||||
|
||||
def test_alignment_from_dict(self):
|
||||
config = SkillConfig.from_dict({
|
||||
"name": "test",
|
||||
"agent_type": "test",
|
||||
"prompt": {"identity": "test"},
|
||||
"alignment": {
|
||||
"constraints": ["no_harm"],
|
||||
"cascade_max_interactions": 5,
|
||||
"cascade_max_depth": 2,
|
||||
"audit_enabled": True,
|
||||
"audit_model": "gpt-4",
|
||||
},
|
||||
})
|
||||
assert config.alignment.constraints == ["no_harm"]
|
||||
assert config.alignment.cascade_max_interactions == 5
|
||||
assert config.alignment.cascade_max_depth == 2
|
||||
assert config.alignment.audit_enabled is True
|
||||
assert config.alignment.audit_model == "gpt-4"
|
||||
|
||||
def test_alignment_to_dict(self):
|
||||
config = SkillConfig(
|
||||
name="test",
|
||||
agent_type="test",
|
||||
prompt={"identity": "test"},
|
||||
alignment={"constraints": ["no_harm"], "audit_enabled": True},
|
||||
)
|
||||
d = config.to_dict()
|
||||
assert "alignment" in d
|
||||
assert d["alignment"]["constraints"] == ["no_harm"]
|
||||
assert d["alignment"]["audit_enabled"] is True
|
||||
|
||||
def test_backward_compatibility_no_alignment(self):
|
||||
config = SkillConfig.from_dict({
|
||||
"name": "test",
|
||||
"agent_type": "test",
|
||||
"prompt": {"identity": "test"},
|
||||
})
|
||||
assert config.alignment.constraints == []
|
||||
|
|
@ -0,0 +1,290 @@
|
|||
"""AuctionHouse 与 WealthTracker 单元测试"""
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.marketplace.auction import AuctionHouse, AuctionResult, Bid
|
||||
from agentkit.marketplace.wealth import WealthTracker
|
||||
|
||||
|
||||
# ---- Fixtures ----
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def wealth_tracker():
|
||||
return WealthTracker()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auction_house():
|
||||
return AuctionHouse()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auction_house_with_tracker():
|
||||
tracker = WealthTracker()
|
||||
return AuctionHouse(wealth_tracker=tracker), tracker
|
||||
|
||||
|
||||
def make_bid(
|
||||
agent_name: str = "agent_a",
|
||||
architecture: str = "react",
|
||||
estimated_steps: int = 5,
|
||||
estimated_cost: float = 10.0,
|
||||
confidence: float = 0.8,
|
||||
payment_offer: float = 1.0,
|
||||
capabilities: list[str] | None = None,
|
||||
) -> Bid:
|
||||
return Bid(
|
||||
agent_name=agent_name,
|
||||
architecture=architecture,
|
||||
estimated_steps=estimated_steps,
|
||||
estimated_cost=estimated_cost,
|
||||
confidence=confidence,
|
||||
payment_offer=payment_offer,
|
||||
capabilities=capabilities or [],
|
||||
)
|
||||
|
||||
|
||||
# ---- AuctionHouse 测试 ----
|
||||
|
||||
|
||||
class TestAuctionHouseSingleBidder:
|
||||
"""单一竞价者自动获胜"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_bidder_wins(self, auction_house):
|
||||
bid = make_bid(agent_name="solo_agent")
|
||||
result = await auction_house.run_auction("do something", [bid])
|
||||
assert result.winner is not None
|
||||
assert result.winner.agent_name == "solo_agent"
|
||||
assert result.total_bidders == 1
|
||||
|
||||
|
||||
class TestAuctionHouseMultipleBidders:
|
||||
"""多竞价者,最高分获胜"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_highest_score_wins(self, auction_house):
|
||||
bid_low = make_bid(
|
||||
agent_name="low_agent",
|
||||
confidence=0.5,
|
||||
estimated_cost=10.0,
|
||||
)
|
||||
bid_high = make_bid(
|
||||
agent_name="high_agent",
|
||||
confidence=0.9,
|
||||
estimated_cost=10.0,
|
||||
)
|
||||
result = await auction_house.run_auction("do something", [bid_low, bid_high])
|
||||
assert result.winner is not None
|
||||
assert result.winner.agent_name == "high_agent"
|
||||
|
||||
|
||||
class TestAuctionHouseNoBidders:
|
||||
"""无竞价者返回 None winner"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_bidders_returns_none(self, auction_house):
|
||||
result = await auction_house.run_auction("do something", [])
|
||||
assert result.winner is None
|
||||
assert result.total_bidders == 0
|
||||
assert result.all_bids == []
|
||||
|
||||
|
||||
class TestAuctionHouseWealthFactor:
|
||||
"""财富因子影响评分"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wealth_factor_affects_scoring(self):
|
||||
tracker = WealthTracker()
|
||||
# Give agent_rich more wealth
|
||||
tracker.reward("agent_rich", 500.0)
|
||||
house = AuctionHouse(wealth_tracker=tracker)
|
||||
|
||||
# Same confidence and cost, but different wealth
|
||||
bid_rich = make_bid(agent_name="agent_rich", confidence=0.8, estimated_cost=10.0)
|
||||
bid_poor = make_bid(agent_name="agent_poor", confidence=0.8, estimated_cost=10.0)
|
||||
|
||||
result = await house.run_auction("do something", [bid_rich, bid_poor])
|
||||
assert result.winner is not None
|
||||
assert result.winner.agent_name == "agent_rich"
|
||||
|
||||
|
||||
class TestAuctionHouseZeroCost:
|
||||
"""零 estimated_cost 处理(max 与 0.001)"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_estimated_cost_handled(self, auction_house):
|
||||
bid = make_bid(agent_name="zero_cost_agent", confidence=0.8, estimated_cost=0.0)
|
||||
result = await auction_house.run_auction("do something", [bid])
|
||||
assert result.winner is not None
|
||||
assert result.winner.agent_name == "zero_cost_agent"
|
||||
|
||||
def test_score_bid_zero_cost(self, auction_house):
|
||||
bid = make_bid(agent_name="zero_cost_agent", confidence=0.8, estimated_cost=0.0)
|
||||
score = auction_house.score_bid(bid)
|
||||
# score = (0.8 / max(0.0, 0.001)) * 1.1 = (0.8 / 0.001) * 1.1 = 880.0
|
||||
expected = (0.8 / 0.001) * 1.1
|
||||
assert abs(score - expected) < 0.01
|
||||
|
||||
|
||||
class TestBidScoringFormula:
|
||||
"""竞价评分公式验证"""
|
||||
|
||||
def test_score_formula(self):
|
||||
tracker = WealthTracker()
|
||||
# Default wealth = 100, so wealth_factor = 1.0 + (100 / 1000.0) = 1.1
|
||||
house = AuctionHouse(wealth_tracker=tracker)
|
||||
|
||||
bid = make_bid(agent_name="test_agent", confidence=0.9, estimated_cost=5.0)
|
||||
score = house.score_bid(bid)
|
||||
|
||||
wealth_factor = 1.0 + (100.0 / 1000.0) # 1.1
|
||||
expected = (0.9 / 5.0) * wealth_factor
|
||||
assert abs(score - expected) < 0.0001
|
||||
|
||||
def test_score_formula_with_custom_wealth(self):
|
||||
tracker = WealthTracker(initial_wealth=200.0)
|
||||
tracker.reward("rich_agent", 300.0)
|
||||
# wealth = 500, factor = 1.0 + 500/1000 = 1.5
|
||||
house = AuctionHouse(wealth_tracker=tracker)
|
||||
|
||||
bid = make_bid(agent_name="rich_agent", confidence=0.6, estimated_cost=3.0)
|
||||
score = house.score_bid(bid)
|
||||
|
||||
wealth_factor = 1.0 + (500.0 / 1000.0) # 1.5
|
||||
expected = (0.6 / 3.0) * wealth_factor
|
||||
assert abs(score - expected) < 0.0001
|
||||
|
||||
|
||||
# ---- WealthTracker 测试 ----
|
||||
|
||||
|
||||
class TestWealthTrackerInitialWealth:
|
||||
"""初始财富默认值"""
|
||||
|
||||
def test_default_initial_wealth(self):
|
||||
tracker = WealthTracker()
|
||||
assert tracker.get_wealth("unknown_agent") == 100.0
|
||||
|
||||
def test_custom_initial_wealth(self):
|
||||
tracker = WealthTracker(initial_wealth=50.0)
|
||||
assert tracker.get_wealth("unknown_agent") == 50.0
|
||||
|
||||
|
||||
class TestWealthTrackerReward:
|
||||
"""奖励增加财富"""
|
||||
|
||||
def test_reward_increases_wealth(self, wealth_tracker):
|
||||
wealth_tracker.reward("agent_a", 50.0)
|
||||
assert wealth_tracker.get_wealth("agent_a") == 150.0
|
||||
|
||||
def test_reward_multiple_times(self, wealth_tracker):
|
||||
wealth_tracker.reward("agent_a", 30.0)
|
||||
wealth_tracker.reward("agent_a", 20.0)
|
||||
assert wealth_tracker.get_wealth("agent_a") == 150.0
|
||||
|
||||
|
||||
class TestWealthTrackerPenalize:
|
||||
"""惩罚减少财富"""
|
||||
|
||||
def test_penalize_decreases_wealth(self, wealth_tracker):
|
||||
wealth_tracker.penalize("agent_a", 30.0)
|
||||
assert wealth_tracker.get_wealth("agent_a") == 70.0
|
||||
|
||||
def test_penalize_below_zero(self, wealth_tracker):
|
||||
wealth_tracker.penalize("agent_a", 150.0)
|
||||
assert wealth_tracker.get_wealth("agent_a") == -50.0
|
||||
|
||||
|
||||
class TestWealthTrackerBankrupt:
|
||||
"""破产检查(wealth <= -100)"""
|
||||
|
||||
def test_bankrupt_at_negative_100(self, wealth_tracker):
|
||||
wealth_tracker.penalize("agent_a", 200.0)
|
||||
assert wealth_tracker.get_wealth("agent_a") == -100.0
|
||||
assert wealth_tracker.is_bankrupt("agent_a") is True
|
||||
|
||||
def test_bankrupt_below_negative_100(self, wealth_tracker):
|
||||
wealth_tracker.penalize("agent_a", 250.0)
|
||||
assert wealth_tracker.is_bankrupt("agent_a") is True
|
||||
|
||||
def test_not_bankrupt_above_negative_100(self, wealth_tracker):
|
||||
wealth_tracker.penalize("agent_a", 150.0)
|
||||
# wealth = -50, which is > -100
|
||||
assert wealth_tracker.is_bankrupt("agent_a") is False
|
||||
|
||||
def test_not_bankrupt_at_default(self, wealth_tracker):
|
||||
assert wealth_tracker.is_bankrupt("agent_a") is False
|
||||
|
||||
|
||||
class TestWealthTrackerReset:
|
||||
"""重置恢复初始财富"""
|
||||
|
||||
def test_reset_restores_initial_wealth(self, wealth_tracker):
|
||||
wealth_tracker.reward("agent_a", 500.0)
|
||||
wealth_tracker.reset("agent_a")
|
||||
assert wealth_tracker.get_wealth("agent_a") == 100.0
|
||||
|
||||
def test_reset_with_custom_initial(self):
|
||||
tracker = WealthTracker(initial_wealth=200.0)
|
||||
tracker.penalize("agent_a", 50.0)
|
||||
tracker.reset("agent_a")
|
||||
assert tracker.get_wealth("agent_a") == 200.0
|
||||
|
||||
|
||||
class TestWealthTrackerRankings:
|
||||
"""排名按财富降序"""
|
||||
|
||||
def test_rankings_sorted_descending(self, wealth_tracker):
|
||||
wealth_tracker.reward("agent_a", 100.0) # 200
|
||||
wealth_tracker.reward("agent_b", 300.0) # 400
|
||||
wealth_tracker.penalize("agent_c", 50.0) # 50
|
||||
|
||||
rankings = wealth_tracker.get_rankings()
|
||||
assert rankings[0][0] == "agent_b"
|
||||
assert rankings[1][0] == "agent_a"
|
||||
assert rankings[2][0] == "agent_c"
|
||||
|
||||
def test_rankings_empty(self, wealth_tracker):
|
||||
assert wealth_tracker.get_rankings() == []
|
||||
|
||||
|
||||
class TestWealthTrackerWealthFactor:
|
||||
"""财富因子计算"""
|
||||
|
||||
def test_wealth_factor_default(self, wealth_tracker):
|
||||
# wealth = 100, factor = 1.0 + 100/1000 = 1.1
|
||||
factor = wealth_tracker.get_wealth_factor("agent_a")
|
||||
assert abs(factor - 1.1) < 0.0001
|
||||
|
||||
def test_wealth_factor_with_wealth(self, wealth_tracker):
|
||||
wealth_tracker.reward("agent_a", 400.0) # wealth = 500
|
||||
factor = wealth_tracker.get_wealth_factor("agent_a")
|
||||
# factor = 1.0 + 500/1000 = 1.5
|
||||
assert abs(factor - 1.5) < 0.0001
|
||||
|
||||
def test_wealth_factor_negative_wealth(self, wealth_tracker):
|
||||
wealth_tracker.penalize("agent_a", 150.0) # wealth = -50
|
||||
factor = wealth_tracker.get_wealth_factor("agent_a")
|
||||
# factor = 1.0 + (-50)/1000 = 0.95
|
||||
assert abs(factor - 0.95) < 0.0001
|
||||
|
||||
|
||||
# ---- Auction 默认禁用验证 ----
|
||||
|
||||
|
||||
class TestAuctionDefaultDisabled:
|
||||
"""拍卖机制默认禁用"""
|
||||
|
||||
def test_auction_not_in_default_config(self):
|
||||
"""验证默认配置中不包含 auction_enabled"""
|
||||
from agentkit.server.config import ServerConfig
|
||||
|
||||
config = ServerConfig()
|
||||
# marketplace section should not exist or auction_enabled should be False
|
||||
marketplace_cfg = getattr(config, "marketplace", None)
|
||||
if marketplace_cfg is not None:
|
||||
auction_enabled = getattr(marketplace_cfg, "auction_enabled", False)
|
||||
assert auction_enabled is False
|
||||
# If marketplace doesn't exist at all, auction is implicitly disabled
|
||||
|
|
@ -0,0 +1,468 @@
|
|||
"""CostAwareRouter 单元测试 - 三层成本感知路由"""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.chat.skill_routing import CostAwareRouter, SkillRoutingResult
|
||||
from agentkit.llm.protocol import LLMResponse, TokenUsage
|
||||
from agentkit.router.intent import IntentRouter, RoutingResult
|
||||
from agentkit.skills.base import IntentConfig, Skill, SkillConfig
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_skill(
|
||||
name: str,
|
||||
keywords: list[str] | None = None,
|
||||
description: str = "",
|
||||
examples: list[str] | None = None,
|
||||
) -> Skill:
|
||||
"""快速构造一个带 intent 配置的 Skill"""
|
||||
config = SkillConfig(
|
||||
name=name,
|
||||
agent_type="test",
|
||||
task_mode="llm_generate",
|
||||
prompt={"system": f"You are a {name} skill."},
|
||||
intent={
|
||||
"keywords": keywords or [],
|
||||
"description": description,
|
||||
"examples": examples or [],
|
||||
},
|
||||
)
|
||||
return Skill(config=config)
|
||||
|
||||
|
||||
def _make_llm_gateway(response_content: str) -> MagicMock:
|
||||
"""构造一个 mock LLMGateway,chat 返回指定 content"""
|
||||
gateway = MagicMock()
|
||||
gateway.chat = AsyncMock(
|
||||
return_value=LLMResponse(
|
||||
content=response_content,
|
||||
model="test-model",
|
||||
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
|
||||
)
|
||||
)
|
||||
return gateway
|
||||
|
||||
|
||||
def _make_skill_registry(skills: list[Skill] | None = None) -> MagicMock:
|
||||
"""构造一个 mock SkillRegistry"""
|
||||
registry = MagicMock()
|
||||
_skills = skills or []
|
||||
registry.list_skills.return_value = _skills
|
||||
|
||||
def _get(name: str):
|
||||
for s in _skills:
|
||||
if s.name == name:
|
||||
return s
|
||||
raise KeyError(f"Skill '{name}' not found")
|
||||
|
||||
registry.get = MagicMock(side_effect=_get)
|
||||
return registry
|
||||
|
||||
|
||||
def _make_intent_router() -> IntentRouter:
|
||||
"""构造一个无 LLM 的 IntentRouter(仅关键词匹配)"""
|
||||
return IntentRouter(llm_gateway=None, model="default")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Layer 0: Rule-based (zero cost)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLayer0Greeting:
|
||||
"""Layer 0: 问候模式匹配"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chinese_greeting_hits_layer0(self):
|
||||
"""'你好' 命中 Layer 0 问候规则,零 token 成本"""
|
||||
router = CostAwareRouter()
|
||||
result = await router.route(
|
||||
content="你好",
|
||||
skill_registry=_make_skill_registry(),
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
assert result.match_method == "greeting"
|
||||
assert result.complexity == 0.0
|
||||
assert result.agent_name == "default"
|
||||
assert result.matched is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_english_greeting_hits_layer0(self):
|
||||
"""'hello' 命中 Layer 0 问候规则"""
|
||||
router = CostAwareRouter()
|
||||
result = await router.route(
|
||||
content="hello",
|
||||
skill_registry=_make_skill_registry(),
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
assert result.match_method == "greeting"
|
||||
assert result.complexity == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_greeting_with_punctuation(self):
|
||||
"""'你好!' 带标点也命中 Layer 0"""
|
||||
router = CostAwareRouter()
|
||||
result = await router.route(
|
||||
content="你好!",
|
||||
skill_registry=_make_skill_registry(),
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
assert result.match_method == "greeting"
|
||||
|
||||
|
||||
class TestLayer0ChatMode:
|
||||
"""Layer 0: 简单对话模式"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_thanks_hits_chat_mode(self):
|
||||
"""'谢谢' 命中 Layer 0 简单对话模式"""
|
||||
router = CostAwareRouter()
|
||||
result = await router.route(
|
||||
content="谢谢",
|
||||
skill_registry=_make_skill_registry(),
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
assert result.match_method == "chat_mode"
|
||||
assert result.complexity == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ok_hits_chat_mode(self):
|
||||
"""'好的' 命中 Layer 0 简单对话模式"""
|
||||
router = CostAwareRouter()
|
||||
result = await router.route(
|
||||
content="好的",
|
||||
skill_registry=_make_skill_registry(),
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
assert result.match_method == "chat_mode"
|
||||
|
||||
|
||||
class TestLayer0ExplicitSkill:
|
||||
"""Layer 0: @skill: 显式前缀"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skill_prefix_hits_layer0(self):
|
||||
"""'@skill:search 搜索XX' 命中 Layer 0 显式 Skill 规则,零 token 成本"""
|
||||
search_skill = _make_skill("search", keywords=["搜索"], description="搜索信息")
|
||||
registry = _make_skill_registry([search_skill])
|
||||
# 需要 IntentRouter 支持 LLM fallback
|
||||
gateway = _make_llm_gateway(json.dumps({"skill": "search", "confidence": 0.9}))
|
||||
intent_router = IntentRouter(llm_gateway=gateway, model="default")
|
||||
|
||||
router = CostAwareRouter()
|
||||
result = await router.route(
|
||||
content="@skill:search 搜索XX",
|
||||
skill_registry=registry,
|
||||
intent_router=intent_router,
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
assert result.matched is True
|
||||
assert result.skill_name == "search"
|
||||
assert result.complexity == 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Layer 1: LLM quick classify (~100 tokens)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLayer1Classification:
|
||||
"""Layer 1: LLM 快速分类"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_medium_complexity_routes_via_intent_router(self):
|
||||
"""'分析下这个数据' 经过 Layer 1 LLM 分类,中等复杂度走 IntentRouter"""
|
||||
# LLM 返回中等复杂度
|
||||
gateway = _make_llm_gateway(json.dumps({"complexity": 0.5}))
|
||||
search_skill = _make_skill("search", keywords=["分析"], description="数据分析")
|
||||
registry = _make_skill_registry([search_skill])
|
||||
|
||||
# IntentRouter 也需要 LLM
|
||||
intent_gateway = _make_llm_gateway(json.dumps({"skill": "search", "confidence": 0.9}))
|
||||
intent_router = IntentRouter(llm_gateway=intent_gateway, model="default")
|
||||
|
||||
router = CostAwareRouter(llm_gateway=gateway, model="default")
|
||||
result = await router.route(
|
||||
content="分析下这个数据",
|
||||
skill_registry=registry,
|
||||
intent_router=intent_router,
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
assert 0.3 <= result.complexity <= 0.7
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_low_complexity_routes_to_default(self):
|
||||
"""低复杂度 (<0.3) 路由到默认 Agent"""
|
||||
gateway = _make_llm_gateway(json.dumps({"complexity": 0.1}))
|
||||
router = CostAwareRouter(llm_gateway=gateway, model="default")
|
||||
result = await router.route(
|
||||
content="随便聊聊",
|
||||
skill_registry=_make_skill_registry(),
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
assert result.complexity < 0.3
|
||||
assert result.match_method == "low_complexity"
|
||||
assert result.agent_name == "default"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_llm_gateway_defaults_to_medium(self):
|
||||
"""无 LLM Gateway 时 quick_classify 返回 0.5(中等复杂度)"""
|
||||
router = CostAwareRouter(llm_gateway=None)
|
||||
complexity = await router.quick_classify("分析下这个数据")
|
||||
assert complexity == 0.5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_malformed_response_defaults_to_medium(self):
|
||||
"""LLM 返回非 JSON 时 quick_classify 返回 0.5"""
|
||||
gateway = _make_llm_gateway("这不是JSON")
|
||||
router = CostAwareRouter(llm_gateway=gateway, model="default")
|
||||
complexity = await router.quick_classify("分析下这个数据")
|
||||
assert complexity == 0.5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complexity_clamped_to_0_1(self):
|
||||
"""复杂度值被限制在 [0.0, 1.0] 范围"""
|
||||
gateway = _make_llm_gateway(json.dumps({"complexity": 1.5}))
|
||||
router = CostAwareRouter(llm_gateway=gateway, model="default")
|
||||
complexity = await router.quick_classify("超级复杂任务")
|
||||
assert complexity == 1.0
|
||||
|
||||
gateway2 = _make_llm_gateway(json.dumps({"complexity": -0.5}))
|
||||
router2 = CostAwareRouter(llm_gateway=gateway2, model="default")
|
||||
complexity2 = await router2.quick_classify("简单任务")
|
||||
assert complexity2 == 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Layer 2: Capability matching / Auction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLayer2CapabilityMatching:
|
||||
"""Layer 2: 能力匹配 / 拍卖"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_high_complexity_triggers_capability_matching(self):
|
||||
"""'做市场调研+竞品分析' 复杂度 > 0.7,触发能力匹配"""
|
||||
gateway = _make_llm_gateway(json.dumps({"complexity": 0.85}))
|
||||
org_context = MagicMock()
|
||||
org_context.find_best_agent = AsyncMock(return_value="market-researcher")
|
||||
|
||||
router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=org_context)
|
||||
result = await router.route(
|
||||
content="做市场调研+竞品分析",
|
||||
skill_registry=_make_skill_registry(),
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
assert result.complexity > 0.7
|
||||
assert result.match_method == "capability"
|
||||
assert result.agent_name == "market-researcher"
|
||||
assert result.matched is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_layer2_with_org_context_object(self):
|
||||
"""org_context.find_best_agent 返回对象时提取 name 属性"""
|
||||
gateway = _make_llm_gateway(json.dumps({"complexity": 0.9}))
|
||||
agent_obj = MagicMock()
|
||||
agent_obj.name = "analyst-agent"
|
||||
org_context = MagicMock()
|
||||
org_context.find_best_agent = AsyncMock(return_value=agent_obj)
|
||||
|
||||
router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=org_context)
|
||||
result = await router.route(
|
||||
content="做市场调研+竞品分析",
|
||||
skill_registry=_make_skill_registry(),
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
assert result.agent_name == "analyst-agent"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_layer2_without_org_context_falls_back_to_intent_router(self):
|
||||
"""无 org_context 时 Layer 2 回退到 IntentRouter"""
|
||||
gateway = _make_llm_gateway(json.dumps({"complexity": 0.8}))
|
||||
search_skill = _make_skill("search", keywords=["调研"], description="市场调研")
|
||||
registry = _make_skill_registry([search_skill])
|
||||
|
||||
intent_gateway = _make_llm_gateway(json.dumps({"skill": "search", "confidence": 0.9}))
|
||||
intent_router = IntentRouter(llm_gateway=intent_gateway, model="default")
|
||||
|
||||
router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=None)
|
||||
result = await router.route(
|
||||
content="做市场调研+竞品分析",
|
||||
skill_registry=registry,
|
||||
intent_router=intent_router,
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
assert result.complexity > 0.7
|
||||
# 回退到 IntentRouter,可能匹配到 skill 或走 default
|
||||
assert result.match_method in ("capability", "keyword", "llm", "intent_router_fallback", None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_layer2_org_context_find_best_agent_returns_none(self):
|
||||
"""org_context.find_best_agent 返回 None 时回退到 IntentRouter"""
|
||||
gateway = _make_llm_gateway(json.dumps({"complexity": 0.8}))
|
||||
org_context = MagicMock()
|
||||
org_context.find_best_agent = AsyncMock(return_value=None)
|
||||
|
||||
search_skill = _make_skill("search", keywords=["调研"], description="市场调研")
|
||||
registry = _make_skill_registry([search_skill])
|
||||
intent_gateway = _make_llm_gateway(json.dumps({"skill": "search", "confidence": 0.9}))
|
||||
intent_router = IntentRouter(llm_gateway=intent_gateway, model="default")
|
||||
|
||||
router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=org_context)
|
||||
result = await router.route(
|
||||
content="做市场调研+竞品分析",
|
||||
skill_registry=registry,
|
||||
intent_router=intent_router,
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
assert result.complexity > 0.7
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auction_disabled_by_default(self):
|
||||
"""拍卖模式默认禁用"""
|
||||
router = CostAwareRouter()
|
||||
assert router._auction_enabled is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auction_can_be_enabled(self):
|
||||
"""拍卖模式可手动启用"""
|
||||
router = CostAwareRouter(auction_enabled=True)
|
||||
assert router._auction_enabled is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Transparency
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTransparency:
|
||||
"""透明度级别切换"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_silent_mode_no_trace(self):
|
||||
"""SILENT 模式不暴露路由追踪"""
|
||||
router = CostAwareRouter()
|
||||
result = await router.route(
|
||||
content="你好",
|
||||
skill_registry=_make_skill_registry(),
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
transparency="SILENT",
|
||||
)
|
||||
assert result.execution_trace == []
|
||||
assert result.transparency_level == "SILENT"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verbose_mode_shows_trace(self):
|
||||
"""VERBOSE 模式显示路由追踪"""
|
||||
router = CostAwareRouter()
|
||||
result = await router.route(
|
||||
content="你好",
|
||||
skill_registry=_make_skill_registry(),
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
transparency="VERBOSE",
|
||||
)
|
||||
assert len(result.execution_trace) > 0
|
||||
assert result.execution_trace[0]["layer"] == 0
|
||||
assert result.execution_trace[0]["method"] == "greeting"
|
||||
assert result.transparency_level == "VERBOSE"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trace_mode_shows_full_trace(self):
|
||||
"""TRACE 模式显示完整路由追踪"""
|
||||
gateway = _make_llm_gateway(json.dumps({"complexity": 0.85}))
|
||||
org_context = MagicMock()
|
||||
org_context.find_best_agent = AsyncMock(return_value="analyst")
|
||||
|
||||
router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=org_context)
|
||||
result = await router.route(
|
||||
content="做市场调研+竞品分析",
|
||||
skill_registry=_make_skill_registry(),
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
transparency="TRACE",
|
||||
)
|
||||
assert len(result.execution_trace) > 0
|
||||
# 应包含 Layer 1 quick_classify 和 Layer 2 的记录
|
||||
layers = [t["layer"] for t in result.execution_trace]
|
||||
assert 1 in layers # Layer 1 quick_classify
|
||||
assert 2 in layers # Layer 2 capability matching
|
||||
assert result.transparency_level == "TRACE"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_transparency_is_silent(self):
|
||||
"""默认透明度为 SILENT"""
|
||||
router = CostAwareRouter()
|
||||
result = await router.route(
|
||||
content="你好",
|
||||
skill_registry=_make_skill_registry(),
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
assert result.transparency_level == "SILENT"
|
||||
assert result.execution_trace == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SkillRoutingResult 新字段
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSkillRoutingResultNewFields:
|
||||
"""SkillRoutingResult 新字段验证"""
|
||||
|
||||
def test_default_transparency_level(self):
|
||||
result = SkillRoutingResult()
|
||||
assert result.transparency_level == "SILENT"
|
||||
|
||||
def test_default_execution_trace(self):
|
||||
result = SkillRoutingResult()
|
||||
assert result.execution_trace == []
|
||||
|
||||
def test_default_complexity(self):
|
||||
result = SkillRoutingResult()
|
||||
assert result.complexity == 0.0
|
||||
|
||||
def test_new_fields_backward_compatible(self):
|
||||
"""新字段不影响旧代码创建 SkillRoutingResult"""
|
||||
result = SkillRoutingResult(
|
||||
skill_name="test",
|
||||
matched=True,
|
||||
match_method="keyword",
|
||||
)
|
||||
assert result.transparency_level == "SILENT"
|
||||
assert result.execution_trace == []
|
||||
assert result.complexity == 0.0
|
||||
|
|
@ -0,0 +1,362 @@
|
|||
"""OrganizationContext 与 AgentDiscovery 单元测试"""
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.org.context import AgentProfile, OrganizationContext
|
||||
from agentkit.org.discovery import AgentDiscovery
|
||||
from agentkit.skills.base import Skill, SkillConfig
|
||||
from agentkit.skills.registry import SkillRegistry
|
||||
|
||||
|
||||
# ---- Fixtures ----
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def org_context():
|
||||
return OrganizationContext()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def profile_rag():
|
||||
return AgentProfile(
|
||||
name="rag_agent",
|
||||
agent_type="react",
|
||||
capabilities=["rag", "search"],
|
||||
skills=["rag_skill"],
|
||||
execution_mode="react",
|
||||
model="gpt-4",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def profile_terminal():
|
||||
return AgentProfile(
|
||||
name="terminal_agent",
|
||||
agent_type="react",
|
||||
capabilities=["terminal", "shell"],
|
||||
skills=["terminal_skill"],
|
||||
execution_mode="react",
|
||||
model="gpt-4",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def profile_coder():
|
||||
return AgentProfile(
|
||||
name="coder_agent",
|
||||
agent_type="rewoo",
|
||||
capabilities=["rag", "terminal", "code_gen"],
|
||||
skills=["coder_skill"],
|
||||
execution_mode="rewoo",
|
||||
model="claude-3",
|
||||
max_concurrency=3,
|
||||
)
|
||||
|
||||
|
||||
# ---- OrganizationContext: 注册与注销 ----
|
||||
|
||||
|
||||
class TestOrganizationContextRegister:
|
||||
"""注册与注销 Agent 档案"""
|
||||
|
||||
def test_register_agent(self, org_context, profile_rag):
|
||||
org_context.register_agent(profile_rag)
|
||||
assert org_context.get_agent_profile("rag_agent") is profile_rag
|
||||
|
||||
def test_unregister_agent(self, org_context, profile_rag):
|
||||
org_context.register_agent(profile_rag)
|
||||
org_context.unregister_agent("rag_agent")
|
||||
assert org_context.get_agent_profile("rag_agent") is None
|
||||
|
||||
def test_unregister_nonexistent_no_error(self, org_context):
|
||||
org_context.unregister_agent("nonexistent") # should not raise
|
||||
|
||||
def test_register_overwrites_existing(self, org_context, profile_rag):
|
||||
org_context.register_agent(profile_rag)
|
||||
updated = AgentProfile(
|
||||
name="rag_agent",
|
||||
agent_type="react",
|
||||
capabilities=["rag", "search", "summarize"],
|
||||
skills=["rag_skill"],
|
||||
)
|
||||
org_context.register_agent(updated)
|
||||
profile = org_context.get_agent_profile("rag_agent")
|
||||
assert profile is updated
|
||||
assert "summarize" in profile.capabilities
|
||||
|
||||
def test_list_agents(self, org_context, profile_rag, profile_terminal):
|
||||
org_context.register_agent(profile_rag)
|
||||
org_context.register_agent(profile_terminal)
|
||||
agents = org_context.list_agents()
|
||||
assert len(agents) == 2
|
||||
names = {a.name for a in agents}
|
||||
assert names == {"rag_agent", "terminal_agent"}
|
||||
|
||||
def test_list_agents_empty(self, org_context):
|
||||
assert org_context.list_agents() == []
|
||||
|
||||
|
||||
# ---- OrganizationContext: 能力查找 ----
|
||||
|
||||
|
||||
class TestOrganizationContextFind:
|
||||
"""find_best_agent() 测试"""
|
||||
|
||||
def test_find_by_required_capabilities(self, org_context, profile_rag, profile_terminal):
|
||||
org_context.register_agent(profile_rag)
|
||||
org_context.register_agent(profile_terminal)
|
||||
result = org_context.find_best_agent(["rag"])
|
||||
assert result is not None
|
||||
assert result.name == "rag_agent"
|
||||
|
||||
def test_find_exact_capability_match(self, org_context, profile_rag, profile_coder):
|
||||
org_context.register_agent(profile_rag)
|
||||
org_context.register_agent(profile_coder)
|
||||
# 两者都有 rag,但 coder 还有 terminal
|
||||
result = org_context.find_best_agent(["rag", "terminal"])
|
||||
assert result is not None
|
||||
assert result.name == "coder_agent"
|
||||
|
||||
def test_find_no_match_returns_none(self, org_context, profile_rag):
|
||||
org_context.register_agent(profile_rag)
|
||||
result = org_context.find_best_agent(["nonexistent_capability"])
|
||||
assert result is None
|
||||
|
||||
def test_find_excluded_agents_skipped(self, org_context, profile_rag, profile_coder):
|
||||
org_context.register_agent(profile_rag)
|
||||
org_context.register_agent(profile_coder)
|
||||
result = org_context.find_best_agent(["rag"], exclude=["coder_agent"])
|
||||
assert result is not None
|
||||
assert result.name == "rag_agent"
|
||||
|
||||
def test_find_unavailable_agents_skipped(self, org_context, profile_rag, profile_coder):
|
||||
org_context.register_agent(profile_rag)
|
||||
org_context.register_agent(profile_coder)
|
||||
org_context.set_availability("coder_agent", False)
|
||||
result = org_context.find_best_agent(["rag", "terminal"])
|
||||
assert result is None # coder is unavailable, rag doesn't have terminal
|
||||
|
||||
def test_find_best_agent_with_load_balancing(self, org_context):
|
||||
low_load = AgentProfile(
|
||||
name="low_load_agent",
|
||||
agent_type="react",
|
||||
capabilities=["rag"],
|
||||
skills=["rag_skill"],
|
||||
current_load=0,
|
||||
)
|
||||
high_load = AgentProfile(
|
||||
name="high_load_agent",
|
||||
agent_type="react",
|
||||
capabilities=["rag"],
|
||||
skills=["rag_skill"],
|
||||
current_load=5,
|
||||
)
|
||||
org_context.register_agent(low_load)
|
||||
org_context.register_agent(high_load)
|
||||
result = org_context.find_best_agent(["rag"])
|
||||
assert result is not None
|
||||
assert result.name == "low_load_agent"
|
||||
|
||||
def test_find_capability_case_insensitive(self, org_context, profile_rag):
|
||||
org_context.register_agent(profile_rag)
|
||||
result = org_context.find_best_agent(["RAG"])
|
||||
assert result is not None
|
||||
assert result.name == "rag_agent"
|
||||
|
||||
|
||||
# ---- OrganizationContext: 负载与可用性 ----
|
||||
|
||||
|
||||
class TestOrganizationContextLoadAvailability:
|
||||
"""update_load() 和 set_availability() 测试"""
|
||||
|
||||
def test_update_load_increase(self, org_context, profile_rag):
|
||||
org_context.register_agent(profile_rag)
|
||||
org_context.update_load("rag_agent", 3)
|
||||
assert org_context.get_agent_profile("rag_agent").current_load == 3
|
||||
|
||||
def test_update_load_decrease(self, org_context, profile_rag):
|
||||
org_context.register_agent(profile_rag)
|
||||
org_context.update_load("rag_agent", 5)
|
||||
org_context.update_load("rag_agent", -2)
|
||||
assert org_context.get_agent_profile("rag_agent").current_load == 3
|
||||
|
||||
def test_update_load_never_below_zero(self, org_context, profile_rag):
|
||||
org_context.register_agent(profile_rag)
|
||||
org_context.update_load("rag_agent", -10)
|
||||
assert org_context.get_agent_profile("rag_agent").current_load == 0
|
||||
|
||||
def test_update_load_nonexistent_no_error(self, org_context):
|
||||
org_context.update_load("nonexistent", 1) # should not raise
|
||||
|
||||
def test_set_availability(self, org_context, profile_rag):
|
||||
org_context.register_agent(profile_rag)
|
||||
org_context.set_availability("rag_agent", False)
|
||||
assert org_context.get_agent_profile("rag_agent").availability is False
|
||||
org_context.set_availability("rag_agent", True)
|
||||
assert org_context.get_agent_profile("rag_agent").availability is True
|
||||
|
||||
def test_set_availability_nonexistent_no_error(self, org_context):
|
||||
org_context.set_availability("nonexistent", False) # should not raise
|
||||
|
||||
|
||||
# ---- OrganizationContext: from_agent_pool ----
|
||||
|
||||
|
||||
class TestOrganizationContextFromPool:
|
||||
"""from_agent_pool() 测试"""
|
||||
|
||||
def test_from_agent_pool_builds_context(self):
|
||||
"""从 AgentPool + SkillRegistry 构建 OrganizationContext"""
|
||||
skill_registry = SkillRegistry()
|
||||
skill_config = SkillConfig(
|
||||
name="my_skill",
|
||||
agent_type="react",
|
||||
capabilities=["rag", "search"],
|
||||
execution_mode="react",
|
||||
llm={"model": "gpt-4"},
|
||||
max_concurrency=2,
|
||||
prompt={"identity": "Test"},
|
||||
)
|
||||
skill = Skill(config=skill_config)
|
||||
skill_registry.register(skill)
|
||||
|
||||
# Mock agent_pool
|
||||
class FakeAgentPool:
|
||||
def list_agents(self):
|
||||
return [{"name": "my_skill", "agent_type": "react"}]
|
||||
|
||||
ctx = OrganizationContext.from_agent_pool(FakeAgentPool(), skill_registry)
|
||||
profile = ctx.get_agent_profile("my_skill")
|
||||
assert profile is not None
|
||||
assert profile.agent_type == "react"
|
||||
assert "rag" in profile.capabilities
|
||||
assert "search" in profile.capabilities
|
||||
assert profile.execution_mode == "react"
|
||||
assert profile.model == "gpt-4"
|
||||
assert profile.max_concurrency == 2
|
||||
|
||||
def test_from_agent_pool_none_graceful(self):
|
||||
"""agent_pool 或 skill_registry 为 None 时返回空上下文"""
|
||||
ctx = OrganizationContext.from_agent_pool(None, SkillRegistry())
|
||||
assert ctx.list_agents() == []
|
||||
|
||||
class FakePool:
|
||||
def list_agents(self):
|
||||
return []
|
||||
|
||||
ctx = OrganizationContext.from_agent_pool(FakePool(), None)
|
||||
assert ctx.list_agents() == []
|
||||
|
||||
def test_from_agent_pool_agent_not_in_registry(self):
|
||||
"""Agent 不在 skill_registry 中时使用默认值"""
|
||||
skill_registry = SkillRegistry()
|
||||
|
||||
class FakeAgentPool:
|
||||
def list_agents(self):
|
||||
return [{"name": "unknown_agent", "agent_type": "direct"}]
|
||||
|
||||
ctx = OrganizationContext.from_agent_pool(FakeAgentPool(), skill_registry)
|
||||
profile = ctx.get_agent_profile("unknown_agent")
|
||||
assert profile is not None
|
||||
assert profile.agent_type == "direct"
|
||||
assert profile.capabilities == []
|
||||
assert profile.execution_mode == "react" # default
|
||||
assert profile.model == "default"
|
||||
|
||||
|
||||
# ---- AgentDiscovery ----
|
||||
|
||||
|
||||
class TestAgentDiscoveryByCapability:
|
||||
"""discover_by_capability() 测试"""
|
||||
|
||||
def test_discover_by_capability(self, org_context, profile_rag, profile_coder):
|
||||
org_context.register_agent(profile_rag)
|
||||
org_context.register_agent(profile_coder)
|
||||
discovery = AgentDiscovery(org_context)
|
||||
result = discovery.discover_by_capability(["rag"])
|
||||
names = {p.name for p in result}
|
||||
assert names == {"rag_agent", "coder_agent"}
|
||||
|
||||
def test_discover_by_capability_no_match(self, org_context, profile_rag):
|
||||
org_context.register_agent(profile_rag)
|
||||
discovery = AgentDiscovery(org_context)
|
||||
result = discovery.discover_by_capability(["nonexistent"])
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestAgentDiscoveryByMode:
|
||||
"""discover_by_execution_mode() 测试"""
|
||||
|
||||
def test_discover_by_execution_mode(self, org_context, profile_rag, profile_coder):
|
||||
org_context.register_agent(profile_rag)
|
||||
org_context.register_agent(profile_coder)
|
||||
discovery = AgentDiscovery(org_context)
|
||||
result = discovery.discover_by_execution_mode("rewoo")
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "coder_agent"
|
||||
|
||||
def test_discover_by_execution_mode_no_match(self, org_context, profile_rag):
|
||||
org_context.register_agent(profile_rag)
|
||||
discovery = AgentDiscovery(org_context)
|
||||
result = discovery.discover_by_execution_mode("plan_exec")
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestAgentDiscoveryAvailable:
|
||||
"""discover_available() 测试"""
|
||||
|
||||
def test_discover_available(self, org_context, profile_rag, profile_coder):
|
||||
org_context.register_agent(profile_rag)
|
||||
org_context.register_agent(profile_coder)
|
||||
org_context.set_availability("coder_agent", False)
|
||||
discovery = AgentDiscovery(org_context)
|
||||
result = discovery.discover_available()
|
||||
names = {p.name for p in result}
|
||||
assert names == {"rag_agent"}
|
||||
|
||||
|
||||
class TestAgentDiscoveryRecommend:
|
||||
"""recommend_agent() 测试"""
|
||||
|
||||
def test_recommend_with_preferred_mode(self, org_context, profile_rag, profile_coder):
|
||||
org_context.register_agent(profile_rag)
|
||||
org_context.register_agent(profile_coder)
|
||||
discovery = AgentDiscovery(org_context)
|
||||
result = discovery.recommend_agent(["rag"], preferred_mode="rewoo")
|
||||
assert result is not None
|
||||
assert result.name == "coder_agent"
|
||||
|
||||
def test_recommend_without_preferred_mode(self, org_context, profile_rag, profile_coder):
|
||||
org_context.register_agent(profile_rag)
|
||||
org_context.register_agent(profile_coder)
|
||||
discovery = AgentDiscovery(org_context)
|
||||
result = discovery.recommend_agent(["rag"])
|
||||
assert result is not None
|
||||
# Both have rag, should pick lower load
|
||||
assert result.current_load == 0
|
||||
|
||||
def test_recommend_fallback_when_no_capability_match(self, org_context, profile_rag):
|
||||
org_context.register_agent(profile_rag)
|
||||
discovery = AgentDiscovery(org_context)
|
||||
result = discovery.recommend_agent(["nonexistent"])
|
||||
# Falls back to any available agent
|
||||
assert result is not None
|
||||
assert result.name == "rag_agent"
|
||||
|
||||
def test_recommend_returns_none_when_no_available(self, org_context, profile_rag):
|
||||
org_context.register_agent(profile_rag)
|
||||
org_context.set_availability("rag_agent", False)
|
||||
discovery = AgentDiscovery(org_context)
|
||||
result = discovery.recommend_agent(["rag"])
|
||||
assert result is None
|
||||
|
||||
def test_recommend_preferred_mode_no_match_uses_any_match(self, org_context, profile_rag):
|
||||
org_context.register_agent(profile_rag)
|
||||
discovery = AgentDiscovery(org_context)
|
||||
# rag_agent has react mode, but we prefer plan_exec
|
||||
result = discovery.recommend_agent(["rag"], preferred_mode="plan_exec")
|
||||
# No plan_exec match, but still has capability match
|
||||
assert result is not None
|
||||
assert result.name == "rag_agent"
|
||||
|
|
@ -0,0 +1,267 @@
|
|||
"""Tests for U8: Soul Dynamic Evolution — SOUL 动态进化与版本追踪."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus
|
||||
from agentkit.evolution.lifecycle import EvolutionMixin
|
||||
from agentkit.evolution.reflector import Reflection, Reflector
|
||||
from agentkit.memory.profile import MemoryStore
|
||||
from agentkit.tools.memory_tool import MemoryTool
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(tmp_path: Path) -> MemoryStore:
|
||||
return MemoryStore(base_dir=tmp_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool(store: MemoryStore) -> MemoryTool:
|
||||
return MemoryTool(memory_store=store)
|
||||
|
||||
|
||||
def _make_task(task_id: str = "test-001") -> TaskMessage:
|
||||
return TaskMessage(
|
||||
task_id=task_id,
|
||||
agent_name="evolving_agent",
|
||||
task_type="echo",
|
||||
priority=0,
|
||||
input_data={"query": "hello"},
|
||||
callback_url=None,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
def _make_result(status: str = TaskStatus.COMPLETED) -> TaskResult:
|
||||
return TaskResult(
|
||||
task_id="test-001",
|
||||
agent_name="evolving_agent",
|
||||
status=status,
|
||||
output_data={"key": "value"},
|
||||
error_message=None,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
metrics={"elapsed_seconds": 5.0},
|
||||
)
|
||||
|
||||
|
||||
class LowQualityReflector(Reflector):
|
||||
"""总是产生低质量结果和改进建议的 Reflector."""
|
||||
|
||||
async def reflect(self, task, result):
|
||||
return Reflection(
|
||||
task_id=task.task_id,
|
||||
agent_name=result.agent_name,
|
||||
outcome="failure",
|
||||
quality_score=0.2,
|
||||
patterns=["slow_execution"],
|
||||
insights=["Low quality score indicates potential issues"],
|
||||
suggestions=["Consider prompt optimization for this task type"],
|
||||
)
|
||||
|
||||
|
||||
class HighQualityReflector(Reflector):
|
||||
"""总是产生高质量结果的 Reflector."""
|
||||
|
||||
async def reflect(self, task, result):
|
||||
return Reflection(
|
||||
task_id=task.task_id,
|
||||
agent_name=result.agent_name,
|
||||
outcome="success",
|
||||
quality_score=0.8,
|
||||
patterns=["fast_execution"],
|
||||
insights=[],
|
||||
suggestions=[],
|
||||
)
|
||||
|
||||
|
||||
class LowQualityNoSuggestionsReflector(Reflector):
|
||||
"""低质量但没有建议的 Reflector."""
|
||||
|
||||
async def reflect(self, task, result):
|
||||
return Reflection(
|
||||
task_id=task.task_id,
|
||||
agent_name=result.agent_name,
|
||||
outcome="failure",
|
||||
quality_score=0.2,
|
||||
patterns=["slow_execution"],
|
||||
insights=["Low quality"],
|
||||
suggestions=[],
|
||||
)
|
||||
|
||||
|
||||
# ── MemoryTool update_soul action 测试 ──────────────────────
|
||||
|
||||
|
||||
class TestMemoryToolUpdateSoul:
|
||||
"""MemoryTool update_soul 操作测试."""
|
||||
|
||||
async def test_basic_update_increments_version(self, tool: MemoryTool, store: MemoryStore):
|
||||
"""基本更新会递增版本号."""
|
||||
# 初始化 SOUL
|
||||
store.get_file("soul").write("## 身份\n我是助手")
|
||||
|
||||
result = await tool.execute(
|
||||
action="update_soul",
|
||||
file="soul",
|
||||
section="性格",
|
||||
content="更加耐心",
|
||||
)
|
||||
assert result["success"] is True
|
||||
assert result["version"] == 2
|
||||
|
||||
# 验证版本 section
|
||||
version_content = store.get_file("soul").read_section("版本")
|
||||
assert "版本: 2" in version_content
|
||||
|
||||
async def test_creates_version_section_if_missing(self, tool: MemoryTool, store: MemoryStore):
|
||||
"""如果不存在版本 section 则创建."""
|
||||
store.get_file("soul").write("## 身份\n我是助手")
|
||||
|
||||
result = await tool.execute(
|
||||
action="update_soul",
|
||||
file="soul",
|
||||
section="性格",
|
||||
content="友好",
|
||||
)
|
||||
assert result["success"] is True
|
||||
assert result["version"] == 2
|
||||
|
||||
# 版本 section 应该存在
|
||||
sections = store.get_file("soul").list_sections()
|
||||
assert "版本" in sections
|
||||
|
||||
async def test_adds_update_history_entry(self, tool: MemoryTool, store: MemoryStore):
|
||||
"""更新历史条目被正确添加."""
|
||||
store.get_file("soul").write("## 身份\n我是助手")
|
||||
|
||||
result = await tool.execute(
|
||||
action="update_soul",
|
||||
file="soul",
|
||||
section="性格",
|
||||
content="更加耐心",
|
||||
reason="用户反馈需要更耐心",
|
||||
)
|
||||
assert result["success"] is True
|
||||
|
||||
history_content = store.get_file("soul").read_section("更新历史")
|
||||
assert "v2" in history_content
|
||||
assert "性格" in history_content
|
||||
assert "用户反馈需要更耐心" in history_content
|
||||
|
||||
async def test_history_limited_to_10_entries(self, tool: MemoryTool, store: MemoryStore):
|
||||
"""更新历史最多保留 10 条."""
|
||||
store.get_file("soul").write("## 身份\n我是助手")
|
||||
|
||||
# 执行 12 次更新
|
||||
for i in range(12):
|
||||
result = await tool.execute(
|
||||
action="update_soul",
|
||||
file="soul",
|
||||
section=f"section_{i}",
|
||||
content=f"content_{i}",
|
||||
)
|
||||
assert result["success"] is True
|
||||
|
||||
history_content = store.get_file("soul").read_section("更新历史")
|
||||
lines = [line for line in history_content.strip().split("\n") if line.strip()]
|
||||
assert len(lines) <= 10
|
||||
|
||||
async def test_requires_section_and_content(self, tool: MemoryTool, store: MemoryStore):
|
||||
"""缺少 section 或 content 时返回错误."""
|
||||
store.get_file("soul").write("## 身份\n我是助手")
|
||||
|
||||
# 缺少 section
|
||||
result = await tool.execute(
|
||||
action="update_soul",
|
||||
file="soul",
|
||||
content="内容",
|
||||
)
|
||||
assert result["success"] is False
|
||||
assert "section" in result.get("error", "").lower()
|
||||
|
||||
# 缺少 content
|
||||
result = await tool.execute(
|
||||
action="update_soul",
|
||||
file="soul",
|
||||
section="性格",
|
||||
)
|
||||
assert result["success"] is False
|
||||
assert "content" in result.get("error", "").lower()
|
||||
|
||||
async def test_invalid_action_still_rejected(self, tool: MemoryTool):
|
||||
"""无效 action 仍然被拒绝."""
|
||||
result = await tool.execute(action="delete_everything", file="soul")
|
||||
assert result["success"] is False
|
||||
assert "Unknown action" in result.get("error", "")
|
||||
|
||||
|
||||
# ── EvolutionMixin.evolve_soul 测试 ──────────────────────────
|
||||
|
||||
|
||||
class TestEvolveSoul:
|
||||
"""EvolutionMixin.evolve_soul 测试."""
|
||||
|
||||
async def test_no_update_when_fewer_than_3_reflections(self, store: MemoryStore):
|
||||
"""少于 3 次同类反思时不触发 soul 更新."""
|
||||
reflector = LowQualityReflector()
|
||||
mixin = EvolutionMixin(reflector=reflector)
|
||||
|
||||
task = _make_task()
|
||||
result = _make_result()
|
||||
|
||||
# 只调用 2 次,不够 3 次阈值
|
||||
for _ in range(2):
|
||||
updated = await mixin.evolve_soul(task, result, memory_store=store)
|
||||
assert updated is False
|
||||
|
||||
async def test_triggers_update_when_3_same_category_reflections(self, store: MemoryStore):
|
||||
"""同类反思累积 >= 3 次时触发 soul 更新."""
|
||||
reflector = LowQualityReflector()
|
||||
mixin = EvolutionMixin(reflector=reflector)
|
||||
|
||||
task = _make_task()
|
||||
result = _make_result()
|
||||
|
||||
# 前 2 次不触发
|
||||
for _ in range(2):
|
||||
updated = await mixin.evolve_soul(task, result, memory_store=store)
|
||||
assert updated is False
|
||||
|
||||
# 第 3 次触发
|
||||
updated = await mixin.evolve_soul(task, result, memory_store=store)
|
||||
assert updated is True
|
||||
|
||||
# 验证 SOUL 被更新了
|
||||
soul_content = store.get_file("soul").read()
|
||||
assert "slow_execution" in soul_content
|
||||
|
||||
async def test_no_update_without_memory_store(self):
|
||||
"""没有 memory_store 时不触发更新."""
|
||||
reflector = LowQualityReflector()
|
||||
mixin = EvolutionMixin(reflector=reflector)
|
||||
|
||||
task = _make_task()
|
||||
result = _make_result()
|
||||
|
||||
updated = await mixin.evolve_soul(task, result, memory_store=None)
|
||||
assert updated is False
|
||||
|
||||
async def test_no_update_when_quality_score_above_threshold(self, store: MemoryStore):
|
||||
"""quality_score >= 0.5 时不触发更新."""
|
||||
reflector = HighQualityReflector()
|
||||
mixin = EvolutionMixin(reflector=reflector)
|
||||
|
||||
task = _make_task()
|
||||
result = _make_result()
|
||||
|
||||
updated = await mixin.evolve_soul(task, result, memory_store=store)
|
||||
assert updated is False
|
||||
Loading…
Reference in New Issue