fischer-agentkit/src/agentkit/chat/semantic_router.py

208 lines
7.2 KiB
Python

"""Semantic Router — Embedding-based intent routing as Layer 1.5.
Uses pre-computed skill embeddings for zero-cost semantic matching,
inserted between Layer 1 (HeuristicClassifier) and Layer 2 (LLM classification)
in CostAwareRouter.
Design doc: docs/plans/2026-06-14-004-u3-semantic-router.md
"""
import logging
from dataclasses import dataclass
from typing import Any
from agentkit.memory.embedder import Embedder, EmbeddingCache
from agentkit.utils.vector_math import compute_cosine_similarity
logger = logging.getLogger(__name__)
@dataclass
class SemanticRouteResult:
"""Result of semantic routing."""
confidence: str # "high" | "medium" | "low"
skill_name: str | None
similarity: float
class SkillEmbeddingIndex:
"""Pre-computed embedding index for registered skills.
Embeddings are computed at skill registration time and cached.
Query-time search is O(n) cosine similarity scan, which is fast
for <100 skills with 1024-1536 dim vectors.
"""
def __init__(self, embedder: Embedder):
self._embedder = embedder
# skill_name → (embedding, source_text)
self._index: dict[str, tuple[list[float], str]] = {}
async def build(self, skill_registry: Any) -> None:
"""Build index from all registered skills."""
if skill_registry is None:
return
skills = skill_registry.list_skills()
for skill in skills:
await self.update_skill(skill.config.name, skill)
async def update_skill(self, skill_name: str, skill: Any) -> None:
"""Re-embed a single skill (on registration/update)."""
source_text = self._build_source_text(skill)
try:
embedding = await self._embedder.embed(source_text)
self._index[skill_name] = (embedding, source_text)
except Exception as e:
logger.warning(f"Failed to embed skill '{skill_name}': {e}")
def remove_skill(self, skill_name: str) -> None:
"""Remove a skill from the index."""
self._index.pop(skill_name, None)
async def search(self, query_embedding: list[float], top_k: int = 5) -> list[tuple[str, float]]:
"""Search for skills matching the query embedding.
Returns:
List of (skill_name, similarity) sorted by similarity descending.
"""
if not self._index:
return []
results: list[tuple[str, float]] = []
for skill_name, (emb, _) in self._index.items():
sim = compute_cosine_similarity(query_embedding, emb)
results.append((skill_name, sim))
results.sort(key=lambda x: x[1], reverse=True)
return results[:top_k]
@staticmethod
def _build_source_text(skill: Any) -> str:
"""Build embedding source text from skill metadata.
Combines description, intent keywords, and capability tags
for rich semantic representation.
"""
config = skill.config if hasattr(skill, "config") else skill
parts = []
# Description
description = getattr(config, "description", "") or ""
if description:
parts.append(description)
# Intent keywords
intent = getattr(config, "intent", None)
if intent and hasattr(intent, "keywords") and intent.keywords:
parts.append(" ".join(intent.keywords))
# Capability tags
capabilities = getattr(config, "capabilities", None)
if capabilities:
tags = []
for cap in capabilities:
if isinstance(cap, str):
tags.append(cap)
elif isinstance(cap, dict):
tags.append(cap.get("tag", ""))
elif hasattr(cap, "tag"):
tags.append(cap.tag)
if tags:
parts.append(" ".join(t for t in tags if t))
# Fallback: use skill name if no other text available
if not parts:
parts.append(getattr(config, "name", "unknown"))
return " | ".join(parts)
@property
def size(self) -> int:
"""Number of skills in the index."""
return len(self._index)
class SemanticRouter:
"""Embedding-based semantic routing as Layer 1.5.
Three confidence zones:
- similarity > similarity_high (0.85): HIGH → direct skill match, skip Layer 2
- similarity_low (0.6) <= similarity <= similarity_high: MEDIUM → skill hint for Layer 2
- similarity < similarity_low (0.6): LOW → no semantic signal, normal routing
"""
def __init__(
self,
embedder: Embedder,
similarity_high: float = 0.85,
similarity_low: float = 0.6,
):
self._embedder = embedder
self._similarity_high = similarity_high
self._similarity_low = similarity_low
self._index = SkillEmbeddingIndex(embedder)
self._query_cache = EmbeddingCache(max_size=500, ttl=1800)
async def build_index(self, skill_registry: Any) -> None:
"""Build skill embedding index from registry."""
await self._index.build(skill_registry)
logger.info(f"Semantic router index built: {self._index.size} skills")
async def update_skill(self, skill_name: str, skill: Any) -> None:
"""Update a single skill's embedding."""
await self._index.update_skill(skill_name, skill)
def remove_skill(self, skill_name: str) -> None:
"""Remove a skill from the index."""
self._index.remove_skill(skill_name)
async def route(self, query: str) -> SemanticRouteResult:
"""Route a query using semantic similarity.
Args:
query: User's input text.
Returns:
SemanticRouteResult with confidence, skill_name, and similarity.
"""
if self._index.size == 0:
return SemanticRouteResult(confidence="low", skill_name=None, similarity=0.0)
try:
# Get query embedding (with cache)
query_embedding = self._query_cache.get(query)
if query_embedding is None:
query_embedding = await self._embedder.embed(query)
self._query_cache.put(query, query_embedding)
# Search skill index
results = await self._index.search(query_embedding, top_k=1)
if not results:
return SemanticRouteResult(confidence="low", skill_name=None, similarity=0.0)
best_skill, best_sim = results[0]
if best_sim >= self._similarity_high:
return SemanticRouteResult(
confidence="high",
skill_name=best_skill,
similarity=best_sim,
)
elif best_sim >= self._similarity_low:
return SemanticRouteResult(
confidence="medium",
skill_name=best_skill,
similarity=best_sim,
)
else:
return SemanticRouteResult(
confidence="low",
skill_name=None,
similarity=best_sim,
)
except Exception as e:
logger.warning(f"Semantic routing failed, returning low confidence: {e}")
return SemanticRouteResult(confidence="low", skill_name=None, similarity=0.0)