refactor(llm+memory+client): remove Any from type signatures

Eliminate 172 Any usages across llm/, memory/, client/ via:
- TypeAlias (MetadataValue, MetadataDict, RAGSearchResult, etc.)
- object for arbitrary dict/value types
- TYPE_CHECKING Protocol for Redis/Quota/RAG/Graph services
- TYPE_CHECKING import + string annotations for forward refs
- Remove unused Any imports (18 F401 fixed)

Tests: 253 passed (llm 21 failures are pre-existing litellm env issue)
ruff: All checks passed
This commit is contained in:
chiguyong 2026-07-01 02:03:51 +08:00
parent aa6367ff9f
commit 34a89c4873
37 changed files with 905 additions and 601 deletions

View File

@ -34,12 +34,19 @@ import os
import sqlite3 import sqlite3
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
from typing import Any, Callable from typing import Callable, TypeAlias
import httpx import httpx
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# 缓存的配置项 — 技能/工作流配置为 JSON 反序列化后的字典,值为标量或嵌套结构。
# 服务器返回的 skills/workflows 列表元素是 dictmodel_dump/to_dict 输出),
# 其中可能包含 list/dict 等容器,因此使用 object 作为值类型。
SkillConfigDict: TypeAlias = dict[str, object]
WorkflowConfigDict: TypeAlias = dict[str, object]
SyncedConfigPayload: TypeAlias = dict[str, object]
# ── Defaults ────────────────────────────────────────────────────────── # ── Defaults ──────────────────────────────────────────────────────────
@ -100,8 +107,8 @@ class ConfigSync:
# In-memory cache (mirrors the SQLite cache for fast access) # In-memory cache (mirrors the SQLite cache for fast access)
self._version: str | None = None self._version: str | None = None
self._skills: list[dict[str, Any]] = [] self._skills: list[SkillConfigDict] = []
self._workflows: list[dict[str, Any]] = [] self._workflows: list[WorkflowConfigDict] = []
self._last_synced_at: str | None = None self._last_synced_at: str | None = None
# ── Lifecycle ───────────────────────────────────────────────── # ── Lifecycle ─────────────────────────────────────────────────
@ -232,15 +239,15 @@ class ConfigSync:
"""Return the current cached config version hash.""" """Return the current cached config version hash."""
return self._version return self._version
def get_skills(self) -> list[dict[str, Any]]: def get_skills(self) -> list[SkillConfigDict]:
"""Return the cached skill configs.""" """Return the cached skill configs."""
return list(self._skills) return list(self._skills)
def get_workflows(self) -> list[dict[str, Any]]: def get_workflows(self) -> list[WorkflowConfigDict]:
"""Return the cached workflow configs.""" """Return the cached workflow configs."""
return list(self._workflows) return list(self._workflows)
def get_all(self) -> dict[str, Any]: def get_all(self) -> SyncedConfigPayload:
"""Return all cached configs as a single dict.""" """Return all cached configs as a single dict."""
return { return {
"version": self._version, "version": self._version,
@ -249,14 +256,14 @@ class ConfigSync:
"synced_at": self._last_synced_at, "synced_at": self._last_synced_at,
} }
def get_skill(self, name: str) -> dict[str, Any] | None: def get_skill(self, name: str) -> SkillConfigDict | None:
"""Return a single skill config by name, or ``None``.""" """Return a single skill config by name, or ``None``."""
for skill in self._skills: for skill in self._skills:
if skill.get("name") == name: if skill.get("name") == name:
return skill return skill
return None return None
def get_workflow(self, workflow_id: str) -> dict[str, Any] | None: def get_workflow(self, workflow_id: str) -> WorkflowConfigDict | None:
"""Return a single workflow config by ID, or ``None``.""" """Return a single workflow config by ID, or ``None``."""
for wf in self._workflows: for wf in self._workflows:
if wf.get("workflow_id") == workflow_id: if wf.get("workflow_id") == workflow_id:
@ -281,7 +288,7 @@ class ConfigSync:
conn.executescript(_CACHE_SCHEMA) conn.executescript(_CACHE_SCHEMA)
conn.commit() conn.commit()
def _save_to_cache(self, data: dict[str, Any]) -> None: def _save_to_cache(self, data: SyncedConfigPayload) -> None:
"""Save the synced configs to the local SQLite cache.""" """Save the synced configs to the local SQLite cache."""
now = datetime.now(timezone.utc).isoformat() now = datetime.now(timezone.utc).isoformat()
with sqlite3.connect(str(self.cache_db_path)) as conn: with sqlite3.connect(str(self.cache_db_path)) as conn:

View File

@ -14,7 +14,7 @@ import logging
import time import time
from collections import OrderedDict from collections import OrderedDict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable from typing import TYPE_CHECKING, Protocol, runtime_checkable
from agentkit.llm.protocol import LLMResponse, TokenUsage, ToolCall from agentkit.llm.protocol import LLMResponse, TokenUsage, ToolCall
from agentkit.utils.vector_math import compute_cosine_similarity from agentkit.utils.vector_math import compute_cosine_similarity
@ -25,6 +25,52 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# TYPE_CHECKING Protocols — 避免 Any描述运行时 lazy import 的第三方对象
# ---------------------------------------------------------------------------
if TYPE_CHECKING:
class _RedisLike(Protocol):
"""Redis 客户端最小契约(仅覆盖本模块用到的方法)。"""
async def get(self, key: str) -> bytes | str | None: ...
async def mget(self, keys: list[str]) -> list[bytes | str | None]: ...
async def set(self, key: str, value: bytes | str, ex: int | None = None) -> None: ...
async def smembers(self, key: str) -> set[bytes | str]: ...
async def sadd(self, name: str, *values: str) -> int: ...
async def srem(self, name: str, *values: str) -> int: ...
async def scard(self, name: str) -> int: ...
async def delete(self, *names: str) -> int: ...
def pipeline(self) -> "_RedisPipelineLike": ...
class _RedisPipelineLike(Protocol):
"""Redis pipeline 最小契约。"""
def get(self, key: str) -> "_RedisPipelineLike": ...
def set(
self, key: str, value: bytes | str, ex: int | None = None
) -> "_RedisPipelineLike": ...
def delete(self, *names: str) -> "_RedisPipelineLike": ...
def sadd(self, name: str, *values: str) -> "_RedisPipelineLike": ...
def srem(self, name: str, *values: str) -> "_RedisPipelineLike": ...
async def execute(self) -> list[object]: ...
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Data Classes # Data Classes
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -328,7 +374,7 @@ class RedisLLMCache:
self._semantic_ttl = semantic_ttl self._semantic_ttl = semantic_ttl
self._similarity_threshold = similarity_threshold self._similarity_threshold = similarity_threshold
self._max_entries_to_scan = max_entries_to_scan self._max_entries_to_scan = max_entries_to_scan
self._redis: Any = None self._redis: _RedisLike | None = None
self._fallback: InMemoryLLMCache | None = fallback # For auto-degradation self._fallback: InMemoryLLMCache | None = fallback # For auto-degradation
self._degraded = False # True if Redis is unreachable self._degraded = False # True if Redis is unreachable
@ -691,7 +737,7 @@ class LitellmCacheManager:
def __init__(self, config: LitellmCacheConfig): def __init__(self, config: LitellmCacheConfig):
self._config = config self._config = config
self._cache_instance: Any = None # litellm.caching.Cache 实例 self._cache_instance: object | None = None # litellm.caching.Cache 实例
self._hits = 0 self._hits = 0
self._misses = 0 self._misses = 0
@ -709,7 +755,7 @@ class LitellmCacheManager:
litellm.cache = None litellm.cache = None
self._cache_instance = None self._cache_instance = None
def _create_cache_instance(self) -> Any: def _create_cache_instance(self) -> object:
"""根据 backend 配置创建 LiteLLM Cache 实例。 """根据 backend 配置创建 LiteLLM Cache 实例。
auto 模式按优先级尝试RedisSemanticCache RedisCache InMemoryCache auto 模式按优先级尝试RedisSemanticCache RedisCache InMemoryCache

View File

