fischer-agentkit/src/agentkit/router/intent.py

201 lines
6.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""IntentRouter - 两级意图路由:关键词匹配 → LLM 分类"""
import json
import logging
from dataclasses import dataclass
from typing import Any
from agentkit.llm.gateway import LLMGateway
from agentkit.skills.base import Skill
logger = logging.getLogger(__name__)
@dataclass
class RoutingResult:
"""路由结果"""
matched_skill: str # 匹配的 Skill 名称
method: str # "keyword" 或 "llm"
confidence: float # 关键词匹配为 1.0LLM 为 0.0-1.0
class IntentRouter:
"""两级意图路由:关键词匹配 → LLM 分类
Level 1: 关键词匹配(零成本,~0ms
Level 2: LLM 分类(回退方案,~200 tokens
"""
def __init__(self, llm_gateway: LLMGateway | None = None, model: str = "default"):
self._llm_gateway = llm_gateway
self._model = model
async def route(
self,
input_data: dict[str, Any],
skills: list[Skill],
) -> RoutingResult:
"""将输入路由到最佳匹配的 Skill
Args:
input_data: 用户输入数据
skills: 候选 Skill 列表
Returns:
RoutingResult 包含匹配的 Skill 名称、匹配方法和置信度
Raises:
ValueError: 当 skills 列表为空,或 LLM 返回不存在的 Skill 名称时
RuntimeError: 当关键词匹配失败且没有 LLM Gateway 时
"""
if not skills:
raise ValueError("Skill list cannot be empty")
# 只有一个 Skill 时直接返回
if len(skills) == 1:
return RoutingResult(
matched_skill=skills[0].name,
method="keyword",
confidence=1.0,
)
# Level 1: 关键词匹配
keyword_result = self._match_keywords(input_data, skills)
if keyword_result is not None:
logger.debug(
f"Keyword match: skill={keyword_result.matched_skill}, "
f"confidence={keyword_result.confidence}"
)
return keyword_result
# Level 2: LLM 分类
return await self._classify_with_llm(input_data, skills)
def _match_keywords(
self, input_data: dict[str, Any], skills: list[Skill]
) -> RoutingResult | None:
"""Level 1: 关键词匹配
从 input_data 中提取所有字符串值(包括嵌套),对每个 Skill 的
intent.keywords 进行大小写不敏感匹配。
"""
text_values = self._extract_string_values(input_data)
combined_text = " ".join(text_values).lower()
if not combined_text:
return None
for skill in skills:
keywords = skill.config.intent.keywords
for keyword in keywords:
if keyword.lower() in combined_text:
return RoutingResult(
matched_skill=skill.name,
method="keyword",
confidence=1.0,
)
return None
async def _classify_with_llm(
self, input_data: dict[str, Any], skills: list[Skill]
) -> RoutingResult:
"""Level 2: LLM 分类
构建 prompt 列出所有 Skill 的名称、描述和示例,让 LLM 判断
最佳匹配的 Skill。
"""
if self._llm_gateway is None:
raise RuntimeError(
"Keyword matching failed and no LLM Gateway configured for fallback"
)
prompt = self._build_classification_prompt(input_data, skills)
response = await self._llm_gateway.chat(
messages=[{"role": "user", "content": prompt}],
model=self._model,
)
return self._parse_llm_response(response.content, skills)
def _build_classification_prompt(
self, input_data: dict[str, Any], skills: list[Skill]
) -> str:
"""构建 LLM 分类 prompt"""
skill_descriptions = []
for i, skill in enumerate(skills, 1):
desc = f"{i}. {skill.name}: {skill.config.intent.description}"
examples = skill.config.intent.examples
if examples:
desc += f"\n Examples: {', '.join(examples)}"
skill_descriptions.append(desc)
skills_block = "\n".join(skill_descriptions)
return (
"You are an intent classifier. Given the user input, determine which skill best matches.\n"
"\n"
"Available skills:\n"
f"{skills_block}\n"
"\n"
f"User input: {input_data}\n"
"\n"
'Respond in JSON format:\n'
'{"skill": "skill_name", "confidence": 0.9}'
)
def _parse_llm_response(
self, content: str, skills: list[Skill]
) -> RoutingResult:
"""解析 LLM 响应,提取 skill name 和 confidence"""
valid_names = {s.name for s in skills}
# 尝试 JSON 解析
try:
data = json.loads(content.strip())
skill_name = data.get("skill", "")
confidence = float(data.get("confidence", 0.0))
except (json.JSONDecodeError, ValueError, TypeError):
# JSON 解析失败,尝试从文本中提取 skill name
skill_name = self._extract_skill_name_from_text(content, valid_names)
confidence = 0.5 # 文本提取时给默认置信度
if skill_name not in valid_names:
raise ValueError(
f"LLM returned unknown skill '{skill_name}', "
f"valid skills are: {sorted(valid_names)}"
)
return RoutingResult(
matched_skill=skill_name,
method="llm",
confidence=confidence,
)
@staticmethod
def _extract_skill_name_from_text(
text: str, valid_names: set[str]
) -> str:
"""从文本中尝试提取有效的 Skill 名称"""
text_lower = text.lower()
for name in valid_names:
if name.lower() in text_lower:
return name
return ""
@staticmethod
def _extract_string_values(data: Any) -> list[str]:
"""递归提取 input_data 中所有字符串值"""
results: list[str] = []
if isinstance(data, str):
results.append(data)
elif isinstance(data, dict):
for value in data.values():
results.extend(IntentRouter._extract_string_values(value))
elif isinstance(data, list):
for item in data:
results.extend(IntentRouter._extract_string_values(item))
return results