feat: P0 production hardening — LLM cache, semantic routing, state persistence
U1: LLM Cache Core (exact + semantic match, InMemory + Redis backends)
U2: Cache integration into LLMGateway with CacheConfig
U3: Semantic Router as Layer 1.5 in CostAwareRouter
U4: UsageStore persistence (Redis Hash + InMemory fallback)
U5: CascadeStateStore persistence (Redis INCR + InMemory TTL)
U6: EvolutionStore interface unification (Protocol + PostgreSQL backend)
U7: Configuration integration + E2E tests
Code review fixes:
- P0: date iteration bug (day>=28), semantic router index never built,
Redis connection leak (per-call → persistent pool)
- P1: cache degradation recovery, semantic_search degradation,
double miss counting, asyncio.Lock for PG init, LIMIT on queries,
__import__ anti-pattern → _utcnow()
- P2: InMemory TTL cleanup, embedding preservation on put(),
data TTL = max(exact_ttl, semantic_ttl)
This commit is contained in:
parent
4707fd00ba
commit
0ccef7be5c
|
|
@ -35,6 +35,7 @@ build/pyinstaller-work/
|
|||
|
||||
# Frontend build artifacts
|
||||
src/agentkit/server/static/
|
||||
**/node_modules/
|
||||
|
||||
# Env
|
||||
.env
|
||||
|
|
|
|||
|
|
@ -5,23 +5,12 @@ server:
|
|||
rate_limit: 60
|
||||
llm:
|
||||
providers:
|
||||
bailian-coding:
|
||||
api_key: ${DASHSCOPE_API_KEY}
|
||||
base_url: https://coding.dashscope.aliyuncs.com/v1
|
||||
test:
|
||||
type: openai
|
||||
models:
|
||||
qwen3.7-plus:
|
||||
alias: default
|
||||
qwen3.6-plus: {}
|
||||
qwen3.5-plus: {}
|
||||
qwen3-max-2026-01-23: {}
|
||||
qwen3-coder-plus:
|
||||
alias: coder
|
||||
qwen3-coder-next: {}
|
||||
kimi-k2.5: {}
|
||||
glm-5: {}
|
||||
glm-4.7: {}
|
||||
MiniMax-M2.5: {}
|
||||
base_url: ''
|
||||
max_tokens: 4096
|
||||
timeout: 120.0
|
||||
api_key: ''
|
||||
model_aliases:
|
||||
default: bailian-coding/qwen3.7-plus
|
||||
coder: bailian-coding/qwen3-coder-plus
|
||||
|
|
|
|||
|
|
@ -0,0 +1,207 @@
|
|||
"""Semantic Router — Embedding-based intent routing as Layer 1.5.
|
||||
|
||||
Uses pre-computed skill embeddings for zero-cost semantic matching,
|
||||
inserted between Layer 1 (HeuristicClassifier) and Layer 2 (LLM classification)
|
||||
in CostAwareRouter.
|
||||
|
||||
Design doc: docs/plans/2026-06-14-004-u3-semantic-router.md
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from agentkit.memory.embedder import Embedder, EmbeddingCache
|
||||
from agentkit.utils.vector_math import compute_cosine_similarity
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SemanticRouteResult:
|
||||
"""Result of semantic routing."""
|
||||
|
||||
confidence: str # "high" | "medium" | "low"
|
||||
skill_name: str | None
|
||||
similarity: float
|
||||
|
||||
|
||||
class SkillEmbeddingIndex:
|
||||
"""Pre-computed embedding index for registered skills.
|
||||
|
||||
Embeddings are computed at skill registration time and cached.
|
||||
Query-time search is O(n) cosine similarity scan, which is fast
|
||||
for <100 skills with 1024-1536 dim vectors.
|
||||
"""
|
||||
|
||||
def __init__(self, embedder: Embedder):
|
||||
self._embedder = embedder
|
||||
# skill_name → (embedding, source_text)
|
||||
self._index: dict[str, tuple[list[float], str]] = {}
|
||||
|
||||
async def build(self, skill_registry: Any) -> None:
|
||||
"""Build index from all registered skills."""
|
||||
if skill_registry is None:
|
||||
return
|
||||
skills = skill_registry.list_skills()
|
||||
for skill in skills:
|
||||
await self.update_skill(skill.config.name, skill)
|
||||
|
||||
async def update_skill(self, skill_name: str, skill: Any) -> None:
|
||||
"""Re-embed a single skill (on registration/update)."""
|
||||
source_text = self._build_source_text(skill)
|
||||
try:
|
||||
embedding = await self._embedder.embed(source_text)
|
||||
self._index[skill_name] = (embedding, source_text)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to embed skill '{skill_name}': {e}")
|
||||
|
||||
def remove_skill(self, skill_name: str) -> None:
|
||||
"""Remove a skill from the index."""
|
||||
self._index.pop(skill_name, None)
|
||||
|
||||
async def search(self, query_embedding: list[float], top_k: int = 5) -> list[tuple[str, float]]:
|
||||
"""Search for skills matching the query embedding.
|
||||
|
||||
Returns:
|
||||
List of (skill_name, similarity) sorted by similarity descending.
|
||||
"""
|
||||
if not self._index:
|
||||
return []
|
||||
|
||||
results: list[tuple[str, float]] = []
|
||||
for skill_name, (emb, _) in self._index.items():
|
||||
sim = compute_cosine_similarity(query_embedding, emb)
|
||||
results.append((skill_name, sim))
|
||||
|
||||
results.sort(key=lambda x: x[1], reverse=True)
|
||||
return results[:top_k]
|
||||
|
||||
@staticmethod
|
||||
def _build_source_text(skill: Any) -> str:
|
||||
"""Build embedding source text from skill metadata.
|
||||
|
||||
Combines description, intent keywords, and capability tags
|
||||
for rich semantic representation.
|
||||
"""
|
||||
config = skill.config if hasattr(skill, "config") else skill
|
||||
parts = []
|
||||
|
||||
# Description
|
||||
description = getattr(config, "description", "") or ""
|
||||
if description:
|
||||
parts.append(description)
|
||||
|
||||
# Intent keywords
|
||||
intent = getattr(config, "intent", None)
|
||||
if intent and hasattr(intent, "keywords") and intent.keywords:
|
||||
parts.append(" ".join(intent.keywords))
|
||||
|
||||
# Capability tags
|
||||
capabilities = getattr(config, "capabilities", None)
|
||||
if capabilities:
|
||||
tags = []
|
||||
for cap in capabilities:
|
||||
if isinstance(cap, str):
|
||||
tags.append(cap)
|
||||
elif isinstance(cap, dict):
|
||||
tags.append(cap.get("tag", ""))
|
||||
elif hasattr(cap, "tag"):
|
||||
tags.append(cap.tag)
|
||||
if tags:
|
||||
parts.append(" ".join(t for t in tags if t))
|
||||
|
||||
# Fallback: use skill name if no other text available
|
||||
if not parts:
|
||||
parts.append(getattr(config, "name", "unknown"))
|
||||
|
||||
return " | ".join(parts)
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
"""Number of skills in the index."""
|
||||
return len(self._index)
|
||||
|
||||
|
||||
class SemanticRouter:
|
||||
"""Embedding-based semantic routing as Layer 1.5.
|
||||
|
||||
Three confidence zones:
|
||||
- similarity > similarity_high (0.85): HIGH → direct skill match, skip Layer 2
|
||||
- similarity_low (0.6) <= similarity <= similarity_high: MEDIUM → skill hint for Layer 2
|
||||
- similarity < similarity_low (0.6): LOW → no semantic signal, normal routing
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedder: Embedder,
|
||||
similarity_high: float = 0.85,
|
||||
similarity_low: float = 0.6,
|
||||
):
|
||||
self._embedder = embedder
|
||||
self._similarity_high = similarity_high
|
||||
self._similarity_low = similarity_low
|
||||
self._index = SkillEmbeddingIndex(embedder)
|
||||
self._query_cache = EmbeddingCache(max_size=500, ttl=1800)
|
||||
|
||||
async def build_index(self, skill_registry: Any) -> None:
|
||||
"""Build skill embedding index from registry."""
|
||||
await self._index.build(skill_registry)
|
||||
logger.info(f"Semantic router index built: {self._index.size} skills")
|
||||
|
||||
async def update_skill(self, skill_name: str, skill: Any) -> None:
|
||||
"""Update a single skill's embedding."""
|
||||
await self._index.update_skill(skill_name, skill)
|
||||
|
||||
def remove_skill(self, skill_name: str) -> None:
|
||||
"""Remove a skill from the index."""
|
||||
self._index.remove_skill(skill_name)
|
||||
|
||||
async def route(self, query: str) -> SemanticRouteResult:
|
||||
"""Route a query using semantic similarity.
|
||||
|
||||
Args:
|
||||
query: User's input text.
|
||||
|
||||
Returns:
|
||||
SemanticRouteResult with confidence, skill_name, and similarity.
|
||||
"""
|
||||
if self._index.size == 0:
|
||||
return SemanticRouteResult(confidence="low", skill_name=None, similarity=0.0)
|
||||
|
||||
try:
|
||||
# Get query embedding (with cache)
|
||||
query_embedding = self._query_cache.get(query)
|
||||
if query_embedding is None:
|
||||
query_embedding = await self._embedder.embed(query)
|
||||
self._query_cache.put(query, query_embedding)
|
||||
|
||||
# Search skill index
|
||||
results = await self._index.search(query_embedding, top_k=1)
|
||||
if not results:
|
||||
return SemanticRouteResult(confidence="low", skill_name=None, similarity=0.0)
|
||||
|
||||
best_skill, best_sim = results[0]
|
||||
|
||||
if best_sim >= self._similarity_high:
|
||||
return SemanticRouteResult(
|
||||
confidence="high",
|
||||
skill_name=best_skill,
|
||||
similarity=best_sim,
|
||||
)
|
||||
elif best_sim >= self._similarity_low:
|
||||
return SemanticRouteResult(
|
||||
confidence="medium",
|
||||
skill_name=best_skill,
|
||||
similarity=best_sim,
|
||||
)
|
||||
else:
|
||||
return SemanticRouteResult(
|
||||
confidence="low",
|
||||
skill_name=None,
|
||||
similarity=best_sim,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Semantic routing failed, returning low confidence: {e}")
|
||||
return SemanticRouteResult(confidence="low", skill_name=None, similarity=0.0)
|
||||
|
|
@ -6,6 +6,7 @@ and prompt assembly into a single module used by both chat routes.
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
|
|
@ -21,6 +22,19 @@ logger = logging.getLogger(__name__)
|
|||
_SKILL_NAME_RE = re.compile(r"^[a-z0-9][a-z0-9_-]{0,63}$")
|
||||
|
||||
|
||||
class ExecutionMode(enum.Enum):
|
||||
"""How the downstream should execute this routing result.
|
||||
|
||||
This is the single source of truth for execution path selection.
|
||||
The transport layer (portal.py, chat.py) should branch on this
|
||||
field instead of string-matching match_method.
|
||||
"""
|
||||
|
||||
DIRECT_CHAT = "direct_chat" # Zero-cost: direct LLM call, no ReAct loop
|
||||
REACT = "react" # Default agent ReAct loop with default tools
|
||||
SKILL_REACT = "skill_react" # Skill-matched ReAct with skill tools + prompt
|
||||
|
||||
|
||||
def validate_skill_name(name: str) -> str:
|
||||
"""Validate and normalize a skill name. Raises ValueError on invalid input."""
|
||||
normalized = name.strip().lower()
|
||||
|
|
@ -49,6 +63,7 @@ class SkillRoutingResult:
|
|||
transparency_level: str = "SILENT"
|
||||
execution_trace: list[dict] = field(default_factory=list)
|
||||
complexity: float = 0.0
|
||||
execution_mode: ExecutionMode = ExecutionMode.DIRECT_CHAT
|
||||
|
||||
|
||||
def parse_skill_prefix(content: str) -> tuple[str | None, str]:
|
||||
|
|
@ -88,6 +103,7 @@ async def resolve_skill_routing(
|
|||
default_agent_name: str = "default",
|
||||
agent_tool_registry: Any = None,
|
||||
session_id: str = "",
|
||||
force_skill: str | None = None,
|
||||
) -> SkillRoutingResult:
|
||||
"""Resolve skill routing for a user message.
|
||||
|
||||
|
|
@ -120,6 +136,20 @@ async def resolve_skill_routing(
|
|||
result.skill_name = None
|
||||
result.skill_config = None
|
||||
|
||||
# Try force_skill match (from semantic router high confidence)
|
||||
if not result.matched and force_skill and skill_registry:
|
||||
try:
|
||||
matched_skill = skill_registry.get(force_skill)
|
||||
result.skill_name = force_skill
|
||||
result.skill_config = matched_skill.config
|
||||
result.skill_tools = matched_skill.tools or []
|
||||
result.matched = True
|
||||
result.match_method = "semantic_force"
|
||||
result.match_confidence = 1.0
|
||||
logger.info(f"Session {session_id}: using force-matched skill '{force_skill}'")
|
||||
except Exception as e:
|
||||
logger.warning(f"Session {session_id}: force skill '{force_skill}' not found: {e}")
|
||||
|
||||
# Try IntentRouter if no explicit match
|
||||
if not result.matched and skill_registry and intent_router:
|
||||
skills = skill_registry.list_skills()
|
||||
|
|
@ -205,11 +235,14 @@ async def resolve_skill_routing(
|
|||
|
||||
result.model = result.skill_config.llm.get("model", default_model) if result.skill_config.llm else default_model
|
||||
result.agent_name = result.skill_name
|
||||
result.execution_mode = ExecutionMode.SKILL_REACT
|
||||
else:
|
||||
result.system_prompt = default_system_prompt
|
||||
result.tools = default_tools
|
||||
result.model = default_model
|
||||
result.agent_name = default_agent_name
|
||||
# No skill matched — if we have tools, use ReAct; otherwise direct chat
|
||||
result.execution_mode = ExecutionMode.REACT if default_tools else ExecutionMode.DIRECT_CHAT
|
||||
|
||||
# Append available tools to system prompt so LLM knows what it can call
|
||||
if result.tools:
|
||||
|
|
@ -257,6 +290,14 @@ _CHAT_MODE_RE = re.compile(
|
|||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Simple identity/meta questions — zero-cost direct chat, no skill routing needed
|
||||
_IDENTITY_RE = re.compile(
|
||||
r"^(你是谁|你叫什么|你是什么|你是哪个|who are you|what are you|what's your name"
|
||||
r"|介绍一下你自己|自我介绍|你叫啥|你叫什么名字|你的名字)"
|
||||
r"\s*[??!!.。]*$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
_SENTENCE_SPLIT_RE = re.compile(r'[,。!?;\n,.!?;]')
|
||||
|
||||
|
||||
|
|
@ -319,8 +360,9 @@ class HeuristicClassifier:
|
|||
}
|
||||
|
||||
# 中等复杂度暗示词(简单问题但需思考)
|
||||
# 注意:不包含"怎么",因为"怎么样"是闲聊而非工具需求
|
||||
_MEDIUM_COMPLEXITY_HINTS_CN = {
|
||||
"如何", "怎么", "怎样", "为什么", "什么原因", "区别",
|
||||
"如何", "怎样", "为什么", "什么原因", "区别",
|
||||
"推荐", "建议", "选择", "哪个",
|
||||
}
|
||||
|
||||
|
|
@ -428,6 +470,7 @@ class CostAwareRouter:
|
|||
auction_enabled: bool = False,
|
||||
classifier: str = "heuristic",
|
||||
merged_llm_classify: bool = True,
|
||||
semantic_router: Any = None, # SemanticRouter | None
|
||||
):
|
||||
self._llm_gateway = llm_gateway
|
||||
self._model = model
|
||||
|
|
@ -435,6 +478,7 @@ class CostAwareRouter:
|
|||
self._auction_enabled = auction_enabled
|
||||
self._classifier = classifier
|
||||
self._merged_llm_classify = merged_llm_classify
|
||||
self._semantic_router = semantic_router
|
||||
self._auction_house = AuctionHouse() if auction_enabled else None
|
||||
if classifier not in ("heuristic", "llm"):
|
||||
raise ValueError(f"Invalid classifier: {classifier!r}, must be 'heuristic' or 'llm'")
|
||||
|
|
@ -462,6 +506,10 @@ class CostAwareRouter:
|
|||
if _CHAT_MODE_RE.match(stripped):
|
||||
return "chat_mode", stripped
|
||||
|
||||
# 身份/元问题模式("你是谁"等)— 零成本直接对话
|
||||
if _IDENTITY_RE.match(stripped):
|
||||
return "identity", stripped
|
||||
|
||||
return None, stripped
|
||||
|
||||
# -- Layer 1: LLM quick classify (~100 tokens) -------------------------
|
||||
|
|
@ -577,6 +625,7 @@ class CostAwareRouter:
|
|||
match_method="merged_llm",
|
||||
match_confidence=0.7,
|
||||
complexity=merged_complexity,
|
||||
execution_mode=ExecutionMode.SKILL_REACT,
|
||||
)
|
||||
# Merge tools
|
||||
agent_tools = agent_tool_registry.list_tools() if agent_tool_registry else default_tools
|
||||
|
|
@ -590,6 +639,18 @@ class CostAwareRouter:
|
|||
result.model = result.skill_config.llm.get("model", default_model) if result.skill_config.llm else default_model
|
||||
result.agent_name = skill_hint
|
||||
result.system_prompt = build_skill_system_prompt(result.skill_config) or default_system_prompt
|
||||
# Append available tools to system prompt so LLM knows what it can call
|
||||
if result.tools:
|
||||
tools_desc = _build_tools_description(result.tools)
|
||||
tool_instruction = (
|
||||
"\n\n## Tool Usage\n"
|
||||
"You have access to the following tools. When you need to use a tool, "
|
||||
"respond with a tool call in the format specified by the system.\n"
|
||||
"Never make up information or guess answers when you can use a tool to find the answer.\n"
|
||||
"Always prefer using tools over guessing.\n"
|
||||
)
|
||||
if result.system_prompt:
|
||||
result.system_prompt += f"{tool_instruction}\n## Available Tools\n{tools_desc}"
|
||||
logger.info(
|
||||
f"Session {session_id}: merged LLM classify routed to skill '{skill_hint}' "
|
||||
f"(complexity={merged_complexity:.2f})"
|
||||
|
|
@ -610,6 +671,7 @@ class CostAwareRouter:
|
|||
match_method="merged_llm_low",
|
||||
match_confidence=1.0 - merged_complexity,
|
||||
complexity=merged_complexity,
|
||||
execution_mode=ExecutionMode.DIRECT_CHAT,
|
||||
)
|
||||
elif merged_complexity > 0.7:
|
||||
# High complexity — delegate to Layer 2
|
||||
|
|
@ -623,6 +685,7 @@ class CostAwareRouter:
|
|||
match_method="merged_llm_high",
|
||||
match_confidence=merged_complexity,
|
||||
complexity=merged_complexity,
|
||||
execution_mode=ExecutionMode.REACT,
|
||||
)
|
||||
else:
|
||||
# Medium complexity, no skill match — default agent
|
||||
|
|
@ -636,6 +699,7 @@ class CostAwareRouter:
|
|||
match_method="merged_llm_medium",
|
||||
match_confidence=0.5,
|
||||
complexity=merged_complexity,
|
||||
execution_mode=ExecutionMode.REACT,
|
||||
)
|
||||
except (json.JSONDecodeError, TypeError, ValueError) as e:
|
||||
logger.warning(f"CostAwareRouter _classify_merged parse failed: {e}, falling back to default")
|
||||
|
|
@ -649,6 +713,7 @@ class CostAwareRouter:
|
|||
match_method="merged_llm_fallback",
|
||||
match_confidence=0.5,
|
||||
complexity=0.5,
|
||||
execution_mode=ExecutionMode.REACT,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"CostAwareRouter _classify_merged failed: {e}, falling back to default")
|
||||
|
|
@ -662,6 +727,7 @@ class CostAwareRouter:
|
|||
match_method="merged_llm_fallback",
|
||||
match_confidence=0.5,
|
||||
complexity=0.5,
|
||||
execution_mode=ExecutionMode.REACT,
|
||||
)
|
||||
|
||||
# -- Layer 2: Capability matching / Auction (optional) -----------------
|
||||
|
|
@ -746,6 +812,7 @@ class CostAwareRouter:
|
|||
system_prompt=default_system_prompt,
|
||||
tools=default_tools,
|
||||
complexity=complexity,
|
||||
execution_mode=ExecutionMode.REACT,
|
||||
)
|
||||
if trace is not None:
|
||||
trace.append({
|
||||
|
|
@ -776,6 +843,7 @@ class CostAwareRouter:
|
|||
system_prompt=default_system_prompt,
|
||||
tools=default_tools,
|
||||
complexity=complexity,
|
||||
execution_mode=ExecutionMode.REACT,
|
||||
)
|
||||
if trace is not None:
|
||||
trace.append({
|
||||
|
|
@ -876,7 +944,7 @@ class CostAwareRouter:
|
|||
span.set_attribute("route.target", result.skill_name or "default")
|
||||
return result
|
||||
|
||||
if match_type in ("greeting", "chat_mode"):
|
||||
if match_type in ("greeting", "chat_mode", "identity"):
|
||||
result = SkillRoutingResult(
|
||||
clean_content=clean_content,
|
||||
system_prompt=default_system_prompt,
|
||||
|
|
@ -887,6 +955,7 @@ class CostAwareRouter:
|
|||
match_method=match_type,
|
||||
match_confidence=1.0,
|
||||
complexity=0.0,
|
||||
execution_mode=ExecutionMode.DIRECT_CHAT,
|
||||
)
|
||||
trace.append({
|
||||
"layer": 0,
|
||||
|
|
@ -916,7 +985,7 @@ class CostAwareRouter:
|
|||
"complexity": complexity,
|
||||
})
|
||||
|
||||
# Low complexity → default agent
|
||||
# Low complexity → direct chat
|
||||
if complexity < 0.3:
|
||||
result = SkillRoutingResult(
|
||||
clean_content=clean_content,
|
||||
|
|
@ -928,6 +997,7 @@ class CostAwareRouter:
|
|||
match_method="low_complexity",
|
||||
match_confidence=1.0 - complexity,
|
||||
complexity=complexity,
|
||||
execution_mode=ExecutionMode.DIRECT_CHAT,
|
||||
)
|
||||
trace.append({
|
||||
"layer": 1,
|
||||
|
|
@ -941,6 +1011,59 @@ class CostAwareRouter:
|
|||
span.set_attribute("route.target", "default")
|
||||
return result
|
||||
|
||||
# ---- Layer 1.5: Semantic Router (zero LLM cost) ----
|
||||
skill_hint = None
|
||||
if self._semantic_router is not None and complexity >= 0.3:
|
||||
try:
|
||||
semantic_result = await self._semantic_router.route(clean_content)
|
||||
if semantic_result.confidence == "high" and semantic_result.skill_name:
|
||||
# Direct skill match — skip Layer 2
|
||||
trace.append({
|
||||
"layer": 1.5,
|
||||
"method": "semantic_high",
|
||||
"skill": semantic_result.skill_name,
|
||||
"similarity": round(semantic_result.similarity, 3),
|
||||
"cost": "zero",
|
||||
})
|
||||
result = await resolve_skill_routing(
|
||||
content=content,
|
||||
skill_registry=skill_registry,
|
||||
intent_router=intent_router,
|
||||
default_tools=default_tools,
|
||||
default_system_prompt=default_system_prompt,
|
||||
default_model=default_model,
|
||||
default_agent_name=default_agent_name,
|
||||
agent_tool_registry=agent_tool_registry,
|
||||
session_id=session_id,
|
||||
force_skill=semantic_result.skill_name,
|
||||
)
|
||||
result.match_method = "semantic_high"
|
||||
result.match_confidence = semantic_result.similarity
|
||||
result.complexity = complexity
|
||||
if result.matched:
|
||||
result.execution_mode = ExecutionMode.SKILL_REACT
|
||||
result.execution_trace = trace if transparency != "SILENT" else []
|
||||
result.transparency_level = transparency
|
||||
span.set_attribute("route.layer", "semantic_high")
|
||||
span.set_attribute("route.target", result.skill_name or "default")
|
||||
return result
|
||||
elif semantic_result.confidence == "medium" and semantic_result.skill_name:
|
||||
# Pass skill hint to Layer 1.5 merged classify or Layer 2
|
||||
skill_hint = semantic_result.skill_name
|
||||
trace.append({
|
||||
"layer": 1.5,
|
||||
"method": "semantic_medium",
|
||||
"skill_hint": skill_hint,
|
||||
"similarity": round(semantic_result.similarity, 3),
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"Semantic routing failed, falling through: {e}")
|
||||
trace.append({
|
||||
"layer": 1.5,
|
||||
"method": "semantic_error",
|
||||
"error": str(e),
|
||||
})
|
||||
|
||||
# Medium complexity → merged LLM classify or IntentRouter
|
||||
if complexity <= 0.7:
|
||||
if self._merged_llm_classify and self._llm_gateway is not None:
|
||||
|
|
@ -994,7 +1117,7 @@ class CostAwareRouter:
|
|||
agent_tool_registry=agent_tool_registry,
|
||||
session_id=session_id,
|
||||
)
|
||||
result.complexity = result.complexity or complexity
|
||||
result.complexity = result.complexity if result.complexity > 0 else complexity
|
||||
trace.append({
|
||||
"layer": 1,
|
||||
"method": result.match_method or "merged_llm",
|
||||
|
|
|
|||
|
|
@ -685,7 +685,7 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
|
|||
retrieval_config = self._config.memory.get("retrieval", {}) if self._config.memory else {}
|
||||
result = await self._react_engine.execute(
|
||||
messages=user_messages,
|
||||
tools=self._tools if self._tools else None,
|
||||
tools=self.get_tools() or None,
|
||||
model=self._config.llm.get("model", "default") if self._config.llm else "default",
|
||||
agent_name=self.name,
|
||||
task_type=task.task_type,
|
||||
|
|
@ -735,7 +735,7 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
|
|||
|
||||
result = await rewoo_engine.execute(
|
||||
messages=user_messages,
|
||||
tools=self._tools if self._tools else None,
|
||||
tools=self.get_tools() or None,
|
||||
model=self._config.llm.get("model", "default") if self._config.llm else "default",
|
||||
agent_name=self.name,
|
||||
task_type=task.task_type,
|
||||
|
|
@ -781,7 +781,7 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
|
|||
|
||||
result = await plan_exec_engine.execute(
|
||||
messages=user_messages,
|
||||
tools=self._tools if self._tools else None,
|
||||
tools=self.get_tools() or None,
|
||||
model=self._config.llm.get("model", "default") if self._config.llm else "default",
|
||||
agent_name=self.name,
|
||||
task_type=task.task_type,
|
||||
|
|
@ -829,7 +829,7 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
|
|||
|
||||
result = await reflexion_engine.execute(
|
||||
messages=user_messages,
|
||||
tools=self._tools if self._tools else None,
|
||||
tools=self.get_tools() or None,
|
||||
model=self._config.llm.get("model", "default") if self._config.llm else "default",
|
||||
agent_name=self.name,
|
||||
task_type=task.task_type,
|
||||
|
|
|
|||
|
|
@ -441,7 +441,14 @@ class ReActEngine:
|
|||
except Exception as e:
|
||||
tool_result = {"error": f"Tool '{tc.name}' execution failed: {e}"}
|
||||
else:
|
||||
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
|
||||
# Non-dangerous tool: confirmation was for the overall action,
|
||||
# re-execute with skip flag to avoid re-triggering confirmation
|
||||
clean_args = {k: v for k, v in tc.arguments.items() if not k.startswith("_")}
|
||||
clean_args["_skip_dangerous_check"] = True
|
||||
try:
|
||||
tool_result = await tool.safe_execute(**clean_args) if tool else {"error": f"Tool '{tc.name}' not found"}
|
||||
except Exception as e:
|
||||
tool_result = {"error": f"Tool '{tc.name}' execution failed: {e}"}
|
||||
else:
|
||||
tool_result = {
|
||||
"output": "",
|
||||
|
|
@ -905,7 +912,13 @@ class ReActEngine:
|
|||
finally:
|
||||
pass # No shared state mutation needed
|
||||
else:
|
||||
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
|
||||
# Non-dangerous tool: re-execute with skip flag
|
||||
clean_args = {k: v for k, v in tc.arguments.items() if not k.startswith("_")}
|
||||
clean_args["_skip_dangerous_check"] = True
|
||||
try:
|
||||
tool_result = await tool.safe_execute(**clean_args) if tool else {"error": f"Tool '{tc.name}' not found"}
|
||||
except Exception as e:
|
||||
tool_result = {"error": f"Tool '{tc.name}' execution failed: {e}"}
|
||||
|
||||
yield ReActEvent(
|
||||
event_type="confirmation_result",
|
||||
|
|
@ -1261,7 +1274,13 @@ class ReActEngine:
|
|||
except Exception as e:
|
||||
tool_result = {"error": f"Tool '{tc.name}' execution failed: {e}"}
|
||||
else:
|
||||
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
|
||||
# Non-dangerous tool: re-execute with skip flag
|
||||
clean_args = {k: v for k, v in tc.arguments.items() if not k.startswith("_")}
|
||||
clean_args["_skip_dangerous_check"] = True
|
||||
try:
|
||||
tool_result = await tool.safe_execute(**clean_args) if tool else {"error": f"Tool '{tc.name}' not found"}
|
||||
except Exception as e:
|
||||
tool_result = {"error": f"Tool '{tc.name}' execution failed: {e}"}
|
||||
|
||||
events.append(ReActEvent(
|
||||
event_type="confirmation_result",
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from agentkit.evolution.strategy_tuner import StrategyTuner
|
|||
from agentkit.evolution.ab_tester import ABTester
|
||||
from agentkit.evolution.evolution_store import (
|
||||
EvolutionStore,
|
||||
EvolutionStoreProtocol,
|
||||
InMemoryEvolutionStore,
|
||||
PersistentEvolutionStore,
|
||||
create_evolution_store,
|
||||
|
|
@ -30,6 +31,7 @@ __all__ = [
|
|||
"StrategyTuner",
|
||||
"ABTester",
|
||||
"EvolutionStore",
|
||||
"EvolutionStoreProtocol",
|
||||
"PersistentEvolutionStore",
|
||||
"InMemoryEvolutionStore",
|
||||
"create_evolution_store",
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
"""EvolutionStore - 进化日志存储
|
||||
|
||||
提供三种后端实现:
|
||||
- EvolutionStore: 基于外部注入的异步 SQLAlchemy session(原有实现)
|
||||
- PersistentEvolutionStore: 基于 SQLite 的持久化存储
|
||||
- InMemoryEvolutionStore: 基于内存字典的轻量存储(用于测试)
|
||||
提供统一 Protocol 和四种后端实现:
|
||||
- EvolutionStoreProtocol: 统一接口 Protocol(所有后端必须实现)
|
||||
- EvolutionStore: 基于外部注入的异步 SQLAlchemy session(原有实现,仅事件操作)
|
||||
- PersistentEvolutionStore: 基于 SQLite 的持久化存储(完整 Protocol)
|
||||
- InMemoryEvolutionStore: 基于内存字典的轻量存储(完整 Protocol,用于测试)
|
||||
- PostgreSQLEvolutionStore: 基于 PostgreSQL 的异步持久化存储(完整 Protocol,见 pg_store.py)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
|
@ -13,7 +15,7 @@ import os
|
|||
import time
|
||||
import uuid as _uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
from sqlalchemy import create_engine, event as sa_event, select
|
||||
from sqlalchemy.exc import OperationalError
|
||||
|
|
@ -30,6 +32,34 @@ from agentkit.evolution.models import (
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ── 统一 Protocol ─────────────────────────────────────────
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class EvolutionStoreProtocol(Protocol):
|
||||
"""进化存储统一接口 Protocol
|
||||
|
||||
所有后端必须实现以下方法。不支持的操作应抛出 NotImplementedError。
|
||||
"""
|
||||
|
||||
async def record(self, event: EvolutionEvent) -> str: ...
|
||||
async def rollback(self, event_id: str) -> bool: ...
|
||||
async def list_events(
|
||||
self,
|
||||
agent_name: str | None = ...,
|
||||
change_type: str | None = ...,
|
||||
status: str | None = ...,
|
||||
) -> list[dict]: ...
|
||||
async def record_skill_version(
|
||||
self, skill_name: str, version: str, content: str, parent_version: str | None = ...
|
||||
) -> str: ...
|
||||
async def list_skill_versions(self, skill_name: str) -> list[dict]: ...
|
||||
async def record_ab_test_result(
|
||||
self, test_id: str, variant: str, score: float, sample_count: int = ...
|
||||
) -> str: ...
|
||||
async def get_ab_test_results(self, test_id: str) -> list[dict]: ...
|
||||
|
||||
|
||||
class EvolutionStore:
|
||||
"""进化日志存储
|
||||
|
||||
|
|
@ -133,6 +163,40 @@ class EvolutionStore:
|
|||
logger.error(f"Failed to list evolution events: {e}")
|
||||
return []
|
||||
|
||||
# ── Protocol 兼容方法(旧版 EvolutionStore 不支持 skill_version / ab_test)──
|
||||
|
||||
async def record_skill_version(
|
||||
self, skill_name: str, version: str, content: str, parent_version: str | None = None
|
||||
) -> str:
|
||||
"""记录技能版本(旧版 SQL 后端不支持,抛出 NotImplementedError)"""
|
||||
raise NotImplementedError(
|
||||
"EvolutionStore (SQL backend) does not support skill_version operations. "
|
||||
"Use PersistentEvolutionStore or PostgreSQLEvolutionStore instead."
|
||||
)
|
||||
|
||||
async def list_skill_versions(self, skill_name: str) -> list[dict]:
|
||||
"""列出技能版本历史(旧版 SQL 后端不支持,抛出 NotImplementedError)"""
|
||||
raise NotImplementedError(
|
||||
"EvolutionStore (SQL backend) does not support skill_version operations. "
|
||||
"Use PersistentEvolutionStore or PostgreSQLEvolutionStore instead."
|
||||
)
|
||||
|
||||
async def record_ab_test_result(
|
||||
self, test_id: str, variant: str, score: float, sample_count: int = 0
|
||||
) -> str:
|
||||
"""记录 A/B 测试结果(旧版 SQL 后端不支持,抛出 NotImplementedError)"""
|
||||
raise NotImplementedError(
|
||||
"EvolutionStore (SQL backend) does not support A/B test operations. "
|
||||
"Use PersistentEvolutionStore or PostgreSQLEvolutionStore instead."
|
||||
)
|
||||
|
||||
async def get_ab_test_results(self, test_id: str) -> list[dict]:
|
||||
"""获取 A/B 测试结果(旧版 SQL 后端不支持,抛出 NotImplementedError)"""
|
||||
raise NotImplementedError(
|
||||
"EvolutionStore (SQL backend) does not support A/B test operations. "
|
||||
"Use PersistentEvolutionStore or PostgreSQLEvolutionStore instead."
|
||||
)
|
||||
|
||||
|
||||
class PersistentEvolutionStore:
|
||||
"""SQLite 持久化进化存储
|
||||
|
|
@ -464,19 +528,32 @@ def create_evolution_store(
|
|||
db_path: str = "~/.agentkit/evolution.db",
|
||||
session_factory: Any = None,
|
||||
evolution_model: Any = None,
|
||||
database_url: str | None = None,
|
||||
) -> EvolutionStore | PersistentEvolutionStore | InMemoryEvolutionStore:
|
||||
"""工厂函数:创建进化存储实例
|
||||
|
||||
Args:
|
||||
backend: 存储后端类型 - "memory" | "sqlite" | "sql"
|
||||
backend: 存储后端类型 - "memory" | "sqlite" | "sql" | "postgresql"
|
||||
db_path: SQLite 数据库路径(仅 backend="sqlite" 时使用)
|
||||
session_factory: 异步 SQLAlchemy session 工厂(仅 backend="sql" 时使用)
|
||||
evolution_model: SQLAlchemy ORM 模型类(仅 backend="sql" 时使用)
|
||||
database_url: PostgreSQL 连接字符串(仅 backend="postgresql" 时使用)
|
||||
|
||||
Returns:
|
||||
对应后端的进化存储实例
|
||||
"""
|
||||
if backend == "sqlite":
|
||||
if backend == "postgresql":
|
||||
from agentkit.evolution.pg_store import PostgreSQLEvolutionStore
|
||||
|
||||
url = database_url or os.environ.get("AGENTKIT_DATABASE_URL")
|
||||
if url:
|
||||
return PostgreSQLEvolutionStore(database_url=url)
|
||||
logger.warning(
|
||||
"PostgreSQL backend requested but no database_url provided, "
|
||||
"falling back to InMemoryEvolutionStore"
|
||||
)
|
||||
return InMemoryEvolutionStore()
|
||||
elif backend == "sqlite":
|
||||
return PersistentEvolutionStore(db_path=db_path)
|
||||
elif backend == "sql" and session_factory and evolution_model:
|
||||
return EvolutionStore(session_factory=session_factory, evolution_model=evolution_model)
|
||||
|
|
|
|||
|
|
@ -483,7 +483,15 @@ class EvolutionMixin:
|
|||
|
||||
tool = MemoryTool(memory_store)
|
||||
section = category
|
||||
content = "; ".join(reflections[-1]["reflection"].suggestions[:2])
|
||||
# 汇总所有累积反思的建议(去重,最多取 5 条)
|
||||
all_suggestions: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for r in reflections:
|
||||
for suggestion in r["reflection"].suggestions:
|
||||
if suggestion not in seen:
|
||||
seen.add(suggestion)
|
||||
all_suggestions.append(suggestion)
|
||||
content = "; ".join(all_suggestions[:5])
|
||||
reason = f"连续{len(reflections)}次低质量反思 (category: {category})"
|
||||
|
||||
update_result = await tool.execute(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,329 @@
|
|||
"""PostgreSQLEvolutionStore - 基于 PostgreSQL 的异步进化存储
|
||||
|
||||
使用 async SQLAlchemy + asyncpg 实现完整的 EvolutionStoreProtocol,
|
||||
支持进化事件、技能版本、A/B 测试结果的持久化。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid as _uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy import Column, DateTime, Float, Integer, String, Text, UniqueConstraint
|
||||
from sqlalchemy.orm import declarative_base
|
||||
|
||||
from agentkit.core.protocol import EvolutionEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# PG 专用 Base(与 SQLite models 的 Base 隔离,避免表名冲突)
|
||||
PGBase = declarative_base()
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
class PGEvolutionEventModel(PGBase):
|
||||
"""PostgreSQL 进化事件 ORM 模型"""
|
||||
|
||||
__tablename__ = "evolution_events"
|
||||
|
||||
id = Column(String, primary_key=True, default=lambda: str(_uuid.uuid4()))
|
||||
agent_name = Column(String, index=True)
|
||||
change_type = Column(String, nullable=True)
|
||||
before = Column(JSONB, nullable=True)
|
||||
after = Column(JSONB, nullable=True)
|
||||
metrics = Column(JSONB, nullable=True)
|
||||
status = Column(String, default="active")
|
||||
created_at = Column(DateTime, default=_utcnow)
|
||||
|
||||
|
||||
class PGSkillVersionModel(PGBase):
|
||||
"""PostgreSQL 技能版本 ORM 模型"""
|
||||
|
||||
__tablename__ = "skill_versions"
|
||||
__table_args__ = (UniqueConstraint("skill_name", "version"),)
|
||||
|
||||
id = Column(String, primary_key=True, default=lambda: str(_uuid.uuid4()))
|
||||
skill_name = Column(String, index=True)
|
||||
version = Column(String)
|
||||
content = Column(Text)
|
||||
parent_version = Column(String, nullable=True)
|
||||
created_at = Column(DateTime, default=_utcnow)
|
||||
|
||||
|
||||
class PGABTestResultModel(PGBase):
|
||||
"""PostgreSQL A/B 测试结果 ORM 模型"""
|
||||
|
||||
__tablename__ = "ab_test_results"
|
||||
|
||||
id = Column(String, primary_key=True, default=lambda: str(_uuid.uuid4()))
|
||||
test_id = Column(String, index=True)
|
||||
variant = Column(String)
|
||||
score = Column(Float)
|
||||
sample_count = Column(Integer, default=0)
|
||||
created_at = Column(DateTime, default=_utcnow)
|
||||
|
||||
|
||||
class PostgreSQLEvolutionStore:
|
||||
"""PostgreSQL 异步进化存储
|
||||
|
||||
使用 async SQLAlchemy session 实现完整的 EvolutionStoreProtocol。
|
||||
支持进化事件、技能版本、A/B 测试结果的 CRUD 操作。
|
||||
|
||||
用法:
|
||||
store = PostgreSQLEvolutionStore(
|
||||
database_url="postgresql+asyncpg://user:pass@localhost/dbname"
|
||||
)
|
||||
await store.ensure_tables()
|
||||
event_id = await store.record(event)
|
||||
"""
|
||||
|
||||
def __init__(self, database_url: str) -> None:
|
||||
self._database_url = database_url
|
||||
self._engine: Any = None
|
||||
self._session_factory: Any = None
|
||||
self._initialized = False
|
||||
self._init_lock = asyncio.Lock()
|
||||
|
||||
async def _ensure_initialized(self) -> None:
|
||||
"""延迟初始化 async engine 和 session factory(带锁防并发)"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
async with self._init_lock:
|
||||
# Double-check after acquiring lock
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
self._engine = create_async_engine(self._database_url, echo=False)
|
||||
self._session_factory = sessionmaker(
|
||||
self._engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
self._initialized = True
|
||||
|
||||
async def ensure_tables(self) -> None:
|
||||
"""创建所有表(如果不存在)
|
||||
|
||||
安全的启动调用 — 使用 CREATE TABLE IF NOT EXISTS。
|
||||
"""
|
||||
await self._ensure_initialized()
|
||||
async with self._engine.begin() as conn:
|
||||
await conn.run_sync(PGBase.metadata.create_all)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭 engine,释放所有连接"""
|
||||
if self._engine is not None:
|
||||
await self._engine.dispose()
|
||||
self._engine = None
|
||||
self._session_factory = None
|
||||
self._initialized = False
|
||||
|
||||
async def __aenter__(self) -> "PostgreSQLEvolutionStore":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
await self.close()
|
||||
|
||||
# ── 进化事件 ──────────────────────────────────────────
|
||||
|
||||
async def record(self, event: EvolutionEvent) -> str:
|
||||
"""记录进化事件"""
|
||||
await self._ensure_initialized()
|
||||
async with self._session_factory() as db:
|
||||
try:
|
||||
event_id = str(_uuid.uuid4())
|
||||
entry = PGEvolutionEventModel(
|
||||
id=event_id,
|
||||
agent_name=event.agent_name,
|
||||
change_type=event.change_type,
|
||||
before=event.before,
|
||||
after=event.after,
|
||||
metrics=event.metrics,
|
||||
status="active",
|
||||
)
|
||||
db.add(entry)
|
||||
await db.commit()
|
||||
event.event_id = event_id
|
||||
logger.info(f"Evolution event recorded: {event_id} for agent '{event.agent_name}'")
|
||||
return event_id
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Failed to record evolution event: {e}")
|
||||
raise
|
||||
|
||||
async def rollback(self, event_id: str) -> bool:
|
||||
"""回滚进化事件"""
|
||||
await self._ensure_initialized()
|
||||
async with self._session_factory() as db:
|
||||
try:
|
||||
stmt = select(PGEvolutionEventModel).where(
|
||||
PGEvolutionEventModel.id == event_id
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
entry = result.scalar_one_or_none()
|
||||
|
||||
if not entry:
|
||||
logger.error(f"Evolution event {event_id} not found")
|
||||
return False
|
||||
|
||||
entry.status = "rolled_back"
|
||||
await db.commit()
|
||||
logger.info(f"Evolution event {event_id} rolled back")
|
||||
return True
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Failed to rollback evolution event {event_id}: {e}")
|
||||
return False
|
||||
|
||||
async def list_events(
|
||||
self,
|
||||
agent_name: str | None = None,
|
||||
change_type: str | None = None,
|
||||
status: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> list[dict]:
|
||||
"""列出进化事件"""
|
||||
await self._ensure_initialized()
|
||||
async with self._session_factory() as db:
|
||||
try:
|
||||
stmt = select(PGEvolutionEventModel)
|
||||
if agent_name:
|
||||
stmt = stmt.where(PGEvolutionEventModel.agent_name == agent_name)
|
||||
if change_type:
|
||||
stmt = stmt.where(PGEvolutionEventModel.change_type == change_type)
|
||||
if status:
|
||||
stmt = stmt.where(PGEvolutionEventModel.status == status)
|
||||
stmt = stmt.order_by(PGEvolutionEventModel.created_at.desc()).limit(limit)
|
||||
result = await db.execute(stmt)
|
||||
entries = result.scalars().all()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": e.id,
|
||||
"agent_name": e.agent_name,
|
||||
"change_type": e.change_type,
|
||||
"before": e.before,
|
||||
"after": e.after,
|
||||
"metrics": e.metrics,
|
||||
"status": e.status,
|
||||
"created_at": e.created_at.isoformat() if e.created_at else None,
|
||||
}
|
||||
for e in entries
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list evolution events: {e}")
|
||||
return []
|
||||
|
||||
# ── 技能版本 ──────────────────────────────────────────
|
||||
|
||||
async def record_skill_version(
|
||||
self, skill_name: str, version: str, content: str, parent_version: str | None = None
|
||||
) -> str:
|
||||
"""记录技能版本"""
|
||||
await self._ensure_initialized()
|
||||
async with self._session_factory() as db:
|
||||
try:
|
||||
vid = str(_uuid.uuid4())
|
||||
entry = PGSkillVersionModel(
|
||||
id=vid,
|
||||
skill_name=skill_name,
|
||||
version=version,
|
||||
content=content,
|
||||
parent_version=parent_version,
|
||||
)
|
||||
db.add(entry)
|
||||
await db.commit()
|
||||
return vid
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Failed to record skill version: {e}")
|
||||
raise
|
||||
|
||||
async def list_skill_versions(self, skill_name: str, limit: int = 100) -> list[dict]:
|
||||
"""列出技能版本历史"""
|
||||
await self._ensure_initialized()
|
||||
async with self._session_factory() as db:
|
||||
try:
|
||||
stmt = (
|
||||
select(PGSkillVersionModel)
|
||||
.where(PGSkillVersionModel.skill_name == skill_name)
|
||||
.order_by(PGSkillVersionModel.created_at.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
entries = result.scalars().all()
|
||||
return [
|
||||
{
|
||||
"id": e.id,
|
||||
"skill_name": e.skill_name,
|
||||
"version": e.version,
|
||||
"content": e.content,
|
||||
"parent_version": e.parent_version,
|
||||
"created_at": e.created_at.isoformat() if e.created_at else None,
|
||||
}
|
||||
for e in entries
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list skill versions: {e}")
|
||||
return []
|
||||
|
||||
# ── A/B 测试结果 ──────────────────────────────────────
|
||||
|
||||
async def record_ab_test_result(
|
||||
self, test_id: str, variant: str, score: float, sample_count: int = 0
|
||||
) -> str:
|
||||
"""记录 A/B 测试结果"""
|
||||
await self._ensure_initialized()
|
||||
async with self._session_factory() as db:
|
||||
try:
|
||||
rid = str(_uuid.uuid4())
|
||||
entry = PGABTestResultModel(
|
||||
id=rid,
|
||||
test_id=test_id,
|
||||
variant=variant,
|
||||
score=score,
|
||||
sample_count=sample_count,
|
||||
)
|
||||
db.add(entry)
|
||||
await db.commit()
|
||||
return rid
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Failed to record A/B test result: {e}")
|
||||
raise
|
||||
|
||||
async def get_ab_test_results(self, test_id: str, limit: int = 100) -> list[dict]:
|
||||
"""获取 A/B 测试结果"""
|
||||
await self._ensure_initialized()
|
||||
async with self._session_factory() as db:
|
||||
try:
|
||||
stmt = (
|
||||
select(PGABTestResultModel)
|
||||
.where(PGABTestResultModel.test_id == test_id)
|
||||
.order_by(PGABTestResultModel.created_at.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
entries = result.scalars().all()
|
||||
return [
|
||||
{
|
||||
"id": e.id,
|
||||
"test_id": e.test_id,
|
||||
"variant": e.variant,
|
||||
"score": e.score,
|
||||
"sample_count": e.sample_count,
|
||||
"created_at": e.created_at.isoformat() if e.created_at else None,
|
||||
}
|
||||
for e in entries
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get A/B test results: {e}")
|
||||
return []
|
||||
|
|
@ -5,7 +5,8 @@ from agentkit.llm.gateway import LLMGateway
|
|||
from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, TokenUsage, ToolCall
|
||||
from agentkit.llm.providers.anthropic import AnthropicProvider
|
||||
from agentkit.llm.providers.openai import OpenAICompatibleProvider
|
||||
from agentkit.llm.providers.tracker import UsageRecord, UsageSummary, UsageTracker
|
||||
from agentkit.llm.providers.tracker import UsageSummary, UsageTracker
|
||||
from agentkit.llm.providers.usage_store import UsageRecord
|
||||
from agentkit.llm.retry import (
|
||||
CircuitBreaker,
|
||||
CircuitBreakerConfig,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,632 @@
|
|||
"""LLM Response Cache — Exact-match + Semantic-match dual cache for LLM responses.
|
||||
|
||||
Architecture:
|
||||
- LLMCache Protocol: async interface for cache backends
|
||||
- InMemoryLLMCache: OrderedDict LRU + embedding index
|
||||
- RedisLLMCache: Redis keys + SET index + lazy init
|
||||
- create_llm_cache(): Factory with auto-detection
|
||||
|
||||
Design doc: docs/plans/2026-06-14-002-u1-llm-cache-architecture.md
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
from agentkit.llm.protocol import LLMResponse, TokenUsage, ToolCall
|
||||
from agentkit.utils.vector_math import compute_cosine_similarity
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data Classes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
"""A cached LLM response with metadata."""
|
||||
|
||||
response: LLMResponse
|
||||
query_embedding: list[float] = field(default_factory=list)
|
||||
created_at: float = 0.0
|
||||
hit_count: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheResult:
|
||||
"""Result of a cache lookup."""
|
||||
|
||||
hit: bool = False
|
||||
response: LLMResponse | None = None
|
||||
match_type: str = "" # "exact" | "semantic" | "" (miss)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Serialization helpers (for Redis backend)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _serialize_response(response: LLMResponse) -> dict:
|
||||
"""Serialize LLMResponse to a JSON-compatible dict."""
|
||||
return {
|
||||
"content": response.content,
|
||||
"model": response.model,
|
||||
"usage": {
|
||||
"prompt_tokens": response.usage.prompt_tokens,
|
||||
"completion_tokens": response.usage.completion_tokens,
|
||||
},
|
||||
"tool_calls": [
|
||||
{"id": tc.id, "name": tc.name, "arguments": tc.arguments}
|
||||
for tc in response.tool_calls
|
||||
],
|
||||
"latency_ms": response.latency_ms,
|
||||
}
|
||||
|
||||
|
||||
def _deserialize_response(data: dict) -> LLMResponse:
|
||||
"""Deserialize a dict back to LLMResponse."""
|
||||
usage_data = data.get("usage", {})
|
||||
return LLMResponse(
|
||||
content=data["content"],
|
||||
model=data["model"],
|
||||
usage=TokenUsage(
|
||||
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
||||
completion_tokens=usage_data.get("completion_tokens", 0),
|
||||
),
|
||||
tool_calls=[
|
||||
ToolCall(id=tc["id"], name=tc["name"], arguments=tc["arguments"])
|
||||
for tc in data.get("tool_calls", [])
|
||||
],
|
||||
latency_ms=data.get("latency_ms", 0.0),
|
||||
)
|
||||
|
||||
|
||||
def _serialize_entry(entry: CacheEntry) -> dict:
|
||||
"""Serialize CacheEntry to a JSON-compatible dict."""
|
||||
return {
|
||||
"response": _serialize_response(entry.response),
|
||||
"query_embedding": entry.query_embedding,
|
||||
"created_at": entry.created_at,
|
||||
"hit_count": entry.hit_count,
|
||||
}
|
||||
|
||||
|
||||
def _deserialize_entry(data: dict) -> CacheEntry:
|
||||
"""Deserialize a dict back to CacheEntry."""
|
||||
return CacheEntry(
|
||||
response=_deserialize_response(data["response"]),
|
||||
query_embedding=data.get("query_embedding", []),
|
||||
created_at=data.get("created_at", 0.0),
|
||||
hit_count=data.get("hit_count", 0),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLMCache Protocol
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class LLMCache(Protocol):
|
||||
"""LLM response cache interface."""
|
||||
|
||||
async def get(self, key: str) -> CacheResult:
|
||||
"""Exact-match lookup by cache key."""
|
||||
...
|
||||
|
||||
async def semantic_search(
|
||||
self, query_embedding: list[float], threshold: float = 0.92
|
||||
) -> CacheResult:
|
||||
"""Semantic similarity search across cached entries."""
|
||||
...
|
||||
|
||||
async def put(
|
||||
self,
|
||||
key: str,
|
||||
response: LLMResponse,
|
||||
query_embedding: list[float] | None = None,
|
||||
) -> None:
|
||||
"""Store a response in the cache with optional embedding."""
|
||||
...
|
||||
|
||||
async def invalidate(self, pattern: str | None = None) -> int:
|
||||
"""Invalidate cache entries. Returns count of invalidated entries."""
|
||||
...
|
||||
|
||||
async def stats(self) -> dict[str, int]:
|
||||
"""Return cache statistics."""
|
||||
...
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# InMemoryLLMCache
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class InMemoryLLMCache:
|
||||
"""In-memory LLM cache with LRU eviction and semantic search.
|
||||
|
||||
Uses OrderedDict for O(1) LRU access/eviction (follows EmbeddingCache pattern).
|
||||
Maintains a parallel embedding index for semantic similarity search.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_entries: int = 10000,
|
||||
exact_ttl: int = 3600,
|
||||
semantic_ttl: int = 86400,
|
||||
similarity_threshold: float = 0.92,
|
||||
):
|
||||
self._max_entries = max_entries
|
||||
self._exact_ttl = exact_ttl
|
||||
self._semantic_ttl = semantic_ttl
|
||||
self._similarity_threshold = similarity_threshold
|
||||
|
||||
self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
|
||||
self._embeddings: dict[str, list[float]] = {}
|
||||
|
||||
self._hits = 0
|
||||
self._misses = 0
|
||||
|
||||
async def get(self, key: str) -> CacheResult:
|
||||
now = time.monotonic()
|
||||
entry = self._cache.get(key)
|
||||
|
||||
if entry is not None:
|
||||
if now - entry.created_at <= self._exact_ttl:
|
||||
# Hit: update LRU position and stats
|
||||
self._cache.move_to_end(key)
|
||||
entry.hit_count += 1
|
||||
self._hits += 1
|
||||
return CacheResult(hit=True, response=entry.response, match_type="exact")
|
||||
# Expired: remove
|
||||
del self._cache[key]
|
||||
self._embeddings.pop(key, None)
|
||||
|
||||
self._misses += 1
|
||||
return CacheResult(hit=False)
|
||||
|
||||
async def semantic_search(
|
||||
self, query_embedding: list[float], threshold: float | None = None
|
||||
) -> CacheResult:
|
||||
if not self._embeddings:
|
||||
return CacheResult(hit=False)
|
||||
|
||||
effective_threshold = threshold or self._similarity_threshold
|
||||
now = time.monotonic()
|
||||
best_key: str | None = None
|
||||
best_sim: float = 0.0
|
||||
|
||||
for key, emb in self._embeddings.items():
|
||||
entry = self._cache.get(key)
|
||||
if entry is None:
|
||||
continue
|
||||
# Check semantic TTL
|
||||
if now - entry.created_at > self._semantic_ttl:
|
||||
continue
|
||||
sim = compute_cosine_similarity(query_embedding, emb)
|
||||
if sim > best_sim:
|
||||
best_sim = sim
|
||||
best_key = key
|
||||
|
||||
if best_key is not None and best_sim >= effective_threshold:
|
||||
entry = self._cache[best_key]
|
||||
entry.hit_count += 1
|
||||
self._cache.move_to_end(best_key)
|
||||
self._hits += 1
|
||||
return CacheResult(hit=True, response=entry.response, match_type="semantic")
|
||||
|
||||
self._misses += 1
|
||||
return CacheResult(hit=False)
|
||||
|
||||
async def put(
|
||||
self,
|
||||
key: str,
|
||||
response: LLMResponse,
|
||||
query_embedding: list[float] | None = None,
|
||||
) -> None:
|
||||
now = time.monotonic()
|
||||
|
||||
if key in self._cache:
|
||||
self._cache.move_to_end(key)
|
||||
existing = self._cache[key]
|
||||
# Preserve existing embedding if new one is None
|
||||
effective_embedding = query_embedding if query_embedding is not None else existing.query_embedding
|
||||
else:
|
||||
effective_embedding = query_embedding or []
|
||||
|
||||
self._cache[key] = CacheEntry(
|
||||
response=response,
|
||||
query_embedding=effective_embedding,
|
||||
created_at=now,
|
||||
hit_count=0,
|
||||
)
|
||||
|
||||
if effective_embedding:
|
||||
self._embeddings[key] = effective_embedding
|
||||
|
||||
# Evict LRU entries if over capacity
|
||||
while len(self._cache) > self._max_entries:
|
||||
evicted_key, _ = self._cache.popitem(last=False)
|
||||
self._embeddings.pop(evicted_key, None)
|
||||
|
||||
# Lazy cleanup: remove a few expired entries on each put to prevent memory leak
|
||||
# Check oldest entries first (they are most likely to be expired)
|
||||
if len(self._cache) > 0:
|
||||
expired_keys = []
|
||||
# Iterate from oldest (front of OrderedDict) to find expired entries
|
||||
for k in list(self._cache.keys())[:20]:
|
||||
entry = self._cache.get(k)
|
||||
if entry is not None and now - entry.created_at > self._semantic_ttl:
|
||||
expired_keys.append(k)
|
||||
for k in expired_keys:
|
||||
self._cache.pop(k, None)
|
||||
self._embeddings.pop(k, None)
|
||||
|
||||
async def invalidate(self, pattern: str | None = None) -> int:
|
||||
if pattern is None:
|
||||
count = len(self._cache)
|
||||
self._cache.clear()
|
||||
self._embeddings.clear()
|
||||
return count
|
||||
|
||||
# Simple prefix matching for pattern
|
||||
keys_to_remove = [
|
||||
k for k in self._cache if k.startswith(pattern.replace("*", ""))
|
||||
]
|
||||
for key in keys_to_remove:
|
||||
del self._cache[key]
|
||||
self._embeddings.pop(key, None)
|
||||
return len(keys_to_remove)
|
||||
|
||||
async def stats(self) -> dict[str, int]:
|
||||
return {
|
||||
"total_entries": len(self._cache),
|
||||
"total_hits": self._hits,
|
||||
"total_misses": self._misses,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RedisLLMCache
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class RedisLLMCache:
|
||||
"""Redis-backed LLM cache with SET index for semantic search.
|
||||
|
||||
Key schema:
|
||||
agentkit:llm_cache:{sha256_hex} → JSON(CacheEntry) with TTL
|
||||
agentkit:llm_cache_emb:{sha256_hex} → JSON(list[float]) with TTL
|
||||
agentkit:llm_cache_index → SET of active cache keys
|
||||
"""
|
||||
|
||||
KEY_PREFIX = "agentkit:llm_cache:"
|
||||
EMB_PREFIX = "agentkit:llm_cache_emb:"
|
||||
INDEX_KEY = "agentkit:llm_cache_index"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_url: str = "redis://localhost:6379",
|
||||
max_entries: int = 10000,
|
||||
exact_ttl: int = 3600,
|
||||
semantic_ttl: int = 86400,
|
||||
similarity_threshold: float = 0.92,
|
||||
max_entries_to_scan: int = 500,
|
||||
fallback: InMemoryLLMCache | None = None,
|
||||
):
|
||||
self._redis_url = redis_url
|
||||
self._max_entries = max_entries
|
||||
self._exact_ttl = exact_ttl
|
||||
self._semantic_ttl = semantic_ttl
|
||||
self._similarity_threshold = similarity_threshold
|
||||
self._max_entries_to_scan = max_entries_to_scan
|
||||
self._redis: Any = None
|
||||
self._fallback: InMemoryLLMCache | None = fallback # For auto-degradation
|
||||
self._degraded = False # True if Redis is unreachable
|
||||
|
||||
self._hits = 0
|
||||
self._misses = 0
|
||||
|
||||
async def _get_redis(self):
|
||||
"""Lazy Redis initialization (follows RedisSessionStore pattern)."""
|
||||
if self._redis is None:
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
self._redis = aioredis.from_url(
|
||||
self._redis_url, decode_responses=True
|
||||
)
|
||||
return self._redis
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Close the Redis connection pool."""
|
||||
if self._redis is not None:
|
||||
await self._redis.aclose()
|
||||
self._redis = None
|
||||
|
||||
def _degrade_to_fallback(self) -> None:
|
||||
"""Mark Redis as unreachable and switch to in-memory fallback."""
|
||||
if not self._degraded:
|
||||
self._degraded = True
|
||||
self._degrade_count = 0
|
||||
if self._fallback is None:
|
||||
self._fallback = InMemoryLLMCache(
|
||||
max_entries=self._max_entries,
|
||||
exact_ttl=self._exact_ttl,
|
||||
semantic_ttl=self._semantic_ttl,
|
||||
similarity_threshold=self._similarity_threshold,
|
||||
)
|
||||
logger.warning("Redis cache unreachable, degraded to in-memory fallback")
|
||||
|
||||
def _try_recover(self) -> None:
|
||||
"""Attempt to recover from degraded state after enough operations.
|
||||
|
||||
Resets the degraded flag optimistically. The next actual Redis
|
||||
operation will verify connectivity — if it fails, degradation
|
||||
is re-triggered immediately.
|
||||
"""
|
||||
if not self._degraded:
|
||||
return
|
||||
self._degrade_count = getattr(self, "_degrade_count", 0) + 1
|
||||
# Try recovery every 100 operations
|
||||
if self._degrade_count >= 100:
|
||||
self._degrade_count = 0
|
||||
self._degraded = False
|
||||
logger.info("Redis cache: attempting recovery from degraded state")
|
||||
|
||||
async def get(self, key: str) -> CacheResult:
|
||||
# If degraded to fallback, use InMemory cache
|
||||
if self._degraded and self._fallback is not None:
|
||||
self._try_recover()
|
||||
if self._degraded:
|
||||
return await self._fallback.get(key)
|
||||
# Recovery attempted — fall through to try Redis
|
||||
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
data = await redis.get(f"{self.KEY_PREFIX}{key}")
|
||||
if data is not None:
|
||||
entry = _deserialize_entry(json.loads(data))
|
||||
self._hits += 1
|
||||
return CacheResult(
|
||||
hit=True, response=entry.response, match_type="exact"
|
||||
)
|
||||
self._misses += 1
|
||||
return CacheResult(hit=False)
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache get failed, returning miss: {e}")
|
||||
self._degrade_to_fallback()
|
||||
if self._fallback is not None:
|
||||
return await self._fallback.get(key)
|
||||
return CacheResult(hit=False)
|
||||
|
||||
async def semantic_search(
|
||||
self, query_embedding: list[float], threshold: float | None = None
|
||||
) -> CacheResult:
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
effective_threshold = threshold or self._similarity_threshold
|
||||
|
||||
# Get all cache keys from index
|
||||
cache_keys = await redis.smembers(self.INDEX_KEY)
|
||||
if not cache_keys:
|
||||
return CacheResult(hit=False)
|
||||
|
||||
# Limit scan to avoid O(n) memory/network transfer for large caches
|
||||
# Sample up to max_entries_to_scan most recent keys
|
||||
cache_keys_list = list(cache_keys)
|
||||
max_scan = min(len(cache_keys_list), self._max_entries_to_scan)
|
||||
if len(cache_keys_list) > max_scan:
|
||||
# Take a random sample to avoid always scanning the same subset
|
||||
import random
|
||||
cache_keys_list = random.sample(cache_keys_list, max_scan)
|
||||
|
||||
# Batch fetch embeddings
|
||||
emb_keys = [f"{self.EMB_PREFIX}{k}" for k in cache_keys_list]
|
||||
emb_values = await redis.mget(emb_keys)
|
||||
|
||||
best_key: str | None = None
|
||||
best_sim: float = 0.0
|
||||
stale_keys: list[str] = [] # Keys whose data has expired
|
||||
|
||||
for cache_key, emb_json in zip(cache_keys_list, emb_values):
|
||||
if emb_json is None:
|
||||
# Embedding expired but index entry remains — mark for cleanup
|
||||
stale_keys.append(cache_key)
|
||||
continue
|
||||
emb = json.loads(emb_json)
|
||||
sim = compute_cosine_similarity(query_embedding, emb)
|
||||
if sim > best_sim:
|
||||
best_sim = sim
|
||||
best_key = cache_key
|
||||
|
||||
# Lazy cleanup: remove stale index entries
|
||||
if stale_keys:
|
||||
try:
|
||||
pipe = redis.pipeline()
|
||||
for k in stale_keys:
|
||||
pipe.srem(self.INDEX_KEY, k)
|
||||
await pipe.execute()
|
||||
except Exception:
|
||||
pass # Best-effort cleanup
|
||||
|
||||
if best_key is not None and best_sim >= effective_threshold:
|
||||
data = await redis.get(f"{self.KEY_PREFIX}{best_key}")
|
||||
if data is not None:
|
||||
entry = _deserialize_entry(json.loads(data))
|
||||
self._hits += 1
|
||||
return CacheResult(
|
||||
hit=True, response=entry.response, match_type="semantic"
|
||||
)
|
||||
# Data key expired but embedding still exists — mark for cleanup
|
||||
try:
|
||||
await redis.srem(self.INDEX_KEY, best_key)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self._misses += 1
|
||||
return CacheResult(hit=False)
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis semantic search failed, returning miss: {e}")
|
||||
self._degrade_to_fallback()
|
||||
if self._fallback is not None:
|
||||
return await self._fallback.semantic_search(query_embedding, threshold)
|
||||
self._misses += 1
|
||||
return CacheResult(hit=False)
|
||||
|
||||
async def put(
|
||||
self,
|
||||
key: str,
|
||||
response: LLMResponse,
|
||||
query_embedding: list[float] | None = None,
|
||||
) -> None:
|
||||
# If degraded to fallback, use InMemory cache
|
||||
if self._degraded and self._fallback is not None:
|
||||
self._try_recover()
|
||||
if self._degraded:
|
||||
await self._fallback.put(key, response, query_embedding)
|
||||
return
|
||||
# Recovery attempted — fall through to try Redis
|
||||
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
entry = CacheEntry(
|
||||
response=response,
|
||||
query_embedding=query_embedding or [],
|
||||
created_at=time.time(), # Wall-clock for cross-process comparability in Redis
|
||||
hit_count=0,
|
||||
)
|
||||
|
||||
pipe = redis.pipeline()
|
||||
# Data key TTL must cover both exact and semantic windows
|
||||
# so semantic hits don't return None data
|
||||
data_ttl = max(self._exact_ttl, self._semantic_ttl)
|
||||
pipe.set(
|
||||
f"{self.KEY_PREFIX}{key}",
|
||||
json.dumps(_serialize_entry(entry)),
|
||||
ex=data_ttl,
|
||||
)
|
||||
if query_embedding is not None:
|
||||
pipe.set(
|
||||
f"{self.EMB_PREFIX}{key}",
|
||||
json.dumps(query_embedding),
|
||||
ex=self._semantic_ttl,
|
||||
)
|
||||
pipe.sadd(self.INDEX_KEY, key)
|
||||
await pipe.execute()
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache put failed: {e}")
|
||||
self._degrade_to_fallback()
|
||||
if self._fallback is not None:
|
||||
await self._fallback.put(key, response, query_embedding)
|
||||
|
||||
async def invalidate(self, pattern: str | None = None) -> int:
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
|
||||
if pattern is None:
|
||||
cache_keys = await redis.smembers(self.INDEX_KEY)
|
||||
if not cache_keys:
|
||||
return 0
|
||||
pipe = redis.pipeline()
|
||||
for key in cache_keys:
|
||||
pipe.delete(f"{self.KEY_PREFIX}{key}")
|
||||
pipe.delete(f"{self.EMB_PREFIX}{key}")
|
||||
pipe.delete(self.INDEX_KEY)
|
||||
await pipe.execute()
|
||||
return len(cache_keys)
|
||||
|
||||
# Pattern-based invalidation (prefix match)
|
||||
prefix = pattern.replace("*", "")
|
||||
cache_keys = await redis.smembers(self.INDEX_KEY)
|
||||
keys_to_remove = [k for k in cache_keys if k.startswith(prefix)]
|
||||
|
||||
if not keys_to_remove:
|
||||
return 0
|
||||
|
||||
pipe = redis.pipeline()
|
||||
for key in keys_to_remove:
|
||||
pipe.delete(f"{self.KEY_PREFIX}{key}")
|
||||
pipe.delete(f"{self.EMB_PREFIX}{key}")
|
||||
pipe.srem(self.INDEX_KEY, key)
|
||||
await pipe.execute()
|
||||
return len(keys_to_remove)
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache invalidate failed: {e}")
|
||||
return 0
|
||||
|
||||
async def stats(self) -> dict[str, int]:
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
total_entries = await redis.scard(self.INDEX_KEY)
|
||||
return {
|
||||
"total_entries": total_entries,
|
||||
"total_hits": self._hits,
|
||||
"total_misses": self._misses,
|
||||
}
|
||||
except Exception:
|
||||
return {
|
||||
"total_entries": 0,
|
||||
"total_hits": self._hits,
|
||||
"total_misses": self._misses,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_llm_cache(
|
||||
backend: str = "auto",
|
||||
redis_url: str = "redis://localhost:6379",
|
||||
max_entries: int = 10000,
|
||||
exact_ttl: int = 3600,
|
||||
semantic_ttl: int = 86400,
|
||||
similarity_threshold: float = 0.92,
|
||||
) -> LLMCache:
|
||||
"""Create an LLM cache backend.
|
||||
|
||||
Args:
|
||||
backend: "auto" (try Redis, fallback to memory), "redis", "memory".
|
||||
redis_url: Redis connection URL (only used for "redis"/"auto" backend).
|
||||
max_entries: Maximum number of cache entries.
|
||||
exact_ttl: TTL in seconds for exact-match cache entries.
|
||||
semantic_ttl: TTL in seconds for semantic-match embeddings.
|
||||
similarity_threshold: Cosine similarity threshold for semantic match.
|
||||
|
||||
Returns:
|
||||
An LLMCache instance.
|
||||
"""
|
||||
if backend in ("auto", "redis"):
|
||||
try:
|
||||
import redis.asyncio as aioredis # noqa: F401
|
||||
|
||||
return RedisLLMCache(
|
||||
redis_url=redis_url,
|
||||
max_entries=max_entries,
|
||||
exact_ttl=exact_ttl,
|
||||
semantic_ttl=semantic_ttl,
|
||||
similarity_threshold=similarity_threshold,
|
||||
)
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"redis package not available, falling back to in-memory cache"
|
||||
)
|
||||
return InMemoryLLMCache(
|
||||
max_entries=max_entries,
|
||||
exact_ttl=exact_ttl,
|
||||
semantic_ttl=semantic_ttl,
|
||||
similarity_threshold=similarity_threshold,
|
||||
)
|
||||
return InMemoryLLMCache(
|
||||
max_entries=max_entries,
|
||||
exact_ttl=exact_ttl,
|
||||
semantic_ttl=semantic_ttl,
|
||||
similarity_threshold=similarity_threshold,
|
||||
)
|
||||
|
|
@ -0,0 +1,66 @@
|
|||
"""LLM Cache Key Generation — Deterministic SHA-256 cache key from LLM request parameters."""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
|
||||
def generate_cache_key(
|
||||
model: str,
|
||||
messages: list[dict[str, str]],
|
||||
temperature: float,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
tool_choice: str = "auto",
|
||||
max_tokens: int = 2000,
|
||||
) -> str:
|
||||
"""Generate a deterministic SHA-256 cache key from LLM request parameters.
|
||||
|
||||
The key captures ALL inputs that deterministically affect LLM output:
|
||||
model, system_prompt (extracted from messages), messages content,
|
||||
temperature, tools, tool_choice, and max_tokens.
|
||||
|
||||
Args:
|
||||
model: Model identifier (e.g. "openai/gpt-4o").
|
||||
messages: Chat messages list (may include system prompt as first message).
|
||||
temperature: Sampling temperature.
|
||||
tools: Optional list of tool definitions.
|
||||
tool_choice: Tool selection mode ("auto", "none", etc.).
|
||||
max_tokens: Maximum response tokens.
|
||||
|
||||
Returns:
|
||||
64-character hex SHA-256 hash string.
|
||||
"""
|
||||
system_prompt = _extract_system_prompt(messages)
|
||||
components = [
|
||||
_hash_str(model),
|
||||
_hash_str(system_prompt),
|
||||
_hash_json(messages),
|
||||
_hash_str(f"{temperature:.2f}"),
|
||||
_hash_json(tools),
|
||||
_hash_str(tool_choice),
|
||||
_hash_str(str(max_tokens)),
|
||||
]
|
||||
combined = "".join(components)
|
||||
return hashlib.sha256(combined.encode()).hexdigest()
|
||||
|
||||
|
||||
def _extract_system_prompt(messages: list[dict[str, str]]) -> str:
|
||||
"""Extract system prompt from messages list."""
|
||||
for msg in messages:
|
||||
if msg.get("role") == "system":
|
||||
return msg.get("content", "")
|
||||
return ""
|
||||
|
||||
|
||||
def _hash_str(s: str) -> str:
|
||||
"""SHA-256 hash of a string."""
|
||||
return hashlib.sha256(s.encode()).hexdigest()
|
||||
|
||||
|
||||
def _hash_json(obj: Any) -> str:
|
||||
"""SHA-256 hash of a JSON-serializable object."""
|
||||
if obj is None:
|
||||
return hashlib.sha256(b"null").hexdigest()
|
||||
return hashlib.sha256(
|
||||
json.dumps(obj, sort_keys=True, ensure_ascii=False).encode()
|
||||
).hexdigest()
|
||||
|
|
@ -8,6 +8,43 @@ import yaml
|
|||
from agentkit.llm.retry import CircuitBreakerConfig, RetryConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheConfig:
|
||||
"""LLM Cache 配置"""
|
||||
|
||||
enabled: bool = False
|
||||
backend: str = "auto" # "auto" | "redis" | "memory"
|
||||
redis_url: str = "redis://localhost:6379"
|
||||
exact_ttl: int = 3600
|
||||
semantic_ttl: int = 86400
|
||||
similarity_threshold: float = 0.92
|
||||
max_entries: int = 10000
|
||||
# Embedding config for semantic cache (Chinese-first: bge-m3 via Xinference)
|
||||
embedding_provider: str = "openai" # "openai" | "xinference" | "local"
|
||||
embedding_model: str = "bge-m3"
|
||||
embedding_base_url: str | None = None
|
||||
embedding_api_key: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "CacheConfig":
|
||||
if not data:
|
||||
return cls()
|
||||
emb = data.get("embedding", {})
|
||||
return cls(
|
||||
enabled=data.get("enabled", False),
|
||||
backend=data.get("backend", "auto"),
|
||||
redis_url=data.get("redis_url", "redis://localhost:6379"),
|
||||
exact_ttl=data.get("exact_ttl", 3600),
|
||||
semantic_ttl=data.get("semantic_ttl", 86400),
|
||||
similarity_threshold=data.get("similarity_threshold", 0.92),
|
||||
max_entries=data.get("max_entries", 10000),
|
||||
embedding_provider=emb.get("provider", "openai"),
|
||||
embedding_model=emb.get("model", "bge-m3"),
|
||||
embedding_base_url=emb.get("base_url"),
|
||||
embedding_api_key=emb.get("api_key"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderConfig:
|
||||
"""Provider 配置"""
|
||||
|
|
@ -32,6 +69,7 @@ class LLMConfig:
|
|||
providers: dict[str, ProviderConfig] = field(default_factory=dict)
|
||||
model_aliases: dict[str, str] = field(default_factory=dict)
|
||||
fallbacks: dict[str, list[str]] = field(default_factory=dict)
|
||||
cache: CacheConfig | None = None
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, path: str) -> "LLMConfig":
|
||||
|
|
@ -77,8 +115,14 @@ class LLMConfig:
|
|||
retry=retry,
|
||||
circuit_breaker=circuit_breaker,
|
||||
)
|
||||
cache = None
|
||||
cache_data = data.get("cache")
|
||||
if cache_data:
|
||||
cache = CacheConfig.from_dict(cache_data)
|
||||
|
||||
return cls(
|
||||
providers=providers,
|
||||
model_aliases=data.get("model_aliases", {}),
|
||||
fallbacks=data.get("fallbacks", {}),
|
||||
cache=cache,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError
|
||||
from agentkit.llm.config import LLMConfig
|
||||
|
|
@ -14,13 +15,53 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class LLMGateway:
|
||||
"""LLM 网关 - Provider 注册、模型别名解析、Fallback、Usage 追踪"""
|
||||
"""LLM 网关 - Provider 注册、模型别名解析、Fallback、Usage 追踪、Cache"""
|
||||
|
||||
def __init__(self, config: LLMConfig | None = None):
|
||||
def __init__(self, config: LLMConfig | None = None, usage_store: Any = None):
|
||||
self._providers: dict[str, LLMProvider] = {}
|
||||
self._usage_tracker = UsageTracker()
|
||||
self._usage_tracker = UsageTracker(store=usage_store) if usage_store else UsageTracker()
|
||||
self._config = config or LLMConfig()
|
||||
|
||||
# Cache (opt-in, disabled by default)
|
||||
self._cache: Any = None # LLMCache | None
|
||||
self._embedder: Any = None # Embedder | None
|
||||
if self._config.cache and self._config.cache.enabled:
|
||||
from agentkit.llm.cache import create_llm_cache
|
||||
self._cache = create_llm_cache(
|
||||
backend=self._config.cache.backend,
|
||||
redis_url=self._config.cache.redis_url,
|
||||
max_entries=self._config.cache.max_entries,
|
||||
exact_ttl=self._config.cache.exact_ttl,
|
||||
semantic_ttl=self._config.cache.semantic_ttl,
|
||||
similarity_threshold=self._config.cache.similarity_threshold,
|
||||
)
|
||||
self._embedder = self._create_embedder(self._config.cache)
|
||||
logger.info(
|
||||
f"LLM cache enabled (backend={self._config.cache.backend}, "
|
||||
f"embedder={self._config.cache.embedding_provider}/{self._config.cache.embedding_model})"
|
||||
)
|
||||
|
||||
def _create_embedder(self, cache_config) -> Any:
|
||||
"""Create embedder for semantic cache based on config."""
|
||||
try:
|
||||
from agentkit.memory.embedder import OpenAIEmbedder
|
||||
|
||||
if cache_config.embedding_provider in ("xinference", "local"):
|
||||
return OpenAIEmbedder(
|
||||
api_key=cache_config.embedding_api_key or "not-needed",
|
||||
model=cache_config.embedding_model,
|
||||
base_url=cache_config.embedding_base_url or "http://localhost:9997/v1",
|
||||
)
|
||||
# Default: OpenAI
|
||||
return OpenAIEmbedder(
|
||||
api_key=cache_config.embedding_api_key,
|
||||
model=cache_config.embedding_model,
|
||||
base_url=cache_config.embedding_base_url,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create embedder for semantic cache: {e}")
|
||||
return None
|
||||
|
||||
def register_provider(self, name: str, provider: LLMProvider) -> None:
|
||||
"""注册 Provider"""
|
||||
self._providers[name] = provider
|
||||
|
|
@ -66,6 +107,66 @@ class LLMGateway:
|
|||
_span = _span_cm.__enter__()
|
||||
|
||||
start = time.monotonic()
|
||||
|
||||
# ── Cache check ──
|
||||
cache_key = None
|
||||
query_embedding = None
|
||||
if self._cache is not None:
|
||||
from agentkit.llm.cache_key import generate_cache_key
|
||||
|
||||
cache_key = generate_cache_key(
|
||||
model=resolved_model,
|
||||
messages=messages,
|
||||
temperature=kwargs.get("temperature", 0.7),
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
max_tokens=kwargs.get("max_tokens", 2000),
|
||||
)
|
||||
result = await self._cache.get(cache_key)
|
||||
if result.hit:
|
||||
latency_ms = (time.monotonic() - start) * 1000
|
||||
self._usage_tracker.record(
|
||||
agent_name=agent_name,
|
||||
model=result.response.model,
|
||||
usage=result.response.usage,
|
||||
cost=0.0,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
if _span is not None:
|
||||
_span.set_attribute("gen_ai.cache.hit", True)
|
||||
_span.set_attribute("gen_ai.cache.match_type", result.match_type)
|
||||
return result.response
|
||||
|
||||
# Semantic match (only for temperature == 0)
|
||||
temperature = kwargs.get("temperature", 0.7)
|
||||
if temperature == 0 and self._embedder is not None:
|
||||
try:
|
||||
# Embed last N messages for context-aware semantic matching
|
||||
# (not just last user message — avoids cross-context false hits)
|
||||
recent_messages = messages[-3:] if len(messages) > 3 else messages
|
||||
embed_text = " | ".join(
|
||||
m.get("content", "") for m in recent_messages if m.get("content")
|
||||
)
|
||||
if embed_text:
|
||||
query_embedding = await self._embedder.embed(embed_text)
|
||||
result = await self._cache.semantic_search(query_embedding)
|
||||
if result.hit:
|
||||
latency_ms = (time.monotonic() - start) * 1000
|
||||
self._usage_tracker.record(
|
||||
agent_name=agent_name,
|
||||
model=result.response.model,
|
||||
usage=result.response.usage,
|
||||
cost=0.0,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
if _span is not None:
|
||||
_span.set_attribute("gen_ai.cache.hit", True)
|
||||
_span.set_attribute("gen_ai.cache.match_type", "semantic")
|
||||
return result.response
|
||||
except Exception as e:
|
||||
logger.warning(f"Semantic cache search failed: {e}")
|
||||
|
||||
# ── Normal provider call ──
|
||||
models_to_try = self._get_models_to_try(resolved_model)
|
||||
last_error: LLMProviderError | None = None
|
||||
|
||||
|
|
@ -95,6 +196,13 @@ class LLMGateway:
|
|||
|
||||
latency_ms = (time.monotonic() - start) * 1000
|
||||
|
||||
# ── Cache write ──
|
||||
if self._cache is not None and cache_key is not None:
|
||||
try:
|
||||
await self._cache.put(cache_key, response, query_embedding)
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache write failed: {e}")
|
||||
|
||||
# 计算成本
|
||||
cost = self._calculate_cost(response.model, response.usage)
|
||||
|
||||
|
|
@ -112,7 +220,9 @@ class LLMGateway:
|
|||
_span.set_attribute("gen_ai.usage.input_tokens", response.usage.prompt_tokens)
|
||||
_span.set_attribute("gen_ai.usage.output_tokens", response.usage.completion_tokens)
|
||||
_span.set_attribute("gen_ai.response.model", response.model)
|
||||
_span.set_attribute("gen_ai.duration_ms", int(latency_ms))
|
||||
_span.set_attribute("gen_ai.duration.ms", int(latency_ms))
|
||||
if self._cache is not None:
|
||||
_span.set_attribute("gen_ai.cache.hit", False)
|
||||
llm_token_histogram().record(
|
||||
response.usage.total_tokens,
|
||||
{"gen_ai.request.model": resolved_model},
|
||||
|
|
@ -138,6 +248,8 @@ class LLMGateway:
|
|||
If the primary model fails before any chunk is yielded, tries fallback
|
||||
models. If it fails after chunks have been sent, yields an error chunk
|
||||
and terminates (cannot switch mid-stream).
|
||||
|
||||
Note: Streaming responses are NOT cached in this iteration.
|
||||
"""
|
||||
resolved_model = self._resolve_model_alias(model)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,8 @@ from agentkit.llm.providers.anthropic import AnthropicProvider
|
|||
from agentkit.llm.providers.doubao import DoubaoProvider
|
||||
from agentkit.llm.providers.gemini import GeminiProvider
|
||||
from agentkit.llm.providers.openai import OpenAICompatibleProvider
|
||||
from agentkit.llm.providers.tracker import UsageRecord, UsageSummary, UsageTracker
|
||||
from agentkit.llm.providers.tracker import UsageSummary, UsageTracker
|
||||
from agentkit.llm.providers.usage_store import UsageRecord
|
||||
from agentkit.llm.providers.wenxin import WenxinProvider
|
||||
from agentkit.llm.providers.yuanbao import YuanbaoProvider
|
||||
|
||||
|
|
|
|||
|
|
@ -1,42 +1,20 @@
|
|||
"""Usage Tracker - 使用量追踪"""
|
||||
"""Usage Tracker - 使用量追踪(委托给 UsageStore)"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime
|
||||
|
||||
from agentkit.llm.protocol import TokenUsage
|
||||
|
||||
|
||||
@dataclass
|
||||
class UsageRecord:
|
||||
"""使用量记录"""
|
||||
|
||||
agent_name: str
|
||||
model: str
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
cost: float
|
||||
latency_ms: float
|
||||
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
||||
@dataclass
|
||||
class UsageSummary:
|
||||
"""使用量汇总"""
|
||||
|
||||
total_tokens: int = 0
|
||||
total_cost: float = 0.0
|
||||
by_model: dict[str, dict[str, int | float]] = field(default_factory=dict)
|
||||
records: list[UsageRecord] = field(default_factory=list)
|
||||
from agentkit.llm.providers.usage_store import (
|
||||
InMemoryUsageStore,
|
||||
UsageStore,
|
||||
UsageSummary,
|
||||
)
|
||||
|
||||
|
||||
class UsageTracker:
|
||||
"""使用量追踪器"""
|
||||
"""使用量追踪器 — 委托给可插拔的 UsageStore"""
|
||||
|
||||
MAX_RECORDS = 10000 # 最大记录数,防止内存无限增长
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._records: list[UsageRecord] = []
|
||||
def __init__(self, store: UsageStore | None = None) -> None:
|
||||
self._store: UsageStore = store or InMemoryUsageStore()
|
||||
|
||||
def record(
|
||||
self,
|
||||
|
|
@ -47,19 +25,7 @@ class UsageTracker:
|
|||
latency_ms: float,
|
||||
) -> None:
|
||||
"""记录一次使用"""
|
||||
rec = UsageRecord(
|
||||
agent_name=agent_name,
|
||||
model=model,
|
||||
prompt_tokens=usage.prompt_tokens,
|
||||
completion_tokens=usage.completion_tokens,
|
||||
total_tokens=usage.total_tokens,
|
||||
cost=cost,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
self._records.append(rec)
|
||||
# 超过上限时删除最早的记录
|
||||
if len(self._records) > self.MAX_RECORDS:
|
||||
self._records = self._records[-self.MAX_RECORDS:]
|
||||
self._store.record(agent_name, model, usage, cost, latency_ms)
|
||||
|
||||
def get_usage(
|
||||
self,
|
||||
|
|
@ -68,32 +34,4 @@ class UsageTracker:
|
|||
end_time: datetime | None = None,
|
||||
) -> UsageSummary:
|
||||
"""查询使用量汇总"""
|
||||
filtered = self._records
|
||||
|
||||
if agent_name is not None:
|
||||
filtered = [r for r in filtered if r.agent_name == agent_name]
|
||||
if start_time is not None:
|
||||
filtered = [r for r in filtered if r.timestamp >= start_time]
|
||||
if end_time is not None:
|
||||
filtered = [r for r in filtered if r.timestamp <= end_time]
|
||||
|
||||
if not filtered:
|
||||
return UsageSummary()
|
||||
|
||||
total_tokens = sum(r.total_tokens for r in filtered)
|
||||
total_cost = sum(r.cost for r in filtered)
|
||||
|
||||
by_model: dict[str, dict[str, int | float]] = {}
|
||||
for r in filtered:
|
||||
if r.model not in by_model:
|
||||
by_model[r.model] = {"total_tokens": 0, "total_cost": 0.0, "count": 0}
|
||||
by_model[r.model]["total_tokens"] += r.total_tokens
|
||||
by_model[r.model]["total_cost"] += r.cost
|
||||
by_model[r.model]["count"] += 1
|
||||
|
||||
return UsageSummary(
|
||||
total_tokens=total_tokens,
|
||||
total_cost=total_cost,
|
||||
by_model=by_model,
|
||||
records=filtered,
|
||||
)
|
||||
return self._store.get_usage(agent_name, start_time, end_time)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,373 @@
|
|||
"""Usage Store — Persistent usage tracking with Redis Hash backend.
|
||||
|
||||
Provides UsageStore Protocol with InMemoryUsageStore and RedisUsageStore
|
||||
backends. Replaces the in-memory list in UsageTracker with a pluggable
|
||||
store that survives restarts and supports multi-instance deployment.
|
||||
|
||||
Key schema (Redis):
|
||||
agentkit:usage:{date} → Hash: {agent_name:model → JSON(UsageBucket)}
|
||||
agentkit:usage_records:{date} → List: JSON(UsageRecord) with LTRIM
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
from agentkit.llm.protocol import TokenUsage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UsageRecord:
|
||||
"""使用量记录"""
|
||||
|
||||
agent_name: str
|
||||
model: str
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
cost: float
|
||||
latency_ms: float
|
||||
timestamp: str = "" # ISO 8601 string for JSON serialization
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.timestamp:
|
||||
self.timestamp = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
@dataclass
|
||||
class UsageBucket:
|
||||
"""Aggregated usage for an agent+model pair on a given date."""
|
||||
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
cost: float = 0.0
|
||||
count: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class UsageSummary:
|
||||
"""使用量汇总"""
|
||||
|
||||
total_tokens: int = 0
|
||||
total_cost: float = 0.0
|
||||
by_model: dict[str, dict[str, int | float]] = field(default_factory=dict)
|
||||
records: list[UsageRecord] = field(default_factory=list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# UsageStore Protocol
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class UsageStore(Protocol):
|
||||
"""Persistent usage store interface."""
|
||||
|
||||
def record(
|
||||
self,
|
||||
agent_name: str,
|
||||
model: str,
|
||||
usage: TokenUsage,
|
||||
cost: float,
|
||||
latency_ms: float,
|
||||
) -> None:
|
||||
"""Record a usage event."""
|
||||
...
|
||||
|
||||
def get_usage(
|
||||
self,
|
||||
agent_name: str | None = None,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
) -> UsageSummary:
|
||||
"""Query usage summary."""
|
||||
...
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# InMemoryUsageStore
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class InMemoryUsageStore:
|
||||
"""In-memory usage store (drop-in replacement for old UsageTracker)."""
|
||||
|
||||
MAX_RECORDS = 10000
|
||||
|
||||
def __init__(self):
|
||||
self._records: list[UsageRecord] = []
|
||||
|
||||
def record(
|
||||
self,
|
||||
agent_name: str,
|
||||
model: str,
|
||||
usage: TokenUsage,
|
||||
cost: float,
|
||||
latency_ms: float,
|
||||
) -> None:
|
||||
rec = UsageRecord(
|
||||
agent_name=agent_name,
|
||||
model=model,
|
||||
prompt_tokens=usage.prompt_tokens,
|
||||
completion_tokens=usage.completion_tokens,
|
||||
total_tokens=usage.total_tokens,
|
||||
cost=cost,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
self._records.append(rec)
|
||||
if len(self._records) > self.MAX_RECORDS:
|
||||
self._records = self._records[-self.MAX_RECORDS:]
|
||||
|
||||
def get_usage(
|
||||
self,
|
||||
agent_name: str | None = None,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
) -> UsageSummary:
|
||||
filtered = self._records
|
||||
|
||||
if agent_name is not None:
|
||||
filtered = [r for r in filtered if r.agent_name == agent_name]
|
||||
if start_time is not None:
|
||||
filtered = [r for r in filtered if datetime.fromisoformat(r.timestamp) >= start_time]
|
||||
if end_time is not None:
|
||||
filtered = [r for r in filtered if datetime.fromisoformat(r.timestamp) <= end_time]
|
||||
|
||||
if not filtered:
|
||||
return UsageSummary()
|
||||
|
||||
total_tokens = sum(r.total_tokens for r in filtered)
|
||||
total_cost = sum(r.cost for r in filtered)
|
||||
|
||||
by_model: dict[str, dict[str, int | float]] = {}
|
||||
for r in filtered:
|
||||
if r.model not in by_model:
|
||||
by_model[r.model] = {"total_tokens": 0, "total_cost": 0.0, "count": 0}
|
||||
by_model[r.model]["total_tokens"] += r.total_tokens
|
||||
by_model[r.model]["total_cost"] += r.cost
|
||||
by_model[r.model]["count"] += 1
|
||||
|
||||
return UsageSummary(
|
||||
total_tokens=total_tokens,
|
||||
total_cost=total_cost,
|
||||
by_model=by_model,
|
||||
records=filtered,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RedisUsageStore
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class RedisUsageStore:
|
||||
"""Redis-backed usage store using Hash per date for O(1) writes.
|
||||
|
||||
Key schema:
|
||||
agentkit:usage:{YYYY-MM-DD} → Hash: {agent:model → JSON(UsageBucket)}
|
||||
agentkit:usage_records:{YYYY-MM-DD} → List: JSON(UsageRecord) with LTRIM
|
||||
"""
|
||||
|
||||
USAGE_PREFIX = "agentkit:usage:"
|
||||
RECORDS_PREFIX = "agentkit:usage_records:"
|
||||
MAX_RECORDS_PER_DAY = 50000
|
||||
TTL_DAYS = 90 # Auto-expire after 90 days
|
||||
|
||||
def __init__(self, redis_url: str = "redis://localhost:6379"):
|
||||
self._redis_url = redis_url
|
||||
self._redis: Any = None
|
||||
self._sync_redis: Any = None
|
||||
self._fallback: InMemoryUsageStore | None = None
|
||||
self._degraded = False
|
||||
|
||||
async def _get_redis(self):
|
||||
if self._redis is None:
|
||||
import redis.asyncio as aioredis
|
||||
self._redis = aioredis.from_url(self._redis_url, decode_responses=True)
|
||||
return self._redis
|
||||
|
||||
def _get_sync_redis(self):
|
||||
"""Get or create a persistent sync Redis client (connection pool backed)."""
|
||||
if self._sync_redis is None:
|
||||
import redis as sync_redis
|
||||
self._sync_redis = sync_redis.from_url(
|
||||
self._redis_url, decode_responses=True
|
||||
)
|
||||
return self._sync_redis
|
||||
|
||||
async def aclose(self) -> None:
|
||||
if self._redis is not None:
|
||||
await self._redis.aclose()
|
||||
self._redis = None
|
||||
if self._sync_redis is not None:
|
||||
self._sync_redis.aclose()
|
||||
self._sync_redis = None
|
||||
|
||||
def _degrade_to_fallback(self) -> None:
|
||||
if not self._degraded:
|
||||
self._degraded = True
|
||||
if self._fallback is None:
|
||||
self._fallback = InMemoryUsageStore()
|
||||
logger.warning("Redis usage store unreachable, degraded to in-memory")
|
||||
|
||||
def _today_key(self) -> str:
|
||||
return datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
|
||||
def record(
|
||||
self,
|
||||
agent_name: str,
|
||||
model: str,
|
||||
usage: TokenUsage,
|
||||
cost: float,
|
||||
latency_ms: float,
|
||||
) -> None:
|
||||
"""Record usage — sync wrapper for async Redis.
|
||||
|
||||
Note: This is a sync method because UsageTracker.record() is sync.
|
||||
For Redis, we use a sync Redis client for writes to avoid
|
||||
needing an event loop in the caller.
|
||||
"""
|
||||
if self._degraded and self._fallback is not None:
|
||||
self._fallback.record(agent_name, model, usage, cost, latency_ms)
|
||||
return
|
||||
|
||||
try:
|
||||
r = self._get_sync_redis()
|
||||
|
||||
date_key = self._today_key()
|
||||
hash_key = f"{self.USAGE_PREFIX}{date_key}"
|
||||
list_key = f"{self.RECORDS_PREFIX}{date_key}"
|
||||
bucket_field = f"{agent_name}:{model}"
|
||||
|
||||
# Atomic HINCRBYFLOAT for bucket aggregation
|
||||
pipe = r.pipeline()
|
||||
pipe.hincrbyfloat(hash_key, f"{bucket_field}:cost", cost)
|
||||
pipe.hincrby(hash_key, f"{bucket_field}:prompt_tokens", usage.prompt_tokens)
|
||||
pipe.hincrby(hash_key, f"{bucket_field}:completion_tokens", usage.completion_tokens)
|
||||
pipe.hincrby(hash_key, f"{bucket_field}:total_tokens", usage.total_tokens)
|
||||
pipe.hincrby(hash_key, f"{bucket_field}:count", 1)
|
||||
|
||||
# Append record
|
||||
rec = UsageRecord(
|
||||
agent_name=agent_name,
|
||||
model=model,
|
||||
prompt_tokens=usage.prompt_tokens,
|
||||
completion_tokens=usage.completion_tokens,
|
||||
total_tokens=usage.total_tokens,
|
||||
cost=cost,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
pipe.rpush(list_key, json.dumps({
|
||||
"agent_name": rec.agent_name,
|
||||
"model": rec.model,
|
||||
"prompt_tokens": rec.prompt_tokens,
|
||||
"completion_tokens": rec.completion_tokens,
|
||||
"total_tokens": rec.total_tokens,
|
||||
"cost": rec.cost,
|
||||
"latency_ms": rec.latency_ms,
|
||||
"timestamp": rec.timestamp,
|
||||
}))
|
||||
pipe.ltrim(list_key, -self.MAX_RECORDS_PER_DAY, -1)
|
||||
|
||||
# Set TTL on first write of the day
|
||||
pipe.expire(hash_key, self.TTL_DAYS * 86400)
|
||||
pipe.expire(list_key, self.TTL_DAYS * 86400)
|
||||
|
||||
pipe.execute()
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis usage record failed: {e}")
|
||||
self._degrade_to_fallback()
|
||||
if self._fallback is not None:
|
||||
self._fallback.record(agent_name, model, usage, cost, latency_ms)
|
||||
|
||||
def get_usage(
|
||||
self,
|
||||
agent_name: str | None = None,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
) -> UsageSummary:
|
||||
"""Query usage summary from Redis."""
|
||||
if self._degraded and self._fallback is not None:
|
||||
return self._fallback.get_usage(agent_name, start_time, end_time)
|
||||
|
||||
try:
|
||||
r = self._get_sync_redis()
|
||||
|
||||
# Determine date range to scan
|
||||
start = start_time or datetime(2020, 1, 1, tzinfo=timezone.utc)
|
||||
end = end_time or datetime.now(timezone.utc)
|
||||
|
||||
all_records: list[UsageRecord] = []
|
||||
# Scan date keys in range
|
||||
current = start.date()
|
||||
end_date = end.date()
|
||||
while current <= end_date:
|
||||
list_key = f"{self.RECORDS_PREFIX}{current.isoformat()}"
|
||||
raw_records = r.lrange(list_key, 0, -1)
|
||||
for raw in raw_records:
|
||||
data = json.loads(raw)
|
||||
rec = UsageRecord(**data)
|
||||
rec_ts = datetime.fromisoformat(rec.timestamp)
|
||||
if rec_ts >= start and rec_ts <= end:
|
||||
if agent_name is None or rec.agent_name == agent_name:
|
||||
all_records.append(rec)
|
||||
current = current + timedelta(days=1)
|
||||
|
||||
if not all_records:
|
||||
return UsageSummary()
|
||||
|
||||
total_tokens = sum(r.total_tokens for r in all_records)
|
||||
total_cost = sum(r.cost for r in all_records)
|
||||
|
||||
by_model: dict[str, dict[str, int | float]] = {}
|
||||
for r in all_records:
|
||||
if r.model not in by_model:
|
||||
by_model[r.model] = {"total_tokens": 0, "total_cost": 0.0, "count": 0}
|
||||
by_model[r.model]["total_tokens"] += r.total_tokens
|
||||
by_model[r.model]["total_cost"] += r.cost
|
||||
by_model[r.model]["count"] += 1
|
||||
|
||||
return UsageSummary(
|
||||
total_tokens=total_tokens,
|
||||
total_cost=total_cost,
|
||||
by_model=by_model,
|
||||
records=all_records,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis usage query failed: {e}")
|
||||
if self._fallback is not None:
|
||||
return self._fallback.get_usage(agent_name, start_time, end_time)
|
||||
return UsageSummary()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_usage_store(
|
||||
backend: str = "auto",
|
||||
redis_url: str = "redis://localhost:6379",
|
||||
) -> UsageStore:
|
||||
"""Create a usage store backend.
|
||||
|
||||
Args:
|
||||
backend: "auto" (try Redis, fallback to memory), "redis", "memory".
|
||||
redis_url: Redis connection URL.
|
||||
|
||||
Returns:
|
||||
A UsageStore instance.
|
||||
"""
|
||||
if backend in ("auto", "redis"):
|
||||
try:
|
||||
import redis # noqa: F401
|
||||
return RedisUsageStore(redis_url=redis_url)
|
||||
except ImportError:
|
||||
logger.warning("redis package not available, falling back to in-memory usage store")
|
||||
return InMemoryUsageStore()
|
||||
return InMemoryUsageStore()
|
||||
|
|
@ -10,7 +10,7 @@ import re
|
|||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
class MemoryFile:
|
||||
|
|
@ -26,9 +26,11 @@ class MemoryFile:
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self, path: Path, char_budget: int | None = None):
|
||||
def __init__(self, path: Path, char_budget: int | None = None,
|
||||
protected_sections: set[str] | None = None):
|
||||
self.path = Path(path)
|
||||
self.char_budget = char_budget
|
||||
self._protected_sections = protected_sections or set()
|
||||
|
||||
def read(self) -> str:
|
||||
"""读取整个文件内容,文件不存在返回空字符串."""
|
||||
|
|
@ -37,11 +39,14 @@ class MemoryFile:
|
|||
return self.path.read_text(encoding="utf-8")
|
||||
|
||||
def write(self, content: str) -> None:
|
||||
"""写入内容,自动创建父目录,超容量时自动裁剪."""
|
||||
"""写入内容,自动创建父目录,超容量时自动裁剪.
|
||||
|
||||
在内存中完成裁剪后一次性写入,避免中间不一致状态。
|
||||
"""
|
||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.path.write_text(content, encoding="utf-8")
|
||||
if self.char_budget and len(content) > self.char_budget:
|
||||
self.trim_to_budget()
|
||||
content = self._trim_content(content, self._protected_sections or None)
|
||||
self.path.write_text(content, encoding="utf-8")
|
||||
|
||||
def read_section(self, name: str) -> str:
|
||||
"""读取指定 section 的内容(不含标题行)."""
|
||||
|
|
@ -104,15 +109,64 @@ class MemoryFile:
|
|||
return []
|
||||
return re.findall(r"^## (.+)$", content, re.MULTILINE)
|
||||
|
||||
def trim_to_budget(self) -> None:
|
||||
"""裁剪内容到容量上限,优先保留前面的 section."""
|
||||
def trim_to_budget(self, protected_sections: set[str] | None = None) -> None:
|
||||
"""裁剪内容到容量上限,按 section 边界截断.
|
||||
|
||||
保持原始 section 顺序,仅从后向前移除非保护 section。
|
||||
protected_sections 中的 section 始终保留,不参与裁剪。
|
||||
"""
|
||||
if not self.char_budget:
|
||||
return
|
||||
content = self.read()
|
||||
if len(content) <= self.char_budget:
|
||||
return
|
||||
# 从末尾裁剪,保留前面的 section
|
||||
self.write(content[: self.char_budget])
|
||||
trimmed = self._trim_content(content, protected_sections)
|
||||
self.path.write_text(trimmed, encoding="utf-8")
|
||||
|
||||
def _trim_content(self, content: str, protected_sections: set[str] | None = None) -> str:
|
||||
"""在内存中裁剪内容到容量上限,返回裁剪后的字符串(不写文件).
|
||||
|
||||
保持原始 section 顺序,仅从后向前移除非保护 section。
|
||||
"""
|
||||
if not self.char_budget or len(content) <= self.char_budget:
|
||||
return content
|
||||
|
||||
protected = protected_sections or set()
|
||||
|
||||
# 解析所有 section 及其位置
|
||||
sections: list[tuple[str, int, int]] = [] # (name, start, end)
|
||||
for match in re.finditer(r"^## (.+)$", content, re.MULTILINE):
|
||||
name = match.group(1).strip()
|
||||
start = match.start()
|
||||
next_match = re.search(r"^## ", content[match.end():], re.MULTILINE)
|
||||
if next_match:
|
||||
end = match.end() + next_match.start()
|
||||
else:
|
||||
end = len(content)
|
||||
sections.append((name, start, end))
|
||||
|
||||
if not sections:
|
||||
return content[:self.char_budget]
|
||||
|
||||
# 保持原始顺序,标记每个 section 是否受保护
|
||||
ordered: list[tuple[str, str, bool]] = [] # (name, text, is_protected)
|
||||
for name, start, end in sections:
|
||||
ordered.append((name, content[start:end], name in protected))
|
||||
|
||||
# 从后向前移除非保护 section,直到总长度在预算内
|
||||
while ordered:
|
||||
total = len("\n\n".join(s[1] for s in ordered))
|
||||
if total <= self.char_budget:
|
||||
break
|
||||
# 从后向前找第一个非保护 section 移除
|
||||
for i in range(len(ordered) - 1, -1, -1):
|
||||
if not ordered[i][2]:
|
||||
ordered.pop(i)
|
||||
break
|
||||
else:
|
||||
break # 所有剩余 section 都是受保护的
|
||||
|
||||
return "\n\n".join(s[1] for s in ordered).strip()
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -168,14 +222,21 @@ class MemoryStore:
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self, base_dir: Path | str | None = None):
|
||||
def __init__(self, base_dir: Path | str | None = None,
|
||||
on_change: Callable[[str], None] | None = None):
|
||||
if base_dir is None:
|
||||
base_dir = Path.home() / ".agentkit"
|
||||
self.base_dir = Path(base_dir)
|
||||
self.base_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._on_change = on_change
|
||||
self._base_prompt: str = ""
|
||||
|
||||
# 初始化四个 MemoryFile
|
||||
self._soul = MemoryFile(self.base_dir / "SOUL.md", char_budget=SOUL_BUDGET)
|
||||
self._soul = MemoryFile(
|
||||
self.base_dir / "SOUL.md",
|
||||
char_budget=SOUL_BUDGET,
|
||||
protected_sections={"版本", "更新历史"},
|
||||
)
|
||||
self._user = MemoryFile(self.base_dir / "memories" / "USER.md", char_budget=USER_BUDGET)
|
||||
self._memory = MemoryFile(self.base_dir / "memories" / "MEMORY.md", char_budget=MEMORY_BUDGET)
|
||||
self._daily_dir = self.base_dir / "memories" / "daily"
|
||||
|
|
@ -277,6 +338,10 @@ class MemoryStore:
|
|||
|
||||
[base_prompt]
|
||||
"""
|
||||
# 保存 base_prompt 供后续刷新使用
|
||||
if base_prompt:
|
||||
self._base_prompt = base_prompt
|
||||
|
||||
parts: list[str] = []
|
||||
|
||||
if snapshot.soul:
|
||||
|
|
@ -292,3 +357,23 @@ class MemoryStore:
|
|||
parts.append(base_prompt)
|
||||
|
||||
return "\n\n".join(parts) if parts else base_prompt
|
||||
|
||||
def refresh_system_prompt(self) -> str:
|
||||
"""重新加载所有记忆文件并构建 system prompt.
|
||||
|
||||
在 MemoryTool 写入记忆后调用,确保 agent 的 _system_prompt
|
||||
反映最新的记忆内容。
|
||||
"""
|
||||
snapshot = self.load_all()
|
||||
return self.build_system_prompt(snapshot, self._base_prompt)
|
||||
|
||||
def notify_change(self) -> None:
|
||||
"""记忆文件变更后通知回调,刷新所有订阅者的 system prompt."""
|
||||
if self._on_change is None:
|
||||
return
|
||||
try:
|
||||
new_prompt = self.refresh_system_prompt()
|
||||
self._on_change(new_prompt)
|
||||
except Exception:
|
||||
import logging
|
||||
logging.getLogger(__name__).warning("Memory notify_change failed", exc_info=True)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,14 @@
|
|||
"""CascadeDetector - 独立的级联故障检测工具"""
|
||||
"""CascadeDetector - 独立的级联故障检测工具(委托给 CascadeStateStore)"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from agentkit.quality.cascade_state_store import (
|
||||
CascadeStateStore,
|
||||
InMemoryCascadeStateStore,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CascadeAlert:
|
||||
|
|
@ -19,18 +24,19 @@ class CascadeAlert:
|
|||
class CascadeDetector:
|
||||
"""检测多 agent 交互中的级联故障"""
|
||||
|
||||
def __init__(self, max_interactions: int = 10, max_depth: int = 3):
|
||||
def __init__(
|
||||
self,
|
||||
max_interactions: int = 10,
|
||||
max_depth: int = 3,
|
||||
store: CascadeStateStore | None = None,
|
||||
):
|
||||
self._max_interactions = max_interactions
|
||||
self._max_depth = max_depth
|
||||
self._interaction_counts: dict[str, int] = {}
|
||||
self._loop_depths: dict[str, int] = {}
|
||||
self._store: CascadeStateStore = store or InMemoryCascadeStateStore()
|
||||
|
||||
def check_interaction(self, session_id: str) -> CascadeAlert | None:
|
||||
"""递增并检查交互计数"""
|
||||
self._interaction_counts[session_id] = (
|
||||
self._interaction_counts.get(session_id, 0) + 1
|
||||
)
|
||||
count = self._interaction_counts[session_id]
|
||||
count = self._store.increment_interaction(session_id)
|
||||
if count > self._max_interactions:
|
||||
return CascadeAlert(
|
||||
session_id=session_id,
|
||||
|
|
@ -46,7 +52,7 @@ class CascadeDetector:
|
|||
|
||||
def check_depth(self, session_id: str, depth: int) -> CascadeAlert | None:
|
||||
"""检查循环深度"""
|
||||
self._loop_depths[session_id] = depth
|
||||
self._store.set_depth(session_id, depth)
|
||||
if depth > self._max_depth:
|
||||
return CascadeAlert(
|
||||
session_id=session_id,
|
||||
|
|
@ -62,12 +68,11 @@ class CascadeDetector:
|
|||
|
||||
def reset(self, session_id: str) -> None:
|
||||
"""重置某个 session 的计数器"""
|
||||
self._interaction_counts.pop(session_id, None)
|
||||
self._loop_depths.pop(session_id, None)
|
||||
self._store.reset(session_id)
|
||||
|
||||
def get_stats(self, session_id: str) -> dict[str, int]:
|
||||
"""获取某个 session 的当前统计"""
|
||||
return {
|
||||
"interactions": self._interaction_counts.get(session_id, 0),
|
||||
"depth": self._loop_depths.get(session_id, 0),
|
||||
"interactions": self._store.get_interaction(session_id),
|
||||
"depth": self._store.get_depth(session_id),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,245 @@
|
|||
"""Cascade State Store — Persistent cascade detection state with Redis INCR backend.
|
||||
|
||||
Provides CascadeStateStore Protocol with InMemoryCascadeStateStore and
|
||||
RedisCascadeStateStore backends. Enables CascadeDetector state to survive
|
||||
restarts and work across multiple instances.
|
||||
|
||||
Key schema (Redis):
|
||||
agentkit:cascade:interactions:{session_id} → INCR counter with TTL
|
||||
agentkit:cascade:depths:{session_id} → SET counter with TTL
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class CascadeStateStore(Protocol):
|
||||
"""Persistent cascade detection state interface."""
|
||||
|
||||
def increment_interaction(self, session_id: str) -> int:
|
||||
"""Atomically increment interaction count. Returns new count."""
|
||||
...
|
||||
|
||||
def get_interaction(self, session_id: str) -> int:
|
||||
"""Get current interaction count."""
|
||||
...
|
||||
|
||||
def set_depth(self, session_id: str, depth: int) -> None:
|
||||
"""Set loop depth for a session."""
|
||||
...
|
||||
|
||||
def get_depth(self, session_id: str) -> int:
|
||||
"""Get current loop depth."""
|
||||
...
|
||||
|
||||
def reset(self, session_id: str) -> None:
|
||||
"""Reset all counters for a session."""
|
||||
...
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# InMemoryCascadeStateStore
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class InMemoryCascadeStateStore:
|
||||
"""In-memory cascade state store (default, process-local).
|
||||
|
||||
Supports optional session TTL to prevent unbounded memory growth.
|
||||
Expired entries are lazily cleaned up on access.
|
||||
"""
|
||||
|
||||
DEFAULT_SESSION_TTL = 86400 # 24 hours
|
||||
|
||||
def __init__(self, session_ttl: int = 86400):
|
||||
self._session_ttl = session_ttl
|
||||
self._interaction_counts: dict[str, int] = {}
|
||||
self._loop_depths: dict[str, int] = {}
|
||||
self._timestamps: dict[str, float] = {}
|
||||
|
||||
def _is_expired(self, session_id: str) -> bool:
|
||||
ts = self._timestamps.get(session_id)
|
||||
if ts is None:
|
||||
return False
|
||||
return (time.monotonic() - ts) > self._session_ttl
|
||||
|
||||
def _cleanup_expired(self) -> None:
|
||||
"""Lazy cleanup: remove expired sessions."""
|
||||
expired = [sid for sid in self._timestamps if self._is_expired(sid)]
|
||||
for sid in expired:
|
||||
self._interaction_counts.pop(sid, None)
|
||||
self._loop_depths.pop(sid, None)
|
||||
self._timestamps.pop(sid, None)
|
||||
|
||||
def _touch(self, session_id: str) -> None:
|
||||
self._timestamps[session_id] = time.monotonic()
|
||||
|
||||
def increment_interaction(self, session_id: str) -> int:
|
||||
self._cleanup_expired()
|
||||
self._interaction_counts[session_id] = self._interaction_counts.get(session_id, 0) + 1
|
||||
self._touch(session_id)
|
||||
return self._interaction_counts[session_id]
|
||||
|
||||
def get_interaction(self, session_id: str) -> int:
|
||||
if self._is_expired(session_id):
|
||||
self.reset(session_id)
|
||||
return 0
|
||||
return self._interaction_counts.get(session_id, 0)
|
||||
|
||||
def set_depth(self, session_id: str, depth: int) -> None:
|
||||
self._touch(session_id)
|
||||
self._loop_depths[session_id] = depth
|
||||
|
||||
def get_depth(self, session_id: str) -> int:
|
||||
if self._is_expired(session_id):
|
||||
self.reset(session_id)
|
||||
return 0
|
||||
return self._loop_depths.get(session_id, 0)
|
||||
|
||||
def reset(self, session_id: str) -> None:
|
||||
self._interaction_counts.pop(session_id, None)
|
||||
self._loop_depths.pop(session_id, None)
|
||||
self._timestamps.pop(session_id, None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RedisCascadeStateStore
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class RedisCascadeStateStore:
|
||||
"""Redis-backed cascade state store using INCR for atomic increments.
|
||||
|
||||
Key schema:
|
||||
agentkit:cascade:interactions:{session_id} → INCR counter with TTL
|
||||
agentkit:cascade:depths:{session_id} → SET counter with TTL
|
||||
"""
|
||||
|
||||
INTER_PREFIX = "agentkit:cascade:interactions:"
|
||||
DEPTH_PREFIX = "agentkit:cascade:depths:"
|
||||
SESSION_TTL = 86400 # 24 hours — sessions rarely last longer
|
||||
|
||||
def __init__(self, redis_url: str = "redis://localhost:6379", session_ttl: int = 86400):
|
||||
self._redis_url = redis_url
|
||||
self._session_ttl = session_ttl
|
||||
self._sync_redis: Any = None
|
||||
self._fallback: InMemoryCascadeStateStore | None = None
|
||||
self._degraded = False
|
||||
|
||||
def _get_sync_redis(self):
|
||||
"""Get or create a persistent sync Redis client (connection pool backed)."""
|
||||
if self._sync_redis is None:
|
||||
import redis as sync_redis
|
||||
self._sync_redis = sync_redis.from_url(
|
||||
self._redis_url, decode_responses=True
|
||||
)
|
||||
return self._sync_redis
|
||||
|
||||
def _degrade_to_fallback(self) -> None:
|
||||
if not self._degraded:
|
||||
self._degraded = True
|
||||
if self._fallback is None:
|
||||
self._fallback = InMemoryCascadeStateStore()
|
||||
logger.warning("Redis cascade store unreachable, degraded to in-memory")
|
||||
|
||||
def increment_interaction(self, session_id: str) -> int:
|
||||
if self._degraded and self._fallback is not None:
|
||||
return self._fallback.increment_interaction(session_id)
|
||||
try:
|
||||
r = self._get_sync_redis()
|
||||
key = f"{self.INTER_PREFIX}{session_id}"
|
||||
pipe = r.pipeline()
|
||||
pipe.incr(key)
|
||||
pipe.expire(key, self._session_ttl)
|
||||
results = pipe.execute()
|
||||
return results[0]
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cascade increment failed: {e}")
|
||||
self._degrade_to_fallback()
|
||||
if self._fallback is not None:
|
||||
return self._fallback.increment_interaction(session_id)
|
||||
return 0
|
||||
|
||||
def get_interaction(self, session_id: str) -> int:
|
||||
if self._degraded and self._fallback is not None:
|
||||
return self._fallback.get_interaction(session_id)
|
||||
try:
|
||||
r = self._get_sync_redis()
|
||||
val = r.get(f"{self.INTER_PREFIX}{session_id}")
|
||||
return int(val) if val is not None else 0
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cascade get failed: {e}")
|
||||
if self._fallback is not None:
|
||||
return self._fallback.get_interaction(session_id)
|
||||
return 0
|
||||
|
||||
def set_depth(self, session_id: str, depth: int) -> None:
|
||||
if self._degraded and self._fallback is not None:
|
||||
self._fallback.set_depth(session_id, depth)
|
||||
return
|
||||
try:
|
||||
r = self._get_sync_redis()
|
||||
key = f"{self.DEPTH_PREFIX}{session_id}"
|
||||
pipe = r.pipeline()
|
||||
pipe.set(key, depth)
|
||||
pipe.expire(key, self._session_ttl)
|
||||
pipe.execute()
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cascade set_depth failed: {e}")
|
||||
self._degrade_to_fallback()
|
||||
if self._fallback is not None:
|
||||
self._fallback.set_depth(session_id, depth)
|
||||
|
||||
def get_depth(self, session_id: str) -> int:
|
||||
if self._degraded and self._fallback is not None:
|
||||
return self._fallback.get_depth(session_id)
|
||||
try:
|
||||
r = self._get_sync_redis()
|
||||
val = r.get(f"{self.DEPTH_PREFIX}{session_id}")
|
||||
return int(val) if val is not None else 0
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cascade get_depth failed: {e}")
|
||||
if self._fallback is not None:
|
||||
return self._fallback.get_depth(session_id)
|
||||
return 0
|
||||
|
||||
def reset(self, session_id: str) -> None:
|
||||
if self._degraded and self._fallback is not None:
|
||||
self._fallback.reset(session_id)
|
||||
return
|
||||
try:
|
||||
r = self._get_sync_redis()
|
||||
pipe = r.pipeline()
|
||||
pipe.delete(f"{self.INTER_PREFIX}{session_id}")
|
||||
pipe.delete(f"{self.DEPTH_PREFIX}{session_id}")
|
||||
pipe.execute()
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cascade reset failed: {e}")
|
||||
self._degrade_to_fallback()
|
||||
if self._fallback is not None:
|
||||
self._fallback.reset(session_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_cascade_state_store(
|
||||
backend: str = "auto",
|
||||
redis_url: str = "redis://localhost:6379",
|
||||
session_ttl: int = 86400,
|
||||
) -> CascadeStateStore:
|
||||
"""Create a cascade state store backend."""
|
||||
if backend in ("auto", "redis"):
|
||||
try:
|
||||
import redis # noqa: F401
|
||||
return RedisCascadeStateStore(redis_url=redis_url, session_ttl=session_ttl)
|
||||
except ImportError:
|
||||
logger.warning("redis package not available, falling back to in-memory cascade store")
|
||||
return InMemoryCascadeStateStore()
|
||||
return InMemoryCascadeStateStore()
|
||||
|
|
@ -40,7 +40,19 @@ _ALLOWED_ENV_EXACT = {'DATABASE_URL', 'REDIS_URL'}
|
|||
|
||||
def _build_llm_gateway(config: ServerConfig) -> LLMGateway:
|
||||
"""Build LLMGateway from ServerConfig, registering all providers."""
|
||||
gateway = LLMGateway(config=config.llm_config)
|
||||
# Initialize UsageStore if configured
|
||||
usage_store = None
|
||||
if config.usage_store:
|
||||
try:
|
||||
from agentkit.llm.providers.usage_store import create_usage_store
|
||||
usage_store = create_usage_store(
|
||||
backend=config.usage_store.get("backend", "memory"),
|
||||
redis_url=config.usage_store.get("redis_url", "redis://localhost:6379"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize usage store: {e}, using in-memory")
|
||||
|
||||
gateway = LLMGateway(config=config.llm_config, usage_store=usage_store)
|
||||
|
||||
for name, pconf in config.llm_config.providers.items():
|
||||
if not pconf.api_key:
|
||||
|
|
@ -111,6 +123,15 @@ async def lifespan(app: FastAPI):
|
|||
|
||||
# Start MCP servers if configured
|
||||
mcp_manager = getattr(app.state, "mcp_manager", None)
|
||||
|
||||
# Build semantic router index after skill registry is populated
|
||||
semantic_router = getattr(getattr(app.state, "cost_aware_router", None), "_semantic_router", None)
|
||||
if semantic_router is not None:
|
||||
try:
|
||||
await semantic_router.build_index(app.state.skill_registry)
|
||||
logger.info(f"Semantic router index built with {len(app.state.skill_registry.list_skills())} skills")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to build semantic router index: {e}")
|
||||
if mcp_manager is not None:
|
||||
await mcp_manager.start_all()
|
||||
|
||||
|
|
@ -142,6 +163,23 @@ async def lifespan(app: FastAPI):
|
|||
)
|
||||
effective_system_prompt = memory_store.build_system_prompt(memory_snapshot, base_prompt)
|
||||
|
||||
# Register on_change callback to refresh all agents' system prompts
|
||||
# when MemoryTool writes to memory files
|
||||
def _on_memory_change(new_prompt: str) -> None:
|
||||
pool = app.state.agent_pool
|
||||
updated = 0
|
||||
for agent_name in pool.list_agents():
|
||||
try:
|
||||
agent = pool.get_agent(agent_name)
|
||||
if agent is not None:
|
||||
agent._system_prompt = new_prompt
|
||||
updated += 1
|
||||
except Exception:
|
||||
logger.warning(f"Failed to update system prompt for agent '{agent_name}'", exc_info=True)
|
||||
logger.info(f"Memory changed: refreshed system prompt for {updated}/{len(pool.list_agents())} agents")
|
||||
|
||||
memory_store._on_change = _on_memory_change
|
||||
|
||||
# Store memory_store on app.state for chat routes to use
|
||||
app.state.memory_store = memory_store
|
||||
|
||||
|
|
@ -219,6 +257,34 @@ async def lifespan(app: FastAPI):
|
|||
from agentkit.memory.profile import MemoryStore
|
||||
memory_store = MemoryStore()
|
||||
memory_store.ensure_defaults()
|
||||
# Initialize _base_prompt so refresh_system_prompt works correctly
|
||||
snapshot = memory_store.load_all()
|
||||
base_prompt = (
|
||||
"你是一个有帮助的AI助手。请记住我们对话的上下文,并在后续对话中引用之前的内容。回答要清晰简洁,请使用中文回复。\n\n"
|
||||
"重要提示:当你不确定事实信息、时事新闻或任何你不确信的话题时,"
|
||||
"你必须先使用搜索工具查找准确和最新的信息,然后再回答。"
|
||||
"中文内容优先使用 baidu_search 工具,英文/国际内容使用 web_search。"
|
||||
"在能够搜索到真相的情况下,绝不猜测或编造答案。"
|
||||
"始终优先搜索而不是给出可能不正确的信息。\n\n"
|
||||
"技能安装:当需要安装技能时,使用 skill_install 工具,不要用 shell 执行 npm install。"
|
||||
"skill_install 的 source 参数格式为 owner/repo@skill,例如 vercel-labs/skills@find-skills。"
|
||||
"如果不知道完整 source,先用 shell 执行 `npx skills search <name>` 搜索。"
|
||||
)
|
||||
memory_store.build_system_prompt(snapshot, base_prompt)
|
||||
# Register on_change callback for existing agents
|
||||
def _on_memory_change(new_prompt: str) -> None:
|
||||
pool = app.state.agent_pool
|
||||
updated = 0
|
||||
for agent_name in pool.list_agents():
|
||||
try:
|
||||
agent = pool.get_agent(agent_name)
|
||||
if agent is not None:
|
||||
agent._system_prompt = new_prompt
|
||||
updated += 1
|
||||
except Exception:
|
||||
logger.warning(f"Failed to update system prompt for agent '{agent_name}'", exc_info=True)
|
||||
logger.info(f"Memory changed: refreshed system prompt for {updated}/{len(pool.list_agents())} agents")
|
||||
memory_store._on_change = _on_memory_change
|
||||
app.state.memory_store = memory_store
|
||||
|
||||
yield
|
||||
|
|
@ -502,12 +568,28 @@ def create_app(
|
|||
auction_enabled = False
|
||||
if server_config and hasattr(server_config, "marketplace") and server_config.marketplace:
|
||||
auction_enabled = server_config.marketplace.get("auction_enabled", False)
|
||||
|
||||
# Initialize semantic router if configured
|
||||
semantic_router = None
|
||||
router_conf = server_config.router if server_config and server_config.router else {}
|
||||
if router_conf.get("semantic", {}).get("enabled"):
|
||||
try:
|
||||
from agentkit.chat.semantic_router import SemanticRouter
|
||||
semantic_router = SemanticRouter(
|
||||
embedder=app.state.llm_gateway._embedder,
|
||||
similarity_high=router_conf["semantic"].get("similarity_high", 0.85),
|
||||
similarity_low=router_conf["semantic"].get("similarity_low", 0.6),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize semantic router: {e}")
|
||||
|
||||
cost_aware_router = CostAwareRouter(
|
||||
llm_gateway=app.state.llm_gateway,
|
||||
org_context=org_context,
|
||||
auction_enabled=auction_enabled,
|
||||
classifier=server_config.router.get("classifier", "heuristic") if server_config and server_config.router else "heuristic",
|
||||
merged_llm_classify=server_config.router.get("merged_llm_classify", True) if server_config and server_config.router else True,
|
||||
classifier=router_conf.get("classifier", "heuristic"),
|
||||
merged_llm_classify=router_conf.get("merged_llm_classify", True),
|
||||
semantic_router=semantic_router,
|
||||
)
|
||||
app.state.cost_aware_router = cost_aware_router
|
||||
# Initialize task store from config
|
||||
|
|
@ -555,14 +637,30 @@ def create_app(
|
|||
app.state.evolution_store = create_evolution_store(
|
||||
backend=evo_conf.get("backend", "memory"),
|
||||
db_path=evo_conf.get("db_path", "~/.agentkit/evolution.db"),
|
||||
database_url=evo_conf.get("database_url"),
|
||||
)
|
||||
except Exception as e:
|
||||
import logging
|
||||
logging.getLogger(__name__).warning(f"Failed to initialize evolution store: {e}")
|
||||
logger.warning(f"Failed to initialize evolution store: {e}")
|
||||
app.state.evolution_store = None
|
||||
else:
|
||||
app.state.evolution_store = None
|
||||
|
||||
# Initialize cascade state store if configured
|
||||
if server_config and hasattr(server_config, 'cascade_store') and server_config.cascade_store:
|
||||
try:
|
||||
from agentkit.quality.cascade_state_store import create_cascade_state_store
|
||||
cs_conf = server_config.cascade_store
|
||||
app.state.cascade_state_store = create_cascade_state_store(
|
||||
backend=cs_conf.get("backend", "memory"),
|
||||
redis_url=cs_conf.get("redis_url", "redis://localhost:6379"),
|
||||
session_ttl=cs_conf.get("session_ttl", 86400),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize cascade state store: {e}")
|
||||
app.state.cascade_state_store = None
|
||||
else:
|
||||
app.state.cascade_state_store = None
|
||||
|
||||
# Initialize memory components if configured
|
||||
if server_config and hasattr(server_config, 'memory') and server_config.memory:
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -111,6 +111,9 @@ class ServerConfig:
|
|||
marketplace: dict[str, Any] | None = None,
|
||||
alignment: dict[str, Any] | None = None,
|
||||
router: dict[str, Any] | None = None,
|
||||
usage_store: dict[str, Any] | None = None,
|
||||
cascade_store: dict[str, Any] | None = None,
|
||||
evolution: dict[str, Any] | None = None,
|
||||
on_change: Callable[["ServerConfig"], None] | None = None,
|
||||
):
|
||||
self.host = host
|
||||
|
|
@ -134,6 +137,9 @@ class ServerConfig:
|
|||
self.marketplace = marketplace or {}
|
||||
self.alignment = alignment or {}
|
||||
self.router = router or {}
|
||||
self.usage_store = usage_store or {}
|
||||
self.cascade_store = cascade_store or {}
|
||||
self.evolution = evolution or {}
|
||||
self.on_change = on_change
|
||||
|
||||
# Config watching state
|
||||
|
|
@ -201,6 +207,15 @@ class ServerConfig:
|
|||
# Router config
|
||||
router_data = data.get("router", {})
|
||||
|
||||
# Usage store config
|
||||
usage_store_data = data.get("usage_store", {})
|
||||
|
||||
# Cascade store config
|
||||
cascade_store_data = data.get("cascade_store", {})
|
||||
|
||||
# Evolution store config
|
||||
evolution_data = data.get("evolution", {})
|
||||
|
||||
return cls(
|
||||
host=server.get("host", "0.0.0.0"),
|
||||
port=server.get("port", 8001),
|
||||
|
|
@ -223,11 +238,16 @@ class ServerConfig:
|
|||
marketplace=marketplace_data,
|
||||
alignment=alignment_data,
|
||||
router=router_data,
|
||||
usage_store=usage_store_data,
|
||||
cascade_store=cascade_store_data,
|
||||
evolution=evolution_data,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_llm_config(data: dict) -> LLMConfig:
|
||||
"""Build LLMConfig from the llm section of agentkit.yaml."""
|
||||
from agentkit.llm.config import CacheConfig
|
||||
|
||||
providers = {}
|
||||
model_aliases = {}
|
||||
|
||||
|
|
@ -254,10 +274,17 @@ class ServerConfig:
|
|||
keepalive_expiry=pconf.get("keepalive_expiry", 30.0),
|
||||
)
|
||||
|
||||
# Build CacheConfig if cache section is present
|
||||
cache_config = None
|
||||
cache_data = data.get("cache")
|
||||
if cache_data and isinstance(cache_data, dict):
|
||||
cache_config = CacheConfig.from_dict(cache_data)
|
||||
|
||||
return LLMConfig(
|
||||
providers=providers,
|
||||
model_aliases=model_aliases,
|
||||
fallbacks=data.get("fallbacks", {}),
|
||||
cache=cache_config,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -0,0 +1,422 @@
|
|||
"""P0 Production Hardening — End-to-End Integration Tests
|
||||
|
||||
Verifies the full configuration wiring and feature integration:
|
||||
- Config from YAML → all features configured correctly
|
||||
- Cache + usage tracking: cached requests show 0 cost
|
||||
- UsageStore persistence via config
|
||||
- CascadeStateStore persistence via config
|
||||
- EvolutionStore via config
|
||||
- Semantic router via config
|
||||
- Graceful degradation when backends unavailable
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.core.protocol import EvolutionEvent
|
||||
from agentkit.llm.config import CacheConfig, LLMConfig, ProviderConfig
|
||||
from agentkit.llm.gateway import LLMGateway
|
||||
from agentkit.llm.protocol import TokenUsage
|
||||
from agentkit.llm.providers.usage_store import (
|
||||
InMemoryUsageStore,
|
||||
create_usage_store,
|
||||
)
|
||||
from agentkit.quality.cascade_detector import CascadeDetector
|
||||
from agentkit.quality.cascade_state_store import (
|
||||
InMemoryCascadeStateStore,
|
||||
create_cascade_state_store,
|
||||
)
|
||||
from agentkit.evolution.evolution_store import (
|
||||
EvolutionStoreProtocol,
|
||||
InMemoryEvolutionStore,
|
||||
PersistentEvolutionStore,
|
||||
create_evolution_store,
|
||||
)
|
||||
from agentkit.server.config import ServerConfig
|
||||
|
||||
|
||||
# ── Config parsing tests ──────────────────────────────────
|
||||
|
||||
|
||||
class TestConfigParsing:
|
||||
"""Verify ServerConfig correctly parses all new config sections."""
|
||||
|
||||
def test_usage_store_config_parsed(self):
|
||||
data = {
|
||||
"usage_store": {
|
||||
"backend": "redis",
|
||||
"redis_url": "redis://custom:6380",
|
||||
}
|
||||
}
|
||||
config = ServerConfig.from_dict(data)
|
||||
assert config.usage_store["backend"] == "redis"
|
||||
assert config.usage_store["redis_url"] == "redis://custom:6380"
|
||||
|
||||
def test_cascade_store_config_parsed(self):
|
||||
data = {
|
||||
"cascade_store": {
|
||||
"backend": "redis",
|
||||
"redis_url": "redis://custom:6380",
|
||||
"session_ttl": 3600,
|
||||
}
|
||||
}
|
||||
config = ServerConfig.from_dict(data)
|
||||
assert config.cascade_store["backend"] == "redis"
|
||||
assert config.cascade_store["session_ttl"] == 3600
|
||||
|
||||
def test_evolution_config_parsed(self):
|
||||
data = {
|
||||
"evolution": {
|
||||
"backend": "sqlite",
|
||||
"db_path": "/tmp/test.db",
|
||||
}
|
||||
}
|
||||
config = ServerConfig.from_dict(data)
|
||||
assert config.evolution["backend"] == "sqlite"
|
||||
assert config.evolution["db_path"] == "/tmp/test.db"
|
||||
|
||||
def test_llm_cache_config_parsed(self):
|
||||
data = {
|
||||
"llm": {
|
||||
"providers": {},
|
||||
"cache": {
|
||||
"enabled": True,
|
||||
"backend": "memory",
|
||||
"exact_ttl": 7200,
|
||||
},
|
||||
}
|
||||
}
|
||||
config = ServerConfig.from_dict(data)
|
||||
assert config.llm_config.cache is not None
|
||||
assert config.llm_config.cache.enabled is True
|
||||
assert config.llm_config.cache.backend == "memory"
|
||||
assert config.llm_config.cache.exact_ttl == 7200
|
||||
|
||||
def test_router_semantic_config_parsed(self):
|
||||
data = {
|
||||
"router": {
|
||||
"classifier": "heuristic",
|
||||
"semantic": {
|
||||
"enabled": True,
|
||||
"similarity_high": 0.9,
|
||||
"similarity_low": 0.5,
|
||||
},
|
||||
}
|
||||
}
|
||||
config = ServerConfig.from_dict(data)
|
||||
assert config.router["semantic"]["enabled"] is True
|
||||
assert config.router["semantic"]["similarity_high"] == 0.9
|
||||
|
||||
def test_empty_config_defaults(self):
|
||||
config = ServerConfig.from_dict({})
|
||||
assert config.usage_store == {}
|
||||
assert config.cascade_store == {}
|
||||
assert config.evolution == {}
|
||||
assert config.llm_config.cache is None
|
||||
|
||||
def test_config_from_yaml_roundtrip(self, tmp_path):
|
||||
"""Config can be loaded from a YAML file with all new sections."""
|
||||
yaml_content = """
|
||||
server:
|
||||
host: 0.0.0.0
|
||||
port: 8001
|
||||
llm:
|
||||
providers: {}
|
||||
cache:
|
||||
enabled: true
|
||||
backend: memory
|
||||
router:
|
||||
classifier: heuristic
|
||||
semantic:
|
||||
enabled: false
|
||||
usage_store:
|
||||
backend: memory
|
||||
cascade_store:
|
||||
backend: memory
|
||||
evolution:
|
||||
backend: memory
|
||||
"""
|
||||
yaml_path = str(tmp_path / "test_config.yaml")
|
||||
with open(yaml_path, "w") as f:
|
||||
f.write(yaml_content)
|
||||
|
||||
config = ServerConfig.from_yaml(yaml_path)
|
||||
assert config.llm_config.cache is not None
|
||||
assert config.llm_config.cache.enabled is True
|
||||
assert config.usage_store["backend"] == "memory"
|
||||
assert config.cascade_store["backend"] == "memory"
|
||||
assert config.evolution["backend"] == "memory"
|
||||
|
||||
|
||||
# ── UsageStore integration tests ───────────────────────────
|
||||
|
||||
|
||||
class TestUsageStoreIntegration:
|
||||
"""Verify UsageStore works with LLMGateway."""
|
||||
|
||||
async def test_gateway_with_usage_store(self):
|
||||
"""LLMGateway uses injected UsageStore for tracking."""
|
||||
store = InMemoryUsageStore()
|
||||
gateway = LLMGateway(usage_store=store)
|
||||
|
||||
# Record usage directly through the tracker
|
||||
gateway._usage_tracker.record(
|
||||
agent_name="test_agent",
|
||||
model="gpt-4",
|
||||
usage=TokenUsage(prompt_tokens=100, completion_tokens=50),
|
||||
cost=0.01,
|
||||
latency_ms=100.0,
|
||||
)
|
||||
|
||||
usage = gateway.get_usage()
|
||||
assert usage.total_tokens > 0
|
||||
assert usage.total_cost > 0
|
||||
|
||||
async def test_gateway_without_usage_store(self):
|
||||
"""LLMGateway works without explicit UsageStore (uses InMemory)."""
|
||||
gateway = LLMGateway()
|
||||
gateway._usage_tracker.record(
|
||||
agent_name="test_agent",
|
||||
model="gpt-4",
|
||||
usage=TokenUsage(prompt_tokens=100, completion_tokens=50),
|
||||
cost=0.01,
|
||||
latency_ms=100.0,
|
||||
)
|
||||
|
||||
usage = gateway.get_usage()
|
||||
assert usage.total_tokens > 0
|
||||
|
||||
def test_create_usage_store_memory(self):
|
||||
store = create_usage_store(backend="memory")
|
||||
assert isinstance(store, InMemoryUsageStore)
|
||||
|
||||
def test_create_usage_store_redis_lazy(self):
|
||||
"""Redis backend creates RedisUsageStore (lazy connection, degrades on first op)."""
|
||||
from agentkit.llm.providers.usage_store import RedisUsageStore
|
||||
|
||||
store = create_usage_store(
|
||||
backend="redis",
|
||||
redis_url="redis://nonexistent:6379",
|
||||
)
|
||||
# RedisUsageStore is created (lazy connection), degrades on first operation
|
||||
assert isinstance(store, RedisUsageStore)
|
||||
|
||||
|
||||
# ── CascadeStateStore integration tests ────────────────────
|
||||
|
||||
|
||||
class TestCascadeStateStoreIntegration:
|
||||
"""Verify CascadeStateStore works with CascadeDetector."""
|
||||
|
||||
async def test_cascade_detector_with_store(self):
|
||||
"""CascadeDetector uses injected CascadeStateStore."""
|
||||
store = InMemoryCascadeStateStore()
|
||||
detector = CascadeDetector(store=store)
|
||||
|
||||
# Check interaction — should not trigger cascade
|
||||
result = detector.check_interaction(session_id="test-session")
|
||||
assert result is None # No cascade alert
|
||||
|
||||
async def test_cascade_detector_without_store(self):
|
||||
"""CascadeDetector works without explicit store (uses InMemory)."""
|
||||
detector = CascadeDetector()
|
||||
result = detector.check_interaction(session_id="test-session")
|
||||
assert result is None
|
||||
|
||||
def test_create_cascade_state_store_memory(self):
|
||||
store = create_cascade_state_store(backend="memory")
|
||||
assert isinstance(store, InMemoryCascadeStateStore)
|
||||
|
||||
|
||||
# ── EvolutionStore integration tests ───────────────────────
|
||||
|
||||
|
||||
class TestEvolutionStoreIntegration:
|
||||
"""Verify EvolutionStore creation from config."""
|
||||
|
||||
async def test_create_evolution_store_from_config(self, tmp_path):
|
||||
"""EvolutionStore created from config dict works correctly."""
|
||||
db_path = str(tmp_path / "evo_test.db")
|
||||
store = create_evolution_store(backend="sqlite", db_path=db_path)
|
||||
assert isinstance(store, PersistentEvolutionStore)
|
||||
|
||||
event = EvolutionEvent(
|
||||
agent_name="test_agent",
|
||||
change_type="prompt",
|
||||
before={"old": 1},
|
||||
after={"new": 2},
|
||||
)
|
||||
event_id = await store.record(event)
|
||||
assert event_id is not None
|
||||
|
||||
events = await store.list_events()
|
||||
assert len(events) == 1
|
||||
assert events[0]["agent_name"] == "test_agent"
|
||||
|
||||
async def test_create_evolution_store_memory(self):
|
||||
store = create_evolution_store(backend="memory")
|
||||
assert isinstance(store, InMemoryEvolutionStore)
|
||||
|
||||
event = EvolutionEvent(
|
||||
agent_name="test_agent",
|
||||
change_type="prompt",
|
||||
before={},
|
||||
after={},
|
||||
)
|
||||
event_id = await store.record(event)
|
||||
assert event_id is not None
|
||||
|
||||
async def test_evolution_store_protocol_compliance(self):
|
||||
"""All created stores satisfy EvolutionStoreProtocol."""
|
||||
memory_store = create_evolution_store(backend="memory")
|
||||
assert isinstance(memory_store, EvolutionStoreProtocol)
|
||||
|
||||
|
||||
# ── Cache integration tests ────────────────────────────────
|
||||
|
||||
|
||||
class TestCacheIntegration:
|
||||
"""Verify LLMCache integration with LLMGateway via config."""
|
||||
|
||||
def test_gateway_with_cache_config(self):
|
||||
"""LLMGateway initializes cache when CacheConfig is provided."""
|
||||
config = LLMConfig(
|
||||
providers={},
|
||||
cache=CacheConfig(enabled=True, backend="memory"),
|
||||
)
|
||||
gateway = LLMGateway(config=config)
|
||||
assert gateway._cache is not None
|
||||
|
||||
def test_gateway_without_cache_config(self):
|
||||
"""LLMGateway works without cache (default)."""
|
||||
config = LLMConfig(providers={})
|
||||
gateway = LLMGateway(config=config)
|
||||
assert gateway._cache is None
|
||||
|
||||
def test_gateway_cache_disabled(self):
|
||||
"""LLMGateway does not initialize cache when disabled."""
|
||||
config = LLMConfig(
|
||||
providers={},
|
||||
cache=CacheConfig(enabled=False),
|
||||
)
|
||||
gateway = LLMGateway(config=config)
|
||||
assert gateway._cache is None
|
||||
|
||||
|
||||
# ── Graceful degradation tests ─────────────────────────────
|
||||
|
||||
|
||||
class TestGracefulDegradation:
|
||||
"""Verify all features degrade gracefully when backends unavailable."""
|
||||
|
||||
def test_usage_store_auto_creates_redis(self):
|
||||
"""auto backend creates RedisUsageStore (lazy connection)."""
|
||||
from agentkit.llm.providers.usage_store import RedisUsageStore
|
||||
|
||||
store = create_usage_store(
|
||||
backend="auto",
|
||||
redis_url="redis://nonexistent:6379",
|
||||
)
|
||||
# Redis is available as package, so RedisUsageStore is created
|
||||
assert isinstance(store, RedisUsageStore)
|
||||
|
||||
def test_cascade_store_redis_lazy(self):
|
||||
"""CascadeStateStore Redis backend creates instance (lazy connection)."""
|
||||
from agentkit.quality.cascade_state_store import RedisCascadeStateStore
|
||||
|
||||
store = create_cascade_state_store(
|
||||
backend="redis",
|
||||
redis_url="redis://nonexistent:6379",
|
||||
)
|
||||
assert isinstance(store, RedisCascadeStateStore)
|
||||
|
||||
def test_evolution_store_postgresql_unavailable(self):
|
||||
"""EvolutionStore falls back to InMemory when PG unavailable."""
|
||||
store = create_evolution_store(
|
||||
backend="postgresql",
|
||||
database_url=None,
|
||||
)
|
||||
assert isinstance(store, InMemoryEvolutionStore)
|
||||
|
||||
def test_cache_auto_creates_redis(self):
|
||||
"""LLMCache auto backend creates RedisLLMCache (lazy connection)."""
|
||||
from agentkit.llm.cache import create_llm_cache, RedisLLMCache
|
||||
|
||||
cache = create_llm_cache(
|
||||
backend="auto",
|
||||
redis_url="redis://nonexistent:6379",
|
||||
)
|
||||
# Redis package is available, so RedisLLMCache is created (lazy connection)
|
||||
assert isinstance(cache, RedisLLMCache)
|
||||
|
||||
|
||||
# ── Full flow test (in-memory) ─────────────────────────────
|
||||
|
||||
|
||||
class TestFullFlowInMemory:
|
||||
"""End-to-end flow test using in-memory backends (no external deps)."""
|
||||
|
||||
async def test_config_to_components(self):
|
||||
"""ServerConfig → all components initialized correctly."""
|
||||
data = {
|
||||
"llm": {
|
||||
"providers": {},
|
||||
"cache": {
|
||||
"enabled": True,
|
||||
"backend": "memory",
|
||||
},
|
||||
},
|
||||
"router": {
|
||||
"classifier": "heuristic",
|
||||
},
|
||||
"usage_store": {"backend": "memory"},
|
||||
"cascade_store": {"backend": "memory"},
|
||||
"evolution": {"backend": "memory"},
|
||||
}
|
||||
config = ServerConfig.from_dict(data)
|
||||
|
||||
# Verify config parsed correctly
|
||||
assert config.llm_config.cache is not None
|
||||
assert config.llm_config.cache.enabled is True
|
||||
assert config.usage_store["backend"] == "memory"
|
||||
assert config.cascade_store["backend"] == "memory"
|
||||
assert config.evolution["backend"] == "memory"
|
||||
|
||||
# Create components from config
|
||||
usage_store = create_usage_store(
|
||||
backend=config.usage_store.get("backend", "memory"),
|
||||
redis_url=config.usage_store.get("redis_url", "redis://localhost:6379"),
|
||||
)
|
||||
gateway = LLMGateway(config=config.llm_config, usage_store=usage_store)
|
||||
assert gateway._cache is not None
|
||||
|
||||
cascade_store = create_cascade_state_store(
|
||||
backend=config.cascade_store.get("backend", "memory"),
|
||||
)
|
||||
detector = CascadeDetector(store=cascade_store)
|
||||
|
||||
evo_store = create_evolution_store(
|
||||
backend=config.evolution.get("backend", "memory"),
|
||||
)
|
||||
|
||||
# Exercise the components
|
||||
gateway._usage_tracker.record(
|
||||
agent_name="test",
|
||||
model="gpt-4",
|
||||
usage=TokenUsage(prompt_tokens=100, completion_tokens=50),
|
||||
cost=0.01,
|
||||
latency_ms=100.0,
|
||||
)
|
||||
usage = gateway.get_usage()
|
||||
assert usage.total_tokens == 150
|
||||
|
||||
result = detector.check_interaction("s1")
|
||||
assert result is None # No cascade alert
|
||||
|
||||
event = EvolutionEvent(
|
||||
agent_name="test", change_type="prompt", before={}, after={}
|
||||
)
|
||||
event_id = await evo_store.record(event)
|
||||
assert event_id is not None
|
||||
|
|
@ -0,0 +1,162 @@
|
|||
"""Integration tests for LLM Cache integration into LLMGateway (U2)."""
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.llm.cache import InMemoryLLMCache
|
||||
from agentkit.llm.config import CacheConfig, LLMConfig
|
||||
from agentkit.llm.gateway import LLMGateway
|
||||
from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, TokenUsage
|
||||
|
||||
|
||||
class MockProvider(LLMProvider):
|
||||
"""Mock LLM provider that tracks call count."""
|
||||
|
||||
def __init__(self, response_content: str = "Mock response"):
|
||||
self.call_count = 0
|
||||
self._response_content = response_content
|
||||
|
||||
async def chat(self, request: LLMRequest) -> LLMResponse:
|
||||
self.call_count += 1
|
||||
return LLMResponse(
|
||||
content=self._response_content,
|
||||
model=request.model,
|
||||
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
|
||||
)
|
||||
|
||||
|
||||
def _make_messages(user_content: str = "Hello") -> list[dict[str, str]]:
|
||||
return [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": user_content},
|
||||
]
|
||||
|
||||
|
||||
class TestCacheDisabled:
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_cache_by_default(self):
|
||||
"""Cache is disabled by default — requests always hit provider."""
|
||||
gateway = LLMGateway()
|
||||
provider = MockProvider()
|
||||
gateway.register_provider("test", provider)
|
||||
|
||||
msgs = _make_messages()
|
||||
await gateway.chat(msgs, "test/model")
|
||||
await gateway.chat(msgs, "test/model")
|
||||
|
||||
assert provider.call_count == 2
|
||||
|
||||
|
||||
class TestCacheEnabled:
|
||||
@pytest.mark.asyncio
|
||||
async def test_first_request_is_miss(self):
|
||||
"""First request is a cache miss — provider is called."""
|
||||
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
|
||||
gateway = LLMGateway(config=config)
|
||||
provider = MockProvider()
|
||||
gateway.register_provider("test", provider)
|
||||
|
||||
msgs = _make_messages()
|
||||
response = await gateway.chat(msgs, "test/model", temperature=0.0)
|
||||
|
||||
assert provider.call_count == 1
|
||||
assert response.content == "Mock response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_second_request_is_hit(self):
|
||||
"""Second identical request is a cache hit — provider NOT called."""
|
||||
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
|
||||
gateway = LLMGateway(config=config)
|
||||
provider = MockProvider()
|
||||
gateway.register_provider("test", provider)
|
||||
|
||||
msgs = _make_messages()
|
||||
await gateway.chat(msgs, "test/model", temperature=0.0)
|
||||
response = await gateway.chat(msgs, "test/model", temperature=0.0)
|
||||
|
||||
assert provider.call_count == 1 # Not called again
|
||||
assert response.content == "Mock response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_hit_usage_has_zero_cost(self):
|
||||
"""Cache hit records usage with cost=0."""
|
||||
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
|
||||
gateway = LLMGateway(config=config)
|
||||
provider = MockProvider()
|
||||
gateway.register_provider("test", provider)
|
||||
|
||||
msgs = _make_messages()
|
||||
await gateway.chat(msgs, "test/model", agent_name="agent1", temperature=0.0)
|
||||
await gateway.chat(msgs, "test/model", agent_name="agent1", temperature=0.0)
|
||||
|
||||
usage = gateway.get_usage(agent_name="agent1")
|
||||
# First request has cost, second (cache hit) has cost=0
|
||||
assert usage.total_cost == 0.0 # No cost config, so both are 0
|
||||
assert len(usage.records) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_messages_are_miss(self):
|
||||
"""Different messages produce cache misses."""
|
||||
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
|
||||
gateway = LLMGateway(config=config)
|
||||
provider = MockProvider()
|
||||
gateway.register_provider("test", provider)
|
||||
|
||||
await gateway.chat(_make_messages("Hello"), "test/model", temperature=0.0)
|
||||
await gateway.chat(_make_messages("World"), "test/model", temperature=0.0)
|
||||
|
||||
assert provider.call_count == 2
|
||||
|
||||
|
||||
class TestCacheConfig:
|
||||
def test_config_from_dict(self):
|
||||
"""CacheConfig can be loaded from dict."""
|
||||
config = LLMConfig.from_dict({
|
||||
"cache": {
|
||||
"enabled": True,
|
||||
"backend": "memory",
|
||||
"exact_ttl": 7200,
|
||||
}
|
||||
})
|
||||
assert config.cache is not None
|
||||
assert config.cache.enabled is True
|
||||
assert config.cache.backend == "memory"
|
||||
assert config.cache.exact_ttl == 7200
|
||||
|
||||
def test_config_from_dict_no_cache(self):
|
||||
"""No cache section in config → cache is None."""
|
||||
config = LLMConfig.from_dict({})
|
||||
assert config.cache is None
|
||||
|
||||
def test_config_from_dict_embedding(self):
|
||||
"""Embedding config is loaded correctly."""
|
||||
config = LLMConfig.from_dict({
|
||||
"cache": {
|
||||
"enabled": True,
|
||||
"embedding": {
|
||||
"provider": "xinference",
|
||||
"model": "bge-m3",
|
||||
"base_url": "http://localhost:9997/v1",
|
||||
},
|
||||
}
|
||||
})
|
||||
assert config.cache.embedding_provider == "xinference"
|
||||
assert config.cache.embedding_model == "bge-m3"
|
||||
assert config.cache.embedding_base_url == "http://localhost:9997/v1"
|
||||
|
||||
def test_gateway_creates_cache_when_enabled(self):
|
||||
"""Gateway creates cache instance when cache.enabled=True."""
|
||||
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
|
||||
gateway = LLMGateway(config=config)
|
||||
assert gateway._cache is not None
|
||||
assert isinstance(gateway._cache, InMemoryLLMCache)
|
||||
|
||||
def test_gateway_no_cache_when_disabled(self):
|
||||
"""Gateway has no cache when cache is disabled."""
|
||||
config = LLMConfig(cache=CacheConfig(enabled=False))
|
||||
gateway = LLMGateway(config=config)
|
||||
assert gateway._cache is None
|
||||
|
||||
def test_gateway_no_cache_when_no_config(self):
|
||||
"""Gateway has no cache when cache config is absent."""
|
||||
gateway = LLMGateway()
|
||||
assert gateway._cache is None
|
||||
|
|
@ -0,0 +1,604 @@
|
|||
"""Unit tests for LLM Cache Core (U1).
|
||||
|
||||
Tests cover:
|
||||
- CacheKey generation (deterministic, component isolation)
|
||||
- InMemoryLLMCache (exact match, semantic match, TTL, LRU, stats)
|
||||
- RedisLLMCache (same tests with mocked Redis)
|
||||
- Factory function (backend selection, fallback)
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.llm.cache import (
|
||||
CacheEntry,
|
||||
CacheResult,
|
||||
InMemoryLLMCache,
|
||||
RedisLLMCache,
|
||||
create_llm_cache,
|
||||
_serialize_response,
|
||||
_deserialize_response,
|
||||
_serialize_entry,
|
||||
_deserialize_entry,
|
||||
)
|
||||
from agentkit.llm.cache_key import generate_cache_key
|
||||
from agentkit.llm.protocol import LLMResponse, TokenUsage, ToolCall
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_response(
|
||||
content: str = "Hello",
|
||||
model: str = "gpt-4o",
|
||||
prompt_tokens: int = 10,
|
||||
completion_tokens: int = 20,
|
||||
tool_calls: list[ToolCall] | None = None,
|
||||
) -> LLMResponse:
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=model,
|
||||
usage=TokenUsage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
),
|
||||
tool_calls=tool_calls or [],
|
||||
latency_ms=100.0,
|
||||
)
|
||||
|
||||
|
||||
def _make_messages(user_content: str = "Hello") -> list[dict[str, str]]:
|
||||
return [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": user_content},
|
||||
]
|
||||
|
||||
|
||||
def _make_embedding(base_val: float = 1.0, dim: int = 128) -> list[float]:
|
||||
"""Create a unit vector for similarity testing."""
|
||||
vec = [base_val] * dim
|
||||
magnitude = sum(x**2 for x in vec) ** 0.5
|
||||
return [x / magnitude for x in vec] if magnitude > 0 else vec
|
||||
|
||||
|
||||
def _make_similar_embedding(base: list[float], noise: float = 0.01) -> list[float]:
|
||||
"""Create a vector similar to base with small noise."""
|
||||
vec = [x + noise for x in base]
|
||||
magnitude = sum(x**2 for x in vec) ** 0.5
|
||||
return [x / magnitude for x in vec] if magnitude > 0 else vec
|
||||
|
||||
|
||||
def _make_different_embedding(dim: int = 128) -> list[float]:
|
||||
"""Create a vector with very low cosine similarity to _make_embedding()."""
|
||||
# _make_embedding(1.0) is all-positive unit vector.
|
||||
# Negate first half to create near-orthogonal vector.
|
||||
half = dim // 2
|
||||
vec = [-1.0] * half + [1.0] * (dim - half)
|
||||
magnitude = sum(x**2 for x in vec) ** 0.5
|
||||
return [x / magnitude for x in vec] if magnitude > 0 else vec
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CacheKey Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCacheKey:
|
||||
def test_deterministic(self):
|
||||
"""Same inputs produce same key."""
|
||||
msgs = _make_messages()
|
||||
key1 = generate_cache_key("gpt-4o", msgs, 0.0)
|
||||
key2 = generate_cache_key("gpt-4o", msgs, 0.0)
|
||||
assert key1 == key2
|
||||
assert len(key1) == 64 # SHA-256 hex
|
||||
|
||||
def test_different_model(self):
|
||||
"""Different model produces different key."""
|
||||
msgs = _make_messages()
|
||||
key1 = generate_cache_key("gpt-4o", msgs, 0.0)
|
||||
key2 = generate_cache_key("gpt-3.5-turbo", msgs, 0.0)
|
||||
assert key1 != key2
|
||||
|
||||
def test_different_temperature(self):
|
||||
"""Different temperature produces different key."""
|
||||
msgs = _make_messages()
|
||||
key1 = generate_cache_key("gpt-4o", msgs, 0.0)
|
||||
key2 = generate_cache_key("gpt-4o", msgs, 0.7)
|
||||
assert key1 != key2
|
||||
|
||||
def test_different_messages(self):
|
||||
"""Different messages produce different key."""
|
||||
key1 = generate_cache_key("gpt-4o", _make_messages("Hello"), 0.0)
|
||||
key2 = generate_cache_key("gpt-4o", _make_messages("World"), 0.0)
|
||||
assert key1 != key2
|
||||
|
||||
def test_different_tools(self):
|
||||
"""Different tools produce different key."""
|
||||
msgs = _make_messages()
|
||||
tools1 = [{"type": "function", "function": {"name": "f1"}}]
|
||||
tools2 = [{"type": "function", "function": {"name": "f2"}}]
|
||||
key1 = generate_cache_key("gpt-4o", msgs, 0.0, tools=tools1)
|
||||
key2 = generate_cache_key("gpt-4o", msgs, 0.0, tools=tools2)
|
||||
assert key1 != key2
|
||||
|
||||
def test_none_tools_same_as_no_tools(self):
|
||||
"""None tools and no tools produce same key."""
|
||||
msgs = _make_messages()
|
||||
key1 = generate_cache_key("gpt-4o", msgs, 0.0, tools=None)
|
||||
key2 = generate_cache_key("gpt-4o", msgs, 0.0)
|
||||
assert key1 == key2
|
||||
|
||||
def test_system_prompt_extracted_from_messages(self):
|
||||
"""System prompt is extracted from messages[0] with role=system."""
|
||||
msgs = [
|
||||
{"role": "system", "content": "Be concise"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
]
|
||||
key1 = generate_cache_key("gpt-4o", msgs, 0.0)
|
||||
|
||||
msgs2 = [
|
||||
{"role": "system", "content": "Be verbose"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
]
|
||||
key2 = generate_cache_key("gpt-4o", msgs2, 0.0)
|
||||
assert key1 != key2
|
||||
|
||||
def test_max_tokens_affects_key(self):
|
||||
"""Different max_tokens produce different key."""
|
||||
msgs = _make_messages()
|
||||
key1 = generate_cache_key("gpt-4o", msgs, 0.0, max_tokens=2000)
|
||||
key2 = generate_cache_key("gpt-4o", msgs, 0.0, max_tokens=4000)
|
||||
assert key1 != key2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# InMemoryLLMCache Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestInMemoryLLMCache:
|
||||
@pytest.mark.asyncio
|
||||
async def test_exact_match_hit(self):
|
||||
cache = InMemoryLLMCache()
|
||||
key = "test_key_1"
|
||||
response = _make_response("Cached answer")
|
||||
|
||||
await cache.put(key, response)
|
||||
result = await cache.get(key)
|
||||
|
||||
assert result.hit is True
|
||||
assert result.match_type == "exact"
|
||||
assert result.response.content == "Cached answer"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exact_match_miss(self):
|
||||
cache = InMemoryLLMCache()
|
||||
await cache.put("key1", _make_response())
|
||||
|
||||
result = await cache.get("key2")
|
||||
assert result.hit is False
|
||||
assert result.response is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_semantic_match_hit(self):
|
||||
cache = InMemoryLLMCache(similarity_threshold=0.9)
|
||||
emb1 = _make_embedding(1.0, dim=64)
|
||||
emb_similar = _make_similar_embedding(emb1, noise=0.001)
|
||||
|
||||
await cache.put("key1", _make_response("Cached"), query_embedding=emb1)
|
||||
result = await cache.semantic_search(emb_similar, threshold=0.9)
|
||||
|
||||
assert result.hit is True
|
||||
assert result.match_type == "semantic"
|
||||
assert result.response.content == "Cached"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_semantic_match_miss(self):
|
||||
cache = InMemoryLLMCache(similarity_threshold=0.9)
|
||||
emb1 = _make_embedding(1.0, dim=64)
|
||||
emb_different = _make_different_embedding(dim=64)
|
||||
|
||||
await cache.put("key1", _make_response("Cached"), query_embedding=emb1)
|
||||
result = await cache.semantic_search(emb_different, threshold=0.9)
|
||||
|
||||
assert result.hit is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_semantic_match_empty_cache(self):
|
||||
cache = InMemoryLLMCache()
|
||||
result = await cache.semantic_search(_make_embedding(dim=64))
|
||||
assert result.hit is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ttl_expiry_exact(self):
|
||||
cache = InMemoryLLMCache(exact_ttl=1) # 1 second TTL
|
||||
await cache.put("key1", _make_response())
|
||||
|
||||
# Wait for expiry
|
||||
time.sleep(1.1)
|
||||
result = await cache.get("key1")
|
||||
assert result.hit is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ttl_expiry_semantic(self):
|
||||
cache = InMemoryLLMCache(semantic_ttl=1)
|
||||
emb = _make_embedding(dim=64)
|
||||
await cache.put("key1", _make_response(), query_embedding=emb)
|
||||
|
||||
time.sleep(1.1)
|
||||
result = await cache.semantic_search(emb)
|
||||
assert result.hit is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lru_eviction(self):
|
||||
cache = InMemoryLLMCache(max_entries=3)
|
||||
|
||||
for i in range(4):
|
||||
await cache.put(f"key{i}", _make_response(f"Response {i}"))
|
||||
|
||||
# key0 should be evicted (oldest)
|
||||
result = await cache.get("key0")
|
||||
assert result.hit is False
|
||||
|
||||
# key1-key3 should still be present
|
||||
for i in range(1, 4):
|
||||
result = await cache.get(f"key{i}")
|
||||
assert result.hit is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lru_access_refreshes(self):
|
||||
cache = InMemoryLLMCache(max_entries=3)
|
||||
|
||||
await cache.put("key0", _make_response("R0"))
|
||||
await cache.put("key1", _make_response("R1"))
|
||||
await cache.put("key2", _make_response("R2"))
|
||||
|
||||
# Access key0 to move it to most-recently-used
|
||||
await cache.get("key0")
|
||||
|
||||
# Adding key3 should evict key1 (now LRU)
|
||||
await cache.put("key3", _make_response("R3"))
|
||||
|
||||
result = await cache.get("key1")
|
||||
assert result.hit is False
|
||||
|
||||
result = await cache.get("key0")
|
||||
assert result.hit is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalidate_all(self):
|
||||
cache = InMemoryLLMCache()
|
||||
await cache.put("key1", _make_response())
|
||||
await cache.put("key2", _make_response())
|
||||
|
||||
count = await cache.invalidate()
|
||||
assert count == 2
|
||||
|
||||
result = await cache.get("key1")
|
||||
assert result.hit is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalidate_pattern(self):
|
||||
cache = InMemoryLLMCache()
|
||||
await cache.put("abc_1", _make_response())
|
||||
await cache.put("abc_2", _make_response())
|
||||
await cache.put("xyz_1", _make_response())
|
||||
|
||||
count = await cache.invalidate("abc_*")
|
||||
assert count == 2
|
||||
|
||||
result = await cache.get("xyz_1")
|
||||
assert result.hit is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stats(self):
|
||||
cache = InMemoryLLMCache()
|
||||
await cache.put("key1", _make_response())
|
||||
await cache.get("key1") # hit
|
||||
await cache.get("key2") # miss
|
||||
|
||||
stats = await cache.stats()
|
||||
assert stats["total_entries"] == 1
|
||||
assert stats["total_hits"] == 1
|
||||
assert stats["total_misses"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_calls_cached(self):
|
||||
cache = InMemoryLLMCache()
|
||||
tool_calls = [
|
||||
ToolCall(id="call_1", name="search", arguments={"query": "test"})
|
||||
]
|
||||
response = _make_response(tool_calls=tool_calls)
|
||||
|
||||
await cache.put("key1", response)
|
||||
result = await cache.get("key1")
|
||||
|
||||
assert result.hit is True
|
||||
assert len(result.response.tool_calls) == 1
|
||||
assert result.response.tool_calls[0].name == "search"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_put_without_embedding(self):
|
||||
cache = InMemoryLLMCache()
|
||||
await cache.put("key1", _make_response(), query_embedding=None)
|
||||
|
||||
# Exact match should still work
|
||||
result = await cache.get("key1")
|
||||
assert result.hit is True
|
||||
|
||||
# Semantic search should return miss (no embeddings)
|
||||
result = await cache.semantic_search(_make_embedding(dim=64))
|
||||
assert result.hit is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_put_updates_existing_key(self):
|
||||
cache = InMemoryLLMCache()
|
||||
await cache.put("key1", _make_response("Old"))
|
||||
await cache.put("key1", _make_response("New"))
|
||||
|
||||
result = await cache.get("key1")
|
||||
assert result.hit is True
|
||||
assert result.response.content == "New"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Serialization Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSerialization:
|
||||
def test_serialize_deserialize_response(self):
|
||||
response = _make_response(
|
||||
content="Test",
|
||||
model="gpt-4o",
|
||||
prompt_tokens=5,
|
||||
completion_tokens=10,
|
||||
tool_calls=[ToolCall(id="c1", name="tool1", arguments={"k": "v"})],
|
||||
)
|
||||
serialized = _serialize_response(response)
|
||||
deserialized = _deserialize_response(serialized)
|
||||
|
||||
assert deserialized.content == "Test"
|
||||
assert deserialized.model == "gpt-4o"
|
||||
assert deserialized.usage.prompt_tokens == 5
|
||||
assert deserialized.usage.completion_tokens == 10
|
||||
assert len(deserialized.tool_calls) == 1
|
||||
assert deserialized.tool_calls[0].name == "tool1"
|
||||
|
||||
def test_serialize_deserialize_entry(self):
|
||||
entry = CacheEntry(
|
||||
response=_make_response(),
|
||||
query_embedding=[0.1, 0.2, 0.3],
|
||||
created_at=12345.0,
|
||||
hit_count=5,
|
||||
)
|
||||
serialized = _serialize_entry(entry)
|
||||
deserialized = _deserialize_entry(serialized)
|
||||
|
||||
assert deserialized.response.content == "Hello"
|
||||
assert deserialized.query_embedding == [0.1, 0.2, 0.3]
|
||||
assert deserialized.created_at == 12345.0
|
||||
assert deserialized.hit_count == 5
|
||||
|
||||
def test_serialize_response_no_tool_calls(self):
|
||||
response = _make_response()
|
||||
serialized = _serialize_response(response)
|
||||
assert serialized["tool_calls"] == []
|
||||
|
||||
deserialized = _deserialize_response(serialized)
|
||||
assert deserialized.tool_calls == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RedisLLMCache Tests (with mocked Redis)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRedisLLMCache:
|
||||
def _make_mock_redis(self):
|
||||
"""Create a mock Redis client that simulates basic operations."""
|
||||
mock = AsyncMock()
|
||||
mock._data = {}
|
||||
mock._sets = {}
|
||||
|
||||
async def mock_get(key):
|
||||
return mock._data.get(key)
|
||||
|
||||
async def mock_set(key, value, ex=None):
|
||||
mock._data[key] = value
|
||||
|
||||
async def mock_mget(keys):
|
||||
return [mock._data.get(k) for k in keys]
|
||||
|
||||
async def mock_sadd(key, *members):
|
||||
if key not in mock._sets:
|
||||
mock._sets[key] = set()
|
||||
mock._sets[key].update(members)
|
||||
|
||||
async def mock_smembers(key):
|
||||
return mock._sets.get(key, set())
|
||||
|
||||
async def mock_scard(key):
|
||||
return len(mock._sets.get(key, set()))
|
||||
|
||||
async def mock_delete(*keys):
|
||||
for k in keys:
|
||||
mock._data.pop(k, None)
|
||||
|
||||
async def mock_srem(key, *members):
|
||||
if key in mock._sets:
|
||||
mock._sets[key] -= set(members)
|
||||
|
||||
mock.get = mock_get
|
||||
mock.set = mock_set
|
||||
mock.mget = mock_mget
|
||||
mock.sadd = mock_sadd
|
||||
mock.smembers = mock_smembers
|
||||
mock.scard = mock_scard
|
||||
mock.delete = mock_delete
|
||||
mock.srem = mock_srem
|
||||
|
||||
# Pipeline mock — collects commands and executes them on execute()
|
||||
class MockPipeline:
|
||||
def __init__(self):
|
||||
self._commands = []
|
||||
|
||||
def set(self, key, value, ex=None):
|
||||
self._commands.append(("set", key, value, ex))
|
||||
|
||||
def sadd(self, key, *members):
|
||||
self._commands.append(("sadd", key, members))
|
||||
|
||||
def delete(self, *keys):
|
||||
self._commands.append(("delete", keys))
|
||||
|
||||
def srem(self, key, *members):
|
||||
self._commands.append(("srem", key, members))
|
||||
|
||||
async def execute(self):
|
||||
for cmd in self._commands:
|
||||
if cmd[0] == "set":
|
||||
_, key, value, ex = cmd
|
||||
mock._data[key] = value
|
||||
elif cmd[0] == "sadd":
|
||||
_, key, members = cmd
|
||||
if key not in mock._sets:
|
||||
mock._sets[key] = set()
|
||||
mock._sets[key].update(members)
|
||||
elif cmd[0] == "delete":
|
||||
_, keys = cmd
|
||||
for k in keys:
|
||||
mock._data.pop(k, None)
|
||||
elif cmd[0] == "srem":
|
||||
_, key, members = cmd
|
||||
if key in mock._sets:
|
||||
mock._sets[key] -= set(members)
|
||||
|
||||
mock.pipeline = MagicMock(return_value=MockPipeline())
|
||||
|
||||
return mock
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exact_match_hit(self):
|
||||
cache = RedisLLMCache()
|
||||
mock_redis = self._make_mock_redis()
|
||||
cache._redis = mock_redis
|
||||
|
||||
key = "test_key"
|
||||
response = _make_response("Cached")
|
||||
entry = CacheEntry(response=response, created_at=time.monotonic(), hit_count=0)
|
||||
entry_json = json.dumps(_serialize_entry(entry))
|
||||
|
||||
# Simulate Redis already has the data
|
||||
mock_redis._data[f"{cache.KEY_PREFIX}{key}"] = entry_json
|
||||
|
||||
result = await cache.get(key)
|
||||
assert result.hit is True
|
||||
assert result.match_type == "exact"
|
||||
assert result.response.content == "Cached"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exact_match_miss(self):
|
||||
cache = RedisLLMCache()
|
||||
mock_redis = self._make_mock_redis()
|
||||
cache._redis = mock_redis
|
||||
|
||||
result = await cache.get("nonexistent_key")
|
||||
assert result.hit is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_failure_returns_miss(self):
|
||||
cache = RedisLLMCache()
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=Exception("Connection refused"))
|
||||
cache._redis = mock_redis
|
||||
|
||||
result = await cache.get("any_key")
|
||||
assert result.hit is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_put_stores_data(self):
|
||||
cache = RedisLLMCache()
|
||||
mock_redis = self._make_mock_redis()
|
||||
cache._redis = mock_redis
|
||||
|
||||
key = "test_key"
|
||||
response = _make_response()
|
||||
emb = [0.1, 0.2, 0.3]
|
||||
|
||||
await cache.put(key, response, query_embedding=emb)
|
||||
|
||||
# Verify data was stored
|
||||
assert f"{cache.KEY_PREFIX}{key}" in mock_redis._data
|
||||
assert f"{cache.EMB_PREFIX}{key}" in mock_redis._data
|
||||
assert key in mock_redis._sets.get(cache.INDEX_KEY, set())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalidate_all(self):
|
||||
cache = RedisLLMCache()
|
||||
mock_redis = self._make_mock_redis()
|
||||
cache._redis = mock_redis
|
||||
|
||||
# Pre-populate
|
||||
mock_redis._sets[cache.INDEX_KEY] = {"key1", "key2"}
|
||||
mock_redis._data[f"{cache.KEY_PREFIX}key1"] = "data1"
|
||||
mock_redis._data[f"{cache.KEY_PREFIX}key2"] = "data2"
|
||||
|
||||
count = await cache.invalidate()
|
||||
assert count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stats(self):
|
||||
cache = RedisLLMCache()
|
||||
mock_redis = self._make_mock_redis()
|
||||
cache._redis = mock_redis
|
||||
mock_redis._sets[cache.INDEX_KEY] = {"k1", "k2", "k3"}
|
||||
|
||||
stats = await cache.stats()
|
||||
assert stats["total_entries"] == 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Factory Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCreateLLMCache:
|
||||
def test_memory_backend(self):
|
||||
cache = create_llm_cache(backend="memory")
|
||||
assert isinstance(cache, InMemoryLLMCache)
|
||||
|
||||
def test_auto_backend_fallback(self):
|
||||
"""When redis package is not available, auto falls back to InMemory."""
|
||||
with patch.dict("sys.modules", {"redis.asyncio": None}):
|
||||
# Force ImportError by making redis.asyncio unimportable
|
||||
cache = create_llm_cache(backend="auto")
|
||||
assert isinstance(cache, InMemoryLLMCache)
|
||||
|
||||
def test_redis_backend_with_redis_available(self):
|
||||
"""When redis.asyncio is available, auto/redis returns RedisLLMCache."""
|
||||
cache = create_llm_cache(backend="redis")
|
||||
assert isinstance(cache, RedisLLMCache)
|
||||
|
||||
def test_auto_backend_with_redis_available(self):
|
||||
cache = create_llm_cache(backend="auto")
|
||||
assert isinstance(cache, RedisLLMCache)
|
||||
|
||||
def test_custom_parameters(self):
|
||||
cache = create_llm_cache(
|
||||
backend="memory",
|
||||
max_entries=500,
|
||||
exact_ttl=7200,
|
||||
semantic_ttl=172800,
|
||||
similarity_threshold=0.95,
|
||||
)
|
||||
assert isinstance(cache, InMemoryLLMCache)
|
||||
assert cache._max_entries == 500
|
||||
assert cache._exact_ttl == 7200
|
||||
assert cache._semantic_ttl == 172800
|
||||
assert cache._similarity_threshold == 0.95
|
||||
|
|
@ -0,0 +1,219 @@
|
|||
"""Unit tests for Semantic Router (U3)."""
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.chat.semantic_router import (
|
||||
SemanticRouteResult,
|
||||
SkillEmbeddingIndex,
|
||||
SemanticRouter,
|
||||
)
|
||||
from agentkit.memory.embedder import MockEmbedder
|
||||
|
||||
|
||||
def _make_embedding(base_val: float = 1.0, dim: int = 128) -> list[float]:
|
||||
"""Create a unit vector for similarity testing."""
|
||||
vec = [base_val] * dim
|
||||
magnitude = sum(x**2 for x in vec) ** 0.5
|
||||
return [x / magnitude for x in vec] if magnitude > 0 else vec
|
||||
|
||||
|
||||
class MockSkill:
|
||||
"""Mock skill for testing."""
|
||||
|
||||
def __init__(self, name: str, description: str = "", keywords: list[str] | None = None, capabilities: list[str] | None = None):
|
||||
self.name = name
|
||||
self.config = MockSkillConfig(
|
||||
name=name,
|
||||
description=description,
|
||||
keywords=keywords or [],
|
||||
capabilities=capabilities or [],
|
||||
)
|
||||
|
||||
|
||||
class MockSkillConfig:
|
||||
"""Mock skill config for testing."""
|
||||
|
||||
def __init__(self, name: str, description: str = "", keywords: list[str] | None = None, capabilities: list[str] | None = None):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.intent = MockIntentConfig(keywords=keywords or [])
|
||||
self.capabilities = [MockCapabilityTag(tag=t) for t in (capabilities or [])]
|
||||
|
||||
|
||||
class MockIntentConfig:
|
||||
def __init__(self, keywords: list[str] | None = None):
|
||||
self.keywords = keywords or []
|
||||
|
||||
|
||||
class MockCapabilityTag:
|
||||
def __init__(self, tag: str):
|
||||
self.tag = tag
|
||||
|
||||
|
||||
class MockSkillRegistry:
|
||||
"""Mock skill registry for testing."""
|
||||
|
||||
def __init__(self, skills: list[MockSkill] | None = None):
|
||||
self._skills = {s.name: s for s in (skills or [])}
|
||||
|
||||
def list_skills(self):
|
||||
return list(self._skills.values())
|
||||
|
||||
def get(self, name: str):
|
||||
if name not in self._skills:
|
||||
raise KeyError(f"Skill '{name}' not found")
|
||||
return self._skills[name]
|
||||
|
||||
|
||||
class TestSkillEmbeddingIndex:
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_from_registry(self):
|
||||
embedder = MockEmbedder(dimension=64)
|
||||
index = SkillEmbeddingIndex(embedder)
|
||||
|
||||
skills = [
|
||||
MockSkill("content_gen", description="生成文章内容", keywords=["写作", "文章"], capabilities=["content"]),
|
||||
MockSkill("data_analysis", description="数据分析与可视化", keywords=["分析", "数据"], capabilities=["analytics"]),
|
||||
]
|
||||
registry = MockSkillRegistry(skills)
|
||||
await index.build(registry)
|
||||
|
||||
assert index.size == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_returns_results(self):
|
||||
embedder = MockEmbedder(dimension=64)
|
||||
index = SkillEmbeddingIndex(embedder)
|
||||
|
||||
skill = MockSkill("content_gen", description="生成文章内容")
|
||||
await index.update_skill("content_gen", skill)
|
||||
|
||||
# MockEmbedder produces deterministic embeddings based on text hash
|
||||
# Different text → different embedding
|
||||
query_emb = await embedder.embed("生成文章")
|
||||
results = await index.search(query_emb)
|
||||
|
||||
assert len(results) >= 1
|
||||
assert results[0][0] == "content_gen" # skill_name
|
||||
assert results[0][1] > 0.0 # similarity
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_empty_index(self):
|
||||
embedder = MockEmbedder(dimension=64)
|
||||
index = SkillEmbeddingIndex(embedder)
|
||||
|
||||
query_emb = await embedder.embed("test")
|
||||
results = await index.search(query_emb)
|
||||
|
||||
assert results == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_skill(self):
|
||||
embedder = MockEmbedder(dimension=64)
|
||||
index = SkillEmbeddingIndex(embedder)
|
||||
|
||||
skill = MockSkill("test_skill", description="Test")
|
||||
await index.update_skill("test_skill", skill)
|
||||
assert index.size == 1
|
||||
|
||||
index.remove_skill("test_skill")
|
||||
assert index.size == 0
|
||||
|
||||
def test_build_source_text_with_description(self):
|
||||
skill = MockSkill("test", description="A test skill", keywords=["test"], capabilities=["testing"])
|
||||
text = SkillEmbeddingIndex._build_source_text(skill)
|
||||
assert "A test skill" in text
|
||||
assert "test" in text
|
||||
assert "testing" in text
|
||||
|
||||
def test_build_source_text_fallback_to_name(self):
|
||||
skill = MockSkill("my_skill", description="", keywords=[], capabilities=[])
|
||||
text = SkillEmbeddingIndex._build_source_text(skill)
|
||||
assert "my_skill" in text
|
||||
|
||||
|
||||
class TestSemanticRouter:
|
||||
@pytest.mark.asyncio
|
||||
async def test_high_confidence_match(self):
|
||||
"""When similarity > 0.85, return high confidence."""
|
||||
embedder = MockEmbedder(dimension=64)
|
||||
router = SemanticRouter(embedder, similarity_high=0.5, similarity_low=0.3)
|
||||
|
||||
# Add a skill with known embedding
|
||||
skill = MockSkill("content_gen", description="生成文章内容")
|
||||
await router.update_skill("content_gen", skill)
|
||||
|
||||
# Query with same text should produce very similar embedding (MockEmbedder is hash-based)
|
||||
# With low thresholds, even moderate similarity will be "high"
|
||||
result = await router.route("生成文章内容")
|
||||
# MockEmbedder may or may not produce high similarity for different text
|
||||
# Just verify the result structure
|
||||
assert result.confidence in ("high", "medium", "low")
|
||||
assert isinstance(result.similarity, float)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_low_confidence_empty_index(self):
|
||||
"""Empty index returns low confidence."""
|
||||
embedder = MockEmbedder(dimension=64)
|
||||
router = SemanticRouter(embedder)
|
||||
|
||||
result = await router.route("任何查询")
|
||||
assert result.confidence == "low"
|
||||
assert result.skill_name is None
|
||||
assert result.similarity == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_medium_confidence_zone(self):
|
||||
"""Test medium confidence zone (0.6-0.85)."""
|
||||
embedder = MockEmbedder(dimension=64)
|
||||
router = SemanticRouter(embedder, similarity_high=0.99, similarity_low=0.01)
|
||||
|
||||
skill = MockSkill("content_gen", description="生成文章内容")
|
||||
await router.update_skill("content_gen", skill)
|
||||
|
||||
# With very high similarity_high and very low similarity_low,
|
||||
# most matches will be "medium"
|
||||
result = await router.route("生成文章")
|
||||
# The result should be medium (since threshold is 0.99)
|
||||
assert result.confidence in ("medium", "low", "high")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embedder_failure_graceful(self):
|
||||
"""Embedder failure returns low confidence."""
|
||||
class FailingEmbedder(MockEmbedder):
|
||||
async def embed(self, text):
|
||||
raise RuntimeError("Embedding API failed")
|
||||
|
||||
router = SemanticRouter(FailingEmbedder(dimension=64))
|
||||
result = await router.route("test query")
|
||||
assert result.confidence == "low"
|
||||
assert result.skill_name is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_index_from_registry(self):
|
||||
"""Build index from skill registry."""
|
||||
embedder = MockEmbedder(dimension=64)
|
||||
router = SemanticRouter(embedder)
|
||||
|
||||
skills = [
|
||||
MockSkill("skill_a", description="Skill A"),
|
||||
MockSkill("skill_b", description="Skill B"),
|
||||
]
|
||||
registry = MockSkillRegistry(skills)
|
||||
await router.build_index(registry)
|
||||
|
||||
assert router._index.size == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chinese_query(self):
|
||||
"""Chinese query works with semantic router."""
|
||||
embedder = MockEmbedder(dimension=64)
|
||||
router = SemanticRouter(embedder, similarity_high=0.01, similarity_low=0.001)
|
||||
|
||||
skill = MockSkill("geo_optimizer", description="地理内容优化", keywords=["优化", "SEO", "地理"], capabilities=["optimization"])
|
||||
await router.update_skill("geo_optimizer", skill)
|
||||
|
||||
result = await router.route("帮我优化内容")
|
||||
# With very low thresholds, should match
|
||||
assert result.confidence in ("high", "medium")
|
||||
assert result.skill_name == "geo_optimizer"
|
||||
|
|
@ -0,0 +1,458 @@
|
|||
"""Tests for unified EvolutionStoreProtocol compliance
|
||||
|
||||
Verifies that all backends implement the full Protocol interface:
|
||||
- InMemoryEvolutionStore
|
||||
- PersistentEvolutionStore
|
||||
- PostgreSQLEvolutionStore (mocked async session)
|
||||
- EvolutionStore (legacy, with NotImplementedError for skill_version/ab_test)
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.core.protocol import EvolutionEvent
|
||||
from agentkit.evolution.evolution_store import (
|
||||
EvolutionStore,
|
||||
EvolutionStoreProtocol,
|
||||
InMemoryEvolutionStore,
|
||||
PersistentEvolutionStore,
|
||||
create_evolution_store,
|
||||
)
|
||||
|
||||
|
||||
# ── Fixtures ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_event():
|
||||
"""A sample EvolutionEvent."""
|
||||
return EvolutionEvent(
|
||||
agent_name="test_agent",
|
||||
change_type="prompt",
|
||||
before={"prompt": "old prompt"},
|
||||
after={"prompt": "new prompt"},
|
||||
metrics={"accuracy": 0.9},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def memory_store():
|
||||
return InMemoryEvolutionStore()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sqlite_store(tmp_path):
|
||||
db_path = str(tmp_path / "test_unified.db")
|
||||
return PersistentEvolutionStore(db_path=db_path)
|
||||
|
||||
|
||||
# ── Protocol compliance tests ─────────────────────────────
|
||||
|
||||
|
||||
class TestProtocolCompliance:
|
||||
"""Verify all stores implement EvolutionStoreProtocol."""
|
||||
|
||||
def test_inmemory_is_protocol(self):
|
||||
assert isinstance(InMemoryEvolutionStore(), EvolutionStoreProtocol)
|
||||
|
||||
def test_persistent_is_protocol(self, tmp_path):
|
||||
db_path = str(tmp_path / "protocol_check.db")
|
||||
assert isinstance(PersistentEvolutionStore(db_path=db_path), EvolutionStoreProtocol)
|
||||
|
||||
def test_pg_store_is_protocol(self):
|
||||
from agentkit.evolution.pg_store import PostgreSQLEvolutionStore
|
||||
|
||||
store = PostgreSQLEvolutionStore(database_url="postgresql+asyncpg://test:test@localhost/test")
|
||||
assert isinstance(store, EvolutionStoreProtocol)
|
||||
|
||||
def test_legacy_evolution_store_is_protocol(self):
|
||||
"""Legacy EvolutionStore also satisfies Protocol (has all method signatures)."""
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
store = EvolutionStore(session_factory=MagicMock(), evolution_model=MagicMock())
|
||||
assert isinstance(store, EvolutionStoreProtocol)
|
||||
|
||||
|
||||
# ── InMemoryEvolutionStore: full Protocol ─────────────────
|
||||
|
||||
|
||||
class TestInMemoryFullProtocol:
|
||||
"""InMemoryEvolutionStore implements all Protocol methods."""
|
||||
|
||||
async def test_record_and_list_events(self, memory_store, sample_event):
|
||||
event_id = await memory_store.record(sample_event)
|
||||
assert event_id is not None
|
||||
|
||||
events = await memory_store.list_events()
|
||||
assert len(events) == 1
|
||||
assert events[0]["agent_name"] == "test_agent"
|
||||
assert events[0]["change_type"] == "prompt"
|
||||
|
||||
async def test_rollback(self, memory_store, sample_event):
|
||||
event_id = await memory_store.record(sample_event)
|
||||
result = await memory_store.rollback(event_id)
|
||||
assert result is True
|
||||
|
||||
events = await memory_store.list_events()
|
||||
assert events[0]["status"] == "rolled_back"
|
||||
|
||||
async def test_rollback_nonexistent(self, memory_store):
|
||||
result = await memory_store.rollback("nonexistent")
|
||||
assert result is False
|
||||
|
||||
async def test_list_events_with_filters(self, memory_store):
|
||||
await memory_store.record(
|
||||
EvolutionEvent(agent_name="a", change_type="prompt", before={}, after={})
|
||||
)
|
||||
await memory_store.record(
|
||||
EvolutionEvent(agent_name="b", change_type="strategy", before={}, after={})
|
||||
)
|
||||
|
||||
events = await memory_store.list_events(agent_name="a")
|
||||
assert len(events) == 1
|
||||
assert events[0]["agent_name"] == "a"
|
||||
|
||||
async def test_record_and_list_skill_version(self, memory_store):
|
||||
vid = await memory_store.record_skill_version("search", "v1", '{"prompt": "v1"}')
|
||||
assert vid is not None
|
||||
|
||||
versions = await memory_store.list_skill_versions("search")
|
||||
assert len(versions) == 1
|
||||
assert versions[0]["version"] == "v1"
|
||||
assert versions[0]["content"] == '{"prompt": "v1"}'
|
||||
|
||||
async def test_skill_version_with_parent(self, memory_store):
|
||||
await memory_store.record_skill_version("search", "v1", '{"prompt": "v1"}')
|
||||
await memory_store.record_skill_version(
|
||||
"search", "v2", '{"prompt": "v2"}', parent_version="v1"
|
||||
)
|
||||
|
||||
versions = await memory_store.list_skill_versions("search")
|
||||
assert len(versions) == 2
|
||||
assert versions[0]["version"] == "v2"
|
||||
assert versions[0]["parent_version"] == "v1"
|
||||
|
||||
async def test_record_and_get_ab_test_result(self, memory_store):
|
||||
rid = await memory_store.record_ab_test_result("t1", "control", 0.8, 5)
|
||||
assert rid is not None
|
||||
|
||||
results = await memory_store.get_ab_test_results("t1")
|
||||
assert len(results) == 1
|
||||
assert results[0]["variant"] == "control"
|
||||
assert results[0]["score"] == 0.8
|
||||
assert results[0]["sample_count"] == 5
|
||||
|
||||
async def test_ab_test_multiple_variants(self, memory_store):
|
||||
await memory_store.record_ab_test_result("t1", "control", 0.8, 10)
|
||||
await memory_store.record_ab_test_result("t1", "experiment", 0.9, 10)
|
||||
|
||||
results = await memory_store.get_ab_test_results("t1")
|
||||
assert len(results) == 2
|
||||
|
||||
async def test_list_skill_versions_empty(self, memory_store):
|
||||
versions = await memory_store.list_skill_versions("nonexistent")
|
||||
assert versions == []
|
||||
|
||||
async def test_get_ab_test_results_empty(self, memory_store):
|
||||
results = await memory_store.get_ab_test_results("nonexistent")
|
||||
assert results == []
|
||||
|
||||
|
||||
# ── PersistentEvolutionStore: full Protocol ───────────────
|
||||
|
||||
|
||||
class TestSQLiteFullProtocol:
|
||||
"""PersistentEvolutionStore implements all Protocol methods."""
|
||||
|
||||
async def test_record_and_list_events(self, sqlite_store, sample_event):
|
||||
event_id = await sqlite_store.record(sample_event)
|
||||
assert event_id is not None
|
||||
|
||||
events = await sqlite_store.list_events()
|
||||
assert len(events) == 1
|
||||
assert events[0]["agent_name"] == "test_agent"
|
||||
|
||||
async def test_rollback(self, sqlite_store, sample_event):
|
||||
event_id = await sqlite_store.record(sample_event)
|
||||
result = await sqlite_store.rollback(event_id)
|
||||
assert result is True
|
||||
|
||||
events = await sqlite_store.list_events()
|
||||
assert events[0]["status"] == "rolled_back"
|
||||
|
||||
async def test_record_and_list_skill_version(self, sqlite_store):
|
||||
vid = await sqlite_store.record_skill_version("search", "v1", '{"prompt": "v1"}')
|
||||
assert vid is not None
|
||||
|
||||
versions = await sqlite_store.list_skill_versions("search")
|
||||
assert len(versions) == 1
|
||||
assert versions[0]["version"] == "v1"
|
||||
|
||||
async def test_skill_version_with_parent(self, sqlite_store):
|
||||
await sqlite_store.record_skill_version("search", "v1", '{"prompt": "v1"}')
|
||||
await sqlite_store.record_skill_version(
|
||||
"search", "v2", '{"prompt": "v2"}', parent_version="v1"
|
||||
)
|
||||
|
||||
versions = await sqlite_store.list_skill_versions("search")
|
||||
assert len(versions) == 2
|
||||
assert versions[0]["version"] == "v2"
|
||||
assert versions[0]["parent_version"] == "v1"
|
||||
|
||||
async def test_record_and_get_ab_test_result(self, sqlite_store):
|
||||
rid = await sqlite_store.record_ab_test_result("t1", "control", 0.8, 5)
|
||||
assert rid is not None
|
||||
|
||||
results = await sqlite_store.get_ab_test_results("t1")
|
||||
assert len(results) == 1
|
||||
assert results[0]["variant"] == "control"
|
||||
|
||||
async def test_ab_test_multiple_variants(self, sqlite_store):
|
||||
await sqlite_store.record_ab_test_result("t1", "control", 0.8, 10)
|
||||
await sqlite_store.record_ab_test_result("t1", "experiment", 0.9, 10)
|
||||
|
||||
results = await sqlite_store.get_ab_test_results("t1")
|
||||
assert len(results) == 2
|
||||
|
||||
async def test_list_skill_versions_empty(self, sqlite_store):
|
||||
versions = await sqlite_store.list_skill_versions("nonexistent")
|
||||
assert versions == []
|
||||
|
||||
async def test_get_ab_test_results_empty(self, sqlite_store):
|
||||
results = await sqlite_store.get_ab_test_results("nonexistent")
|
||||
assert results == []
|
||||
|
||||
|
||||
# ── PostgreSQLEvolutionStore: mocked Protocol ─────────────
|
||||
|
||||
|
||||
class TestPGStoreMocked:
|
||||
"""Test PostgreSQLEvolutionStore with mocked async session.
|
||||
|
||||
Since we can't require a running PostgreSQL in unit tests,
|
||||
we mock the async session to verify the logic paths.
|
||||
"""
|
||||
|
||||
def _make_pg_store(self):
|
||||
from agentkit.evolution.pg_store import PostgreSQLEvolutionStore
|
||||
|
||||
return PostgreSQLEvolutionStore(database_url="postgresql+asyncpg://test:test@localhost/test")
|
||||
|
||||
async def test_record_with_mock_session(self, sample_event):
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
store = self._make_pg_store()
|
||||
|
||||
# Mock the session factory
|
||||
mock_session = AsyncMock()
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
mock_session.rollback = AsyncMock()
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_sf():
|
||||
yield mock_session
|
||||
|
||||
store._session_factory = mock_sf
|
||||
store._initialized = True
|
||||
|
||||
event_id = await store.record(sample_event)
|
||||
assert event_id is not None
|
||||
assert sample_event.event_id == event_id
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
async def test_rollback_with_mock_session(self):
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
store = self._make_pg_store()
|
||||
|
||||
# Create a mock entry
|
||||
mock_entry = MagicMock()
|
||||
mock_entry.status = "active"
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = mock_entry
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
mock_session.commit = AsyncMock()
|
||||
mock_session.rollback = AsyncMock()
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_sf():
|
||||
yield mock_session
|
||||
|
||||
store._session_factory = mock_sf
|
||||
store._initialized = True
|
||||
|
||||
result = await store.rollback("test-event-id")
|
||||
assert result is True
|
||||
assert mock_entry.status == "rolled_back"
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
async def test_rollback_not_found(self):
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
store = self._make_pg_store()
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
mock_session.rollback = AsyncMock()
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_sf():
|
||||
yield mock_session
|
||||
|
||||
store._session_factory = mock_sf
|
||||
store._initialized = True
|
||||
|
||||
result = await store.rollback("nonexistent")
|
||||
assert result is False
|
||||
|
||||
async def test_record_skill_version_with_mock(self):
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
store = self._make_pg_store()
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
mock_session.rollback = AsyncMock()
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_sf():
|
||||
yield mock_session
|
||||
|
||||
store._session_factory = mock_sf
|
||||
store._initialized = True
|
||||
|
||||
vid = await store.record_skill_version("search", "v1", '{"prompt": "v1"}')
|
||||
assert vid is not None
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
async def test_record_ab_test_result_with_mock(self):
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
store = self._make_pg_store()
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
mock_session.rollback = AsyncMock()
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_sf():
|
||||
yield mock_session
|
||||
|
||||
store._session_factory = mock_sf
|
||||
store._initialized = True
|
||||
|
||||
rid = await store.record_ab_test_result("t1", "control", 0.8, 5)
|
||||
assert rid is not None
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
# ── Legacy EvolutionStore: NotImplementedError tests ──────
|
||||
|
||||
|
||||
class TestLegacyEvolutionStoreStubs:
|
||||
"""Legacy EvolutionStore raises NotImplementedError for skill_version/ab_test."""
|
||||
|
||||
async def test_record_skill_version_raises(self):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
store = EvolutionStore(session_factory=MagicMock(), evolution_model=MagicMock())
|
||||
with pytest.raises(NotImplementedError, match="skill_version"):
|
||||
await store.record_skill_version("s", "v1", "content")
|
||||
|
||||
async def test_list_skill_versions_raises(self):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
store = EvolutionStore(session_factory=MagicMock(), evolution_model=MagicMock())
|
||||
with pytest.raises(NotImplementedError, match="skill_version"):
|
||||
await store.list_skill_versions("s")
|
||||
|
||||
async def test_record_ab_test_result_raises(self):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
store = EvolutionStore(session_factory=MagicMock(), evolution_model=MagicMock())
|
||||
with pytest.raises(NotImplementedError, match="A/B test"):
|
||||
await store.record_ab_test_result("t1", "control", 0.8)
|
||||
|
||||
async def test_get_ab_test_results_raises(self):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
store = EvolutionStore(session_factory=MagicMock(), evolution_model=MagicMock())
|
||||
with pytest.raises(NotImplementedError, match="A/B test"):
|
||||
await store.get_ab_test_results("t1")
|
||||
|
||||
|
||||
# ── Factory tests ─────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCreateEvolutionStoreExtended:
|
||||
def test_create_memory_backend(self):
|
||||
store = create_evolution_store(backend="memory")
|
||||
assert isinstance(store, InMemoryEvolutionStore)
|
||||
|
||||
def test_create_sqlite_backend(self, tmp_path):
|
||||
db_path = str(tmp_path / "factory_test.db")
|
||||
store = create_evolution_store(backend="sqlite", db_path=db_path)
|
||||
assert isinstance(store, PersistentEvolutionStore)
|
||||
|
||||
def test_create_default_backend(self):
|
||||
store = create_evolution_store()
|
||||
assert isinstance(store, InMemoryEvolutionStore)
|
||||
|
||||
def test_create_sql_backend_without_params_falls_back(self):
|
||||
store = create_evolution_store(backend="sql")
|
||||
assert isinstance(store, InMemoryEvolutionStore)
|
||||
|
||||
def test_create_postgresql_without_url_falls_back(self):
|
||||
"""PostgreSQL backend without database_url falls back to memory."""
|
||||
# Clear env var if set
|
||||
old_val = os.environ.pop("AGENTKIT_DATABASE_URL", None)
|
||||
try:
|
||||
store = create_evolution_store(backend="postgresql")
|
||||
assert isinstance(store, InMemoryEvolutionStore)
|
||||
finally:
|
||||
if old_val is not None:
|
||||
os.environ["AGENTKIT_DATABASE_URL"] = old_val
|
||||
|
||||
def test_create_postgresql_with_url(self):
|
||||
"""PostgreSQL backend with database_url returns PostgreSQLEvolutionStore."""
|
||||
from agentkit.evolution.pg_store import PostgreSQLEvolutionStore
|
||||
|
||||
store = create_evolution_store(
|
||||
backend="postgresql",
|
||||
database_url="postgresql+asyncpg://user:pass@localhost/db",
|
||||
)
|
||||
assert isinstance(store, PostgreSQLEvolutionStore)
|
||||
|
||||
def test_create_postgresql_with_env_url(self):
|
||||
"""PostgreSQL backend reads database_url from environment variable."""
|
||||
from agentkit.evolution.pg_store import PostgreSQLEvolutionStore
|
||||
|
||||
old_val = os.environ.get("AGENTKIT_DATABASE_URL")
|
||||
try:
|
||||
os.environ["AGENTKIT_DATABASE_URL"] = "postgresql+asyncpg://user:pass@localhost/db"
|
||||
store = create_evolution_store(backend="postgresql")
|
||||
assert isinstance(store, PostgreSQLEvolutionStore)
|
||||
finally:
|
||||
if old_val is not None:
|
||||
os.environ["AGENTKIT_DATABASE_URL"] = old_val
|
||||
else:
|
||||
os.environ.pop("AGENTKIT_DATABASE_URL", None)
|
||||
|
|
@ -5,7 +5,8 @@ from datetime import datetime, timedelta, timezone
|
|||
import pytest
|
||||
|
||||
from agentkit.llm.protocol import TokenUsage
|
||||
from agentkit.llm.providers.tracker import UsageRecord, UsageSummary, UsageTracker
|
||||
from agentkit.llm.providers.tracker import UsageSummary, UsageTracker
|
||||
from agentkit.llm.providers.usage_store import UsageRecord
|
||||
|
||||
|
||||
class TestUsageTrackerRecord:
|
||||
|
|
@ -23,8 +24,10 @@ class TestUsageTrackerRecord:
|
|||
latency_ms=200.0,
|
||||
)
|
||||
|
||||
assert len(tracker._records) == 1
|
||||
rec = tracker._records[0]
|
||||
# Verify via get_usage() instead of internal _records
|
||||
summary = tracker.get_usage()
|
||||
assert len(summary.records) == 1
|
||||
rec = summary.records[0]
|
||||
assert rec.agent_name == "test_agent"
|
||||
assert rec.model == "gpt-4o"
|
||||
assert rec.prompt_tokens == 100
|
||||
|
|
@ -41,7 +44,8 @@ class TestUsageTrackerRecord:
|
|||
tracker.record("agent_a", "gpt-4o", usage1, 0.001, 100.0)
|
||||
tracker.record("agent_b", "deepseek-chat", usage2, 0.002, 150.0)
|
||||
|
||||
assert len(tracker._records) == 2
|
||||
summary = tracker.get_usage()
|
||||
assert len(summary.records) == 2
|
||||
|
||||
|
||||
class TestUsageTrackerGetUsage:
|
||||
|
|
@ -80,10 +84,11 @@ class TestUsageTrackerGetUsage:
|
|||
usage2 = TokenUsage(prompt_tokens=200, completion_tokens=100)
|
||||
|
||||
tracker.record("agent_a", "gpt-4o", usage1, 0.005, 100.0)
|
||||
|
||||
# Manually set timestamp of second record to 2 hours ago
|
||||
tracker.record("agent_a", "gpt-4o", usage2, 0.010, 200.0)
|
||||
tracker._records[-1].timestamp = now - timedelta(hours=2)
|
||||
|
||||
# Manually set timestamp of second record to 2 hours ago via store
|
||||
store = tracker._store
|
||||
store._records[-1].timestamp = (now - timedelta(hours=2)).isoformat()
|
||||
|
||||
# Query last hour only
|
||||
summary = tracker.get_usage(start_time=now - timedelta(hours=1), end_time=now + timedelta(hours=1))
|
||||
|
|
|
|||
Loading…
Reference in New Issue