@ -2,14 +2,13 @@
import hashlib import hashlib
import json import json
from typing import Any
def generate_cache_key( def generate_cache_key(
model: str, model: str,
messages: list[dict[str, str]], messages: list[dict[str, str]],
temperature: float, temperature: float,
tools: list[dict[str, Any]] | None = None, tools: list[dict[str, object]] | None = None,
tool_choice: str = "auto", tool_choice: str = "auto",
max_tokens: int = 2000, max_tokens: int = 2000,
user_id: str | None = None, user_id: str | None = None,

View File

@ -3,12 +3,12 @@
import json import json
import logging import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, TYPE_CHECKING from typing import TYPE_CHECKING
from agentkit.llm.retry import CircuitBreakerConfig, RetryConfig from agentkit.llm.retry import CircuitBreakerConfig, RetryConfig
if TYPE_CHECKING: if TYPE_CHECKING:
from agentkit.channels.secrets import SecretsStore from agentkit.channels.secrets import SecretEntry, SecretsStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -56,7 +56,7 @@ class ProviderConfig:
api_key: str api_key: str
base_url: str base_url: str
models: dict[str, dict[str, Any]] = field(default_factory=dict) models: dict[str, dict[str, object]] = field(default_factory=dict)
type: str = "openai" # "openai" | "anthropic" | "gemini" type: str = "openai" # "openai" | "anthropic" | "gemini"
max_tokens: int = 4096 # Anthropic: default max_tokens max_tokens: int = 4096 # Anthropic: default max_tokens
timeout: float = 120.0 # Anthropic: request timeout timeout: float = 120.0 # Anthropic: request timeout
@ -168,18 +168,18 @@ class ProviderConfig:
return f"llm:provider:{self.type}:api_key" return f"llm:provider:{self.type}:api_key"
@staticmethod @staticmethod
def _encode_secret_entry(entry: Any, key: str) -> str: def _encode_secret_entry(entry: object, key: str) -> str:
"""把 SecretEntry 编码为 JSON 字符串(含 key 字段)。""" """把 SecretEntry 编码为 JSON 字符串(含 key 字段)。"""
# entry 是 SecretEntry pydantic 模型,有 model_dump() # entry 是 SecretEntry pydantic 模型,有 model_dump()
if hasattr(entry, "model_dump"): if hasattr(entry, "model_dump"):
data = entry.model_dump() data = entry.model_dump() # type: ignore[attr-defined]
else: else:
data = dict(entry) data = dict(entry) # type: ignore[call-overload]
data["key"] = key data["key"] = key
return json.dumps(data) return json.dumps(data)
@staticmethod @staticmethod
def _decode_secret_entry(encoded: str) -> Any: def _decode_secret_entry(encoded: str) -> "SecretEntry":
"""从 JSON 字符串解码 SecretEntry。返回带 .key 属性的对象。""" """从 JSON 字符串解码 SecretEntry。返回带 .key 属性的对象。"""
from agentkit.channels.secrets import SecretEntry from agentkit.channels.secrets import SecretEntry

View File

@ -5,7 +5,7 @@ import logging
import time import time
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
from typing import Any from typing import TYPE_CHECKING, Protocol
from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError
from agentkit.llm.config import LLMConfig from agentkit.llm.config import LLMConfig
@ -14,9 +14,40 @@ from agentkit.llm.providers.tracker import UsageSummary, UsageTracker
from agentkit.telemetry.tracing import get_tracer, _OTEL_AVAILABLE from agentkit.telemetry.tracing import get_tracer, _OTEL_AVAILABLE
from agentkit.telemetry.metrics import llm_token_histogram from agentkit.telemetry.metrics import llm_token_histogram
if TYPE_CHECKING:
from agentkit.llm.cache import LitellmCacheManager
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# TYPE_CHECKING Protocols — 避免 Any描述运行时 lazy import 的对象
# ---------------------------------------------------------------------------
if TYPE_CHECKING:
class _QuotaServiceLike(Protocol):
"""Quota service 最小契约(仅覆盖 gateway._enforce_quota 用到的方法)。"""
async def is_model_allowed(
self, db: Path, department_id: str, model: str
) -> tuple[bool, str]: ...
async def check_quota(
self,
db: Path,
department_id: str,
quota_type: str,
period: str,
current: float,
) -> tuple[bool, str]: ...
async def get_quota(
self, db: Path, department_id: str, quota_type: str, period: str
) -> dict[str, object] | None: ...
class QuotaExceededError(Exception): class QuotaExceededError(Exception):
"""Raised when a department's LLM quota is exceeded. """Raised when a department's LLM quota is exceeded.
@ -29,8 +60,8 @@ class QuotaExceededError(Exception):
department_id: str, department_id: str,
quota_type: str, quota_type: str,
period: str, period: str,
limit: Any, limit: object,
current: Any, current: object,
) -> None: ) -> None:
self.department_id = department_id self.department_id = department_id
self.quota_type = quota_type self.quota_type = quota_type
@ -46,13 +77,13 @@ class QuotaExceededError(Exception):
class LLMGateway: class LLMGateway:
"""LLM 网关 - Provider 注册、模型别名解析、Fallback、Usage 追踪、Cache""" """LLM 网关 - Provider 注册、模型别名解析、Fallback、Usage 追踪、Cache"""
def __init__(self, config: LLMConfig | None = None, usage_store: Any = None): def __init__(self, config: LLMConfig | None = None, usage_store: object | None = None):
self._providers: dict[str, LLMProvider] = {} self._providers: dict[str, LLMProvider] = {}
self._usage_tracker = UsageTracker(store=usage_store) if usage_store else UsageTracker() self._usage_tracker = UsageTracker(store=usage_store) if usage_store else UsageTracker()
self._config = config or LLMConfig() self._config = config or LLMConfig()
# Cache (U17 — LiteLLM 缓存管理器opt-in默认禁用) # Cache (U17 — LiteLLM 缓存管理器opt-in默认禁用)
self._cache_manager: Any = None # LitellmCacheManager | None self._cache_manager: "LitellmCacheManager | None" = None
if self._config.cache and self._config.cache.enabled: if self._config.cache and self._config.cache.enabled:
from agentkit.llm.cache import LitellmCacheConfig, LitellmCacheManager from agentkit.llm.cache import LitellmCacheConfig, LitellmCacheManager
@ -601,7 +632,7 @@ class LLMGateway:
async def _check_quota_value( async def _check_quota_value(
self, self,
quota_service: Any, quota_service: _QuotaServiceLike,
db: Path, db: Path,
dept_id: str, dept_id: str,
period: str, period: str,

View File

@ -27,12 +27,11 @@ from __future__ import annotations
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def migrate_api_keys_to_secrets(config_path: Path | str) -> dict[str, dict[str, Any]]: def migrate_api_keys_to_secrets(config_path: Path | str) -> dict[str, dict[str, object]]:
"""把 agentkit.yaml 中的 plaintext API Key 迁移到 SecretsStore。 """把 agentkit.yaml 中的 plaintext API Key 迁移到 SecretsStore。
流程 流程
@ -63,8 +62,8 @@ def migrate_api_keys_to_secrets(config_path: Path | str) -> dict[str, dict[str,
store = SecretsStore() # master key 从 env 加载 store = SecretsStore() # master key 从 env 加载
async def _run() -> dict[str, dict[str, Any]]: async def _run() -> dict[str, dict[str, object]]:
report: dict[str, dict[str, Any]] = {} report: dict[str, dict[str, object]] = {}
for name, pconf in llm_config.providers.items(): for name, pconf in llm_config.providers.items():
if pconf.api_key_source == "secrets_store" and not pconf.api_key: if pconf.api_key_source == "secrets_store" and not pconf.api_key:
report[name] = {"status": "skipped", "source": pconf.api_key_source} report[name] = {"status": "skipped", "source": pconf.api_key_source}
@ -93,9 +92,9 @@ def migrate_api_keys_to_secrets(config_path: Path | str) -> dict[str, dict[str,
report = asyncio.run(_run()) report = asyncio.run(_run())
# 写回 YAML更新 llm.providers 段,保留其它段 # 写回 YAML更新 llm.providers 段,保留其它段
providers_out: dict[str, dict[str, Any]] = {} providers_out: dict[str, dict[str, object]] = {}
for name, pconf in llm_config.providers.items(): for name, pconf in llm_config.providers.items():
entry: dict[str, Any] = { entry: dict[str, object] = {
"type": pconf.type, "type": pconf.type,
"base_url": pconf.base_url, "base_url": pconf.base_url,
"models": pconf.models, "models": pconf.models,

View File

@ -2,7 +2,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any
@dataclass @dataclass
@ -23,7 +22,7 @@ class ToolCall:
id: str id: str
name: str name: str
arguments: dict[str, Any] arguments: dict[str, object]
@dataclass @dataclass
@ -32,7 +31,7 @@ class LLMRequest:
messages: list[dict[str, str]] messages: list[dict[str, str]]
model: str model: str
tools: list[dict[str, Any]] | None = None tools: list[dict[str, object]] | None = None
tool_choice: str = "auto" tool_choice: str = "auto"
temperature: float = 0.7 temperature: float = 0.7
max_tokens: int = 2000 max_tokens: int = 2000
@ -42,13 +41,13 @@ class LLMRequest:
self, self,
messages: list[dict[str, str]], messages: list[dict[str, str]],
model: str, model: str,
tools: list[dict[str, Any]] | None = None, tools: list[dict[str, object]] | None = None,
tool_choice: str = "auto", tool_choice: str = "auto",
temperature: float = 0.7, temperature: float = 0.7,
max_tokens: int = 2000, max_tokens: int = 2000,
timeout: float | None = None, timeout: float | None = None,
cache: dict[str, Any] | None = None, cache: dict[str, object] | None = None,
**kwargs: Any, **kwargs: object,
): ):
self.messages = messages self.messages = messages
self.model = model self.model = model
@ -59,7 +58,7 @@ class LLMRequest:
self.timeout = timeout self.timeout = timeout
self._extra = kwargs self._extra = kwargs
# U17 — LiteLLM cache 参数cache_key 或 no-cache透传到 litellm.acompletion # U17 — LiteLLM cache 参数cache_key 或 no-cache透传到 litellm.acompletion
self._cache: dict[str, Any] | None = cache self._cache: dict[str, object] | None = cache
@dataclass @dataclass

View File

@ -3,7 +3,6 @@
import json import json
import logging import logging
import time import time
from typing import Any
import httpx import httpx
@ -99,7 +98,9 @@ class AnthropicProvider(LLMProvider):
"content-type": "application/json", "content-type": "application/json",
} }
def _convert_messages(self, messages: list[dict[str, str]]) -> tuple[str | list[dict[str, Any]] | None, list[dict[str, Any]]]: def _convert_messages(
self, messages: list[dict[str, str]]
) -> tuple[str | list[dict[str, object]] | None, list[dict[str, object]]]:
"""将 OpenAI 风格消息转换为 Anthropic 格式 """将 OpenAI 风格消息转换为 Anthropic 格式
Returns: Returns:
@ -110,8 +111,8 @@ class AnthropicProvider(LLMProvider):
- list[dict]: Anthropic content blocks(支持 cache_control,U2/G2) - list[dict]: Anthropic content blocks(支持 cache_control,U2/G2)
- None: system 消息 - None: system 消息
""" """
system_prompt: str | list[dict[str, Any]] | None = None system_prompt: str | list[dict[str, object]] | None = None
anthropic_messages: list[dict[str, Any]] = [] anthropic_messages: list[dict[str, object]] = []
for msg in messages: for msg in messages:
role = msg.get("role", "") role = msg.get("role", "")
@ -127,7 +128,7 @@ class AnthropicProvider(LLMProvider):
# 检查是否有 tool_calls (OpenAI 格式) # 检查是否有 tool_calls (OpenAI 格式)
tool_calls = msg.get("tool_calls") tool_calls = msg.get("tool_calls")
if tool_calls: if tool_calls:
blocks: list[dict[str, Any]] = [] blocks: list[dict[str, object]] = []
# 如果有文本内容,先添加文本块 # 如果有文本内容,先添加文本块
if content: if content:
blocks.append({"type": "text", "text": content}) blocks.append({"type": "text", "text": content})
@ -139,25 +140,29 @@ class AnthropicProvider(LLMProvider):
arguments = json.loads(arguments) arguments = json.loads(arguments)
except json.JSONDecodeError: except json.JSONDecodeError:
arguments = {"raw": arguments} arguments = {"raw": arguments}
blocks.append({ blocks.append(
"type": "tool_use", {
"id": tc.get("id", ""), "type": "tool_use",
"name": func.get("name", ""), "id": tc.get("id", ""),
"input": arguments, "name": func.get("name", ""),
}) "input": arguments,
}
)
anthropic_messages.append({"role": "assistant", "content": blocks}) anthropic_messages.append({"role": "assistant", "content": blocks})
else: else:
anthropic_messages.append({ anthropic_messages.append(
"role": "assistant", {
"content": [{"type": "text", "text": content}], "role": "assistant",
}) "content": [{"type": "text", "text": content}],
}
)
continue continue
if role == "user": if role == "user":
# 检查是否是 tool_result 消息 (OpenAI 格式中 tool 角色的结果) # 检查是否是 tool_result 消息 (OpenAI 格式中 tool 角色的结果)
# OpenAI 格式: {"role": "tool", "tool_call_id": "...", "content": "..."} # OpenAI 格式: {"role": "tool", "tool_call_id": "...", "content": "..."}
if msg.get("tool_call_id"): if msg.get("tool_call_id"):
tool_result_blocks: list[dict[str, Any]] = [] tool_result_blocks: list[dict[str, object]] = []
tool_content = msg.get("content", "") tool_content = msg.get("content", "")
# tool_result 的 content 可以是字符串或内容块列表 # tool_result 的 content 可以是字符串或内容块列表
if isinstance(tool_content, str): if isinstance(tool_content, str):
@ -167,56 +172,72 @@ class AnthropicProvider(LLMProvider):
else: else:
tool_result_blocks.append({"type": "text", "text": str(tool_content)}) tool_result_blocks.append({"type": "text", "text": str(tool_content)})
anthropic_messages.append({ anthropic_messages.append(
"role": "user", {
"content": [{ "role": "user",
"type": "tool_result", "content": [
"tool_use_id": msg.get("tool_call_id", ""), {
"content": tool_result_blocks, "type": "tool_result",
}], "tool_use_id": msg.get("tool_call_id", ""),
}) "content": tool_result_blocks,
}
],
}
)
else: else:
anthropic_messages.append({ anthropic_messages.append(
"role": "user", {
"content": [{"type": "text", "text": content}], "role": "user",
}) "content": [{"type": "text", "text": content}],
}
)
continue continue
if role == "tool": if role == "tool":
# OpenAI 格式中独立的 tool 消息 # OpenAI 格式中独立的 tool 消息
tool_content = msg.get("content", "") tool_content = msg.get("content", "")
if isinstance(tool_content, str): if isinstance(tool_content, str):
result_content: list[dict[str, Any]] | str = [{"type": "text", "text": tool_content}] result_content: list[dict[str, object]] | str = [
{"type": "text", "text": tool_content}
]
elif isinstance(tool_content, list): elif isinstance(tool_content, list):
result_content = tool_content result_content = tool_content
else: else:
result_content = [{"type": "text", "text": str(tool_content)}] result_content = [{"type": "text", "text": str(tool_content)}]
anthropic_messages.append({ anthropic_messages.append(
"role": "user", {
"content": [{ "role": "user",
"type": "tool_result", "content": [
"tool_use_id": msg.get("tool_call_id", ""), {
"content": result_content, "type": "tool_result",
}], "tool_use_id": msg.get("tool_call_id", ""),
}) "content": result_content,
}
],
}
)
return system_prompt, anthropic_messages return system_prompt, anthropic_messages
def _convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]: def _convert_tools(self, tools: list[dict[str, object]]) -> list[dict[str, object]]:
"""将 OpenAI function 格式转换为 Anthropic tool 格式""" """将 OpenAI function 格式转换为 Anthropic tool 格式"""
anthropic_tools = [] anthropic_tools = []
for tool in tools: for tool in tools:
if tool.get("type") == "function": if tool.get("type") == "function":
func = tool.get("function", {}) func = tool.get("function", {})
anthropic_tools.append({ anthropic_tools.append(
"name": func.get("name", ""), {
"description": func.get("description", ""), "name": func.get("name", ""),
"input_schema": func.get("parameters", {"type": "object", "properties": {}}), "description": func.get("description", ""),
}) "input_schema": func.get(
"parameters", {"type": "object", "properties": {}}
),
}
)
return anthropic_tools return anthropic_tools
def _convert_tool_choice(self, tool_choice: str) -> dict[str, Any] | None: def _convert_tool_choice(self, tool_choice: str) -> dict[str, object] | None:
"""将 OpenAI tool_choice 格式转换为 Anthropic 格式""" """将 OpenAI tool_choice 格式转换为 Anthropic 格式"""
if tool_choice == "auto": if tool_choice == "auto":
return {"type": "auto"} return {"type": "auto"}
@ -227,7 +248,7 @@ class AnthropicProvider(LLMProvider):
return {"type": "tool", "name": tool_choice} return {"type": "tool", "name": tool_choice}
return None return None
def _parse_response(self, data: dict[str, Any], model: str) -> LLMResponse: def _parse_response(self, data: dict[str, object], model: str) -> LLMResponse:
"""将 Anthropic 响应转换为 LLMResponse""" """将 Anthropic 响应转换为 LLMResponse"""
content_blocks = data.get("content", []) content_blocks = data.get("content", [])
text_parts: list[str] = [] text_parts: list[str] = []
@ -238,11 +259,13 @@ class AnthropicProvider(LLMProvider):
if block_type == "text": if block_type == "text":
text_parts.append(block.get("text", "")) text_parts.append(block.get("text", ""))
elif block_type == "tool_use": elif block_type == "tool_use":
tool_calls.append(ToolCall( tool_calls.append(
id=block.get("id", ""), ToolCall(
name=block.get("name", ""), id=block.get("id", ""),
arguments=block.get("input", {}), name=block.get("name", ""),
)) arguments=block.get("input", {}),
)
)
usage_data = data.get("usage", {}) usage_data = data.get("usage", {})
usage = TokenUsage( usage = TokenUsage(
@ -287,7 +310,7 @@ class AnthropicProvider(LLMProvider):
system_prompt, anthropic_messages = self._convert_messages(request.messages) system_prompt, anthropic_messages = self._convert_messages(request.messages)
payload: dict[str, Any] = { payload: dict[str, object] = {
"model": request.model, "model": request.model,
"max_tokens": request.max_tokens or self._max_tokens, "max_tokens": request.max_tokens or self._max_tokens,
"messages": anthropic_messages, "messages": anthropic_messages,
@ -346,7 +369,7 @@ class AnthropicProvider(LLMProvider):
system_prompt, anthropic_messages = self._convert_messages(request.messages) system_prompt, anthropic_messages = self._convert_messages(request.messages)
payload: dict[str, Any] = { payload: dict[str, object] = {
"model": request.model, "model": request.model,
"max_tokens": request.max_tokens or self._max_tokens, "max_tokens": request.max_tokens or self._max_tokens,
"messages": anthropic_messages, "messages": anthropic_messages,
@ -375,7 +398,7 @@ class AnthropicProvider(LLMProvider):
async def _iterate_stream(self, response, request: LLMRequest): async def _iterate_stream(self, response, request: LLMRequest):
"""Iterate over an already-open SSE stream and yield StreamChunks.""" """Iterate over an already-open SSE stream and yield StreamChunks."""
# Accumulated tool calls: tool_use_id -> {id, name, input_json_str} # Accumulated tool calls: tool_use_id -> {id, name, input_json_str}
accumulated_tool_calls: dict[str, dict[str, Any]] = {} accumulated_tool_calls: dict[str, dict[str, object]] = {}
current_tool_id: str | None = None current_tool_id: str | None = None
current_tool_name: str | None = None current_tool_name: str | None = None
current_tool_input_json: str = "" current_tool_input_json: str = ""
@ -433,7 +456,9 @@ class AnthropicProvider(LLMProvider):
# Finalize current tool call if any # Finalize current tool call if any
if current_tool_id is not None: if current_tool_id is not None:
try: try:
arguments = json.loads(current_tool_input_json) if current_tool_input_json else {} arguments = (
json.loads(current_tool_input_json) if current_tool_input_json else {}
)
except json.JSONDecodeError: except json.JSONDecodeError:
arguments = {"raw": current_tool_input_json} arguments = {"raw": current_tool_input_json}
@ -510,7 +535,7 @@ class AnthropicProvider(LLMProvider):
error_msg = error_info.get("message", "Stream error") error_msg = error_info.get("message", "Stream error")
raise LLMProviderError("anthropic", error_msg) raise LLMProviderError("anthropic", error_msg)
def get_model_info(self) -> dict[str, Any]: def get_model_info(self) -> dict[str, object]:
"""返回 Provider 和模型信息""" """返回 Provider 和模型信息"""
return { return {
"provider": "anthropic", "provider": "anthropic",

View File

@ -8,7 +8,6 @@ API火山引擎 OpenAI 兼容接口
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Any
from agentkit.llm.providers.openai import OpenAICompatibleProvider from agentkit.llm.providers.openai import OpenAICompatibleProvider
@ -48,7 +47,7 @@ class DoubaoProvider(OpenAICompatibleProvider):
api_key: str, api_key: str,
base_url: str = DOUBAO_DEFAULT_BASE_URL, base_url: str = DOUBAO_DEFAULT_BASE_URL,
default_model: str = "doubao-pro-32k", default_model: str = "doubao-pro-32k",
**kwargs: Any, **kwargs: object,
): ):
super().__init__( super().__init__(
api_key=api_key, api_key=api_key,

View File

@ -3,7 +3,6 @@
import json import json
import logging import logging
import time import time
from typing import Any
import httpx import httpx
@ -90,14 +89,14 @@ class GeminiProvider(LLMProvider):
def _convert_messages( def _convert_messages(
self, messages: list[dict[str, str]] self, messages: list[dict[str, str]]
) -> tuple[dict[str, Any] | None, list[dict[str, Any]]]: ) -> tuple[dict[str, object] | None, list[dict[str, object]]]:
"""将 OpenAI 风格消息转换为 Gemini 格式 """将 OpenAI 风格消息转换为 Gemini 格式
Returns: Returns:
(system_instruction, contents) (system_instruction, contents)
""" """
system_instruction: dict[str, Any] | None = None system_instruction: dict[str, object] | None = None
contents: list[dict[str, Any]] = [] contents: list[dict[str, object]] = []
for msg in messages: for msg in messages:
role = msg.get("role", "") role = msg.get("role", "")
@ -119,28 +118,34 @@ class GeminiProvider(LLMProvider):
tool_name = parsed.get("name", "") tool_name = parsed.get("name", "")
except (json.JSONDecodeError, AttributeError): except (json.JSONDecodeError, AttributeError):
pass pass
contents.append({ contents.append(
"role": "user", {
"parts": [{ "role": "user",
"functionResponse": { "parts": [
"name": tool_name, {
"response": { "functionResponse": {
"content": content, "name": tool_name,
}, "response": {
}, "content": content,
}], },
}) },
}
],
}
)
else: else:
contents.append({ contents.append(
"role": "user", {
"parts": [{"text": content}], "role": "user",
}) "parts": [{"text": content}],
}
)
continue continue
if role == "assistant": if role == "assistant":
tool_calls = msg.get("tool_calls") tool_calls = msg.get("tool_calls")
if tool_calls: if tool_calls:
parts: list[dict[str, Any]] = [] parts: list[dict[str, object]] = []
if content: if content:
parts.append({"text": content}) parts.append({"text": content})
for tc in tool_calls: for tc in tool_calls:
@ -151,54 +156,64 @@ class GeminiProvider(LLMProvider):
arguments = json.loads(arguments) arguments = json.loads(arguments)
except json.JSONDecodeError: except json.JSONDecodeError:
arguments = {"raw": arguments} arguments = {"raw": arguments}
parts.append({ parts.append(
"functionCall": { {
"name": func.get("name", ""), "functionCall": {
"args": arguments, "name": func.get("name", ""),
}, "args": arguments,
}) },
}
)
contents.append({"role": "model", "parts": parts}) contents.append({"role": "model", "parts": parts})
else: else:
contents.append({ contents.append(
"role": "model", {
"parts": [{"text": content}], "role": "model",
}) "parts": [{"text": content}],
}
)
continue continue
if role == "tool": if role == "tool":
# OpenAI format: {"role": "tool", "tool_call_id": "...", "content": "..."} # OpenAI format: {"role": "tool", "tool_call_id": "...", "content": "..."}
tool_name = msg.get("name", "") tool_name = msg.get("name", "")
tool_content = msg.get("content", "") tool_content = msg.get("content", "")
contents.append({ contents.append(
"role": "user", {
"parts": [{ "role": "user",
"functionResponse": { "parts": [
"name": tool_name, {
"response": { "functionResponse": {
"content": tool_content, "name": tool_name,
}, "response": {
}, "content": tool_content,
}], },
}) },
}
],
}
)
return system_instruction, contents return system_instruction, contents
def _convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]: def _convert_tools(self, tools: list[dict[str, object]]) -> list[dict[str, object]]:
"""将 OpenAI function 格式转换为 Gemini functionDeclarations""" """将 OpenAI function 格式转换为 Gemini functionDeclarations"""
declarations = [] declarations = []
for tool in tools: for tool in tools:
if tool.get("type") == "function": if tool.get("type") == "function":
func = tool.get("function", {}) func = tool.get("function", {})
declarations.append({ declarations.append(
"name": func.get("name", ""), {
"description": func.get("description", ""), "name": func.get("name", ""),
"parameters": func.get("parameters", {"type": "object", "properties": {}}), "description": func.get("description", ""),
}) "parameters": func.get("parameters", {"type": "object", "properties": {}}),
}
)
if not declarations: if not declarations:
return [] return []
return [{"functionDeclarations": declarations}] return [{"functionDeclarations": declarations}]
def _convert_tool_choice(self, tool_choice: str) -> dict[str, Any] | None: def _convert_tool_choice(self, tool_choice: str) -> dict[str, object] | None:
"""将 OpenAI tool_choice 格式转换为 Gemini toolConfig""" """将 OpenAI tool_choice 格式转换为 Gemini toolConfig"""
if tool_choice == "auto": if tool_choice == "auto":
return {"functionCallingConfig": {"mode": "AUTO"}} return {"functionCallingConfig": {"mode": "AUTO"}}
@ -210,7 +225,7 @@ class GeminiProvider(LLMProvider):
return {"functionCallingConfig": {"mode": "NONE"}} return {"functionCallingConfig": {"mode": "NONE"}}
return None return None
def _parse_response(self, data: dict[str, Any], model: str) -> LLMResponse: def _parse_response(self, data: dict[str, object], model: str) -> LLMResponse:
"""将 Gemini 响应转换为 LLMResponse""" """将 Gemini 响应转换为 LLMResponse"""
candidates = data.get("candidates", []) candidates = data.get("candidates", [])
text_parts: list[str] = [] text_parts: list[str] = []
@ -225,11 +240,13 @@ class GeminiProvider(LLMProvider):
text_parts.append(part["text"]) text_parts.append(part["text"])
elif "functionCall" in part: elif "functionCall" in part:
fc = part["functionCall"] fc = part["functionCall"]
tool_calls.append(ToolCall( tool_calls.append(
id=f"call_{tool_call_index}", ToolCall(
name=fc.get("name", ""), id=f"call_{tool_call_index}",
arguments=fc.get("args", {}), name=fc.get("name", ""),
)) arguments=fc.get("args", {}),
)
)
tool_call_index += 1 tool_call_index += 1
usage_metadata = data.get("usageMetadata", {}) usage_metadata = data.get("usageMetadata", {})
@ -275,7 +292,7 @@ class GeminiProvider(LLMProvider):
system_instruction, contents = self._convert_messages(request.messages) system_instruction, contents = self._convert_messages(request.messages)
payload: dict[str, Any] = { payload: dict[str, object] = {
"contents": contents, "contents": contents,
"generationConfig": { "generationConfig": {
"temperature": request.temperature, "temperature": request.temperature,
@ -340,7 +357,7 @@ class GeminiProvider(LLMProvider):
system_instruction, contents = self._convert_messages(request.messages) system_instruction, contents = self._convert_messages(request.messages)
payload: dict[str, Any] = { payload: dict[str, object] = {
"contents": contents, "contents": contents,
"generationConfig": { "generationConfig": {
"temperature": request.temperature, "temperature": request.temperature,
@ -374,7 +391,7 @@ class GeminiProvider(LLMProvider):
async def _iterate_stream(self, response, request: LLMRequest): async def _iterate_stream(self, response, request: LLMRequest):
"""Iterate over an already-open SSE stream and yield StreamChunks.""" """Iterate over an already-open SSE stream and yield StreamChunks."""
accumulated_tool_calls: list[dict[str, Any]] = [] accumulated_tool_calls: list[dict[str, object]] = []
model = request.model or self._model model = request.model or self._model
async for line in response.aiter_lines(): async for line in response.aiter_lines():
@ -436,11 +453,13 @@ class GeminiProvider(LLMProvider):
) )
elif "functionCall" in part: elif "functionCall" in part:
fc = part["functionCall"] fc = part["functionCall"]
accumulated_tool_calls.append({ accumulated_tool_calls.append(
"id": f"call_{len(accumulated_tool_calls)}", {
"name": fc.get("name", ""), "id": f"call_{len(accumulated_tool_calls)}",
"arguments": fc.get("args", {}), "name": fc.get("name", ""),
}) "arguments": fc.get("args", {}),
}
)
# Check for finish reason # Check for finish reason
finish_reason = candidates[0].get("finishReason", "") finish_reason = candidates[0].get("finishReason", "")
@ -461,7 +480,7 @@ class GeminiProvider(LLMProvider):
) )
accumulated_tool_calls = [] accumulated_tool_calls = []
def get_model_info(self) -> dict[str, Any]: def get_model_info(self) -> dict[str, object]:
"""返回 Provider 和模型信息""" """返回 Provider 和模型信息"""
return { return {
"provider": "gemini", "provider": "gemini",

View File

@ -26,8 +26,7 @@ import inspect
import json import json
import logging import logging
import time import time
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator, Iterable
from typing import Any
from agentkit.core.exceptions import LLMProviderError from agentkit.core.exceptions import LLMProviderError
from agentkit.llm.protocol import ( from agentkit.llm.protocol import (
@ -81,13 +80,13 @@ class LitellmProvider(LLMProvider):
api_key: str, api_key: str,
base_url: str | None = None, base_url: str | None = None,
provider_type: str = "openai", provider_type: str = "openai",
**default_kwargs: Any, **default_kwargs: object,
) -> None: ) -> None:
self._model_prefix = model_prefix self._model_prefix = model_prefix
self._api_key = api_key self._api_key = api_key
self._base_url = base_url or None # 空字符串视作未设置 self._base_url = base_url or None # 空字符串视作未设置
self._provider_type = provider_type self._provider_type = provider_type
self._default_kwargs: dict[str, Any] = dict(default_kwargs) self._default_kwargs: dict[str, object] = dict(default_kwargs)
async def chat(self, request: LLMRequest) -> LLMResponse: async def chat(self, request: LLMRequest) -> LLMResponse:
"""非流式 chat — 调用 ``litellm.acompletion`` 并翻译响应。""" """非流式 chat — 调用 ``litellm.acompletion`` 并翻译响应。"""
@ -116,7 +115,7 @@ class LitellmProvider(LLMProvider):
kwargs = self._build_kwargs(request, stream=True) kwargs = self._build_kwargs(request, stream=True)
accumulated_tool_calls: dict[int, dict[str, Any]] = {} accumulated_tool_calls: dict[int, dict[str, object]] = {}
final_usage: TokenUsage | None = None final_usage: TokenUsage | None = None
final_model: str = request.model final_model: str = request.model
@ -158,9 +157,9 @@ class LitellmProvider(LLMProvider):
# 内部辅助 # 内部辅助
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def _build_kwargs(self, request: LLMRequest, *, stream: bool) -> dict[str, Any]: def _build_kwargs(self, request: LLMRequest, *, stream: bool) -> dict[str, object]:
"""从 LLMRequest 构造 litellm.acompletion kwargs。""" """从 LLMRequest 构造 litellm.acompletion kwargs。"""
kwargs: dict[str, Any] = { kwargs: dict[str, object] = {
"model": f"{self._model_prefix}{request.model}", "model": f"{self._model_prefix}{request.model}",
"messages": request.messages, "messages": request.messages,
"temperature": request.temperature, "temperature": request.temperature,
@ -187,7 +186,7 @@ class LitellmProvider(LLMProvider):
def _parse_response( def _parse_response(
self, self,
response: Any, response: object,
request_model: str, request_model: str,
latency_ms: float, latency_ms: float,
) -> LLMResponse: ) -> LLMResponse:
@ -229,9 +228,9 @@ class LitellmProvider(LLMProvider):
def _parse_stream_chunk( def _parse_stream_chunk(
self, self,
chunk: Any, chunk: object,
request_model: str, request_model: str,
accumulated_tool_calls: dict[int, dict[str, Any]], accumulated_tool_calls: dict[int, dict[str, object]],
) -> StreamChunk: ) -> StreamChunk:
"""解析单个流式 chunk非 final。累计 tool_calls 到传入字典。""" """解析单个流式 chunk非 final。累计 tool_calls 到传入字典。"""
choices = getattr(chunk, "choices", None) or [] choices = getattr(chunk, "choices", None) or []
@ -262,7 +261,7 @@ class LitellmProvider(LLMProvider):
def _finalize_tool_calls( def _finalize_tool_calls(
self, self,
accumulated: dict[int, dict[str, Any]], accumulated: dict[int, dict[str, object]],
) -> list[ToolCall]: ) -> list[ToolCall]:
"""把累计的流式 tool_calls 字典转成 ToolCall 列表。""" """把累计的流式 tool_calls 字典转成 ToolCall 列表。"""
tool_calls: list[ToolCall] = [] tool_calls: list[ToolCall] = []
@ -288,7 +287,7 @@ class LitellmProvider(LLMProvider):
# ---------------------------------------------------------------------- # ----------------------------------------------------------------------
def _parse_tool_calls(raw_tool_calls: Any) -> list[ToolCall]: def _parse_tool_calls(raw_tool_calls: Iterable[object]) -> list[ToolCall]:
"""解析非流式响应的 tool_callsOpenAI 格式 list[ChoiceMessageToolCall])。""" """解析非流式响应的 tool_callsOpenAI 格式 list[ChoiceMessageToolCall])。"""
result: list[ToolCall] = [] result: list[ToolCall] = []
for tc in raw_tool_calls: for tc in raw_tool_calls:
@ -312,7 +311,7 @@ def _parse_tool_calls(raw_tool_calls: Any) -> list[ToolCall]:
return result return result
def _parse_usage(usage_obj: Any) -> TokenUsage: def _parse_usage(usage_obj: object) -> TokenUsage:
"""解析 usage 对象OpenAI CompletionUsage 或 dict""" """解析 usage 对象OpenAI CompletionUsage 或 dict"""
prompt = getattr(usage_obj, "prompt_tokens", None) prompt = getattr(usage_obj, "prompt_tokens", None)
completion = getattr(usage_obj, "completion_tokens", None) completion = getattr(usage_obj, "completion_tokens", None)
@ -326,8 +325,8 @@ def _parse_usage(usage_obj: Any) -> TokenUsage:
def _accumulate_stream_tool_calls( def _accumulate_stream_tool_calls(
raw_tool_calls: Any, raw_tool_calls: Iterable[object],
accumulated: dict[int, dict[str, Any]], accumulated: dict[int, dict[str, object]],
) -> None: ) -> None:
"""累计流式 chunk 里的 tool_calls 片段OpenAI delta.tool_calls 格式)。 """累计流式 chunk 里的 tool_calls 片段OpenAI delta.tool_calls 格式)。
@ -364,7 +363,7 @@ def create_litellm_provider(
provider_type: str, provider_type: str,
api_key: str, api_key: str,
base_url: str | None = None, base_url: str | None = None,
**kwargs: Any, **kwargs: object,
) -> LitellmProvider: ) -> LitellmProvider:
"""根据 provider_type 创建 LitellmProvider 实例。 """根据 provider_type 创建 LitellmProvider 实例。

View File

@ -18,7 +18,7 @@ import json
import logging import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Any, Protocol, runtime_checkable from typing import Protocol, runtime_checkable
from agentkit.llm.protocol import TokenUsage from agentkit.llm.protocol import TokenUsage
@ -294,8 +294,8 @@ class RedisUsageStore:
def __init__(self, redis_url: str = "redis://localhost:6379"): def __init__(self, redis_url: str = "redis://localhost:6379"):
self._redis_url = redis_url self._redis_url = redis_url
self._redis: Any = None self._redis: object | None = None
self._sync_redis: Any = None self._sync_redis: object | None = None
self._fallback: InMemoryUsageStore | None = None self._fallback: InMemoryUsageStore | None = None
self._degraded = False self._degraded = False
self._health_check_task: asyncio.Task[None] | None = None self._health_check_task: asyncio.Task[None] | None = None
@ -687,7 +687,7 @@ class RedisUsageStore:
@staticmethod @staticmethod
def _read_list( def _read_list(
r: Any, r: object,
list_key: str, list_key: str,
start: datetime, start: datetime,
end: datetime, end: datetime,

View File

@ -9,7 +9,6 @@ from __future__ import annotations
import logging import logging
import time import time
from typing import Any
from agentkit.llm.providers.openai import OpenAICompatibleProvider from agentkit.llm.providers.openai import OpenAICompatibleProvider
from agentkit.llm.protocol import LLMRequest, LLMResponse from agentkit.llm.protocol import LLMRequest, LLMResponse
@ -51,7 +50,7 @@ class WenxinProvider(OpenAICompatibleProvider):
secret_key: str | None = None, secret_key: str | None = None,
base_url: str = WENXIN_DEFAULT_BASE_URL, base_url: str = WENXIN_DEFAULT_BASE_URL,
default_model: str = "ernie-4.5-turbo-128k", default_model: str = "ernie-4.5-turbo-128k",
**kwargs: Any, **kwargs: object,
): ):
# If AK/SK provided, use token-based auth # If AK/SK provided, use token-based auth
self._access_key = access_key self._access_key = access_key

View File

@ -8,7 +8,6 @@ API腾讯云 OpenAI 兼容接口
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Any
from agentkit.llm.providers.openai import OpenAICompatibleProvider from agentkit.llm.providers.openai import OpenAICompatibleProvider
from agentkit.llm.protocol import LLMRequest, LLMResponse from agentkit.llm.protocol import LLMRequest, LLMResponse
@ -48,7 +47,7 @@ class YuanbaoProvider(OpenAICompatibleProvider):
base_url: str = YUANBAO_DEFAULT_BASE_URL, base_url: str = YUANBAO_DEFAULT_BASE_URL,
default_model: str = "hunyuan-turbos-latest", default_model: str = "hunyuan-turbos-latest",
enable_enhancement: bool = False, enable_enhancement: bool = False,
**kwargs: Any, **kwargs: object,
): ):
self._enable_enhancement = enable_enhancement self._enable_enhancement = enable_enhancement
super().__init__( super().__init__(

View File

@ -3,7 +3,6 @@
import json import json
import logging import logging
from collections.abc import AsyncIterator, Awaitable, Callable from collections.abc import AsyncIterator, Awaitable, Callable
from typing import Any
import httpx import httpx
@ -66,7 +65,7 @@ class RemoteLLMProvider(LLMProvider):
"Content-Type": "application/json", "Content-Type": "application/json",
} }
def _build_payload(self, request: LLMRequest) -> dict[str, Any]: def _build_payload(self, request: LLMRequest) -> dict[str, object]:
"""Convert LLMRequest to server API payload.""" """Convert LLMRequest to server API payload."""
return { return {
"messages": request.messages, "messages": request.messages,
@ -91,7 +90,7 @@ class RemoteLLMProvider(LLMProvider):
return str(body["error"]) return str(body["error"])
return str(body) return str(body)
def _parse_response(self, data: dict[str, Any], request: LLMRequest) -> LLMResponse: def _parse_response(self, data: dict[str, object], request: LLMRequest) -> LLMResponse:
"""Parse server response JSON into an LLMResponse.""" """Parse server response JSON into an LLMResponse."""
usage_data = data.get("usage") or {} usage_data = data.get("usage") or {}
usage = TokenUsage( usage = TokenUsage(
@ -115,7 +114,7 @@ class RemoteLLMProvider(LLMProvider):
latency_ms=data.get("latency_ms", 0.0), latency_ms=data.get("latency_ms", 0.0),
) )
def _parse_chunk(self, data: dict[str, Any], request: LLMRequest) -> StreamChunk: def _parse_chunk(self, data: dict[str, object], request: LLMRequest) -> StreamChunk:
"""Parse a single SSE data payload into a StreamChunk.""" """Parse a single SSE data payload into a StreamChunk."""
usage: TokenUsage | None = None usage: TokenUsage | None = None
usage_data = data.get("usage") usage_data = data.get("usage")
@ -218,9 +217,7 @@ class RemoteLLMProvider(LLMProvider):
if response.status_code == 502: if response.status_code == 502:
await response.aread() await response.aread()
detail = self._extract_error_detail(response) detail = self._extract_error_detail(response)
raise LLMProviderError( raise LLMProviderError("remote", f"Server LLM gateway error: {detail}")
"remote", f"Server LLM gateway error: {detail}"
)
if response.status_code != 200: if response.status_code != 200:
await response.aread() await response.aread()
raise LLMProviderError( raise LLMProviderError(

View File

@ -5,7 +5,7 @@ import logging
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from typing import Any, Callable from typing import Callable
from agentkit.core.exceptions import LLMProviderError from agentkit.core.exceptions import LLMProviderError
@ -20,9 +20,7 @@ class RetryConfig:
base_delay: float = 1.0 base_delay: float = 1.0
max_delay: float = 30.0 max_delay: float = 30.0
exponential_base: float = 2.0 exponential_base: float = 2.0
retryable_status_codes: set[int] = field( retryable_status_codes: set[int] = field(default_factory=lambda: {429, 500, 502, 503, 529})
default_factory=lambda: {429, 500, 502, 503, 529}
)
class CircuitState(Enum): class CircuitState(Enum):
@ -69,7 +67,7 @@ class RetryPolicy:
def __init__(self, config: RetryConfig | None = None): def __init__(self, config: RetryConfig | None = None):
self._config = config or RetryConfig() self._config = config or RetryConfig()
async def execute(self, fn: Callable, *args: Any, **kwargs: Any) -> Any: async def execute(self, fn: Callable, *args: object, **kwargs: object) -> object:
"""Execute fn with retry on retryable errors.""" """Execute fn with retry on retryable errors."""
last_error: Exception | None = None last_error: Exception | None = None
@ -84,7 +82,7 @@ class RetryPolicy:
raise raise
delay = min( delay = min(
self._config.base_delay * (self._config.exponential_base ** attempt), self._config.base_delay * (self._config.exponential_base**attempt),
self._config.max_delay, self._config.max_delay,
) )
logger.warning( logger.warning(
@ -142,7 +140,7 @@ class CircuitBreaker:
f"after {self._failure_count} failures" f"after {self._failure_count} failures"
) )
async def execute(self, fn: Callable, *args: Any, **kwargs: Any) -> Any: async def execute(self, fn: Callable, *args: object, **kwargs: object) -> object:
"""Execute fn through the circuit breaker.""" """Execute fn through the circuit breaker."""
current_state = self.state current_state = self.state
@ -158,6 +156,6 @@ class CircuitBreaker:
result = await fn(*args, **kwargs) result = await fn(*args, **kwargs)
self._on_success() self._on_success()
return result return result
except Exception as e: except Exception:
self._on_failure() self._on_failure()
raise raise

View File

@ -10,7 +10,6 @@ from __future__ import annotations
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any
import httpx import httpx
@ -95,8 +94,7 @@ class KBAdapter(ABC):
async def delete_by_id(self, id: str) -> bool: async def delete_by_id(self, id: str) -> bool:
"""按文档 ID 删除(子类可覆盖)""" """按文档 ID 删除(子类可覆盖)"""
logger.warning( logger.warning(
f"{self.__class__.__name__} does not support delete_by_id; " f"{self.__class__.__name__} does not support delete_by_id; id '{id}' skipped"
f"id '{id}' skipped"
) )
return False return False
@ -127,8 +125,7 @@ class KBAdapter(ABC):
async def get_document(self, doc_id: str) -> Document | None: async def get_document(self, doc_id: str) -> Document | None:
"""按 ID 获取单个文档(子类可覆盖)""" """按 ID 获取单个文档(子类可覆盖)"""
logger.warning( logger.warning(
f"{self.__class__.__name__} does not support get_document; " f"{self.__class__.__name__} does not support get_document; doc_id '{doc_id}' not found"
f"doc_id '{doc_id}' not found"
) )
return None return None
@ -156,5 +153,5 @@ class KBAdapter(ABC):
async def __aenter__(self) -> KBAdapter: async def __aenter__(self) -> KBAdapter:
return self return self
async def __aexit__(self, *args: Any) -> None: async def __aexit__(self, *args: object) -> None:
await self.close() await self.close()

View File

@ -7,7 +7,6 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Any
import httpx import httpx
@ -56,7 +55,9 @@ class ConfluenceAdapter(KBAdapter):
) )
self._base_url = base_url.rstrip("/") self._base_url = base_url.rstrip("/")
if not is_safe_url(self._base_url): if not is_safe_url(self._base_url):
raise ValueError(f"Unsafe base_url: {self._base_url}. Private/internal URLs are not allowed.") raise ValueError(
f"Unsafe base_url: {self._base_url}. Private/internal URLs are not allowed."
)
self._username = username self._username = username
self._api_token = api_token self._api_token = api_token
self._space_keys = space_keys or [] self._space_keys = space_keys or []
@ -65,9 +66,7 @@ class ConfluenceAdapter(KBAdapter):
"""创建 Confluence API HTTP 客户端""" """创建 Confluence API HTTP 客户端"""
import base64 import base64
credentials = base64.b64encode( credentials = base64.b64encode(f"{self._username}:{self._api_token}".encode()).decode()
f"{self._username}:{self._api_token}".encode()
).decode()
return httpx.AsyncClient( return httpx.AsyncClient(
base_url=self._base_url, base_url=self._base_url,
headers={ headers={
@ -101,7 +100,7 @@ class ConfluenceAdapter(KBAdapter):
space_filter = " OR ".join( space_filter = " OR ".join(
f'space = "{_escape_cql(key)}"' for key in self._space_keys f'space = "{_escape_cql(key)}"' for key in self._space_keys
) )
cql = f'{cql} AND ({space_filter})' cql = f"{cql} AND ({space_filter})"
resp = await client.get( resp = await client.get(
"/rest/api/content/search", "/rest/api/content/search",
@ -115,6 +114,7 @@ class ConfluenceAdapter(KBAdapter):
body = page.get("body", {}).get("storage", {}).get("value", "") body = page.get("body", {}).get("storage", {}).get("value", "")
# Strip HTML tags for plain text content # Strip HTML tags for plain text content
import re import re
content = re.sub(r"<[^>]+>", "", body) if body else page.get("title", "") content = re.sub(r"<[^>]+>", "", body) if body else page.get("title", "")
results.append( results.append(
@ -136,8 +136,7 @@ class ConfluenceAdapter(KBAdapter):
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
logger.error( logger.error(
f"Confluence search HTTP error: {e.response.status_code}" f"Confluence search HTTP error: {e.response.status_code}{e.response.text[:200]}"
f"{e.response.text[:200]}"
) )
return [] return []
except Exception as e: except Exception as e:
@ -157,6 +156,7 @@ class ConfluenceAdapter(KBAdapter):
body = page.get("body", {}).get("storage", {}).get("value", "") body = page.get("body", {}).get("storage", {}).get("value", "")
import re import re
content = re.sub(r"<[^>]+>", "", body) if body else "" content = re.sub(r"<[^>]+>", "", body) if body else ""
return Document( return Document(
@ -191,13 +191,17 @@ class ConfluenceAdapter(KBAdapter):
source_type="confluence", source_type="confluence",
) )
) )
return sources if sources else [ return (
SourceInfo( sources
source_id=self._source_id, if sources
source_name=self._source_name, else [
source_type=self._source_type, SourceInfo(
) source_id=self._source_id,
] source_name=self._source_name,
source_type=self._source_type,
)
]
)
except Exception as e: except Exception as e:
logger.error(f"Confluence list_sources error: {e}") logger.error(f"Confluence list_sources error: {e}")
return [ return [

View File

@ -8,7 +8,7 @@ from __future__ import annotations
import logging import logging
import time import time
from typing import Any from typing import TypeAlias
import httpx import httpx
@ -18,6 +18,9 @@ from agentkit.utils.security import is_safe_url
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# 飞书搜索请求 payloadsearch_key/page_size/wiki_space_ids — 值为 str|int|list[str]。
FeishuSearchPayload: TypeAlias = dict[str, object]
class FeishuKBAdapter(KBAdapter): class FeishuKBAdapter(KBAdapter):
"""飞书知识库适配器 """飞书知识库适配器
@ -54,7 +57,9 @@ class FeishuKBAdapter(KBAdapter):
self._app_secret = app_secret self._app_secret = app_secret
self._base_url = base_url.rstrip("/") self._base_url = base_url.rstrip("/")
if not is_safe_url(self._base_url): if not is_safe_url(self._base_url):
raise ValueError(f"Unsafe base_url: {self._base_url}. Private/internal URLs are not allowed.") raise ValueError(
f"Unsafe base_url: {self._base_url}. Private/internal URLs are not allowed."
)
self._space_ids = space_ids or [] self._space_ids = space_ids or []
self._access_token: str | None = None self._access_token: str | None = None
self._token_expiry: float = 0.0 self._token_expiry: float = 0.0
@ -94,10 +99,7 @@ class FeishuKBAdapter(KBAdapter):
self._client = None self._client = None
return self._access_token return self._access_token
else: else:
logger.error( logger.error(f"Feishu auth failed: code={data.get('code')}, msg={data.get('msg')}")
f"Feishu auth failed: code={data.get('code')}, "
f"msg={data.get('msg')}"
)
return None return None
except Exception as e: except Exception as e:
logger.error(f"Feishu auth error: {e}") logger.error(f"Feishu auth error: {e}")
@ -121,7 +123,7 @@ class FeishuKBAdapter(KBAdapter):
client = self._get_client() client = self._get_client()
try: try:
payload: dict[str, Any] = { payload: FeishuSearchPayload = {
"search_key": query, "search_key": query,
"page_size": top_k, "page_size": top_k,
} }
@ -137,8 +139,7 @@ class FeishuKBAdapter(KBAdapter):
if data.get("code") != 0: if data.get("code") != 0:
logger.error( logger.error(
f"Feishu search failed: code={data.get('code')}, " f"Feishu search failed: code={data.get('code')}, msg={data.get('msg')}"
f"msg={data.get('msg')}"
) )
return [] return []
@ -162,8 +163,7 @@ class FeishuKBAdapter(KBAdapter):
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
logger.error( logger.error(
f"Feishu search HTTP error: {e.response.status_code}" f"Feishu search HTTP error: {e.response.status_code}{e.response.text[:200]}"
f"{e.response.text[:200]}"
) )
return [] return []
except Exception as e: except Exception as e:
@ -179,7 +179,7 @@ class FeishuKBAdapter(KBAdapter):
client = self._get_client() client = self._get_client()
try: try:
resp = await client.get( resp = await client.get(
f"/wiki/v2/spaces/get_node", "/wiki/v2/spaces/get_node",
params={"token": doc_id}, params={"token": doc_id},
) )
resp.raise_for_status() resp.raise_for_status()
@ -230,13 +230,17 @@ class FeishuKBAdapter(KBAdapter):
source_type="feishu", source_type="feishu",
) )
) )
return sources if sources else [ return (
SourceInfo( sources
source_id=self._source_id, if sources
source_name=self._source_name, else [
source_type=self._source_type, SourceInfo(
) source_id=self._source_id,
] source_name=self._source_name,
source_type=self._source_type,
)
]
)
except Exception as e: except Exception as e:
logger.error(f"Feishu list_sources error: {e}") logger.error(f"Feishu list_sources error: {e}")
return [ return [

View File

@ -7,7 +7,6 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Any
import httpx import httpx
@ -55,7 +54,9 @@ class GenericHTTPAdapter(KBAdapter):
) )
self._endpoint_url = endpoint_url.rstrip("/") self._endpoint_url = endpoint_url.rstrip("/")
if not is_safe_url(self._endpoint_url): if not is_safe_url(self._endpoint_url):
raise ValueError(f"Unsafe endpoint_url: {self._endpoint_url}. Private/internal URLs are not allowed.") raise ValueError(
f"Unsafe endpoint_url: {self._endpoint_url}. Private/internal URLs are not allowed."
)
self._auth_config = auth_config or {} self._auth_config = auth_config or {}
self._extra_headers = headers or {} self._extra_headers = headers or {}
@ -74,12 +75,11 @@ class GenericHTTPAdapter(KBAdapter):
headers["Authorization"] = f"Bearer {token}" headers["Authorization"] = f"Bearer {token}"
elif auth_type == "basic": elif auth_type == "basic":
import base64 import base64
username = self._auth_config.get("username", "") username = self._auth_config.get("username", "")
password = self._auth_config.get("password", "") password = self._auth_config.get("password", "")
if username and password: if username and password:
credentials = base64.b64encode( credentials = base64.b64encode(f"{username}:{password}".encode()).decode()
f"{username}:{password}".encode()
).decode()
headers["Authorization"] = f"Basic {credentials}" headers["Authorization"] = f"Basic {credentials}"
elif auth_type == "api_key": elif auth_type == "api_key":
key_name = self._auth_config.get("header_name", "X-API-Key") key_name = self._auth_config.get("header_name", "X-API-Key")
@ -135,8 +135,7 @@ class GenericHTTPAdapter(KBAdapter):
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
logger.error( logger.error(
f"GenericHTTP search HTTP error: {e.response.status_code}" f"GenericHTTP search HTTP error: {e.response.status_code}{e.response.text[:200]}"
f"{e.response.text[:200]}"
) )
return [] return []
except Exception as e: except Exception as e:
@ -177,8 +176,7 @@ class GenericHTTPAdapter(KBAdapter):
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
logger.error( logger.error(
f"GenericHTTP ingest HTTP error: {e.response.status_code}" f"GenericHTTP ingest HTTP error: {e.response.status_code}{e.response.text[:200]}"
f"{e.response.text[:200]}"
) )
return [] return []
except Exception as e: except Exception as e:
@ -245,13 +243,17 @@ class GenericHTTPAdapter(KBAdapter):
document_count=item.get("document_count", 0), document_count=item.get("document_count", 0),
) )
) )
return sources if sources else [ return (
SourceInfo( sources
source_id=self._source_id, if sources
source_name=self._source_name, else [
source_type=self._source_type, SourceInfo(
) source_id=self._source_id,
] source_name=self._source_name,
source_type=self._source_type,
)
]
)
except Exception as e: except Exception as e:
logger.debug(f"GenericHTTP list_sources error (endpoint may not exist): {e}") logger.debug(f"GenericHTTP list_sources error (endpoint may not exist): {e}")

View File

@ -3,19 +3,29 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any from typing import TypeAlias
# 共享类型别名 — 跨 memory 子系统复用,避免 `Any` 残留。
# MetadataValue 覆盖 metadata dict 中实际出现的原始类型;
# MemoryValue 额外允许 dict/list 容器以容纳结构化负载(如 episodic 经验字典)。
MetadataValue: TypeAlias = str | int | float | bool | None
MetadataDict: TypeAlias = dict[str, MetadataValue]
MemoryValue: TypeAlias = (
str | int | float | bool | None | dict[str, MetadataValue] | list[MetadataValue]
)
@dataclass @dataclass
class MemoryItem: class MemoryItem:
"""记忆条目""" """记忆条目"""
key: str key: str
value: Any value: object
metadata: dict[str, Any] = field(default_factory=dict) metadata: MetadataDict = field(default_factory=dict)
score: float = 1.0 score: float = 1.0
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
def to_dict(self) -> dict: def to_dict(self) -> dict[str, object]:
return { return {
"key": self.key, "key": self.key,
"value": self.value, "value": self.value,
@ -35,7 +45,7 @@ class Memory(ABC):
""" """
@abstractmethod @abstractmethod
async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None: async def store(self, key: str, value: object, metadata: MetadataDict | None = None) -> None:
"""存储记忆""" """存储记忆"""
... ...
@ -45,7 +55,9 @@ class Memory(ABC):
... ...
@abstractmethod @abstractmethod
async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None) -> list[MemoryItem]: async def search(
self, query: str, top_k: int = 5, filters: MetadataDict | None = None
) -> list[MemoryItem]:
"""语义检索""" """语义检索"""
... ...
@ -54,7 +66,7 @@ class Memory(ABC):
"""删除记忆""" """删除记忆"""
... ...
async def store_batch(self, items: list[tuple[str, Any, dict | None]]) -> None: async def store_batch(self, items: list[tuple[str, object, MetadataDict | None]]) -> None:
"""批量存储""" """批量存储"""
for key, value, metadata in items: for key, value, metadata in items:
await self.store(key, value, metadata) await self.store(key, value, metadata)

View File

@ -11,10 +11,18 @@ import logging
import re import re
import uuid import uuid
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import TypeAlias
from agentkit.memory.base import MetadataDict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# 分块元数据source_doc/position/char_count/chunking_strategy/heading/heading_level
# — 全部为原始标量str/int
ChunkMetadata: TypeAlias = MetadataDict
# _split_by_headings 返回的节段结构。
SectionInfo: TypeAlias = dict[str, str | int]
@dataclass @dataclass
class Chunk: class Chunk:
@ -22,7 +30,7 @@ class Chunk:
chunk_id: str chunk_id: str
content: str content: str
metadata: dict[str, Any] = field(default_factory=dict) metadata: ChunkMetadata = field(default_factory=dict)
def __post_init__(self) -> None: def __post_init__(self) -> None:
if "source_doc" not in self.metadata: if "source_doc" not in self.metadata:
@ -30,7 +38,7 @@ class Chunk:
if "position" not in self.metadata: if "position" not in self.metadata:
self.metadata["position"] = 0 self.metadata["position"] = 0
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, object]:
return { return {
"chunk_id": self.chunk_id, "chunk_id": self.chunk_id,
"content": self.content, "content": self.content,
@ -57,7 +65,9 @@ class TextChunker:
separator: 优先分割符 separator: 优先分割符
""" """
if chunk_overlap >= chunk_size: if chunk_overlap >= chunk_size:
raise ValueError(f"chunk_overlap ({chunk_overlap}) must be less than chunk_size ({chunk_size})") raise ValueError(
f"chunk_overlap ({chunk_overlap}) must be less than chunk_size ({chunk_size})"
)
self._chunk_size = chunk_size self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap self._chunk_overlap = chunk_overlap
self._separator = separator self._separator = separator
@ -66,7 +76,7 @@ class TextChunker:
self, self,
text: str, text: str,
source_doc_id: str = "", source_doc_id: str = "",
metadata: dict[str, Any] | None = None, metadata: ChunkMetadata | None = None,
) -> list[Chunk]: ) -> list[Chunk]:
"""将文本分块 """将文本分块
@ -96,11 +106,13 @@ class TextChunker:
chunk_meta = dict(base_meta) chunk_meta = dict(base_meta)
chunk_meta["position"] = i chunk_meta["position"] = i
chunk_meta["char_count"] = len(chunk_text) chunk_meta["char_count"] = len(chunk_text)
chunks.append(Chunk( chunks.append(
chunk_id=str(uuid.uuid4()), Chunk(
content=chunk_text, chunk_id=str(uuid.uuid4()),
metadata=chunk_meta, content=chunk_text,
)) metadata=chunk_meta,
)
)
return chunks return chunks
@ -142,7 +154,9 @@ class TextChunker:
overlap_text[overlap_start:], segments overlap_text[overlap_start:], segments
) )
current = overlap_segments current = overlap_segments
current_len = sum(len(s) for s in current) + len(self._separator) * max(0, len(current) - 1) current_len = sum(len(s) for s in current) + len(self._separator) * max(
0, len(current) - 1
)
current.append(segment) current.append(segment)
current_len += seg_len + len(self._separator) current_len += seg_len + len(self._separator)
@ -214,7 +228,7 @@ class StructuralChunker:
self, self,
text: str, text: str,
source_doc_id: str = "", source_doc_id: str = "",
metadata: dict[str, Any] | None = None, metadata: ChunkMetadata | None = None,
) -> list[Chunk]: ) -> list[Chunk]:
"""将文本按结构分块 """将文本按结构分块
@ -266,23 +280,25 @@ class StructuralChunker:
chunk_meta["heading"] = heading chunk_meta["heading"] = heading
chunk_meta["heading_level"] = level chunk_meta["heading_level"] = level
chunk_meta["char_count"] = len(content) chunk_meta["char_count"] = len(content)
chunks.append(Chunk( chunks.append(
chunk_id=str(uuid.uuid4()), Chunk(
content=content, chunk_id=str(uuid.uuid4()),
metadata=chunk_meta, content=content,
)) metadata=chunk_meta,
)
)
position += 1 position += 1
return chunks return chunks
def _split_by_headings(self, text: str) -> list[dict[str, Any]]: def _split_by_headings(self, text: str) -> list[SectionInfo]:
"""按标题分割 Markdown 文本 """按标题分割 Markdown 文本
Returns: Returns:
列表每项包含 heading, content, level 列表每项包含 heading, content, level
""" """
lines = text.split("\n") lines = text.split("\n")
sections: list[dict[str, Any]] = [] sections: list[SectionInfo] = []
current_heading = "" current_heading = ""
current_level = 0 current_level = 0
current_lines: list[str] = [] current_lines: list[str] = []
@ -296,11 +312,13 @@ class StructuralChunker:
if current_lines: if current_lines:
content = "\n".join(current_lines).strip() content = "\n".join(current_lines).strip()
if content: if content:
sections.append({ sections.append(
"heading": current_heading, {
"content": content, "heading": current_heading,
"level": current_level, "content": content,
}) "level": current_level,
}
)
# 开始新节 # 开始新节
current_heading = match.group(2).strip() current_heading = match.group(2).strip()
@ -313,18 +331,22 @@ class StructuralChunker:
if current_lines: if current_lines:
content = "\n".join(current_lines).strip() content = "\n".join(current_lines).strip()
if content: if content:
sections.append({ sections.append(
"heading": current_heading, {
"content": content, "heading": current_heading,
"level": current_level, "content": content,
}) "level": current_level,
}
)
# 如果没有标题结构,整体作为一个块 # 如果没有标题结构,整体作为一个块
if not sections: if not sections:
sections.append({ sections.append(
"heading": "", {
"content": text.strip(), "heading": "",
"level": 0, "content": text.strip(),
}) "level": 0,
}
)
return sections return sections

View File

@ -9,10 +9,14 @@ from __future__ import annotations
import hashlib import hashlib
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import TYPE_CHECKING
from agentkit.memory.base import MetadataDict
from agentkit.memory.embedder import EmbeddingCache from agentkit.memory.embedder import EmbeddingCache
if TYPE_CHECKING:
from agentkit.llm.gateway import LLMGateway
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -24,7 +28,7 @@ class ContextualChunk:
context_prefix: str context_prefix: str
enhanced_content: str enhanced_content: str
chunk_index: int chunk_index: int
metadata: dict[str, Any] metadata: MetadataDict
@property @property
def content(self) -> str: def content(self) -> str:
@ -65,7 +69,7 @@ class ContextualChunker:
def __init__( def __init__(
self, self,
llm_gateway: Any = None, llm_gateway: LLMGateway | None = None,
cache: EmbeddingCache | None = None, cache: EmbeddingCache | None = None,
batch_size: int = 8, batch_size: int = 8,
max_context_length: int = 200, max_context_length: int = 200,
@ -90,7 +94,7 @@ class ContextualChunker:
self, self,
document: str, document: str,
chunks: list[str], chunks: list[str],
metadata: dict[str, Any] | None = None, metadata: MetadataDict | None = None,
) -> list[ContextualChunk]: ) -> list[ContextualChunk]:
"""为文档块添加上下文前缀 """为文档块添加上下文前缀
@ -134,7 +138,7 @@ class ContextualChunker:
document: str, document: str,
chunks: list[str], chunks: list[str],
start_index: int, start_index: int,
metadata: dict[str, Any] | None, metadata: MetadataDict | None,
) -> list[ContextualChunk]: ) -> list[ContextualChunk]:
"""处理一批文档块""" """处理一批文档块"""
results: list[ContextualChunk] = [] results: list[ContextualChunk] = []

View File

@ -12,7 +12,9 @@ import uuid
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
from typing import Any from typing import TypeAlias
from agentkit.memory.base import MetadataDict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -23,6 +25,11 @@ MAX_CONTENT_SIZE = 100 * 1024 * 1024 # 100MB
MAX_ROWS_PER_SHEET = 10_000 MAX_ROWS_PER_SHEET = 10_000
MAX_CELL_CHARS = 10_000 MAX_CELL_CHARS = 10_000
# 文档元数据source/format/parser/page_count/table_count/sheet_count/row_count/
# heading_count/created_at/title/truncated — 全部为原始标量。
DocumentMetadata: TypeAlias = MetadataDict
ParseResult: TypeAlias = tuple[str, DocumentMetadata]
@dataclass @dataclass
class Document: class Document:
@ -31,7 +38,7 @@ class Document:
doc_id: str doc_id: str
title: str title: str
content: str content: str
metadata: dict[str, Any] = field(default_factory=dict) metadata: DocumentMetadata = field(default_factory=dict)
def __post_init__(self) -> None: def __post_init__(self) -> None:
if "source" not in self.metadata: if "source" not in self.metadata:
@ -43,7 +50,7 @@ class Document:
if "created_at" not in self.metadata: if "created_at" not in self.metadata:
self.metadata["created_at"] = datetime.now(timezone.utc).isoformat() self.metadata["created_at"] = datetime.now(timezone.utc).isoformat()
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, object]:
return { return {
"doc_id": self.doc_id, "doc_id": self.doc_id,
"title": self.title, "title": self.title,
@ -136,12 +143,14 @@ class DocumentLoader:
parser = parsers.get(doc_format) parser = parsers.get(doc_format)
if parser is None: if parser is None:
logger.warning(f"Unsupported format '{doc_format}' for {filename}, falling back to text") logger.warning(
f"Unsupported format '{doc_format}' for {filename}, falling back to text"
)
parser = self._parse_text parser = self._parse_text
text, extra_meta = parser(content, filename) text, extra_meta = parser(content, filename)
metadata: dict[str, Any] = { metadata: DocumentMetadata = {
"source": filename, "source": filename,
"format": doc_format, "format": doc_format,
"created_at": datetime.now(timezone.utc).isoformat(), "created_at": datetime.now(timezone.utc).isoformat(),
@ -159,7 +168,7 @@ class DocumentLoader:
metadata=metadata, metadata=metadata,
) )
def _parse_pdf(self, content: bytes, filename: str) -> tuple[str, dict[str, Any]]: def _parse_pdf(self, content: bytes, filename: str) -> ParseResult:
"""解析 PDF 文件 """解析 PDF 文件
优先使用 PyMuPDF (fitz)回退到 pdfplumber最终回退到纯文本 优先使用 PyMuPDF (fitz)回退到 pdfplumber最终回退到纯文本
@ -215,7 +224,7 @@ class DocumentLoader:
logger.warning(f"No PDF parser available for {filename}, falling back to text extraction") logger.warning(f"No PDF parser available for {filename}, falling back to text extraction")
return self._parse_text(content, filename) return self._parse_text(content, filename)
def _parse_docx(self, content: bytes, filename: str) -> tuple[str, dict[str, Any]]: def _parse_docx(self, content: bytes, filename: str) -> ParseResult:
"""解析 Word 文件 """解析 Word 文件
使用 python-docx回退到纯文本 使用 python-docx回退到纯文本
@ -259,7 +268,7 @@ class DocumentLoader:
logger.warning(f"python-docx parsing failed for {filename}: {e}") logger.warning(f"python-docx parsing failed for {filename}: {e}")
return self._parse_text(content, filename) return self._parse_text(content, filename)
def _parse_xlsx(self, content: bytes, filename: str) -> tuple[str, dict[str, Any]]: def _parse_xlsx(self, content: bytes, filename: str) -> ParseResult:
"""解析 Excel 文件 """解析 Excel 文件
使用 openpyxl回退到纯文本每个 sheet 转为 Markdown 表格 使用 openpyxl回退到纯文本每个 sheet 转为 Markdown 表格
@ -313,7 +322,7 @@ class DocumentLoader:
finally: finally:
wb.close() wb.close()
text = "\n".join(sections).strip() text = "\n".join(sections).strip()
meta: dict[str, Any] = { meta: DocumentMetadata = {
"parser": "openpyxl", "parser": "openpyxl",
"sheet_count": sheet_count, "sheet_count": sheet_count,
"row_count": total_rows, "row_count": total_rows,
@ -328,7 +337,7 @@ class DocumentLoader:
logger.warning(f"openpyxl parsing failed for {filename}: {e}") logger.warning(f"openpyxl parsing failed for {filename}: {e}")
return self._parse_text(content, filename) return self._parse_text(content, filename)
def _parse_markdown(self, content: bytes, filename: str) -> tuple[str, dict[str, Any]]: def _parse_markdown(self, content: bytes, filename: str) -> ParseResult:
"""解析 Markdown 文件 """解析 Markdown 文件
使用 mistune如果可用否则直接读取文本 使用 mistune如果可用否则直接读取文本
@ -347,7 +356,7 @@ class DocumentLoader:
title = line_stripped.lstrip("#").strip() title = line_stripped.lstrip("#").strip()
break break
meta: dict[str, Any] = { meta: DocumentMetadata = {
"parser": "markdown", "parser": "markdown",
} }
if title: if title:
@ -362,7 +371,7 @@ class DocumentLoader:
return text, meta return text, meta
def _parse_html(self, content: bytes, filename: str) -> tuple[str, dict[str, Any]]: def _parse_html(self, content: bytes, filename: str) -> ParseResult:
"""解析 HTML 文件 """解析 HTML 文件
使用 BeautifulSoup 提取文本回退到纯文本 使用 BeautifulSoup 提取文本回退到纯文本
@ -388,7 +397,7 @@ class DocumentLoader:
if soup.title and soup.title.string: if soup.title and soup.title.string:
title = soup.title.string.strip() title = soup.title.string.strip()
meta: dict[str, Any] = { meta: DocumentMetadata = {
"parser": "beautifulsoup", "parser": "beautifulsoup",
} }
if title: if title:
@ -402,7 +411,7 @@ class DocumentLoader:
logger.warning(f"BeautifulSoup parsing failed for {filename}: {e}") logger.warning(f"BeautifulSoup parsing failed for {filename}: {e}")
return self._parse_text(content, filename) return self._parse_text(content, filename)
def _parse_text(self, content: bytes, filename: str) -> tuple[str, dict[str, Any]]: def _parse_text(self, content: bytes, filename: str) -> ParseResult:
"""解析纯文本文件""" """解析纯文本文件"""
try: try:
text = content.decode("utf-8") text = content.decode("utf-8")

View File

@ -1,12 +1,17 @@
"""Embedder 接口与实现 - 文本向量化""" """Embedder 接口与实现 - 文本向量化"""
from __future__ import annotations
import hashlib import hashlib
import logging import logging
import os import os
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import OrderedDict from collections import OrderedDict
from typing import Any from typing import TYPE_CHECKING
if TYPE_CHECKING:
import httpx
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -97,13 +102,14 @@ class OpenAIEmbedder(Embedder):
self._model = model self._model = model
self._base_url = base_url self._base_url = base_url
self._dimension = 1536 # text-embedding-3-small 默认维度 self._dimension = 1536 # text-embedding-3-small 默认维度
self._client: Any = None self._client: httpx.AsyncClient | None = None
self._cache = cache self._cache = cache
def _get_client(self): def _get_client(self) -> httpx.AsyncClient:
"""Lazily create and reuse a single httpx.AsyncClient.""" """Lazily create and reuse a single httpx.AsyncClient."""
if self._client is None: if self._client is None:
import httpx import httpx
self._client = httpx.AsyncClient(timeout=30.0) self._client = httpx.AsyncClient(timeout=30.0)
return self._client return self._client

View File

@ -1,17 +1,22 @@
"""Episodic Memory - 基于 pgvector + PostgreSQL 的任务经验记忆""" """Episodic Memory - 基于 pgvector + PostgreSQL 的任务经验记忆"""
from __future__ import annotations
import json import json
import logging import logging
import math import math
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any from typing import TYPE_CHECKING
from sqlalchemy import text from sqlalchemy import text
from agentkit.memory.base import Memory, MemoryItem from agentkit.memory.base import Memory, MemoryItem, MetadataDict
from agentkit.memory.embedder import Embedder from agentkit.memory.embedder import Embedder
from agentkit.utils.vector_math import compute_cosine_similarity from agentkit.utils.vector_math import compute_cosine_similarity
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -28,8 +33,8 @@ class EpisodicMemory(Memory):
def __init__( def __init__(
self, self,
session_factory: Any, session_factory: object,
episodic_model: Any, episodic_model: object,
embedder: Embedder | None = None, embedder: Embedder | None = None,
decay_rate: float = 0.01, decay_rate: float = 0.01,
alpha: float = 0.7, alpha: float = 0.7,
@ -57,7 +62,7 @@ class EpisodicMemory(Memory):
self._pgvector_enabled = pgvector_enabled self._pgvector_enabled = pgvector_enabled
self._table_name = table_name self._table_name = table_name
async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None: async def store(self, key: str, value: object, metadata: MetadataDict | None = None) -> None:
"""存储任务经验""" """存储任务经验"""
async with self._session_factory() as db: async with self._session_factory() as db:
try: try:
@ -68,7 +73,11 @@ class EpisodicMemory(Memory):
embedding = None embedding = None
if self._embedder: if self._embedder:
if isinstance(value, dict): if isinstance(value, dict):
text = value.get("output_summary", "") or value.get("input_summary", "") or json.dumps(value, ensure_ascii=False)[:500] text = (
value.get("output_summary", "")
or value.get("input_summary", "")
or json.dumps(value, ensure_ascii=False)[:500]
)
else: else:
text = str(value) text = str(value)
embedding = await self._embedder.embed(text) embedding = await self._embedder.embed(text)
@ -106,13 +115,11 @@ class EpisodicMemory(Memory):
logger.error(f"Failed to retrieve episodic memory: {e}") logger.error(f"Failed to retrieve episodic memory: {e}")
return None return None
async def _retrieve_pgvector(self, db: Any, query_embedding: list[float]) -> MemoryItem | None: async def _retrieve_pgvector(
self, db: AsyncSession, query_embedding: list[float]
) -> MemoryItem | None:
"""使用 pgvector ``<=>`` 算符检索最相似条目""" """使用 pgvector ``<=>`` 算符检索最相似条目"""
sql = text( sql = text(f"SELECT * FROM {self._table_name} ORDER BY embedding <=> :query_vec LIMIT :lim")
f"SELECT * FROM {self._table_name} "
f"ORDER BY embedding <=> :query_vec "
f"LIMIT :lim"
)
result = await db.execute(sql, {"query_vec": str(query_embedding), "lim": 1}) result = await db.execute(sql, {"query_vec": str(query_embedding), "lim": 1})
row = result.mappings().first() row = result.mappings().first()
@ -147,7 +154,9 @@ class EpisodicMemory(Memory):
created_at=row.get("created_at") or datetime.now(timezone.utc), created_at=row.get("created_at") or datetime.now(timezone.utc),
) )
async def _retrieve_client_side(self, db: Any, query_embedding: list[float]) -> MemoryItem | None: async def _retrieve_client_side(
self, db: AsyncSession, query_embedding: list[float]
) -> MemoryItem | None:
"""客户端 O(N) cosine similarity 检索(回退路径)""" """客户端 O(N) cosine similarity 检索(回退路径)"""
Model = self._episodic_model Model = self._episodic_model
from sqlalchemy import select from sqlalchemy import select
@ -193,7 +202,13 @@ class EpisodicMemory(Memory):
created_at=best_item.created_at or datetime.now(timezone.utc), created_at=best_item.created_at or datetime.now(timezone.utc),
) )
async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None, search_multiplier: int = 5) -> list[MemoryItem]: async def search(
self,
query: str,
top_k: int = 5,
filters: MetadataDict | None = None,
search_multiplier: int = 5,
) -> list[MemoryItem]:
"""语义检索相似历史案例 """语义检索相似历史案例
Args: Args:
@ -214,10 +229,10 @@ class EpisodicMemory(Memory):
async def _search_pgvector( async def _search_pgvector(
self, self,
db: Any, db: AsyncSession,
query: str, query: str,
top_k: int, top_k: int,
filters: dict[str, Any] | None, filters: MetadataDict | None,
search_multiplier: int, search_multiplier: int,
) -> list[MemoryItem]: ) -> list[MemoryItem]:
"""使用 pgvector ``<=>`` 算符检索,再 Python 侧 time_decay 重排""" """使用 pgvector ``<=>`` 算符检索,再 Python 侧 time_decay 重排"""
@ -225,7 +240,7 @@ class EpisodicMemory(Memory):
fetch_limit = top_k * search_multiplier fetch_limit = top_k * search_multiplier
where_clauses = [] where_clauses = []
params: dict[str, Any] = {"query_vec": str(query_embedding), "lim": fetch_limit} params: dict[str, object] = {"query_vec": str(query_embedding), "lim": fetch_limit}
filters = filters or {} filters = filters or {}
if filters.get("agent_name"): if filters.get("agent_name"):
@ -256,7 +271,11 @@ class EpisodicMemory(Memory):
items = [] items = []
for row in rows: for row in rows:
row_embedding = row.get("embedding") row_embedding = row.get("embedding")
age_hours = (datetime.now(timezone.utc) - row["created_at"]).total_seconds() / 3600 if row.get("created_at") else 0 age_hours = (
(datetime.now(timezone.utc) - row["created_at"]).total_seconds() / 3600
if row.get("created_at")
else 0
)
decay = math.exp(-self._decay_rate * age_hours) decay = math.exp(-self._decay_rate * age_hours)
time_decay_score = (row.get("quality_score") or 0.5) * decay time_decay_score = (row.get("quality_score") or 0.5) * decay
@ -266,33 +285,37 @@ class EpisodicMemory(Memory):
else: else:
score = time_decay_score score = time_decay_score
items.append(MemoryItem( items.append(
key=str(row.get("id", "")), MemoryItem(
value={ key=str(row.get("id", "")),
"input_summary": row.get("input_summary", ""), value={
"output_summary": row.get("output_summary", ""), "input_summary": row.get("input_summary", ""),
"outcome": row.get("outcome", "success"), "output_summary": row.get("output_summary", ""),
"quality_score": row.get("quality_score", 0.5), "outcome": row.get("outcome", "success"),
"reflection": row.get("reflection", ""), "quality_score": row.get("quality_score", 0.5),
}, "reflection": row.get("reflection", ""),
metadata={ },
"agent_name": row.get("agent_name", ""), metadata={
"task_type": row.get("task_type", ""), "agent_name": row.get("agent_name", ""),
"created_at": row["created_at"].isoformat() if row.get("created_at") else None, "task_type": row.get("task_type", ""),
}, "created_at": row["created_at"].isoformat()
score=score, if row.get("created_at")
created_at=row.get("created_at") or datetime.now(timezone.utc), else None,
)) },
score=score,
created_at=row.get("created_at") or datetime.now(timezone.utc),
)
)
items.sort(key=lambda x: x.score, reverse=True) items.sort(key=lambda x: x.score, reverse=True)
return items[:top_k] return items[:top_k]
async def _search_client_side( async def _search_client_side(
self, self,
db: Any, db: AsyncSession,
query: str, query: str,
top_k: int, top_k: int,
filters: dict[str, Any] | None, filters: MetadataDict | None,
search_multiplier: int, search_multiplier: int,
) -> list[MemoryItem]: ) -> list[MemoryItem]:
"""客户端 O(N) cosine similarity 检索(回退路径)""" """客户端 O(N) cosine similarity 检索(回退路径)"""
@ -300,6 +323,7 @@ class EpisodicMemory(Memory):
filters = filters or {} filters = filters or {}
from sqlalchemy import select from sqlalchemy import select
stmt = select(Model) stmt = select(Model)
if filters.get("agent_name"): if filters.get("agent_name"):
@ -322,7 +346,11 @@ class EpisodicMemory(Memory):
# 计算得分并构建 MemoryItem # 计算得分并构建 MemoryItem
items = [] items = []
for entry in entries: for entry in entries:
age_hours = (datetime.now(timezone.utc) - entry.created_at).total_seconds() / 3600 if entry.created_at else 0 age_hours = (
(datetime.now(timezone.utc) - entry.created_at).total_seconds() / 3600
if entry.created_at
else 0
)
decay = math.exp(-self._decay_rate * age_hours) decay = math.exp(-self._decay_rate * age_hours)
time_decay_score = (entry.quality_score or 0.5) * decay time_decay_score = (entry.quality_score or 0.5) * decay
@ -333,30 +361,34 @@ class EpisodicMemory(Memory):
else: else:
score = time_decay_score score = time_decay_score
items.append(MemoryItem( items.append(
key=str(entry.id), MemoryItem(
value={ key=str(entry.id),
"input_summary": entry.input_summary, value={
"output_summary": entry.output_summary, "input_summary": entry.input_summary,
"outcome": entry.outcome, "output_summary": entry.output_summary,
"quality_score": entry.quality_score, "outcome": entry.outcome,
"reflection": entry.reflection, "quality_score": entry.quality_score,
}, "reflection": entry.reflection,
metadata={ },
"agent_name": entry.agent_name, metadata={
"task_type": entry.task_type, "agent_name": entry.agent_name,
"created_at": entry.created_at.isoformat() if entry.created_at else None, "task_type": entry.task_type,
}, "created_at": entry.created_at.isoformat() if entry.created_at else None,
score=score, },
created_at=entry.created_at or datetime.now(timezone.utc), score=score,
)) created_at=entry.created_at or datetime.now(timezone.utc),
)
)
items.sort(key=lambda x: x.score, reverse=True) items.sort(key=lambda x: x.score, reverse=True)
if len(items) < top_k: if len(items) < top_k:
logger.warning( logger.warning(
"EpisodicMemory.search returned %d results after scoring (top_k=%d). " "EpisodicMemory.search returned %d results after scoring (top_k=%d). "
"Consider increasing search_multiplier (current=%d) to avoid missing relevant entries.", "Consider increasing search_multiplier (current=%d) to avoid missing relevant entries.",
len(items), top_k, search_multiplier, len(items),
top_k,
search_multiplier,
) )
return items[:top_k] return items[:top_k]
@ -364,8 +396,9 @@ class EpisodicMemory(Memory):
"""删除指定经验""" """删除指定经验"""
async with self._session_factory() as db: async with self._session_factory() as db:
try: try:
from sqlalchemy import select, delete as sql_delete from sqlalchemy import delete as sql_delete
import uuid import uuid
Model = self._episodic_model Model = self._episodic_model
stmt = sql_delete(Model).where(Model.id == uuid.UUID(key)) stmt = sql_delete(Model).where(Model.id == uuid.UUID(key))

View File

@ -3,13 +3,26 @@
配置驱动不直接依赖业务系统代码通过 base_url + api_key 连接 配置驱动不直接依赖业务系统代码通过 base_url + api_key 连接
""" """
from __future__ import annotations
import logging import logging
from typing import Any from typing import TYPE_CHECKING, TypeAlias
import httpx import httpx
from agentkit.memory.base import MetadataDict
if TYPE_CHECKING:
from agentkit.llm.gateway import LLMGateway
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# 标准化检索结果id/content/score/source/document_id/document_title/
# knowledge_base_id/metadata — 值为原始标量或嵌套 dict。
RAGSearchResult: TypeAlias = dict[str, object]
# ingest() 写入的文档负载title/content/source_type/metadata。
RAGIngestPayload: TypeAlias = dict[str, object]
class HttpRAGService: class HttpRAGService:
"""HTTP 客户端,调用业务系统的知识库检索 API """HTTP 客户端,调用业务系统的知识库检索 API
@ -39,7 +52,7 @@ class HttpRAGService:
knowledge_base_ids: list[str] | None = None, knowledge_base_ids: list[str] | None = None,
timeout: int = 30, timeout: int = 30,
contextual_chunking: bool = False, contextual_chunking: bool = False,
llm_gateway: Any = None, llm_gateway: LLMGateway | None = None,
): ):
""" """
Args: Args:
@ -74,7 +87,7 @@ class HttpRAGService:
query: str, query: str,
knowledge_base_ids: list[str] | None = None, knowledge_base_ids: list[str] | None = None,
top_k: int = 5, top_k: int = 5,
) -> list[dict[str, Any]]: ) -> list[RAGSearchResult]:
"""语义检索知识库 """语义检索知识库
Args: Args:
@ -113,19 +126,23 @@ class HttpRAGService:
normalized = [] normalized = []
for r in results: for r in results:
if isinstance(r, dict): if isinstance(r, dict):
normalized.append({ normalized.append(
"id": r.get("chunk_id", r.get("id", "")), {
"content": r.get("content", ""), "id": r.get("chunk_id", r.get("id", "")),
"score": float(r.get("score", 0.0)), "content": r.get("content", ""),
"source": r.get("source", "rag"), "score": float(r.get("score", 0.0)),
"document_id": r.get("document_id", ""), "source": r.get("source", "rag"),
"document_title": r.get("document_title", ""), "document_id": r.get("document_id", ""),
"metadata": r.get("metadata", {}), "document_title": r.get("document_title", ""),
}) "metadata": r.get("metadata", {}),
}
)
return normalized return normalized
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
logger.error(f"RAG search HTTP error: {e.response.status_code}{e.response.text[:200]}") logger.error(
f"RAG search HTTP error: {e.response.status_code}{e.response.text[:200]}"
)
return [] return []
except httpx.RequestError as e: except httpx.RequestError as e:
logger.error(f"RAG search request error: {e}") logger.error(f"RAG search request error: {e}")
@ -141,7 +158,7 @@ class HttpRAGService:
top_k: int = 5, top_k: int = 5,
use_rerank: bool = True, use_rerank: bool = True,
use_compression: bool = False, use_compression: bool = False,
) -> list[dict[str, Any]]: ) -> list[RAGSearchResult]:
"""增强语义检索知识库(支持 rerank 和 compression """增强语义检索知识库(支持 rerank 和 compression
对每个知识库分别调用 /bases/{kb_id}/retrieve 接口 对每个知识库分别调用 /bases/{kb_id}/retrieve 接口
@ -169,7 +186,7 @@ class HttpRAGService:
} }
client = self._get_client() client = self._get_client()
all_results: list[dict[str, Any]] = [] all_results: list[RAGSearchResult] = []
for kb_id in kb_ids: for kb_id in kb_ids:
try: try:
@ -189,28 +206,27 @@ class HttpRAGService:
# 标准化 # 标准化
for r in results: for r in results:
if isinstance(r, dict): if isinstance(r, dict):
all_results.append({ all_results.append(
"id": r.get("chunk_id", r.get("id", "")), {
"content": r.get("content", ""), "id": r.get("chunk_id", r.get("id", "")),
"score": float(r.get("score", 0.0)), "content": r.get("content", ""),
"source": r.get("source", "rag"), "score": float(r.get("score", 0.0)),
"document_id": r.get("document_id", ""), "source": r.get("source", "rag"),
"document_title": r.get("document_title", ""), "document_id": r.get("document_id", ""),
"knowledge_base_id": kb_id, "document_title": r.get("document_title", ""),
"metadata": r.get("metadata", {}), "knowledge_base_id": kb_id,
}) "metadata": r.get("metadata", {}),
}
)
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
if e.response.status_code == 404: if e.response.status_code == 404:
# This KB doesn't support enhanced search — fall back to # This KB doesn't support enhanced search — fall back to
# standard search for THIS KB only, not all KBs. # standard search for THIS KB only, not all KBs.
logger.info( logger.info(
f"Enhanced search not available for KB {kb_id}, " f"Enhanced search not available for KB {kb_id}, using standard search"
f"using standard search"
)
std_result = await self.search(
query, knowledge_base_ids=[kb_id], top_k=top_k
) )
std_result = await self.search(query, knowledge_base_ids=[kb_id], top_k=top_k)
all_results.extend(std_result) all_results.extend(std_result)
else: else:
logger.error( logger.error(
@ -232,9 +248,9 @@ class HttpRAGService:
async def ingest( async def ingest(
self, self,
key: str, key: str,
value: Any, value: object,
metadata: dict[str, Any] | None = None, metadata: MetadataDict | None = None,
) -> dict[str, Any] | None: ) -> dict[str, object] | None:
"""写入文档到知识库(可选操作) """写入文档到知识库(可选操作)
When contextual_chunking is enabled and llm_gateway is configured, When contextual_chunking is enabled and llm_gateway is configured,
@ -308,5 +324,5 @@ class HttpRAGService:
async def __aenter__(self) -> "HttpRAGService": async def __aenter__(self) -> "HttpRAGService":
return self return self
async def __aexit__(self, *args: Any) -> None: async def __aexit__(self, *args: object) -> None:
await self.close() await self.close()

View File

@ -11,7 +11,11 @@ from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any, Protocol, runtime_checkable from typing import Protocol, TypeAlias, runtime_checkable
from agentkit.memory.base import MetadataDict
KBMetadata: TypeAlias = MetadataDict
@dataclass @dataclass
@ -22,7 +26,7 @@ class Document:
content: str content: str
title: str = "" title: str = ""
source_id: str = "" source_id: str = ""
metadata: dict[str, Any] = field(default_factory=dict) metadata: KBMetadata = field(default_factory=dict)
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
@ -34,7 +38,7 @@ class QueryResult:
source_id: str source_id: str
source_name: str source_name: str
score: float score: float
metadata: dict[str, Any] = field(default_factory=dict) metadata: KBMetadata = field(default_factory=dict)
doc_id: str = "" doc_id: str = ""
title: str = "" title: str = ""

View File

@ -11,25 +11,32 @@ from __future__ import annotations
import json import json
import logging import logging
import re import re
import uuid
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any from typing import TYPE_CHECKING, TypeAlias
_SAFE_TABLE_NAME_PATTERN = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
from agentkit.memory.chunking import Chunk, StructuralChunker, TextChunker from agentkit.memory.chunking import Chunk, StructuralChunker, TextChunker
from agentkit.memory.document_loader import Document as LoaderDocument from agentkit.memory.document_loader import Document as LoaderDocument
from agentkit.memory.embedder import Embedder from agentkit.memory.embedder import Embedder
from agentkit.memory.knowledge_base import ( from agentkit.memory.knowledge_base import (
Document, Document,
KnowledgeBase,
QueryResult, QueryResult,
SourceInfo, SourceInfo,
) )
from agentkit.utils.vector_math import compute_cosine_similarity from agentkit.utils.vector_math import compute_cosine_similarity
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
_SAFE_TABLE_NAME_PATTERN = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# InMemoryLocalRAGService 内部存储的文档元信息结构。
# 字段title/source_id/format/chunk_ids/metadata/created_at — 值为标量或 list[str]。
InMemoryDocInfo: TypeAlias = dict[str, object]
# 内部 chunk 存储结构content/embedding/metadata/source_doc_id。
InMemoryChunkInfo: TypeAlias = dict[str, object]
def _loader_doc_to_kb_doc(loader_doc: LoaderDocument) -> Document: def _loader_doc_to_kb_doc(loader_doc: LoaderDocument) -> Document:
"""将 document_loader.Document 转换为 knowledge_base.Document""" """将 document_loader.Document 转换为 knowledge_base.Document"""
@ -53,7 +60,7 @@ class LocalRAGService:
def __init__( def __init__(
self, self,
session_factory: Any, session_factory: object,
embedder: Embedder, embedder: Embedder,
chunk_size: int = 1000, chunk_size: int = 1000,
chunk_overlap: int = 200, chunk_overlap: int = 200,
@ -75,10 +82,14 @@ class LocalRAGService:
self._chunk_overlap = chunk_overlap self._chunk_overlap = chunk_overlap
self._table_name = table_name self._table_name = table_name
if not _SAFE_TABLE_NAME_PATTERN.match(self._table_name): if not _SAFE_TABLE_NAME_PATTERN.match(self._table_name):
raise ValueError(f"Invalid table_name: {self._table_name}. Must match [a-zA-Z_][a-zA-Z0-9_]*") raise ValueError(
f"Invalid table_name: {self._table_name}. Must match [a-zA-Z_][a-zA-Z0-9_]*"
)
self._pgvector_enabled = pgvector_enabled self._pgvector_enabled = pgvector_enabled
self._text_chunker = TextChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap) self._text_chunker = TextChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
self._structural_chunker = StructuralChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap) self._structural_chunker = StructuralChunker(
chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
async def ingest(self, documents: list[Document]) -> list[str]: async def ingest(self, documents: list[Document]) -> list[str]:
"""摄取文档列表 """摄取文档列表
@ -136,9 +147,7 @@ class LocalRAGService:
try: try:
from sqlalchemy import text as sql_text from sqlalchemy import text as sql_text
sql = sql_text( sql = sql_text(f"DELETE FROM {self._table_name} WHERE source_doc_id = :doc_id")
f"DELETE FROM {self._table_name} WHERE source_doc_id = :doc_id"
)
await db.execute(sql, {"doc_id": id}) await db.execute(sql, {"doc_id": id})
await db.commit() await db.commit()
return True return True
@ -171,20 +180,15 @@ class LocalRAGService:
sources = [] sources = []
for row in rows: for row in rows:
meta = {} sources.append(
if row.get("doc_metadata"): SourceInfo(
try: source_id=row["source_doc_id"],
meta = json.loads(row["doc_metadata"]) source_name=row.get("source_title", ""),
except (json.JSONDecodeError, TypeError): source_type=row.get("doc_format", "local"),
pass document_count=row.get("chunk_count", 0),
last_updated=row["created_at"] if row.get("created_at") else None,
sources.append(SourceInfo( )
source_id=row["source_doc_id"], )
source_name=row.get("source_title", ""),
source_type=row.get("doc_format", "local"),
document_count=row.get("chunk_count", 0),
last_updated=row["created_at"] if row.get("created_at") else None,
))
return sources return sources
except Exception as e: except Exception as e:
logger.error(f"Failed to list sources: {e}") logger.error(f"Failed to list sources: {e}")
@ -271,7 +275,7 @@ class LocalRAGService:
async def _query_pgvector( async def _query_pgvector(
self, self,
db: Any, db: AsyncSession,
query_embedding: list[float], query_embedding: list[float],
top_k: int, top_k: int,
) -> list[QueryResult]: ) -> list[QueryResult]:
@ -286,10 +290,13 @@ class LocalRAGService:
f"LIMIT :lim" f"LIMIT :lim"
) )
result = await db.execute(sql, { result = await db.execute(
"query_vec": str(query_embedding), sql,
"lim": top_k, {
}) "query_vec": str(query_embedding),
"lim": top_k,
},
)
rows = result.mappings().all() rows = result.mappings().all()
results = [] results = []
@ -306,21 +313,23 @@ class LocalRAGService:
except (json.JSONDecodeError, TypeError): except (json.JSONDecodeError, TypeError):
pass pass
results.append(QueryResult( results.append(
content=row["content"], QueryResult(
source_id=row["source_doc_id"], content=row["content"],
source_name=row.get("source_title", ""), source_id=row["source_doc_id"],
score=cosine, source_name=row.get("source_title", ""),
metadata=chunk_meta, score=cosine,
doc_id=row["source_doc_id"], metadata=chunk_meta,
title=row.get("source_title", ""), doc_id=row["source_doc_id"],
)) title=row.get("source_title", ""),
)
)
return results return results
async def _query_client_side( async def _query_client_side(
self, self,
db: Any, db: AsyncSession,
query_embedding: list[float], query_embedding: list[float],
top_k: int, top_k: int,
) -> list[QueryResult]: ) -> list[QueryResult]:
@ -363,15 +372,17 @@ class LocalRAGService:
except (json.JSONDecodeError, TypeError): except (json.JSONDecodeError, TypeError):
pass pass
candidates.append(QueryResult( candidates.append(
content=row["content"], QueryResult(
source_id=row["source_doc_id"], content=row["content"],
source_name=row.get("source_title", ""), source_id=row["source_doc_id"],
score=cosine, source_name=row.get("source_title", ""),
metadata=chunk_meta, score=cosine,
doc_id=row["source_doc_id"], metadata=chunk_meta,
title=row.get("source_title", ""), doc_id=row["source_doc_id"],
)) title=row.get("source_title", ""),
)
)
candidates.sort(key=lambda x: x.score, reverse=True) candidates.sort(key=lambda x: x.score, reverse=True)
return candidates[:top_k] return candidates[:top_k]
@ -398,11 +409,15 @@ class InMemoryLocalRAGService:
""" """
self._embedder = embedder self._embedder = embedder
self._text_chunker = TextChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap) self._text_chunker = TextChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
self._structural_chunker = StructuralChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap) self._structural_chunker = StructuralChunker(
chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
# 内存存储 # 内存存储
self._chunks: dict[str, dict[str, Any]] = {} # chunk_id → {content, embedding, metadata} self._chunks: dict[str, InMemoryChunkInfo] = {} # chunk_id → {content, embedding, metadata}
self._documents: dict[str, dict[str, Any]] = {} # doc_id → {title, format, chunk_ids, metadata, created_at} self._documents: dict[
str, InMemoryDocInfo
] = {} # doc_id → {title, format, chunk_ids, metadata, created_at}
async def ingest(self, documents: list[Document]) -> list[str]: async def ingest(self, documents: list[Document]) -> list[str]:
"""摄取文档列表 """摄取文档列表
@ -459,15 +474,17 @@ class InMemoryLocalRAGService:
source_doc_id = chunk_data["source_doc_id"] source_doc_id = chunk_data["source_doc_id"]
doc_info = self._documents.get(source_doc_id, {}) doc_info = self._documents.get(source_doc_id, {})
candidates.append(QueryResult( candidates.append(
content=chunk_data["content"], QueryResult(
source_id=source_doc_id, content=chunk_data["content"],
source_name=doc_info.get("title", ""), source_id=source_doc_id,
score=cosine, source_name=doc_info.get("title", ""),
metadata=chunk_data.get("metadata", {}), score=cosine,
doc_id=source_doc_id, metadata=chunk_data.get("metadata", {}),
title=doc_info.get("title", ""), doc_id=source_doc_id,
)) title=doc_info.get("title", ""),
)
)
candidates.sort(key=lambda x: x.score, reverse=True) candidates.sort(key=lambda x: x.score, reverse=True)
return candidates[:top_k] return candidates[:top_k]
@ -488,13 +505,15 @@ class InMemoryLocalRAGService:
"""列出已摄取的文档""" """列出已摄取的文档"""
sources = [] sources = []
for doc_id, doc_info in self._documents.items(): for doc_id, doc_info in self._documents.items():
sources.append(SourceInfo( sources.append(
source_id=doc_id, SourceInfo(
source_name=doc_info["title"], source_id=doc_id,
source_type=doc_info.get("format", "local"), source_name=doc_info["title"],
document_count=len(doc_info.get("chunk_ids", [])), source_type=doc_info.get("format", "local"),
last_updated=doc_info.get("created_at"), document_count=len(doc_info.get("chunk_ids", [])),
)) last_updated=doc_info.get("created_at"),
)
)
return sources return sources
async def health_check(self) -> bool: async def health_check(self) -> bool:

View File

@ -13,7 +13,6 @@ import asyncio
import hashlib import hashlib
import logging import logging
from dataclasses import replace from dataclasses import replace
from typing import Any
from agentkit.memory.knowledge_base import KnowledgeBase, QueryResult, SourceInfo from agentkit.memory.knowledge_base import KnowledgeBase, QueryResult, SourceInfo
@ -186,15 +185,13 @@ class MultiSourceRetriever:
Returns: Returns:
所有源的检索结果列表已应用权重 所有源的检索结果列表已应用权重
""" """
async def _query_one(name: str, kb: KnowledgeBase) -> list[QueryResult]: async def _query_one(name: str, kb: KnowledgeBase) -> list[QueryResult]:
try: try:
results = await kb.query(query, top_k=top_k) results = await kb.query(query, top_k=top_k)
# 应用权重 # 应用权重
weight = (weights or {}).get(name, 1.0) weight = (weights or {}).get(name, 1.0)
return [ return [replace(r, score=r.score * weight, source_name=name) for r in results]
replace(r, score=r.score * weight, source_name=name)
for r in results
]
except Exception as e: except Exception as e:
logger.error(f"Query failed for source '{name}': {e}") logger.error(f"Query failed for source '{name}': {e}")
return [] return []

View File

@ -7,10 +7,10 @@
from __future__ import annotations from __future__ import annotations
import re import re
from dataclasses import dataclass, field from dataclasses import dataclass
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
from typing import Any, Callable from typing import Callable
class MemoryFile: class MemoryFile:
@ -26,8 +26,9 @@ class MemoryFile:
""" """
def __init__(self, path: Path, char_budget: int | None = None, def __init__(
protected_sections: set[str] | None = None): self, path: Path, char_budget: int | None = None, protected_sections: set[str] | None = None
):
self.path = Path(path) self.path = Path(path)
self.char_budget = char_budget self.char_budget = char_budget
self._protected_sections = protected_sections or set() self._protected_sections = protected_sections or set()
@ -138,7 +139,7 @@ class MemoryFile:
for match in re.finditer(r"^## (.+)$", content, re.MULTILINE): for match in re.finditer(r"^## (.+)$", content, re.MULTILINE):
name = match.group(1).strip() name = match.group(1).strip()
start = match.start() start = match.start()
next_match = re.search(r"^## ", content[match.end():], re.MULTILINE) next_match = re.search(r"^## ", content[match.end() :], re.MULTILINE)
if next_match: if next_match:
end = match.end() + next_match.start() end = match.end() + next_match.start()
else: else:
@ -146,7 +147,7 @@ class MemoryFile:
sections.append((name, start, end)) sections.append((name, start, end))
if not sections: if not sections:
return content[:self.char_budget] return content[: self.char_budget]
# 保持原始顺序,标记每个 section 是否受保护 # 保持原始顺序,标记每个 section 是否受保护
ordered: list[tuple[str, str, bool]] = [] # (name, text, is_protected) ordered: list[tuple[str, str, bool]] = [] # (name, text, is_protected)
@ -222,8 +223,9 @@ class MemoryStore:
""" """
def __init__(self, base_dir: Path | str | None = None, def __init__(
on_change: Callable[[str], None] | None = None): self, base_dir: Path | str | None = None, on_change: Callable[[str], None] | None = None
):
if base_dir is None: if base_dir is None:
base_dir = Path.home() / ".agentkit" base_dir = Path.home() / ".agentkit"
self.base_dir = Path(base_dir) self.base_dir = Path(base_dir)
@ -238,7 +240,9 @@ class MemoryStore:
protected_sections={"版本", "更新历史"}, protected_sections={"版本", "更新历史"},
) )
self._user = MemoryFile(self.base_dir / "memories" / "USER.md", char_budget=USER_BUDGET) self._user = MemoryFile(self.base_dir / "memories" / "USER.md", char_budget=USER_BUDGET)
self._memory = MemoryFile(self.base_dir / "memories" / "MEMORY.md", char_budget=MEMORY_BUDGET) self._memory = MemoryFile(
self.base_dir / "memories" / "MEMORY.md", char_budget=MEMORY_BUDGET
)
self._daily_dir = self.base_dir / "memories" / "daily" self._daily_dir = self.base_dir / "memories" / "daily"
self._daily_dir.mkdir(parents=True, exist_ok=True) self._daily_dir.mkdir(parents=True, exist_ok=True)
@ -376,4 +380,5 @@ class MemoryStore:
self._on_change(new_prompt) self._on_change(new_prompt)
except Exception: except Exception:
import logging import logging
logging.getLogger(__name__).warning("Memory notify_change failed", exc_info=True) logging.getLogger(__name__).warning("Memory notify_change failed", exc_info=True)

