493 lines
18 KiB
Python
493 lines
18 KiB
Python
"""Shared skill routing logic for GUI and CLI chat.
|
||
|
||
Extracts the duplicated skill routing, @skill: prefix parsing,
|
||
and prompt assembly into a single module used by both chat routes.
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
import re
|
||
from dataclasses import dataclass, field
|
||
from typing import Any
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# Strict validation: only lowercase alphanumeric, hyphens, underscores
|
||
_SKILL_NAME_RE = re.compile(r"^[a-z0-9][a-z0-9_-]{0,63}$")
|
||
|
||
|
||
def validate_skill_name(name: str) -> str:
|
||
"""Validate and normalize a skill name. Raises ValueError on invalid input."""
|
||
normalized = name.strip().lower()
|
||
if not _SKILL_NAME_RE.match(normalized):
|
||
raise ValueError(
|
||
f"Invalid skill name '{name}': must match [a-z0-9][a-z0-9_-]{{0,63}}"
|
||
)
|
||
return normalized
|
||
|
||
|
||
@dataclass
|
||
class SkillRoutingResult:
|
||
"""Result of skill routing for a user message."""
|
||
|
||
skill_name: str | None = None
|
||
skill_config: Any = None
|
||
skill_tools: list = field(default_factory=list)
|
||
clean_content: str = ""
|
||
system_prompt: str | None = None
|
||
tools: list = field(default_factory=list)
|
||
model: str = "default"
|
||
agent_name: str | None = None
|
||
matched: bool = False
|
||
match_method: str | None = None
|
||
match_confidence: float = 0.0
|
||
transparency_level: str = "SILENT"
|
||
execution_trace: list[dict] = field(default_factory=list)
|
||
complexity: float = 0.0
|
||
|
||
|
||
def parse_skill_prefix(content: str) -> tuple[str | None, str]:
|
||
"""Parse @skill:name prefix from user message.
|
||
|
||
Returns (skill_name_or_None, clean_content).
|
||
"""
|
||
if not content.startswith("@skill:"):
|
||
return None, content
|
||
|
||
parts = content.split(" ", 1)
|
||
skill_ref = parts[0][7:] # strip "@skill:"
|
||
explicit_skill = skill_ref.strip()
|
||
clean = parts[1].strip() if len(parts) > 1 else content[7 + len(skill_ref):].strip()
|
||
return explicit_skill, clean
|
||
|
||
|
||
def build_skill_system_prompt(skill_config) -> str | None:
|
||
"""Build system prompt from skill config's prompt section."""
|
||
if not skill_config or not skill_config.prompt:
|
||
return None
|
||
prompt_parts = []
|
||
for key in ("identity", "context", "instructions", "constraints", "output_format"):
|
||
val = skill_config.prompt.get(key)
|
||
if val:
|
||
prompt_parts.append(val)
|
||
return "\n\n".join(prompt_parts) if prompt_parts else None
|
||
|
||
|
||
async def resolve_skill_routing(
|
||
content: str,
|
||
skill_registry: Any,
|
||
intent_router: Any,
|
||
default_tools: list,
|
||
default_system_prompt: str | None,
|
||
default_model: str = "default",
|
||
default_agent_name: str = "default",
|
||
agent_tool_registry: Any = None,
|
||
session_id: str = "",
|
||
) -> SkillRoutingResult:
|
||
"""Resolve skill routing for a user message.
|
||
|
||
This is the shared entry point used by both GUI WebSocket chat and CLI chat.
|
||
Returns a SkillRoutingResult with all execution parameters set.
|
||
"""
|
||
result = SkillRoutingResult()
|
||
|
||
# Parse @skill: prefix
|
||
explicit_skill, clean_content = parse_skill_prefix(content)
|
||
result.clean_content = clean_content
|
||
|
||
if explicit_skill:
|
||
logger.info(f"Session {session_id}: explicit skill reference: {explicit_skill}")
|
||
|
||
# Try explicit skill match
|
||
if explicit_skill and skill_registry:
|
||
try:
|
||
matched_skill = skill_registry.get(explicit_skill)
|
||
result.skill_name = explicit_skill
|
||
result.skill_config = matched_skill.config
|
||
result.skill_tools = matched_skill.tools or []
|
||
result.matched = True
|
||
result.match_method = "explicit"
|
||
result.match_confidence = 1.0
|
||
logger.info(f"Session {session_id}: using explicit skill '{explicit_skill}'")
|
||
except Exception as e:
|
||
logger.warning(f"Session {session_id}: explicit skill '{explicit_skill}' not found: {e}")
|
||
# Reset so we don't enter skill branch with stale data
|
||
result.skill_name = None
|
||
result.skill_config = None
|
||
|
||
# Try IntentRouter if no explicit match
|
||
if not result.matched and skill_registry and intent_router:
|
||
skills = skill_registry.list_skills()
|
||
routable_skills = [s for s in skills if s.config.intent.keywords]
|
||
if routable_skills:
|
||
try:
|
||
routing_result = await intent_router.route(
|
||
input_data={"content": clean_content},
|
||
skills=routable_skills,
|
||
)
|
||
if routing_result and routing_result.confidence >= 0.5:
|
||
skill_name = routing_result.matched_skill
|
||
try:
|
||
matched_skill = skill_registry.get(skill_name)
|
||
result.skill_name = skill_name
|
||
result.skill_config = matched_skill.config
|
||
result.skill_tools = matched_skill.tools or []
|
||
result.matched = True
|
||
result.match_method = routing_result.method
|
||
result.match_confidence = routing_result.confidence
|
||
logger.info(
|
||
f"Session {session_id}: routed to skill '{skill_name}' "
|
||
f"via {routing_result.method} (confidence={routing_result.confidence})"
|
||
)
|
||
except Exception as e:
|
||
logger.warning(f"Session {session_id}: skill '{skill_name}' found by router but not in registry: {e}")
|
||
except Exception as e:
|
||
logger.warning(f"Skill routing failed for session {session_id}: {e}")
|
||
|
||
# Determine execution parameters
|
||
if result.matched and result.skill_config:
|
||
skill_prompt = build_skill_system_prompt(result.skill_config)
|
||
result.system_prompt = skill_prompt or default_system_prompt
|
||
|
||
# Merge skill tools with agent tools, deduplicating by name
|
||
agent_tools = agent_tool_registry.list_tools() if agent_tool_registry else default_tools
|
||
seen_names = set()
|
||
merged_tools = []
|
||
for tool in result.skill_tools + agent_tools:
|
||
if tool.name not in seen_names:
|
||
seen_names.add(tool.name)
|
||
merged_tools.append(tool)
|
||
result.tools = merged_tools
|
||
|
||
result.model = result.skill_config.llm.get("model", default_model) if result.skill_config.llm else default_model
|
||
result.agent_name = result.skill_name
|
||
else:
|
||
result.system_prompt = default_system_prompt
|
||
result.tools = default_tools
|
||
result.model = default_model
|
||
result.agent_name = default_agent_name
|
||
|
||
return result
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# CostAwareRouter - 三层成本感知路由
|
||
# ---------------------------------------------------------------------------
|
||
|
||
_GREETING_RE = re.compile(
|
||
r"^(你好|hi|hello|hey|嗨|哈喽|早上好|下午好|晚上好|good morning|good afternoon|good evening)\s*[!!.。??]*$",
|
||
re.IGNORECASE,
|
||
)
|
||
|
||
_CHAT_MODE_RE = re.compile(
|
||
r"^(谢谢|感谢|thanks|thank you|ok|好的|嗯|对|是|不是|没关系|再见|bye|goodbye)\s*[!!.。??]*$",
|
||
re.IGNORECASE,
|
||
)
|
||
|
||
_COMPLEXITY_CLASSIFY_PROMPT = (
|
||
"Assess the complexity of the following user request on a scale of 0.0 to 1.0.\n"
|
||
"0.0 = trivial greeting / simple chat\n"
|
||
"0.3 = single-skill task (e.g. search, translate)\n"
|
||
"0.7 = multi-step or cross-domain task (e.g. market research + competitor analysis)\n"
|
||
"1.0 = highly complex, multi-agent collaboration needed\n\n"
|
||
'User request: "{content}"\n\n'
|
||
'Respond ONLY with a JSON object: {{"complexity": <float>}}'
|
||
)
|
||
|
||
|
||
class CostAwareRouter:
|
||
"""三层成本感知路由器。
|
||
|
||
Layer 0: 规则匹配(零成本)— @skill: 前缀 / 问候 / 简单对话
|
||
Layer 1: LLM 快速分类(~100 tokens)— 复杂度评估 + IntentRouter
|
||
Layer 2: 能力匹配 / 拍卖(可选)— 高复杂度任务委派给最佳 Agent
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
llm_gateway: Any = None,
|
||
model: str = "default",
|
||
org_context: Any = None,
|
||
auction_enabled: bool = False,
|
||
):
|
||
self._llm_gateway = llm_gateway
|
||
self._model = model
|
||
self._org_context = org_context
|
||
self._auction_enabled = auction_enabled
|
||
|
||
# -- Layer 0: Rule-based (zero cost) ------------------------------------
|
||
|
||
def _match_layer0(self, content: str) -> tuple[str | None, str]:
|
||
"""Layer 0 规则匹配。
|
||
|
||
Returns:
|
||
(match_type, clean_content) — match_type 为 None 表示未命中。
|
||
"""
|
||
# @skill: 显式前缀
|
||
explicit_skill, clean = parse_skill_prefix(content)
|
||
if explicit_skill:
|
||
return "explicit_skill", clean
|
||
|
||
# 问候模式
|
||
stripped = content.strip()
|
||
if _GREETING_RE.match(stripped):
|
||
return "greeting", stripped
|
||
|
||
# 简单对话模式
|
||
if _CHAT_MODE_RE.match(stripped):
|
||
return "chat_mode", stripped
|
||
|
||
return None, stripped
|
||
|
||
# -- Layer 1: LLM quick classify (~100 tokens) -------------------------
|
||
|
||
async def quick_classify(self, content: str) -> float:
|
||
"""使用 LLM 快速评估用户请求的复杂度 (0.0-1.0)。
|
||
|
||
当 LLM Gateway 不可用或解析失败时,返回默认中等复杂度 0.5。
|
||
"""
|
||
if self._llm_gateway is None:
|
||
return 0.5
|
||
|
||
prompt = _COMPLEXITY_CLASSIFY_PROMPT.format(content=content)
|
||
try:
|
||
response = await self._llm_gateway.chat(
|
||
messages=[{"role": "user", "content": prompt}],
|
||
model=self._model,
|
||
)
|
||
data = json.loads(response.content.strip())
|
||
complexity = float(data.get("complexity", 0.5))
|
||
return max(0.0, min(1.0, complexity))
|
||
except Exception as e:
|
||
logger.warning(f"CostAwareRouter quick_classify failed: {e}")
|
||
return 0.5
|
||
|
||
# -- Layer 2: Capability matching / Auction (optional) -----------------
|
||
|
||
async def _route_layer2(
|
||
self,
|
||
content: str,
|
||
skill_registry: Any,
|
||
intent_router: Any,
|
||
default_tools: list,
|
||
default_system_prompt: str | None,
|
||
default_model: str,
|
||
default_agent_name: str,
|
||
agent_tool_registry: Any = None,
|
||
session_id: str = "",
|
||
complexity: float = 0.0,
|
||
trace: list[dict] | None = None,
|
||
) -> SkillRoutingResult:
|
||
"""Layer 2: 高复杂度任务通过 org_context.find_best_agent 路由。"""
|
||
if self._org_context is not None and hasattr(self._org_context, "find_best_agent"):
|
||
try:
|
||
best_agent = await self._org_context.find_best_agent(content)
|
||
if best_agent is not None:
|
||
agent_name = best_agent if isinstance(best_agent, str) else getattr(best_agent, "name", str(best_agent))
|
||
result = SkillRoutingResult(
|
||
clean_content=content,
|
||
matched=True,
|
||
match_method="capability",
|
||
match_confidence=0.8,
|
||
agent_name=agent_name,
|
||
model=default_model,
|
||
system_prompt=default_system_prompt,
|
||
tools=default_tools,
|
||
complexity=complexity,
|
||
)
|
||
if trace is not None:
|
||
trace.append({
|
||
"layer": 2,
|
||
"method": "capability",
|
||
"agent_name": agent_name,
|
||
"complexity": complexity,
|
||
})
|
||
return result
|
||
except Exception as e:
|
||
logger.warning(f"CostAwareRouter Layer 2 org_context.find_best_agent failed: {e}")
|
||
|
||
# Fallback: 使用 IntentRouter
|
||
result = await resolve_skill_routing(
|
||
content=content,
|
||
skill_registry=skill_registry,
|
||
intent_router=intent_router,
|
||
default_tools=default_tools,
|
||
default_system_prompt=default_system_prompt,
|
||
default_model=default_model,
|
||
default_agent_name=default_agent_name,
|
||
agent_tool_registry=agent_tool_registry,
|
||
session_id=session_id,
|
||
)
|
||
result.complexity = complexity
|
||
if trace is not None:
|
||
trace.append({
|
||
"layer": 2,
|
||
"method": "intent_router_fallback",
|
||
"complexity": complexity,
|
||
})
|
||
return result
|
||
|
||
# -- Main entry point ---------------------------------------------------
|
||
|
||
async def route(
|
||
self,
|
||
content: str,
|
||
skill_registry: Any,
|
||
intent_router: Any,
|
||
default_tools: list,
|
||
default_system_prompt: str | None,
|
||
default_model: str = "default",
|
||
default_agent_name: str = "default",
|
||
agent_tool_registry: Any = None,
|
||
session_id: str = "",
|
||
transparency: str = "SILENT",
|
||
) -> SkillRoutingResult:
|
||
"""三层成本感知路由主入口。
|
||
|
||
Args:
|
||
content: 用户输入内容
|
||
skill_registry: Skill 注册表
|
||
intent_router: IntentRouter 实例
|
||
default_tools: 默认工具列表
|
||
default_system_prompt: 默认系统提示词
|
||
default_model: 默认模型
|
||
default_agent_name: 默认 Agent 名称
|
||
agent_tool_registry: Agent 工具注册表
|
||
session_id: 会话 ID
|
||
transparency: 透明度级别 (SILENT / VERBOSE / TRACE)
|
||
|
||
Returns:
|
||
SkillRoutingResult 包含路由结果和追踪信息
|
||
"""
|
||
trace: list[dict] = []
|
||
|
||
# ---- Layer 0: Rule-based (zero cost) ----
|
||
match_type, clean_content = self._match_layer0(content)
|
||
|
||
if match_type == "explicit_skill":
|
||
result = await resolve_skill_routing(
|
||
content=content,
|
||
skill_registry=skill_registry,
|
||
intent_router=intent_router,
|
||
default_tools=default_tools,
|
||
default_system_prompt=default_system_prompt,
|
||
default_model=default_model,
|
||
default_agent_name=default_agent_name,
|
||
agent_tool_registry=agent_tool_registry,
|
||
session_id=session_id,
|
||
)
|
||
result.match_method = result.match_method or "explicit_skill"
|
||
result.complexity = 0.0
|
||
trace.append({
|
||
"layer": 0,
|
||
"method": "explicit_skill",
|
||
"matched": result.matched,
|
||
"cost": "zero",
|
||
})
|
||
result.execution_trace = trace if transparency != "SILENT" else []
|
||
result.transparency_level = transparency
|
||
return result
|
||
|
||
if match_type in ("greeting", "chat_mode"):
|
||
result = SkillRoutingResult(
|
||
clean_content=clean_content,
|
||
system_prompt=default_system_prompt,
|
||
tools=default_tools,
|
||
model=default_model,
|
||
agent_name=default_agent_name,
|
||
matched=False,
|
||
match_method=match_type,
|
||
match_confidence=1.0,
|
||
complexity=0.0,
|
||
)
|
||
trace.append({
|
||
"layer": 0,
|
||
"method": match_type,
|
||
"matched": False,
|
||
"cost": "zero",
|
||
})
|
||
result.execution_trace = trace if transparency != "SILENT" else []
|
||
result.transparency_level = transparency
|
||
return result
|
||
|
||
# ---- Layer 1: LLM quick classify (~100 tokens) ----
|
||
complexity = await self.quick_classify(clean_content)
|
||
trace.append({
|
||
"layer": 1,
|
||
"method": "quick_classify",
|
||
"complexity": complexity,
|
||
})
|
||
|
||
# Low complexity → default agent
|
||
if complexity < 0.3:
|
||
result = SkillRoutingResult(
|
||
clean_content=clean_content,
|
||
system_prompt=default_system_prompt,
|
||
tools=default_tools,
|
||
model=default_model,
|
||
agent_name=default_agent_name,
|
||
matched=False,
|
||
match_method="low_complexity",
|
||
match_confidence=1.0 - complexity,
|
||
complexity=complexity,
|
||
)
|
||
trace.append({
|
||
"layer": 1,
|
||
"method": "low_complexity",
|
||
"complexity": complexity,
|
||
"routed_to": "default",
|
||
})
|
||
result.execution_trace = trace if transparency != "SILENT" else []
|
||
result.transparency_level = transparency
|
||
return result
|
||
|
||
# Medium complexity → IntentRouter via resolve_skill_routing
|
||
if complexity <= 0.7:
|
||
result = await resolve_skill_routing(
|
||
content=content,
|
||
skill_registry=skill_registry,
|
||
intent_router=intent_router,
|
||
default_tools=default_tools,
|
||
default_system_prompt=default_system_prompt,
|
||
default_model=default_model,
|
||
default_agent_name=default_agent_name,
|
||
agent_tool_registry=agent_tool_registry,
|
||
session_id=session_id,
|
||
)
|
||
result.complexity = complexity
|
||
trace.append({
|
||
"layer": 1,
|
||
"method": "intent_router",
|
||
"complexity": complexity,
|
||
"matched": result.matched,
|
||
})
|
||
result.execution_trace = trace if transparency != "SILENT" else []
|
||
result.transparency_level = transparency
|
||
return result
|
||
|
||
# ---- Layer 2: Capability matching / Auction (high complexity) ----
|
||
trace.append({
|
||
"layer": 2,
|
||
"method": "capability_or_auction",
|
||
"complexity": complexity,
|
||
"auction_enabled": self._auction_enabled,
|
||
})
|
||
result = await self._route_layer2(
|
||
content=content,
|
||
skill_registry=skill_registry,
|
||
intent_router=intent_router,
|
||
default_tools=default_tools,
|
||
default_system_prompt=default_system_prompt,
|
||
default_model=default_model,
|
||
default_agent_name=default_agent_name,
|
||
agent_tool_registry=agent_tool_registry,
|
||
session_id=session_id,
|
||
complexity=complexity,
|
||
trace=trace,
|
||
)
|
||
result.execution_trace = trace if transparency != "SILENT" else []
|
||
result.transparency_level = transparency
|
||
return result
|