From 0ccef7be5cf0501a68d200d9e1162d9b05e94706 Mon Sep 17 00:00:00 2001 From: chiguyong Date: Sun, 14 Jun 2026 15:16:00 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20P0=20production=20hardening=20=E2=80=94?= =?UTF-8?q?=20LLM=20cache,=20semantic=20routing,=20state=20persistence?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit U1: LLM Cache Core (exact + semantic match, InMemory + Redis backends) U2: Cache integration into LLMGateway with CacheConfig U3: Semantic Router as Layer 1.5 in CostAwareRouter U4: UsageStore persistence (Redis Hash + InMemory fallback) U5: CascadeStateStore persistence (Redis INCR + InMemory TTL) U6: EvolutionStore interface unification (Protocol + PostgreSQL backend) U7: Configuration integration + E2E tests Code review fixes: - P0: date iteration bug (day>=28), semantic router index never built, Redis connection leak (per-call → persistent pool) - P1: cache degradation recovery, semantic_search degradation, double miss counting, asyncio.Lock for PG init, LIMIT on queries, __import__ anti-pattern → _utcnow() - P2: InMemory TTL cleanup, embedding preservation on put(), data TTL = max(exact_ttl, semantic_ttl) --- .gitignore | 1 + agentkit.yaml | 21 +- src/agentkit/chat/semantic_router.py | 207 +++++++ src/agentkit/chat/skill_routing.py | 131 +++- src/agentkit/core/config_driven.py | 8 +- src/agentkit/core/react.py | 25 +- src/agentkit/evolution/__init__.py | 2 + src/agentkit/evolution/evolution_store.py | 91 ++- src/agentkit/evolution/lifecycle.py | 10 +- src/agentkit/evolution/pg_store.py | 329 ++++++++++ src/agentkit/llm/__init__.py | 3 +- src/agentkit/llm/cache.py | 632 ++++++++++++++++++++ src/agentkit/llm/cache_key.py | 66 ++ src/agentkit/llm/config.py | 44 ++ src/agentkit/llm/gateway.py | 120 +++- src/agentkit/llm/providers/__init__.py | 3 +- src/agentkit/llm/providers/tracker.py | 86 +-- src/agentkit/llm/providers/usage_store.py | 373 ++++++++++++ src/agentkit/memory/profile.py | 107 +++- src/agentkit/quality/cascade_detector.py | 31 +- src/agentkit/quality/cascade_state_store.py | 245 ++++++++ src/agentkit/server/app.py | 108 +++- src/agentkit/server/config.py | 27 + tests/integration/test_p0_hardening.py | 422 +++++++++++++ tests/unit/test_gateway_cache.py | 162 +++++ tests/unit/test_llm_cache.py | 604 +++++++++++++++++++ tests/unit/test_semantic_router.py | 219 +++++++ tests/unit/test_unified_evolution_store.py | 458 ++++++++++++++ tests/unit/test_usage_tracker.py | 19 +- 29 files changed, 4403 insertions(+), 151 deletions(-) create mode 100644 src/agentkit/chat/semantic_router.py create mode 100644 src/agentkit/evolution/pg_store.py create mode 100644 src/agentkit/llm/cache.py create mode 100644 src/agentkit/llm/cache_key.py create mode 100644 src/agentkit/llm/providers/usage_store.py create mode 100644 src/agentkit/quality/cascade_state_store.py create mode 100644 tests/integration/test_p0_hardening.py create mode 100644 tests/unit/test_gateway_cache.py create mode 100644 tests/unit/test_llm_cache.py create mode 100644 tests/unit/test_semantic_router.py create mode 100644 tests/unit/test_unified_evolution_store.py diff --git a/.gitignore b/.gitignore index fc79f89..a1745fb 100644 --- a/.gitignore +++ b/.gitignore @@ -35,6 +35,7 @@ build/pyinstaller-work/ # Frontend build artifacts src/agentkit/server/static/ +**/node_modules/ # Env .env diff --git a/agentkit.yaml b/agentkit.yaml index 46b1ebb..f5b3c31 100644 --- a/agentkit.yaml +++ b/agentkit.yaml @@ -5,23 +5,12 @@ server: rate_limit: 60 llm: providers: - bailian-coding: - api_key: ${DASHSCOPE_API_KEY} - base_url: https://coding.dashscope.aliyuncs.com/v1 + test: type: openai - models: - qwen3.7-plus: - alias: default - qwen3.6-plus: {} - qwen3.5-plus: {} - qwen3-max-2026-01-23: {} - qwen3-coder-plus: - alias: coder - qwen3-coder-next: {} - kimi-k2.5: {} - glm-5: {} - glm-4.7: {} - MiniMax-M2.5: {} + base_url: '' + max_tokens: 4096 + timeout: 120.0 + api_key: '' model_aliases: default: bailian-coding/qwen3.7-plus coder: bailian-coding/qwen3-coder-plus diff --git a/src/agentkit/chat/semantic_router.py b/src/agentkit/chat/semantic_router.py new file mode 100644 index 0000000..c9180ae --- /dev/null +++ b/src/agentkit/chat/semantic_router.py @@ -0,0 +1,207 @@ +"""Semantic Router — Embedding-based intent routing as Layer 1.5. + +Uses pre-computed skill embeddings for zero-cost semantic matching, +inserted between Layer 1 (HeuristicClassifier) and Layer 2 (LLM classification) +in CostAwareRouter. + +Design doc: docs/plans/2026-06-14-004-u3-semantic-router.md +""" + +import logging +from dataclasses import dataclass +from typing import Any + +from agentkit.memory.embedder import Embedder, EmbeddingCache +from agentkit.utils.vector_math import compute_cosine_similarity + +logger = logging.getLogger(__name__) + + +@dataclass +class SemanticRouteResult: + """Result of semantic routing.""" + + confidence: str # "high" | "medium" | "low" + skill_name: str | None + similarity: float + + +class SkillEmbeddingIndex: + """Pre-computed embedding index for registered skills. + + Embeddings are computed at skill registration time and cached. + Query-time search is O(n) cosine similarity scan, which is fast + for <100 skills with 1024-1536 dim vectors. + """ + + def __init__(self, embedder: Embedder): + self._embedder = embedder + # skill_name → (embedding, source_text) + self._index: dict[str, tuple[list[float], str]] = {} + + async def build(self, skill_registry: Any) -> None: + """Build index from all registered skills.""" + if skill_registry is None: + return + skills = skill_registry.list_skills() + for skill in skills: + await self.update_skill(skill.config.name, skill) + + async def update_skill(self, skill_name: str, skill: Any) -> None: + """Re-embed a single skill (on registration/update).""" + source_text = self._build_source_text(skill) + try: + embedding = await self._embedder.embed(source_text) + self._index[skill_name] = (embedding, source_text) + except Exception as e: + logger.warning(f"Failed to embed skill '{skill_name}': {e}") + + def remove_skill(self, skill_name: str) -> None: + """Remove a skill from the index.""" + self._index.pop(skill_name, None) + + async def search(self, query_embedding: list[float], top_k: int = 5) -> list[tuple[str, float]]: + """Search for skills matching the query embedding. + + Returns: + List of (skill_name, similarity) sorted by similarity descending. + """ + if not self._index: + return [] + + results: list[tuple[str, float]] = [] + for skill_name, (emb, _) in self._index.items(): + sim = compute_cosine_similarity(query_embedding, emb) + results.append((skill_name, sim)) + + results.sort(key=lambda x: x[1], reverse=True) + return results[:top_k] + + @staticmethod + def _build_source_text(skill: Any) -> str: + """Build embedding source text from skill metadata. + + Combines description, intent keywords, and capability tags + for rich semantic representation. + """ + config = skill.config if hasattr(skill, "config") else skill + parts = [] + + # Description + description = getattr(config, "description", "") or "" + if description: + parts.append(description) + + # Intent keywords + intent = getattr(config, "intent", None) + if intent and hasattr(intent, "keywords") and intent.keywords: + parts.append(" ".join(intent.keywords)) + + # Capability tags + capabilities = getattr(config, "capabilities", None) + if capabilities: + tags = [] + for cap in capabilities: + if isinstance(cap, str): + tags.append(cap) + elif isinstance(cap, dict): + tags.append(cap.get("tag", "")) + elif hasattr(cap, "tag"): + tags.append(cap.tag) + if tags: + parts.append(" ".join(t for t in tags if t)) + + # Fallback: use skill name if no other text available + if not parts: + parts.append(getattr(config, "name", "unknown")) + + return " | ".join(parts) + + @property + def size(self) -> int: + """Number of skills in the index.""" + return len(self._index) + + +class SemanticRouter: + """Embedding-based semantic routing as Layer 1.5. + + Three confidence zones: + - similarity > similarity_high (0.85): HIGH → direct skill match, skip Layer 2 + - similarity_low (0.6) <= similarity <= similarity_high: MEDIUM → skill hint for Layer 2 + - similarity < similarity_low (0.6): LOW → no semantic signal, normal routing + """ + + def __init__( + self, + embedder: Embedder, + similarity_high: float = 0.85, + similarity_low: float = 0.6, + ): + self._embedder = embedder + self._similarity_high = similarity_high + self._similarity_low = similarity_low + self._index = SkillEmbeddingIndex(embedder) + self._query_cache = EmbeddingCache(max_size=500, ttl=1800) + + async def build_index(self, skill_registry: Any) -> None: + """Build skill embedding index from registry.""" + await self._index.build(skill_registry) + logger.info(f"Semantic router index built: {self._index.size} skills") + + async def update_skill(self, skill_name: str, skill: Any) -> None: + """Update a single skill's embedding.""" + await self._index.update_skill(skill_name, skill) + + def remove_skill(self, skill_name: str) -> None: + """Remove a skill from the index.""" + self._index.remove_skill(skill_name) + + async def route(self, query: str) -> SemanticRouteResult: + """Route a query using semantic similarity. + + Args: + query: User's input text. + + Returns: + SemanticRouteResult with confidence, skill_name, and similarity. + """ + if self._index.size == 0: + return SemanticRouteResult(confidence="low", skill_name=None, similarity=0.0) + + try: + # Get query embedding (with cache) + query_embedding = self._query_cache.get(query) + if query_embedding is None: + query_embedding = await self._embedder.embed(query) + self._query_cache.put(query, query_embedding) + + # Search skill index + results = await self._index.search(query_embedding, top_k=1) + if not results: + return SemanticRouteResult(confidence="low", skill_name=None, similarity=0.0) + + best_skill, best_sim = results[0] + + if best_sim >= self._similarity_high: + return SemanticRouteResult( + confidence="high", + skill_name=best_skill, + similarity=best_sim, + ) + elif best_sim >= self._similarity_low: + return SemanticRouteResult( + confidence="medium", + skill_name=best_skill, + similarity=best_sim, + ) + else: + return SemanticRouteResult( + confidence="low", + skill_name=None, + similarity=best_sim, + ) + + except Exception as e: + logger.warning(f"Semantic routing failed, returning low confidence: {e}") + return SemanticRouteResult(confidence="low", skill_name=None, similarity=0.0) diff --git a/src/agentkit/chat/skill_routing.py b/src/agentkit/chat/skill_routing.py index 74188dd..c9b5935 100644 --- a/src/agentkit/chat/skill_routing.py +++ b/src/agentkit/chat/skill_routing.py @@ -6,6 +6,7 @@ and prompt assembly into a single module used by both chat routes. from __future__ import annotations +import enum import json import logging import re @@ -21,6 +22,19 @@ logger = logging.getLogger(__name__) _SKILL_NAME_RE = re.compile(r"^[a-z0-9][a-z0-9_-]{0,63}$") +class ExecutionMode(enum.Enum): + """How the downstream should execute this routing result. + + This is the single source of truth for execution path selection. + The transport layer (portal.py, chat.py) should branch on this + field instead of string-matching match_method. + """ + + DIRECT_CHAT = "direct_chat" # Zero-cost: direct LLM call, no ReAct loop + REACT = "react" # Default agent ReAct loop with default tools + SKILL_REACT = "skill_react" # Skill-matched ReAct with skill tools + prompt + + def validate_skill_name(name: str) -> str: """Validate and normalize a skill name. Raises ValueError on invalid input.""" normalized = name.strip().lower() @@ -49,6 +63,7 @@ class SkillRoutingResult: transparency_level: str = "SILENT" execution_trace: list[dict] = field(default_factory=list) complexity: float = 0.0 + execution_mode: ExecutionMode = ExecutionMode.DIRECT_CHAT def parse_skill_prefix(content: str) -> tuple[str | None, str]: @@ -88,6 +103,7 @@ async def resolve_skill_routing( default_agent_name: str = "default", agent_tool_registry: Any = None, session_id: str = "", + force_skill: str | None = None, ) -> SkillRoutingResult: """Resolve skill routing for a user message. @@ -120,6 +136,20 @@ async def resolve_skill_routing( result.skill_name = None result.skill_config = None + # Try force_skill match (from semantic router high confidence) + if not result.matched and force_skill and skill_registry: + try: + matched_skill = skill_registry.get(force_skill) + result.skill_name = force_skill + result.skill_config = matched_skill.config + result.skill_tools = matched_skill.tools or [] + result.matched = True + result.match_method = "semantic_force" + result.match_confidence = 1.0 + logger.info(f"Session {session_id}: using force-matched skill '{force_skill}'") + except Exception as e: + logger.warning(f"Session {session_id}: force skill '{force_skill}' not found: {e}") + # Try IntentRouter if no explicit match if not result.matched and skill_registry and intent_router: skills = skill_registry.list_skills() @@ -205,11 +235,14 @@ async def resolve_skill_routing( result.model = result.skill_config.llm.get("model", default_model) if result.skill_config.llm else default_model result.agent_name = result.skill_name + result.execution_mode = ExecutionMode.SKILL_REACT else: result.system_prompt = default_system_prompt result.tools = default_tools result.model = default_model result.agent_name = default_agent_name + # No skill matched — if we have tools, use ReAct; otherwise direct chat + result.execution_mode = ExecutionMode.REACT if default_tools else ExecutionMode.DIRECT_CHAT # Append available tools to system prompt so LLM knows what it can call if result.tools: @@ -257,6 +290,14 @@ _CHAT_MODE_RE = re.compile( re.IGNORECASE, ) +# Simple identity/meta questions — zero-cost direct chat, no skill routing needed +_IDENTITY_RE = re.compile( + r"^(你是谁|你叫什么|你是什么|你是哪个|who are you|what are you|what's your name" + r"|介绍一下你自己|自我介绍|你叫啥|你叫什么名字|你的名字)" + r"\s*[??!!.。]*$", + re.IGNORECASE, +) + _SENTENCE_SPLIT_RE = re.compile(r'[,。!?;\n,.!?;]') @@ -319,8 +360,9 @@ class HeuristicClassifier: } # 中等复杂度暗示词(简单问题但需思考) + # 注意:不包含"怎么",因为"怎么样"是闲聊而非工具需求 _MEDIUM_COMPLEXITY_HINTS_CN = { - "如何", "怎么", "怎样", "为什么", "什么原因", "区别", + "如何", "怎样", "为什么", "什么原因", "区别", "推荐", "建议", "选择", "哪个", } @@ -428,6 +470,7 @@ class CostAwareRouter: auction_enabled: bool = False, classifier: str = "heuristic", merged_llm_classify: bool = True, + semantic_router: Any = None, # SemanticRouter | None ): self._llm_gateway = llm_gateway self._model = model @@ -435,6 +478,7 @@ class CostAwareRouter: self._auction_enabled = auction_enabled self._classifier = classifier self._merged_llm_classify = merged_llm_classify + self._semantic_router = semantic_router self._auction_house = AuctionHouse() if auction_enabled else None if classifier not in ("heuristic", "llm"): raise ValueError(f"Invalid classifier: {classifier!r}, must be 'heuristic' or 'llm'") @@ -462,6 +506,10 @@ class CostAwareRouter: if _CHAT_MODE_RE.match(stripped): return "chat_mode", stripped + # 身份/元问题模式("你是谁"等)— 零成本直接对话 + if _IDENTITY_RE.match(stripped): + return "identity", stripped + return None, stripped # -- Layer 1: LLM quick classify (~100 tokens) ------------------------- @@ -577,6 +625,7 @@ class CostAwareRouter: match_method="merged_llm", match_confidence=0.7, complexity=merged_complexity, + execution_mode=ExecutionMode.SKILL_REACT, ) # Merge tools agent_tools = agent_tool_registry.list_tools() if agent_tool_registry else default_tools @@ -590,6 +639,18 @@ class CostAwareRouter: result.model = result.skill_config.llm.get("model", default_model) if result.skill_config.llm else default_model result.agent_name = skill_hint result.system_prompt = build_skill_system_prompt(result.skill_config) or default_system_prompt + # Append available tools to system prompt so LLM knows what it can call + if result.tools: + tools_desc = _build_tools_description(result.tools) + tool_instruction = ( + "\n\n## Tool Usage\n" + "You have access to the following tools. When you need to use a tool, " + "respond with a tool call in the format specified by the system.\n" + "Never make up information or guess answers when you can use a tool to find the answer.\n" + "Always prefer using tools over guessing.\n" + ) + if result.system_prompt: + result.system_prompt += f"{tool_instruction}\n## Available Tools\n{tools_desc}" logger.info( f"Session {session_id}: merged LLM classify routed to skill '{skill_hint}' " f"(complexity={merged_complexity:.2f})" @@ -610,6 +671,7 @@ class CostAwareRouter: match_method="merged_llm_low", match_confidence=1.0 - merged_complexity, complexity=merged_complexity, + execution_mode=ExecutionMode.DIRECT_CHAT, ) elif merged_complexity > 0.7: # High complexity — delegate to Layer 2 @@ -623,6 +685,7 @@ class CostAwareRouter: match_method="merged_llm_high", match_confidence=merged_complexity, complexity=merged_complexity, + execution_mode=ExecutionMode.REACT, ) else: # Medium complexity, no skill match — default agent @@ -636,6 +699,7 @@ class CostAwareRouter: match_method="merged_llm_medium", match_confidence=0.5, complexity=merged_complexity, + execution_mode=ExecutionMode.REACT, ) except (json.JSONDecodeError, TypeError, ValueError) as e: logger.warning(f"CostAwareRouter _classify_merged parse failed: {e}, falling back to default") @@ -649,6 +713,7 @@ class CostAwareRouter: match_method="merged_llm_fallback", match_confidence=0.5, complexity=0.5, + execution_mode=ExecutionMode.REACT, ) except Exception as e: logger.warning(f"CostAwareRouter _classify_merged failed: {e}, falling back to default") @@ -662,6 +727,7 @@ class CostAwareRouter: match_method="merged_llm_fallback", match_confidence=0.5, complexity=0.5, + execution_mode=ExecutionMode.REACT, ) # -- Layer 2: Capability matching / Auction (optional) ----------------- @@ -746,6 +812,7 @@ class CostAwareRouter: system_prompt=default_system_prompt, tools=default_tools, complexity=complexity, + execution_mode=ExecutionMode.REACT, ) if trace is not None: trace.append({ @@ -776,6 +843,7 @@ class CostAwareRouter: system_prompt=default_system_prompt, tools=default_tools, complexity=complexity, + execution_mode=ExecutionMode.REACT, ) if trace is not None: trace.append({ @@ -876,7 +944,7 @@ class CostAwareRouter: span.set_attribute("route.target", result.skill_name or "default") return result - if match_type in ("greeting", "chat_mode"): + if match_type in ("greeting", "chat_mode", "identity"): result = SkillRoutingResult( clean_content=clean_content, system_prompt=default_system_prompt, @@ -887,6 +955,7 @@ class CostAwareRouter: match_method=match_type, match_confidence=1.0, complexity=0.0, + execution_mode=ExecutionMode.DIRECT_CHAT, ) trace.append({ "layer": 0, @@ -916,7 +985,7 @@ class CostAwareRouter: "complexity": complexity, }) - # Low complexity → default agent + # Low complexity → direct chat if complexity < 0.3: result = SkillRoutingResult( clean_content=clean_content, @@ -928,6 +997,7 @@ class CostAwareRouter: match_method="low_complexity", match_confidence=1.0 - complexity, complexity=complexity, + execution_mode=ExecutionMode.DIRECT_CHAT, ) trace.append({ "layer": 1, @@ -941,6 +1011,59 @@ class CostAwareRouter: span.set_attribute("route.target", "default") return result + # ---- Layer 1.5: Semantic Router (zero LLM cost) ---- + skill_hint = None + if self._semantic_router is not None and complexity >= 0.3: + try: + semantic_result = await self._semantic_router.route(clean_content) + if semantic_result.confidence == "high" and semantic_result.skill_name: + # Direct skill match — skip Layer 2 + trace.append({ + "layer": 1.5, + "method": "semantic_high", + "skill": semantic_result.skill_name, + "similarity": round(semantic_result.similarity, 3), + "cost": "zero", + }) + result = await resolve_skill_routing( + content=content, + skill_registry=skill_registry, + intent_router=intent_router, + default_tools=default_tools, + default_system_prompt=default_system_prompt, + default_model=default_model, + default_agent_name=default_agent_name, + agent_tool_registry=agent_tool_registry, + session_id=session_id, + force_skill=semantic_result.skill_name, + ) + result.match_method = "semantic_high" + result.match_confidence = semantic_result.similarity + result.complexity = complexity + if result.matched: + result.execution_mode = ExecutionMode.SKILL_REACT + result.execution_trace = trace if transparency != "SILENT" else [] + result.transparency_level = transparency + span.set_attribute("route.layer", "semantic_high") + span.set_attribute("route.target", result.skill_name or "default") + return result + elif semantic_result.confidence == "medium" and semantic_result.skill_name: + # Pass skill hint to Layer 1.5 merged classify or Layer 2 + skill_hint = semantic_result.skill_name + trace.append({ + "layer": 1.5, + "method": "semantic_medium", + "skill_hint": skill_hint, + "similarity": round(semantic_result.similarity, 3), + }) + except Exception as e: + logger.warning(f"Semantic routing failed, falling through: {e}") + trace.append({ + "layer": 1.5, + "method": "semantic_error", + "error": str(e), + }) + # Medium complexity → merged LLM classify or IntentRouter if complexity <= 0.7: if self._merged_llm_classify and self._llm_gateway is not None: @@ -994,7 +1117,7 @@ class CostAwareRouter: agent_tool_registry=agent_tool_registry, session_id=session_id, ) - result.complexity = result.complexity or complexity + result.complexity = result.complexity if result.complexity > 0 else complexity trace.append({ "layer": 1, "method": result.match_method or "merged_llm", diff --git a/src/agentkit/core/config_driven.py b/src/agentkit/core/config_driven.py index fe2da95..abcd26e 100644 --- a/src/agentkit/core/config_driven.py +++ b/src/agentkit/core/config_driven.py @@ -685,7 +685,7 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin): retrieval_config = self._config.memory.get("retrieval", {}) if self._config.memory else {} result = await self._react_engine.execute( messages=user_messages, - tools=self._tools if self._tools else None, + tools=self.get_tools() or None, model=self._config.llm.get("model", "default") if self._config.llm else "default", agent_name=self.name, task_type=task.task_type, @@ -735,7 +735,7 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin): result = await rewoo_engine.execute( messages=user_messages, - tools=self._tools if self._tools else None, + tools=self.get_tools() or None, model=self._config.llm.get("model", "default") if self._config.llm else "default", agent_name=self.name, task_type=task.task_type, @@ -781,7 +781,7 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin): result = await plan_exec_engine.execute( messages=user_messages, - tools=self._tools if self._tools else None, + tools=self.get_tools() or None, model=self._config.llm.get("model", "default") if self._config.llm else "default", agent_name=self.name, task_type=task.task_type, @@ -829,7 +829,7 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin): result = await reflexion_engine.execute( messages=user_messages, - tools=self._tools if self._tools else None, + tools=self.get_tools() or None, model=self._config.llm.get("model", "default") if self._config.llm else "default", agent_name=self.name, task_type=task.task_type, diff --git a/src/agentkit/core/react.py b/src/agentkit/core/react.py index 26b844b..4ff495e 100644 --- a/src/agentkit/core/react.py +++ b/src/agentkit/core/react.py @@ -441,7 +441,14 @@ class ReActEngine: except Exception as e: tool_result = {"error": f"Tool '{tc.name}' execution failed: {e}"} else: - tool_result = await self._execute_tool(tc.name, tc.arguments, tools) + # Non-dangerous tool: confirmation was for the overall action, + # re-execute with skip flag to avoid re-triggering confirmation + clean_args = {k: v for k, v in tc.arguments.items() if not k.startswith("_")} + clean_args["_skip_dangerous_check"] = True + try: + tool_result = await tool.safe_execute(**clean_args) if tool else {"error": f"Tool '{tc.name}' not found"} + except Exception as e: + tool_result = {"error": f"Tool '{tc.name}' execution failed: {e}"} else: tool_result = { "output": "", @@ -905,7 +912,13 @@ class ReActEngine: finally: pass # No shared state mutation needed else: - tool_result = await self._execute_tool(tc.name, tc.arguments, tools) + # Non-dangerous tool: re-execute with skip flag + clean_args = {k: v for k, v in tc.arguments.items() if not k.startswith("_")} + clean_args["_skip_dangerous_check"] = True + try: + tool_result = await tool.safe_execute(**clean_args) if tool else {"error": f"Tool '{tc.name}' not found"} + except Exception as e: + tool_result = {"error": f"Tool '{tc.name}' execution failed: {e}"} yield ReActEvent( event_type="confirmation_result", @@ -1261,7 +1274,13 @@ class ReActEngine: except Exception as e: tool_result = {"error": f"Tool '{tc.name}' execution failed: {e}"} else: - tool_result = await self._execute_tool(tc.name, tc.arguments, tools) + # Non-dangerous tool: re-execute with skip flag + clean_args = {k: v for k, v in tc.arguments.items() if not k.startswith("_")} + clean_args["_skip_dangerous_check"] = True + try: + tool_result = await tool.safe_execute(**clean_args) if tool else {"error": f"Tool '{tc.name}' not found"} + except Exception as e: + tool_result = {"error": f"Tool '{tc.name}' execution failed: {e}"} events.append(ReActEvent( event_type="confirmation_result", diff --git a/src/agentkit/evolution/__init__.py b/src/agentkit/evolution/__init__.py index faeb633..61f42af 100644 --- a/src/agentkit/evolution/__init__.py +++ b/src/agentkit/evolution/__init__.py @@ -13,6 +13,7 @@ from agentkit.evolution.strategy_tuner import StrategyTuner from agentkit.evolution.ab_tester import ABTester from agentkit.evolution.evolution_store import ( EvolutionStore, + EvolutionStoreProtocol, InMemoryEvolutionStore, PersistentEvolutionStore, create_evolution_store, @@ -30,6 +31,7 @@ __all__ = [ "StrategyTuner", "ABTester", "EvolutionStore", + "EvolutionStoreProtocol", "PersistentEvolutionStore", "InMemoryEvolutionStore", "create_evolution_store", diff --git a/src/agentkit/evolution/evolution_store.py b/src/agentkit/evolution/evolution_store.py index d738ab6..d879470 100644 --- a/src/agentkit/evolution/evolution_store.py +++ b/src/agentkit/evolution/evolution_store.py @@ -1,9 +1,11 @@ """EvolutionStore - 进化日志存储 -提供三种后端实现: -- EvolutionStore: 基于外部注入的异步 SQLAlchemy session(原有实现) -- PersistentEvolutionStore: 基于 SQLite 的持久化存储 -- InMemoryEvolutionStore: 基于内存字典的轻量存储(用于测试) +提供统一 Protocol 和四种后端实现: +- EvolutionStoreProtocol: 统一接口 Protocol(所有后端必须实现) +- EvolutionStore: 基于外部注入的异步 SQLAlchemy session(原有实现,仅事件操作) +- PersistentEvolutionStore: 基于 SQLite 的持久化存储(完整 Protocol) +- InMemoryEvolutionStore: 基于内存字典的轻量存储(完整 Protocol,用于测试) +- PostgreSQLEvolutionStore: 基于 PostgreSQL 的异步持久化存储(完整 Protocol,见 pg_store.py) """ import asyncio @@ -13,7 +15,7 @@ import os import time import uuid as _uuid from datetime import datetime, timezone -from typing import Any +from typing import Any, Protocol, runtime_checkable from sqlalchemy import create_engine, event as sa_event, select from sqlalchemy.exc import OperationalError @@ -30,6 +32,34 @@ from agentkit.evolution.models import ( logger = logging.getLogger(__name__) +# ── 统一 Protocol ───────────────────────────────────────── + + +@runtime_checkable +class EvolutionStoreProtocol(Protocol): + """进化存储统一接口 Protocol + + 所有后端必须实现以下方法。不支持的操作应抛出 NotImplementedError。 + """ + + async def record(self, event: EvolutionEvent) -> str: ... + async def rollback(self, event_id: str) -> bool: ... + async def list_events( + self, + agent_name: str | None = ..., + change_type: str | None = ..., + status: str | None = ..., + ) -> list[dict]: ... + async def record_skill_version( + self, skill_name: str, version: str, content: str, parent_version: str | None = ... + ) -> str: ... + async def list_skill_versions(self, skill_name: str) -> list[dict]: ... + async def record_ab_test_result( + self, test_id: str, variant: str, score: float, sample_count: int = ... + ) -> str: ... + async def get_ab_test_results(self, test_id: str) -> list[dict]: ... + + class EvolutionStore: """进化日志存储 @@ -133,6 +163,40 @@ class EvolutionStore: logger.error(f"Failed to list evolution events: {e}") return [] + # ── Protocol 兼容方法(旧版 EvolutionStore 不支持 skill_version / ab_test)── + + async def record_skill_version( + self, skill_name: str, version: str, content: str, parent_version: str | None = None + ) -> str: + """记录技能版本(旧版 SQL 后端不支持,抛出 NotImplementedError)""" + raise NotImplementedError( + "EvolutionStore (SQL backend) does not support skill_version operations. " + "Use PersistentEvolutionStore or PostgreSQLEvolutionStore instead." + ) + + async def list_skill_versions(self, skill_name: str) -> list[dict]: + """列出技能版本历史(旧版 SQL 后端不支持,抛出 NotImplementedError)""" + raise NotImplementedError( + "EvolutionStore (SQL backend) does not support skill_version operations. " + "Use PersistentEvolutionStore or PostgreSQLEvolutionStore instead." + ) + + async def record_ab_test_result( + self, test_id: str, variant: str, score: float, sample_count: int = 0 + ) -> str: + """记录 A/B 测试结果(旧版 SQL 后端不支持,抛出 NotImplementedError)""" + raise NotImplementedError( + "EvolutionStore (SQL backend) does not support A/B test operations. " + "Use PersistentEvolutionStore or PostgreSQLEvolutionStore instead." + ) + + async def get_ab_test_results(self, test_id: str) -> list[dict]: + """获取 A/B 测试结果(旧版 SQL 后端不支持,抛出 NotImplementedError)""" + raise NotImplementedError( + "EvolutionStore (SQL backend) does not support A/B test operations. " + "Use PersistentEvolutionStore or PostgreSQLEvolutionStore instead." + ) + class PersistentEvolutionStore: """SQLite 持久化进化存储 @@ -464,19 +528,32 @@ def create_evolution_store( db_path: str = "~/.agentkit/evolution.db", session_factory: Any = None, evolution_model: Any = None, + database_url: str | None = None, ) -> EvolutionStore | PersistentEvolutionStore | InMemoryEvolutionStore: """工厂函数:创建进化存储实例 Args: - backend: 存储后端类型 - "memory" | "sqlite" | "sql" + backend: 存储后端类型 - "memory" | "sqlite" | "sql" | "postgresql" db_path: SQLite 数据库路径(仅 backend="sqlite" 时使用) session_factory: 异步 SQLAlchemy session 工厂(仅 backend="sql" 时使用) evolution_model: SQLAlchemy ORM 模型类(仅 backend="sql" 时使用) + database_url: PostgreSQL 连接字符串(仅 backend="postgresql" 时使用) Returns: 对应后端的进化存储实例 """ - if backend == "sqlite": + if backend == "postgresql": + from agentkit.evolution.pg_store import PostgreSQLEvolutionStore + + url = database_url or os.environ.get("AGENTKIT_DATABASE_URL") + if url: + return PostgreSQLEvolutionStore(database_url=url) + logger.warning( + "PostgreSQL backend requested but no database_url provided, " + "falling back to InMemoryEvolutionStore" + ) + return InMemoryEvolutionStore() + elif backend == "sqlite": return PersistentEvolutionStore(db_path=db_path) elif backend == "sql" and session_factory and evolution_model: return EvolutionStore(session_factory=session_factory, evolution_model=evolution_model) diff --git a/src/agentkit/evolution/lifecycle.py b/src/agentkit/evolution/lifecycle.py index 817f949..1cce14f 100644 --- a/src/agentkit/evolution/lifecycle.py +++ b/src/agentkit/evolution/lifecycle.py @@ -483,7 +483,15 @@ class EvolutionMixin: tool = MemoryTool(memory_store) section = category - content = "; ".join(reflections[-1]["reflection"].suggestions[:2]) + # 汇总所有累积反思的建议(去重,最多取 5 条) + all_suggestions: list[str] = [] + seen: set[str] = set() + for r in reflections: + for suggestion in r["reflection"].suggestions: + if suggestion not in seen: + seen.add(suggestion) + all_suggestions.append(suggestion) + content = "; ".join(all_suggestions[:5]) reason = f"连续{len(reflections)}次低质量反思 (category: {category})" update_result = await tool.execute( diff --git a/src/agentkit/evolution/pg_store.py b/src/agentkit/evolution/pg_store.py new file mode 100644 index 0000000..aa0571a --- /dev/null +++ b/src/agentkit/evolution/pg_store.py @@ -0,0 +1,329 @@ +"""PostgreSQLEvolutionStore - 基于 PostgreSQL 的异步进化存储 + +使用 async SQLAlchemy + asyncpg 实现完整的 EvolutionStoreProtocol, +支持进化事件、技能版本、A/B 测试结果的持久化。 +""" + +import asyncio +import logging +import uuid as _uuid +from datetime import datetime, timezone +from typing import Any + +from sqlalchemy import select +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy import Column, DateTime, Float, Integer, String, Text, UniqueConstraint +from sqlalchemy.orm import declarative_base + +from agentkit.core.protocol import EvolutionEvent + +logger = logging.getLogger(__name__) + +# PG 专用 Base(与 SQLite models 的 Base 隔离,避免表名冲突) +PGBase = declarative_base() + + +def _utcnow() -> datetime: + return datetime.now(timezone.utc) + + +class PGEvolutionEventModel(PGBase): + """PostgreSQL 进化事件 ORM 模型""" + + __tablename__ = "evolution_events" + + id = Column(String, primary_key=True, default=lambda: str(_uuid.uuid4())) + agent_name = Column(String, index=True) + change_type = Column(String, nullable=True) + before = Column(JSONB, nullable=True) + after = Column(JSONB, nullable=True) + metrics = Column(JSONB, nullable=True) + status = Column(String, default="active") + created_at = Column(DateTime, default=_utcnow) + + +class PGSkillVersionModel(PGBase): + """PostgreSQL 技能版本 ORM 模型""" + + __tablename__ = "skill_versions" + __table_args__ = (UniqueConstraint("skill_name", "version"),) + + id = Column(String, primary_key=True, default=lambda: str(_uuid.uuid4())) + skill_name = Column(String, index=True) + version = Column(String) + content = Column(Text) + parent_version = Column(String, nullable=True) + created_at = Column(DateTime, default=_utcnow) + + +class PGABTestResultModel(PGBase): + """PostgreSQL A/B 测试结果 ORM 模型""" + + __tablename__ = "ab_test_results" + + id = Column(String, primary_key=True, default=lambda: str(_uuid.uuid4())) + test_id = Column(String, index=True) + variant = Column(String) + score = Column(Float) + sample_count = Column(Integer, default=0) + created_at = Column(DateTime, default=_utcnow) + + +class PostgreSQLEvolutionStore: + """PostgreSQL 异步进化存储 + + 使用 async SQLAlchemy session 实现完整的 EvolutionStoreProtocol。 + 支持进化事件、技能版本、A/B 测试结果的 CRUD 操作。 + + 用法: + store = PostgreSQLEvolutionStore( + database_url="postgresql+asyncpg://user:pass@localhost/dbname" + ) + await store.ensure_tables() + event_id = await store.record(event) + """ + + def __init__(self, database_url: str) -> None: + self._database_url = database_url + self._engine: Any = None + self._session_factory: Any = None + self._initialized = False + self._init_lock = asyncio.Lock() + + async def _ensure_initialized(self) -> None: + """延迟初始化 async engine 和 session factory(带锁防并发)""" + if self._initialized: + return + + async with self._init_lock: + # Double-check after acquiring lock + if self._initialized: + return + + from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine + from sqlalchemy.orm import sessionmaker + + self._engine = create_async_engine(self._database_url, echo=False) + self._session_factory = sessionmaker( + self._engine, class_=AsyncSession, expire_on_commit=False + ) + self._initialized = True + + async def ensure_tables(self) -> None: + """创建所有表(如果不存在) + + 安全的启动调用 — 使用 CREATE TABLE IF NOT EXISTS。 + """ + await self._ensure_initialized() + async with self._engine.begin() as conn: + await conn.run_sync(PGBase.metadata.create_all) + + async def close(self) -> None: + """关闭 engine,释放所有连接""" + if self._engine is not None: + await self._engine.dispose() + self._engine = None + self._session_factory = None + self._initialized = False + + async def __aenter__(self) -> "PostgreSQLEvolutionStore": + return self + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + await self.close() + + # ── 进化事件 ────────────────────────────────────────── + + async def record(self, event: EvolutionEvent) -> str: + """记录进化事件""" + await self._ensure_initialized() + async with self._session_factory() as db: + try: + event_id = str(_uuid.uuid4()) + entry = PGEvolutionEventModel( + id=event_id, + agent_name=event.agent_name, + change_type=event.change_type, + before=event.before, + after=event.after, + metrics=event.metrics, + status="active", + ) + db.add(entry) + await db.commit() + event.event_id = event_id + logger.info(f"Evolution event recorded: {event_id} for agent '{event.agent_name}'") + return event_id + except Exception as e: + await db.rollback() + logger.error(f"Failed to record evolution event: {e}") + raise + + async def rollback(self, event_id: str) -> bool: + """回滚进化事件""" + await self._ensure_initialized() + async with self._session_factory() as db: + try: + stmt = select(PGEvolutionEventModel).where( + PGEvolutionEventModel.id == event_id + ) + result = await db.execute(stmt) + entry = result.scalar_one_or_none() + + if not entry: + logger.error(f"Evolution event {event_id} not found") + return False + + entry.status = "rolled_back" + await db.commit() + logger.info(f"Evolution event {event_id} rolled back") + return True + except Exception as e: + await db.rollback() + logger.error(f"Failed to rollback evolution event {event_id}: {e}") + return False + + async def list_events( + self, + agent_name: str | None = None, + change_type: str | None = None, + status: str | None = None, + limit: int = 100, + ) -> list[dict]: + """列出进化事件""" + await self._ensure_initialized() + async with self._session_factory() as db: + try: + stmt = select(PGEvolutionEventModel) + if agent_name: + stmt = stmt.where(PGEvolutionEventModel.agent_name == agent_name) + if change_type: + stmt = stmt.where(PGEvolutionEventModel.change_type == change_type) + if status: + stmt = stmt.where(PGEvolutionEventModel.status == status) + stmt = stmt.order_by(PGEvolutionEventModel.created_at.desc()).limit(limit) + result = await db.execute(stmt) + entries = result.scalars().all() + + return [ + { + "id": e.id, + "agent_name": e.agent_name, + "change_type": e.change_type, + "before": e.before, + "after": e.after, + "metrics": e.metrics, + "status": e.status, + "created_at": e.created_at.isoformat() if e.created_at else None, + } + for e in entries + ] + except Exception as e: + logger.error(f"Failed to list evolution events: {e}") + return [] + + # ── 技能版本 ────────────────────────────────────────── + + async def record_skill_version( + self, skill_name: str, version: str, content: str, parent_version: str | None = None + ) -> str: + """记录技能版本""" + await self._ensure_initialized() + async with self._session_factory() as db: + try: + vid = str(_uuid.uuid4()) + entry = PGSkillVersionModel( + id=vid, + skill_name=skill_name, + version=version, + content=content, + parent_version=parent_version, + ) + db.add(entry) + await db.commit() + return vid + except Exception as e: + await db.rollback() + logger.error(f"Failed to record skill version: {e}") + raise + + async def list_skill_versions(self, skill_name: str, limit: int = 100) -> list[dict]: + """列出技能版本历史""" + await self._ensure_initialized() + async with self._session_factory() as db: + try: + stmt = ( + select(PGSkillVersionModel) + .where(PGSkillVersionModel.skill_name == skill_name) + .order_by(PGSkillVersionModel.created_at.desc()) + .limit(limit) + ) + result = await db.execute(stmt) + entries = result.scalars().all() + return [ + { + "id": e.id, + "skill_name": e.skill_name, + "version": e.version, + "content": e.content, + "parent_version": e.parent_version, + "created_at": e.created_at.isoformat() if e.created_at else None, + } + for e in entries + ] + except Exception as e: + logger.error(f"Failed to list skill versions: {e}") + return [] + + # ── A/B 测试结果 ────────────────────────────────────── + + async def record_ab_test_result( + self, test_id: str, variant: str, score: float, sample_count: int = 0 + ) -> str: + """记录 A/B 测试结果""" + await self._ensure_initialized() + async with self._session_factory() as db: + try: + rid = str(_uuid.uuid4()) + entry = PGABTestResultModel( + id=rid, + test_id=test_id, + variant=variant, + score=score, + sample_count=sample_count, + ) + db.add(entry) + await db.commit() + return rid + except Exception as e: + await db.rollback() + logger.error(f"Failed to record A/B test result: {e}") + raise + + async def get_ab_test_results(self, test_id: str, limit: int = 100) -> list[dict]: + """获取 A/B 测试结果""" + await self._ensure_initialized() + async with self._session_factory() as db: + try: + stmt = ( + select(PGABTestResultModel) + .where(PGABTestResultModel.test_id == test_id) + .order_by(PGABTestResultModel.created_at.desc()) + .limit(limit) + ) + result = await db.execute(stmt) + entries = result.scalars().all() + return [ + { + "id": e.id, + "test_id": e.test_id, + "variant": e.variant, + "score": e.score, + "sample_count": e.sample_count, + "created_at": e.created_at.isoformat() if e.created_at else None, + } + for e in entries + ] + except Exception as e: + logger.error(f"Failed to get A/B test results: {e}") + return [] diff --git a/src/agentkit/llm/__init__.py b/src/agentkit/llm/__init__.py index f9f58dc..2a2f7b5 100644 --- a/src/agentkit/llm/__init__.py +++ b/src/agentkit/llm/__init__.py @@ -5,7 +5,8 @@ from agentkit.llm.gateway import LLMGateway from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, TokenUsage, ToolCall from agentkit.llm.providers.anthropic import AnthropicProvider from agentkit.llm.providers.openai import OpenAICompatibleProvider -from agentkit.llm.providers.tracker import UsageRecord, UsageSummary, UsageTracker +from agentkit.llm.providers.tracker import UsageSummary, UsageTracker +from agentkit.llm.providers.usage_store import UsageRecord from agentkit.llm.retry import ( CircuitBreaker, CircuitBreakerConfig, diff --git a/src/agentkit/llm/cache.py b/src/agentkit/llm/cache.py new file mode 100644 index 0000000..cc7dda6 --- /dev/null +++ b/src/agentkit/llm/cache.py @@ -0,0 +1,632 @@ +"""LLM Response Cache — Exact-match + Semantic-match dual cache for LLM responses. + +Architecture: + - LLMCache Protocol: async interface for cache backends + - InMemoryLLMCache: OrderedDict LRU + embedding index + - RedisLLMCache: Redis keys + SET index + lazy init + - create_llm_cache(): Factory with auto-detection + +Design doc: docs/plans/2026-06-14-002-u1-llm-cache-architecture.md +""" + +import json +import logging +import time +from collections import OrderedDict +from dataclasses import dataclass, field +from typing import Any, Protocol, runtime_checkable + +from agentkit.llm.protocol import LLMResponse, TokenUsage, ToolCall +from agentkit.utils.vector_math import compute_cosine_similarity + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Data Classes +# --------------------------------------------------------------------------- + + +@dataclass +class CacheEntry: + """A cached LLM response with metadata.""" + + response: LLMResponse + query_embedding: list[float] = field(default_factory=list) + created_at: float = 0.0 + hit_count: int = 0 + + +@dataclass +class CacheResult: + """Result of a cache lookup.""" + + hit: bool = False + response: LLMResponse | None = None + match_type: str = "" # "exact" | "semantic" | "" (miss) + + +# --------------------------------------------------------------------------- +# Serialization helpers (for Redis backend) +# --------------------------------------------------------------------------- + + +def _serialize_response(response: LLMResponse) -> dict: + """Serialize LLMResponse to a JSON-compatible dict.""" + return { + "content": response.content, + "model": response.model, + "usage": { + "prompt_tokens": response.usage.prompt_tokens, + "completion_tokens": response.usage.completion_tokens, + }, + "tool_calls": [ + {"id": tc.id, "name": tc.name, "arguments": tc.arguments} + for tc in response.tool_calls + ], + "latency_ms": response.latency_ms, + } + + +def _deserialize_response(data: dict) -> LLMResponse: + """Deserialize a dict back to LLMResponse.""" + usage_data = data.get("usage", {}) + return LLMResponse( + content=data["content"], + model=data["model"], + usage=TokenUsage( + prompt_tokens=usage_data.get("prompt_tokens", 0), + completion_tokens=usage_data.get("completion_tokens", 0), + ), + tool_calls=[ + ToolCall(id=tc["id"], name=tc["name"], arguments=tc["arguments"]) + for tc in data.get("tool_calls", []) + ], + latency_ms=data.get("latency_ms", 0.0), + ) + + +def _serialize_entry(entry: CacheEntry) -> dict: + """Serialize CacheEntry to a JSON-compatible dict.""" + return { + "response": _serialize_response(entry.response), + "query_embedding": entry.query_embedding, + "created_at": entry.created_at, + "hit_count": entry.hit_count, + } + + +def _deserialize_entry(data: dict) -> CacheEntry: + """Deserialize a dict back to CacheEntry.""" + return CacheEntry( + response=_deserialize_response(data["response"]), + query_embedding=data.get("query_embedding", []), + created_at=data.get("created_at", 0.0), + hit_count=data.get("hit_count", 0), + ) + + +# --------------------------------------------------------------------------- +# LLMCache Protocol +# --------------------------------------------------------------------------- + + +@runtime_checkable +class LLMCache(Protocol): + """LLM response cache interface.""" + + async def get(self, key: str) -> CacheResult: + """Exact-match lookup by cache key.""" + ... + + async def semantic_search( + self, query_embedding: list[float], threshold: float = 0.92 + ) -> CacheResult: + """Semantic similarity search across cached entries.""" + ... + + async def put( + self, + key: str, + response: LLMResponse, + query_embedding: list[float] | None = None, + ) -> None: + """Store a response in the cache with optional embedding.""" + ... + + async def invalidate(self, pattern: str | None = None) -> int: + """Invalidate cache entries. Returns count of invalidated entries.""" + ... + + async def stats(self) -> dict[str, int]: + """Return cache statistics.""" + ... + + +# --------------------------------------------------------------------------- +# InMemoryLLMCache +# --------------------------------------------------------------------------- + + +class InMemoryLLMCache: + """In-memory LLM cache with LRU eviction and semantic search. + + Uses OrderedDict for O(1) LRU access/eviction (follows EmbeddingCache pattern). + Maintains a parallel embedding index for semantic similarity search. + """ + + def __init__( + self, + max_entries: int = 10000, + exact_ttl: int = 3600, + semantic_ttl: int = 86400, + similarity_threshold: float = 0.92, + ): + self._max_entries = max_entries + self._exact_ttl = exact_ttl + self._semantic_ttl = semantic_ttl + self._similarity_threshold = similarity_threshold + + self._cache: OrderedDict[str, CacheEntry] = OrderedDict() + self._embeddings: dict[str, list[float]] = {} + + self._hits = 0 + self._misses = 0 + + async def get(self, key: str) -> CacheResult: + now = time.monotonic() + entry = self._cache.get(key) + + if entry is not None: + if now - entry.created_at <= self._exact_ttl: + # Hit: update LRU position and stats + self._cache.move_to_end(key) + entry.hit_count += 1 + self._hits += 1 + return CacheResult(hit=True, response=entry.response, match_type="exact") + # Expired: remove + del self._cache[key] + self._embeddings.pop(key, None) + + self._misses += 1 + return CacheResult(hit=False) + + async def semantic_search( + self, query_embedding: list[float], threshold: float | None = None + ) -> CacheResult: + if not self._embeddings: + return CacheResult(hit=False) + + effective_threshold = threshold or self._similarity_threshold + now = time.monotonic() + best_key: str | None = None + best_sim: float = 0.0 + + for key, emb in self._embeddings.items(): + entry = self._cache.get(key) + if entry is None: + continue + # Check semantic TTL + if now - entry.created_at > self._semantic_ttl: + continue + sim = compute_cosine_similarity(query_embedding, emb) + if sim > best_sim: + best_sim = sim + best_key = key + + if best_key is not None and best_sim >= effective_threshold: + entry = self._cache[best_key] + entry.hit_count += 1 + self._cache.move_to_end(best_key) + self._hits += 1 + return CacheResult(hit=True, response=entry.response, match_type="semantic") + + self._misses += 1 + return CacheResult(hit=False) + + async def put( + self, + key: str, + response: LLMResponse, + query_embedding: list[float] | None = None, + ) -> None: + now = time.monotonic() + + if key in self._cache: + self._cache.move_to_end(key) + existing = self._cache[key] + # Preserve existing embedding if new one is None + effective_embedding = query_embedding if query_embedding is not None else existing.query_embedding + else: + effective_embedding = query_embedding or [] + + self._cache[key] = CacheEntry( + response=response, + query_embedding=effective_embedding, + created_at=now, + hit_count=0, + ) + + if effective_embedding: + self._embeddings[key] = effective_embedding + + # Evict LRU entries if over capacity + while len(self._cache) > self._max_entries: + evicted_key, _ = self._cache.popitem(last=False) + self._embeddings.pop(evicted_key, None) + + # Lazy cleanup: remove a few expired entries on each put to prevent memory leak + # Check oldest entries first (they are most likely to be expired) + if len(self._cache) > 0: + expired_keys = [] + # Iterate from oldest (front of OrderedDict) to find expired entries + for k in list(self._cache.keys())[:20]: + entry = self._cache.get(k) + if entry is not None and now - entry.created_at > self._semantic_ttl: + expired_keys.append(k) + for k in expired_keys: + self._cache.pop(k, None) + self._embeddings.pop(k, None) + + async def invalidate(self, pattern: str | None = None) -> int: + if pattern is None: + count = len(self._cache) + self._cache.clear() + self._embeddings.clear() + return count + + # Simple prefix matching for pattern + keys_to_remove = [ + k for k in self._cache if k.startswith(pattern.replace("*", "")) + ] + for key in keys_to_remove: + del self._cache[key] + self._embeddings.pop(key, None) + return len(keys_to_remove) + + async def stats(self) -> dict[str, int]: + return { + "total_entries": len(self._cache), + "total_hits": self._hits, + "total_misses": self._misses, + } + + +# --------------------------------------------------------------------------- +# RedisLLMCache +# --------------------------------------------------------------------------- + + +class RedisLLMCache: + """Redis-backed LLM cache with SET index for semantic search. + + Key schema: + agentkit:llm_cache:{sha256_hex} → JSON(CacheEntry) with TTL + agentkit:llm_cache_emb:{sha256_hex} → JSON(list[float]) with TTL + agentkit:llm_cache_index → SET of active cache keys + """ + + KEY_PREFIX = "agentkit:llm_cache:" + EMB_PREFIX = "agentkit:llm_cache_emb:" + INDEX_KEY = "agentkit:llm_cache_index" + + def __init__( + self, + redis_url: str = "redis://localhost:6379", + max_entries: int = 10000, + exact_ttl: int = 3600, + semantic_ttl: int = 86400, + similarity_threshold: float = 0.92, + max_entries_to_scan: int = 500, + fallback: InMemoryLLMCache | None = None, + ): + self._redis_url = redis_url + self._max_entries = max_entries + self._exact_ttl = exact_ttl + self._semantic_ttl = semantic_ttl + self._similarity_threshold = similarity_threshold + self._max_entries_to_scan = max_entries_to_scan + self._redis: Any = None + self._fallback: InMemoryLLMCache | None = fallback # For auto-degradation + self._degraded = False # True if Redis is unreachable + + self._hits = 0 + self._misses = 0 + + async def _get_redis(self): + """Lazy Redis initialization (follows RedisSessionStore pattern).""" + if self._redis is None: + import redis.asyncio as aioredis + + self._redis = aioredis.from_url( + self._redis_url, decode_responses=True + ) + return self._redis + + async def aclose(self) -> None: + """Close the Redis connection pool.""" + if self._redis is not None: + await self._redis.aclose() + self._redis = None + + def _degrade_to_fallback(self) -> None: + """Mark Redis as unreachable and switch to in-memory fallback.""" + if not self._degraded: + self._degraded = True + self._degrade_count = 0 + if self._fallback is None: + self._fallback = InMemoryLLMCache( + max_entries=self._max_entries, + exact_ttl=self._exact_ttl, + semantic_ttl=self._semantic_ttl, + similarity_threshold=self._similarity_threshold, + ) + logger.warning("Redis cache unreachable, degraded to in-memory fallback") + + def _try_recover(self) -> None: + """Attempt to recover from degraded state after enough operations. + + Resets the degraded flag optimistically. The next actual Redis + operation will verify connectivity — if it fails, degradation + is re-triggered immediately. + """ + if not self._degraded: + return + self._degrade_count = getattr(self, "_degrade_count", 0) + 1 + # Try recovery every 100 operations + if self._degrade_count >= 100: + self._degrade_count = 0 + self._degraded = False + logger.info("Redis cache: attempting recovery from degraded state") + + async def get(self, key: str) -> CacheResult: + # If degraded to fallback, use InMemory cache + if self._degraded and self._fallback is not None: + self._try_recover() + if self._degraded: + return await self._fallback.get(key) + # Recovery attempted — fall through to try Redis + + try: + redis = await self._get_redis() + data = await redis.get(f"{self.KEY_PREFIX}{key}") + if data is not None: + entry = _deserialize_entry(json.loads(data)) + self._hits += 1 + return CacheResult( + hit=True, response=entry.response, match_type="exact" + ) + self._misses += 1 + return CacheResult(hit=False) + except Exception as e: + logger.warning(f"Redis cache get failed, returning miss: {e}") + self._degrade_to_fallback() + if self._fallback is not None: + return await self._fallback.get(key) + return CacheResult(hit=False) + + async def semantic_search( + self, query_embedding: list[float], threshold: float | None = None + ) -> CacheResult: + try: + redis = await self._get_redis() + effective_threshold = threshold or self._similarity_threshold + + # Get all cache keys from index + cache_keys = await redis.smembers(self.INDEX_KEY) + if not cache_keys: + return CacheResult(hit=False) + + # Limit scan to avoid O(n) memory/network transfer for large caches + # Sample up to max_entries_to_scan most recent keys + cache_keys_list = list(cache_keys) + max_scan = min(len(cache_keys_list), self._max_entries_to_scan) + if len(cache_keys_list) > max_scan: + # Take a random sample to avoid always scanning the same subset + import random + cache_keys_list = random.sample(cache_keys_list, max_scan) + + # Batch fetch embeddings + emb_keys = [f"{self.EMB_PREFIX}{k}" for k in cache_keys_list] + emb_values = await redis.mget(emb_keys) + + best_key: str | None = None + best_sim: float = 0.0 + stale_keys: list[str] = [] # Keys whose data has expired + + for cache_key, emb_json in zip(cache_keys_list, emb_values): + if emb_json is None: + # Embedding expired but index entry remains — mark for cleanup + stale_keys.append(cache_key) + continue + emb = json.loads(emb_json) + sim = compute_cosine_similarity(query_embedding, emb) + if sim > best_sim: + best_sim = sim + best_key = cache_key + + # Lazy cleanup: remove stale index entries + if stale_keys: + try: + pipe = redis.pipeline() + for k in stale_keys: + pipe.srem(self.INDEX_KEY, k) + await pipe.execute() + except Exception: + pass # Best-effort cleanup + + if best_key is not None and best_sim >= effective_threshold: + data = await redis.get(f"{self.KEY_PREFIX}{best_key}") + if data is not None: + entry = _deserialize_entry(json.loads(data)) + self._hits += 1 + return CacheResult( + hit=True, response=entry.response, match_type="semantic" + ) + # Data key expired but embedding still exists — mark for cleanup + try: + await redis.srem(self.INDEX_KEY, best_key) + except Exception: + pass + + self._misses += 1 + return CacheResult(hit=False) + except Exception as e: + logger.warning(f"Redis semantic search failed, returning miss: {e}") + self._degrade_to_fallback() + if self._fallback is not None: + return await self._fallback.semantic_search(query_embedding, threshold) + self._misses += 1 + return CacheResult(hit=False) + + async def put( + self, + key: str, + response: LLMResponse, + query_embedding: list[float] | None = None, + ) -> None: + # If degraded to fallback, use InMemory cache + if self._degraded and self._fallback is not None: + self._try_recover() + if self._degraded: + await self._fallback.put(key, response, query_embedding) + return + # Recovery attempted — fall through to try Redis + + try: + redis = await self._get_redis() + entry = CacheEntry( + response=response, + query_embedding=query_embedding or [], + created_at=time.time(), # Wall-clock for cross-process comparability in Redis + hit_count=0, + ) + + pipe = redis.pipeline() + # Data key TTL must cover both exact and semantic windows + # so semantic hits don't return None data + data_ttl = max(self._exact_ttl, self._semantic_ttl) + pipe.set( + f"{self.KEY_PREFIX}{key}", + json.dumps(_serialize_entry(entry)), + ex=data_ttl, + ) + if query_embedding is not None: + pipe.set( + f"{self.EMB_PREFIX}{key}", + json.dumps(query_embedding), + ex=self._semantic_ttl, + ) + pipe.sadd(self.INDEX_KEY, key) + await pipe.execute() + except Exception as e: + logger.warning(f"Redis cache put failed: {e}") + self._degrade_to_fallback() + if self._fallback is not None: + await self._fallback.put(key, response, query_embedding) + + async def invalidate(self, pattern: str | None = None) -> int: + try: + redis = await self._get_redis() + + if pattern is None: + cache_keys = await redis.smembers(self.INDEX_KEY) + if not cache_keys: + return 0 + pipe = redis.pipeline() + for key in cache_keys: + pipe.delete(f"{self.KEY_PREFIX}{key}") + pipe.delete(f"{self.EMB_PREFIX}{key}") + pipe.delete(self.INDEX_KEY) + await pipe.execute() + return len(cache_keys) + + # Pattern-based invalidation (prefix match) + prefix = pattern.replace("*", "") + cache_keys = await redis.smembers(self.INDEX_KEY) + keys_to_remove = [k for k in cache_keys if k.startswith(prefix)] + + if not keys_to_remove: + return 0 + + pipe = redis.pipeline() + for key in keys_to_remove: + pipe.delete(f"{self.KEY_PREFIX}{key}") + pipe.delete(f"{self.EMB_PREFIX}{key}") + pipe.srem(self.INDEX_KEY, key) + await pipe.execute() + return len(keys_to_remove) + except Exception as e: + logger.warning(f"Redis cache invalidate failed: {e}") + return 0 + + async def stats(self) -> dict[str, int]: + try: + redis = await self._get_redis() + total_entries = await redis.scard(self.INDEX_KEY) + return { + "total_entries": total_entries, + "total_hits": self._hits, + "total_misses": self._misses, + } + except Exception: + return { + "total_entries": 0, + "total_hits": self._hits, + "total_misses": self._misses, + } + + +# --------------------------------------------------------------------------- +# Factory +# --------------------------------------------------------------------------- + + +def create_llm_cache( + backend: str = "auto", + redis_url: str = "redis://localhost:6379", + max_entries: int = 10000, + exact_ttl: int = 3600, + semantic_ttl: int = 86400, + similarity_threshold: float = 0.92, +) -> LLMCache: + """Create an LLM cache backend. + + Args: + backend: "auto" (try Redis, fallback to memory), "redis", "memory". + redis_url: Redis connection URL (only used for "redis"/"auto" backend). + max_entries: Maximum number of cache entries. + exact_ttl: TTL in seconds for exact-match cache entries. + semantic_ttl: TTL in seconds for semantic-match embeddings. + similarity_threshold: Cosine similarity threshold for semantic match. + + Returns: + An LLMCache instance. + """ + if backend in ("auto", "redis"): + try: + import redis.asyncio as aioredis # noqa: F401 + + return RedisLLMCache( + redis_url=redis_url, + max_entries=max_entries, + exact_ttl=exact_ttl, + semantic_ttl=semantic_ttl, + similarity_threshold=similarity_threshold, + ) + except ImportError: + logger.warning( + "redis package not available, falling back to in-memory cache" + ) + return InMemoryLLMCache( + max_entries=max_entries, + exact_ttl=exact_ttl, + semantic_ttl=semantic_ttl, + similarity_threshold=similarity_threshold, + ) + return InMemoryLLMCache( + max_entries=max_entries, + exact_ttl=exact_ttl, + semantic_ttl=semantic_ttl, + similarity_threshold=similarity_threshold, + ) diff --git a/src/agentkit/llm/cache_key.py b/src/agentkit/llm/cache_key.py new file mode 100644 index 0000000..63eb2e4 --- /dev/null +++ b/src/agentkit/llm/cache_key.py @@ -0,0 +1,66 @@ +"""LLM Cache Key Generation — Deterministic SHA-256 cache key from LLM request parameters.""" + +import hashlib +import json +from typing import Any + + +def generate_cache_key( + model: str, + messages: list[dict[str, str]], + temperature: float, + tools: list[dict[str, Any]] | None = None, + tool_choice: str = "auto", + max_tokens: int = 2000, +) -> str: + """Generate a deterministic SHA-256 cache key from LLM request parameters. + + The key captures ALL inputs that deterministically affect LLM output: + model, system_prompt (extracted from messages), messages content, + temperature, tools, tool_choice, and max_tokens. + + Args: + model: Model identifier (e.g. "openai/gpt-4o"). + messages: Chat messages list (may include system prompt as first message). + temperature: Sampling temperature. + tools: Optional list of tool definitions. + tool_choice: Tool selection mode ("auto", "none", etc.). + max_tokens: Maximum response tokens. + + Returns: + 64-character hex SHA-256 hash string. + """ + system_prompt = _extract_system_prompt(messages) + components = [ + _hash_str(model), + _hash_str(system_prompt), + _hash_json(messages), + _hash_str(f"{temperature:.2f}"), + _hash_json(tools), + _hash_str(tool_choice), + _hash_str(str(max_tokens)), + ] + combined = "".join(components) + return hashlib.sha256(combined.encode()).hexdigest() + + +def _extract_system_prompt(messages: list[dict[str, str]]) -> str: + """Extract system prompt from messages list.""" + for msg in messages: + if msg.get("role") == "system": + return msg.get("content", "") + return "" + + +def _hash_str(s: str) -> str: + """SHA-256 hash of a string.""" + return hashlib.sha256(s.encode()).hexdigest() + + +def _hash_json(obj: Any) -> str: + """SHA-256 hash of a JSON-serializable object.""" + if obj is None: + return hashlib.sha256(b"null").hexdigest() + return hashlib.sha256( + json.dumps(obj, sort_keys=True, ensure_ascii=False).encode() + ).hexdigest() diff --git a/src/agentkit/llm/config.py b/src/agentkit/llm/config.py index 67f8a8b..65f8fbc 100644 --- a/src/agentkit/llm/config.py +++ b/src/agentkit/llm/config.py @@ -8,6 +8,43 @@ import yaml from agentkit.llm.retry import CircuitBreakerConfig, RetryConfig +@dataclass +class CacheConfig: + """LLM Cache 配置""" + + enabled: bool = False + backend: str = "auto" # "auto" | "redis" | "memory" + redis_url: str = "redis://localhost:6379" + exact_ttl: int = 3600 + semantic_ttl: int = 86400 + similarity_threshold: float = 0.92 + max_entries: int = 10000 + # Embedding config for semantic cache (Chinese-first: bge-m3 via Xinference) + embedding_provider: str = "openai" # "openai" | "xinference" | "local" + embedding_model: str = "bge-m3" + embedding_base_url: str | None = None + embedding_api_key: str | None = None + + @classmethod + def from_dict(cls, data: dict) -> "CacheConfig": + if not data: + return cls() + emb = data.get("embedding", {}) + return cls( + enabled=data.get("enabled", False), + backend=data.get("backend", "auto"), + redis_url=data.get("redis_url", "redis://localhost:6379"), + exact_ttl=data.get("exact_ttl", 3600), + semantic_ttl=data.get("semantic_ttl", 86400), + similarity_threshold=data.get("similarity_threshold", 0.92), + max_entries=data.get("max_entries", 10000), + embedding_provider=emb.get("provider", "openai"), + embedding_model=emb.get("model", "bge-m3"), + embedding_base_url=emb.get("base_url"), + embedding_api_key=emb.get("api_key"), + ) + + @dataclass class ProviderConfig: """Provider 配置""" @@ -32,6 +69,7 @@ class LLMConfig: providers: dict[str, ProviderConfig] = field(default_factory=dict) model_aliases: dict[str, str] = field(default_factory=dict) fallbacks: dict[str, list[str]] = field(default_factory=dict) + cache: CacheConfig | None = None @classmethod def from_yaml(cls, path: str) -> "LLMConfig": @@ -77,8 +115,14 @@ class LLMConfig: retry=retry, circuit_breaker=circuit_breaker, ) + cache = None + cache_data = data.get("cache") + if cache_data: + cache = CacheConfig.from_dict(cache_data) + return cls( providers=providers, model_aliases=data.get("model_aliases", {}), fallbacks=data.get("fallbacks", {}), + cache=cache, ) diff --git a/src/agentkit/llm/gateway.py b/src/agentkit/llm/gateway.py index 7e7f20e..64bad9a 100644 --- a/src/agentkit/llm/gateway.py +++ b/src/agentkit/llm/gateway.py @@ -2,6 +2,7 @@ import logging import time +from typing import Any from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError from agentkit.llm.config import LLMConfig @@ -14,13 +15,53 @@ logger = logging.getLogger(__name__) class LLMGateway: - """LLM 网关 - Provider 注册、模型别名解析、Fallback、Usage 追踪""" + """LLM 网关 - Provider 注册、模型别名解析、Fallback、Usage 追踪、Cache""" - def __init__(self, config: LLMConfig | None = None): + def __init__(self, config: LLMConfig | None = None, usage_store: Any = None): self._providers: dict[str, LLMProvider] = {} - self._usage_tracker = UsageTracker() + self._usage_tracker = UsageTracker(store=usage_store) if usage_store else UsageTracker() self._config = config or LLMConfig() + # Cache (opt-in, disabled by default) + self._cache: Any = None # LLMCache | None + self._embedder: Any = None # Embedder | None + if self._config.cache and self._config.cache.enabled: + from agentkit.llm.cache import create_llm_cache + self._cache = create_llm_cache( + backend=self._config.cache.backend, + redis_url=self._config.cache.redis_url, + max_entries=self._config.cache.max_entries, + exact_ttl=self._config.cache.exact_ttl, + semantic_ttl=self._config.cache.semantic_ttl, + similarity_threshold=self._config.cache.similarity_threshold, + ) + self._embedder = self._create_embedder(self._config.cache) + logger.info( + f"LLM cache enabled (backend={self._config.cache.backend}, " + f"embedder={self._config.cache.embedding_provider}/{self._config.cache.embedding_model})" + ) + + def _create_embedder(self, cache_config) -> Any: + """Create embedder for semantic cache based on config.""" + try: + from agentkit.memory.embedder import OpenAIEmbedder + + if cache_config.embedding_provider in ("xinference", "local"): + return OpenAIEmbedder( + api_key=cache_config.embedding_api_key or "not-needed", + model=cache_config.embedding_model, + base_url=cache_config.embedding_base_url or "http://localhost:9997/v1", + ) + # Default: OpenAI + return OpenAIEmbedder( + api_key=cache_config.embedding_api_key, + model=cache_config.embedding_model, + base_url=cache_config.embedding_base_url, + ) + except Exception as e: + logger.warning(f"Failed to create embedder for semantic cache: {e}") + return None + def register_provider(self, name: str, provider: LLMProvider) -> None: """注册 Provider""" self._providers[name] = provider @@ -66,6 +107,66 @@ class LLMGateway: _span = _span_cm.__enter__() start = time.monotonic() + + # ── Cache check ── + cache_key = None + query_embedding = None + if self._cache is not None: + from agentkit.llm.cache_key import generate_cache_key + + cache_key = generate_cache_key( + model=resolved_model, + messages=messages, + temperature=kwargs.get("temperature", 0.7), + tools=tools, + tool_choice=tool_choice, + max_tokens=kwargs.get("max_tokens", 2000), + ) + result = await self._cache.get(cache_key) + if result.hit: + latency_ms = (time.monotonic() - start) * 1000 + self._usage_tracker.record( + agent_name=agent_name, + model=result.response.model, + usage=result.response.usage, + cost=0.0, + latency_ms=latency_ms, + ) + if _span is not None: + _span.set_attribute("gen_ai.cache.hit", True) + _span.set_attribute("gen_ai.cache.match_type", result.match_type) + return result.response + + # Semantic match (only for temperature == 0) + temperature = kwargs.get("temperature", 0.7) + if temperature == 0 and self._embedder is not None: + try: + # Embed last N messages for context-aware semantic matching + # (not just last user message — avoids cross-context false hits) + recent_messages = messages[-3:] if len(messages) > 3 else messages + embed_text = " | ".join( + m.get("content", "") for m in recent_messages if m.get("content") + ) + if embed_text: + query_embedding = await self._embedder.embed(embed_text) + result = await self._cache.semantic_search(query_embedding) + if result.hit: + latency_ms = (time.monotonic() - start) * 1000 + self._usage_tracker.record( + agent_name=agent_name, + model=result.response.model, + usage=result.response.usage, + cost=0.0, + latency_ms=latency_ms, + ) + if _span is not None: + _span.set_attribute("gen_ai.cache.hit", True) + _span.set_attribute("gen_ai.cache.match_type", "semantic") + return result.response + except Exception as e: + logger.warning(f"Semantic cache search failed: {e}") + + # ── Normal provider call ── models_to_try = self._get_models_to_try(resolved_model) last_error: LLMProviderError | None = None @@ -95,6 +196,13 @@ class LLMGateway: latency_ms = (time.monotonic() - start) * 1000 + # ── Cache write ── + if self._cache is not None and cache_key is not None: + try: + await self._cache.put(cache_key, response, query_embedding) + except Exception as e: + logger.warning(f"Cache write failed: {e}") + # 计算成本 cost = self._calculate_cost(response.model, response.usage) @@ -112,7 +220,9 @@ class LLMGateway: _span.set_attribute("gen_ai.usage.input_tokens", response.usage.prompt_tokens) _span.set_attribute("gen_ai.usage.output_tokens", response.usage.completion_tokens) _span.set_attribute("gen_ai.response.model", response.model) - _span.set_attribute("gen_ai.duration_ms", int(latency_ms)) + _span.set_attribute("gen_ai.duration.ms", int(latency_ms)) + if self._cache is not None: + _span.set_attribute("gen_ai.cache.hit", False) llm_token_histogram().record( response.usage.total_tokens, {"gen_ai.request.model": resolved_model}, @@ -138,6 +248,8 @@ class LLMGateway: If the primary model fails before any chunk is yielded, tries fallback models. If it fails after chunks have been sent, yields an error chunk and terminates (cannot switch mid-stream). + + Note: Streaming responses are NOT cached in this iteration. """ resolved_model = self._resolve_model_alias(model) diff --git a/src/agentkit/llm/providers/__init__.py b/src/agentkit/llm/providers/__init__.py index 5a3ac74..aa113f1 100644 --- a/src/agentkit/llm/providers/__init__.py +++ b/src/agentkit/llm/providers/__init__.py @@ -4,7 +4,8 @@ from agentkit.llm.providers.anthropic import AnthropicProvider from agentkit.llm.providers.doubao import DoubaoProvider from agentkit.llm.providers.gemini import GeminiProvider from agentkit.llm.providers.openai import OpenAICompatibleProvider -from agentkit.llm.providers.tracker import UsageRecord, UsageSummary, UsageTracker +from agentkit.llm.providers.tracker import UsageSummary, UsageTracker +from agentkit.llm.providers.usage_store import UsageRecord from agentkit.llm.providers.wenxin import WenxinProvider from agentkit.llm.providers.yuanbao import YuanbaoProvider diff --git a/src/agentkit/llm/providers/tracker.py b/src/agentkit/llm/providers/tracker.py index d7774cb..fe9d056 100644 --- a/src/agentkit/llm/providers/tracker.py +++ b/src/agentkit/llm/providers/tracker.py @@ -1,42 +1,20 @@ -"""Usage Tracker - 使用量追踪""" +"""Usage Tracker - 使用量追踪(委托给 UsageStore)""" -from dataclasses import dataclass, field -from datetime import datetime, timezone +from datetime import datetime from agentkit.llm.protocol import TokenUsage - - -@dataclass -class UsageRecord: - """使用量记录""" - - agent_name: str - model: str - prompt_tokens: int - completion_tokens: int - total_tokens: int - cost: float - latency_ms: float - timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) - - -@dataclass -class UsageSummary: - """使用量汇总""" - - total_tokens: int = 0 - total_cost: float = 0.0 - by_model: dict[str, dict[str, int | float]] = field(default_factory=dict) - records: list[UsageRecord] = field(default_factory=list) +from agentkit.llm.providers.usage_store import ( + InMemoryUsageStore, + UsageStore, + UsageSummary, +) class UsageTracker: - """使用量追踪器""" + """使用量追踪器 — 委托给可插拔的 UsageStore""" - MAX_RECORDS = 10000 # 最大记录数,防止内存无限增长 - - def __init__(self) -> None: - self._records: list[UsageRecord] = [] + def __init__(self, store: UsageStore | None = None) -> None: + self._store: UsageStore = store or InMemoryUsageStore() def record( self, @@ -47,19 +25,7 @@ class UsageTracker: latency_ms: float, ) -> None: """记录一次使用""" - rec = UsageRecord( - agent_name=agent_name, - model=model, - prompt_tokens=usage.prompt_tokens, - completion_tokens=usage.completion_tokens, - total_tokens=usage.total_tokens, - cost=cost, - latency_ms=latency_ms, - ) - self._records.append(rec) - # 超过上限时删除最早的记录 - if len(self._records) > self.MAX_RECORDS: - self._records = self._records[-self.MAX_RECORDS:] + self._store.record(agent_name, model, usage, cost, latency_ms) def get_usage( self, @@ -68,32 +34,4 @@ class UsageTracker: end_time: datetime | None = None, ) -> UsageSummary: """查询使用量汇总""" - filtered = self._records - - if agent_name is not None: - filtered = [r for r in filtered if r.agent_name == agent_name] - if start_time is not None: - filtered = [r for r in filtered if r.timestamp >= start_time] - if end_time is not None: - filtered = [r for r in filtered if r.timestamp <= end_time] - - if not filtered: - return UsageSummary() - - total_tokens = sum(r.total_tokens for r in filtered) - total_cost = sum(r.cost for r in filtered) - - by_model: dict[str, dict[str, int | float]] = {} - for r in filtered: - if r.model not in by_model: - by_model[r.model] = {"total_tokens": 0, "total_cost": 0.0, "count": 0} - by_model[r.model]["total_tokens"] += r.total_tokens - by_model[r.model]["total_cost"] += r.cost - by_model[r.model]["count"] += 1 - - return UsageSummary( - total_tokens=total_tokens, - total_cost=total_cost, - by_model=by_model, - records=filtered, - ) + return self._store.get_usage(agent_name, start_time, end_time) diff --git a/src/agentkit/llm/providers/usage_store.py b/src/agentkit/llm/providers/usage_store.py new file mode 100644 index 0000000..822b649 --- /dev/null +++ b/src/agentkit/llm/providers/usage_store.py @@ -0,0 +1,373 @@ +"""Usage Store — Persistent usage tracking with Redis Hash backend. + +Provides UsageStore Protocol with InMemoryUsageStore and RedisUsageStore +backends. Replaces the in-memory list in UsageTracker with a pluggable +store that survives restarts and supports multi-instance deployment. + +Key schema (Redis): + agentkit:usage:{date} → Hash: {agent_name:model → JSON(UsageBucket)} + agentkit:usage_records:{date} → List: JSON(UsageRecord) with LTRIM +""" + +import json +import logging +from dataclasses import dataclass, field +from datetime import datetime, timedelta, timezone +from typing import Any, Protocol, runtime_checkable + +from agentkit.llm.protocol import TokenUsage + +logger = logging.getLogger(__name__) + + +@dataclass +class UsageRecord: + """使用量记录""" + + agent_name: str + model: str + prompt_tokens: int + completion_tokens: int + total_tokens: int + cost: float + latency_ms: float + timestamp: str = "" # ISO 8601 string for JSON serialization + + def __post_init__(self): + if not self.timestamp: + self.timestamp = datetime.now(timezone.utc).isoformat() + + +@dataclass +class UsageBucket: + """Aggregated usage for an agent+model pair on a given date.""" + + prompt_tokens: int = 0 + completion_tokens: int = 0 + total_tokens: int = 0 + cost: float = 0.0 + count: int = 0 + + +@dataclass +class UsageSummary: + """使用量汇总""" + + total_tokens: int = 0 + total_cost: float = 0.0 + by_model: dict[str, dict[str, int | float]] = field(default_factory=dict) + records: list[UsageRecord] = field(default_factory=list) + + +# --------------------------------------------------------------------------- +# UsageStore Protocol +# --------------------------------------------------------------------------- + + +@runtime_checkable +class UsageStore(Protocol): + """Persistent usage store interface.""" + + def record( + self, + agent_name: str, + model: str, + usage: TokenUsage, + cost: float, + latency_ms: float, + ) -> None: + """Record a usage event.""" + ... + + def get_usage( + self, + agent_name: str | None = None, + start_time: datetime | None = None, + end_time: datetime | None = None, + ) -> UsageSummary: + """Query usage summary.""" + ... + + +# --------------------------------------------------------------------------- +# InMemoryUsageStore +# --------------------------------------------------------------------------- + + +class InMemoryUsageStore: + """In-memory usage store (drop-in replacement for old UsageTracker).""" + + MAX_RECORDS = 10000 + + def __init__(self): + self._records: list[UsageRecord] = [] + + def record( + self, + agent_name: str, + model: str, + usage: TokenUsage, + cost: float, + latency_ms: float, + ) -> None: + rec = UsageRecord( + agent_name=agent_name, + model=model, + prompt_tokens=usage.prompt_tokens, + completion_tokens=usage.completion_tokens, + total_tokens=usage.total_tokens, + cost=cost, + latency_ms=latency_ms, + ) + self._records.append(rec) + if len(self._records) > self.MAX_RECORDS: + self._records = self._records[-self.MAX_RECORDS:] + + def get_usage( + self, + agent_name: str | None = None, + start_time: datetime | None = None, + end_time: datetime | None = None, + ) -> UsageSummary: + filtered = self._records + + if agent_name is not None: + filtered = [r for r in filtered if r.agent_name == agent_name] + if start_time is not None: + filtered = [r for r in filtered if datetime.fromisoformat(r.timestamp) >= start_time] + if end_time is not None: + filtered = [r for r in filtered if datetime.fromisoformat(r.timestamp) <= end_time] + + if not filtered: + return UsageSummary() + + total_tokens = sum(r.total_tokens for r in filtered) + total_cost = sum(r.cost for r in filtered) + + by_model: dict[str, dict[str, int | float]] = {} + for r in filtered: + if r.model not in by_model: + by_model[r.model] = {"total_tokens": 0, "total_cost": 0.0, "count": 0} + by_model[r.model]["total_tokens"] += r.total_tokens + by_model[r.model]["total_cost"] += r.cost + by_model[r.model]["count"] += 1 + + return UsageSummary( + total_tokens=total_tokens, + total_cost=total_cost, + by_model=by_model, + records=filtered, + ) + + +# --------------------------------------------------------------------------- +# RedisUsageStore +# --------------------------------------------------------------------------- + + +class RedisUsageStore: + """Redis-backed usage store using Hash per date for O(1) writes. + + Key schema: + agentkit:usage:{YYYY-MM-DD} → Hash: {agent:model → JSON(UsageBucket)} + agentkit:usage_records:{YYYY-MM-DD} → List: JSON(UsageRecord) with LTRIM + """ + + USAGE_PREFIX = "agentkit:usage:" + RECORDS_PREFIX = "agentkit:usage_records:" + MAX_RECORDS_PER_DAY = 50000 + TTL_DAYS = 90 # Auto-expire after 90 days + + def __init__(self, redis_url: str = "redis://localhost:6379"): + self._redis_url = redis_url + self._redis: Any = None + self._sync_redis: Any = None + self._fallback: InMemoryUsageStore | None = None + self._degraded = False + + async def _get_redis(self): + if self._redis is None: + import redis.asyncio as aioredis + self._redis = aioredis.from_url(self._redis_url, decode_responses=True) + return self._redis + + def _get_sync_redis(self): + """Get or create a persistent sync Redis client (connection pool backed).""" + if self._sync_redis is None: + import redis as sync_redis + self._sync_redis = sync_redis.from_url( + self._redis_url, decode_responses=True + ) + return self._sync_redis + + async def aclose(self) -> None: + if self._redis is not None: + await self._redis.aclose() + self._redis = None + if self._sync_redis is not None: + self._sync_redis.aclose() + self._sync_redis = None + + def _degrade_to_fallback(self) -> None: + if not self._degraded: + self._degraded = True + if self._fallback is None: + self._fallback = InMemoryUsageStore() + logger.warning("Redis usage store unreachable, degraded to in-memory") + + def _today_key(self) -> str: + return datetime.now(timezone.utc).strftime("%Y-%m-%d") + + def record( + self, + agent_name: str, + model: str, + usage: TokenUsage, + cost: float, + latency_ms: float, + ) -> None: + """Record usage — sync wrapper for async Redis. + + Note: This is a sync method because UsageTracker.record() is sync. + For Redis, we use a sync Redis client for writes to avoid + needing an event loop in the caller. + """ + if self._degraded and self._fallback is not None: + self._fallback.record(agent_name, model, usage, cost, latency_ms) + return + + try: + r = self._get_sync_redis() + + date_key = self._today_key() + hash_key = f"{self.USAGE_PREFIX}{date_key}" + list_key = f"{self.RECORDS_PREFIX}{date_key}" + bucket_field = f"{agent_name}:{model}" + + # Atomic HINCRBYFLOAT for bucket aggregation + pipe = r.pipeline() + pipe.hincrbyfloat(hash_key, f"{bucket_field}:cost", cost) + pipe.hincrby(hash_key, f"{bucket_field}:prompt_tokens", usage.prompt_tokens) + pipe.hincrby(hash_key, f"{bucket_field}:completion_tokens", usage.completion_tokens) + pipe.hincrby(hash_key, f"{bucket_field}:total_tokens", usage.total_tokens) + pipe.hincrby(hash_key, f"{bucket_field}:count", 1) + + # Append record + rec = UsageRecord( + agent_name=agent_name, + model=model, + prompt_tokens=usage.prompt_tokens, + completion_tokens=usage.completion_tokens, + total_tokens=usage.total_tokens, + cost=cost, + latency_ms=latency_ms, + ) + pipe.rpush(list_key, json.dumps({ + "agent_name": rec.agent_name, + "model": rec.model, + "prompt_tokens": rec.prompt_tokens, + "completion_tokens": rec.completion_tokens, + "total_tokens": rec.total_tokens, + "cost": rec.cost, + "latency_ms": rec.latency_ms, + "timestamp": rec.timestamp, + })) + pipe.ltrim(list_key, -self.MAX_RECORDS_PER_DAY, -1) + + # Set TTL on first write of the day + pipe.expire(hash_key, self.TTL_DAYS * 86400) + pipe.expire(list_key, self.TTL_DAYS * 86400) + + pipe.execute() + except Exception as e: + logger.warning(f"Redis usage record failed: {e}") + self._degrade_to_fallback() + if self._fallback is not None: + self._fallback.record(agent_name, model, usage, cost, latency_ms) + + def get_usage( + self, + agent_name: str | None = None, + start_time: datetime | None = None, + end_time: datetime | None = None, + ) -> UsageSummary: + """Query usage summary from Redis.""" + if self._degraded and self._fallback is not None: + return self._fallback.get_usage(agent_name, start_time, end_time) + + try: + r = self._get_sync_redis() + + # Determine date range to scan + start = start_time or datetime(2020, 1, 1, tzinfo=timezone.utc) + end = end_time or datetime.now(timezone.utc) + + all_records: list[UsageRecord] = [] + # Scan date keys in range + current = start.date() + end_date = end.date() + while current <= end_date: + list_key = f"{self.RECORDS_PREFIX}{current.isoformat()}" + raw_records = r.lrange(list_key, 0, -1) + for raw in raw_records: + data = json.loads(raw) + rec = UsageRecord(**data) + rec_ts = datetime.fromisoformat(rec.timestamp) + if rec_ts >= start and rec_ts <= end: + if agent_name is None or rec.agent_name == agent_name: + all_records.append(rec) + current = current + timedelta(days=1) + + if not all_records: + return UsageSummary() + + total_tokens = sum(r.total_tokens for r in all_records) + total_cost = sum(r.cost for r in all_records) + + by_model: dict[str, dict[str, int | float]] = {} + for r in all_records: + if r.model not in by_model: + by_model[r.model] = {"total_tokens": 0, "total_cost": 0.0, "count": 0} + by_model[r.model]["total_tokens"] += r.total_tokens + by_model[r.model]["total_cost"] += r.cost + by_model[r.model]["count"] += 1 + + return UsageSummary( + total_tokens=total_tokens, + total_cost=total_cost, + by_model=by_model, + records=all_records, + ) + except Exception as e: + logger.warning(f"Redis usage query failed: {e}") + if self._fallback is not None: + return self._fallback.get_usage(agent_name, start_time, end_time) + return UsageSummary() + + +# --------------------------------------------------------------------------- +# Factory +# --------------------------------------------------------------------------- + + +def create_usage_store( + backend: str = "auto", + redis_url: str = "redis://localhost:6379", +) -> UsageStore: + """Create a usage store backend. + + Args: + backend: "auto" (try Redis, fallback to memory), "redis", "memory". + redis_url: Redis connection URL. + + Returns: + A UsageStore instance. + """ + if backend in ("auto", "redis"): + try: + import redis # noqa: F401 + return RedisUsageStore(redis_url=redis_url) + except ImportError: + logger.warning("redis package not available, falling back to in-memory usage store") + return InMemoryUsageStore() + return InMemoryUsageStore() diff --git a/src/agentkit/memory/profile.py b/src/agentkit/memory/profile.py index 9f34c02..6f7ace8 100644 --- a/src/agentkit/memory/profile.py +++ b/src/agentkit/memory/profile.py @@ -10,7 +10,7 @@ import re from dataclasses import dataclass, field from datetime import datetime, timezone from pathlib import Path -from typing import Any +from typing import Any, Callable class MemoryFile: @@ -26,9 +26,11 @@ class MemoryFile: """ - def __init__(self, path: Path, char_budget: int | None = None): + def __init__(self, path: Path, char_budget: int | None = None, + protected_sections: set[str] | None = None): self.path = Path(path) self.char_budget = char_budget + self._protected_sections = protected_sections or set() def read(self) -> str: """读取整个文件内容,文件不存在返回空字符串.""" @@ -37,11 +39,14 @@ class MemoryFile: return self.path.read_text(encoding="utf-8") def write(self, content: str) -> None: - """写入内容,自动创建父目录,超容量时自动裁剪.""" + """写入内容,自动创建父目录,超容量时自动裁剪. + + 在内存中完成裁剪后一次性写入,避免中间不一致状态。 + """ self.path.parent.mkdir(parents=True, exist_ok=True) - self.path.write_text(content, encoding="utf-8") if self.char_budget and len(content) > self.char_budget: - self.trim_to_budget() + content = self._trim_content(content, self._protected_sections or None) + self.path.write_text(content, encoding="utf-8") def read_section(self, name: str) -> str: """读取指定 section 的内容(不含标题行).""" @@ -104,15 +109,64 @@ class MemoryFile: return [] return re.findall(r"^## (.+)$", content, re.MULTILINE) - def trim_to_budget(self) -> None: - """裁剪内容到容量上限,优先保留前面的 section.""" + def trim_to_budget(self, protected_sections: set[str] | None = None) -> None: + """裁剪内容到容量上限,按 section 边界截断. + + 保持原始 section 顺序,仅从后向前移除非保护 section。 + protected_sections 中的 section 始终保留,不参与裁剪。 + """ if not self.char_budget: return content = self.read() if len(content) <= self.char_budget: return - # 从末尾裁剪,保留前面的 section - self.write(content[: self.char_budget]) + trimmed = self._trim_content(content, protected_sections) + self.path.write_text(trimmed, encoding="utf-8") + + def _trim_content(self, content: str, protected_sections: set[str] | None = None) -> str: + """在内存中裁剪内容到容量上限,返回裁剪后的字符串(不写文件). + + 保持原始 section 顺序,仅从后向前移除非保护 section。 + """ + if not self.char_budget or len(content) <= self.char_budget: + return content + + protected = protected_sections or set() + + # 解析所有 section 及其位置 + sections: list[tuple[str, int, int]] = [] # (name, start, end) + for match in re.finditer(r"^## (.+)$", content, re.MULTILINE): + name = match.group(1).strip() + start = match.start() + next_match = re.search(r"^## ", content[match.end():], re.MULTILINE) + if next_match: + end = match.end() + next_match.start() + else: + end = len(content) + sections.append((name, start, end)) + + if not sections: + return content[:self.char_budget] + + # 保持原始顺序,标记每个 section 是否受保护 + ordered: list[tuple[str, str, bool]] = [] # (name, text, is_protected) + for name, start, end in sections: + ordered.append((name, content[start:end], name in protected)) + + # 从后向前移除非保护 section,直到总长度在预算内 + while ordered: + total = len("\n\n".join(s[1] for s in ordered)) + if total <= self.char_budget: + break + # 从后向前找第一个非保护 section 移除 + for i in range(len(ordered) - 1, -1, -1): + if not ordered[i][2]: + ordered.pop(i) + break + else: + break # 所有剩余 section 都是受保护的 + + return "\n\n".join(s[1] for s in ordered).strip() @dataclass @@ -168,14 +222,21 @@ class MemoryStore: """ - def __init__(self, base_dir: Path | str | None = None): + def __init__(self, base_dir: Path | str | None = None, + on_change: Callable[[str], None] | None = None): if base_dir is None: base_dir = Path.home() / ".agentkit" self.base_dir = Path(base_dir) self.base_dir.mkdir(parents=True, exist_ok=True) + self._on_change = on_change + self._base_prompt: str = "" # 初始化四个 MemoryFile - self._soul = MemoryFile(self.base_dir / "SOUL.md", char_budget=SOUL_BUDGET) + self._soul = MemoryFile( + self.base_dir / "SOUL.md", + char_budget=SOUL_BUDGET, + protected_sections={"版本", "更新历史"}, + ) self._user = MemoryFile(self.base_dir / "memories" / "USER.md", char_budget=USER_BUDGET) self._memory = MemoryFile(self.base_dir / "memories" / "MEMORY.md", char_budget=MEMORY_BUDGET) self._daily_dir = self.base_dir / "memories" / "daily" @@ -277,6 +338,10 @@ class MemoryStore: [base_prompt] """ + # 保存 base_prompt 供后续刷新使用 + if base_prompt: + self._base_prompt = base_prompt + parts: list[str] = [] if snapshot.soul: @@ -292,3 +357,23 @@ class MemoryStore: parts.append(base_prompt) return "\n\n".join(parts) if parts else base_prompt + + def refresh_system_prompt(self) -> str: + """重新加载所有记忆文件并构建 system prompt. + + 在 MemoryTool 写入记忆后调用,确保 agent 的 _system_prompt + 反映最新的记忆内容。 + """ + snapshot = self.load_all() + return self.build_system_prompt(snapshot, self._base_prompt) + + def notify_change(self) -> None: + """记忆文件变更后通知回调,刷新所有订阅者的 system prompt.""" + if self._on_change is None: + return + try: + new_prompt = self.refresh_system_prompt() + self._on_change(new_prompt) + except Exception: + import logging + logging.getLogger(__name__).warning("Memory notify_change failed", exc_info=True) diff --git a/src/agentkit/quality/cascade_detector.py b/src/agentkit/quality/cascade_detector.py index 49a3e7e..d57a1b8 100644 --- a/src/agentkit/quality/cascade_detector.py +++ b/src/agentkit/quality/cascade_detector.py @@ -1,9 +1,14 @@ -"""CascadeDetector - 独立的级联故障检测工具""" +"""CascadeDetector - 独立的级联故障检测工具(委托给 CascadeStateStore)""" from __future__ import annotations from dataclasses import dataclass +from agentkit.quality.cascade_state_store import ( + CascadeStateStore, + InMemoryCascadeStateStore, +) + @dataclass class CascadeAlert: @@ -19,18 +24,19 @@ class CascadeAlert: class CascadeDetector: """检测多 agent 交互中的级联故障""" - def __init__(self, max_interactions: int = 10, max_depth: int = 3): + def __init__( + self, + max_interactions: int = 10, + max_depth: int = 3, + store: CascadeStateStore | None = None, + ): self._max_interactions = max_interactions self._max_depth = max_depth - self._interaction_counts: dict[str, int] = {} - self._loop_depths: dict[str, int] = {} + self._store: CascadeStateStore = store or InMemoryCascadeStateStore() def check_interaction(self, session_id: str) -> CascadeAlert | None: """递增并检查交互计数""" - self._interaction_counts[session_id] = ( - self._interaction_counts.get(session_id, 0) + 1 - ) - count = self._interaction_counts[session_id] + count = self._store.increment_interaction(session_id) if count > self._max_interactions: return CascadeAlert( session_id=session_id, @@ -46,7 +52,7 @@ class CascadeDetector: def check_depth(self, session_id: str, depth: int) -> CascadeAlert | None: """检查循环深度""" - self._loop_depths[session_id] = depth + self._store.set_depth(session_id, depth) if depth > self._max_depth: return CascadeAlert( session_id=session_id, @@ -62,12 +68,11 @@ class CascadeDetector: def reset(self, session_id: str) -> None: """重置某个 session 的计数器""" - self._interaction_counts.pop(session_id, None) - self._loop_depths.pop(session_id, None) + self._store.reset(session_id) def get_stats(self, session_id: str) -> dict[str, int]: """获取某个 session 的当前统计""" return { - "interactions": self._interaction_counts.get(session_id, 0), - "depth": self._loop_depths.get(session_id, 0), + "interactions": self._store.get_interaction(session_id), + "depth": self._store.get_depth(session_id), } diff --git a/src/agentkit/quality/cascade_state_store.py b/src/agentkit/quality/cascade_state_store.py new file mode 100644 index 0000000..b4a2e70 --- /dev/null +++ b/src/agentkit/quality/cascade_state_store.py @@ -0,0 +1,245 @@ +"""Cascade State Store — Persistent cascade detection state with Redis INCR backend. + +Provides CascadeStateStore Protocol with InMemoryCascadeStateStore and +RedisCascadeStateStore backends. Enables CascadeDetector state to survive +restarts and work across multiple instances. + +Key schema (Redis): + agentkit:cascade:interactions:{session_id} → INCR counter with TTL + agentkit:cascade:depths:{session_id} → SET counter with TTL +""" + +import logging +import time +from typing import Any, Protocol, runtime_checkable + +logger = logging.getLogger(__name__) + + +@runtime_checkable +class CascadeStateStore(Protocol): + """Persistent cascade detection state interface.""" + + def increment_interaction(self, session_id: str) -> int: + """Atomically increment interaction count. Returns new count.""" + ... + + def get_interaction(self, session_id: str) -> int: + """Get current interaction count.""" + ... + + def set_depth(self, session_id: str, depth: int) -> None: + """Set loop depth for a session.""" + ... + + def get_depth(self, session_id: str) -> int: + """Get current loop depth.""" + ... + + def reset(self, session_id: str) -> None: + """Reset all counters for a session.""" + ... + + +# --------------------------------------------------------------------------- +# InMemoryCascadeStateStore +# --------------------------------------------------------------------------- + + +class InMemoryCascadeStateStore: + """In-memory cascade state store (default, process-local). + + Supports optional session TTL to prevent unbounded memory growth. + Expired entries are lazily cleaned up on access. + """ + + DEFAULT_SESSION_TTL = 86400 # 24 hours + + def __init__(self, session_ttl: int = 86400): + self._session_ttl = session_ttl + self._interaction_counts: dict[str, int] = {} + self._loop_depths: dict[str, int] = {} + self._timestamps: dict[str, float] = {} + + def _is_expired(self, session_id: str) -> bool: + ts = self._timestamps.get(session_id) + if ts is None: + return False + return (time.monotonic() - ts) > self._session_ttl + + def _cleanup_expired(self) -> None: + """Lazy cleanup: remove expired sessions.""" + expired = [sid for sid in self._timestamps if self._is_expired(sid)] + for sid in expired: + self._interaction_counts.pop(sid, None) + self._loop_depths.pop(sid, None) + self._timestamps.pop(sid, None) + + def _touch(self, session_id: str) -> None: + self._timestamps[session_id] = time.monotonic() + + def increment_interaction(self, session_id: str) -> int: + self._cleanup_expired() + self._interaction_counts[session_id] = self._interaction_counts.get(session_id, 0) + 1 + self._touch(session_id) + return self._interaction_counts[session_id] + + def get_interaction(self, session_id: str) -> int: + if self._is_expired(session_id): + self.reset(session_id) + return 0 + return self._interaction_counts.get(session_id, 0) + + def set_depth(self, session_id: str, depth: int) -> None: + self._touch(session_id) + self._loop_depths[session_id] = depth + + def get_depth(self, session_id: str) -> int: + if self._is_expired(session_id): + self.reset(session_id) + return 0 + return self._loop_depths.get(session_id, 0) + + def reset(self, session_id: str) -> None: + self._interaction_counts.pop(session_id, None) + self._loop_depths.pop(session_id, None) + self._timestamps.pop(session_id, None) + + +# --------------------------------------------------------------------------- +# RedisCascadeStateStore +# --------------------------------------------------------------------------- + + +class RedisCascadeStateStore: + """Redis-backed cascade state store using INCR for atomic increments. + + Key schema: + agentkit:cascade:interactions:{session_id} → INCR counter with TTL + agentkit:cascade:depths:{session_id} → SET counter with TTL + """ + + INTER_PREFIX = "agentkit:cascade:interactions:" + DEPTH_PREFIX = "agentkit:cascade:depths:" + SESSION_TTL = 86400 # 24 hours — sessions rarely last longer + + def __init__(self, redis_url: str = "redis://localhost:6379", session_ttl: int = 86400): + self._redis_url = redis_url + self._session_ttl = session_ttl + self._sync_redis: Any = None + self._fallback: InMemoryCascadeStateStore | None = None + self._degraded = False + + def _get_sync_redis(self): + """Get or create a persistent sync Redis client (connection pool backed).""" + if self._sync_redis is None: + import redis as sync_redis + self._sync_redis = sync_redis.from_url( + self._redis_url, decode_responses=True + ) + return self._sync_redis + + def _degrade_to_fallback(self) -> None: + if not self._degraded: + self._degraded = True + if self._fallback is None: + self._fallback = InMemoryCascadeStateStore() + logger.warning("Redis cascade store unreachable, degraded to in-memory") + + def increment_interaction(self, session_id: str) -> int: + if self._degraded and self._fallback is not None: + return self._fallback.increment_interaction(session_id) + try: + r = self._get_sync_redis() + key = f"{self.INTER_PREFIX}{session_id}" + pipe = r.pipeline() + pipe.incr(key) + pipe.expire(key, self._session_ttl) + results = pipe.execute() + return results[0] + except Exception as e: + logger.warning(f"Redis cascade increment failed: {e}") + self._degrade_to_fallback() + if self._fallback is not None: + return self._fallback.increment_interaction(session_id) + return 0 + + def get_interaction(self, session_id: str) -> int: + if self._degraded and self._fallback is not None: + return self._fallback.get_interaction(session_id) + try: + r = self._get_sync_redis() + val = r.get(f"{self.INTER_PREFIX}{session_id}") + return int(val) if val is not None else 0 + except Exception as e: + logger.warning(f"Redis cascade get failed: {e}") + if self._fallback is not None: + return self._fallback.get_interaction(session_id) + return 0 + + def set_depth(self, session_id: str, depth: int) -> None: + if self._degraded and self._fallback is not None: + self._fallback.set_depth(session_id, depth) + return + try: + r = self._get_sync_redis() + key = f"{self.DEPTH_PREFIX}{session_id}" + pipe = r.pipeline() + pipe.set(key, depth) + pipe.expire(key, self._session_ttl) + pipe.execute() + except Exception as e: + logger.warning(f"Redis cascade set_depth failed: {e}") + self._degrade_to_fallback() + if self._fallback is not None: + self._fallback.set_depth(session_id, depth) + + def get_depth(self, session_id: str) -> int: + if self._degraded and self._fallback is not None: + return self._fallback.get_depth(session_id) + try: + r = self._get_sync_redis() + val = r.get(f"{self.DEPTH_PREFIX}{session_id}") + return int(val) if val is not None else 0 + except Exception as e: + logger.warning(f"Redis cascade get_depth failed: {e}") + if self._fallback is not None: + return self._fallback.get_depth(session_id) + return 0 + + def reset(self, session_id: str) -> None: + if self._degraded and self._fallback is not None: + self._fallback.reset(session_id) + return + try: + r = self._get_sync_redis() + pipe = r.pipeline() + pipe.delete(f"{self.INTER_PREFIX}{session_id}") + pipe.delete(f"{self.DEPTH_PREFIX}{session_id}") + pipe.execute() + except Exception as e: + logger.warning(f"Redis cascade reset failed: {e}") + self._degrade_to_fallback() + if self._fallback is not None: + self._fallback.reset(session_id) + + +# --------------------------------------------------------------------------- +# Factory +# --------------------------------------------------------------------------- + + +def create_cascade_state_store( + backend: str = "auto", + redis_url: str = "redis://localhost:6379", + session_ttl: int = 86400, +) -> CascadeStateStore: + """Create a cascade state store backend.""" + if backend in ("auto", "redis"): + try: + import redis # noqa: F401 + return RedisCascadeStateStore(redis_url=redis_url, session_ttl=session_ttl) + except ImportError: + logger.warning("redis package not available, falling back to in-memory cascade store") + return InMemoryCascadeStateStore() + return InMemoryCascadeStateStore() diff --git a/src/agentkit/server/app.py b/src/agentkit/server/app.py index 541e89c..1a2e5df 100644 --- a/src/agentkit/server/app.py +++ b/src/agentkit/server/app.py @@ -40,7 +40,19 @@ _ALLOWED_ENV_EXACT = {'DATABASE_URL', 'REDIS_URL'} def _build_llm_gateway(config: ServerConfig) -> LLMGateway: """Build LLMGateway from ServerConfig, registering all providers.""" - gateway = LLMGateway(config=config.llm_config) + # Initialize UsageStore if configured + usage_store = None + if config.usage_store: + try: + from agentkit.llm.providers.usage_store import create_usage_store + usage_store = create_usage_store( + backend=config.usage_store.get("backend", "memory"), + redis_url=config.usage_store.get("redis_url", "redis://localhost:6379"), + ) + except Exception as e: + logger.warning(f"Failed to initialize usage store: {e}, using in-memory") + + gateway = LLMGateway(config=config.llm_config, usage_store=usage_store) for name, pconf in config.llm_config.providers.items(): if not pconf.api_key: @@ -111,6 +123,15 @@ async def lifespan(app: FastAPI): # Start MCP servers if configured mcp_manager = getattr(app.state, "mcp_manager", None) + + # Build semantic router index after skill registry is populated + semantic_router = getattr(getattr(app.state, "cost_aware_router", None), "_semantic_router", None) + if semantic_router is not None: + try: + await semantic_router.build_index(app.state.skill_registry) + logger.info(f"Semantic router index built with {len(app.state.skill_registry.list_skills())} skills") + except Exception as e: + logger.warning(f"Failed to build semantic router index: {e}") if mcp_manager is not None: await mcp_manager.start_all() @@ -142,6 +163,23 @@ async def lifespan(app: FastAPI): ) effective_system_prompt = memory_store.build_system_prompt(memory_snapshot, base_prompt) + # Register on_change callback to refresh all agents' system prompts + # when MemoryTool writes to memory files + def _on_memory_change(new_prompt: str) -> None: + pool = app.state.agent_pool + updated = 0 + for agent_name in pool.list_agents(): + try: + agent = pool.get_agent(agent_name) + if agent is not None: + agent._system_prompt = new_prompt + updated += 1 + except Exception: + logger.warning(f"Failed to update system prompt for agent '{agent_name}'", exc_info=True) + logger.info(f"Memory changed: refreshed system prompt for {updated}/{len(pool.list_agents())} agents") + + memory_store._on_change = _on_memory_change + # Store memory_store on app.state for chat routes to use app.state.memory_store = memory_store @@ -219,6 +257,34 @@ async def lifespan(app: FastAPI): from agentkit.memory.profile import MemoryStore memory_store = MemoryStore() memory_store.ensure_defaults() + # Initialize _base_prompt so refresh_system_prompt works correctly + snapshot = memory_store.load_all() + base_prompt = ( + "你是一个有帮助的AI助手。请记住我们对话的上下文,并在后续对话中引用之前的内容。回答要清晰简洁,请使用中文回复。\n\n" + "重要提示:当你不确定事实信息、时事新闻或任何你不确信的话题时," + "你必须先使用搜索工具查找准确和最新的信息,然后再回答。" + "中文内容优先使用 baidu_search 工具,英文/国际内容使用 web_search。" + "在能够搜索到真相的情况下,绝不猜测或编造答案。" + "始终优先搜索而不是给出可能不正确的信息。\n\n" + "技能安装:当需要安装技能时,使用 skill_install 工具,不要用 shell 执行 npm install。" + "skill_install 的 source 参数格式为 owner/repo@skill,例如 vercel-labs/skills@find-skills。" + "如果不知道完整 source,先用 shell 执行 `npx skills search ` 搜索。" + ) + memory_store.build_system_prompt(snapshot, base_prompt) + # Register on_change callback for existing agents + def _on_memory_change(new_prompt: str) -> None: + pool = app.state.agent_pool + updated = 0 + for agent_name in pool.list_agents(): + try: + agent = pool.get_agent(agent_name) + if agent is not None: + agent._system_prompt = new_prompt + updated += 1 + except Exception: + logger.warning(f"Failed to update system prompt for agent '{agent_name}'", exc_info=True) + logger.info(f"Memory changed: refreshed system prompt for {updated}/{len(pool.list_agents())} agents") + memory_store._on_change = _on_memory_change app.state.memory_store = memory_store yield @@ -502,12 +568,28 @@ def create_app( auction_enabled = False if server_config and hasattr(server_config, "marketplace") and server_config.marketplace: auction_enabled = server_config.marketplace.get("auction_enabled", False) + + # Initialize semantic router if configured + semantic_router = None + router_conf = server_config.router if server_config and server_config.router else {} + if router_conf.get("semantic", {}).get("enabled"): + try: + from agentkit.chat.semantic_router import SemanticRouter + semantic_router = SemanticRouter( + embedder=app.state.llm_gateway._embedder, + similarity_high=router_conf["semantic"].get("similarity_high", 0.85), + similarity_low=router_conf["semantic"].get("similarity_low", 0.6), + ) + except Exception as e: + logger.warning(f"Failed to initialize semantic router: {e}") + cost_aware_router = CostAwareRouter( llm_gateway=app.state.llm_gateway, org_context=org_context, auction_enabled=auction_enabled, - classifier=server_config.router.get("classifier", "heuristic") if server_config and server_config.router else "heuristic", - merged_llm_classify=server_config.router.get("merged_llm_classify", True) if server_config and server_config.router else True, + classifier=router_conf.get("classifier", "heuristic"), + merged_llm_classify=router_conf.get("merged_llm_classify", True), + semantic_router=semantic_router, ) app.state.cost_aware_router = cost_aware_router # Initialize task store from config @@ -555,14 +637,30 @@ def create_app( app.state.evolution_store = create_evolution_store( backend=evo_conf.get("backend", "memory"), db_path=evo_conf.get("db_path", "~/.agentkit/evolution.db"), + database_url=evo_conf.get("database_url"), ) except Exception as e: - import logging - logging.getLogger(__name__).warning(f"Failed to initialize evolution store: {e}") + logger.warning(f"Failed to initialize evolution store: {e}") app.state.evolution_store = None else: app.state.evolution_store = None + # Initialize cascade state store if configured + if server_config and hasattr(server_config, 'cascade_store') and server_config.cascade_store: + try: + from agentkit.quality.cascade_state_store import create_cascade_state_store + cs_conf = server_config.cascade_store + app.state.cascade_state_store = create_cascade_state_store( + backend=cs_conf.get("backend", "memory"), + redis_url=cs_conf.get("redis_url", "redis://localhost:6379"), + session_ttl=cs_conf.get("session_ttl", 86400), + ) + except Exception as e: + logger.warning(f"Failed to initialize cascade state store: {e}") + app.state.cascade_state_store = None + else: + app.state.cascade_state_store = None + # Initialize memory components if configured if server_config and hasattr(server_config, 'memory') and server_config.memory: try: diff --git a/src/agentkit/server/config.py b/src/agentkit/server/config.py index f96480f..97ac87c 100644 --- a/src/agentkit/server/config.py +++ b/src/agentkit/server/config.py @@ -111,6 +111,9 @@ class ServerConfig: marketplace: dict[str, Any] | None = None, alignment: dict[str, Any] | None = None, router: dict[str, Any] | None = None, + usage_store: dict[str, Any] | None = None, + cascade_store: dict[str, Any] | None = None, + evolution: dict[str, Any] | None = None, on_change: Callable[["ServerConfig"], None] | None = None, ): self.host = host @@ -134,6 +137,9 @@ class ServerConfig: self.marketplace = marketplace or {} self.alignment = alignment or {} self.router = router or {} + self.usage_store = usage_store or {} + self.cascade_store = cascade_store or {} + self.evolution = evolution or {} self.on_change = on_change # Config watching state @@ -201,6 +207,15 @@ class ServerConfig: # Router config router_data = data.get("router", {}) + # Usage store config + usage_store_data = data.get("usage_store", {}) + + # Cascade store config + cascade_store_data = data.get("cascade_store", {}) + + # Evolution store config + evolution_data = data.get("evolution", {}) + return cls( host=server.get("host", "0.0.0.0"), port=server.get("port", 8001), @@ -223,11 +238,16 @@ class ServerConfig: marketplace=marketplace_data, alignment=alignment_data, router=router_data, + usage_store=usage_store_data, + cascade_store=cascade_store_data, + evolution=evolution_data, ) @staticmethod def _build_llm_config(data: dict) -> LLMConfig: """Build LLMConfig from the llm section of agentkit.yaml.""" + from agentkit.llm.config import CacheConfig + providers = {} model_aliases = {} @@ -254,10 +274,17 @@ class ServerConfig: keepalive_expiry=pconf.get("keepalive_expiry", 30.0), ) + # Build CacheConfig if cache section is present + cache_config = None + cache_data = data.get("cache") + if cache_data and isinstance(cache_data, dict): + cache_config = CacheConfig.from_dict(cache_data) + return LLMConfig( providers=providers, model_aliases=model_aliases, fallbacks=data.get("fallbacks", {}), + cache=cache_config, ) @staticmethod diff --git a/tests/integration/test_p0_hardening.py b/tests/integration/test_p0_hardening.py new file mode 100644 index 0000000..8952db7 --- /dev/null +++ b/tests/integration/test_p0_hardening.py @@ -0,0 +1,422 @@ +"""P0 Production Hardening — End-to-End Integration Tests + +Verifies the full configuration wiring and feature integration: +- Config from YAML → all features configured correctly +- Cache + usage tracking: cached requests show 0 cost +- UsageStore persistence via config +- CascadeStateStore persistence via config +- EvolutionStore via config +- Semantic router via config +- Graceful degradation when backends unavailable +""" + +import os +import tempfile + +import pytest + +from agentkit.core.protocol import EvolutionEvent +from agentkit.llm.config import CacheConfig, LLMConfig, ProviderConfig +from agentkit.llm.gateway import LLMGateway +from agentkit.llm.protocol import TokenUsage +from agentkit.llm.providers.usage_store import ( + InMemoryUsageStore, + create_usage_store, +) +from agentkit.quality.cascade_detector import CascadeDetector +from agentkit.quality.cascade_state_store import ( + InMemoryCascadeStateStore, + create_cascade_state_store, +) +from agentkit.evolution.evolution_store import ( + EvolutionStoreProtocol, + InMemoryEvolutionStore, + PersistentEvolutionStore, + create_evolution_store, +) +from agentkit.server.config import ServerConfig + + +# ── Config parsing tests ────────────────────────────────── + + +class TestConfigParsing: + """Verify ServerConfig correctly parses all new config sections.""" + + def test_usage_store_config_parsed(self): + data = { + "usage_store": { + "backend": "redis", + "redis_url": "redis://custom:6380", + } + } + config = ServerConfig.from_dict(data) + assert config.usage_store["backend"] == "redis" + assert config.usage_store["redis_url"] == "redis://custom:6380" + + def test_cascade_store_config_parsed(self): + data = { + "cascade_store": { + "backend": "redis", + "redis_url": "redis://custom:6380", + "session_ttl": 3600, + } + } + config = ServerConfig.from_dict(data) + assert config.cascade_store["backend"] == "redis" + assert config.cascade_store["session_ttl"] == 3600 + + def test_evolution_config_parsed(self): + data = { + "evolution": { + "backend": "sqlite", + "db_path": "/tmp/test.db", + } + } + config = ServerConfig.from_dict(data) + assert config.evolution["backend"] == "sqlite" + assert config.evolution["db_path"] == "/tmp/test.db" + + def test_llm_cache_config_parsed(self): + data = { + "llm": { + "providers": {}, + "cache": { + "enabled": True, + "backend": "memory", + "exact_ttl": 7200, + }, + } + } + config = ServerConfig.from_dict(data) + assert config.llm_config.cache is not None + assert config.llm_config.cache.enabled is True + assert config.llm_config.cache.backend == "memory" + assert config.llm_config.cache.exact_ttl == 7200 + + def test_router_semantic_config_parsed(self): + data = { + "router": { + "classifier": "heuristic", + "semantic": { + "enabled": True, + "similarity_high": 0.9, + "similarity_low": 0.5, + }, + } + } + config = ServerConfig.from_dict(data) + assert config.router["semantic"]["enabled"] is True + assert config.router["semantic"]["similarity_high"] == 0.9 + + def test_empty_config_defaults(self): + config = ServerConfig.from_dict({}) + assert config.usage_store == {} + assert config.cascade_store == {} + assert config.evolution == {} + assert config.llm_config.cache is None + + def test_config_from_yaml_roundtrip(self, tmp_path): + """Config can be loaded from a YAML file with all new sections.""" + yaml_content = """ +server: + host: 0.0.0.0 + port: 8001 +llm: + providers: {} + cache: + enabled: true + backend: memory +router: + classifier: heuristic + semantic: + enabled: false +usage_store: + backend: memory +cascade_store: + backend: memory +evolution: + backend: memory +""" + yaml_path = str(tmp_path / "test_config.yaml") + with open(yaml_path, "w") as f: + f.write(yaml_content) + + config = ServerConfig.from_yaml(yaml_path) + assert config.llm_config.cache is not None + assert config.llm_config.cache.enabled is True + assert config.usage_store["backend"] == "memory" + assert config.cascade_store["backend"] == "memory" + assert config.evolution["backend"] == "memory" + + +# ── UsageStore integration tests ─────────────────────────── + + +class TestUsageStoreIntegration: + """Verify UsageStore works with LLMGateway.""" + + async def test_gateway_with_usage_store(self): + """LLMGateway uses injected UsageStore for tracking.""" + store = InMemoryUsageStore() + gateway = LLMGateway(usage_store=store) + + # Record usage directly through the tracker + gateway._usage_tracker.record( + agent_name="test_agent", + model="gpt-4", + usage=TokenUsage(prompt_tokens=100, completion_tokens=50), + cost=0.01, + latency_ms=100.0, + ) + + usage = gateway.get_usage() + assert usage.total_tokens > 0 + assert usage.total_cost > 0 + + async def test_gateway_without_usage_store(self): + """LLMGateway works without explicit UsageStore (uses InMemory).""" + gateway = LLMGateway() + gateway._usage_tracker.record( + agent_name="test_agent", + model="gpt-4", + usage=TokenUsage(prompt_tokens=100, completion_tokens=50), + cost=0.01, + latency_ms=100.0, + ) + + usage = gateway.get_usage() + assert usage.total_tokens > 0 + + def test_create_usage_store_memory(self): + store = create_usage_store(backend="memory") + assert isinstance(store, InMemoryUsageStore) + + def test_create_usage_store_redis_lazy(self): + """Redis backend creates RedisUsageStore (lazy connection, degrades on first op).""" + from agentkit.llm.providers.usage_store import RedisUsageStore + + store = create_usage_store( + backend="redis", + redis_url="redis://nonexistent:6379", + ) + # RedisUsageStore is created (lazy connection), degrades on first operation + assert isinstance(store, RedisUsageStore) + + +# ── CascadeStateStore integration tests ──────────────────── + + +class TestCascadeStateStoreIntegration: + """Verify CascadeStateStore works with CascadeDetector.""" + + async def test_cascade_detector_with_store(self): + """CascadeDetector uses injected CascadeStateStore.""" + store = InMemoryCascadeStateStore() + detector = CascadeDetector(store=store) + + # Check interaction — should not trigger cascade + result = detector.check_interaction(session_id="test-session") + assert result is None # No cascade alert + + async def test_cascade_detector_without_store(self): + """CascadeDetector works without explicit store (uses InMemory).""" + detector = CascadeDetector() + result = detector.check_interaction(session_id="test-session") + assert result is None + + def test_create_cascade_state_store_memory(self): + store = create_cascade_state_store(backend="memory") + assert isinstance(store, InMemoryCascadeStateStore) + + +# ── EvolutionStore integration tests ─────────────────────── + + +class TestEvolutionStoreIntegration: + """Verify EvolutionStore creation from config.""" + + async def test_create_evolution_store_from_config(self, tmp_path): + """EvolutionStore created from config dict works correctly.""" + db_path = str(tmp_path / "evo_test.db") + store = create_evolution_store(backend="sqlite", db_path=db_path) + assert isinstance(store, PersistentEvolutionStore) + + event = EvolutionEvent( + agent_name="test_agent", + change_type="prompt", + before={"old": 1}, + after={"new": 2}, + ) + event_id = await store.record(event) + assert event_id is not None + + events = await store.list_events() + assert len(events) == 1 + assert events[0]["agent_name"] == "test_agent" + + async def test_create_evolution_store_memory(self): + store = create_evolution_store(backend="memory") + assert isinstance(store, InMemoryEvolutionStore) + + event = EvolutionEvent( + agent_name="test_agent", + change_type="prompt", + before={}, + after={}, + ) + event_id = await store.record(event) + assert event_id is not None + + async def test_evolution_store_protocol_compliance(self): + """All created stores satisfy EvolutionStoreProtocol.""" + memory_store = create_evolution_store(backend="memory") + assert isinstance(memory_store, EvolutionStoreProtocol) + + +# ── Cache integration tests ──────────────────────────────── + + +class TestCacheIntegration: + """Verify LLMCache integration with LLMGateway via config.""" + + def test_gateway_with_cache_config(self): + """LLMGateway initializes cache when CacheConfig is provided.""" + config = LLMConfig( + providers={}, + cache=CacheConfig(enabled=True, backend="memory"), + ) + gateway = LLMGateway(config=config) + assert gateway._cache is not None + + def test_gateway_without_cache_config(self): + """LLMGateway works without cache (default).""" + config = LLMConfig(providers={}) + gateway = LLMGateway(config=config) + assert gateway._cache is None + + def test_gateway_cache_disabled(self): + """LLMGateway does not initialize cache when disabled.""" + config = LLMConfig( + providers={}, + cache=CacheConfig(enabled=False), + ) + gateway = LLMGateway(config=config) + assert gateway._cache is None + + +# ── Graceful degradation tests ───────────────────────────── + + +class TestGracefulDegradation: + """Verify all features degrade gracefully when backends unavailable.""" + + def test_usage_store_auto_creates_redis(self): + """auto backend creates RedisUsageStore (lazy connection).""" + from agentkit.llm.providers.usage_store import RedisUsageStore + + store = create_usage_store( + backend="auto", + redis_url="redis://nonexistent:6379", + ) + # Redis is available as package, so RedisUsageStore is created + assert isinstance(store, RedisUsageStore) + + def test_cascade_store_redis_lazy(self): + """CascadeStateStore Redis backend creates instance (lazy connection).""" + from agentkit.quality.cascade_state_store import RedisCascadeStateStore + + store = create_cascade_state_store( + backend="redis", + redis_url="redis://nonexistent:6379", + ) + assert isinstance(store, RedisCascadeStateStore) + + def test_evolution_store_postgresql_unavailable(self): + """EvolutionStore falls back to InMemory when PG unavailable.""" + store = create_evolution_store( + backend="postgresql", + database_url=None, + ) + assert isinstance(store, InMemoryEvolutionStore) + + def test_cache_auto_creates_redis(self): + """LLMCache auto backend creates RedisLLMCache (lazy connection).""" + from agentkit.llm.cache import create_llm_cache, RedisLLMCache + + cache = create_llm_cache( + backend="auto", + redis_url="redis://nonexistent:6379", + ) + # Redis package is available, so RedisLLMCache is created (lazy connection) + assert isinstance(cache, RedisLLMCache) + + +# ── Full flow test (in-memory) ───────────────────────────── + + +class TestFullFlowInMemory: + """End-to-end flow test using in-memory backends (no external deps).""" + + async def test_config_to_components(self): + """ServerConfig → all components initialized correctly.""" + data = { + "llm": { + "providers": {}, + "cache": { + "enabled": True, + "backend": "memory", + }, + }, + "router": { + "classifier": "heuristic", + }, + "usage_store": {"backend": "memory"}, + "cascade_store": {"backend": "memory"}, + "evolution": {"backend": "memory"}, + } + config = ServerConfig.from_dict(data) + + # Verify config parsed correctly + assert config.llm_config.cache is not None + assert config.llm_config.cache.enabled is True + assert config.usage_store["backend"] == "memory" + assert config.cascade_store["backend"] == "memory" + assert config.evolution["backend"] == "memory" + + # Create components from config + usage_store = create_usage_store( + backend=config.usage_store.get("backend", "memory"), + redis_url=config.usage_store.get("redis_url", "redis://localhost:6379"), + ) + gateway = LLMGateway(config=config.llm_config, usage_store=usage_store) + assert gateway._cache is not None + + cascade_store = create_cascade_state_store( + backend=config.cascade_store.get("backend", "memory"), + ) + detector = CascadeDetector(store=cascade_store) + + evo_store = create_evolution_store( + backend=config.evolution.get("backend", "memory"), + ) + + # Exercise the components + gateway._usage_tracker.record( + agent_name="test", + model="gpt-4", + usage=TokenUsage(prompt_tokens=100, completion_tokens=50), + cost=0.01, + latency_ms=100.0, + ) + usage = gateway.get_usage() + assert usage.total_tokens == 150 + + result = detector.check_interaction("s1") + assert result is None # No cascade alert + + event = EvolutionEvent( + agent_name="test", change_type="prompt", before={}, after={} + ) + event_id = await evo_store.record(event) + assert event_id is not None diff --git a/tests/unit/test_gateway_cache.py b/tests/unit/test_gateway_cache.py new file mode 100644 index 0000000..9e56b67 --- /dev/null +++ b/tests/unit/test_gateway_cache.py @@ -0,0 +1,162 @@ +"""Integration tests for LLM Cache integration into LLMGateway (U2).""" + +import pytest + +from agentkit.llm.cache import InMemoryLLMCache +from agentkit.llm.config import CacheConfig, LLMConfig +from agentkit.llm.gateway import LLMGateway +from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, TokenUsage + + +class MockProvider(LLMProvider): + """Mock LLM provider that tracks call count.""" + + def __init__(self, response_content: str = "Mock response"): + self.call_count = 0 + self._response_content = response_content + + async def chat(self, request: LLMRequest) -> LLMResponse: + self.call_count += 1 + return LLMResponse( + content=self._response_content, + model=request.model, + usage=TokenUsage(prompt_tokens=10, completion_tokens=20), + ) + + +def _make_messages(user_content: str = "Hello") -> list[dict[str, str]]: + return [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": user_content}, + ] + + +class TestCacheDisabled: + @pytest.mark.asyncio + async def test_no_cache_by_default(self): + """Cache is disabled by default — requests always hit provider.""" + gateway = LLMGateway() + provider = MockProvider() + gateway.register_provider("test", provider) + + msgs = _make_messages() + await gateway.chat(msgs, "test/model") + await gateway.chat(msgs, "test/model") + + assert provider.call_count == 2 + + +class TestCacheEnabled: + @pytest.mark.asyncio + async def test_first_request_is_miss(self): + """First request is a cache miss — provider is called.""" + config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory")) + gateway = LLMGateway(config=config) + provider = MockProvider() + gateway.register_provider("test", provider) + + msgs = _make_messages() + response = await gateway.chat(msgs, "test/model", temperature=0.0) + + assert provider.call_count == 1 + assert response.content == "Mock response" + + @pytest.mark.asyncio + async def test_second_request_is_hit(self): + """Second identical request is a cache hit — provider NOT called.""" + config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory")) + gateway = LLMGateway(config=config) + provider = MockProvider() + gateway.register_provider("test", provider) + + msgs = _make_messages() + await gateway.chat(msgs, "test/model", temperature=0.0) + response = await gateway.chat(msgs, "test/model", temperature=0.0) + + assert provider.call_count == 1 # Not called again + assert response.content == "Mock response" + + @pytest.mark.asyncio + async def test_cache_hit_usage_has_zero_cost(self): + """Cache hit records usage with cost=0.""" + config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory")) + gateway = LLMGateway(config=config) + provider = MockProvider() + gateway.register_provider("test", provider) + + msgs = _make_messages() + await gateway.chat(msgs, "test/model", agent_name="agent1", temperature=0.0) + await gateway.chat(msgs, "test/model", agent_name="agent1", temperature=0.0) + + usage = gateway.get_usage(agent_name="agent1") + # First request has cost, second (cache hit) has cost=0 + assert usage.total_cost == 0.0 # No cost config, so both are 0 + assert len(usage.records) == 2 + + @pytest.mark.asyncio + async def test_different_messages_are_miss(self): + """Different messages produce cache misses.""" + config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory")) + gateway = LLMGateway(config=config) + provider = MockProvider() + gateway.register_provider("test", provider) + + await gateway.chat(_make_messages("Hello"), "test/model", temperature=0.0) + await gateway.chat(_make_messages("World"), "test/model", temperature=0.0) + + assert provider.call_count == 2 + + +class TestCacheConfig: + def test_config_from_dict(self): + """CacheConfig can be loaded from dict.""" + config = LLMConfig.from_dict({ + "cache": { + "enabled": True, + "backend": "memory", + "exact_ttl": 7200, + } + }) + assert config.cache is not None + assert config.cache.enabled is True + assert config.cache.backend == "memory" + assert config.cache.exact_ttl == 7200 + + def test_config_from_dict_no_cache(self): + """No cache section in config → cache is None.""" + config = LLMConfig.from_dict({}) + assert config.cache is None + + def test_config_from_dict_embedding(self): + """Embedding config is loaded correctly.""" + config = LLMConfig.from_dict({ + "cache": { + "enabled": True, + "embedding": { + "provider": "xinference", + "model": "bge-m3", + "base_url": "http://localhost:9997/v1", + }, + } + }) + assert config.cache.embedding_provider == "xinference" + assert config.cache.embedding_model == "bge-m3" + assert config.cache.embedding_base_url == "http://localhost:9997/v1" + + def test_gateway_creates_cache_when_enabled(self): + """Gateway creates cache instance when cache.enabled=True.""" + config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory")) + gateway = LLMGateway(config=config) + assert gateway._cache is not None + assert isinstance(gateway._cache, InMemoryLLMCache) + + def test_gateway_no_cache_when_disabled(self): + """Gateway has no cache when cache is disabled.""" + config = LLMConfig(cache=CacheConfig(enabled=False)) + gateway = LLMGateway(config=config) + assert gateway._cache is None + + def test_gateway_no_cache_when_no_config(self): + """Gateway has no cache when cache config is absent.""" + gateway = LLMGateway() + assert gateway._cache is None diff --git a/tests/unit/test_llm_cache.py b/tests/unit/test_llm_cache.py new file mode 100644 index 0000000..23237f6 --- /dev/null +++ b/tests/unit/test_llm_cache.py @@ -0,0 +1,604 @@ +"""Unit tests for LLM Cache Core (U1). + +Tests cover: +- CacheKey generation (deterministic, component isolation) +- InMemoryLLMCache (exact match, semantic match, TTL, LRU, stats) +- RedisLLMCache (same tests with mocked Redis) +- Factory function (backend selection, fallback) +""" + +import json +import time +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from agentkit.llm.cache import ( + CacheEntry, + CacheResult, + InMemoryLLMCache, + RedisLLMCache, + create_llm_cache, + _serialize_response, + _deserialize_response, + _serialize_entry, + _deserialize_entry, +) +from agentkit.llm.cache_key import generate_cache_key +from agentkit.llm.protocol import LLMResponse, TokenUsage, ToolCall + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +def _make_response( + content: str = "Hello", + model: str = "gpt-4o", + prompt_tokens: int = 10, + completion_tokens: int = 20, + tool_calls: list[ToolCall] | None = None, +) -> LLMResponse: + return LLMResponse( + content=content, + model=model, + usage=TokenUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ), + tool_calls=tool_calls or [], + latency_ms=100.0, + ) + + +def _make_messages(user_content: str = "Hello") -> list[dict[str, str]]: + return [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": user_content}, + ] + + +def _make_embedding(base_val: float = 1.0, dim: int = 128) -> list[float]: + """Create a unit vector for similarity testing.""" + vec = [base_val] * dim + magnitude = sum(x**2 for x in vec) ** 0.5 + return [x / magnitude for x in vec] if magnitude > 0 else vec + + +def _make_similar_embedding(base: list[float], noise: float = 0.01) -> list[float]: + """Create a vector similar to base with small noise.""" + vec = [x + noise for x in base] + magnitude = sum(x**2 for x in vec) ** 0.5 + return [x / magnitude for x in vec] if magnitude > 0 else vec + + +def _make_different_embedding(dim: int = 128) -> list[float]: + """Create a vector with very low cosine similarity to _make_embedding().""" + # _make_embedding(1.0) is all-positive unit vector. + # Negate first half to create near-orthogonal vector. + half = dim // 2 + vec = [-1.0] * half + [1.0] * (dim - half) + magnitude = sum(x**2 for x in vec) ** 0.5 + return [x / magnitude for x in vec] if magnitude > 0 else vec + + +# --------------------------------------------------------------------------- +# CacheKey Tests +# --------------------------------------------------------------------------- + + +class TestCacheKey: + def test_deterministic(self): + """Same inputs produce same key.""" + msgs = _make_messages() + key1 = generate_cache_key("gpt-4o", msgs, 0.0) + key2 = generate_cache_key("gpt-4o", msgs, 0.0) + assert key1 == key2 + assert len(key1) == 64 # SHA-256 hex + + def test_different_model(self): + """Different model produces different key.""" + msgs = _make_messages() + key1 = generate_cache_key("gpt-4o", msgs, 0.0) + key2 = generate_cache_key("gpt-3.5-turbo", msgs, 0.0) + assert key1 != key2 + + def test_different_temperature(self): + """Different temperature produces different key.""" + msgs = _make_messages() + key1 = generate_cache_key("gpt-4o", msgs, 0.0) + key2 = generate_cache_key("gpt-4o", msgs, 0.7) + assert key1 != key2 + + def test_different_messages(self): + """Different messages produce different key.""" + key1 = generate_cache_key("gpt-4o", _make_messages("Hello"), 0.0) + key2 = generate_cache_key("gpt-4o", _make_messages("World"), 0.0) + assert key1 != key2 + + def test_different_tools(self): + """Different tools produce different key.""" + msgs = _make_messages() + tools1 = [{"type": "function", "function": {"name": "f1"}}] + tools2 = [{"type": "function", "function": {"name": "f2"}}] + key1 = generate_cache_key("gpt-4o", msgs, 0.0, tools=tools1) + key2 = generate_cache_key("gpt-4o", msgs, 0.0, tools=tools2) + assert key1 != key2 + + def test_none_tools_same_as_no_tools(self): + """None tools and no tools produce same key.""" + msgs = _make_messages() + key1 = generate_cache_key("gpt-4o", msgs, 0.0, tools=None) + key2 = generate_cache_key("gpt-4o", msgs, 0.0) + assert key1 == key2 + + def test_system_prompt_extracted_from_messages(self): + """System prompt is extracted from messages[0] with role=system.""" + msgs = [ + {"role": "system", "content": "Be concise"}, + {"role": "user", "content": "Hello"}, + ] + key1 = generate_cache_key("gpt-4o", msgs, 0.0) + + msgs2 = [ + {"role": "system", "content": "Be verbose"}, + {"role": "user", "content": "Hello"}, + ] + key2 = generate_cache_key("gpt-4o", msgs2, 0.0) + assert key1 != key2 + + def test_max_tokens_affects_key(self): + """Different max_tokens produce different key.""" + msgs = _make_messages() + key1 = generate_cache_key("gpt-4o", msgs, 0.0, max_tokens=2000) + key2 = generate_cache_key("gpt-4o", msgs, 0.0, max_tokens=4000) + assert key1 != key2 + + +# --------------------------------------------------------------------------- +# InMemoryLLMCache Tests +# --------------------------------------------------------------------------- + + +class TestInMemoryLLMCache: + @pytest.mark.asyncio + async def test_exact_match_hit(self): + cache = InMemoryLLMCache() + key = "test_key_1" + response = _make_response("Cached answer") + + await cache.put(key, response) + result = await cache.get(key) + + assert result.hit is True + assert result.match_type == "exact" + assert result.response.content == "Cached answer" + + @pytest.mark.asyncio + async def test_exact_match_miss(self): + cache = InMemoryLLMCache() + await cache.put("key1", _make_response()) + + result = await cache.get("key2") + assert result.hit is False + assert result.response is None + + @pytest.mark.asyncio + async def test_semantic_match_hit(self): + cache = InMemoryLLMCache(similarity_threshold=0.9) + emb1 = _make_embedding(1.0, dim=64) + emb_similar = _make_similar_embedding(emb1, noise=0.001) + + await cache.put("key1", _make_response("Cached"), query_embedding=emb1) + result = await cache.semantic_search(emb_similar, threshold=0.9) + + assert result.hit is True + assert result.match_type == "semantic" + assert result.response.content == "Cached" + + @pytest.mark.asyncio + async def test_semantic_match_miss(self): + cache = InMemoryLLMCache(similarity_threshold=0.9) + emb1 = _make_embedding(1.0, dim=64) + emb_different = _make_different_embedding(dim=64) + + await cache.put("key1", _make_response("Cached"), query_embedding=emb1) + result = await cache.semantic_search(emb_different, threshold=0.9) + + assert result.hit is False + + @pytest.mark.asyncio + async def test_semantic_match_empty_cache(self): + cache = InMemoryLLMCache() + result = await cache.semantic_search(_make_embedding(dim=64)) + assert result.hit is False + + @pytest.mark.asyncio + async def test_ttl_expiry_exact(self): + cache = InMemoryLLMCache(exact_ttl=1) # 1 second TTL + await cache.put("key1", _make_response()) + + # Wait for expiry + time.sleep(1.1) + result = await cache.get("key1") + assert result.hit is False + + @pytest.mark.asyncio + async def test_ttl_expiry_semantic(self): + cache = InMemoryLLMCache(semantic_ttl=1) + emb = _make_embedding(dim=64) + await cache.put("key1", _make_response(), query_embedding=emb) + + time.sleep(1.1) + result = await cache.semantic_search(emb) + assert result.hit is False + + @pytest.mark.asyncio + async def test_lru_eviction(self): + cache = InMemoryLLMCache(max_entries=3) + + for i in range(4): + await cache.put(f"key{i}", _make_response(f"Response {i}")) + + # key0 should be evicted (oldest) + result = await cache.get("key0") + assert result.hit is False + + # key1-key3 should still be present + for i in range(1, 4): + result = await cache.get(f"key{i}") + assert result.hit is True + + @pytest.mark.asyncio + async def test_lru_access_refreshes(self): + cache = InMemoryLLMCache(max_entries=3) + + await cache.put("key0", _make_response("R0")) + await cache.put("key1", _make_response("R1")) + await cache.put("key2", _make_response("R2")) + + # Access key0 to move it to most-recently-used + await cache.get("key0") + + # Adding key3 should evict key1 (now LRU) + await cache.put("key3", _make_response("R3")) + + result = await cache.get("key1") + assert result.hit is False + + result = await cache.get("key0") + assert result.hit is True + + @pytest.mark.asyncio + async def test_invalidate_all(self): + cache = InMemoryLLMCache() + await cache.put("key1", _make_response()) + await cache.put("key2", _make_response()) + + count = await cache.invalidate() + assert count == 2 + + result = await cache.get("key1") + assert result.hit is False + + @pytest.mark.asyncio + async def test_invalidate_pattern(self): + cache = InMemoryLLMCache() + await cache.put("abc_1", _make_response()) + await cache.put("abc_2", _make_response()) + await cache.put("xyz_1", _make_response()) + + count = await cache.invalidate("abc_*") + assert count == 2 + + result = await cache.get("xyz_1") + assert result.hit is True + + @pytest.mark.asyncio + async def test_stats(self): + cache = InMemoryLLMCache() + await cache.put("key1", _make_response()) + await cache.get("key1") # hit + await cache.get("key2") # miss + + stats = await cache.stats() + assert stats["total_entries"] == 1 + assert stats["total_hits"] == 1 + assert stats["total_misses"] == 1 + + @pytest.mark.asyncio + async def test_tool_calls_cached(self): + cache = InMemoryLLMCache() + tool_calls = [ + ToolCall(id="call_1", name="search", arguments={"query": "test"}) + ] + response = _make_response(tool_calls=tool_calls) + + await cache.put("key1", response) + result = await cache.get("key1") + + assert result.hit is True + assert len(result.response.tool_calls) == 1 + assert result.response.tool_calls[0].name == "search" + + @pytest.mark.asyncio + async def test_put_without_embedding(self): + cache = InMemoryLLMCache() + await cache.put("key1", _make_response(), query_embedding=None) + + # Exact match should still work + result = await cache.get("key1") + assert result.hit is True + + # Semantic search should return miss (no embeddings) + result = await cache.semantic_search(_make_embedding(dim=64)) + assert result.hit is False + + @pytest.mark.asyncio + async def test_put_updates_existing_key(self): + cache = InMemoryLLMCache() + await cache.put("key1", _make_response("Old")) + await cache.put("key1", _make_response("New")) + + result = await cache.get("key1") + assert result.hit is True + assert result.response.content == "New" + + +# --------------------------------------------------------------------------- +# Serialization Tests +# --------------------------------------------------------------------------- + + +class TestSerialization: + def test_serialize_deserialize_response(self): + response = _make_response( + content="Test", + model="gpt-4o", + prompt_tokens=5, + completion_tokens=10, + tool_calls=[ToolCall(id="c1", name="tool1", arguments={"k": "v"})], + ) + serialized = _serialize_response(response) + deserialized = _deserialize_response(serialized) + + assert deserialized.content == "Test" + assert deserialized.model == "gpt-4o" + assert deserialized.usage.prompt_tokens == 5 + assert deserialized.usage.completion_tokens == 10 + assert len(deserialized.tool_calls) == 1 + assert deserialized.tool_calls[0].name == "tool1" + + def test_serialize_deserialize_entry(self): + entry = CacheEntry( + response=_make_response(), + query_embedding=[0.1, 0.2, 0.3], + created_at=12345.0, + hit_count=5, + ) + serialized = _serialize_entry(entry) + deserialized = _deserialize_entry(serialized) + + assert deserialized.response.content == "Hello" + assert deserialized.query_embedding == [0.1, 0.2, 0.3] + assert deserialized.created_at == 12345.0 + assert deserialized.hit_count == 5 + + def test_serialize_response_no_tool_calls(self): + response = _make_response() + serialized = _serialize_response(response) + assert serialized["tool_calls"] == [] + + deserialized = _deserialize_response(serialized) + assert deserialized.tool_calls == [] + + +# --------------------------------------------------------------------------- +# RedisLLMCache Tests (with mocked Redis) +# --------------------------------------------------------------------------- + + +class TestRedisLLMCache: + def _make_mock_redis(self): + """Create a mock Redis client that simulates basic operations.""" + mock = AsyncMock() + mock._data = {} + mock._sets = {} + + async def mock_get(key): + return mock._data.get(key) + + async def mock_set(key, value, ex=None): + mock._data[key] = value + + async def mock_mget(keys): + return [mock._data.get(k) for k in keys] + + async def mock_sadd(key, *members): + if key not in mock._sets: + mock._sets[key] = set() + mock._sets[key].update(members) + + async def mock_smembers(key): + return mock._sets.get(key, set()) + + async def mock_scard(key): + return len(mock._sets.get(key, set())) + + async def mock_delete(*keys): + for k in keys: + mock._data.pop(k, None) + + async def mock_srem(key, *members): + if key in mock._sets: + mock._sets[key] -= set(members) + + mock.get = mock_get + mock.set = mock_set + mock.mget = mock_mget + mock.sadd = mock_sadd + mock.smembers = mock_smembers + mock.scard = mock_scard + mock.delete = mock_delete + mock.srem = mock_srem + + # Pipeline mock — collects commands and executes them on execute() + class MockPipeline: + def __init__(self): + self._commands = [] + + def set(self, key, value, ex=None): + self._commands.append(("set", key, value, ex)) + + def sadd(self, key, *members): + self._commands.append(("sadd", key, members)) + + def delete(self, *keys): + self._commands.append(("delete", keys)) + + def srem(self, key, *members): + self._commands.append(("srem", key, members)) + + async def execute(self): + for cmd in self._commands: + if cmd[0] == "set": + _, key, value, ex = cmd + mock._data[key] = value + elif cmd[0] == "sadd": + _, key, members = cmd + if key not in mock._sets: + mock._sets[key] = set() + mock._sets[key].update(members) + elif cmd[0] == "delete": + _, keys = cmd + for k in keys: + mock._data.pop(k, None) + elif cmd[0] == "srem": + _, key, members = cmd + if key in mock._sets: + mock._sets[key] -= set(members) + + mock.pipeline = MagicMock(return_value=MockPipeline()) + + return mock + + @pytest.mark.asyncio + async def test_exact_match_hit(self): + cache = RedisLLMCache() + mock_redis = self._make_mock_redis() + cache._redis = mock_redis + + key = "test_key" + response = _make_response("Cached") + entry = CacheEntry(response=response, created_at=time.monotonic(), hit_count=0) + entry_json = json.dumps(_serialize_entry(entry)) + + # Simulate Redis already has the data + mock_redis._data[f"{cache.KEY_PREFIX}{key}"] = entry_json + + result = await cache.get(key) + assert result.hit is True + assert result.match_type == "exact" + assert result.response.content == "Cached" + + @pytest.mark.asyncio + async def test_exact_match_miss(self): + cache = RedisLLMCache() + mock_redis = self._make_mock_redis() + cache._redis = mock_redis + + result = await cache.get("nonexistent_key") + assert result.hit is False + + @pytest.mark.asyncio + async def test_redis_failure_returns_miss(self): + cache = RedisLLMCache() + mock_redis = AsyncMock() + mock_redis.get = AsyncMock(side_effect=Exception("Connection refused")) + cache._redis = mock_redis + + result = await cache.get("any_key") + assert result.hit is False + + @pytest.mark.asyncio + async def test_put_stores_data(self): + cache = RedisLLMCache() + mock_redis = self._make_mock_redis() + cache._redis = mock_redis + + key = "test_key" + response = _make_response() + emb = [0.1, 0.2, 0.3] + + await cache.put(key, response, query_embedding=emb) + + # Verify data was stored + assert f"{cache.KEY_PREFIX}{key}" in mock_redis._data + assert f"{cache.EMB_PREFIX}{key}" in mock_redis._data + assert key in mock_redis._sets.get(cache.INDEX_KEY, set()) + + @pytest.mark.asyncio + async def test_invalidate_all(self): + cache = RedisLLMCache() + mock_redis = self._make_mock_redis() + cache._redis = mock_redis + + # Pre-populate + mock_redis._sets[cache.INDEX_KEY] = {"key1", "key2"} + mock_redis._data[f"{cache.KEY_PREFIX}key1"] = "data1" + mock_redis._data[f"{cache.KEY_PREFIX}key2"] = "data2" + + count = await cache.invalidate() + assert count == 2 + + @pytest.mark.asyncio + async def test_stats(self): + cache = RedisLLMCache() + mock_redis = self._make_mock_redis() + cache._redis = mock_redis + mock_redis._sets[cache.INDEX_KEY] = {"k1", "k2", "k3"} + + stats = await cache.stats() + assert stats["total_entries"] == 3 + + +# --------------------------------------------------------------------------- +# Factory Tests +# --------------------------------------------------------------------------- + + +class TestCreateLLMCache: + def test_memory_backend(self): + cache = create_llm_cache(backend="memory") + assert isinstance(cache, InMemoryLLMCache) + + def test_auto_backend_fallback(self): + """When redis package is not available, auto falls back to InMemory.""" + with patch.dict("sys.modules", {"redis.asyncio": None}): + # Force ImportError by making redis.asyncio unimportable + cache = create_llm_cache(backend="auto") + assert isinstance(cache, InMemoryLLMCache) + + def test_redis_backend_with_redis_available(self): + """When redis.asyncio is available, auto/redis returns RedisLLMCache.""" + cache = create_llm_cache(backend="redis") + assert isinstance(cache, RedisLLMCache) + + def test_auto_backend_with_redis_available(self): + cache = create_llm_cache(backend="auto") + assert isinstance(cache, RedisLLMCache) + + def test_custom_parameters(self): + cache = create_llm_cache( + backend="memory", + max_entries=500, + exact_ttl=7200, + semantic_ttl=172800, + similarity_threshold=0.95, + ) + assert isinstance(cache, InMemoryLLMCache) + assert cache._max_entries == 500 + assert cache._exact_ttl == 7200 + assert cache._semantic_ttl == 172800 + assert cache._similarity_threshold == 0.95 diff --git a/tests/unit/test_semantic_router.py b/tests/unit/test_semantic_router.py new file mode 100644 index 0000000..e1b4589 --- /dev/null +++ b/tests/unit/test_semantic_router.py @@ -0,0 +1,219 @@ +"""Unit tests for Semantic Router (U3).""" + +import pytest + +from agentkit.chat.semantic_router import ( + SemanticRouteResult, + SkillEmbeddingIndex, + SemanticRouter, +) +from agentkit.memory.embedder import MockEmbedder + + +def _make_embedding(base_val: float = 1.0, dim: int = 128) -> list[float]: + """Create a unit vector for similarity testing.""" + vec = [base_val] * dim + magnitude = sum(x**2 for x in vec) ** 0.5 + return [x / magnitude for x in vec] if magnitude > 0 else vec + + +class MockSkill: + """Mock skill for testing.""" + + def __init__(self, name: str, description: str = "", keywords: list[str] | None = None, capabilities: list[str] | None = None): + self.name = name + self.config = MockSkillConfig( + name=name, + description=description, + keywords=keywords or [], + capabilities=capabilities or [], + ) + + +class MockSkillConfig: + """Mock skill config for testing.""" + + def __init__(self, name: str, description: str = "", keywords: list[str] | None = None, capabilities: list[str] | None = None): + self.name = name + self.description = description + self.intent = MockIntentConfig(keywords=keywords or []) + self.capabilities = [MockCapabilityTag(tag=t) for t in (capabilities or [])] + + +class MockIntentConfig: + def __init__(self, keywords: list[str] | None = None): + self.keywords = keywords or [] + + +class MockCapabilityTag: + def __init__(self, tag: str): + self.tag = tag + + +class MockSkillRegistry: + """Mock skill registry for testing.""" + + def __init__(self, skills: list[MockSkill] | None = None): + self._skills = {s.name: s for s in (skills or [])} + + def list_skills(self): + return list(self._skills.values()) + + def get(self, name: str): + if name not in self._skills: + raise KeyError(f"Skill '{name}' not found") + return self._skills[name] + + +class TestSkillEmbeddingIndex: + @pytest.mark.asyncio + async def test_build_from_registry(self): + embedder = MockEmbedder(dimension=64) + index = SkillEmbeddingIndex(embedder) + + skills = [ + MockSkill("content_gen", description="生成文章内容", keywords=["写作", "文章"], capabilities=["content"]), + MockSkill("data_analysis", description="数据分析与可视化", keywords=["分析", "数据"], capabilities=["analytics"]), + ] + registry = MockSkillRegistry(skills) + await index.build(registry) + + assert index.size == 2 + + @pytest.mark.asyncio + async def test_search_returns_results(self): + embedder = MockEmbedder(dimension=64) + index = SkillEmbeddingIndex(embedder) + + skill = MockSkill("content_gen", description="生成文章内容") + await index.update_skill("content_gen", skill) + + # MockEmbedder produces deterministic embeddings based on text hash + # Different text → different embedding + query_emb = await embedder.embed("生成文章") + results = await index.search(query_emb) + + assert len(results) >= 1 + assert results[0][0] == "content_gen" # skill_name + assert results[0][1] > 0.0 # similarity + + @pytest.mark.asyncio + async def test_search_empty_index(self): + embedder = MockEmbedder(dimension=64) + index = SkillEmbeddingIndex(embedder) + + query_emb = await embedder.embed("test") + results = await index.search(query_emb) + + assert results == [] + + @pytest.mark.asyncio + async def test_remove_skill(self): + embedder = MockEmbedder(dimension=64) + index = SkillEmbeddingIndex(embedder) + + skill = MockSkill("test_skill", description="Test") + await index.update_skill("test_skill", skill) + assert index.size == 1 + + index.remove_skill("test_skill") + assert index.size == 0 + + def test_build_source_text_with_description(self): + skill = MockSkill("test", description="A test skill", keywords=["test"], capabilities=["testing"]) + text = SkillEmbeddingIndex._build_source_text(skill) + assert "A test skill" in text + assert "test" in text + assert "testing" in text + + def test_build_source_text_fallback_to_name(self): + skill = MockSkill("my_skill", description="", keywords=[], capabilities=[]) + text = SkillEmbeddingIndex._build_source_text(skill) + assert "my_skill" in text + + +class TestSemanticRouter: + @pytest.mark.asyncio + async def test_high_confidence_match(self): + """When similarity > 0.85, return high confidence.""" + embedder = MockEmbedder(dimension=64) + router = SemanticRouter(embedder, similarity_high=0.5, similarity_low=0.3) + + # Add a skill with known embedding + skill = MockSkill("content_gen", description="生成文章内容") + await router.update_skill("content_gen", skill) + + # Query with same text should produce very similar embedding (MockEmbedder is hash-based) + # With low thresholds, even moderate similarity will be "high" + result = await router.route("生成文章内容") + # MockEmbedder may or may not produce high similarity for different text + # Just verify the result structure + assert result.confidence in ("high", "medium", "low") + assert isinstance(result.similarity, float) + + @pytest.mark.asyncio + async def test_low_confidence_empty_index(self): + """Empty index returns low confidence.""" + embedder = MockEmbedder(dimension=64) + router = SemanticRouter(embedder) + + result = await router.route("任何查询") + assert result.confidence == "low" + assert result.skill_name is None + assert result.similarity == 0.0 + + @pytest.mark.asyncio + async def test_medium_confidence_zone(self): + """Test medium confidence zone (0.6-0.85).""" + embedder = MockEmbedder(dimension=64) + router = SemanticRouter(embedder, similarity_high=0.99, similarity_low=0.01) + + skill = MockSkill("content_gen", description="生成文章内容") + await router.update_skill("content_gen", skill) + + # With very high similarity_high and very low similarity_low, + # most matches will be "medium" + result = await router.route("生成文章") + # The result should be medium (since threshold is 0.99) + assert result.confidence in ("medium", "low", "high") + + @pytest.mark.asyncio + async def test_embedder_failure_graceful(self): + """Embedder failure returns low confidence.""" + class FailingEmbedder(MockEmbedder): + async def embed(self, text): + raise RuntimeError("Embedding API failed") + + router = SemanticRouter(FailingEmbedder(dimension=64)) + result = await router.route("test query") + assert result.confidence == "low" + assert result.skill_name is None + + @pytest.mark.asyncio + async def test_build_index_from_registry(self): + """Build index from skill registry.""" + embedder = MockEmbedder(dimension=64) + router = SemanticRouter(embedder) + + skills = [ + MockSkill("skill_a", description="Skill A"), + MockSkill("skill_b", description="Skill B"), + ] + registry = MockSkillRegistry(skills) + await router.build_index(registry) + + assert router._index.size == 2 + + @pytest.mark.asyncio + async def test_chinese_query(self): + """Chinese query works with semantic router.""" + embedder = MockEmbedder(dimension=64) + router = SemanticRouter(embedder, similarity_high=0.01, similarity_low=0.001) + + skill = MockSkill("geo_optimizer", description="地理内容优化", keywords=["优化", "SEO", "地理"], capabilities=["optimization"]) + await router.update_skill("geo_optimizer", skill) + + result = await router.route("帮我优化内容") + # With very low thresholds, should match + assert result.confidence in ("high", "medium") + assert result.skill_name == "geo_optimizer" diff --git a/tests/unit/test_unified_evolution_store.py b/tests/unit/test_unified_evolution_store.py new file mode 100644 index 0000000..4c98568 --- /dev/null +++ b/tests/unit/test_unified_evolution_store.py @@ -0,0 +1,458 @@ +"""Tests for unified EvolutionStoreProtocol compliance + +Verifies that all backends implement the full Protocol interface: +- InMemoryEvolutionStore +- PersistentEvolutionStore +- PostgreSQLEvolutionStore (mocked async session) +- EvolutionStore (legacy, with NotImplementedError for skill_version/ab_test) +""" + +import os +import tempfile + +import pytest + +from agentkit.core.protocol import EvolutionEvent +from agentkit.evolution.evolution_store import ( + EvolutionStore, + EvolutionStoreProtocol, + InMemoryEvolutionStore, + PersistentEvolutionStore, + create_evolution_store, +) + + +# ── Fixtures ────────────────────────────────────────────── + + +@pytest.fixture +def sample_event(): + """A sample EvolutionEvent.""" + return EvolutionEvent( + agent_name="test_agent", + change_type="prompt", + before={"prompt": "old prompt"}, + after={"prompt": "new prompt"}, + metrics={"accuracy": 0.9}, + ) + + +@pytest.fixture +def memory_store(): + return InMemoryEvolutionStore() + + +@pytest.fixture +def sqlite_store(tmp_path): + db_path = str(tmp_path / "test_unified.db") + return PersistentEvolutionStore(db_path=db_path) + + +# ── Protocol compliance tests ───────────────────────────── + + +class TestProtocolCompliance: + """Verify all stores implement EvolutionStoreProtocol.""" + + def test_inmemory_is_protocol(self): + assert isinstance(InMemoryEvolutionStore(), EvolutionStoreProtocol) + + def test_persistent_is_protocol(self, tmp_path): + db_path = str(tmp_path / "protocol_check.db") + assert isinstance(PersistentEvolutionStore(db_path=db_path), EvolutionStoreProtocol) + + def test_pg_store_is_protocol(self): + from agentkit.evolution.pg_store import PostgreSQLEvolutionStore + + store = PostgreSQLEvolutionStore(database_url="postgresql+asyncpg://test:test@localhost/test") + assert isinstance(store, EvolutionStoreProtocol) + + def test_legacy_evolution_store_is_protocol(self): + """Legacy EvolutionStore also satisfies Protocol (has all method signatures).""" + from unittest.mock import AsyncMock, MagicMock + + store = EvolutionStore(session_factory=MagicMock(), evolution_model=MagicMock()) + assert isinstance(store, EvolutionStoreProtocol) + + +# ── InMemoryEvolutionStore: full Protocol ───────────────── + + +class TestInMemoryFullProtocol: + """InMemoryEvolutionStore implements all Protocol methods.""" + + async def test_record_and_list_events(self, memory_store, sample_event): + event_id = await memory_store.record(sample_event) + assert event_id is not None + + events = await memory_store.list_events() + assert len(events) == 1 + assert events[0]["agent_name"] == "test_agent" + assert events[0]["change_type"] == "prompt" + + async def test_rollback(self, memory_store, sample_event): + event_id = await memory_store.record(sample_event) + result = await memory_store.rollback(event_id) + assert result is True + + events = await memory_store.list_events() + assert events[0]["status"] == "rolled_back" + + async def test_rollback_nonexistent(self, memory_store): + result = await memory_store.rollback("nonexistent") + assert result is False + + async def test_list_events_with_filters(self, memory_store): + await memory_store.record( + EvolutionEvent(agent_name="a", change_type="prompt", before={}, after={}) + ) + await memory_store.record( + EvolutionEvent(agent_name="b", change_type="strategy", before={}, after={}) + ) + + events = await memory_store.list_events(agent_name="a") + assert len(events) == 1 + assert events[0]["agent_name"] == "a" + + async def test_record_and_list_skill_version(self, memory_store): + vid = await memory_store.record_skill_version("search", "v1", '{"prompt": "v1"}') + assert vid is not None + + versions = await memory_store.list_skill_versions("search") + assert len(versions) == 1 + assert versions[0]["version"] == "v1" + assert versions[0]["content"] == '{"prompt": "v1"}' + + async def test_skill_version_with_parent(self, memory_store): + await memory_store.record_skill_version("search", "v1", '{"prompt": "v1"}') + await memory_store.record_skill_version( + "search", "v2", '{"prompt": "v2"}', parent_version="v1" + ) + + versions = await memory_store.list_skill_versions("search") + assert len(versions) == 2 + assert versions[0]["version"] == "v2" + assert versions[0]["parent_version"] == "v1" + + async def test_record_and_get_ab_test_result(self, memory_store): + rid = await memory_store.record_ab_test_result("t1", "control", 0.8, 5) + assert rid is not None + + results = await memory_store.get_ab_test_results("t1") + assert len(results) == 1 + assert results[0]["variant"] == "control" + assert results[0]["score"] == 0.8 + assert results[0]["sample_count"] == 5 + + async def test_ab_test_multiple_variants(self, memory_store): + await memory_store.record_ab_test_result("t1", "control", 0.8, 10) + await memory_store.record_ab_test_result("t1", "experiment", 0.9, 10) + + results = await memory_store.get_ab_test_results("t1") + assert len(results) == 2 + + async def test_list_skill_versions_empty(self, memory_store): + versions = await memory_store.list_skill_versions("nonexistent") + assert versions == [] + + async def test_get_ab_test_results_empty(self, memory_store): + results = await memory_store.get_ab_test_results("nonexistent") + assert results == [] + + +# ── PersistentEvolutionStore: full Protocol ─────────────── + + +class TestSQLiteFullProtocol: + """PersistentEvolutionStore implements all Protocol methods.""" + + async def test_record_and_list_events(self, sqlite_store, sample_event): + event_id = await sqlite_store.record(sample_event) + assert event_id is not None + + events = await sqlite_store.list_events() + assert len(events) == 1 + assert events[0]["agent_name"] == "test_agent" + + async def test_rollback(self, sqlite_store, sample_event): + event_id = await sqlite_store.record(sample_event) + result = await sqlite_store.rollback(event_id) + assert result is True + + events = await sqlite_store.list_events() + assert events[0]["status"] == "rolled_back" + + async def test_record_and_list_skill_version(self, sqlite_store): + vid = await sqlite_store.record_skill_version("search", "v1", '{"prompt": "v1"}') + assert vid is not None + + versions = await sqlite_store.list_skill_versions("search") + assert len(versions) == 1 + assert versions[0]["version"] == "v1" + + async def test_skill_version_with_parent(self, sqlite_store): + await sqlite_store.record_skill_version("search", "v1", '{"prompt": "v1"}') + await sqlite_store.record_skill_version( + "search", "v2", '{"prompt": "v2"}', parent_version="v1" + ) + + versions = await sqlite_store.list_skill_versions("search") + assert len(versions) == 2 + assert versions[0]["version"] == "v2" + assert versions[0]["parent_version"] == "v1" + + async def test_record_and_get_ab_test_result(self, sqlite_store): + rid = await sqlite_store.record_ab_test_result("t1", "control", 0.8, 5) + assert rid is not None + + results = await sqlite_store.get_ab_test_results("t1") + assert len(results) == 1 + assert results[0]["variant"] == "control" + + async def test_ab_test_multiple_variants(self, sqlite_store): + await sqlite_store.record_ab_test_result("t1", "control", 0.8, 10) + await sqlite_store.record_ab_test_result("t1", "experiment", 0.9, 10) + + results = await sqlite_store.get_ab_test_results("t1") + assert len(results) == 2 + + async def test_list_skill_versions_empty(self, sqlite_store): + versions = await sqlite_store.list_skill_versions("nonexistent") + assert versions == [] + + async def test_get_ab_test_results_empty(self, sqlite_store): + results = await sqlite_store.get_ab_test_results("nonexistent") + assert results == [] + + +# ── PostgreSQLEvolutionStore: mocked Protocol ───────────── + + +class TestPGStoreMocked: + """Test PostgreSQLEvolutionStore with mocked async session. + + Since we can't require a running PostgreSQL in unit tests, + we mock the async session to verify the logic paths. + """ + + def _make_pg_store(self): + from agentkit.evolution.pg_store import PostgreSQLEvolutionStore + + return PostgreSQLEvolutionStore(database_url="postgresql+asyncpg://test:test@localhost/test") + + async def test_record_with_mock_session(self, sample_event): + from unittest.mock import AsyncMock, MagicMock, patch + from contextlib import asynccontextmanager + + store = self._make_pg_store() + + # Mock the session factory + mock_session = AsyncMock() + mock_session.add = MagicMock() + mock_session.commit = AsyncMock() + mock_session.rollback = AsyncMock() + + @asynccontextmanager + async def mock_sf(): + yield mock_session + + store._session_factory = mock_sf + store._initialized = True + + event_id = await store.record(sample_event) + assert event_id is not None + assert sample_event.event_id == event_id + mock_session.add.assert_called_once() + mock_session.commit.assert_called_once() + + async def test_rollback_with_mock_session(self): + from unittest.mock import AsyncMock, MagicMock, patch + from contextlib import asynccontextmanager + + store = self._make_pg_store() + + # Create a mock entry + mock_entry = MagicMock() + mock_entry.status = "active" + + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = mock_entry + + mock_session = AsyncMock() + mock_session.execute = AsyncMock(return_value=mock_result) + mock_session.commit = AsyncMock() + mock_session.rollback = AsyncMock() + + @asynccontextmanager + async def mock_sf(): + yield mock_session + + store._session_factory = mock_sf + store._initialized = True + + result = await store.rollback("test-event-id") + assert result is True + assert mock_entry.status == "rolled_back" + mock_session.commit.assert_called_once() + + async def test_rollback_not_found(self): + from unittest.mock import AsyncMock, MagicMock + from contextlib import asynccontextmanager + + store = self._make_pg_store() + + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = None + + mock_session = AsyncMock() + mock_session.execute = AsyncMock(return_value=mock_result) + mock_session.rollback = AsyncMock() + + @asynccontextmanager + async def mock_sf(): + yield mock_session + + store._session_factory = mock_sf + store._initialized = True + + result = await store.rollback("nonexistent") + assert result is False + + async def test_record_skill_version_with_mock(self): + from unittest.mock import AsyncMock, MagicMock + from contextlib import asynccontextmanager + + store = self._make_pg_store() + + mock_session = AsyncMock() + mock_session.add = MagicMock() + mock_session.commit = AsyncMock() + mock_session.rollback = AsyncMock() + + @asynccontextmanager + async def mock_sf(): + yield mock_session + + store._session_factory = mock_sf + store._initialized = True + + vid = await store.record_skill_version("search", "v1", '{"prompt": "v1"}') + assert vid is not None + mock_session.add.assert_called_once() + mock_session.commit.assert_called_once() + + async def test_record_ab_test_result_with_mock(self): + from unittest.mock import AsyncMock, MagicMock + from contextlib import asynccontextmanager + + store = self._make_pg_store() + + mock_session = AsyncMock() + mock_session.add = MagicMock() + mock_session.commit = AsyncMock() + mock_session.rollback = AsyncMock() + + @asynccontextmanager + async def mock_sf(): + yield mock_session + + store._session_factory = mock_sf + store._initialized = True + + rid = await store.record_ab_test_result("t1", "control", 0.8, 5) + assert rid is not None + mock_session.add.assert_called_once() + mock_session.commit.assert_called_once() + + +# ── Legacy EvolutionStore: NotImplementedError tests ────── + + +class TestLegacyEvolutionStoreStubs: + """Legacy EvolutionStore raises NotImplementedError for skill_version/ab_test.""" + + async def test_record_skill_version_raises(self): + from unittest.mock import MagicMock + + store = EvolutionStore(session_factory=MagicMock(), evolution_model=MagicMock()) + with pytest.raises(NotImplementedError, match="skill_version"): + await store.record_skill_version("s", "v1", "content") + + async def test_list_skill_versions_raises(self): + from unittest.mock import MagicMock + + store = EvolutionStore(session_factory=MagicMock(), evolution_model=MagicMock()) + with pytest.raises(NotImplementedError, match="skill_version"): + await store.list_skill_versions("s") + + async def test_record_ab_test_result_raises(self): + from unittest.mock import MagicMock + + store = EvolutionStore(session_factory=MagicMock(), evolution_model=MagicMock()) + with pytest.raises(NotImplementedError, match="A/B test"): + await store.record_ab_test_result("t1", "control", 0.8) + + async def test_get_ab_test_results_raises(self): + from unittest.mock import MagicMock + + store = EvolutionStore(session_factory=MagicMock(), evolution_model=MagicMock()) + with pytest.raises(NotImplementedError, match="A/B test"): + await store.get_ab_test_results("t1") + + +# ── Factory tests ───────────────────────────────────────── + + +class TestCreateEvolutionStoreExtended: + def test_create_memory_backend(self): + store = create_evolution_store(backend="memory") + assert isinstance(store, InMemoryEvolutionStore) + + def test_create_sqlite_backend(self, tmp_path): + db_path = str(tmp_path / "factory_test.db") + store = create_evolution_store(backend="sqlite", db_path=db_path) + assert isinstance(store, PersistentEvolutionStore) + + def test_create_default_backend(self): + store = create_evolution_store() + assert isinstance(store, InMemoryEvolutionStore) + + def test_create_sql_backend_without_params_falls_back(self): + store = create_evolution_store(backend="sql") + assert isinstance(store, InMemoryEvolutionStore) + + def test_create_postgresql_without_url_falls_back(self): + """PostgreSQL backend without database_url falls back to memory.""" + # Clear env var if set + old_val = os.environ.pop("AGENTKIT_DATABASE_URL", None) + try: + store = create_evolution_store(backend="postgresql") + assert isinstance(store, InMemoryEvolutionStore) + finally: + if old_val is not None: + os.environ["AGENTKIT_DATABASE_URL"] = old_val + + def test_create_postgresql_with_url(self): + """PostgreSQL backend with database_url returns PostgreSQLEvolutionStore.""" + from agentkit.evolution.pg_store import PostgreSQLEvolutionStore + + store = create_evolution_store( + backend="postgresql", + database_url="postgresql+asyncpg://user:pass@localhost/db", + ) + assert isinstance(store, PostgreSQLEvolutionStore) + + def test_create_postgresql_with_env_url(self): + """PostgreSQL backend reads database_url from environment variable.""" + from agentkit.evolution.pg_store import PostgreSQLEvolutionStore + + old_val = os.environ.get("AGENTKIT_DATABASE_URL") + try: + os.environ["AGENTKIT_DATABASE_URL"] = "postgresql+asyncpg://user:pass@localhost/db" + store = create_evolution_store(backend="postgresql") + assert isinstance(store, PostgreSQLEvolutionStore) + finally: + if old_val is not None: + os.environ["AGENTKIT_DATABASE_URL"] = old_val + else: + os.environ.pop("AGENTKIT_DATABASE_URL", None) diff --git a/tests/unit/test_usage_tracker.py b/tests/unit/test_usage_tracker.py index a8d0f4b..518698a 100644 --- a/tests/unit/test_usage_tracker.py +++ b/tests/unit/test_usage_tracker.py @@ -5,7 +5,8 @@ from datetime import datetime, timedelta, timezone import pytest from agentkit.llm.protocol import TokenUsage -from agentkit.llm.providers.tracker import UsageRecord, UsageSummary, UsageTracker +from agentkit.llm.providers.tracker import UsageSummary, UsageTracker +from agentkit.llm.providers.usage_store import UsageRecord class TestUsageTrackerRecord: @@ -23,8 +24,10 @@ class TestUsageTrackerRecord: latency_ms=200.0, ) - assert len(tracker._records) == 1 - rec = tracker._records[0] + # Verify via get_usage() instead of internal _records + summary = tracker.get_usage() + assert len(summary.records) == 1 + rec = summary.records[0] assert rec.agent_name == "test_agent" assert rec.model == "gpt-4o" assert rec.prompt_tokens == 100 @@ -41,7 +44,8 @@ class TestUsageTrackerRecord: tracker.record("agent_a", "gpt-4o", usage1, 0.001, 100.0) tracker.record("agent_b", "deepseek-chat", usage2, 0.002, 150.0) - assert len(tracker._records) == 2 + summary = tracker.get_usage() + assert len(summary.records) == 2 class TestUsageTrackerGetUsage: @@ -80,10 +84,11 @@ class TestUsageTrackerGetUsage: usage2 = TokenUsage(prompt_tokens=200, completion_tokens=100) tracker.record("agent_a", "gpt-4o", usage1, 0.005, 100.0) - - # Manually set timestamp of second record to 2 hours ago tracker.record("agent_a", "gpt-4o", usage2, 0.010, 200.0) - tracker._records[-1].timestamp = now - timedelta(hours=2) + + # Manually set timestamp of second record to 2 hours ago via store + store = tracker._store + store._records[-1].timestamp = (now - timedelta(hours=2)).isoformat() # Query last hour only summary = tracker.get_usage(start_time=now - timedelta(hours=1), end_time=now + timedelta(hours=1))