View File

@ -87,10 +87,22 @@ class RuleQueryTransformer(QueryTransformerBase):
""" """
_FILLER_WORDS_CN: list[str] = [ _FILLER_WORDS_CN: list[str] = [
"帮我", "", "一下", "分析", "看看", "告诉我", "想知道", "请问", "帮我",
"",
"一下",
"分析",
"看看",
"告诉我",
"想知道",
"请问",
] ]
_FILLER_WORDS_EN: list[str] = [ _FILLER_WORDS_EN: list[str] = [
"please", "can you", "help me", "could you", "i want to", "i need to", "please",
"can you",
"help me",
"could you",
"i want to",
"i need to",
] ]
def __init__( def __init__(
@ -101,9 +113,7 @@ class RuleQueryTransformer(QueryTransformerBase):
self._synonyms = synonyms or {} self._synonyms = synonyms or {}
self._max_sub_queries = max_sub_queries self._max_sub_queries = max_sub_queries
# Pre-compile filler patterns # Pre-compile filler patterns
self._filler_patterns_cn = [ self._filler_patterns_cn = [re.compile(re.escape(w)) for w in self._FILLER_WORDS_CN]
re.compile(re.escape(w)) for w in self._FILLER_WORDS_CN
]
self._filler_patterns_en = [ self._filler_patterns_en = [
re.compile(re.escape(w), re.IGNORECASE) for w in self._FILLER_WORDS_EN re.compile(re.escape(w), re.IGNORECASE) for w in self._FILLER_WORDS_EN
] ]
@ -166,7 +176,9 @@ def create_query_transformer(
"""工厂函数:根据策略创建查询改写器""" """工厂函数:根据策略创建查询改写器"""
if strategy == "llm": if strategy == "llm":
if llm_gateway is None: if llm_gateway is None:
logger.warning("LLM strategy requested but no llm_gateway provided, falling back to NoOp") logger.warning(
"LLM strategy requested but no llm_gateway provided, falling back to NoOp"
)
return NoOpQueryTransformer() return NoOpQueryTransformer()
return LLMQueryTransformer(llm_gateway, max_sub_queries=max_sub_queries) return LLMQueryTransformer(llm_gateway, max_sub_queries=max_sub_queries)
elif strategy == "rule": elif strategy == "rule":

View File

@ -7,11 +7,11 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from dataclasses import dataclass, field from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Any from typing import TYPE_CHECKING
from agentkit.memory.base import MemoryItem from agentkit.memory.base import MemoryItem, MetadataDict
from agentkit.memory.query_transformer import QueryTransformerBase, NoOpQueryTransformer from agentkit.memory.query_transformer import QueryTransformerBase, NoOpQueryTransformer
from agentkit.memory.relevance_scorer import ( from agentkit.memory.relevance_scorer import (
RelevanceScorer, RelevanceScorer,
@ -19,6 +19,10 @@ from agentkit.memory.relevance_scorer import (
RetrievalEvaluation, RetrievalEvaluation,
) )
if TYPE_CHECKING:
# 避免与 retriever.py 形成运行时循环导入。
from agentkit.memory.retriever import MemoryRetriever
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -70,7 +74,7 @@ class RAGSelfCorrectionLoop:
def __init__( def __init__(
self, self,
retriever: Any, # MemoryRetriever retriever: MemoryRetriever,
scorer: RelevanceScorer | None = None, scorer: RelevanceScorer | None = None,
query_transformer: QueryTransformerBase | None = None, query_transformer: QueryTransformerBase | None = None,
max_retries: int = 3, max_retries: int = 3,
@ -87,7 +91,7 @@ class RAGSelfCorrectionLoop:
query: str, query: str,
top_k: int = 5, top_k: int = 5,
token_budget: int = 3000, token_budget: int = 3000,
filters: dict[str, Any] | None = None, filters: MetadataDict | None = None,
) -> RAGLoopResult: ) -> RAGLoopResult:
"""执行带自纠正的检索 """执行带自纠正的检索
@ -107,8 +111,11 @@ class RAGSelfCorrectionLoop:
while retry_count <= self._max_retries: while retry_count <= self._max_retries:
# RETRIEVE # RETRIEVE
items = await self._retriever.retrieve( items = await self._retriever.retrieve(
current_query, top_k=top_k, token_budget=token_budget, current_query,
filters=filters, _skip_correction=True, top_k=top_k,
token_budget=token_budget,
filters=filters,
_skip_correction=True,
) )
# EVALUATE # EVALUATE
@ -144,9 +151,7 @@ class RAGSelfCorrectionLoop:
# CORRECT — rewrite query and retry # CORRECT — rewrite query and retry
retry_count += 1 retry_count += 1
if retry_count <= self._max_retries: if retry_count <= self._max_retries:
current_query = await self._rewrite_query( current_query = await self._rewrite_query(query, current_query, evaluation)
query, current_query, evaluation
)
continue continue
# DEGRADE — exceeded max retries # DEGRADE — exceeded max retries
@ -154,9 +159,7 @@ class RAGSelfCorrectionLoop:
# Degraded result: filter to relevant items and mark low confidence # Degraded result: filter to relevant items and mark low confidence
relevant_items = [ relevant_items = [
s.item s.item for s in evaluation.scores if s.verdict != RelevanceVerdict.INCORRECT
for s in evaluation.scores
if s.verdict != RelevanceVerdict.INCORRECT
] ]
result_items = relevant_items if relevant_items else items result_items = relevant_items if relevant_items else items

View File

@ -6,11 +6,9 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
import math
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Any
from agentkit.memory.base import MemoryItem from agentkit.memory.base import MemoryItem
@ -120,9 +118,7 @@ class RelevanceScorer:
reason=reason, reason=reason,
) )
def evaluate( def evaluate(self, query: str, items: list[MemoryItem]) -> RetrievalEvaluation:
self, query: str, items: list[MemoryItem]
) -> RetrievalEvaluation:
"""评估一次检索的整体质量""" """评估一次检索的整体质量"""
if not items: if not items:
return RetrievalEvaluation( return RetrievalEvaluation(
@ -134,9 +130,7 @@ class RelevanceScorer:
) )
scores = [self.score_item(query, item) for item in items] scores = [self.score_item(query, item) for item in items]
relevant_count = sum( relevant_count = sum(1 for s in scores if s.verdict != RelevanceVerdict.INCORRECT)
1 for s in scores if s.verdict != RelevanceVerdict.INCORRECT
)
avg_score = sum(s.score for s in scores) / len(scores) avg_score = sum(s.score for s in scores) / len(scores)
# Overall verdict based on average score and relevant ratio # Overall verdict based on average score and relevant ratio

View File

@ -7,19 +7,16 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import math
from dataclasses import replace from dataclasses import replace
from datetime import datetime
from typing import Any
from agentkit.memory.base import Memory, MemoryItem from agentkit.memory.base import MemoryItem, MetadataDict
from agentkit.memory.working import WorkingMemory from agentkit.memory.working import WorkingMemory
from agentkit.memory.episodic import EpisodicMemory from agentkit.memory.episodic import EpisodicMemory
from agentkit.memory.semantic import SemanticMemory from agentkit.memory.semantic import SemanticMemory
from agentkit.memory.query_transformer import QueryTransformerBase from agentkit.memory.query_transformer import QueryTransformerBase
from agentkit.memory.rag_loop import RAGSelfCorrectionLoop from agentkit.memory.rag_loop import RAGSelfCorrectionLoop
from agentkit.memory.relevance_scorer import RelevanceScorer from agentkit.memory.relevance_scorer import RelevanceScorer
from agentkit.memory.knowledge_base import KnowledgeBase, QueryResult from agentkit.memory.knowledge_base import KnowledgeBase
from agentkit.memory.multi_source_retriever import MultiSourceRetriever from agentkit.memory.multi_source_retriever import MultiSourceRetriever
from agentkit.tools.base import Tool from agentkit.tools.base import Tool
@ -32,11 +29,11 @@ def _estimate_tokens(text: str) -> int:
Chinese characters typically use 1-2 tokens each. Chinese characters typically use 1-2 tokens each.
English words typically use 1 token each. English words typically use 1 token each.
""" """
cjk_count = sum(1 for c in text if '\u4e00' <= c <= '\u9fff') cjk_count = sum(1 for c in text if "\u4e00" <= c <= "\u9fff")
non_cjk = text non_cjk = text
for c in text: for c in text:
if '\u4e00' <= c <= '\u9fff': if "\u4e00" <= c <= "\u9fff":
non_cjk = non_cjk.replace(c, ' ') non_cjk = non_cjk.replace(c, " ")
word_count = len(non_cjk.split()) word_count = len(non_cjk.split())
return cjk_count * 2 + word_count return cjk_count * 2 + word_count
@ -89,7 +86,7 @@ class MemoryRetriever:
query: str, query: str,
top_k: int = 5, top_k: int = 5,
token_budget: int = 3000, token_budget: int = 3000,
filters: dict[str, Any] | None = None, filters: MetadataDict | None = None,
_skip_correction: bool = False, _skip_correction: bool = False,
sources: list[str] | None = None, sources: list[str] | None = None,
source_weights: dict[str, float] | None = None, source_weights: dict[str, float] | None = None,
@ -121,9 +118,7 @@ class MemoryRetriever:
query, top_k=top_k, token_budget=token_budget, filters=filters query, top_k=top_k, token_budget=token_budget, filters=filters
) )
if result.degraded: if result.degraded:
logger.warning( logger.warning(f"RAG self-correction degraded after {result.total_retries} retries")
f"RAG self-correction degraded after {result.total_retries} retries"
)
return result.items return result.items
# Query transformation # Query transformation
if self._query_transformer is not None: if self._query_transformer is not None:
@ -139,9 +134,7 @@ class MemoryRetriever:
# Sub-query search in parallel # Sub-query search in parallel
if sub_queries: if sub_queries:
sub_tasks = [ sub_tasks = [self._search_layers(sq, top_k, filters) for sq in sub_queries]
self._search_layers(sq, top_k, filters) for sq in sub_queries
]
sub_results = await asyncio.gather(*sub_tasks, return_exceptions=True) sub_results = await asyncio.gather(*sub_tasks, return_exceptions=True)
for result in sub_results: for result in sub_results:
if isinstance(result, Exception): if isinstance(result, Exception):
@ -178,7 +171,7 @@ class MemoryRetriever:
self, self,
query: str, query: str,
top_k: int = 5, top_k: int = 5,
filters: dict[str, Any] | None = None, filters: MetadataDict | None = None,
) -> list[MemoryItem]: ) -> list[MemoryItem]:
"""Search all configured memory layers with a single query""" """Search all configured memory layers with a single query"""
tasks = [] tasks = []
@ -237,18 +230,20 @@ class MemoryRetriever:
# QueryResult → MemoryItem # QueryResult → MemoryItem
items = [] items = []
for r in kb_results: for r in kb_results:
items.append(MemoryItem( items.append(
key=r.source_id, MemoryItem(
value=r.content, key=r.source_id,
metadata={ value=r.content,
**r.metadata, metadata={
"source": "rag", **r.metadata,
"source_name": r.source_name, "source": "rag",
"doc_id": r.doc_id, "source_name": r.source_name,
"document_title": r.title, "doc_id": r.doc_id,
}, "document_title": r.title,
score=r.score, },
)) score=r.score,
)
)
# Token 预算管理 # Token 预算管理
selected = [] selected = []
@ -318,7 +313,9 @@ class MemoryRetriever:
if source == "rag": if source == "rag":
kb_type = item.metadata.get("kb_type", "知识库") kb_type = item.metadata.get("kb_type", "知识库")
document_title = item.metadata.get("document_title", "未知文档") document_title = item.metadata.get("document_title", "未知文档")
return f"### 知识库参考 [来源: {kb_type} | 相关度: {score:.2f} | 文档: {document_title}]" return (
f"### 知识库参考 [来源: {kb_type} | 相关度: {score:.2f} | 文档: {document_title}]"
)
elif source == "graph": elif source == "graph":
return f"### 知识图谱 [实体: {item.key} | 相关度: {score:.2f}]" return f"### 知识图谱 [实体: {item.key} | 相关度: {score:.2f}]"
elif source == "episodic": elif source == "episodic":
@ -330,7 +327,7 @@ class MemoryRetriever:
return f"### 参考 [来源: {source} | 相关度: {score:.2f}]" return f"### 参考 [来源: {source} | 相关度: {score:.2f}]"
async def store_episode( async def store_episode(
self, key: str, value: Any, metadata: dict[str, Any] | None = None self, key: str, value: object, metadata: MetadataDict | None = None
) -> None: ) -> None:
"""Store an episode into episodic memory if available. """Store an episode into episodic memory if available.
@ -386,12 +383,14 @@ class RetrieveKnowledgeTool(Tool):
items = await self._retriever.retrieve(query, top_k=5) items = await self._retriever.retrieve(query, top_k=5)
results = [] results = []
for item in items: for item in items:
results.append({ results.append(
"content": item.value, {
"score": item.score, "content": item.value,
"source": item.metadata.get("source", "unknown"), "score": item.score,
"document_title": item.metadata.get("document_title", ""), "source": item.metadata.get("source", "unknown"),
}) "document_title": item.metadata.get("document_title", ""),
}
)
return {"query": query, "results": results, "call_count": self._call_count} return {"query": query, "results": results, "call_count": self._call_count}
except Exception as e: except Exception as e:
return {"error": str(e), "results": []} return {"error": str(e), "results": []}

