refactor(llm+memory+client): remove Any from type signatures

Eliminate 172 Any usages across llm/, memory/, client/ via:
- TypeAlias (MetadataValue, MetadataDict, RAGSearchResult, etc.)
- object for arbitrary dict/value types
- TYPE_CHECKING Protocol for Redis/Quota/RAG/Graph services
- TYPE_CHECKING import + string annotations for forward refs
- Remove unused Any imports (18 F401 fixed)

Tests: 253 passed (llm 21 failures are pre-existing litellm env issue)
ruff: All checks passed
This commit is contained in:
chiguyong 2026-07-01 02:03:51 +08:00
parent aa6367ff9f
commit 34a89c4873
37 changed files with 905 additions and 601 deletions

View File

@ -34,12 +34,19 @@ import os
import sqlite3
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Callable
from typing import Callable, TypeAlias
import httpx
logger = logging.getLogger(__name__)
# 缓存的配置项 — 技能/工作流配置为 JSON 反序列化后的字典,值为标量或嵌套结构。
# 服务器返回的 skills/workflows 列表元素是 dictmodel_dump/to_dict 输出),
# 其中可能包含 list/dict 等容器,因此使用 object 作为值类型。
SkillConfigDict: TypeAlias = dict[str, object]
WorkflowConfigDict: TypeAlias = dict[str, object]
SyncedConfigPayload: TypeAlias = dict[str, object]
# ── Defaults ──────────────────────────────────────────────────────────
@ -100,8 +107,8 @@ class ConfigSync:
# In-memory cache (mirrors the SQLite cache for fast access)
self._version: str | None = None
self._skills: list[dict[str, Any]] = []
self._workflows: list[dict[str, Any]] = []
self._skills: list[SkillConfigDict] = []
self._workflows: list[WorkflowConfigDict] = []
self._last_synced_at: str | None = None
# ── Lifecycle ─────────────────────────────────────────────────
@ -232,15 +239,15 @@ class ConfigSync:
"""Return the current cached config version hash."""
return self._version
def get_skills(self) -> list[dict[str, Any]]:
def get_skills(self) -> list[SkillConfigDict]:
"""Return the cached skill configs."""
return list(self._skills)
def get_workflows(self) -> list[dict[str, Any]]:
def get_workflows(self) -> list[WorkflowConfigDict]:
"""Return the cached workflow configs."""
return list(self._workflows)
def get_all(self) -> dict[str, Any]:
def get_all(self) -> SyncedConfigPayload:
"""Return all cached configs as a single dict."""
return {
"version": self._version,
@ -249,14 +256,14 @@ class ConfigSync:
"synced_at": self._last_synced_at,
}
def get_skill(self, name: str) -> dict[str, Any] | None:
def get_skill(self, name: str) -> SkillConfigDict | None:
"""Return a single skill config by name, or ``None``."""
for skill in self._skills:
if skill.get("name") == name:
return skill
return None
def get_workflow(self, workflow_id: str) -> dict[str, Any] | None:
def get_workflow(self, workflow_id: str) -> WorkflowConfigDict | None:
"""Return a single workflow config by ID, or ``None``."""
for wf in self._workflows:
if wf.get("workflow_id") == workflow_id:
@ -281,7 +288,7 @@ class ConfigSync:
conn.executescript(_CACHE_SCHEMA)
conn.commit()
def _save_to_cache(self, data: dict[str, Any]) -> None:
def _save_to_cache(self, data: SyncedConfigPayload) -> None:
"""Save the synced configs to the local SQLite cache."""
now = datetime.now(timezone.utc).isoformat()
with sqlite3.connect(str(self.cache_db_path)) as conn:

View File

@ -14,7 +14,7 @@ import logging
import time
from collections import OrderedDict
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
from typing import TYPE_CHECKING, Protocol, runtime_checkable
from agentkit.llm.protocol import LLMResponse, TokenUsage, ToolCall
from agentkit.utils.vector_math import compute_cosine_similarity
@ -25,6 +25,52 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# TYPE_CHECKING Protocols — 避免 Any描述运行时 lazy import 的第三方对象
# ---------------------------------------------------------------------------
if TYPE_CHECKING:
class _RedisLike(Protocol):
"""Redis 客户端最小契约(仅覆盖本模块用到的方法)。"""
async def get(self, key: str) -> bytes | str | None: ...
async def mget(self, keys: list[str]) -> list[bytes | str | None]: ...
async def set(self, key: str, value: bytes | str, ex: int | None = None) -> None: ...
async def smembers(self, key: str) -> set[bytes | str]: ...
async def sadd(self, name: str, *values: str) -> int: ...
async def srem(self, name: str, *values: str) -> int: ...
async def scard(self, name: str) -> int: ...
async def delete(self, *names: str) -> int: ...
def pipeline(self) -> "_RedisPipelineLike": ...
class _RedisPipelineLike(Protocol):
"""Redis pipeline 最小契约。"""
def get(self, key: str) -> "_RedisPipelineLike": ...
def set(
self, key: str, value: bytes | str, ex: int | None = None
) -> "_RedisPipelineLike": ...
def delete(self, *names: str) -> "_RedisPipelineLike": ...
def sadd(self, name: str, *values: str) -> "_RedisPipelineLike": ...
def srem(self, name: str, *values: str) -> "_RedisPipelineLike": ...
async def execute(self) -> list[object]: ...
# ---------------------------------------------------------------------------
# Data Classes
# ---------------------------------------------------------------------------
@ -328,7 +374,7 @@ class RedisLLMCache:
self._semantic_ttl = semantic_ttl
self._similarity_threshold = similarity_threshold
self._max_entries_to_scan = max_entries_to_scan
self._redis: Any = None
self._redis: _RedisLike | None = None
self._fallback: InMemoryLLMCache | None = fallback # For auto-degradation
self._degraded = False # True if Redis is unreachable
@ -691,7 +737,7 @@ class LitellmCacheManager:
def __init__(self, config: LitellmCacheConfig):
self._config = config
self._cache_instance: Any = None # litellm.caching.Cache 实例
self._cache_instance: object | None = None # litellm.caching.Cache 实例
self._hits = 0
self._misses = 0
@ -709,7 +755,7 @@ class LitellmCacheManager:
litellm.cache = None
self._cache_instance = None
def _create_cache_instance(self) -> Any:
def _create_cache_instance(self) -> object:
"""根据 backend 配置创建 LiteLLM Cache 实例。
auto 模式按优先级尝试RedisSemanticCache RedisCache InMemoryCache

View File

@ -2,14 +2,13 @@
import hashlib
import json
from typing import Any
def generate_cache_key(
model: str,
messages: list[dict[str, str]],
temperature: float,
tools: list[dict[str, Any]] | None = None,
tools: list[dict[str, object]] | None = None,
tool_choice: str = "auto",
max_tokens: int = 2000,
user_id: str | None = None,

View File

@ -3,12 +3,12 @@
import json
import logging
from dataclasses import dataclass, field
from typing import Any, TYPE_CHECKING
from typing import TYPE_CHECKING
from agentkit.llm.retry import CircuitBreakerConfig, RetryConfig
if TYPE_CHECKING:
from agentkit.channels.secrets import SecretsStore
from agentkit.channels.secrets import SecretEntry, SecretsStore
logger = logging.getLogger(__name__)
@ -56,7 +56,7 @@ class ProviderConfig:
api_key: str
base_url: str
models: dict[str, dict[str, Any]] = field(default_factory=dict)
models: dict[str, dict[str, object]] = field(default_factory=dict)
type: str = "openai" # "openai" | "anthropic" | "gemini"
max_tokens: int = 4096 # Anthropic: default max_tokens
timeout: float = 120.0 # Anthropic: request timeout
@ -168,18 +168,18 @@ class ProviderConfig:
return f"llm:provider:{self.type}:api_key"
@staticmethod
def _encode_secret_entry(entry: Any, key: str) -> str:
def _encode_secret_entry(entry: object, key: str) -> str:
"""把 SecretEntry 编码为 JSON 字符串(含 key 字段)。"""
# entry 是 SecretEntry pydantic 模型,有 model_dump()
if hasattr(entry, "model_dump"):
data = entry.model_dump()
data = entry.model_dump() # type: ignore[attr-defined]
else:
data = dict(entry)
data = dict(entry) # type: ignore[call-overload]
data["key"] = key
return json.dumps(data)
@staticmethod
def _decode_secret_entry(encoded: str) -> Any:
def _decode_secret_entry(encoded: str) -> "SecretEntry":
"""从 JSON 字符串解码 SecretEntry。返回带 .key 属性的对象。"""
from agentkit.channels.secrets import SecretEntry

View File

@ -5,7 +5,7 @@ import logging
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from typing import TYPE_CHECKING, Protocol
from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError
from agentkit.llm.config import LLMConfig
@ -14,9 +14,40 @@ from agentkit.llm.providers.tracker import UsageSummary, UsageTracker
from agentkit.telemetry.tracing import get_tracer, _OTEL_AVAILABLE
from agentkit.telemetry.metrics import llm_token_histogram
if TYPE_CHECKING:
from agentkit.llm.cache import LitellmCacheManager
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# TYPE_CHECKING Protocols — 避免 Any描述运行时 lazy import 的对象
# ---------------------------------------------------------------------------
if TYPE_CHECKING:
class _QuotaServiceLike(Protocol):
"""Quota service 最小契约(仅覆盖 gateway._enforce_quota 用到的方法)。"""
async def is_model_allowed(
self, db: Path, department_id: str, model: str
) -> tuple[bool, str]: ...
async def check_quota(
self,
db: Path,
department_id: str,
quota_type: str,
period: str,
current: float,
) -> tuple[bool, str]: ...
async def get_quota(
self, db: Path, department_id: str, quota_type: str, period: str
) -> dict[str, object] | None: ...
class QuotaExceededError(Exception):
"""Raised when a department's LLM quota is exceeded.
@ -29,8 +60,8 @@ class QuotaExceededError(Exception):
department_id: str,
quota_type: str,
period: str,
limit: Any,
current: Any,
limit: object,
current: object,
) -> None:
self.department_id = department_id
self.quota_type = quota_type
@ -46,13 +77,13 @@ class QuotaExceededError(Exception):
class LLMGateway:
"""LLM 网关 - Provider 注册、模型别名解析、Fallback、Usage 追踪、Cache"""
def __init__(self, config: LLMConfig | None = None, usage_store: Any = None):
def __init__(self, config: LLMConfig | None = None, usage_store: object | None = None):
self._providers: dict[str, LLMProvider] = {}
self._usage_tracker = UsageTracker(store=usage_store) if usage_store else UsageTracker()
self._config = config or LLMConfig()
# Cache (U17 — LiteLLM 缓存管理器opt-in默认禁用)
self._cache_manager: Any = None # LitellmCacheManager | None
self._cache_manager: "LitellmCacheManager | None" = None
if self._config.cache and self._config.cache.enabled:
from agentkit.llm.cache import LitellmCacheConfig, LitellmCacheManager
@ -601,7 +632,7 @@ class LLMGateway:
async def _check_quota_value(
self,
quota_service: Any,
quota_service: _QuotaServiceLike,
db: Path,
dept_id: str,
period: str,

View File

@ -27,12 +27,11 @@ from __future__ import annotations
import logging
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
def migrate_api_keys_to_secrets(config_path: Path | str) -> dict[str, dict[str, Any]]:
def migrate_api_keys_to_secrets(config_path: Path | str) -> dict[str, dict[str, object]]:
"""把 agentkit.yaml 中的 plaintext API Key 迁移到 SecretsStore。
流程
@ -63,8 +62,8 @@ def migrate_api_keys_to_secrets(config_path: Path | str) -> dict[str, dict[str,
store = SecretsStore() # master key 从 env 加载
async def _run() -> dict[str, dict[str, Any]]:
report: dict[str, dict[str, Any]] = {}
async def _run() -> dict[str, dict[str, object]]:
report: dict[str, dict[str, object]] = {}
for name, pconf in llm_config.providers.items():
if pconf.api_key_source == "secrets_store" and not pconf.api_key:
report[name] = {"status": "skipped", "source": pconf.api_key_source}
@ -93,9 +92,9 @@ def migrate_api_keys_to_secrets(config_path: Path | str) -> dict[str, dict[str,
report = asyncio.run(_run())
# 写回 YAML更新 llm.providers 段,保留其它段
providers_out: dict[str, dict[str, Any]] = {}
providers_out: dict[str, dict[str, object]] = {}
for name, pconf in llm_config.providers.items():
entry: dict[str, Any] = {
entry: dict[str, object] = {
"type": pconf.type,
"base_url": pconf.base_url,
"models": pconf.models,

View File

@ -2,7 +2,6 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any
@dataclass
@ -23,7 +22,7 @@ class ToolCall:
id: str
name: str
arguments: dict[str, Any]
arguments: dict[str, object]
@dataclass
@ -32,7 +31,7 @@ class LLMRequest:
messages: list[dict[str, str]]
model: str
tools: list[dict[str, Any]] | None = None
tools: list[dict[str, object]] | None = None
tool_choice: str = "auto"
temperature: float = 0.7
max_tokens: int = 2000
@ -42,13 +41,13 @@ class LLMRequest:
self,
messages: list[dict[str, str]],
model: str,
tools: list[dict[str, Any]] | None = None,
tools: list[dict[str, object]] | None = None,
tool_choice: str = "auto",
temperature: float = 0.7,
max_tokens: int = 2000,
timeout: float | None = None,
cache: dict[str, Any] | None = None,
**kwargs: Any,
cache: dict[str, object] | None = None,
**kwargs: object,
):
self.messages = messages
self.model = model
@ -59,7 +58,7 @@ class LLMRequest:
self.timeout = timeout
self._extra = kwargs
# U17 — LiteLLM cache 参数cache_key 或 no-cache透传到 litellm.acompletion
self._cache: dict[str, Any] | None = cache
self._cache: dict[str, object] | None = cache
@dataclass

View File

@ -3,7 +3,6 @@
import json
import logging
import time
from typing import Any
import httpx
@ -99,7 +98,9 @@ class AnthropicProvider(LLMProvider):
"content-type": "application/json",
}
def _convert_messages(self, messages: list[dict[str, str]]) -> tuple[str | list[dict[str, Any]] | None, list[dict[str, Any]]]:
def _convert_messages(
self, messages: list[dict[str, str]]
) -> tuple[str | list[dict[str, object]] | None, list[dict[str, object]]]:
"""将 OpenAI 风格消息转换为 Anthropic 格式
Returns:
@ -110,8 +111,8 @@ class AnthropicProvider(LLMProvider):
- list[dict]: Anthropic content blocks(支持 cache_control,U2/G2)
- None: system 消息
"""
system_prompt: str | list[dict[str, Any]] | None = None
anthropic_messages: list[dict[str, Any]] = []
system_prompt: str | list[dict[str, object]] | None = None
anthropic_messages: list[dict[str, object]] = []
for msg in messages:
role = msg.get("role", "")
@ -127,7 +128,7 @@ class AnthropicProvider(LLMProvider):
# 检查是否有 tool_calls (OpenAI 格式)
tool_calls = msg.get("tool_calls")
if tool_calls:
blocks: list[dict[str, Any]] = []
blocks: list[dict[str, object]] = []
# 如果有文本内容,先添加文本块
if content:
blocks.append({"type": "text", "text": content})
@ -139,25 +140,29 @@ class AnthropicProvider(LLMProvider):
arguments = json.loads(arguments)
except json.JSONDecodeError:
arguments = {"raw": arguments}
blocks.append({
blocks.append(
{
"type": "tool_use",
"id": tc.get("id", ""),
"name": func.get("name", ""),
"input": arguments,
})
}
)
anthropic_messages.append({"role": "assistant", "content": blocks})
else:
anthropic_messages.append({
anthropic_messages.append(
{
"role": "assistant",
"content": [{"type": "text", "text": content}],
})
}
)
continue
if role == "user":
# 检查是否是 tool_result 消息 (OpenAI 格式中 tool 角色的结果)
# OpenAI 格式: {"role": "tool", "tool_call_id": "...", "content": "..."}
if msg.get("tool_call_id"):
tool_result_blocks: list[dict[str, Any]] = []
tool_result_blocks: list[dict[str, object]] = []
tool_content = msg.get("content", "")
# tool_result 的 content 可以是字符串或内容块列表
if isinstance(tool_content, str):
@ -167,56 +172,72 @@ class AnthropicProvider(LLMProvider):
else:
tool_result_blocks.append({"type": "text", "text": str(tool_content)})
anthropic_messages.append({
anthropic_messages.append(
{
"role": "user",
"content": [{
"content": [
{
"type": "tool_result",
"tool_use_id": msg.get("tool_call_id", ""),
"content": tool_result_blocks,
}],
})
}
],
}
)
else:
anthropic_messages.append({
anthropic_messages.append(
{
"role": "user",
"content": [{"type": "text", "text": content}],
})
}
)
continue
if role == "tool":
# OpenAI 格式中独立的 tool 消息
tool_content = msg.get("content", "")
if isinstance(tool_content, str):
result_content: list[dict[str, Any]] | str = [{"type": "text", "text": tool_content}]
result_content: list[dict[str, object]] | str = [
{"type": "text", "text": tool_content}
]
elif isinstance(tool_content, list):
result_content = tool_content
else:
result_content = [{"type": "text", "text": str(tool_content)}]
anthropic_messages.append({
anthropic_messages.append(
{
"role": "user",
"content": [{
"content": [
{
"type": "tool_result",
"tool_use_id": msg.get("tool_call_id", ""),
"content": result_content,
}],
})
}
],
}
)
return system_prompt, anthropic_messages
def _convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
def _convert_tools(self, tools: list[dict[str, object]]) -> list[dict[str, object]]:
"""将 OpenAI function 格式转换为 Anthropic tool 格式"""
anthropic_tools = []
for tool in tools:
if tool.get("type") == "function":
func = tool.get("function", {})
anthropic_tools.append({
anthropic_tools.append(
{
"name": func.get("name", ""),
"description": func.get("description", ""),
"input_schema": func.get("parameters", {"type": "object", "properties": {}}),
})
"input_schema": func.get(
"parameters", {"type": "object", "properties": {}}
),
}
)
return anthropic_tools
def _convert_tool_choice(self, tool_choice: str) -> dict[str, Any] | None:
def _convert_tool_choice(self, tool_choice: str) -> dict[str, object] | None:
"""将 OpenAI tool_choice 格式转换为 Anthropic 格式"""
if tool_choice == "auto":
return {"type": "auto"}
@ -227,7 +248,7 @@ class AnthropicProvider(LLMProvider):
return {"type": "tool", "name": tool_choice}
return None
def _parse_response(self, data: dict[str, Any], model: str) -> LLMResponse:
def _parse_response(self, data: dict[str, object], model: str) -> LLMResponse:
"""将 Anthropic 响应转换为 LLMResponse"""
content_blocks = data.get("content", [])
text_parts: list[str] = []
@ -238,11 +259,13 @@ class AnthropicProvider(LLMProvider):
if block_type == "text":
text_parts.append(block.get("text", ""))
elif block_type == "tool_use":
tool_calls.append(ToolCall(
tool_calls.append(
ToolCall(
id=block.get("id", ""),
name=block.get("name", ""),
arguments=block.get("input", {}),
))
)
)
usage_data = data.get("usage", {})
usage = TokenUsage(
@ -287,7 +310,7 @@ class AnthropicProvider(LLMProvider):
system_prompt, anthropic_messages = self._convert_messages(request.messages)
payload: dict[str, Any] = {
payload: dict[str, object] = {
"model": request.model,
"max_tokens": request.max_tokens or self._max_tokens,
"messages": anthropic_messages,
@ -346,7 +369,7 @@ class AnthropicProvider(LLMProvider):
system_prompt, anthropic_messages = self._convert_messages(request.messages)
payload: dict[str, Any] = {
payload: dict[str, object] = {
"model": request.model,
"max_tokens": request.max_tokens or self._max_tokens,
"messages": anthropic_messages,
@ -375,7 +398,7 @@ class AnthropicProvider(LLMProvider):
async def _iterate_stream(self, response, request: LLMRequest):
"""Iterate over an already-open SSE stream and yield StreamChunks."""
# Accumulated tool calls: tool_use_id -> {id, name, input_json_str}
accumulated_tool_calls: dict[str, dict[str, Any]] = {}
accumulated_tool_calls: dict[str, dict[str, object]] = {}
current_tool_id: str | None = None
current_tool_name: str | None = None
current_tool_input_json: str = ""
@ -433,7 +456,9 @@ class AnthropicProvider(LLMProvider):
# Finalize current tool call if any
if current_tool_id is not None:
try:
arguments = json.loads(current_tool_input_json) if current_tool_input_json else {}
arguments = (
json.loads(current_tool_input_json) if current_tool_input_json else {}
)
except json.JSONDecodeError:
arguments = {"raw": current_tool_input_json}
@ -510,7 +535,7 @@ class AnthropicProvider(LLMProvider):
error_msg = error_info.get("message", "Stream error")
raise LLMProviderError("anthropic", error_msg)
def get_model_info(self) -> dict[str, Any]:
def get_model_info(self) -> dict[str, object]:
"""返回 Provider 和模型信息"""
return {
"provider": "anthropic",

View File

@ -8,7 +8,6 @@ API火山引擎 OpenAI 兼容接口
from __future__ import annotations
import logging
from typing import Any
from agentkit.llm.providers.openai import OpenAICompatibleProvider
@ -48,7 +47,7 @@ class DoubaoProvider(OpenAICompatibleProvider):
api_key: str,
base_url: str = DOUBAO_DEFAULT_BASE_URL,
default_model: str = "doubao-pro-32k",
**kwargs: Any,
**kwargs: object,
):
super().__init__(
api_key=api_key,

View File

@ -3,7 +3,6 @@
import json
import logging
import time
from typing import Any
import httpx
@ -90,14 +89,14 @@ class GeminiProvider(LLMProvider):
def _convert_messages(
self, messages: list[dict[str, str]]
) -> tuple[dict[str, Any] | None, list[dict[str, Any]]]:
) -> tuple[dict[str, object] | None, list[dict[str, object]]]:
"""将 OpenAI 风格消息转换为 Gemini 格式
Returns:
(system_instruction, contents)
"""
system_instruction: dict[str, Any] | None = None
contents: list[dict[str, Any]] = []
system_instruction: dict[str, object] | None = None
contents: list[dict[str, object]] = []
for msg in messages:
role = msg.get("role", "")
@ -119,28 +118,34 @@ class GeminiProvider(LLMProvider):
tool_name = parsed.get("name", "")
except (json.JSONDecodeError, AttributeError):
pass
contents.append({
contents.append(
{
"role": "user",
"parts": [{
"parts": [
{
"functionResponse": {
"name": tool_name,
"response": {
"content": content,
},
},
}],
})
}
],
}
)
else:
contents.append({
contents.append(
{
"role": "user",
"parts": [{"text": content}],
})
}
)
continue
if role == "assistant":
tool_calls = msg.get("tool_calls")
if tool_calls:
parts: list[dict[str, Any]] = []
parts: list[dict[str, object]] = []
if content:
parts.append({"text": content})
for tc in tool_calls:
@ -151,54 +156,64 @@ class GeminiProvider(LLMProvider):
arguments = json.loads(arguments)
except json.JSONDecodeError:
arguments = {"raw": arguments}
parts.append({
parts.append(
{
"functionCall": {
"name": func.get("name", ""),
"args": arguments,
},
})
}
)
contents.append({"role": "model", "parts": parts})
else:
contents.append({
contents.append(
{
"role": "model",
"parts": [{"text": content}],
})
}
)
continue
if role == "tool":
# OpenAI format: {"role": "tool", "tool_call_id": "...", "content": "..."}
tool_name = msg.get("name", "")
tool_content = msg.get("content", "")
contents.append({
contents.append(
{
"role": "user",
"parts": [{
"parts": [
{
"functionResponse": {
"name": tool_name,
"response": {
"content": tool_content,
},
},
}],
})
}
],
}
)
return system_instruction, contents
def _convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
def _convert_tools(self, tools: list[dict[str, object]]) -> list[dict[str, object]]:
"""将 OpenAI function 格式转换为 Gemini functionDeclarations"""
declarations = []
for tool in tools:
if tool.get("type") == "function":
func = tool.get("function", {})
declarations.append({
declarations.append(
{
"name": func.get("name", ""),
"description": func.get("description", ""),
"parameters": func.get("parameters", {"type": "object", "properties": {}}),
})
}
)
if not declarations:
return []
return [{"functionDeclarations": declarations}]
def _convert_tool_choice(self, tool_choice: str) -> dict[str, Any] | None:
def _convert_tool_choice(self, tool_choice: str) -> dict[str, object] | None:
"""将 OpenAI tool_choice 格式转换为 Gemini toolConfig"""
if tool_choice == "auto":
return {"functionCallingConfig": {"mode": "AUTO"}}
@ -210,7 +225,7 @@ class GeminiProvider(LLMProvider):
return {"functionCallingConfig": {"mode": "NONE"}}
return None
def _parse_response(self, data: dict[str, Any], model: str) -> LLMResponse:
def _parse_response(self, data: dict[str, object], model: str) -> LLMResponse:
"""将 Gemini 响应转换为 LLMResponse"""
candidates = data.get("candidates", [])
text_parts: list[str] = []
@ -225,11 +240,13 @@ class GeminiProvider(LLMProvider):
text_parts.append(part["text"])
elif "functionCall" in part:
fc = part["functionCall"]
tool_calls.append(ToolCall(
tool_calls.append(
ToolCall(
id=f"call_{tool_call_index}",
name=fc.get("name", ""),
arguments=fc.get("args", {}),
))
)
)
tool_call_index += 1
usage_metadata = data.get("usageMetadata", {})
@ -275,7 +292,7 @@ class GeminiProvider(LLMProvider):
system_instruction, contents = self._convert_messages(request.messages)
payload: dict[str, Any] = {
payload: dict[str, object] = {
"contents": contents,
"generationConfig": {
"temperature": request.temperature,
@ -340,7 +357,7 @@ class GeminiProvider(LLMProvider):
system_instruction, contents = self._convert_messages(request.messages)
payload: dict[str, Any] = {
payload: dict[str, object] = {
"contents": contents,
"generationConfig": {
"temperature": request.temperature,
@ -374,7 +391,7 @@ class GeminiProvider(LLMProvider):
async def _iterate_stream(self, response, request: LLMRequest):
"""Iterate over an already-open SSE stream and yield StreamChunks."""
accumulated_tool_calls: list[dict[str, Any]] = []
accumulated_tool_calls: list[dict[str, object]] = []
model = request.model or self._model
async for line in response.aiter_lines():
@ -436,11 +453,13 @@ class GeminiProvider(LLMProvider):
)
elif "functionCall" in part:
fc = part["functionCall"]
accumulated_tool_calls.append({
accumulated_tool_calls.append(
{
"id": f"call_{len(accumulated_tool_calls)}",
"name": fc.get("name", ""),
"arguments": fc.get("args", {}),
})
}
)
# Check for finish reason
finish_reason = candidates[0].get("finishReason", "")
@ -461,7 +480,7 @@ class GeminiProvider(LLMProvider):
)
accumulated_tool_calls = []
def get_model_info(self) -> dict[str, Any]:
def get_model_info(self) -> dict[str, object]:
"""返回 Provider 和模型信息"""
return {
"provider": "gemini",

View File

@ -26,8 +26,7 @@ import inspect
import json
import logging
import time
from collections.abc import AsyncGenerator
from typing import Any
from collections.abc import AsyncGenerator, Iterable
from agentkit.core.exceptions import LLMProviderError
from agentkit.llm.protocol import (
@ -81,13 +80,13 @@ class LitellmProvider(LLMProvider):
api_key: str,
base_url: str | None = None,
provider_type: str = "openai",
**default_kwargs: Any,
**default_kwargs: object,
) -> None:
self._model_prefix = model_prefix
self._api_key = api_key
self._base_url = base_url or None # 空字符串视作未设置
self._provider_type = provider_type
self._default_kwargs: dict[str, Any] = dict(default_kwargs)
self._default_kwargs: dict[str, object] = dict(default_kwargs)
async def chat(self, request: LLMRequest) -> LLMResponse:
"""非流式 chat — 调用 ``litellm.acompletion`` 并翻译响应。"""
@ -116,7 +115,7 @@ class LitellmProvider(LLMProvider):
kwargs = self._build_kwargs(request, stream=True)
accumulated_tool_calls: dict[int, dict[str, Any]] = {}
accumulated_tool_calls: dict[int, dict[str, object]] = {}
final_usage: TokenUsage | None = None
final_model: str = request.model
@ -158,9 +157,9 @@ class LitellmProvider(LLMProvider):
# 内部辅助
# ------------------------------------------------------------------
def _build_kwargs(self, request: LLMRequest, *, stream: bool) -> dict[str, Any]:
def _build_kwargs(self, request: LLMRequest, *, stream: bool) -> dict[str, object]:
"""从 LLMRequest 构造 litellm.acompletion kwargs。"""
kwargs: dict[str, Any] = {
kwargs: dict[str, object] = {
"model": f"{self._model_prefix}{request.model}",
"messages": request.messages,
"temperature": request.temperature,
@ -187,7 +186,7 @@ class LitellmProvider(LLMProvider):
def _parse_response(
self,
response: Any,
response: object,
request_model: str,
latency_ms: float,
) -> LLMResponse:
@ -229,9 +228,9 @@ class LitellmProvider(LLMProvider):
def _parse_stream_chunk(
self,
chunk: Any,
chunk: object,
request_model: str,
accumulated_tool_calls: dict[int, dict[str, Any]],
accumulated_tool_calls: dict[int, dict[str, object]],
) -> StreamChunk:
"""解析单个流式 chunk非 final。累计 tool_calls 到传入字典。"""
choices = getattr(chunk, "choices", None) or []
@ -262,7 +261,7 @@ class LitellmProvider(LLMProvider):
def _finalize_tool_calls(
self,
accumulated: dict[int, dict[str, Any]],
accumulated: dict[int, dict[str, object]],
) -> list[ToolCall]:
"""把累计的流式 tool_calls 字典转成 ToolCall 列表。"""
tool_calls: list[ToolCall] = []
@ -288,7 +287,7 @@ class LitellmProvider(LLMProvider):
# ----------------------------------------------------------------------
def _parse_tool_calls(raw_tool_calls: Any) -> list[ToolCall]:
def _parse_tool_calls(raw_tool_calls: Iterable[object]) -> list[ToolCall]:
"""解析非流式响应的 tool_callsOpenAI 格式 list[ChoiceMessageToolCall])。"""
result: list[ToolCall] = []
for tc in raw_tool_calls:
@ -312,7 +311,7 @@ def _parse_tool_calls(raw_tool_calls: Any) -> list[ToolCall]:
return result
def _parse_usage(usage_obj: Any) -> TokenUsage:
def _parse_usage(usage_obj: object) -> TokenUsage:
"""解析 usage 对象OpenAI CompletionUsage 或 dict"""
prompt = getattr(usage_obj, "prompt_tokens", None)
completion = getattr(usage_obj, "completion_tokens", None)
@ -326,8 +325,8 @@ def _parse_usage(usage_obj: Any) -> TokenUsage:
def _accumulate_stream_tool_calls(
raw_tool_calls: Any,
accumulated: dict[int, dict[str, Any]],
raw_tool_calls: Iterable[object],
accumulated: dict[int, dict[str, object]],
) -> None:
"""累计流式 chunk 里的 tool_calls 片段OpenAI delta.tool_calls 格式)。
@ -364,7 +363,7 @@ def create_litellm_provider(
provider_type: str,
api_key: str,
base_url: str | None = None,
**kwargs: Any,
**kwargs: object,
) -> LitellmProvider:
"""根据 provider_type 创建 LitellmProvider 实例。

View File

@ -18,7 +18,7 @@ import json
import logging
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from typing import Any, Protocol, runtime_checkable
from typing import Protocol, runtime_checkable
from agentkit.llm.protocol import TokenUsage
@ -294,8 +294,8 @@ class RedisUsageStore:
def __init__(self, redis_url: str = "redis://localhost:6379"):
self._redis_url = redis_url
self._redis: Any = None
self._sync_redis: Any = None
self._redis: object | None = None
self._sync_redis: object | None = None
self._fallback: InMemoryUsageStore | None = None
self._degraded = False
self._health_check_task: asyncio.Task[None] | None = None
@ -687,7 +687,7 @@ class RedisUsageStore:
@staticmethod
def _read_list(
r: Any,
r: object,
list_key: str,
start: datetime,
end: datetime,

View File

@ -9,7 +9,6 @@ from __future__ import annotations
import logging
import time
from typing import Any
from agentkit.llm.providers.openai import OpenAICompatibleProvider
from agentkit.llm.protocol import LLMRequest, LLMResponse
@ -51,7 +50,7 @@ class WenxinProvider(OpenAICompatibleProvider):
secret_key: str | None = None,
base_url: str = WENXIN_DEFAULT_BASE_URL,
default_model: str = "ernie-4.5-turbo-128k",
**kwargs: Any,
**kwargs: object,
):
# If AK/SK provided, use token-based auth
self._access_key = access_key

View File

@ -8,7 +8,6 @@ API腾讯云 OpenAI 兼容接口
from __future__ import annotations
import logging
from typing import Any
from agentkit.llm.providers.openai import OpenAICompatibleProvider
from agentkit.llm.protocol import LLMRequest, LLMResponse
@ -48,7 +47,7 @@ class YuanbaoProvider(OpenAICompatibleProvider):
base_url: str = YUANBAO_DEFAULT_BASE_URL,
default_model: str = "hunyuan-turbos-latest",
enable_enhancement: bool = False,
**kwargs: Any,
**kwargs: object,
):
self._enable_enhancement = enable_enhancement
super().__init__(

View File

@ -3,7 +3,6 @@
import json
import logging
from collections.abc import AsyncIterator, Awaitable, Callable
from typing import Any
import httpx
@ -66,7 +65,7 @@ class RemoteLLMProvider(LLMProvider):
"Content-Type": "application/json",
}
def _build_payload(self, request: LLMRequest) -> dict[str, Any]:
def _build_payload(self, request: LLMRequest) -> dict[str, object]:
"""Convert LLMRequest to server API payload."""
return {
"messages": request.messages,
@ -91,7 +90,7 @@ class RemoteLLMProvider(LLMProvider):
return str(body["error"])
return str(body)
def _parse_response(self, data: dict[str, Any], request: LLMRequest) -> LLMResponse:
def _parse_response(self, data: dict[str, object], request: LLMRequest) -> LLMResponse:
"""Parse server response JSON into an LLMResponse."""
usage_data = data.get("usage") or {}
usage = TokenUsage(
@ -115,7 +114,7 @@ class RemoteLLMProvider(LLMProvider):
latency_ms=data.get("latency_ms", 0.0),
)
def _parse_chunk(self, data: dict[str, Any], request: LLMRequest) -> StreamChunk:
def _parse_chunk(self, data: dict[str, object], request: LLMRequest) -> StreamChunk:
"""Parse a single SSE data payload into a StreamChunk."""
usage: TokenUsage | None = None
usage_data = data.get("usage")
@ -218,9 +217,7 @@ class RemoteLLMProvider(LLMProvider):
if response.status_code == 502:
await response.aread()
detail = self._extract_error_detail(response)
raise LLMProviderError(
"remote", f"Server LLM gateway error: {detail}"
)
raise LLMProviderError("remote", f"Server LLM gateway error: {detail}")
if response.status_code != 200:
await response.aread()
raise LLMProviderError(

View File

@ -5,7 +5,7 @@ import logging
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable
from typing import Callable
from agentkit.core.exceptions import LLMProviderError
@ -20,9 +20,7 @@ class RetryConfig:
base_delay: float = 1.0
max_delay: float = 30.0
exponential_base: float = 2.0
retryable_status_codes: set[int] = field(
default_factory=lambda: {429, 500, 502, 503, 529}
)
retryable_status_codes: set[int] = field(default_factory=lambda: {429, 500, 502, 503, 529})
class CircuitState(Enum):
@ -69,7 +67,7 @@ class RetryPolicy:
def __init__(self, config: RetryConfig | None = None):
self._config = config or RetryConfig()
async def execute(self, fn: Callable, *args: Any, **kwargs: Any) -> Any:
async def execute(self, fn: Callable, *args: object, **kwargs: object) -> object:
"""Execute fn with retry on retryable errors."""
last_error: Exception | None = None
@ -84,7 +82,7 @@ class RetryPolicy:
raise
delay = min(
self._config.base_delay * (self._config.exponential_base ** attempt),
self._config.base_delay * (self._config.exponential_base**attempt),
self._config.max_delay,
)
logger.warning(
@ -142,7 +140,7 @@ class CircuitBreaker:
f"after {self._failure_count} failures"
)
async def execute(self, fn: Callable, *args: Any, **kwargs: Any) -> Any:
async def execute(self, fn: Callable, *args: object, **kwargs: object) -> object:
"""Execute fn through the circuit breaker."""
current_state = self.state
@ -158,6 +156,6 @@ class CircuitBreaker:
result = await fn(*args, **kwargs)
self._on_success()
return result
except Exception as e:
except Exception:
self._on_failure()
raise

View File

@ -10,7 +10,6 @@ from __future__ import annotations
import logging
from abc import ABC, abstractmethod
from typing import Any
import httpx
@ -95,8 +94,7 @@ class KBAdapter(ABC):
async def delete_by_id(self, id: str) -> bool:
"""按文档 ID 删除(子类可覆盖)"""
logger.warning(
f"{self.__class__.__name__} does not support delete_by_id; "
f"id '{id}' skipped"
f"{self.__class__.__name__} does not support delete_by_id; id '{id}' skipped"
)
return False
@ -127,8 +125,7 @@ class KBAdapter(ABC):
async def get_document(self, doc_id: str) -> Document | None:
"""按 ID 获取单个文档(子类可覆盖)"""
logger.warning(
f"{self.__class__.__name__} does not support get_document; "
f"doc_id '{doc_id}' not found"
f"{self.__class__.__name__} does not support get_document; doc_id '{doc_id}' not found"
)
return None
@ -156,5 +153,5 @@ class KBAdapter(ABC):
async def __aenter__(self) -> KBAdapter:
return self
async def __aexit__(self, *args: Any) -> None:
async def __aexit__(self, *args: object) -> None:
await self.close()

View File

@ -7,7 +7,6 @@
from __future__ import annotations
import logging
from typing import Any
import httpx
@ -56,7 +55,9 @@ class ConfluenceAdapter(KBAdapter):
)
self._base_url = base_url.rstrip("/")
if not is_safe_url(self._base_url):
raise ValueError(f"Unsafe base_url: {self._base_url}. Private/internal URLs are not allowed.")
raise ValueError(
f"Unsafe base_url: {self._base_url}. Private/internal URLs are not allowed."
)
self._username = username
self._api_token = api_token
self._space_keys = space_keys or []
@ -65,9 +66,7 @@ class ConfluenceAdapter(KBAdapter):
"""创建 Confluence API HTTP 客户端"""
import base64
credentials = base64.b64encode(
f"{self._username}:{self._api_token}".encode()
).decode()
credentials = base64.b64encode(f"{self._username}:{self._api_token}".encode()).decode()
return httpx.AsyncClient(
base_url=self._base_url,
headers={
@ -101,7 +100,7 @@ class ConfluenceAdapter(KBAdapter):
space_filter = " OR ".join(
f'space = "{_escape_cql(key)}"' for key in self._space_keys
)
cql = f'{cql} AND ({space_filter})'
cql = f"{cql} AND ({space_filter})"
resp = await client.get(
"/rest/api/content/search",
@ -115,6 +114,7 @@ class ConfluenceAdapter(KBAdapter):
body = page.get("body", {}).get("storage", {}).get("value", "")
# Strip HTML tags for plain text content
import re
content = re.sub(r"<[^>]+>", "", body) if body else page.get("title", "")
results.append(
@ -136,8 +136,7 @@ class ConfluenceAdapter(KBAdapter):
except httpx.HTTPStatusError as e:
logger.error(
f"Confluence search HTTP error: {e.response.status_code}"
f"{e.response.text[:200]}"
f"Confluence search HTTP error: {e.response.status_code}{e.response.text[:200]}"
)
return []
except Exception as e:
@ -157,6 +156,7 @@ class ConfluenceAdapter(KBAdapter):
body = page.get("body", {}).get("storage", {}).get("value", "")
import re
content = re.sub(r"<[^>]+>", "", body) if body else ""
return Document(
@ -191,13 +191,17 @@ class ConfluenceAdapter(KBAdapter):
source_type="confluence",
)
)
return sources if sources else [
return (
sources
if sources
else [
SourceInfo(
source_id=self._source_id,
source_name=self._source_name,
source_type=self._source_type,
)
]
)
except Exception as e:
logger.error(f"Confluence list_sources error: {e}")
return [

View File

@ -8,7 +8,7 @@ from __future__ import annotations
import logging
import time
from typing import Any
from typing import TypeAlias
import httpx
@ -18,6 +18,9 @@ from agentkit.utils.security import is_safe_url
logger = logging.getLogger(__name__)
# 飞书搜索请求 payloadsearch_key/page_size/wiki_space_ids — 值为 str|int|list[str]。
FeishuSearchPayload: TypeAlias = dict[str, object]
class FeishuKBAdapter(KBAdapter):
"""飞书知识库适配器
@ -54,7 +57,9 @@ class FeishuKBAdapter(KBAdapter):
self._app_secret = app_secret
self._base_url = base_url.rstrip("/")
if not is_safe_url(self._base_url):
raise ValueError(f"Unsafe base_url: {self._base_url}. Private/internal URLs are not allowed.")
raise ValueError(
f"Unsafe base_url: {self._base_url}. Private/internal URLs are not allowed."
)
self._space_ids = space_ids or []
self._access_token: str | None = None
self._token_expiry: float = 0.0
@ -94,10 +99,7 @@ class FeishuKBAdapter(KBAdapter):
self._client = None
return self._access_token
else:
logger.error(
f"Feishu auth failed: code={data.get('code')}, "
f"msg={data.get('msg')}"
)
logger.error(f"Feishu auth failed: code={data.get('code')}, msg={data.get('msg')}")
return None
except Exception as e:
logger.error(f"Feishu auth error: {e}")
@ -121,7 +123,7 @@ class FeishuKBAdapter(KBAdapter):
client = self._get_client()
try:
payload: dict[str, Any] = {
payload: FeishuSearchPayload = {
"search_key": query,
"page_size": top_k,
}
@ -137,8 +139,7 @@ class FeishuKBAdapter(KBAdapter):
if data.get("code") != 0:
logger.error(
f"Feishu search failed: code={data.get('code')}, "
f"msg={data.get('msg')}"
f"Feishu search failed: code={data.get('code')}, msg={data.get('msg')}"
)
return []
@ -162,8 +163,7 @@ class FeishuKBAdapter(KBAdapter):
except httpx.HTTPStatusError as e:
logger.error(
f"Feishu search HTTP error: {e.response.status_code}"
f"{e.response.text[:200]}"
f"Feishu search HTTP error: {e.response.status_code}{e.response.text[:200]}"
)
return []
except Exception as e:
@ -179,7 +179,7 @@ class FeishuKBAdapter(KBAdapter):
client = self._get_client()
try:
resp = await client.get(
f"/wiki/v2/spaces/get_node",
"/wiki/v2/spaces/get_node",
params={"token": doc_id},
)
resp.raise_for_status()
@ -230,13 +230,17 @@ class FeishuKBAdapter(KBAdapter):
source_type="feishu",
)
)
return sources if sources else [
return (
sources
if sources
else [
SourceInfo(
source_id=self._source_id,
source_name=self._source_name,
source_type=self._source_type,
)
]
)
except Exception as e:
logger.error(f"Feishu list_sources error: {e}")
return [

View File

@ -7,7 +7,6 @@
from __future__ import annotations
import logging
from typing import Any
import httpx
@ -55,7 +54,9 @@ class GenericHTTPAdapter(KBAdapter):
)
self._endpoint_url = endpoint_url.rstrip("/")
if not is_safe_url(self._endpoint_url):
raise ValueError(f"Unsafe endpoint_url: {self._endpoint_url}. Private/internal URLs are not allowed.")
raise ValueError(
f"Unsafe endpoint_url: {self._endpoint_url}. Private/internal URLs are not allowed."
)
self._auth_config = auth_config or {}
self._extra_headers = headers or {}
@ -74,12 +75,11 @@ class GenericHTTPAdapter(KBAdapter):
headers["Authorization"] = f"Bearer {token}"
elif auth_type == "basic":
import base64
username = self._auth_config.get("username", "")
password = self._auth_config.get("password", "")
if username and password:
credentials = base64.b64encode(
f"{username}:{password}".encode()
).decode()
credentials = base64.b64encode(f"{username}:{password}".encode()).decode()
headers["Authorization"] = f"Basic {credentials}"
elif auth_type == "api_key":
key_name = self._auth_config.get("header_name", "X-API-Key")
@ -135,8 +135,7 @@ class GenericHTTPAdapter(KBAdapter):
except httpx.HTTPStatusError as e:
logger.error(
f"GenericHTTP search HTTP error: {e.response.status_code}"
f"{e.response.text[:200]}"
f"GenericHTTP search HTTP error: {e.response.status_code}{e.response.text[:200]}"
)
return []
except Exception as e:
@ -177,8 +176,7 @@ class GenericHTTPAdapter(KBAdapter):
except httpx.HTTPStatusError as e:
logger.error(
f"GenericHTTP ingest HTTP error: {e.response.status_code}"
f"{e.response.text[:200]}"
f"GenericHTTP ingest HTTP error: {e.response.status_code}{e.response.text[:200]}"
)
return []
except Exception as e:
@ -245,13 +243,17 @@ class GenericHTTPAdapter(KBAdapter):
document_count=item.get("document_count", 0),
)
)
return sources if sources else [
return (
sources
if sources
else [
SourceInfo(
source_id=self._source_id,
source_name=self._source_name,
source_type=self._source_type,
)
]
)
except Exception as e:
logger.debug(f"GenericHTTP list_sources error (endpoint may not exist): {e}")

View File

@ -3,19 +3,29 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any
from typing import TypeAlias
# 共享类型别名 — 跨 memory 子系统复用,避免 `Any` 残留。
# MetadataValue 覆盖 metadata dict 中实际出现的原始类型;
# MemoryValue 额外允许 dict/list 容器以容纳结构化负载(如 episodic 经验字典)。
MetadataValue: TypeAlias = str | int | float | bool | None
MetadataDict: TypeAlias = dict[str, MetadataValue]
MemoryValue: TypeAlias = (
str | int | float | bool | None | dict[str, MetadataValue] | list[MetadataValue]
)
@dataclass
class MemoryItem:
"""记忆条目"""
key: str
value: Any
metadata: dict[str, Any] = field(default_factory=dict)
value: object
metadata: MetadataDict = field(default_factory=dict)
score: float = 1.0
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
def to_dict(self) -> dict:
def to_dict(self) -> dict[str, object]:
return {
"key": self.key,
"value": self.value,
@ -35,7 +45,7 @@ class Memory(ABC):
"""
@abstractmethod
async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None:
async def store(self, key: str, value: object, metadata: MetadataDict | None = None) -> None:
"""存储记忆"""
...
@ -45,7 +55,9 @@ class Memory(ABC):
...
@abstractmethod
async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None) -> list[MemoryItem]:
async def search(
self, query: str, top_k: int = 5, filters: MetadataDict | None = None
) -> list[MemoryItem]:
"""语义检索"""
...
@ -54,7 +66,7 @@ class Memory(ABC):
"""删除记忆"""
...
async def store_batch(self, items: list[tuple[str, Any, dict | None]]) -> None:
async def store_batch(self, items: list[tuple[str, object, MetadataDict | None]]) -> None:
"""批量存储"""
for key, value, metadata in items:
await self.store(key, value, metadata)

View File

@ -11,10 +11,18 @@ import logging
import re
import uuid
from dataclasses import dataclass, field
from typing import Any
from typing import TypeAlias
from agentkit.memory.base import MetadataDict
logger = logging.getLogger(__name__)
# 分块元数据source_doc/position/char_count/chunking_strategy/heading/heading_level
# — 全部为原始标量str/int
ChunkMetadata: TypeAlias = MetadataDict
# _split_by_headings 返回的节段结构。
SectionInfo: TypeAlias = dict[str, str | int]
@dataclass
class Chunk:
@ -22,7 +30,7 @@ class Chunk:
chunk_id: str
content: str
metadata: dict[str, Any] = field(default_factory=dict)
metadata: ChunkMetadata = field(default_factory=dict)
def __post_init__(self) -> None:
if "source_doc" not in self.metadata:
@ -30,7 +38,7 @@ class Chunk:
if "position" not in self.metadata:
self.metadata["position"] = 0
def to_dict(self) -> dict[str, Any]:
def to_dict(self) -> dict[str, object]:
return {
"chunk_id": self.chunk_id,
"content": self.content,
@ -57,7 +65,9 @@ class TextChunker:
separator: 优先分割符
"""
if chunk_overlap >= chunk_size:
raise ValueError(f"chunk_overlap ({chunk_overlap}) must be less than chunk_size ({chunk_size})")
raise ValueError(
f"chunk_overlap ({chunk_overlap}) must be less than chunk_size ({chunk_size})"
)
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap
self._separator = separator
@ -66,7 +76,7 @@ class TextChunker:
self,
text: str,
source_doc_id: str = "",
metadata: dict[str, Any] | None = None,
metadata: ChunkMetadata | None = None,
) -> list[Chunk]:
"""将文本分块
@ -96,11 +106,13 @@ class TextChunker:
chunk_meta = dict(base_meta)
chunk_meta["position"] = i
chunk_meta["char_count"] = len(chunk_text)
chunks.append(Chunk(
chunks.append(
Chunk(
chunk_id=str(uuid.uuid4()),
content=chunk_text,
metadata=chunk_meta,
))
)
)
return chunks
@ -142,7 +154,9 @@ class TextChunker:
overlap_text[overlap_start:], segments
)
current = overlap_segments
current_len = sum(len(s) for s in current) + len(self._separator) * max(0, len(current) - 1)
current_len = sum(len(s) for s in current) + len(self._separator) * max(
0, len(current) - 1
)
current.append(segment)
current_len += seg_len + len(self._separator)
@ -214,7 +228,7 @@ class StructuralChunker:
self,
text: str,
source_doc_id: str = "",
metadata: dict[str, Any] | None = None,
metadata: ChunkMetadata | None = None,
) -> list[Chunk]:
"""将文本按结构分块
@ -266,23 +280,25 @@ class StructuralChunker:
chunk_meta["heading"] = heading
chunk_meta["heading_level"] = level
chunk_meta["char_count"] = len(content)
chunks.append(Chunk(
chunks.append(
Chunk(
chunk_id=str(uuid.uuid4()),
content=content,
metadata=chunk_meta,
))
)
)
position += 1
return chunks
def _split_by_headings(self, text: str) -> list[dict[str, Any]]:
def _split_by_headings(self, text: str) -> list[SectionInfo]:
"""按标题分割 Markdown 文本
Returns:
列表每项包含 heading, content, level
"""
lines = text.split("\n")
sections: list[dict[str, Any]] = []
sections: list[SectionInfo] = []
current_heading = ""
current_level = 0
current_lines: list[str] = []
@ -296,11 +312,13 @@ class StructuralChunker:
if current_lines:
content = "\n".join(current_lines).strip()
if content:
sections.append({
sections.append(
{
"heading": current_heading,
"content": content,
"level": current_level,
})
}
)
# 开始新节
current_heading = match.group(2).strip()
@ -313,18 +331,22 @@ class StructuralChunker:
if current_lines:
content = "\n".join(current_lines).strip()
if content:
sections.append({
sections.append(
{
"heading": current_heading,
"content": content,
"level": current_level,
})
}
)
# 如果没有标题结构,整体作为一个块
if not sections:
sections.append({
sections.append(
{
"heading": "",
"content": text.strip(),
"level": 0,
})
}
)
return sections

View File

@ -9,10 +9,14 @@ from __future__ import annotations
import hashlib
import logging
from dataclasses import dataclass
from typing import Any
from typing import TYPE_CHECKING
from agentkit.memory.base import MetadataDict
from agentkit.memory.embedder import EmbeddingCache
if TYPE_CHECKING:
from agentkit.llm.gateway import LLMGateway
logger = logging.getLogger(__name__)
@ -24,7 +28,7 @@ class ContextualChunk:
context_prefix: str
enhanced_content: str
chunk_index: int
metadata: dict[str, Any]
metadata: MetadataDict
@property
def content(self) -> str:
@ -65,7 +69,7 @@ class ContextualChunker:
def __init__(
self,
llm_gateway: Any = None,
llm_gateway: LLMGateway | None = None,
cache: EmbeddingCache | None = None,
batch_size: int = 8,
max_context_length: int = 200,
@ -90,7 +94,7 @@ class ContextualChunker:
self,
document: str,
chunks: list[str],
metadata: dict[str, Any] | None = None,
metadata: MetadataDict | None = None,
) -> list[ContextualChunk]:
"""为文档块添加上下文前缀
@ -134,7 +138,7 @@ class ContextualChunker:
document: str,
chunks: list[str],
start_index: int,
metadata: dict[str, Any] | None,
metadata: MetadataDict | None,
) -> list[ContextualChunk]:
"""处理一批文档块"""
results: list[ContextualChunk] = []

View File

@ -12,7 +12,9 @@ import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from typing import TypeAlias
from agentkit.memory.base import MetadataDict
logger = logging.getLogger(__name__)
@ -23,6 +25,11 @@ MAX_CONTENT_SIZE = 100 * 1024 * 1024 # 100MB
MAX_ROWS_PER_SHEET = 10_000
MAX_CELL_CHARS = 10_000
# 文档元数据source/format/parser/page_count/table_count/sheet_count/row_count/
# heading_count/created_at/title/truncated — 全部为原始标量。
DocumentMetadata: TypeAlias = MetadataDict
ParseResult: TypeAlias = tuple[str, DocumentMetadata]
@dataclass
class Document:
@ -31,7 +38,7 @@ class Document:
doc_id: str
title: str
content: str
metadata: dict[str, Any] = field(default_factory=dict)
metadata: DocumentMetadata = field(default_factory=dict)
def __post_init__(self) -> None:
if "source" not in self.metadata:
@ -43,7 +50,7 @@ class Document:
if "created_at" not in self.metadata:
self.metadata["created_at"] = datetime.now(timezone.utc).isoformat()
def to_dict(self) -> dict[str, Any]:
def to_dict(self) -> dict[str, object]:
return {
"doc_id": self.doc_id,
"title": self.title,
@ -136,12 +143,14 @@ class DocumentLoader:
parser = parsers.get(doc_format)
if parser is None:
logger.warning(f"Unsupported format '{doc_format}' for {filename}, falling back to text")
logger.warning(
f"Unsupported format '{doc_format}' for {filename}, falling back to text"
)
parser = self._parse_text
text, extra_meta = parser(content, filename)
metadata: dict[str, Any] = {
metadata: DocumentMetadata = {
"source": filename,
"format": doc_format,
"created_at": datetime.now(timezone.utc).isoformat(),
@ -159,7 +168,7 @@ class DocumentLoader:
metadata=metadata,
)
def _parse_pdf(self, content: bytes, filename: str) -> tuple[str, dict[str, Any]]:
def _parse_pdf(self, content: bytes, filename: str) -> ParseResult:
"""解析 PDF 文件
优先使用 PyMuPDF (fitz)回退到 pdfplumber最终回退到纯文本
@ -215,7 +224,7 @@ class DocumentLoader:
logger.warning(f"No PDF parser available for {filename}, falling back to text extraction")
return self._parse_text(content, filename)
def _parse_docx(self, content: bytes, filename: str) -> tuple[str, dict[str, Any]]:
def _parse_docx(self, content: bytes, filename: str) -> ParseResult:
"""解析 Word 文件
使用 python-docx回退到纯文本
@ -259,7 +268,7 @@ class DocumentLoader:
logger.warning(f"python-docx parsing failed for {filename}: {e}")
return self._parse_text(content, filename)
def _parse_xlsx(self, content: bytes, filename: str) -> tuple[str, dict[str, Any]]:
def _parse_xlsx(self, content: bytes, filename: str) -> ParseResult:
"""解析 Excel 文件
使用 openpyxl回退到纯文本每个 sheet 转为 Markdown 表格
@ -313,7 +322,7 @@ class DocumentLoader:
finally:
wb.close()
text = "\n".join(sections).strip()
meta: dict[str, Any] = {
meta: DocumentMetadata = {
"parser": "openpyxl",
"sheet_count": sheet_count,
"row_count": total_rows,
@ -328,7 +337,7 @@ class DocumentLoader:
logger.warning(f"openpyxl parsing failed for {filename}: {e}")
return self._parse_text(content, filename)
def _parse_markdown(self, content: bytes, filename: str) -> tuple[str, dict[str, Any]]:
def _parse_markdown(self, content: bytes, filename: str) -> ParseResult:
"""解析 Markdown 文件
使用 mistune如果可用否则直接读取文本
@ -347,7 +356,7 @@ class DocumentLoader:
title = line_stripped.lstrip("#").strip()
break
meta: dict[str, Any] = {
meta: DocumentMetadata = {
"parser": "markdown",
}
if title:
@ -362,7 +371,7 @@ class DocumentLoader:
return text, meta
def _parse_html(self, content: bytes, filename: str) -> tuple[str, dict[str, Any]]:
def _parse_html(self, content: bytes, filename: str) -> ParseResult:
"""解析 HTML 文件
使用 BeautifulSoup 提取文本回退到纯文本
@ -388,7 +397,7 @@ class DocumentLoader:
if soup.title and soup.title.string:
title = soup.title.string.strip()
meta: dict[str, Any] = {
meta: DocumentMetadata = {
"parser": "beautifulsoup",
}
if title:
@ -402,7 +411,7 @@ class DocumentLoader:
logger.warning(f"BeautifulSoup parsing failed for {filename}: {e}")
return self._parse_text(content, filename)
def _parse_text(self, content: bytes, filename: str) -> tuple[str, dict[str, Any]]:
def _parse_text(self, content: bytes, filename: str) -> ParseResult:
"""解析纯文本文件"""
try:
text = content.decode("utf-8")

View File

@ -1,12 +1,17 @@
"""Embedder 接口与实现 - 文本向量化"""
from __future__ import annotations
import hashlib
import logging
import os
import time
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Any
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import httpx
logger = logging.getLogger(__name__)
@ -97,13 +102,14 @@ class OpenAIEmbedder(Embedder):
self._model = model
self._base_url = base_url
self._dimension = 1536 # text-embedding-3-small 默认维度
self._client: Any = None
self._client: httpx.AsyncClient | None = None
self._cache = cache
def _get_client(self):
def _get_client(self) -> httpx.AsyncClient:
"""Lazily create and reuse a single httpx.AsyncClient."""
if self._client is None:
import httpx
self._client = httpx.AsyncClient(timeout=30.0)
return self._client

View File

@ -1,17 +1,22 @@
"""Episodic Memory - 基于 pgvector + PostgreSQL 的任务经验记忆"""
from __future__ import annotations
import json
import logging
import math
from datetime import datetime, timezone
from typing import Any
from typing import TYPE_CHECKING
from sqlalchemy import text
from agentkit.memory.base import Memory, MemoryItem
from agentkit.memory.base import Memory, MemoryItem, MetadataDict
from agentkit.memory.embedder import Embedder
from agentkit.utils.vector_math import compute_cosine_similarity
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
logger = logging.getLogger(__name__)
@ -28,8 +33,8 @@ class EpisodicMemory(Memory):
def __init__(
self,
session_factory: Any,
episodic_model: Any,
session_factory: object,
episodic_model: object,
embedder: Embedder | None = None,
decay_rate: float = 0.01,
alpha: float = 0.7,
@ -57,7 +62,7 @@ class EpisodicMemory(Memory):
self._pgvector_enabled = pgvector_enabled
self._table_name = table_name
async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None:
async def store(self, key: str, value: object, metadata: MetadataDict | None = None) -> None:
"""存储任务经验"""
async with self._session_factory() as db:
try:
@ -68,7 +73,11 @@ class EpisodicMemory(Memory):
embedding = None
if self._embedder:
if isinstance(value, dict):
text = value.get("output_summary", "") or value.get("input_summary", "") or json.dumps(value, ensure_ascii=False)[:500]
text = (
value.get("output_summary", "")
or value.get("input_summary", "")
or json.dumps(value, ensure_ascii=False)[:500]
)
else:
text = str(value)
embedding = await self._embedder.embed(text)
@ -106,13 +115,11 @@ class EpisodicMemory(Memory):
logger.error(f"Failed to retrieve episodic memory: {e}")
return None
async def _retrieve_pgvector(self, db: Any, query_embedding: list[float]) -> MemoryItem | None:
async def _retrieve_pgvector(
self, db: AsyncSession, query_embedding: list[float]
) -> MemoryItem | None:
"""使用 pgvector ``<=>`` 算符检索最相似条目"""
sql = text(
f"SELECT * FROM {self._table_name} "
f"ORDER BY embedding <=> :query_vec "
f"LIMIT :lim"
)
sql = text(f"SELECT * FROM {self._table_name} ORDER BY embedding <=> :query_vec LIMIT :lim")
result = await db.execute(sql, {"query_vec": str(query_embedding), "lim": 1})
row = result.mappings().first()
@ -147,7 +154,9 @@ class EpisodicMemory(Memory):
created_at=row.get("created_at") or datetime.now(timezone.utc),
)
async def _retrieve_client_side(self, db: Any, query_embedding: list[float]) -> MemoryItem | None:
async def _retrieve_client_side(
self, db: AsyncSession, query_embedding: list[float]
) -> MemoryItem | None:
"""客户端 O(N) cosine similarity 检索(回退路径)"""
Model = self._episodic_model
from sqlalchemy import select
@ -193,7 +202,13 @@ class EpisodicMemory(Memory):
created_at=best_item.created_at or datetime.now(timezone.utc),
)
async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None, search_multiplier: int = 5) -> list[MemoryItem]:
async def search(
self,
query: str,
top_k: int = 5,
filters: MetadataDict | None = None,
search_multiplier: int = 5,
) -> list[MemoryItem]:
"""语义检索相似历史案例
Args:
@ -214,10 +229,10 @@ class EpisodicMemory(Memory):
async def _search_pgvector(
self,
db: Any,
db: AsyncSession,
query: str,
top_k: int,
filters: dict[str, Any] | None,
filters: MetadataDict | None,
search_multiplier: int,
) -> list[MemoryItem]:
"""使用 pgvector ``<=>`` 算符检索,再 Python 侧 time_decay 重排"""
@ -225,7 +240,7 @@ class EpisodicMemory(Memory):
fetch_limit = top_k * search_multiplier
where_clauses = []
params: dict[str, Any] = {"query_vec": str(query_embedding), "lim": fetch_limit}
params: dict[str, object] = {"query_vec": str(query_embedding), "lim": fetch_limit}
filters = filters or {}
if filters.get("agent_name"):
@ -256,7 +271,11 @@ class EpisodicMemory(Memory):
items = []
for row in rows:
row_embedding = row.get("embedding")
age_hours = (datetime.now(timezone.utc) - row["created_at"]).total_seconds() / 3600 if row.get("created_at") else 0
age_hours = (
(datetime.now(timezone.utc) - row["created_at"]).total_seconds() / 3600
if row.get("created_at")
else 0
)
decay = math.exp(-self._decay_rate * age_hours)
time_decay_score = (row.get("quality_score") or 0.5) * decay
@ -266,7 +285,8 @@ class EpisodicMemory(Memory):
else:
score = time_decay_score
items.append(MemoryItem(
items.append(
MemoryItem(
key=str(row.get("id", "")),
value={
"input_summary": row.get("input_summary", ""),
@ -278,21 +298,24 @@ class EpisodicMemory(Memory):
metadata={
"agent_name": row.get("agent_name", ""),
"task_type": row.get("task_type", ""),
"created_at": row["created_at"].isoformat() if row.get("created_at") else None,
"created_at": row["created_at"].isoformat()
if row.get("created_at")
else None,
},
score=score,
created_at=row.get("created_at") or datetime.now(timezone.utc),
))
)
)
items.sort(key=lambda x: x.score, reverse=True)
return items[:top_k]
async def _search_client_side(
self,
db: Any,
db: AsyncSession,
query: str,
top_k: int,
filters: dict[str, Any] | None,
filters: MetadataDict | None,
search_multiplier: int,
) -> list[MemoryItem]:
"""客户端 O(N) cosine similarity 检索(回退路径)"""
@ -300,6 +323,7 @@ class EpisodicMemory(Memory):
filters = filters or {}
from sqlalchemy import select
stmt = select(Model)
if filters.get("agent_name"):
@ -322,7 +346,11 @@ class EpisodicMemory(Memory):
# 计算得分并构建 MemoryItem
items = []
for entry in entries:
age_hours = (datetime.now(timezone.utc) - entry.created_at).total_seconds() / 3600 if entry.created_at else 0
age_hours = (
(datetime.now(timezone.utc) - entry.created_at).total_seconds() / 3600
if entry.created_at
else 0
)
decay = math.exp(-self._decay_rate * age_hours)
time_decay_score = (entry.quality_score or 0.5) * decay
@ -333,7 +361,8 @@ class EpisodicMemory(Memory):
else:
score = time_decay_score
items.append(MemoryItem(
items.append(
MemoryItem(
key=str(entry.id),
value={
"input_summary": entry.input_summary,
@ -349,14 +378,17 @@ class EpisodicMemory(Memory):
},
score=score,
created_at=entry.created_at or datetime.now(timezone.utc),
))
)
)
items.sort(key=lambda x: x.score, reverse=True)
if len(items) < top_k:
logger.warning(
"EpisodicMemory.search returned %d results after scoring (top_k=%d). "
"Consider increasing search_multiplier (current=%d) to avoid missing relevant entries.",
len(items), top_k, search_multiplier,
len(items),
top_k,
search_multiplier,
)
return items[:top_k]
@ -364,8 +396,9 @@ class EpisodicMemory(Memory):
"""删除指定经验"""
async with self._session_factory() as db:
try:
from sqlalchemy import select, delete as sql_delete
from sqlalchemy import delete as sql_delete
import uuid
Model = self._episodic_model
stmt = sql_delete(Model).where(Model.id == uuid.UUID(key))

View File

@ -3,13 +3,26 @@
配置驱动不直接依赖业务系统代码通过 base_url + api_key 连接
"""
from __future__ import annotations
import logging
from typing import Any
from typing import TYPE_CHECKING, TypeAlias
import httpx
from agentkit.memory.base import MetadataDict
if TYPE_CHECKING:
from agentkit.llm.gateway import LLMGateway
logger = logging.getLogger(__name__)
# 标准化检索结果id/content/score/source/document_id/document_title/
# knowledge_base_id/metadata — 值为原始标量或嵌套 dict。
RAGSearchResult: TypeAlias = dict[str, object]
# ingest() 写入的文档负载title/content/source_type/metadata。
RAGIngestPayload: TypeAlias = dict[str, object]
class HttpRAGService:
"""HTTP 客户端,调用业务系统的知识库检索 API
@ -39,7 +52,7 @@ class HttpRAGService:
knowledge_base_ids: list[str] | None = None,
timeout: int = 30,
contextual_chunking: bool = False,
llm_gateway: Any = None,
llm_gateway: LLMGateway | None = None,
):
"""
Args:
@ -74,7 +87,7 @@ class HttpRAGService:
query: str,
knowledge_base_ids: list[str] | None = None,
top_k: int = 5,
) -> list[dict[str, Any]]:
) -> list[RAGSearchResult]:
"""语义检索知识库
Args:
@ -113,7 +126,8 @@ class HttpRAGService:
normalized = []
for r in results:
if isinstance(r, dict):
normalized.append({
normalized.append(
{
"id": r.get("chunk_id", r.get("id", "")),
"content": r.get("content", ""),
"score": float(r.get("score", 0.0)),
@ -121,11 +135,14 @@ class HttpRAGService:
"document_id": r.get("document_id", ""),
"document_title": r.get("document_title", ""),
"metadata": r.get("metadata", {}),
})
}
)
return normalized
except httpx.HTTPStatusError as e:
logger.error(f"RAG search HTTP error: {e.response.status_code}{e.response.text[:200]}")
logger.error(
f"RAG search HTTP error: {e.response.status_code}{e.response.text[:200]}"
)
return []
except httpx.RequestError as e:
logger.error(f"RAG search request error: {e}")
@ -141,7 +158,7 @@ class HttpRAGService:
top_k: int = 5,
use_rerank: bool = True,
use_compression: bool = False,
) -> list[dict[str, Any]]:
) -> list[RAGSearchResult]:
"""增强语义检索知识库(支持 rerank 和 compression
对每个知识库分别调用 /bases/{kb_id}/retrieve 接口
@ -169,7 +186,7 @@ class HttpRAGService:
}
client = self._get_client()
all_results: list[dict[str, Any]] = []
all_results: list[RAGSearchResult] = []
for kb_id in kb_ids:
try:
@ -189,7 +206,8 @@ class HttpRAGService:
# 标准化
for r in results:
if isinstance(r, dict):
all_results.append({
all_results.append(
{
"id": r.get("chunk_id", r.get("id", "")),
"content": r.get("content", ""),
"score": float(r.get("score", 0.0)),
@ -198,19 +216,17 @@ class HttpRAGService:
"document_title": r.get("document_title", ""),
"knowledge_base_id": kb_id,
"metadata": r.get("metadata", {}),
})
}
)
except httpx.HTTPStatusError as e:
if e.response.status_code == 404:
# This KB doesn't support enhanced search — fall back to
# standard search for THIS KB only, not all KBs.
logger.info(
f"Enhanced search not available for KB {kb_id}, "
f"using standard search"
)
std_result = await self.search(
query, knowledge_base_ids=[kb_id], top_k=top_k
f"Enhanced search not available for KB {kb_id}, using standard search"
)
std_result = await self.search(query, knowledge_base_ids=[kb_id], top_k=top_k)
all_results.extend(std_result)
else:
logger.error(
@ -232,9 +248,9 @@ class HttpRAGService:
async def ingest(
self,
key: str,
value: Any,
metadata: dict[str, Any] | None = None,
) -> dict[str, Any] | None:
value: object,
metadata: MetadataDict | None = None,
) -> dict[str, object] | None:
"""写入文档到知识库(可选操作)
When contextual_chunking is enabled and llm_gateway is configured,
@ -308,5 +324,5 @@ class HttpRAGService:
async def __aenter__(self) -> "HttpRAGService":
return self
async def __aexit__(self, *args: Any) -> None:
async def __aexit__(self, *args: object) -> None:
await self.close()

View File

@ -11,7 +11,11 @@ from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any, Protocol, runtime_checkable
from typing import Protocol, TypeAlias, runtime_checkable
from agentkit.memory.base import MetadataDict
KBMetadata: TypeAlias = MetadataDict
@dataclass
@ -22,7 +26,7 @@ class Document:
content: str
title: str = ""
source_id: str = ""
metadata: dict[str, Any] = field(default_factory=dict)
metadata: KBMetadata = field(default_factory=dict)
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
@ -34,7 +38,7 @@ class QueryResult:
source_id: str
source_name: str
score: float
metadata: dict[str, Any] = field(default_factory=dict)
metadata: KBMetadata = field(default_factory=dict)
doc_id: str = ""
title: str = ""

View File

@ -11,25 +11,32 @@ from __future__ import annotations
import json
import logging
import re
import uuid
from datetime import datetime, timezone
from typing import Any
_SAFE_TABLE_NAME_PATTERN = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
from typing import TYPE_CHECKING, TypeAlias
from agentkit.memory.chunking import Chunk, StructuralChunker, TextChunker
from agentkit.memory.document_loader import Document as LoaderDocument
from agentkit.memory.embedder import Embedder
from agentkit.memory.knowledge_base import (
Document,
KnowledgeBase,
QueryResult,
SourceInfo,
)
from agentkit.utils.vector_math import compute_cosine_similarity
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
_SAFE_TABLE_NAME_PATTERN = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
logger = logging.getLogger(__name__)
# InMemoryLocalRAGService 内部存储的文档元信息结构。
# 字段title/source_id/format/chunk_ids/metadata/created_at — 值为标量或 list[str]。
InMemoryDocInfo: TypeAlias = dict[str, object]
# 内部 chunk 存储结构content/embedding/metadata/source_doc_id。
InMemoryChunkInfo: TypeAlias = dict[str, object]
def _loader_doc_to_kb_doc(loader_doc: LoaderDocument) -> Document:
"""将 document_loader.Document 转换为 knowledge_base.Document"""
@ -53,7 +60,7 @@ class LocalRAGService:
def __init__(
self,
session_factory: Any,
session_factory: object,
embedder: Embedder,
chunk_size: int = 1000,
chunk_overlap: int = 200,
@ -75,10 +82,14 @@ class LocalRAGService:
self._chunk_overlap = chunk_overlap
self._table_name = table_name
if not _SAFE_TABLE_NAME_PATTERN.match(self._table_name):
raise ValueError(f"Invalid table_name: {self._table_name}. Must match [a-zA-Z_][a-zA-Z0-9_]*")
raise ValueError(
f"Invalid table_name: {self._table_name}. Must match [a-zA-Z_][a-zA-Z0-9_]*"
)
self._pgvector_enabled = pgvector_enabled
self._text_chunker = TextChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
self._structural_chunker = StructuralChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
self._structural_chunker = StructuralChunker(
chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
async def ingest(self, documents: list[Document]) -> list[str]:
"""摄取文档列表
@ -136,9 +147,7 @@ class LocalRAGService:
try:
from sqlalchemy import text as sql_text
sql = sql_text(
f"DELETE FROM {self._table_name} WHERE source_doc_id = :doc_id"
)
sql = sql_text(f"DELETE FROM {self._table_name} WHERE source_doc_id = :doc_id")
await db.execute(sql, {"doc_id": id})
await db.commit()
return True
@ -171,20 +180,15 @@ class LocalRAGService:
sources = []
for row in rows:
meta = {}
if row.get("doc_metadata"):
try:
meta = json.loads(row["doc_metadata"])
except (json.JSONDecodeError, TypeError):
pass
sources.append(SourceInfo(
sources.append(
SourceInfo(
source_id=row["source_doc_id"],
source_name=row.get("source_title", ""),
source_type=row.get("doc_format", "local"),
document_count=row.get("chunk_count", 0),
last_updated=row["created_at"] if row.get("created_at") else None,
))
)
)
return sources
except Exception as e:
logger.error(f"Failed to list sources: {e}")
@ -271,7 +275,7 @@ class LocalRAGService:
async def _query_pgvector(
self,
db: Any,
db: AsyncSession,
query_embedding: list[float],
top_k: int,
) -> list[QueryResult]:
@ -286,10 +290,13 @@ class LocalRAGService:
f"LIMIT :lim"
)
result = await db.execute(sql, {
result = await db.execute(
sql,
{
"query_vec": str(query_embedding),
"lim": top_k,
})
},
)
rows = result.mappings().all()
results = []
@ -306,7 +313,8 @@ class LocalRAGService:
except (json.JSONDecodeError, TypeError):
pass
results.append(QueryResult(
results.append(
QueryResult(
content=row["content"],
source_id=row["source_doc_id"],
source_name=row.get("source_title", ""),
@ -314,13 +322,14 @@ class LocalRAGService:
metadata=chunk_meta,
doc_id=row["source_doc_id"],
title=row.get("source_title", ""),
))
)
)
return results
async def _query_client_side(
self,
db: Any,
db: AsyncSession,
query_embedding: list[float],
top_k: int,
) -> list[QueryResult]:
@ -363,7 +372,8 @@ class LocalRAGService:
except (json.JSONDecodeError, TypeError):
pass
candidates.append(QueryResult(
candidates.append(
QueryResult(
content=row["content"],
source_id=row["source_doc_id"],
source_name=row.get("source_title", ""),
@ -371,7 +381,8 @@ class LocalRAGService:
metadata=chunk_meta,
doc_id=row["source_doc_id"],
title=row.get("source_title", ""),
))
)
)
candidates.sort(key=lambda x: x.score, reverse=True)
return candidates[:top_k]
@ -398,11 +409,15 @@ class InMemoryLocalRAGService:
"""
self._embedder = embedder
self._text_chunker = TextChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
self._structural_chunker = StructuralChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
self._structural_chunker = StructuralChunker(
chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
# 内存存储
self._chunks: dict[str, dict[str, Any]] = {} # chunk_id → {content, embedding, metadata}
self._documents: dict[str, dict[str, Any]] = {} # doc_id → {title, format, chunk_ids, metadata, created_at}
self._chunks: dict[str, InMemoryChunkInfo] = {} # chunk_id → {content, embedding, metadata}
self._documents: dict[
str, InMemoryDocInfo
] = {} # doc_id → {title, format, chunk_ids, metadata, created_at}
async def ingest(self, documents: list[Document]) -> list[str]:
"""摄取文档列表
@ -459,7 +474,8 @@ class InMemoryLocalRAGService:
source_doc_id = chunk_data["source_doc_id"]
doc_info = self._documents.get(source_doc_id, {})
candidates.append(QueryResult(
candidates.append(
QueryResult(
content=chunk_data["content"],
source_id=source_doc_id,
source_name=doc_info.get("title", ""),
@ -467,7 +483,8 @@ class InMemoryLocalRAGService:
metadata=chunk_data.get("metadata", {}),
doc_id=source_doc_id,
title=doc_info.get("title", ""),
))
)
)
candidates.sort(key=lambda x: x.score, reverse=True)
return candidates[:top_k]
@ -488,13 +505,15 @@ class InMemoryLocalRAGService:
"""列出已摄取的文档"""
sources = []
for doc_id, doc_info in self._documents.items():
sources.append(SourceInfo(
sources.append(
SourceInfo(
source_id=doc_id,
source_name=doc_info["title"],
source_type=doc_info.get("format", "local"),
document_count=len(doc_info.get("chunk_ids", [])),
last_updated=doc_info.get("created_at"),
))
)
)
return sources
async def health_check(self) -> bool:

