225 lines
7.9 KiB
Python
225 lines
7.9 KiB
Python
"""Semantic Router — Embedding-based intent routing as Layer 1.5.
|
|
|
|
Uses pre-computed skill embeddings for zero-cost semantic matching,
|
|
inserted between Layer 1 (HeuristicClassifier) and Layer 2 (LLM classification)
|
|
in CostAwareRouter.
|
|
|
|
Design doc: docs/plans/2026-06-14-004-u3-semantic-router.md
|
|
"""
|
|
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
from agentkit.memory.embedder import Embedder, EmbeddingCache
|
|
from agentkit.utils.vector_math import compute_cosine_similarity
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class SemanticRouteResult:
|
|
"""Result of semantic routing."""
|
|
|
|
confidence: str # "high" | "medium" | "low"
|
|
skill_name: str | None
|
|
similarity: float
|
|
|
|
|
|
class SkillEmbeddingIndex:
|
|
"""Pre-computed embedding index for registered skills.
|
|
|
|
Embeddings are computed at skill registration time and cached.
|
|
Query-time search is O(n) cosine similarity scan, which is fast
|
|
for <100 skills with 1024-1536 dim vectors.
|
|
"""
|
|
|
|
def __init__(self, embedder: Embedder):
|
|
self._embedder = embedder
|
|
# skill_name → (embedding, source_text)
|
|
self._index: dict[str, tuple[list[float], str]] = {}
|
|
|
|
async def build(self, skill_registry: Any) -> None:
|
|
"""Build index from all registered skills."""
|
|
if skill_registry is None:
|
|
return
|
|
skills = skill_registry.list_skills()
|
|
for skill in skills:
|
|
await self.update_skill(skill.config.name, skill)
|
|
|
|
async def update_skill(self, skill_name: str, skill: Any) -> None:
|
|
"""Re-embed a single skill (on registration/update)."""
|
|
source_text = self._build_source_text(skill)
|
|
try:
|
|
embedding = await self._embedder.embed(source_text)
|
|
self._index[skill_name] = (embedding, source_text)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to embed skill '{skill_name}': {e}")
|
|
|
|
def remove_skill(self, skill_name: str) -> None:
|
|
"""Remove a skill from the index."""
|
|
self._index.pop(skill_name, None)
|
|
|
|
async def search(self, query_embedding: list[float], top_k: int = 5) -> list[tuple[str, float]]:
|
|
"""Search for skills matching the query embedding.
|
|
|
|
Returns:
|
|
List of (skill_name, similarity) sorted by similarity descending.
|
|
"""
|
|
if not self._index:
|
|
return []
|
|
|
|
results: list[tuple[str, float]] = []
|
|
for skill_name, (emb, _) in self._index.items():
|
|
sim = compute_cosine_similarity(query_embedding, emb)
|
|
results.append((skill_name, sim))
|
|
|
|
results.sort(key=lambda x: x[1], reverse=True)
|
|
return results[:top_k]
|
|
|
|
@staticmethod
|
|
def _build_source_text(skill: Any) -> str:
|
|
"""Build embedding source text from skill metadata.
|
|
|
|
Combines description, intent keywords, and capability tags
|
|
for rich semantic representation.
|
|
"""
|
|
config = skill.config if hasattr(skill, "config") else skill
|
|
parts = []
|
|
|
|
# Description
|
|
description = getattr(config, "description", "") or ""
|
|
if description:
|
|
parts.append(description)
|
|
|
|
# Intent keywords
|
|
intent = getattr(config, "intent", None)
|
|
if intent and hasattr(intent, "keywords") and intent.keywords:
|
|
parts.append(" ".join(intent.keywords))
|
|
|
|
# Intent examples (rich semantic signal for short queries)
|
|
if intent and hasattr(intent, "examples") and intent.examples:
|
|
parts.append(" ".join(intent.examples))
|
|
|
|
# Capability tags
|
|
capabilities = getattr(config, "capabilities", None)
|
|
if capabilities:
|
|
tags = []
|
|
for cap in capabilities:
|
|
if isinstance(cap, str):
|
|
tags.append(cap)
|
|
elif isinstance(cap, dict):
|
|
tags.append(cap.get("tag", ""))
|
|
elif hasattr(cap, "tag"):
|
|
tags.append(cap.tag)
|
|
if tags:
|
|
parts.append(" ".join(t for t in tags if t))
|
|
|
|
# Fallback: use skill name if no other text available
|
|
if not parts:
|
|
parts.append(getattr(config, "name", "unknown"))
|
|
|
|
return " | ".join(parts)
|
|
|
|
@property
|
|
def size(self) -> int:
|
|
"""Number of skills in the index."""
|
|
return len(self._index)
|
|
|
|
|
|
class SemanticRouter:
|
|
"""Embedding-based semantic routing as Layer 1.5.
|
|
|
|
Three confidence zones:
|
|
- similarity > similarity_high (0.85): HIGH → direct skill match, skip Layer 2
|
|
- similarity_low (0.4) <= similarity <= similarity_high: MEDIUM → skill hint for Layer 2
|
|
- similarity < similarity_low (0.4): LOW → no semantic signal, normal routing
|
|
|
|
Short text (<20 chars) uses a lower effective threshold because
|
|
brief queries naturally have lower embedding similarity.
|
|
"""
|
|
|
|
_SHORT_TEXT_THRESHOLD = 20 # chars
|
|
|
|
def __init__(
|
|
self,
|
|
embedder: Embedder,
|
|
similarity_high: float = 0.85,
|
|
similarity_low: float = 0.4,
|
|
):
|
|
self._embedder = embedder
|
|
self._similarity_high = similarity_high
|
|
self._similarity_low = similarity_low
|
|
self._index = SkillEmbeddingIndex(embedder)
|
|
self._query_cache = EmbeddingCache(max_size=500, ttl=1800)
|
|
|
|
async def build_index(self, skill_registry: Any) -> None:
|
|
"""Build skill embedding index from registry."""
|
|
await self._index.build(skill_registry)
|
|
logger.info(f"Semantic router index built: {self._index.size} skills")
|
|
|
|
async def update_skill(self, skill_name: str, skill: Any) -> None:
|
|
"""Update a single skill's embedding."""
|
|
await self._index.update_skill(skill_name, skill)
|
|
|
|
def remove_skill(self, skill_name: str) -> None:
|
|
"""Remove a skill from the index."""
|
|
self._index.remove_skill(skill_name)
|
|
|
|
async def route(self, query: str) -> SemanticRouteResult:
|
|
"""Route a query using semantic similarity.
|
|
|
|
Args:
|
|
query: User's input text.
|
|
|
|
Returns:
|
|
SemanticRouteResult with confidence, skill_name, and similarity.
|
|
"""
|
|
if self._index.size == 0:
|
|
return SemanticRouteResult(confidence="low", skill_name=None, similarity=0.0)
|
|
|
|
if not query or not query.strip():
|
|
return SemanticRouteResult(confidence="low", skill_name=None, similarity=0.0)
|
|
|
|
try:
|
|
# Get query embedding (with cache)
|
|
query_embedding = self._query_cache.get(query)
|
|
if query_embedding is None:
|
|
query_embedding = await self._embedder.embed(query)
|
|
self._query_cache.put(query, query_embedding)
|
|
|
|
# Search skill index
|
|
results = await self._index.search(query_embedding, top_k=1)
|
|
if not results:
|
|
return SemanticRouteResult(confidence="low", skill_name=None, similarity=0.0)
|
|
|
|
best_skill, best_sim = results[0]
|
|
|
|
# Short text uses lower effective threshold
|
|
effective_low = self._similarity_low
|
|
if len(query) < self._SHORT_TEXT_THRESHOLD:
|
|
effective_low = max(0.25, self._similarity_low - 0.15)
|
|
|
|
if best_sim >= self._similarity_high:
|
|
return SemanticRouteResult(
|
|
confidence="high",
|
|
skill_name=best_skill,
|
|
similarity=best_sim,
|
|
)
|
|
elif best_sim >= effective_low:
|
|
return SemanticRouteResult(
|
|
confidence="medium",
|
|
skill_name=best_skill,
|
|
similarity=best_sim,
|
|
)
|
|
else:
|
|
return SemanticRouteResult(
|
|
confidence="low",
|
|
skill_name=None,
|
|
similarity=best_sim,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Semantic routing failed, returning low confidence: {e}")
|
|
return SemanticRouteResult(confidence="low", skill_name=None, similarity=0.0)
|