View File

@ -3,14 +3,45 @@
适配器模式对接外部 RAG 服务和知识图谱 适配器模式对接外部 RAG 服务和知识图谱
""" """
import logging from __future__ import annotations
from typing import Any
from agentkit.memory.base import Memory, MemoryItem import logging
from typing import TYPE_CHECKING, Protocol
from agentkit.memory.base import Memory, MemoryItem, MetadataDict
if TYPE_CHECKING:
from agentkit.memory.http_rag import RAGSearchResult
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class _RAGServiceLike(Protocol):
"""RAG 检索服务最小接口契约duck-typed"""
async def search(
self,
query: str,
knowledge_base_ids: list[str] | None = ...,
top_k: int = ...,
) -> list[RAGSearchResult]: ...
async def enhanced_search(
self,
query: str,
knowledge_base_ids: list[str] | None = ...,
top_k: int = ...,
use_rerank: bool = ...,
use_compression: bool = ...,
) -> list[RAGSearchResult]: ...
class _GraphServiceLike(Protocol):
"""知识图谱服务最小接口契约duck-typed"""
async def query(self, query: str, depth: int = ...) -> list[dict[str, object]]: ...
class SemanticMemory(Memory): class SemanticMemory(Memory):
"""Semantic Memory - 知识库检索 """Semantic Memory - 知识库检索
@ -19,8 +50,8 @@ class SemanticMemory(Memory):
def __init__( def __init__(
self, self,
rag_service: Any = None, rag_service: _RAGServiceLike | None = None,
graph_service: Any = None, graph_service: _GraphServiceLike | None = None,
knowledge_base_ids: list[str] | None = None, knowledge_base_ids: list[str] | None = None,
search_mode: str = "standard", search_mode: str = "standard",
use_rerank: bool = True, use_rerank: bool = True,
@ -45,9 +76,9 @@ class SemanticMemory(Memory):
self._use_compression = use_compression self._use_compression = use_compression
self._kb_weights = kb_weights self._kb_weights = kb_weights
async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None: async def store(self, key: str, value: object, metadata: MetadataDict | None = None) -> None:
"""Semantic Memory 通常只读,写入委托给 RAG 服务的 ingest 方法""" """Semantic Memory 通常只读,写入委托给 RAG 服务的 ingest 方法"""
if self._rag_service and hasattr(self._rag_service, 'ingest'): if self._rag_service and hasattr(self._rag_service, "ingest"):
await self._rag_service.ingest(key, value, metadata) await self._rag_service.ingest(key, value, metadata)
else: else:
logger.warning("SemanticMemory.store: no RAG service configured for writing") logger.warning("SemanticMemory.store: no RAG service configured for writing")
@ -56,7 +87,9 @@ class SemanticMemory(Memory):
"""按 key 精确检索Semantic Memory 通常不按 key 检索)""" """按 key 精确检索Semantic Memory 通常不按 key 检索)"""
return None return None
async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None) -> list[MemoryItem]: async def search(
self, query: str, top_k: int = 5, filters: MetadataDict | None = None
) -> list[MemoryItem]:
"""语义检索知识库""" """语义检索知识库"""
items = [] items = []
@ -64,7 +97,9 @@ class SemanticMemory(Memory):
if self._rag_service: if self._rag_service:
try: try:
kb_ids = (filters or {}).get("knowledge_base_ids", self._knowledge_base_ids) kb_ids = (filters or {}).get("knowledge_base_ids", self._knowledge_base_ids)
if self._search_mode == "enhanced" and hasattr(self._rag_service, "enhanced_search"): if self._search_mode == "enhanced" and hasattr(
self._rag_service, "enhanced_search"
):
results = await self._rag_service.enhanced_search( results = await self._rag_service.enhanced_search(
query, query,
knowledge_base_ids=kb_ids, knowledge_base_ids=kb_ids,
@ -73,24 +108,28 @@ class SemanticMemory(Memory):
use_compression=self._use_compression, use_compression=self._use_compression,
) )
else: else:
results = await self._rag_service.search(query, knowledge_base_ids=kb_ids, top_k=top_k) results = await self._rag_service.search(
query, knowledge_base_ids=kb_ids, top_k=top_k
)
for r in results: for r in results:
kb_id = r.get("knowledge_base_id", "") kb_id = r.get("knowledge_base_id", "")
score = r.get("score", 0.0) score = r.get("score", 0.0)
# Apply per-KB weights # Apply per-KB weights
if self._kb_weights and kb_id in self._kb_weights: if self._kb_weights and kb_id in self._kb_weights:
score *= self._kb_weights[kb_id] score *= self._kb_weights[kb_id]
items.append(MemoryItem( items.append(
key=r.get("id", ""), MemoryItem(
value=r.get("content", ""), key=r.get("id", ""),
metadata={ value=r.get("content", ""),
"source": r.get("source", "rag"), metadata={
"score": score, "source": r.get("source", "rag"),
"document_id": r.get("document_id"), "score": score,
"knowledge_base_id": kb_id, "document_id": r.get("document_id"),
}, "knowledge_base_id": kb_id,
score=score, },
)) score=score,
)
)
except Exception as e: except Exception as e:
logger.error(f"RAG search failed: {e}") logger.error(f"RAG search failed: {e}")
@ -99,16 +138,18 @@ class SemanticMemory(Memory):
try: try:
graph_results = await self._graph_service.query(query, depth=2) graph_results = await self._graph_service.query(query, depth=2)
for r in graph_results[:top_k]: for r in graph_results[:top_k]:
items.append(MemoryItem( items.append(
key=r.get("id", ""), MemoryItem(
value=r.get("content", ""), key=r.get("id", ""),
metadata={ value=r.get("content", ""),
"source": "graph", metadata={
"entities": r.get("entities", []), "source": "graph",
"relations": r.get("relations", []), "entities": r.get("entities", []),
}, "relations": r.get("relations", []),
score=r.get("score", 0.0), },
)) score=r.get("score", 0.0),
)
)
except Exception as e: except Exception as e:
logger.error(f"Graph search failed: {e}") logger.error(f"Graph search failed: {e}")

View File

@ -3,11 +3,10 @@
import json import json
import logging import logging
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any
import redis.asyncio as aioredis import redis.asyncio as aioredis
from agentkit.memory.base import Memory, MemoryItem from agentkit.memory.base import Memory, MemoryItem, MetadataDict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -32,7 +31,7 @@ class WorkingMemory(Memory):
def _make_key(self, key: str) -> str: def _make_key(self, key: str) -> str:
return f"{self._key_prefix}:{key}" return f"{self._key_prefix}:{key}"
async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None: async def store(self, key: str, value: object, metadata: MetadataDict | None = None) -> None:
redis_key = self._make_key(key) redis_key = self._make_key(key)
item = MemoryItem( item = MemoryItem(
key=key, key=key,
@ -57,10 +56,14 @@ class WorkingMemory(Memory):
value=item_dict["value"], value=item_dict["value"],
metadata=item_dict.get("metadata", {}), metadata=item_dict.get("metadata", {}),
score=item_dict.get("score", 1.0), score=item_dict.get("score", 1.0),
created_at=datetime.fromisoformat(item_dict["created_at"]) if item_dict.get("created_at") else datetime.now(timezone.utc), created_at=datetime.fromisoformat(item_dict["created_at"])
if item_dict.get("created_at")
else datetime.now(timezone.utc),
) )
async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None) -> list[MemoryItem]: async def search(
self, query: str, top_k: int = 5, filters: MetadataDict | None = None
) -> list[MemoryItem]:
"""Working Memory 不支持语义检索,按 key 前缀匹配""" """Working Memory 不支持语义检索,按 key 前缀匹配"""
pattern = self._make_key(f"{query}*") pattern = self._make_key(f"{query}*")
keys = [] keys = []
@ -74,13 +77,15 @@ class WorkingMemory(Memory):
data = await self._redis.get(key) data = await self._redis.get(key)
if data: if data:
item_dict = json.loads(data) item_dict = json.loads(data)
items.append(MemoryItem( items.append(
key=item_dict["key"], MemoryItem(
value=item_dict["value"], key=item_dict["key"],
metadata=item_dict.get("metadata", {}), value=item_dict["value"],
score=1.0, metadata=item_dict.get("metadata", {}),
created_at=datetime.now(timezone.utc), score=1.0,
)) created_at=datetime.now(timezone.utc),
)
)
return items return items
async def delete(self, key: str) -> bool: async def delete(self, key: str) -> bool: