201 lines
6.5 KiB
Python
201 lines
6.5 KiB
Python
"""IntentRouter - 两级意图路由:关键词匹配 → LLM 分类"""
|
||
|
||
import json
|
||
import logging
|
||
from dataclasses import dataclass
|
||
from typing import Any
|
||
|
||
from agentkit.llm.gateway import LLMGateway
|
||
from agentkit.skills.base import Skill
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@dataclass
|
||
class RoutingResult:
|
||
"""路由结果"""
|
||
|
||
matched_skill: str # 匹配的 Skill 名称
|
||
method: str # "keyword" 或 "llm"
|
||
confidence: float # 关键词匹配为 1.0,LLM 为 0.0-1.0
|
||
|
||
|
||
class IntentRouter:
|
||
"""两级意图路由:关键词匹配 → LLM 分类
|
||
|
||
Level 1: 关键词匹配(零成本,~0ms)
|
||
Level 2: LLM 分类(回退方案,~200 tokens)
|
||
"""
|
||
|
||
def __init__(self, llm_gateway: LLMGateway | None = None, model: str = "default"):
|
||
self._llm_gateway = llm_gateway
|
||
self._model = model
|
||
|
||
async def route(
|
||
self,
|
||
input_data: dict[str, Any],
|
||
skills: list[Skill],
|
||
) -> RoutingResult:
|
||
"""将输入路由到最佳匹配的 Skill
|
||
|
||
Args:
|
||
input_data: 用户输入数据
|
||
skills: 候选 Skill 列表
|
||
|
||
Returns:
|
||
RoutingResult 包含匹配的 Skill 名称、匹配方法和置信度
|
||
|
||
Raises:
|
||
ValueError: 当 skills 列表为空,或 LLM 返回不存在的 Skill 名称时
|
||
RuntimeError: 当关键词匹配失败且没有 LLM Gateway 时
|
||
"""
|
||
if not skills:
|
||
raise ValueError("Skill list cannot be empty")
|
||
|
||
# 只有一个 Skill 时直接返回
|
||
if len(skills) == 1:
|
||
return RoutingResult(
|
||
matched_skill=skills[0].name,
|
||
method="keyword",
|
||
confidence=1.0,
|
||
)
|
||
|
||
# Level 1: 关键词匹配
|
||
keyword_result = self._match_keywords(input_data, skills)
|
||
if keyword_result is not None:
|
||
logger.debug(
|
||
f"Keyword match: skill={keyword_result.matched_skill}, "
|
||
f"confidence={keyword_result.confidence}"
|
||
)
|
||
return keyword_result
|
||
|
||
# Level 2: LLM 分类
|
||
return await self._classify_with_llm(input_data, skills)
|
||
|
||
def _match_keywords(
|
||
self, input_data: dict[str, Any], skills: list[Skill]
|
||
) -> RoutingResult | None:
|
||
"""Level 1: 关键词匹配
|
||
|
||
从 input_data 中提取所有字符串值(包括嵌套),对每个 Skill 的
|
||
intent.keywords 进行大小写不敏感匹配。
|
||
"""
|
||
text_values = self._extract_string_values(input_data)
|
||
combined_text = " ".join(text_values).lower()
|
||
|
||
if not combined_text:
|
||
return None
|
||
|
||
for skill in skills:
|
||
keywords = skill.config.intent.keywords
|
||
for keyword in keywords:
|
||
if keyword.lower() in combined_text:
|
||
return RoutingResult(
|
||
matched_skill=skill.name,
|
||
method="keyword",
|
||
confidence=1.0,
|
||
)
|
||
|
||
return None
|
||
|
||
async def _classify_with_llm(
|
||
self, input_data: dict[str, Any], skills: list[Skill]
|
||
) -> RoutingResult:
|
||
"""Level 2: LLM 分类
|
||
|
||
构建 prompt 列出所有 Skill 的名称、描述和示例,让 LLM 判断
|
||
最佳匹配的 Skill。
|
||
"""
|
||
if self._llm_gateway is None:
|
||
raise RuntimeError(
|
||
"Keyword matching failed and no LLM Gateway configured for fallback"
|
||
)
|
||
|
||
prompt = self._build_classification_prompt(input_data, skills)
|
||
|
||
response = await self._llm_gateway.chat(
|
||
messages=[{"role": "user", "content": prompt}],
|
||
model=self._model,
|
||
)
|
||
|
||
return self._parse_llm_response(response.content, skills)
|
||
|
||
def _build_classification_prompt(
|
||
self, input_data: dict[str, Any], skills: list[Skill]
|
||
) -> str:
|
||
"""构建 LLM 分类 prompt"""
|
||
skill_descriptions = []
|
||
for i, skill in enumerate(skills, 1):
|
||
desc = f"{i}. {skill.name}: {skill.config.intent.description}"
|
||
examples = skill.config.intent.examples
|
||
if examples:
|
||
desc += f"\n Examples: {', '.join(examples)}"
|
||
skill_descriptions.append(desc)
|
||
|
||
skills_block = "\n".join(skill_descriptions)
|
||
|
||
return (
|
||
"You are an intent classifier. Given the user input, determine which skill best matches.\n"
|
||
"\n"
|
||
"Available skills:\n"
|
||
f"{skills_block}\n"
|
||
"\n"
|
||
f"User input: {input_data}\n"
|
||
"\n"
|
||
'Respond in JSON format:\n'
|
||
'{"skill": "skill_name", "confidence": 0.9}'
|
||
)
|
||
|
||
def _parse_llm_response(
|
||
self, content: str, skills: list[Skill]
|
||
) -> RoutingResult:
|
||
"""解析 LLM 响应,提取 skill name 和 confidence"""
|
||
valid_names = {s.name for s in skills}
|
||
|
||
# 尝试 JSON 解析
|
||
try:
|
||
data = json.loads(content.strip())
|
||
skill_name = data.get("skill", "")
|
||
confidence = float(data.get("confidence", 0.0))
|
||
except (json.JSONDecodeError, ValueError, TypeError):
|
||
# JSON 解析失败,尝试从文本中提取 skill name
|
||
skill_name = self._extract_skill_name_from_text(content, valid_names)
|
||
confidence = 0.5 # 文本提取时给默认置信度
|
||
|
||
if skill_name not in valid_names:
|
||
raise ValueError(
|
||
f"LLM returned unknown skill '{skill_name}', "
|
||
f"valid skills are: {sorted(valid_names)}"
|
||
)
|
||
|
||
return RoutingResult(
|
||
matched_skill=skill_name,
|
||
method="llm",
|
||
confidence=confidence,
|
||
)
|
||
|
||
@staticmethod
|
||
def _extract_skill_name_from_text(
|
||
text: str, valid_names: set[str]
|
||
) -> str:
|
||
"""从文本中尝试提取有效的 Skill 名称"""
|
||
text_lower = text.lower()
|
||
for name in valid_names:
|
||
if name.lower() in text_lower:
|
||
return name
|
||
return ""
|
||
|
||
@staticmethod
|
||
def _extract_string_values(data: Any) -> list[str]:
|
||
"""递归提取 input_data 中所有字符串值"""
|
||
results: list[str] = []
|
||
if isinstance(data, str):
|
||
results.append(data)
|
||
elif isinstance(data, dict):
|
||
for value in data.values():
|
||
results.extend(IntentRouter._extract_string_values(value))
|
||
elif isinstance(data, list):
|
||
for item in data:
|
||
results.extend(IntentRouter._extract_string_values(item))
|
||
return results
|