View File

@ -13,7 +13,6 @@ import asyncio
import hashlib
import logging
from dataclasses import replace
from typing import Any
from agentkit.memory.knowledge_base import KnowledgeBase, QueryResult, SourceInfo
@ -186,15 +185,13 @@ class MultiSourceRetriever:
Returns:
所有源的检索结果列表已应用权重
"""
async def _query_one(name: str, kb: KnowledgeBase) -> list[QueryResult]:
try:
results = await kb.query(query, top_k=top_k)
# 应用权重
weight = (weights or {}).get(name, 1.0)
return [
replace(r, score=r.score * weight, source_name=name)
for r in results
]
return [replace(r, score=r.score * weight, source_name=name) for r in results]
except Exception as e:
logger.error(f"Query failed for source '{name}': {e}")
return []

View File

@ -7,10 +7,10 @@
from __future__ import annotations
import re
from dataclasses import dataclass, field
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Callable
from typing import Callable
class MemoryFile:
@ -26,8 +26,9 @@ class MemoryFile:
"""
def __init__(self, path: Path, char_budget: int | None = None,
protected_sections: set[str] | None = None):
def __init__(
self, path: Path, char_budget: int | None = None, protected_sections: set[str] | None = None
):
self.path = Path(path)
self.char_budget = char_budget
self._protected_sections = protected_sections or set()
@ -138,7 +139,7 @@ class MemoryFile:
for match in re.finditer(r"^## (.+)$", content, re.MULTILINE):
name = match.group(1).strip()
start = match.start()
next_match = re.search(r"^## ", content[match.end():], re.MULTILINE)
next_match = re.search(r"^## ", content[match.end() :], re.MULTILINE)
if next_match:
end = match.end() + next_match.start()
else:
@ -146,7 +147,7 @@ class MemoryFile:
sections.append((name, start, end))
if not sections:
return content[:self.char_budget]
return content[: self.char_budget]
# 保持原始顺序,标记每个 section 是否受保护
ordered: list[tuple[str, str, bool]] = [] # (name, text, is_protected)
@ -222,8 +223,9 @@ class MemoryStore:
"""
def __init__(self, base_dir: Path | str | None = None,
on_change: Callable[[str], None] | None = None):
def __init__(
self, base_dir: Path | str | None = None, on_change: Callable[[str], None] | None = None
):
if base_dir is None:
base_dir = Path.home() / ".agentkit"
self.base_dir = Path(base_dir)
@ -238,7 +240,9 @@ class MemoryStore:
protected_sections={"版本", "更新历史"},
)
self._user = MemoryFile(self.base_dir / "memories" / "USER.md", char_budget=USER_BUDGET)
self._memory = MemoryFile(self.base_dir / "memories" / "MEMORY.md", char_budget=MEMORY_BUDGET)
self._memory = MemoryFile(
self.base_dir / "memories" / "MEMORY.md", char_budget=MEMORY_BUDGET
)
self._daily_dir = self.base_dir / "memories" / "daily"
self._daily_dir.mkdir(parents=True, exist_ok=True)
@ -376,4 +380,5 @@ class MemoryStore:
self._on_change(new_prompt)
except Exception:
import logging
logging.getLogger(__name__).warning("Memory notify_change failed", exc_info=True)

View File

@ -87,10 +87,22 @@ class RuleQueryTransformer(QueryTransformerBase):
"""
_FILLER_WORDS_CN: list[str] = [
"帮我", "", "一下", "分析", "看看", "告诉我", "想知道", "请问",
"帮我",
"",
"一下",
"分析",
"看看",
"告诉我",
"想知道",
"请问",
]
_FILLER_WORDS_EN: list[str] = [
"please", "can you", "help me", "could you", "i want to", "i need to",
"please",
"can you",
"help me",
"could you",
"i want to",
"i need to",
]
def __init__(
@ -101,9 +113,7 @@ class RuleQueryTransformer(QueryTransformerBase):
self._synonyms = synonyms or {}
self._max_sub_queries = max_sub_queries
# Pre-compile filler patterns
self._filler_patterns_cn = [
re.compile(re.escape(w)) for w in self._FILLER_WORDS_CN
]
self._filler_patterns_cn = [re.compile(re.escape(w)) for w in self._FILLER_WORDS_CN]
self._filler_patterns_en = [
re.compile(re.escape(w), re.IGNORECASE) for w in self._FILLER_WORDS_EN
]
@ -166,7 +176,9 @@ def create_query_transformer(
"""工厂函数:根据策略创建查询改写器"""
if strategy == "llm":
if llm_gateway is None:
logger.warning("LLM strategy requested but no llm_gateway provided, falling back to NoOp")
logger.warning(
"LLM strategy requested but no llm_gateway provided, falling back to NoOp"
)
return NoOpQueryTransformer()
return LLMQueryTransformer(llm_gateway, max_sub_queries=max_sub_queries)
elif strategy == "rule":

View File

@ -7,11 +7,11 @@
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from dataclasses import dataclass
from enum import Enum
from typing import Any
from typing import TYPE_CHECKING
from agentkit.memory.base import MemoryItem
from agentkit.memory.base import MemoryItem, MetadataDict
from agentkit.memory.query_transformer import QueryTransformerBase, NoOpQueryTransformer
from agentkit.memory.relevance_scorer import (
RelevanceScorer,
@ -19,6 +19,10 @@ from agentkit.memory.relevance_scorer import (
RetrievalEvaluation,
)
if TYPE_CHECKING:
# 避免与 retriever.py 形成运行时循环导入。
from agentkit.memory.retriever import MemoryRetriever
logger = logging.getLogger(__name__)
@ -70,7 +74,7 @@ class RAGSelfCorrectionLoop:
def __init__(
self,
retriever: Any, # MemoryRetriever
retriever: MemoryRetriever,
scorer: RelevanceScorer | None = None,
query_transformer: QueryTransformerBase | None = None,
max_retries: int = 3,
@ -87,7 +91,7 @@ class RAGSelfCorrectionLoop:
query: str,
top_k: int = 5,
token_budget: int = 3000,
filters: dict[str, Any] | None = None,
filters: MetadataDict | None = None,
) -> RAGLoopResult:
"""执行带自纠正的检索
@ -107,8 +111,11 @@ class RAGSelfCorrectionLoop:
while retry_count <= self._max_retries:
# RETRIEVE
items = await self._retriever.retrieve(
current_query, top_k=top_k, token_budget=token_budget,
filters=filters, _skip_correction=True,
current_query,
top_k=top_k,
token_budget=token_budget,
filters=filters,
_skip_correction=True,
)
# EVALUATE
@ -144,9 +151,7 @@ class RAGSelfCorrectionLoop:
# CORRECT — rewrite query and retry
retry_count += 1
if retry_count <= self._max_retries:
current_query = await self._rewrite_query(
query, current_query, evaluation
)
current_query = await self._rewrite_query(query, current_query, evaluation)
continue
# DEGRADE — exceeded max retries
@ -154,9 +159,7 @@ class RAGSelfCorrectionLoop:
# Degraded result: filter to relevant items and mark low confidence
relevant_items = [
s.item
for s in evaluation.scores
if s.verdict != RelevanceVerdict.INCORRECT
s.item for s in evaluation.scores if s.verdict != RelevanceVerdict.INCORRECT
]
result_items = relevant_items if relevant_items else items

View File

@ -6,11 +6,9 @@
from __future__ import annotations
import logging
import math
import re
from dataclasses import dataclass
from enum import Enum
from typing import Any
from agentkit.memory.base import MemoryItem
@ -120,9 +118,7 @@ class RelevanceScorer:
reason=reason,
)
def evaluate(
self, query: str, items: list[MemoryItem]
) -> RetrievalEvaluation:
def evaluate(self, query: str, items: list[MemoryItem]) -> RetrievalEvaluation:
"""评估一次检索的整体质量"""
if not items:
return RetrievalEvaluation(
@ -134,9 +130,7 @@ class RelevanceScorer:
)
scores = [self.score_item(query, item) for item in items]
relevant_count = sum(
1 for s in scores if s.verdict != RelevanceVerdict.INCORRECT
)
relevant_count = sum(1 for s in scores if s.verdict != RelevanceVerdict.INCORRECT)
avg_score = sum(s.score for s in scores) / len(scores)
# Overall verdict based on average score and relevant ratio

View File

@ -7,19 +7,16 @@ from __future__ import annotations
import asyncio
import logging
import math
from dataclasses import replace
from datetime import datetime
from typing import Any
from agentkit.memory.base import Memory, MemoryItem
from agentkit.memory.base import MemoryItem, MetadataDict
from agentkit.memory.working import WorkingMemory
from agentkit.memory.episodic import EpisodicMemory
from agentkit.memory.semantic import SemanticMemory
from agentkit.memory.query_transformer import QueryTransformerBase
from agentkit.memory.rag_loop import RAGSelfCorrectionLoop
from agentkit.memory.relevance_scorer import RelevanceScorer
from agentkit.memory.knowledge_base import KnowledgeBase, QueryResult
from agentkit.memory.knowledge_base import KnowledgeBase
from agentkit.memory.multi_source_retriever import MultiSourceRetriever
from agentkit.tools.base import Tool
@ -32,11 +29,11 @@ def _estimate_tokens(text: str) -> int:
Chinese characters typically use 1-2 tokens each.
English words typically use 1 token each.
"""
cjk_count = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
cjk_count = sum(1 for c in text if "\u4e00" <= c <= "\u9fff")
non_cjk = text
for c in text:
if '\u4e00' <= c <= '\u9fff':
non_cjk = non_cjk.replace(c, ' ')
if "\u4e00" <= c <= "\u9fff":
non_cjk = non_cjk.replace(c, " ")
word_count = len(non_cjk.split())
return cjk_count * 2 + word_count
@ -89,7 +86,7 @@ class MemoryRetriever:
query: str,
top_k: int = 5,
token_budget: int = 3000,
filters: dict[str, Any] | None = None,
filters: MetadataDict | None = None,
_skip_correction: bool = False,
sources: list[str] | None = None,
source_weights: dict[str, float] | None = None,
@ -121,9 +118,7 @@ class MemoryRetriever:
query, top_k=top_k, token_budget=token_budget, filters=filters
)
if result.degraded:
logger.warning(
f"RAG self-correction degraded after {result.total_retries} retries"
)
logger.warning(f"RAG self-correction degraded after {result.total_retries} retries")
return result.items
# Query transformation
if self._query_transformer is not None:
@ -139,9 +134,7 @@ class MemoryRetriever:
# Sub-query search in parallel
if sub_queries:
sub_tasks = [
self._search_layers(sq, top_k, filters) for sq in sub_queries
]
sub_tasks = [self._search_layers(sq, top_k, filters) for sq in sub_queries]
sub_results = await asyncio.gather(*sub_tasks, return_exceptions=True)
for result in sub_results:
if isinstance(result, Exception):
@ -178,7 +171,7 @@ class MemoryRetriever:
self,
query: str,
top_k: int = 5,
filters: dict[str, Any] | None = None,
filters: MetadataDict | None = None,
) -> list[MemoryItem]:
"""Search all configured memory layers with a single query"""
tasks = []
@ -237,7 +230,8 @@ class MemoryRetriever:
# QueryResult → MemoryItem
items = []
for r in kb_results:
items.append(MemoryItem(
items.append(
MemoryItem(
key=r.source_id,
value=r.content,
metadata={
@ -248,7 +242,8 @@ class MemoryRetriever:
"document_title": r.title,
},
score=r.score,
))
)
)
# Token 预算管理
selected = []
@ -318,7 +313,9 @@ class MemoryRetriever:
if source == "rag":
kb_type = item.metadata.get("kb_type", "知识库")
document_title = item.metadata.get("document_title", "未知文档")
return f"### 知识库参考 [来源: {kb_type} | 相关度: {score:.2f} | 文档: {document_title}]"
return (
f"### 知识库参考 [来源: {kb_type} | 相关度: {score:.2f} | 文档: {document_title}]"
)
elif source == "graph":
return f"### 知识图谱 [实体: {item.key} | 相关度: {score:.2f}]"
elif source == "episodic":
@ -330,7 +327,7 @@ class MemoryRetriever:
return f"### 参考 [来源: {source} | 相关度: {score:.2f}]"
async def store_episode(
self, key: str, value: Any, metadata: dict[str, Any] | None = None
self, key: str, value: object, metadata: MetadataDict | None = None
) -> None:
"""Store an episode into episodic memory if available.
@ -386,12 +383,14 @@ class RetrieveKnowledgeTool(Tool):
items = await self._retriever.retrieve(query, top_k=5)
results = []
for item in items:
results.append({
results.append(
{
"content": item.value,
"score": item.score,
"source": item.metadata.get("source", "unknown"),
"document_title": item.metadata.get("document_title", ""),
})
}
)
return {"query": query, "results": results, "call_count": self._call_count}
except Exception as e:
return {"error": str(e), "results": []}

View File

@ -3,14 +3,45 @@
适配器模式对接外部 RAG 服务和知识图谱
"""
import logging
from typing import Any
from __future__ import annotations
from agentkit.memory.base import Memory, MemoryItem
import logging
from typing import TYPE_CHECKING, Protocol
from agentkit.memory.base import Memory, MemoryItem, MetadataDict
if TYPE_CHECKING:
from agentkit.memory.http_rag import RAGSearchResult
logger = logging.getLogger(__name__)
class _RAGServiceLike(Protocol):
"""RAG 检索服务最小接口契约duck-typed"""
async def search(
self,
query: str,
knowledge_base_ids: list[str] | None = ...,
top_k: int = ...,
) -> list[RAGSearchResult]: ...
async def enhanced_search(
self,
query: str,
knowledge_base_ids: list[str] | None = ...,
top_k: int = ...,
use_rerank: bool = ...,
use_compression: bool = ...,
) -> list[RAGSearchResult]: ...
class _GraphServiceLike(Protocol):
"""知识图谱服务最小接口契约duck-typed"""
async def query(self, query: str, depth: int = ...) -> list[dict[str, object]]: ...
class SemanticMemory(Memory):
"""Semantic Memory - 知识库检索
@ -19,8 +50,8 @@ class SemanticMemory(Memory):
def __init__(
self,
rag_service: Any = None,
graph_service: Any = None,
rag_service: _RAGServiceLike | None = None,
graph_service: _GraphServiceLike | None = None,
knowledge_base_ids: list[str] | None = None,
search_mode: str = "standard",
use_rerank: bool = True,
@ -45,9 +76,9 @@ class SemanticMemory(Memory):
self._use_compression = use_compression
self._kb_weights = kb_weights
async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None:
async def store(self, key: str, value: object, metadata: MetadataDict | None = None) -> None:
"""Semantic Memory 通常只读,写入委托给 RAG 服务的 ingest 方法"""
if self._rag_service and hasattr(self._rag_service, 'ingest'):
if self._rag_service and hasattr(self._rag_service, "ingest"):
await self._rag_service.ingest(key, value, metadata)
else:
logger.warning("SemanticMemory.store: no RAG service configured for writing")
@ -56,7 +87,9 @@ class SemanticMemory(Memory):
"""按 key 精确检索Semantic Memory 通常不按 key 检索)"""
return None
async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None) -> list[MemoryItem]:
async def search(
self, query: str, top_k: int = 5, filters: MetadataDict | None = None
) -> list[MemoryItem]:
"""语义检索知识库"""
items = []
@ -64,7 +97,9 @@ class SemanticMemory(Memory):
if self._rag_service:
try:
kb_ids = (filters or {}).get("knowledge_base_ids", self._knowledge_base_ids)
if self._search_mode == "enhanced" and hasattr(self._rag_service, "enhanced_search"):
if self._search_mode == "enhanced" and hasattr(
self._rag_service, "enhanced_search"
):
results = await self._rag_service.enhanced_search(
query,
knowledge_base_ids=kb_ids,
@ -73,14 +108,17 @@ class SemanticMemory(Memory):
use_compression=self._use_compression,
)
else:
results = await self._rag_service.search(query, knowledge_base_ids=kb_ids, top_k=top_k)
results = await self._rag_service.search(
query, knowledge_base_ids=kb_ids, top_k=top_k
)
for r in results:
kb_id = r.get("knowledge_base_id", "")
score = r.get("score", 0.0)
# Apply per-KB weights
if self._kb_weights and kb_id in self._kb_weights:
score *= self._kb_weights[kb_id]
items.append(MemoryItem(
items.append(
MemoryItem(
key=r.get("id", ""),
value=r.get("content", ""),
metadata={
@ -90,7 +128,8 @@ class SemanticMemory(Memory):
"knowledge_base_id": kb_id,
},
score=score,
))
)
)
except Exception as e:
logger.error(f"RAG search failed: {e}")
@ -99,7 +138,8 @@ class SemanticMemory(Memory):
try:
graph_results = await self._graph_service.query(query, depth=2)
for r in graph_results[:top_k]:
items.append(MemoryItem(
items.append(
MemoryItem(
key=r.get("id", ""),
value=r.get("content", ""),
metadata={
@ -108,7 +148,8 @@ class SemanticMemory(Memory):
"relations": r.get("relations", []),
},
score=r.get("score", 0.0),
))
)
)
except Exception as e:
logger.error(f"Graph search failed: {e}")

View File

@ -3,11 +3,10 @@
import json
import logging
from datetime import datetime, timezone
from typing import Any
import redis.asyncio as aioredis
from agentkit.memory.base import Memory, MemoryItem
from agentkit.memory.base import Memory, MemoryItem, MetadataDict
logger = logging.getLogger(__name__)
@ -32,7 +31,7 @@ class WorkingMemory(Memory):
def _make_key(self, key: str) -> str:
return f"{self._key_prefix}:{key}"
async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None:
async def store(self, key: str, value: object, metadata: MetadataDict | None = None) -> None:
redis_key = self._make_key(key)
item = MemoryItem(
key=key,
@ -57,10 +56,14 @@ class WorkingMemory(Memory):
value=item_dict["value"],
metadata=item_dict.get("metadata", {}),
score=item_dict.get("score", 1.0),
created_at=datetime.fromisoformat(item_dict["created_at"]) if item_dict.get("created_at") else datetime.now(timezone.utc),
created_at=datetime.fromisoformat(item_dict["created_at"])
if item_dict.get("created_at")
else datetime.now(timezone.utc),
)
async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None) -> list[MemoryItem]:
async def search(
self, query: str, top_k: int = 5, filters: MetadataDict | None = None
) -> list[MemoryItem]:
"""Working Memory 不支持语义检索,按 key 前缀匹配"""
pattern = self._make_key(f"{query}*")
keys = []
@ -74,13 +77,15 @@ class WorkingMemory(Memory):
data = await self._redis.get(key)
if data:
item_dict = json.loads(data)
items.append(MemoryItem(
items.append(
MemoryItem(
key=item_dict["key"],
value=item_dict["value"],
metadata=item_dict.get("metadata", {}),
score=1.0,
created_at=datetime.now(timezone.utc),
))
)
)
return items
async def delete(self, key: str) -> bool: