"""IntentRouter - 两级意图路由:关键词匹配 → LLM 分类""" import json import logging from dataclasses import dataclass from typing import Any from agentkit.llm.gateway import LLMGateway from agentkit.skills.base import Skill logger = logging.getLogger(__name__) @dataclass class RoutingResult: """路由结果""" matched_skill: str # 匹配的 Skill 名称 method: str # "keyword" 或 "llm" confidence: float # 关键词匹配为 1.0,LLM 为 0.0-1.0 class IntentRouter: """两级意图路由:关键词匹配 → LLM 分类 Level 1: 关键词匹配(零成本,~0ms) Level 2: LLM 分类(回退方案,~200 tokens) """ def __init__(self, llm_gateway: LLMGateway | None = None, model: str = "default"): self._llm_gateway = llm_gateway self._model = model async def route( self, input_data: dict[str, Any], skills: list[Skill], ) -> RoutingResult: """将输入路由到最佳匹配的 Skill Args: input_data: 用户输入数据 skills: 候选 Skill 列表 Returns: RoutingResult 包含匹配的 Skill 名称、匹配方法和置信度 Raises: ValueError: 当 skills 列表为空,或 LLM 返回不存在的 Skill 名称时 RuntimeError: 当关键词匹配失败且没有 LLM Gateway 时 """ if not skills: raise ValueError("Skill list cannot be empty") # 只有一个 Skill 时直接返回 if len(skills) == 1: return RoutingResult( matched_skill=skills[0].name, method="keyword", confidence=1.0, ) # Level 1: 关键词匹配 keyword_result = self._match_keywords(input_data, skills) if keyword_result is not None: logger.debug( f"Keyword match: skill={keyword_result.matched_skill}, " f"confidence={keyword_result.confidence}" ) return keyword_result # Level 2: LLM 分类 return await self._classify_with_llm(input_data, skills) def _match_keywords( self, input_data: dict[str, Any], skills: list[Skill] ) -> RoutingResult | None: """Level 1: 关键词匹配 从 input_data 中提取所有字符串值(包括嵌套),对每个 Skill 的 intent.keywords 进行大小写不敏感匹配。 """ text_values = self._extract_string_values(input_data) combined_text = " ".join(text_values).lower() if not combined_text: return None for skill in skills: keywords = skill.config.intent.keywords for keyword in keywords: if keyword.lower() in combined_text: return RoutingResult( matched_skill=skill.name, method="keyword", confidence=1.0, ) return None async def _classify_with_llm( self, input_data: dict[str, Any], skills: list[Skill] ) -> RoutingResult: """Level 2: LLM 分类 构建 prompt 列出所有 Skill 的名称、描述和示例,让 LLM 判断 最佳匹配的 Skill。 """ if self._llm_gateway is None: raise RuntimeError( "Keyword matching failed and no LLM Gateway configured for fallback" ) prompt = self._build_classification_prompt(input_data, skills) response = await self._llm_gateway.chat( messages=[{"role": "user", "content": prompt}], model=self._model, ) return self._parse_llm_response(response.content, skills) def _build_classification_prompt( self, input_data: dict[str, Any], skills: list[Skill] ) -> str: """构建 LLM 分类 prompt""" skill_descriptions = [] for i, skill in enumerate(skills, 1): desc = f"{i}. {skill.name}: {skill.config.intent.description}" examples = skill.config.intent.examples if examples: desc += f"\n Examples: {', '.join(examples)}" skill_descriptions.append(desc) skills_block = "\n".join(skill_descriptions) return ( "You are an intent classifier. Given the user input, determine which skill best matches.\n" "\n" "Available skills:\n" f"{skills_block}\n" "\n" f"User input: {input_data}\n" "\n" 'Respond in JSON format:\n' '{"skill": "skill_name", "confidence": 0.9}' ) def _parse_llm_response( self, content: str, skills: list[Skill] ) -> RoutingResult: """解析 LLM 响应,提取 skill name 和 confidence""" valid_names = {s.name for s in skills} # 尝试 JSON 解析 try: data = json.loads(content.strip()) skill_name = data.get("skill", "") confidence = float(data.get("confidence", 0.0)) except (json.JSONDecodeError, ValueError, TypeError): # JSON 解析失败,尝试从文本中提取 skill name skill_name = self._extract_skill_name_from_text(content, valid_names) confidence = 0.5 # 文本提取时给默认置信度 if skill_name not in valid_names: raise ValueError( f"LLM returned unknown skill '{skill_name}', " f"valid skills are: {sorted(valid_names)}" ) return RoutingResult( matched_skill=skill_name, method="llm", confidence=confidence, ) @staticmethod def _extract_skill_name_from_text( text: str, valid_names: set[str] ) -> str: """从文本中尝试提取有效的 Skill 名称""" text_lower = text.lower() for name in valid_names: if name.lower() in text_lower: return name return "" @staticmethod def _extract_string_values(data: Any) -> list[str]: """递归提取 input_data 中所有字符串值""" results: list[str] = [] if isinstance(data, str): results.append(data) elif isinstance(data, dict): for value in data.values(): results.extend(IntentRouter._extract_string_values(value)) elif isinstance(data, list): for item in data: results.extend(IntentRouter._extract_string_values(item)) return results