refactor(llm+memory+client): remove Any from type signatures
Eliminate 172 Any usages across llm/, memory/, client/ via: - TypeAlias (MetadataValue, MetadataDict, RAGSearchResult, etc.) - object for arbitrary dict/value types - TYPE_CHECKING Protocol for Redis/Quota/RAG/Graph services - TYPE_CHECKING import + string annotations for forward refs - Remove unused Any imports (18 F401 fixed) Tests: 253 passed (llm 21 failures are pre-existing litellm env issue) ruff: All checks passed
This commit is contained in:
parent
aa6367ff9f
commit
34a89c4873
|
|
@ -34,12 +34,19 @@ import os
|
|||
import sqlite3
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable
|
||||
from typing import Callable, TypeAlias
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 缓存的配置项 — 技能/工作流配置为 JSON 反序列化后的字典,值为标量或嵌套结构。
|
||||
# 服务器返回的 skills/workflows 列表元素是 dict(model_dump/to_dict 输出),
|
||||
# 其中可能包含 list/dict 等容器,因此使用 object 作为值类型。
|
||||
SkillConfigDict: TypeAlias = dict[str, object]
|
||||
WorkflowConfigDict: TypeAlias = dict[str, object]
|
||||
SyncedConfigPayload: TypeAlias = dict[str, object]
|
||||
|
||||
|
||||
# ── Defaults ──────────────────────────────────────────────────────────
|
||||
|
||||
|
|
@ -100,8 +107,8 @@ class ConfigSync:
|
|||
|
||||
# In-memory cache (mirrors the SQLite cache for fast access)
|
||||
self._version: str | None = None
|
||||
self._skills: list[dict[str, Any]] = []
|
||||
self._workflows: list[dict[str, Any]] = []
|
||||
self._skills: list[SkillConfigDict] = []
|
||||
self._workflows: list[WorkflowConfigDict] = []
|
||||
self._last_synced_at: str | None = None
|
||||
|
||||
# ── Lifecycle ─────────────────────────────────────────────────
|
||||
|
|
@ -232,15 +239,15 @@ class ConfigSync:
|
|||
"""Return the current cached config version hash."""
|
||||
return self._version
|
||||
|
||||
def get_skills(self) -> list[dict[str, Any]]:
|
||||
def get_skills(self) -> list[SkillConfigDict]:
|
||||
"""Return the cached skill configs."""
|
||||
return list(self._skills)
|
||||
|
||||
def get_workflows(self) -> list[dict[str, Any]]:
|
||||
def get_workflows(self) -> list[WorkflowConfigDict]:
|
||||
"""Return the cached workflow configs."""
|
||||
return list(self._workflows)
|
||||
|
||||
def get_all(self) -> dict[str, Any]:
|
||||
def get_all(self) -> SyncedConfigPayload:
|
||||
"""Return all cached configs as a single dict."""
|
||||
return {
|
||||
"version": self._version,
|
||||
|
|
@ -249,14 +256,14 @@ class ConfigSync:
|
|||
"synced_at": self._last_synced_at,
|
||||
}
|
||||
|
||||
def get_skill(self, name: str) -> dict[str, Any] | None:
|
||||
def get_skill(self, name: str) -> SkillConfigDict | None:
|
||||
"""Return a single skill config by name, or ``None``."""
|
||||
for skill in self._skills:
|
||||
if skill.get("name") == name:
|
||||
return skill
|
||||
return None
|
||||
|
||||
def get_workflow(self, workflow_id: str) -> dict[str, Any] | None:
|
||||
def get_workflow(self, workflow_id: str) -> WorkflowConfigDict | None:
|
||||
"""Return a single workflow config by ID, or ``None``."""
|
||||
for wf in self._workflows:
|
||||
if wf.get("workflow_id") == workflow_id:
|
||||
|
|
@ -281,7 +288,7 @@ class ConfigSync:
|
|||
conn.executescript(_CACHE_SCHEMA)
|
||||
conn.commit()
|
||||
|
||||
def _save_to_cache(self, data: dict[str, Any]) -> None:
|
||||
def _save_to_cache(self, data: SyncedConfigPayload) -> None:
|
||||
"""Save the synced configs to the local SQLite cache."""
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
with sqlite3.connect(str(self.cache_db_path)) as conn:
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ import logging
|
|||
import time
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
|
||||
from typing import TYPE_CHECKING, Protocol, runtime_checkable
|
||||
|
||||
from agentkit.llm.protocol import LLMResponse, TokenUsage, ToolCall
|
||||
from agentkit.utils.vector_math import compute_cosine_similarity
|
||||
|
|
@ -25,6 +25,52 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TYPE_CHECKING Protocols — 避免 Any,描述运行时 lazy import 的第三方对象
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _RedisLike(Protocol):
|
||||
"""Redis 客户端最小契约(仅覆盖本模块用到的方法)。"""
|
||||
|
||||
async def get(self, key: str) -> bytes | str | None: ...
|
||||
|
||||
async def mget(self, keys: list[str]) -> list[bytes | str | None]: ...
|
||||
|
||||
async def set(self, key: str, value: bytes | str, ex: int | None = None) -> None: ...
|
||||
|
||||
async def smembers(self, key: str) -> set[bytes | str]: ...
|
||||
|
||||
async def sadd(self, name: str, *values: str) -> int: ...
|
||||
|
||||
async def srem(self, name: str, *values: str) -> int: ...
|
||||
|
||||
async def scard(self, name: str) -> int: ...
|
||||
|
||||
async def delete(self, *names: str) -> int: ...
|
||||
|
||||
def pipeline(self) -> "_RedisPipelineLike": ...
|
||||
|
||||
class _RedisPipelineLike(Protocol):
|
||||
"""Redis pipeline 最小契约。"""
|
||||
|
||||
def get(self, key: str) -> "_RedisPipelineLike": ...
|
||||
|
||||
def set(
|
||||
self, key: str, value: bytes | str, ex: int | None = None
|
||||
) -> "_RedisPipelineLike": ...
|
||||
|
||||
def delete(self, *names: str) -> "_RedisPipelineLike": ...
|
||||
|
||||
def sadd(self, name: str, *values: str) -> "_RedisPipelineLike": ...
|
||||
|
||||
def srem(self, name: str, *values: str) -> "_RedisPipelineLike": ...
|
||||
|
||||
async def execute(self) -> list[object]: ...
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data Classes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -328,7 +374,7 @@ class RedisLLMCache:
|
|||
self._semantic_ttl = semantic_ttl
|
||||
self._similarity_threshold = similarity_threshold
|
||||
self._max_entries_to_scan = max_entries_to_scan
|
||||
self._redis: Any = None
|
||||
self._redis: _RedisLike | None = None
|
||||
self._fallback: InMemoryLLMCache | None = fallback # For auto-degradation
|
||||
self._degraded = False # True if Redis is unreachable
|
||||
|
||||
|
|
@ -691,7 +737,7 @@ class LitellmCacheManager:
|
|||
|
||||
def __init__(self, config: LitellmCacheConfig):
|
||||
self._config = config
|
||||
self._cache_instance: Any = None # litellm.caching.Cache 实例
|
||||
self._cache_instance: object | None = None # litellm.caching.Cache 实例
|
||||
self._hits = 0
|
||||
self._misses = 0
|
||||
|
||||
|
|
@ -709,7 +755,7 @@ class LitellmCacheManager:
|
|||
litellm.cache = None
|
||||
self._cache_instance = None
|
||||
|
||||
def _create_cache_instance(self) -> Any:
|
||||
def _create_cache_instance(self) -> object:
|
||||
"""根据 backend 配置创建 LiteLLM Cache 实例。
|
||||
|
||||
auto 模式按优先级尝试:RedisSemanticCache → RedisCache → InMemoryCache。
|
||||
|
|
|
|||
|
|
@ -2,14 +2,13 @@
|
|||
|
||||
import hashlib
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
|
||||
def generate_cache_key(
|
||||
model: str,
|
||||
messages: list[dict[str, str]],
|
||||
temperature: float,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
tools: list[dict[str, object]] | None = None,
|
||||
tool_choice: str = "auto",
|
||||
max_tokens: int = 2000,
|
||||
user_id: str | None = None,
|
||||
|
|
|
|||
|
|
@ -3,12 +3,12 @@
|
|||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from agentkit.llm.retry import CircuitBreakerConfig, RetryConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agentkit.channels.secrets import SecretsStore
|
||||
from agentkit.channels.secrets import SecretEntry, SecretsStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -56,7 +56,7 @@ class ProviderConfig:
|
|||
|
||||
api_key: str
|
||||
base_url: str
|
||||
models: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
models: dict[str, dict[str, object]] = field(default_factory=dict)
|
||||
type: str = "openai" # "openai" | "anthropic" | "gemini"
|
||||
max_tokens: int = 4096 # Anthropic: default max_tokens
|
||||
timeout: float = 120.0 # Anthropic: request timeout
|
||||
|
|
@ -168,18 +168,18 @@ class ProviderConfig:
|
|||
return f"llm:provider:{self.type}:api_key"
|
||||
|
||||
@staticmethod
|
||||
def _encode_secret_entry(entry: Any, key: str) -> str:
|
||||
def _encode_secret_entry(entry: object, key: str) -> str:
|
||||
"""把 SecretEntry 编码为 JSON 字符串(含 key 字段)。"""
|
||||
# entry 是 SecretEntry pydantic 模型,有 model_dump()
|
||||
if hasattr(entry, "model_dump"):
|
||||
data = entry.model_dump()
|
||||
data = entry.model_dump() # type: ignore[attr-defined]
|
||||
else:
|
||||
data = dict(entry)
|
||||
data = dict(entry) # type: ignore[call-overload]
|
||||
data["key"] = key
|
||||
return json.dumps(data)
|
||||
|
||||
@staticmethod
|
||||
def _decode_secret_entry(encoded: str) -> Any:
|
||||
def _decode_secret_entry(encoded: str) -> "SecretEntry":
|
||||
"""从 JSON 字符串解码 SecretEntry。返回带 .key 属性的对象。"""
|
||||
from agentkit.channels.secrets import SecretEntry
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import logging
|
|||
import time
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError
|
||||
from agentkit.llm.config import LLMConfig
|
||||
|
|
@ -14,9 +14,40 @@ from agentkit.llm.providers.tracker import UsageSummary, UsageTracker
|
|||
from agentkit.telemetry.tracing import get_tracer, _OTEL_AVAILABLE
|
||||
from agentkit.telemetry.metrics import llm_token_histogram
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agentkit.llm.cache import LitellmCacheManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TYPE_CHECKING Protocols — 避免 Any,描述运行时 lazy import 的对象
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _QuotaServiceLike(Protocol):
|
||||
"""Quota service 最小契约(仅覆盖 gateway._enforce_quota 用到的方法)。"""
|
||||
|
||||
async def is_model_allowed(
|
||||
self, db: Path, department_id: str, model: str
|
||||
) -> tuple[bool, str]: ...
|
||||
|
||||
async def check_quota(
|
||||
self,
|
||||
db: Path,
|
||||
department_id: str,
|
||||
quota_type: str,
|
||||
period: str,
|
||||
current: float,
|
||||
) -> tuple[bool, str]: ...
|
||||
|
||||
async def get_quota(
|
||||
self, db: Path, department_id: str, quota_type: str, period: str
|
||||
) -> dict[str, object] | None: ...
|
||||
|
||||
|
||||
class QuotaExceededError(Exception):
|
||||
"""Raised when a department's LLM quota is exceeded.
|
||||
|
||||
|
|
@ -29,8 +60,8 @@ class QuotaExceededError(Exception):
|
|||
department_id: str,
|
||||
quota_type: str,
|
||||
period: str,
|
||||
limit: Any,
|
||||
current: Any,
|
||||
limit: object,
|
||||
current: object,
|
||||
) -> None:
|
||||
self.department_id = department_id
|
||||
self.quota_type = quota_type
|
||||
|
|
@ -46,13 +77,13 @@ class QuotaExceededError(Exception):
|
|||
class LLMGateway:
|
||||
"""LLM 网关 - Provider 注册、模型别名解析、Fallback、Usage 追踪、Cache"""
|
||||
|
||||
def __init__(self, config: LLMConfig | None = None, usage_store: Any = None):
|
||||
def __init__(self, config: LLMConfig | None = None, usage_store: object | None = None):
|
||||
self._providers: dict[str, LLMProvider] = {}
|
||||
self._usage_tracker = UsageTracker(store=usage_store) if usage_store else UsageTracker()
|
||||
self._config = config or LLMConfig()
|
||||
|
||||
# Cache (U17 — LiteLLM 缓存管理器,opt-in,默认禁用)
|
||||
self._cache_manager: Any = None # LitellmCacheManager | None
|
||||
self._cache_manager: "LitellmCacheManager | None" = None
|
||||
if self._config.cache and self._config.cache.enabled:
|
||||
from agentkit.llm.cache import LitellmCacheConfig, LitellmCacheManager
|
||||
|
||||
|
|
@ -601,7 +632,7 @@ class LLMGateway:
|
|||
|
||||
async def _check_quota_value(
|
||||
self,
|
||||
quota_service: Any,
|
||||
quota_service: _QuotaServiceLike,
|
||||
db: Path,
|
||||
dept_id: str,
|
||||
period: str,
|
||||
|
|
|
|||
|
|
@ -27,12 +27,11 @@ from __future__ import annotations
|
|||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def migrate_api_keys_to_secrets(config_path: Path | str) -> dict[str, dict[str, Any]]:
|
||||
def migrate_api_keys_to_secrets(config_path: Path | str) -> dict[str, dict[str, object]]:
|
||||
"""把 agentkit.yaml 中的 plaintext API Key 迁移到 SecretsStore。
|
||||
|
||||
流程:
|
||||
|
|
@ -63,8 +62,8 @@ def migrate_api_keys_to_secrets(config_path: Path | str) -> dict[str, dict[str,
|
|||
|
||||
store = SecretsStore() # master key 从 env 加载
|
||||
|
||||
async def _run() -> dict[str, dict[str, Any]]:
|
||||
report: dict[str, dict[str, Any]] = {}
|
||||
async def _run() -> dict[str, dict[str, object]]:
|
||||
report: dict[str, dict[str, object]] = {}
|
||||
for name, pconf in llm_config.providers.items():
|
||||
if pconf.api_key_source == "secrets_store" and not pconf.api_key:
|
||||
report[name] = {"status": "skipped", "source": pconf.api_key_source}
|
||||
|
|
@ -93,9 +92,9 @@ def migrate_api_keys_to_secrets(config_path: Path | str) -> dict[str, dict[str,
|
|||
report = asyncio.run(_run())
|
||||
|
||||
# 写回 YAML:更新 llm.providers 段,保留其它段
|
||||
providers_out: dict[str, dict[str, Any]] = {}
|
||||
providers_out: dict[str, dict[str, object]] = {}
|
||||
for name, pconf in llm_config.providers.items():
|
||||
entry: dict[str, Any] = {
|
||||
entry: dict[str, object] = {
|
||||
"type": pconf.type,
|
||||
"base_url": pconf.base_url,
|
||||
"models": pconf.models,
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -23,7 +22,7 @@ class ToolCall:
|
|||
|
||||
id: str
|
||||
name: str
|
||||
arguments: dict[str, Any]
|
||||
arguments: dict[str, object]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -32,7 +31,7 @@ class LLMRequest:
|
|||
|
||||
messages: list[dict[str, str]]
|
||||
model: str
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
tools: list[dict[str, object]] | None = None
|
||||
tool_choice: str = "auto"
|
||||
temperature: float = 0.7
|
||||
max_tokens: int = 2000
|
||||
|
|
@ -42,13 +41,13 @@ class LLMRequest:
|
|||
self,
|
||||
messages: list[dict[str, str]],
|
||||
model: str,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
tools: list[dict[str, object]] | None = None,
|
||||
tool_choice: str = "auto",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
timeout: float | None = None,
|
||||
cache: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
cache: dict[str, object] | None = None,
|
||||
**kwargs: object,
|
||||
):
|
||||
self.messages = messages
|
||||
self.model = model
|
||||
|
|
@ -59,7 +58,7 @@ class LLMRequest:
|
|||
self.timeout = timeout
|
||||
self._extra = kwargs
|
||||
# U17 — LiteLLM cache 参数(cache_key 或 no-cache),透传到 litellm.acompletion
|
||||
self._cache: dict[str, Any] | None = cache
|
||||
self._cache: dict[str, object] | None = cache
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
|
|
@ -99,7 +98,9 @@ class AnthropicProvider(LLMProvider):
|
|||
"content-type": "application/json",
|
||||
}
|
||||
|
||||
def _convert_messages(self, messages: list[dict[str, str]]) -> tuple[str | list[dict[str, Any]] | None, list[dict[str, Any]]]:
|
||||
def _convert_messages(
|
||||
self, messages: list[dict[str, str]]
|
||||
) -> tuple[str | list[dict[str, object]] | None, list[dict[str, object]]]:
|
||||
"""将 OpenAI 风格消息转换为 Anthropic 格式
|
||||
|
||||
Returns:
|
||||
|
|
@ -110,8 +111,8 @@ class AnthropicProvider(LLMProvider):
|
|||
- list[dict]: Anthropic content blocks(支持 cache_control,U2/G2)
|
||||
- None: 无 system 消息
|
||||
"""
|
||||
system_prompt: str | list[dict[str, Any]] | None = None
|
||||
anthropic_messages: list[dict[str, Any]] = []
|
||||
system_prompt: str | list[dict[str, object]] | None = None
|
||||
anthropic_messages: list[dict[str, object]] = []
|
||||
|
||||
for msg in messages:
|
||||
role = msg.get("role", "")
|
||||
|
|
@ -127,7 +128,7 @@ class AnthropicProvider(LLMProvider):
|
|||
# 检查是否有 tool_calls (OpenAI 格式)
|
||||
tool_calls = msg.get("tool_calls")
|
||||
if tool_calls:
|
||||
blocks: list[dict[str, Any]] = []
|
||||
blocks: list[dict[str, object]] = []
|
||||
# 如果有文本内容,先添加文本块
|
||||
if content:
|
||||
blocks.append({"type": "text", "text": content})
|
||||
|
|
@ -139,25 +140,29 @@ class AnthropicProvider(LLMProvider):
|
|||
arguments = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
arguments = {"raw": arguments}
|
||||
blocks.append({
|
||||
blocks.append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": tc.get("id", ""),
|
||||
"name": func.get("name", ""),
|
||||
"input": arguments,
|
||||
})
|
||||
}
|
||||
)
|
||||
anthropic_messages.append({"role": "assistant", "content": blocks})
|
||||
else:
|
||||
anthropic_messages.append({
|
||||
anthropic_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": content}],
|
||||
})
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
if role == "user":
|
||||
# 检查是否是 tool_result 消息 (OpenAI 格式中 tool 角色的结果)
|
||||
# OpenAI 格式: {"role": "tool", "tool_call_id": "...", "content": "..."}
|
||||
if msg.get("tool_call_id"):
|
||||
tool_result_blocks: list[dict[str, Any]] = []
|
||||
tool_result_blocks: list[dict[str, object]] = []
|
||||
tool_content = msg.get("content", "")
|
||||
# tool_result 的 content 可以是字符串或内容块列表
|
||||
if isinstance(tool_content, str):
|
||||
|
|
@ -167,56 +172,72 @@ class AnthropicProvider(LLMProvider):
|
|||
else:
|
||||
tool_result_blocks.append({"type": "text", "text": str(tool_content)})
|
||||
|
||||
anthropic_messages.append({
|
||||
anthropic_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": msg.get("tool_call_id", ""),
|
||||
"content": tool_result_blocks,
|
||||
}],
|
||||
})
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
else:
|
||||
anthropic_messages.append({
|
||||
anthropic_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": content}],
|
||||
})
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
if role == "tool":
|
||||
# OpenAI 格式中独立的 tool 消息
|
||||
tool_content = msg.get("content", "")
|
||||
if isinstance(tool_content, str):
|
||||
result_content: list[dict[str, Any]] | str = [{"type": "text", "text": tool_content}]
|
||||
result_content: list[dict[str, object]] | str = [
|
||||
{"type": "text", "text": tool_content}
|
||||
]
|
||||
elif isinstance(tool_content, list):
|
||||
result_content = tool_content
|
||||
else:
|
||||
result_content = [{"type": "text", "text": str(tool_content)}]
|
||||
|
||||
anthropic_messages.append({
|
||||
anthropic_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": msg.get("tool_call_id", ""),
|
||||
"content": result_content,
|
||||
}],
|
||||
})
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
return system_prompt, anthropic_messages
|
||||
|
||||
def _convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
def _convert_tools(self, tools: list[dict[str, object]]) -> list[dict[str, object]]:
|
||||
"""将 OpenAI function 格式转换为 Anthropic tool 格式"""
|
||||
anthropic_tools = []
|
||||
for tool in tools:
|
||||
if tool.get("type") == "function":
|
||||
func = tool.get("function", {})
|
||||
anthropic_tools.append({
|
||||
anthropic_tools.append(
|
||||
{
|
||||
"name": func.get("name", ""),
|
||||
"description": func.get("description", ""),
|
||||
"input_schema": func.get("parameters", {"type": "object", "properties": {}}),
|
||||
})
|
||||
"input_schema": func.get(
|
||||
"parameters", {"type": "object", "properties": {}}
|
||||
),
|
||||
}
|
||||
)
|
||||
return anthropic_tools
|
||||
|
||||
def _convert_tool_choice(self, tool_choice: str) -> dict[str, Any] | None:
|
||||
def _convert_tool_choice(self, tool_choice: str) -> dict[str, object] | None:
|
||||
"""将 OpenAI tool_choice 格式转换为 Anthropic 格式"""
|
||||
if tool_choice == "auto":
|
||||
return {"type": "auto"}
|
||||
|
|
@ -227,7 +248,7 @@ class AnthropicProvider(LLMProvider):
|
|||
return {"type": "tool", "name": tool_choice}
|
||||
return None
|
||||
|
||||
def _parse_response(self, data: dict[str, Any], model: str) -> LLMResponse:
|
||||
def _parse_response(self, data: dict[str, object], model: str) -> LLMResponse:
|
||||
"""将 Anthropic 响应转换为 LLMResponse"""
|
||||
content_blocks = data.get("content", [])
|
||||
text_parts: list[str] = []
|
||||
|
|
@ -238,11 +259,13 @@ class AnthropicProvider(LLMProvider):
|
|||
if block_type == "text":
|
||||
text_parts.append(block.get("text", ""))
|
||||
elif block_type == "tool_use":
|
||||
tool_calls.append(ToolCall(
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=block.get("id", ""),
|
||||
name=block.get("name", ""),
|
||||
arguments=block.get("input", {}),
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
usage_data = data.get("usage", {})
|
||||
usage = TokenUsage(
|
||||
|
|
@ -287,7 +310,7 @@ class AnthropicProvider(LLMProvider):
|
|||
|
||||
system_prompt, anthropic_messages = self._convert_messages(request.messages)
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
payload: dict[str, object] = {
|
||||
"model": request.model,
|
||||
"max_tokens": request.max_tokens or self._max_tokens,
|
||||
"messages": anthropic_messages,
|
||||
|
|
@ -346,7 +369,7 @@ class AnthropicProvider(LLMProvider):
|
|||
|
||||
system_prompt, anthropic_messages = self._convert_messages(request.messages)
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
payload: dict[str, object] = {
|
||||
"model": request.model,
|
||||
"max_tokens": request.max_tokens or self._max_tokens,
|
||||
"messages": anthropic_messages,
|
||||
|
|
@ -375,7 +398,7 @@ class AnthropicProvider(LLMProvider):
|
|||
async def _iterate_stream(self, response, request: LLMRequest):
|
||||
"""Iterate over an already-open SSE stream and yield StreamChunks."""
|
||||
# Accumulated tool calls: tool_use_id -> {id, name, input_json_str}
|
||||
accumulated_tool_calls: dict[str, dict[str, Any]] = {}
|
||||
accumulated_tool_calls: dict[str, dict[str, object]] = {}
|
||||
current_tool_id: str | None = None
|
||||
current_tool_name: str | None = None
|
||||
current_tool_input_json: str = ""
|
||||
|
|
@ -433,7 +456,9 @@ class AnthropicProvider(LLMProvider):
|
|||
# Finalize current tool call if any
|
||||
if current_tool_id is not None:
|
||||
try:
|
||||
arguments = json.loads(current_tool_input_json) if current_tool_input_json else {}
|
||||
arguments = (
|
||||
json.loads(current_tool_input_json) if current_tool_input_json else {}
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
arguments = {"raw": current_tool_input_json}
|
||||
|
||||
|
|
@ -510,7 +535,7 @@ class AnthropicProvider(LLMProvider):
|
|||
error_msg = error_info.get("message", "Stream error")
|
||||
raise LLMProviderError("anthropic", error_msg)
|
||||
|
||||
def get_model_info(self) -> dict[str, Any]:
|
||||
def get_model_info(self) -> dict[str, object]:
|
||||
"""返回 Provider 和模型信息"""
|
||||
return {
|
||||
"provider": "anthropic",
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ API:火山引擎 OpenAI 兼容接口
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from agentkit.llm.providers.openai import OpenAICompatibleProvider
|
||||
|
||||
|
|
@ -48,7 +47,7 @@ class DoubaoProvider(OpenAICompatibleProvider):
|
|||
api_key: str,
|
||||
base_url: str = DOUBAO_DEFAULT_BASE_URL,
|
||||
default_model: str = "doubao-pro-32k",
|
||||
**kwargs: Any,
|
||||
**kwargs: object,
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
|
|
@ -90,14 +89,14 @@ class GeminiProvider(LLMProvider):
|
|||
|
||||
def _convert_messages(
|
||||
self, messages: list[dict[str, str]]
|
||||
) -> tuple[dict[str, Any] | None, list[dict[str, Any]]]:
|
||||
) -> tuple[dict[str, object] | None, list[dict[str, object]]]:
|
||||
"""将 OpenAI 风格消息转换为 Gemini 格式
|
||||
|
||||
Returns:
|
||||
(system_instruction, contents)
|
||||
"""
|
||||
system_instruction: dict[str, Any] | None = None
|
||||
contents: list[dict[str, Any]] = []
|
||||
system_instruction: dict[str, object] | None = None
|
||||
contents: list[dict[str, object]] = []
|
||||
|
||||
for msg in messages:
|
||||
role = msg.get("role", "")
|
||||
|
|
@ -119,28 +118,34 @@ class GeminiProvider(LLMProvider):
|
|||
tool_name = parsed.get("name", "")
|
||||
except (json.JSONDecodeError, AttributeError):
|
||||
pass
|
||||
contents.append({
|
||||
contents.append(
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [{
|
||||
"parts": [
|
||||
{
|
||||
"functionResponse": {
|
||||
"name": tool_name,
|
||||
"response": {
|
||||
"content": content,
|
||||
},
|
||||
},
|
||||
}],
|
||||
})
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
else:
|
||||
contents.append({
|
||||
contents.append(
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [{"text": content}],
|
||||
})
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
if role == "assistant":
|
||||
tool_calls = msg.get("tool_calls")
|
||||
if tool_calls:
|
||||
parts: list[dict[str, Any]] = []
|
||||
parts: list[dict[str, object]] = []
|
||||
if content:
|
||||
parts.append({"text": content})
|
||||
for tc in tool_calls:
|
||||
|
|
@ -151,54 +156,64 @@ class GeminiProvider(LLMProvider):
|
|||
arguments = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
arguments = {"raw": arguments}
|
||||
parts.append({
|
||||
parts.append(
|
||||
{
|
||||
"functionCall": {
|
||||
"name": func.get("name", ""),
|
||||
"args": arguments,
|
||||
},
|
||||
})
|
||||
}
|
||||
)
|
||||
contents.append({"role": "model", "parts": parts})
|
||||
else:
|
||||
contents.append({
|
||||
contents.append(
|
||||
{
|
||||
"role": "model",
|
||||
"parts": [{"text": content}],
|
||||
})
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
if role == "tool":
|
||||
# OpenAI format: {"role": "tool", "tool_call_id": "...", "content": "..."}
|
||||
tool_name = msg.get("name", "")
|
||||
tool_content = msg.get("content", "")
|
||||
contents.append({
|
||||
contents.append(
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [{
|
||||
"parts": [
|
||||
{
|
||||
"functionResponse": {
|
||||
"name": tool_name,
|
||||
"response": {
|
||||
"content": tool_content,
|
||||
},
|
||||
},
|
||||
}],
|
||||
})
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
return system_instruction, contents
|
||||
|
||||
def _convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
def _convert_tools(self, tools: list[dict[str, object]]) -> list[dict[str, object]]:
|
||||
"""将 OpenAI function 格式转换为 Gemini functionDeclarations"""
|
||||
declarations = []
|
||||
for tool in tools:
|
||||
if tool.get("type") == "function":
|
||||
func = tool.get("function", {})
|
||||
declarations.append({
|
||||
declarations.append(
|
||||
{
|
||||
"name": func.get("name", ""),
|
||||
"description": func.get("description", ""),
|
||||
"parameters": func.get("parameters", {"type": "object", "properties": {}}),
|
||||
})
|
||||
}
|
||||
)
|
||||
if not declarations:
|
||||
return []
|
||||
return [{"functionDeclarations": declarations}]
|
||||
|
||||
def _convert_tool_choice(self, tool_choice: str) -> dict[str, Any] | None:
|
||||
def _convert_tool_choice(self, tool_choice: str) -> dict[str, object] | None:
|
||||
"""将 OpenAI tool_choice 格式转换为 Gemini toolConfig"""
|
||||
if tool_choice == "auto":
|
||||
return {"functionCallingConfig": {"mode": "AUTO"}}
|
||||
|
|
@ -210,7 +225,7 @@ class GeminiProvider(LLMProvider):
|
|||
return {"functionCallingConfig": {"mode": "NONE"}}
|
||||
return None
|
||||
|
||||
def _parse_response(self, data: dict[str, Any], model: str) -> LLMResponse:
|
||||
def _parse_response(self, data: dict[str, object], model: str) -> LLMResponse:
|
||||
"""将 Gemini 响应转换为 LLMResponse"""
|
||||
candidates = data.get("candidates", [])
|
||||
text_parts: list[str] = []
|
||||
|
|
@ -225,11 +240,13 @@ class GeminiProvider(LLMProvider):
|
|||
text_parts.append(part["text"])
|
||||
elif "functionCall" in part:
|
||||
fc = part["functionCall"]
|
||||
tool_calls.append(ToolCall(
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=f"call_{tool_call_index}",
|
||||
name=fc.get("name", ""),
|
||||
arguments=fc.get("args", {}),
|
||||
))
|
||||
)
|
||||
)
|
||||
tool_call_index += 1
|
||||
|
||||
usage_metadata = data.get("usageMetadata", {})
|
||||
|
|
@ -275,7 +292,7 @@ class GeminiProvider(LLMProvider):
|
|||
|
||||
system_instruction, contents = self._convert_messages(request.messages)
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
payload: dict[str, object] = {
|
||||
"contents": contents,
|
||||
"generationConfig": {
|
||||
"temperature": request.temperature,
|
||||
|
|
@ -340,7 +357,7 @@ class GeminiProvider(LLMProvider):
|
|||
|
||||
system_instruction, contents = self._convert_messages(request.messages)
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
payload: dict[str, object] = {
|
||||
"contents": contents,
|
||||
"generationConfig": {
|
||||
"temperature": request.temperature,
|
||||
|
|
@ -374,7 +391,7 @@ class GeminiProvider(LLMProvider):
|
|||
|
||||
async def _iterate_stream(self, response, request: LLMRequest):
|
||||
"""Iterate over an already-open SSE stream and yield StreamChunks."""
|
||||
accumulated_tool_calls: list[dict[str, Any]] = []
|
||||
accumulated_tool_calls: list[dict[str, object]] = []
|
||||
model = request.model or self._model
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
|
|
@ -436,11 +453,13 @@ class GeminiProvider(LLMProvider):
|
|||
)
|
||||
elif "functionCall" in part:
|
||||
fc = part["functionCall"]
|
||||
accumulated_tool_calls.append({
|
||||
accumulated_tool_calls.append(
|
||||
{
|
||||
"id": f"call_{len(accumulated_tool_calls)}",
|
||||
"name": fc.get("name", ""),
|
||||
"arguments": fc.get("args", {}),
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
# Check for finish reason
|
||||
finish_reason = candidates[0].get("finishReason", "")
|
||||
|
|
@ -461,7 +480,7 @@ class GeminiProvider(LLMProvider):
|
|||
)
|
||||
accumulated_tool_calls = []
|
||||
|
||||
def get_model_info(self) -> dict[str, Any]:
|
||||
def get_model_info(self) -> dict[str, object]:
|
||||
"""返回 Provider 和模型信息"""
|
||||
return {
|
||||
"provider": "gemini",
|
||||
|
|
|
|||
|
|
@ -26,8 +26,7 @@ import inspect
|
|||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
from collections.abc import AsyncGenerator, Iterable
|
||||
|
||||
from agentkit.core.exceptions import LLMProviderError
|
||||
from agentkit.llm.protocol import (
|
||||
|
|
@ -81,13 +80,13 @@ class LitellmProvider(LLMProvider):
|
|||
api_key: str,
|
||||
base_url: str | None = None,
|
||||
provider_type: str = "openai",
|
||||
**default_kwargs: Any,
|
||||
**default_kwargs: object,
|
||||
) -> None:
|
||||
self._model_prefix = model_prefix
|
||||
self._api_key = api_key
|
||||
self._base_url = base_url or None # 空字符串视作未设置
|
||||
self._provider_type = provider_type
|
||||
self._default_kwargs: dict[str, Any] = dict(default_kwargs)
|
||||
self._default_kwargs: dict[str, object] = dict(default_kwargs)
|
||||
|
||||
async def chat(self, request: LLMRequest) -> LLMResponse:
|
||||
"""非流式 chat — 调用 ``litellm.acompletion`` 并翻译响应。"""
|
||||
|
|
@ -116,7 +115,7 @@ class LitellmProvider(LLMProvider):
|
|||
|
||||
kwargs = self._build_kwargs(request, stream=True)
|
||||
|
||||
accumulated_tool_calls: dict[int, dict[str, Any]] = {}
|
||||
accumulated_tool_calls: dict[int, dict[str, object]] = {}
|
||||
final_usage: TokenUsage | None = None
|
||||
final_model: str = request.model
|
||||
|
||||
|
|
@ -158,9 +157,9 @@ class LitellmProvider(LLMProvider):
|
|||
# 内部辅助
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _build_kwargs(self, request: LLMRequest, *, stream: bool) -> dict[str, Any]:
|
||||
def _build_kwargs(self, request: LLMRequest, *, stream: bool) -> dict[str, object]:
|
||||
"""从 LLMRequest 构造 litellm.acompletion kwargs。"""
|
||||
kwargs: dict[str, Any] = {
|
||||
kwargs: dict[str, object] = {
|
||||
"model": f"{self._model_prefix}{request.model}",
|
||||
"messages": request.messages,
|
||||
"temperature": request.temperature,
|
||||
|
|
@ -187,7 +186,7 @@ class LitellmProvider(LLMProvider):
|
|||
|
||||
def _parse_response(
|
||||
self,
|
||||
response: Any,
|
||||
response: object,
|
||||
request_model: str,
|
||||
latency_ms: float,
|
||||
) -> LLMResponse:
|
||||
|
|
@ -229,9 +228,9 @@ class LitellmProvider(LLMProvider):
|
|||
|
||||
def _parse_stream_chunk(
|
||||
self,
|
||||
chunk: Any,
|
||||
chunk: object,
|
||||
request_model: str,
|
||||
accumulated_tool_calls: dict[int, dict[str, Any]],
|
||||
accumulated_tool_calls: dict[int, dict[str, object]],
|
||||
) -> StreamChunk:
|
||||
"""解析单个流式 chunk(非 final)。累计 tool_calls 到传入字典。"""
|
||||
choices = getattr(chunk, "choices", None) or []
|
||||
|
|
@ -262,7 +261,7 @@ class LitellmProvider(LLMProvider):
|
|||
|
||||
def _finalize_tool_calls(
|
||||
self,
|
||||
accumulated: dict[int, dict[str, Any]],
|
||||
accumulated: dict[int, dict[str, object]],
|
||||
) -> list[ToolCall]:
|
||||
"""把累计的流式 tool_calls 字典转成 ToolCall 列表。"""
|
||||
tool_calls: list[ToolCall] = []
|
||||
|
|
@ -288,7 +287,7 @@ class LitellmProvider(LLMProvider):
|
|||
# ----------------------------------------------------------------------
|
||||
|
||||
|
||||
def _parse_tool_calls(raw_tool_calls: Any) -> list[ToolCall]:
|
||||
def _parse_tool_calls(raw_tool_calls: Iterable[object]) -> list[ToolCall]:
|
||||
"""解析非流式响应的 tool_calls(OpenAI 格式 list[ChoiceMessageToolCall])。"""
|
||||
result: list[ToolCall] = []
|
||||
for tc in raw_tool_calls:
|
||||
|
|
@ -312,7 +311,7 @@ def _parse_tool_calls(raw_tool_calls: Any) -> list[ToolCall]:
|
|||
return result
|
||||
|
||||
|
||||
def _parse_usage(usage_obj: Any) -> TokenUsage:
|
||||
def _parse_usage(usage_obj: object) -> TokenUsage:
|
||||
"""解析 usage 对象(OpenAI CompletionUsage 或 dict)。"""
|
||||
prompt = getattr(usage_obj, "prompt_tokens", None)
|
||||
completion = getattr(usage_obj, "completion_tokens", None)
|
||||
|
|
@ -326,8 +325,8 @@ def _parse_usage(usage_obj: Any) -> TokenUsage:
|
|||
|
||||
|
||||
def _accumulate_stream_tool_calls(
|
||||
raw_tool_calls: Any,
|
||||
accumulated: dict[int, dict[str, Any]],
|
||||
raw_tool_calls: Iterable[object],
|
||||
accumulated: dict[int, dict[str, object]],
|
||||
) -> None:
|
||||
"""累计流式 chunk 里的 tool_calls 片段(OpenAI delta.tool_calls 格式)。
|
||||
|
||||
|
|
@ -364,7 +363,7 @@ def create_litellm_provider(
|
|||
provider_type: str,
|
||||
api_key: str,
|
||||
base_url: str | None = None,
|
||||
**kwargs: Any,
|
||||
**kwargs: object,
|
||||
) -> LitellmProvider:
|
||||
"""根据 provider_type 创建 LitellmProvider 实例。
|
||||
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ import json
|
|||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from agentkit.llm.protocol import TokenUsage
|
||||
|
||||
|
|
@ -294,8 +294,8 @@ class RedisUsageStore:
|
|||
|
||||
def __init__(self, redis_url: str = "redis://localhost:6379"):
|
||||
self._redis_url = redis_url
|
||||
self._redis: Any = None
|
||||
self._sync_redis: Any = None
|
||||
self._redis: object | None = None
|
||||
self._sync_redis: object | None = None
|
||||
self._fallback: InMemoryUsageStore | None = None
|
||||
self._degraded = False
|
||||
self._health_check_task: asyncio.Task[None] | None = None
|
||||
|
|
@ -687,7 +687,7 @@ class RedisUsageStore:
|
|||
|
||||
@staticmethod
|
||||
def _read_list(
|
||||
r: Any,
|
||||
r: object,
|
||||
list_key: str,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ from __future__ import annotations
|
|||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from agentkit.llm.providers.openai import OpenAICompatibleProvider
|
||||
from agentkit.llm.protocol import LLMRequest, LLMResponse
|
||||
|
|
@ -51,7 +50,7 @@ class WenxinProvider(OpenAICompatibleProvider):
|
|||
secret_key: str | None = None,
|
||||
base_url: str = WENXIN_DEFAULT_BASE_URL,
|
||||
default_model: str = "ernie-4.5-turbo-128k",
|
||||
**kwargs: Any,
|
||||
**kwargs: object,
|
||||
):
|
||||
# If AK/SK provided, use token-based auth
|
||||
self._access_key = access_key
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ API:腾讯云 OpenAI 兼容接口
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from agentkit.llm.providers.openai import OpenAICompatibleProvider
|
||||
from agentkit.llm.protocol import LLMRequest, LLMResponse
|
||||
|
|
@ -48,7 +47,7 @@ class YuanbaoProvider(OpenAICompatibleProvider):
|
|||
base_url: str = YUANBAO_DEFAULT_BASE_URL,
|
||||
default_model: str = "hunyuan-turbos-latest",
|
||||
enable_enhancement: bool = False,
|
||||
**kwargs: Any,
|
||||
**kwargs: object,
|
||||
):
|
||||
self._enable_enhancement = enable_enhancement
|
||||
super().__init__(
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncIterator, Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
|
|
@ -66,7 +65,7 @@ class RemoteLLMProvider(LLMProvider):
|
|||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
def _build_payload(self, request: LLMRequest) -> dict[str, Any]:
|
||||
def _build_payload(self, request: LLMRequest) -> dict[str, object]:
|
||||
"""Convert LLMRequest to server API payload."""
|
||||
return {
|
||||
"messages": request.messages,
|
||||
|
|
@ -91,7 +90,7 @@ class RemoteLLMProvider(LLMProvider):
|
|||
return str(body["error"])
|
||||
return str(body)
|
||||
|
||||
def _parse_response(self, data: dict[str, Any], request: LLMRequest) -> LLMResponse:
|
||||
def _parse_response(self, data: dict[str, object], request: LLMRequest) -> LLMResponse:
|
||||
"""Parse server response JSON into an LLMResponse."""
|
||||
usage_data = data.get("usage") or {}
|
||||
usage = TokenUsage(
|
||||
|
|
@ -115,7 +114,7 @@ class RemoteLLMProvider(LLMProvider):
|
|||
latency_ms=data.get("latency_ms", 0.0),
|
||||
)
|
||||
|
||||
def _parse_chunk(self, data: dict[str, Any], request: LLMRequest) -> StreamChunk:
|
||||
def _parse_chunk(self, data: dict[str, object], request: LLMRequest) -> StreamChunk:
|
||||
"""Parse a single SSE data payload into a StreamChunk."""
|
||||
usage: TokenUsage | None = None
|
||||
usage_data = data.get("usage")
|
||||
|
|
@ -218,9 +217,7 @@ class RemoteLLMProvider(LLMProvider):
|
|||
if response.status_code == 502:
|
||||
await response.aread()
|
||||
detail = self._extract_error_detail(response)
|
||||
raise LLMProviderError(
|
||||
"remote", f"Server LLM gateway error: {detail}"
|
||||
)
|
||||
raise LLMProviderError("remote", f"Server LLM gateway error: {detail}")
|
||||
if response.status_code != 200:
|
||||
await response.aread()
|
||||
raise LLMProviderError(
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import logging
|
|||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Callable
|
||||
from typing import Callable
|
||||
|
||||
from agentkit.core.exceptions import LLMProviderError
|
||||
|
||||
|
|
@ -20,9 +20,7 @@ class RetryConfig:
|
|||
base_delay: float = 1.0
|
||||
max_delay: float = 30.0
|
||||
exponential_base: float = 2.0
|
||||
retryable_status_codes: set[int] = field(
|
||||
default_factory=lambda: {429, 500, 502, 503, 529}
|
||||
)
|
||||
retryable_status_codes: set[int] = field(default_factory=lambda: {429, 500, 502, 503, 529})
|
||||
|
||||
|
||||
class CircuitState(Enum):
|
||||
|
|
@ -69,7 +67,7 @@ class RetryPolicy:
|
|||
def __init__(self, config: RetryConfig | None = None):
|
||||
self._config = config or RetryConfig()
|
||||
|
||||
async def execute(self, fn: Callable, *args: Any, **kwargs: Any) -> Any:
|
||||
async def execute(self, fn: Callable, *args: object, **kwargs: object) -> object:
|
||||
"""Execute fn with retry on retryable errors."""
|
||||
last_error: Exception | None = None
|
||||
|
||||
|
|
@ -84,7 +82,7 @@ class RetryPolicy:
|
|||
raise
|
||||
|
||||
delay = min(
|
||||
self._config.base_delay * (self._config.exponential_base ** attempt),
|
||||
self._config.base_delay * (self._config.exponential_base**attempt),
|
||||
self._config.max_delay,
|
||||
)
|
||||
logger.warning(
|
||||
|
|
@ -142,7 +140,7 @@ class CircuitBreaker:
|
|||
f"after {self._failure_count} failures"
|
||||
)
|
||||
|
||||
async def execute(self, fn: Callable, *args: Any, **kwargs: Any) -> Any:
|
||||
async def execute(self, fn: Callable, *args: object, **kwargs: object) -> object:
|
||||
"""Execute fn through the circuit breaker."""
|
||||
current_state = self.state
|
||||
|
||||
|
|
@ -158,6 +156,6 @@ class CircuitBreaker:
|
|||
result = await fn(*args, **kwargs)
|
||||
self._on_success()
|
||||
return result
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
self._on_failure()
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ from __future__ import annotations
|
|||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
|
|
@ -95,8 +94,7 @@ class KBAdapter(ABC):
|
|||
async def delete_by_id(self, id: str) -> bool:
|
||||
"""按文档 ID 删除(子类可覆盖)"""
|
||||
logger.warning(
|
||||
f"{self.__class__.__name__} does not support delete_by_id; "
|
||||
f"id '{id}' skipped"
|
||||
f"{self.__class__.__name__} does not support delete_by_id; id '{id}' skipped"
|
||||
)
|
||||
return False
|
||||
|
||||
|
|
@ -127,8 +125,7 @@ class KBAdapter(ABC):
|
|||
async def get_document(self, doc_id: str) -> Document | None:
|
||||
"""按 ID 获取单个文档(子类可覆盖)"""
|
||||
logger.warning(
|
||||
f"{self.__class__.__name__} does not support get_document; "
|
||||
f"doc_id '{doc_id}' not found"
|
||||
f"{self.__class__.__name__} does not support get_document; doc_id '{doc_id}' not found"
|
||||
)
|
||||
return None
|
||||
|
||||
|
|
@ -156,5 +153,5 @@ class KBAdapter(ABC):
|
|||
async def __aenter__(self) -> KBAdapter:
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args: Any) -> None:
|
||||
async def __aexit__(self, *args: object) -> None:
|
||||
await self.close()
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
|
|
@ -56,7 +55,9 @@ class ConfluenceAdapter(KBAdapter):
|
|||
)
|
||||
self._base_url = base_url.rstrip("/")
|
||||
if not is_safe_url(self._base_url):
|
||||
raise ValueError(f"Unsafe base_url: {self._base_url}. Private/internal URLs are not allowed.")
|
||||
raise ValueError(
|
||||
f"Unsafe base_url: {self._base_url}. Private/internal URLs are not allowed."
|
||||
)
|
||||
self._username = username
|
||||
self._api_token = api_token
|
||||
self._space_keys = space_keys or []
|
||||
|
|
@ -65,9 +66,7 @@ class ConfluenceAdapter(KBAdapter):
|
|||
"""创建 Confluence API HTTP 客户端"""
|
||||
import base64
|
||||
|
||||
credentials = base64.b64encode(
|
||||
f"{self._username}:{self._api_token}".encode()
|
||||
).decode()
|
||||
credentials = base64.b64encode(f"{self._username}:{self._api_token}".encode()).decode()
|
||||
return httpx.AsyncClient(
|
||||
base_url=self._base_url,
|
||||
headers={
|
||||
|
|
@ -101,7 +100,7 @@ class ConfluenceAdapter(KBAdapter):
|
|||
space_filter = " OR ".join(
|
||||
f'space = "{_escape_cql(key)}"' for key in self._space_keys
|
||||
)
|
||||
cql = f'{cql} AND ({space_filter})'
|
||||
cql = f"{cql} AND ({space_filter})"
|
||||
|
||||
resp = await client.get(
|
||||
"/rest/api/content/search",
|
||||
|
|
@ -115,6 +114,7 @@ class ConfluenceAdapter(KBAdapter):
|
|||
body = page.get("body", {}).get("storage", {}).get("value", "")
|
||||
# Strip HTML tags for plain text content
|
||||
import re
|
||||
|
||||
content = re.sub(r"<[^>]+>", "", body) if body else page.get("title", "")
|
||||
|
||||
results.append(
|
||||
|
|
@ -136,8 +136,7 @@ class ConfluenceAdapter(KBAdapter):
|
|||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(
|
||||
f"Confluence search HTTP error: {e.response.status_code} — "
|
||||
f"{e.response.text[:200]}"
|
||||
f"Confluence search HTTP error: {e.response.status_code} — {e.response.text[:200]}"
|
||||
)
|
||||
return []
|
||||
except Exception as e:
|
||||
|
|
@ -157,6 +156,7 @@ class ConfluenceAdapter(KBAdapter):
|
|||
|
||||
body = page.get("body", {}).get("storage", {}).get("value", "")
|
||||
import re
|
||||
|
||||
content = re.sub(r"<[^>]+>", "", body) if body else ""
|
||||
|
||||
return Document(
|
||||
|
|
@ -191,13 +191,17 @@ class ConfluenceAdapter(KBAdapter):
|
|||
source_type="confluence",
|
||||
)
|
||||
)
|
||||
return sources if sources else [
|
||||
return (
|
||||
sources
|
||||
if sources
|
||||
else [
|
||||
SourceInfo(
|
||||
source_id=self._source_id,
|
||||
source_name=self._source_name,
|
||||
source_type=self._source_type,
|
||||
)
|
||||
]
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Confluence list_sources error: {e}")
|
||||
return [
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from __future__ import annotations
|
|||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import TypeAlias
|
||||
|
||||
import httpx
|
||||
|
||||
|
|
@ -18,6 +18,9 @@ from agentkit.utils.security import is_safe_url
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 飞书搜索请求 payload:search_key/page_size/wiki_space_ids — 值为 str|int|list[str]。
|
||||
FeishuSearchPayload: TypeAlias = dict[str, object]
|
||||
|
||||
|
||||
class FeishuKBAdapter(KBAdapter):
|
||||
"""飞书知识库适配器
|
||||
|
|
@ -54,7 +57,9 @@ class FeishuKBAdapter(KBAdapter):
|
|||
self._app_secret = app_secret
|
||||
self._base_url = base_url.rstrip("/")
|
||||
if not is_safe_url(self._base_url):
|
||||
raise ValueError(f"Unsafe base_url: {self._base_url}. Private/internal URLs are not allowed.")
|
||||
raise ValueError(
|
||||
f"Unsafe base_url: {self._base_url}. Private/internal URLs are not allowed."
|
||||
)
|
||||
self._space_ids = space_ids or []
|
||||
self._access_token: str | None = None
|
||||
self._token_expiry: float = 0.0
|
||||
|
|
@ -94,10 +99,7 @@ class FeishuKBAdapter(KBAdapter):
|
|||
self._client = None
|
||||
return self._access_token
|
||||
else:
|
||||
logger.error(
|
||||
f"Feishu auth failed: code={data.get('code')}, "
|
||||
f"msg={data.get('msg')}"
|
||||
)
|
||||
logger.error(f"Feishu auth failed: code={data.get('code')}, msg={data.get('msg')}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Feishu auth error: {e}")
|
||||
|
|
@ -121,7 +123,7 @@ class FeishuKBAdapter(KBAdapter):
|
|||
|
||||
client = self._get_client()
|
||||
try:
|
||||
payload: dict[str, Any] = {
|
||||
payload: FeishuSearchPayload = {
|
||||
"search_key": query,
|
||||
"page_size": top_k,
|
||||
}
|
||||
|
|
@ -137,8 +139,7 @@ class FeishuKBAdapter(KBAdapter):
|
|||
|
||||
if data.get("code") != 0:
|
||||
logger.error(
|
||||
f"Feishu search failed: code={data.get('code')}, "
|
||||
f"msg={data.get('msg')}"
|
||||
f"Feishu search failed: code={data.get('code')}, msg={data.get('msg')}"
|
||||
)
|
||||
return []
|
||||
|
||||
|
|
@ -162,8 +163,7 @@ class FeishuKBAdapter(KBAdapter):
|
|||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(
|
||||
f"Feishu search HTTP error: {e.response.status_code} — "
|
||||
f"{e.response.text[:200]}"
|
||||
f"Feishu search HTTP error: {e.response.status_code} — {e.response.text[:200]}"
|
||||
)
|
||||
return []
|
||||
except Exception as e:
|
||||
|
|
@ -179,7 +179,7 @@ class FeishuKBAdapter(KBAdapter):
|
|||
client = self._get_client()
|
||||
try:
|
||||
resp = await client.get(
|
||||
f"/wiki/v2/spaces/get_node",
|
||||
"/wiki/v2/spaces/get_node",
|
||||
params={"token": doc_id},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
|
|
@ -230,13 +230,17 @@ class FeishuKBAdapter(KBAdapter):
|
|||
source_type="feishu",
|
||||
)
|
||||
)
|
||||
return sources if sources else [
|
||||
return (
|
||||
sources
|
||||
if sources
|
||||
else [
|
||||
SourceInfo(
|
||||
source_id=self._source_id,
|
||||
source_name=self._source_name,
|
||||
source_type=self._source_type,
|
||||
)
|
||||
]
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Feishu list_sources error: {e}")
|
||||
return [
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
|
|
@ -55,7 +54,9 @@ class GenericHTTPAdapter(KBAdapter):
|
|||
)
|
||||
self._endpoint_url = endpoint_url.rstrip("/")
|
||||
if not is_safe_url(self._endpoint_url):
|
||||
raise ValueError(f"Unsafe endpoint_url: {self._endpoint_url}. Private/internal URLs are not allowed.")
|
||||
raise ValueError(
|
||||
f"Unsafe endpoint_url: {self._endpoint_url}. Private/internal URLs are not allowed."
|
||||
)
|
||||
self._auth_config = auth_config or {}
|
||||
self._extra_headers = headers or {}
|
||||
|
||||
|
|
@ -74,12 +75,11 @@ class GenericHTTPAdapter(KBAdapter):
|
|||
headers["Authorization"] = f"Bearer {token}"
|
||||
elif auth_type == "basic":
|
||||
import base64
|
||||
|
||||
username = self._auth_config.get("username", "")
|
||||
password = self._auth_config.get("password", "")
|
||||
if username and password:
|
||||
credentials = base64.b64encode(
|
||||
f"{username}:{password}".encode()
|
||||
).decode()
|
||||
credentials = base64.b64encode(f"{username}:{password}".encode()).decode()
|
||||
headers["Authorization"] = f"Basic {credentials}"
|
||||
elif auth_type == "api_key":
|
||||
key_name = self._auth_config.get("header_name", "X-API-Key")
|
||||
|
|
@ -135,8 +135,7 @@ class GenericHTTPAdapter(KBAdapter):
|
|||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(
|
||||
f"GenericHTTP search HTTP error: {e.response.status_code} — "
|
||||
f"{e.response.text[:200]}"
|
||||
f"GenericHTTP search HTTP error: {e.response.status_code} — {e.response.text[:200]}"
|
||||
)
|
||||
return []
|
||||
except Exception as e:
|
||||
|
|
@ -177,8 +176,7 @@ class GenericHTTPAdapter(KBAdapter):
|
|||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(
|
||||
f"GenericHTTP ingest HTTP error: {e.response.status_code} — "
|
||||
f"{e.response.text[:200]}"
|
||||
f"GenericHTTP ingest HTTP error: {e.response.status_code} — {e.response.text[:200]}"
|
||||
)
|
||||
return []
|
||||
except Exception as e:
|
||||
|
|
@ -245,13 +243,17 @@ class GenericHTTPAdapter(KBAdapter):
|
|||
document_count=item.get("document_count", 0),
|
||||
)
|
||||
)
|
||||
return sources if sources else [
|
||||
return (
|
||||
sources
|
||||
if sources
|
||||
else [
|
||||
SourceInfo(
|
||||
source_id=self._source_id,
|
||||
source_name=self._source_name,
|
||||
source_type=self._source_type,
|
||||
)
|
||||
]
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"GenericHTTP list_sources error (endpoint may not exist): {e}")
|
||||
|
||||
|
|
|
|||
|
|
@ -3,19 +3,29 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from typing import TypeAlias
|
||||
|
||||
# 共享类型别名 — 跨 memory 子系统复用,避免 `Any` 残留。
|
||||
# MetadataValue 覆盖 metadata dict 中实际出现的原始类型;
|
||||
# MemoryValue 额外允许 dict/list 容器以容纳结构化负载(如 episodic 经验字典)。
|
||||
MetadataValue: TypeAlias = str | int | float | bool | None
|
||||
MetadataDict: TypeAlias = dict[str, MetadataValue]
|
||||
MemoryValue: TypeAlias = (
|
||||
str | int | float | bool | None | dict[str, MetadataValue] | list[MetadataValue]
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryItem:
|
||||
"""记忆条目"""
|
||||
|
||||
key: str
|
||||
value: Any
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
value: object
|
||||
metadata: MetadataDict = field(default_factory=dict)
|
||||
score: float = 1.0
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
def to_dict(self) -> dict[str, object]:
|
||||
return {
|
||||
"key": self.key,
|
||||
"value": self.value,
|
||||
|
|
@ -35,7 +45,7 @@ class Memory(ABC):
|
|||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None:
|
||||
async def store(self, key: str, value: object, metadata: MetadataDict | None = None) -> None:
|
||||
"""存储记忆"""
|
||||
...
|
||||
|
||||
|
|
@ -45,7 +55,9 @@ class Memory(ABC):
|
|||
...
|
||||
|
||||
@abstractmethod
|
||||
async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None) -> list[MemoryItem]:
|
||||
async def search(
|
||||
self, query: str, top_k: int = 5, filters: MetadataDict | None = None
|
||||
) -> list[MemoryItem]:
|
||||
"""语义检索"""
|
||||
...
|
||||
|
||||
|
|
@ -54,7 +66,7 @@ class Memory(ABC):
|
|||
"""删除记忆"""
|
||||
...
|
||||
|
||||
async def store_batch(self, items: list[tuple[str, Any, dict | None]]) -> None:
|
||||
async def store_batch(self, items: list[tuple[str, object, MetadataDict | None]]) -> None:
|
||||
"""批量存储"""
|
||||
for key, value, metadata in items:
|
||||
await self.store(key, value, metadata)
|
||||
|
|
|
|||
|
|
@ -11,10 +11,18 @@ import logging
|
|||
import re
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from typing import TypeAlias
|
||||
|
||||
from agentkit.memory.base import MetadataDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 分块元数据:source_doc/position/char_count/chunking_strategy/heading/heading_level
|
||||
# — 全部为原始标量(str/int)。
|
||||
ChunkMetadata: TypeAlias = MetadataDict
|
||||
# _split_by_headings 返回的节段结构。
|
||||
SectionInfo: TypeAlias = dict[str, str | int]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Chunk:
|
||||
|
|
@ -22,7 +30,7 @@ class Chunk:
|
|||
|
||||
chunk_id: str
|
||||
content: str
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
metadata: ChunkMetadata = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if "source_doc" not in self.metadata:
|
||||
|
|
@ -30,7 +38,7 @@ class Chunk:
|
|||
if "position" not in self.metadata:
|
||||
self.metadata["position"] = 0
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, object]:
|
||||
return {
|
||||
"chunk_id": self.chunk_id,
|
||||
"content": self.content,
|
||||
|
|
@ -57,7 +65,9 @@ class TextChunker:
|
|||
separator: 优先分割符
|
||||
"""
|
||||
if chunk_overlap >= chunk_size:
|
||||
raise ValueError(f"chunk_overlap ({chunk_overlap}) must be less than chunk_size ({chunk_size})")
|
||||
raise ValueError(
|
||||
f"chunk_overlap ({chunk_overlap}) must be less than chunk_size ({chunk_size})"
|
||||
)
|
||||
self._chunk_size = chunk_size
|
||||
self._chunk_overlap = chunk_overlap
|
||||
self._separator = separator
|
||||
|
|
@ -66,7 +76,7 @@ class TextChunker:
|
|||
self,
|
||||
text: str,
|
||||
source_doc_id: str = "",
|
||||
metadata: dict[str, Any] | None = None,
|
||||
metadata: ChunkMetadata | None = None,
|
||||
) -> list[Chunk]:
|
||||
"""将文本分块
|
||||
|
||||
|
|
@ -96,11 +106,13 @@ class TextChunker:
|
|||
chunk_meta = dict(base_meta)
|
||||
chunk_meta["position"] = i
|
||||
chunk_meta["char_count"] = len(chunk_text)
|
||||
chunks.append(Chunk(
|
||||
chunks.append(
|
||||
Chunk(
|
||||
chunk_id=str(uuid.uuid4()),
|
||||
content=chunk_text,
|
||||
metadata=chunk_meta,
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
return chunks
|
||||
|
||||
|
|
@ -142,7 +154,9 @@ class TextChunker:
|
|||
overlap_text[overlap_start:], segments
|
||||
)
|
||||
current = overlap_segments
|
||||
current_len = sum(len(s) for s in current) + len(self._separator) * max(0, len(current) - 1)
|
||||
current_len = sum(len(s) for s in current) + len(self._separator) * max(
|
||||
0, len(current) - 1
|
||||
)
|
||||
|
||||
current.append(segment)
|
||||
current_len += seg_len + len(self._separator)
|
||||
|
|
@ -214,7 +228,7 @@ class StructuralChunker:
|
|||
self,
|
||||
text: str,
|
||||
source_doc_id: str = "",
|
||||
metadata: dict[str, Any] | None = None,
|
||||
metadata: ChunkMetadata | None = None,
|
||||
) -> list[Chunk]:
|
||||
"""将文本按结构分块
|
||||
|
||||
|
|
@ -266,23 +280,25 @@ class StructuralChunker:
|
|||
chunk_meta["heading"] = heading
|
||||
chunk_meta["heading_level"] = level
|
||||
chunk_meta["char_count"] = len(content)
|
||||
chunks.append(Chunk(
|
||||
chunks.append(
|
||||
Chunk(
|
||||
chunk_id=str(uuid.uuid4()),
|
||||
content=content,
|
||||
metadata=chunk_meta,
|
||||
))
|
||||
)
|
||||
)
|
||||
position += 1
|
||||
|
||||
return chunks
|
||||
|
||||
def _split_by_headings(self, text: str) -> list[dict[str, Any]]:
|
||||
def _split_by_headings(self, text: str) -> list[SectionInfo]:
|
||||
"""按标题分割 Markdown 文本
|
||||
|
||||
Returns:
|
||||
列表,每项包含 heading, content, level
|
||||
"""
|
||||
lines = text.split("\n")
|
||||
sections: list[dict[str, Any]] = []
|
||||
sections: list[SectionInfo] = []
|
||||
current_heading = ""
|
||||
current_level = 0
|
||||
current_lines: list[str] = []
|
||||
|
|
@ -296,11 +312,13 @@ class StructuralChunker:
|
|||
if current_lines:
|
||||
content = "\n".join(current_lines).strip()
|
||||
if content:
|
||||
sections.append({
|
||||
sections.append(
|
||||
{
|
||||
"heading": current_heading,
|
||||
"content": content,
|
||||
"level": current_level,
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
# 开始新节
|
||||
current_heading = match.group(2).strip()
|
||||
|
|
@ -313,18 +331,22 @@ class StructuralChunker:
|
|||
if current_lines:
|
||||
content = "\n".join(current_lines).strip()
|
||||
if content:
|
||||
sections.append({
|
||||
sections.append(
|
||||
{
|
||||
"heading": current_heading,
|
||||
"content": content,
|
||||
"level": current_level,
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
# 如果没有标题结构,整体作为一个块
|
||||
if not sections:
|
||||
sections.append({
|
||||
sections.append(
|
||||
{
|
||||
"heading": "",
|
||||
"content": text.strip(),
|
||||
"level": 0,
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
return sections
|
||||
|
|
|
|||
|
|
@ -9,10 +9,14 @@ from __future__ import annotations
|
|||
import hashlib
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from agentkit.memory.base import MetadataDict
|
||||
from agentkit.memory.embedder import EmbeddingCache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agentkit.llm.gateway import LLMGateway
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -24,7 +28,7 @@ class ContextualChunk:
|
|||
context_prefix: str
|
||||
enhanced_content: str
|
||||
chunk_index: int
|
||||
metadata: dict[str, Any]
|
||||
metadata: MetadataDict
|
||||
|
||||
@property
|
||||
def content(self) -> str:
|
||||
|
|
@ -65,7 +69,7 @@ class ContextualChunker:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
llm_gateway: Any = None,
|
||||
llm_gateway: LLMGateway | None = None,
|
||||
cache: EmbeddingCache | None = None,
|
||||
batch_size: int = 8,
|
||||
max_context_length: int = 200,
|
||||
|
|
@ -90,7 +94,7 @@ class ContextualChunker:
|
|||
self,
|
||||
document: str,
|
||||
chunks: list[str],
|
||||
metadata: dict[str, Any] | None = None,
|
||||
metadata: MetadataDict | None = None,
|
||||
) -> list[ContextualChunk]:
|
||||
"""为文档块添加上下文前缀
|
||||
|
||||
|
|
@ -134,7 +138,7 @@ class ContextualChunker:
|
|||
document: str,
|
||||
chunks: list[str],
|
||||
start_index: int,
|
||||
metadata: dict[str, Any] | None,
|
||||
metadata: MetadataDict | None,
|
||||
) -> list[ContextualChunk]:
|
||||
"""处理一批文档块"""
|
||||
results: list[ContextualChunk] = []
|
||||
|
|
|
|||
|
|
@ -12,7 +12,9 @@ import uuid
|
|||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import TypeAlias
|
||||
|
||||
from agentkit.memory.base import MetadataDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -23,6 +25,11 @@ MAX_CONTENT_SIZE = 100 * 1024 * 1024 # 100MB
|
|||
MAX_ROWS_PER_SHEET = 10_000
|
||||
MAX_CELL_CHARS = 10_000
|
||||
|
||||
# 文档元数据:source/format/parser/page_count/table_count/sheet_count/row_count/
|
||||
# heading_count/created_at/title/truncated — 全部为原始标量。
|
||||
DocumentMetadata: TypeAlias = MetadataDict
|
||||
ParseResult: TypeAlias = tuple[str, DocumentMetadata]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Document:
|
||||
|
|
@ -31,7 +38,7 @@ class Document:
|
|||
doc_id: str
|
||||
title: str
|
||||
content: str
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
metadata: DocumentMetadata = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if "source" not in self.metadata:
|
||||
|
|
@ -43,7 +50,7 @@ class Document:
|
|||
if "created_at" not in self.metadata:
|
||||
self.metadata["created_at"] = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, object]:
|
||||
return {
|
||||
"doc_id": self.doc_id,
|
||||
"title": self.title,
|
||||
|
|
@ -136,12 +143,14 @@ class DocumentLoader:
|
|||
|
||||
parser = parsers.get(doc_format)
|
||||
if parser is None:
|
||||
logger.warning(f"Unsupported format '{doc_format}' for {filename}, falling back to text")
|
||||
logger.warning(
|
||||
f"Unsupported format '{doc_format}' for {filename}, falling back to text"
|
||||
)
|
||||
parser = self._parse_text
|
||||
|
||||
text, extra_meta = parser(content, filename)
|
||||
|
||||
metadata: dict[str, Any] = {
|
||||
metadata: DocumentMetadata = {
|
||||
"source": filename,
|
||||
"format": doc_format,
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
|
|
@ -159,7 +168,7 @@ class DocumentLoader:
|
|||
metadata=metadata,
|
||||
)
|
||||
|
||||
def _parse_pdf(self, content: bytes, filename: str) -> tuple[str, dict[str, Any]]:
|
||||
def _parse_pdf(self, content: bytes, filename: str) -> ParseResult:
|
||||
"""解析 PDF 文件
|
||||
|
||||
优先使用 PyMuPDF (fitz),回退到 pdfplumber,最终回退到纯文本。
|
||||
|
|
@ -215,7 +224,7 @@ class DocumentLoader:
|
|||
logger.warning(f"No PDF parser available for {filename}, falling back to text extraction")
|
||||
return self._parse_text(content, filename)
|
||||
|
||||
def _parse_docx(self, content: bytes, filename: str) -> tuple[str, dict[str, Any]]:
|
||||
def _parse_docx(self, content: bytes, filename: str) -> ParseResult:
|
||||
"""解析 Word 文件
|
||||
|
||||
使用 python-docx,回退到纯文本。
|
||||
|
|
@ -259,7 +268,7 @@ class DocumentLoader:
|
|||
logger.warning(f"python-docx parsing failed for {filename}: {e}")
|
||||
return self._parse_text(content, filename)
|
||||
|
||||
def _parse_xlsx(self, content: bytes, filename: str) -> tuple[str, dict[str, Any]]:
|
||||
def _parse_xlsx(self, content: bytes, filename: str) -> ParseResult:
|
||||
"""解析 Excel 文件
|
||||
|
||||
使用 openpyxl,回退到纯文本。每个 sheet 转为 Markdown 表格,
|
||||
|
|
@ -313,7 +322,7 @@ class DocumentLoader:
|
|||
finally:
|
||||
wb.close()
|
||||
text = "\n".join(sections).strip()
|
||||
meta: dict[str, Any] = {
|
||||
meta: DocumentMetadata = {
|
||||
"parser": "openpyxl",
|
||||
"sheet_count": sheet_count,
|
||||
"row_count": total_rows,
|
||||
|
|
@ -328,7 +337,7 @@ class DocumentLoader:
|
|||
logger.warning(f"openpyxl parsing failed for {filename}: {e}")
|
||||
return self._parse_text(content, filename)
|
||||
|
||||
def _parse_markdown(self, content: bytes, filename: str) -> tuple[str, dict[str, Any]]:
|
||||
def _parse_markdown(self, content: bytes, filename: str) -> ParseResult:
|
||||
"""解析 Markdown 文件
|
||||
|
||||
使用 mistune(如果可用),否则直接读取文本。
|
||||
|
|
@ -347,7 +356,7 @@ class DocumentLoader:
|
|||
title = line_stripped.lstrip("#").strip()
|
||||
break
|
||||
|
||||
meta: dict[str, Any] = {
|
||||
meta: DocumentMetadata = {
|
||||
"parser": "markdown",
|
||||
}
|
||||
if title:
|
||||
|
|
@ -362,7 +371,7 @@ class DocumentLoader:
|
|||
|
||||
return text, meta
|
||||
|
||||
def _parse_html(self, content: bytes, filename: str) -> tuple[str, dict[str, Any]]:
|
||||
def _parse_html(self, content: bytes, filename: str) -> ParseResult:
|
||||
"""解析 HTML 文件
|
||||
|
||||
使用 BeautifulSoup 提取文本,回退到纯文本。
|
||||
|
|
@ -388,7 +397,7 @@ class DocumentLoader:
|
|||
if soup.title and soup.title.string:
|
||||
title = soup.title.string.strip()
|
||||
|
||||
meta: dict[str, Any] = {
|
||||
meta: DocumentMetadata = {
|
||||
"parser": "beautifulsoup",
|
||||
}
|
||||
if title:
|
||||
|
|
@ -402,7 +411,7 @@ class DocumentLoader:
|
|||
logger.warning(f"BeautifulSoup parsing failed for {filename}: {e}")
|
||||
return self._parse_text(content, filename)
|
||||
|
||||
def _parse_text(self, content: bytes, filename: str) -> tuple[str, dict[str, Any]]:
|
||||
def _parse_text(self, content: bytes, filename: str) -> ParseResult:
|
||||
"""解析纯文本文件"""
|
||||
try:
|
||||
text = content.decode("utf-8")
|
||||
|
|
|
|||
|
|
@ -1,12 +1,17 @@
|
|||
"""Embedder 接口与实现 - 文本向量化"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import OrderedDict
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -97,13 +102,14 @@ class OpenAIEmbedder(Embedder):
|
|||
self._model = model
|
||||
self._base_url = base_url
|
||||
self._dimension = 1536 # text-embedding-3-small 默认维度
|
||||
self._client: Any = None
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
self._cache = cache
|
||||
|
||||
def _get_client(self):
|
||||
def _get_client(self) -> httpx.AsyncClient:
|
||||
"""Lazily create and reuse a single httpx.AsyncClient."""
|
||||
if self._client is None:
|
||||
import httpx
|
||||
|
||||
self._client = httpx.AsyncClient(timeout=30.0)
|
||||
return self._client
|
||||
|
||||
|
|
|
|||
|
|
@ -1,17 +1,22 @@
|
|||
"""Episodic Memory - 基于 pgvector + PostgreSQL 的任务经验记忆"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from agentkit.memory.base import Memory, MemoryItem
|
||||
from agentkit.memory.base import Memory, MemoryItem, MetadataDict
|
||||
from agentkit.memory.embedder import Embedder
|
||||
from agentkit.utils.vector_math import compute_cosine_similarity
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -28,8 +33,8 @@ class EpisodicMemory(Memory):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
session_factory: Any,
|
||||
episodic_model: Any,
|
||||
session_factory: object,
|
||||
episodic_model: object,
|
||||
embedder: Embedder | None = None,
|
||||
decay_rate: float = 0.01,
|
||||
alpha: float = 0.7,
|
||||
|
|
@ -57,7 +62,7 @@ class EpisodicMemory(Memory):
|
|||
self._pgvector_enabled = pgvector_enabled
|
||||
self._table_name = table_name
|
||||
|
||||
async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None:
|
||||
async def store(self, key: str, value: object, metadata: MetadataDict | None = None) -> None:
|
||||
"""存储任务经验"""
|
||||
async with self._session_factory() as db:
|
||||
try:
|
||||
|
|
@ -68,7 +73,11 @@ class EpisodicMemory(Memory):
|
|||
embedding = None
|
||||
if self._embedder:
|
||||
if isinstance(value, dict):
|
||||
text = value.get("output_summary", "") or value.get("input_summary", "") or json.dumps(value, ensure_ascii=False)[:500]
|
||||
text = (
|
||||
value.get("output_summary", "")
|
||||
or value.get("input_summary", "")
|
||||
or json.dumps(value, ensure_ascii=False)[:500]
|
||||
)
|
||||
else:
|
||||
text = str(value)
|
||||
embedding = await self._embedder.embed(text)
|
||||
|
|
@ -106,13 +115,11 @@ class EpisodicMemory(Memory):
|
|||
logger.error(f"Failed to retrieve episodic memory: {e}")
|
||||
return None
|
||||
|
||||
async def _retrieve_pgvector(self, db: Any, query_embedding: list[float]) -> MemoryItem | None:
|
||||
async def _retrieve_pgvector(
|
||||
self, db: AsyncSession, query_embedding: list[float]
|
||||
) -> MemoryItem | None:
|
||||
"""使用 pgvector ``<=>`` 算符检索最相似条目"""
|
||||
sql = text(
|
||||
f"SELECT * FROM {self._table_name} "
|
||||
f"ORDER BY embedding <=> :query_vec "
|
||||
f"LIMIT :lim"
|
||||
)
|
||||
sql = text(f"SELECT * FROM {self._table_name} ORDER BY embedding <=> :query_vec LIMIT :lim")
|
||||
result = await db.execute(sql, {"query_vec": str(query_embedding), "lim": 1})
|
||||
row = result.mappings().first()
|
||||
|
||||
|
|
@ -147,7 +154,9 @@ class EpisodicMemory(Memory):
|
|||
created_at=row.get("created_at") or datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
async def _retrieve_client_side(self, db: Any, query_embedding: list[float]) -> MemoryItem | None:
|
||||
async def _retrieve_client_side(
|
||||
self, db: AsyncSession, query_embedding: list[float]
|
||||
) -> MemoryItem | None:
|
||||
"""客户端 O(N) cosine similarity 检索(回退路径)"""
|
||||
Model = self._episodic_model
|
||||
from sqlalchemy import select
|
||||
|
|
@ -193,7 +202,13 @@ class EpisodicMemory(Memory):
|
|||
created_at=best_item.created_at or datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None, search_multiplier: int = 5) -> list[MemoryItem]:
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
top_k: int = 5,
|
||||
filters: MetadataDict | None = None,
|
||||
search_multiplier: int = 5,
|
||||
) -> list[MemoryItem]:
|
||||
"""语义检索相似历史案例
|
||||
|
||||
Args:
|
||||
|
|
@ -214,10 +229,10 @@ class EpisodicMemory(Memory):
|
|||
|
||||
async def _search_pgvector(
|
||||
self,
|
||||
db: Any,
|
||||
db: AsyncSession,
|
||||
query: str,
|
||||
top_k: int,
|
||||
filters: dict[str, Any] | None,
|
||||
filters: MetadataDict | None,
|
||||
search_multiplier: int,
|
||||
) -> list[MemoryItem]:
|
||||
"""使用 pgvector ``<=>`` 算符检索,再 Python 侧 time_decay 重排"""
|
||||
|
|
@ -225,7 +240,7 @@ class EpisodicMemory(Memory):
|
|||
fetch_limit = top_k * search_multiplier
|
||||
|
||||
where_clauses = []
|
||||
params: dict[str, Any] = {"query_vec": str(query_embedding), "lim": fetch_limit}
|
||||
params: dict[str, object] = {"query_vec": str(query_embedding), "lim": fetch_limit}
|
||||
|
||||
filters = filters or {}
|
||||
if filters.get("agent_name"):
|
||||
|
|
@ -256,7 +271,11 @@ class EpisodicMemory(Memory):
|
|||
items = []
|
||||
for row in rows:
|
||||
row_embedding = row.get("embedding")
|
||||
age_hours = (datetime.now(timezone.utc) - row["created_at"]).total_seconds() / 3600 if row.get("created_at") else 0
|
||||
age_hours = (
|
||||
(datetime.now(timezone.utc) - row["created_at"]).total_seconds() / 3600
|
||||
if row.get("created_at")
|
||||
else 0
|
||||
)
|
||||
decay = math.exp(-self._decay_rate * age_hours)
|
||||
time_decay_score = (row.get("quality_score") or 0.5) * decay
|
||||
|
||||
|
|
@ -266,7 +285,8 @@ class EpisodicMemory(Memory):
|
|||
else:
|
||||
score = time_decay_score
|
||||
|
||||
items.append(MemoryItem(
|
||||
items.append(
|
||||
MemoryItem(
|
||||
key=str(row.get("id", "")),
|
||||
value={
|
||||
"input_summary": row.get("input_summary", ""),
|
||||
|
|
@ -278,21 +298,24 @@ class EpisodicMemory(Memory):
|
|||
metadata={
|
||||
"agent_name": row.get("agent_name", ""),
|
||||
"task_type": row.get("task_type", ""),
|
||||
"created_at": row["created_at"].isoformat() if row.get("created_at") else None,
|
||||
"created_at": row["created_at"].isoformat()
|
||||
if row.get("created_at")
|
||||
else None,
|
||||
},
|
||||
score=score,
|
||||
created_at=row.get("created_at") or datetime.now(timezone.utc),
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
items.sort(key=lambda x: x.score, reverse=True)
|
||||
return items[:top_k]
|
||||
|
||||
async def _search_client_side(
|
||||
self,
|
||||
db: Any,
|
||||
db: AsyncSession,
|
||||
query: str,
|
||||
top_k: int,
|
||||
filters: dict[str, Any] | None,
|
||||
filters: MetadataDict | None,
|
||||
search_multiplier: int,
|
||||
) -> list[MemoryItem]:
|
||||
"""客户端 O(N) cosine similarity 检索(回退路径)"""
|
||||
|
|
@ -300,6 +323,7 @@ class EpisodicMemory(Memory):
|
|||
filters = filters or {}
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
stmt = select(Model)
|
||||
|
||||
if filters.get("agent_name"):
|
||||
|
|
@ -322,7 +346,11 @@ class EpisodicMemory(Memory):
|
|||
# 计算得分并构建 MemoryItem
|
||||
items = []
|
||||
for entry in entries:
|
||||
age_hours = (datetime.now(timezone.utc) - entry.created_at).total_seconds() / 3600 if entry.created_at else 0
|
||||
age_hours = (
|
||||
(datetime.now(timezone.utc) - entry.created_at).total_seconds() / 3600
|
||||
if entry.created_at
|
||||
else 0
|
||||
)
|
||||
decay = math.exp(-self._decay_rate * age_hours)
|
||||
time_decay_score = (entry.quality_score or 0.5) * decay
|
||||
|
||||
|
|
@ -333,7 +361,8 @@ class EpisodicMemory(Memory):
|
|||
else:
|
||||
score = time_decay_score
|
||||
|
||||
items.append(MemoryItem(
|
||||
items.append(
|
||||
MemoryItem(
|
||||
key=str(entry.id),
|
||||
value={
|
||||
"input_summary": entry.input_summary,
|
||||
|
|
@ -349,14 +378,17 @@ class EpisodicMemory(Memory):
|
|||
},
|
||||
score=score,
|
||||
created_at=entry.created_at or datetime.now(timezone.utc),
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
items.sort(key=lambda x: x.score, reverse=True)
|
||||
if len(items) < top_k:
|
||||
logger.warning(
|
||||
"EpisodicMemory.search returned %d results after scoring (top_k=%d). "
|
||||
"Consider increasing search_multiplier (current=%d) to avoid missing relevant entries.",
|
||||
len(items), top_k, search_multiplier,
|
||||
len(items),
|
||||
top_k,
|
||||
search_multiplier,
|
||||
)
|
||||
return items[:top_k]
|
||||
|
||||
|
|
@ -364,8 +396,9 @@ class EpisodicMemory(Memory):
|
|||
"""删除指定经验"""
|
||||
async with self._session_factory() as db:
|
||||
try:
|
||||
from sqlalchemy import select, delete as sql_delete
|
||||
from sqlalchemy import delete as sql_delete
|
||||
import uuid
|
||||
|
||||
Model = self._episodic_model
|
||||
|
||||
stmt = sql_delete(Model).where(Model.id == uuid.UUID(key))
|
||||
|
|
|
|||
|
|
@ -3,13 +3,26 @@
|
|||
配置驱动,不直接依赖业务系统代码,通过 base_url + api_key 连接。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, TypeAlias
|
||||
|
||||
import httpx
|
||||
|
||||
from agentkit.memory.base import MetadataDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agentkit.llm.gateway import LLMGateway
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 标准化检索结果:id/content/score/source/document_id/document_title/
|
||||
# knowledge_base_id/metadata — 值为原始标量或嵌套 dict。
|
||||
RAGSearchResult: TypeAlias = dict[str, object]
|
||||
# ingest() 写入的文档负载:title/content/source_type/metadata。
|
||||
RAGIngestPayload: TypeAlias = dict[str, object]
|
||||
|
||||
|
||||
class HttpRAGService:
|
||||
"""HTTP 客户端,调用业务系统的知识库检索 API
|
||||
|
|
@ -39,7 +52,7 @@ class HttpRAGService:
|
|||
knowledge_base_ids: list[str] | None = None,
|
||||
timeout: int = 30,
|
||||
contextual_chunking: bool = False,
|
||||
llm_gateway: Any = None,
|
||||
llm_gateway: LLMGateway | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
|
@ -74,7 +87,7 @@ class HttpRAGService:
|
|||
query: str,
|
||||
knowledge_base_ids: list[str] | None = None,
|
||||
top_k: int = 5,
|
||||
) -> list[dict[str, Any]]:
|
||||
) -> list[RAGSearchResult]:
|
||||
"""语义检索知识库
|
||||
|
||||
Args:
|
||||
|
|
@ -113,7 +126,8 @@ class HttpRAGService:
|
|||
normalized = []
|
||||
for r in results:
|
||||
if isinstance(r, dict):
|
||||
normalized.append({
|
||||
normalized.append(
|
||||
{
|
||||
"id": r.get("chunk_id", r.get("id", "")),
|
||||
"content": r.get("content", ""),
|
||||
"score": float(r.get("score", 0.0)),
|
||||
|
|
@ -121,11 +135,14 @@ class HttpRAGService:
|
|||
"document_id": r.get("document_id", ""),
|
||||
"document_title": r.get("document_title", ""),
|
||||
"metadata": r.get("metadata", {}),
|
||||
})
|
||||
}
|
||||
)
|
||||
return normalized
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"RAG search HTTP error: {e.response.status_code} — {e.response.text[:200]}")
|
||||
logger.error(
|
||||
f"RAG search HTTP error: {e.response.status_code} — {e.response.text[:200]}"
|
||||
)
|
||||
return []
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"RAG search request error: {e}")
|
||||
|
|
@ -141,7 +158,7 @@ class HttpRAGService:
|
|||
top_k: int = 5,
|
||||
use_rerank: bool = True,
|
||||
use_compression: bool = False,
|
||||
) -> list[dict[str, Any]]:
|
||||
) -> list[RAGSearchResult]:
|
||||
"""增强语义检索知识库(支持 rerank 和 compression)
|
||||
|
||||
对每个知识库分别调用 /bases/{kb_id}/retrieve 接口,
|
||||
|
|
@ -169,7 +186,7 @@ class HttpRAGService:
|
|||
}
|
||||
|
||||
client = self._get_client()
|
||||
all_results: list[dict[str, Any]] = []
|
||||
all_results: list[RAGSearchResult] = []
|
||||
|
||||
for kb_id in kb_ids:
|
||||
try:
|
||||
|
|
@ -189,7 +206,8 @@ class HttpRAGService:
|
|||
# 标准化
|
||||
for r in results:
|
||||
if isinstance(r, dict):
|
||||
all_results.append({
|
||||
all_results.append(
|
||||
{
|
||||
"id": r.get("chunk_id", r.get("id", "")),
|
||||
"content": r.get("content", ""),
|
||||
"score": float(r.get("score", 0.0)),
|
||||
|
|
@ -198,19 +216,17 @@ class HttpRAGService:
|
|||
"document_title": r.get("document_title", ""),
|
||||
"knowledge_base_id": kb_id,
|
||||
"metadata": r.get("metadata", {}),
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 404:
|
||||
# This KB doesn't support enhanced search — fall back to
|
||||
# standard search for THIS KB only, not all KBs.
|
||||
logger.info(
|
||||
f"Enhanced search not available for KB {kb_id}, "
|
||||
f"using standard search"
|
||||
)
|
||||
std_result = await self.search(
|
||||
query, knowledge_base_ids=[kb_id], top_k=top_k
|
||||
f"Enhanced search not available for KB {kb_id}, using standard search"
|
||||
)
|
||||
std_result = await self.search(query, knowledge_base_ids=[kb_id], top_k=top_k)
|
||||
all_results.extend(std_result)
|
||||
else:
|
||||
logger.error(
|
||||
|
|
@ -232,9 +248,9 @@ class HttpRAGService:
|
|||
async def ingest(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
value: object,
|
||||
metadata: MetadataDict | None = None,
|
||||
) -> dict[str, object] | None:
|
||||
"""写入文档到知识库(可选操作)
|
||||
|
||||
When contextual_chunking is enabled and llm_gateway is configured,
|
||||
|
|
@ -308,5 +324,5 @@ class HttpRAGService:
|
|||
async def __aenter__(self) -> "HttpRAGService":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args: Any) -> None:
|
||||
async def __aexit__(self, *args: object) -> None:
|
||||
await self.close()
|
||||
|
|
|
|||
|
|
@ -11,7 +11,11 @@ from __future__ import annotations
|
|||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
from typing import Protocol, TypeAlias, runtime_checkable
|
||||
|
||||
from agentkit.memory.base import MetadataDict
|
||||
|
||||
KBMetadata: TypeAlias = MetadataDict
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -22,7 +26,7 @@ class Document:
|
|||
content: str
|
||||
title: str = ""
|
||||
source_id: str = ""
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
metadata: KBMetadata = field(default_factory=dict)
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
||||
|
|
@ -34,7 +38,7 @@ class QueryResult:
|
|||
source_id: str
|
||||
source_name: str
|
||||
score: float
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
metadata: KBMetadata = field(default_factory=dict)
|
||||
doc_id: str = ""
|
||||
title: str = ""
|
||||
|
||||
|
|
|
|||
|
|
@ -11,25 +11,32 @@ from __future__ import annotations
|
|||
import json
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
_SAFE_TABLE_NAME_PATTERN = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
|
||||
from typing import TYPE_CHECKING, TypeAlias
|
||||
|
||||
from agentkit.memory.chunking import Chunk, StructuralChunker, TextChunker
|
||||
from agentkit.memory.document_loader import Document as LoaderDocument
|
||||
from agentkit.memory.embedder import Embedder
|
||||
from agentkit.memory.knowledge_base import (
|
||||
Document,
|
||||
KnowledgeBase,
|
||||
QueryResult,
|
||||
SourceInfo,
|
||||
)
|
||||
from agentkit.utils.vector_math import compute_cosine_similarity
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
_SAFE_TABLE_NAME_PATTERN = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# InMemoryLocalRAGService 内部存储的文档元信息结构。
|
||||
# 字段:title/source_id/format/chunk_ids/metadata/created_at — 值为标量或 list[str]。
|
||||
InMemoryDocInfo: TypeAlias = dict[str, object]
|
||||
# 内部 chunk 存储结构:content/embedding/metadata/source_doc_id。
|
||||
InMemoryChunkInfo: TypeAlias = dict[str, object]
|
||||
|
||||
|
||||
def _loader_doc_to_kb_doc(loader_doc: LoaderDocument) -> Document:
|
||||
"""将 document_loader.Document 转换为 knowledge_base.Document"""
|
||||
|
|
@ -53,7 +60,7 @@ class LocalRAGService:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
session_factory: Any,
|
||||
session_factory: object,
|
||||
embedder: Embedder,
|
||||
chunk_size: int = 1000,
|
||||
chunk_overlap: int = 200,
|
||||
|
|
@ -75,10 +82,14 @@ class LocalRAGService:
|
|||
self._chunk_overlap = chunk_overlap
|
||||
self._table_name = table_name
|
||||
if not _SAFE_TABLE_NAME_PATTERN.match(self._table_name):
|
||||
raise ValueError(f"Invalid table_name: {self._table_name}. Must match [a-zA-Z_][a-zA-Z0-9_]*")
|
||||
raise ValueError(
|
||||
f"Invalid table_name: {self._table_name}. Must match [a-zA-Z_][a-zA-Z0-9_]*"
|
||||
)
|
||||
self._pgvector_enabled = pgvector_enabled
|
||||
self._text_chunker = TextChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||
self._structural_chunker = StructuralChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||
self._structural_chunker = StructuralChunker(
|
||||
chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
||||
)
|
||||
|
||||
async def ingest(self, documents: list[Document]) -> list[str]:
|
||||
"""摄取文档列表
|
||||
|
|
@ -136,9 +147,7 @@ class LocalRAGService:
|
|||
try:
|
||||
from sqlalchemy import text as sql_text
|
||||
|
||||
sql = sql_text(
|
||||
f"DELETE FROM {self._table_name} WHERE source_doc_id = :doc_id"
|
||||
)
|
||||
sql = sql_text(f"DELETE FROM {self._table_name} WHERE source_doc_id = :doc_id")
|
||||
await db.execute(sql, {"doc_id": id})
|
||||
await db.commit()
|
||||
return True
|
||||
|
|
@ -171,20 +180,15 @@ class LocalRAGService:
|
|||
|
||||
sources = []
|
||||
for row in rows:
|
||||
meta = {}
|
||||
if row.get("doc_metadata"):
|
||||
try:
|
||||
meta = json.loads(row["doc_metadata"])
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
sources.append(SourceInfo(
|
||||
sources.append(
|
||||
SourceInfo(
|
||||
source_id=row["source_doc_id"],
|
||||
source_name=row.get("source_title", ""),
|
||||
source_type=row.get("doc_format", "local"),
|
||||
document_count=row.get("chunk_count", 0),
|
||||
last_updated=row["created_at"] if row.get("created_at") else None,
|
||||
))
|
||||
)
|
||||
)
|
||||
return sources
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list sources: {e}")
|
||||
|
|
@ -271,7 +275,7 @@ class LocalRAGService:
|
|||
|
||||
async def _query_pgvector(
|
||||
self,
|
||||
db: Any,
|
||||
db: AsyncSession,
|
||||
query_embedding: list[float],
|
||||
top_k: int,
|
||||
) -> list[QueryResult]:
|
||||
|
|
@ -286,10 +290,13 @@ class LocalRAGService:
|
|||
f"LIMIT :lim"
|
||||
)
|
||||
|
||||
result = await db.execute(sql, {
|
||||
result = await db.execute(
|
||||
sql,
|
||||
{
|
||||
"query_vec": str(query_embedding),
|
||||
"lim": top_k,
|
||||
})
|
||||
},
|
||||
)
|
||||
rows = result.mappings().all()
|
||||
|
||||
results = []
|
||||
|
|
@ -306,7 +313,8 @@ class LocalRAGService:
|
|||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
results.append(QueryResult(
|
||||
results.append(
|
||||
QueryResult(
|
||||
content=row["content"],
|
||||
source_id=row["source_doc_id"],
|
||||
source_name=row.get("source_title", ""),
|
||||
|
|
@ -314,13 +322,14 @@ class LocalRAGService:
|
|||
metadata=chunk_meta,
|
||||
doc_id=row["source_doc_id"],
|
||||
title=row.get("source_title", ""),
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def _query_client_side(
|
||||
self,
|
||||
db: Any,
|
||||
db: AsyncSession,
|
||||
query_embedding: list[float],
|
||||
top_k: int,
|
||||
) -> list[QueryResult]:
|
||||
|
|
@ -363,7 +372,8 @@ class LocalRAGService:
|
|||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
candidates.append(QueryResult(
|
||||
candidates.append(
|
||||
QueryResult(
|
||||
content=row["content"],
|
||||
source_id=row["source_doc_id"],
|
||||
source_name=row.get("source_title", ""),
|
||||
|
|
@ -371,7 +381,8 @@ class LocalRAGService:
|
|||
metadata=chunk_meta,
|
||||
doc_id=row["source_doc_id"],
|
||||
title=row.get("source_title", ""),
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
candidates.sort(key=lambda x: x.score, reverse=True)
|
||||
return candidates[:top_k]
|
||||
|
|
@ -398,11 +409,15 @@ class InMemoryLocalRAGService:
|
|||
"""
|
||||
self._embedder = embedder
|
||||
self._text_chunker = TextChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||
self._structural_chunker = StructuralChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||
self._structural_chunker = StructuralChunker(
|
||||
chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
||||
)
|
||||
|
||||
# 内存存储
|
||||
self._chunks: dict[str, dict[str, Any]] = {} # chunk_id → {content, embedding, metadata}
|
||||
self._documents: dict[str, dict[str, Any]] = {} # doc_id → {title, format, chunk_ids, metadata, created_at}
|
||||
self._chunks: dict[str, InMemoryChunkInfo] = {} # chunk_id → {content, embedding, metadata}
|
||||
self._documents: dict[
|
||||
str, InMemoryDocInfo
|
||||
] = {} # doc_id → {title, format, chunk_ids, metadata, created_at}
|
||||
|
||||
async def ingest(self, documents: list[Document]) -> list[str]:
|
||||
"""摄取文档列表
|
||||
|
|
@ -459,7 +474,8 @@ class InMemoryLocalRAGService:
|
|||
source_doc_id = chunk_data["source_doc_id"]
|
||||
doc_info = self._documents.get(source_doc_id, {})
|
||||
|
||||
candidates.append(QueryResult(
|
||||
candidates.append(
|
||||
QueryResult(
|
||||
content=chunk_data["content"],
|
||||
source_id=source_doc_id,
|
||||
source_name=doc_info.get("title", ""),
|
||||
|
|
@ -467,7 +483,8 @@ class InMemoryLocalRAGService:
|
|||
metadata=chunk_data.get("metadata", {}),
|
||||
doc_id=source_doc_id,
|
||||
title=doc_info.get("title", ""),
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
candidates.sort(key=lambda x: x.score, reverse=True)
|
||||
return candidates[:top_k]
|
||||
|
|
@ -488,13 +505,15 @@ class InMemoryLocalRAGService:
|
|||
"""列出已摄取的文档"""
|
||||
sources = []
|
||||
for doc_id, doc_info in self._documents.items():
|
||||
sources.append(SourceInfo(
|
||||
sources.append(
|
||||
SourceInfo(
|
||||
source_id=doc_id,
|
||||
source_name=doc_info["title"],
|
||||
source_type=doc_info.get("format", "local"),
|
||||
document_count=len(doc_info.get("chunk_ids", [])),
|
||||
last_updated=doc_info.get("created_at"),
|
||||
))
|
||||
)
|
||||
)
|
||||
return sources
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ import asyncio
|
|||
import hashlib
|
||||
import logging
|
||||
from dataclasses import replace
|
||||
from typing import Any
|
||||
|
||||
from agentkit.memory.knowledge_base import KnowledgeBase, QueryResult, SourceInfo
|
||||
|
||||
|
|
@ -186,15 +185,13 @@ class MultiSourceRetriever:
|
|||
Returns:
|
||||
所有源的检索结果列表(已应用权重)
|
||||
"""
|
||||
|
||||
async def _query_one(name: str, kb: KnowledgeBase) -> list[QueryResult]:
|
||||
try:
|
||||
results = await kb.query(query, top_k=top_k)
|
||||
# 应用权重
|
||||
weight = (weights or {}).get(name, 1.0)
|
||||
return [
|
||||
replace(r, score=r.score * weight, source_name=name)
|
||||
for r in results
|
||||
]
|
||||
return [replace(r, score=r.score * weight, source_name=name) for r in results]
|
||||
except Exception as e:
|
||||
logger.error(f"Query failed for source '{name}': {e}")
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -7,10 +7,10 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable
|
||||
from typing import Callable
|
||||
|
||||
|
||||
class MemoryFile:
|
||||
|
|
@ -26,8 +26,9 @@ class MemoryFile:
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self, path: Path, char_budget: int | None = None,
|
||||
protected_sections: set[str] | None = None):
|
||||
def __init__(
|
||||
self, path: Path, char_budget: int | None = None, protected_sections: set[str] | None = None
|
||||
):
|
||||
self.path = Path(path)
|
||||
self.char_budget = char_budget
|
||||
self._protected_sections = protected_sections or set()
|
||||
|
|
@ -138,7 +139,7 @@ class MemoryFile:
|
|||
for match in re.finditer(r"^## (.+)$", content, re.MULTILINE):
|
||||
name = match.group(1).strip()
|
||||
start = match.start()
|
||||
next_match = re.search(r"^## ", content[match.end():], re.MULTILINE)
|
||||
next_match = re.search(r"^## ", content[match.end() :], re.MULTILINE)
|
||||
if next_match:
|
||||
end = match.end() + next_match.start()
|
||||
else:
|
||||
|
|
@ -146,7 +147,7 @@ class MemoryFile:
|
|||
sections.append((name, start, end))
|
||||
|
||||
if not sections:
|
||||
return content[:self.char_budget]
|
||||
return content[: self.char_budget]
|
||||
|
||||
# 保持原始顺序,标记每个 section 是否受保护
|
||||
ordered: list[tuple[str, str, bool]] = [] # (name, text, is_protected)
|
||||
|
|
@ -222,8 +223,9 @@ class MemoryStore:
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self, base_dir: Path | str | None = None,
|
||||
on_change: Callable[[str], None] | None = None):
|
||||
def __init__(
|
||||
self, base_dir: Path | str | None = None, on_change: Callable[[str], None] | None = None
|
||||
):
|
||||
if base_dir is None:
|
||||
base_dir = Path.home() / ".agentkit"
|
||||
self.base_dir = Path(base_dir)
|
||||
|
|
@ -238,7 +240,9 @@ class MemoryStore:
|
|||
protected_sections={"版本", "更新历史"},
|
||||
)
|
||||
self._user = MemoryFile(self.base_dir / "memories" / "USER.md", char_budget=USER_BUDGET)
|
||||
self._memory = MemoryFile(self.base_dir / "memories" / "MEMORY.md", char_budget=MEMORY_BUDGET)
|
||||
self._memory = MemoryFile(
|
||||
self.base_dir / "memories" / "MEMORY.md", char_budget=MEMORY_BUDGET
|
||||
)
|
||||
self._daily_dir = self.base_dir / "memories" / "daily"
|
||||
self._daily_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
|
@ -376,4 +380,5 @@ class MemoryStore:
|
|||
self._on_change(new_prompt)
|
||||
except Exception:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).warning("Memory notify_change failed", exc_info=True)
|
||||
|
|
|
|||
|
|
@ -87,10 +87,22 @@ class RuleQueryTransformer(QueryTransformerBase):
|
|||
"""
|
||||
|
||||
_FILLER_WORDS_CN: list[str] = [
|
||||
"帮我", "请", "一下", "分析", "看看", "告诉我", "想知道", "请问",
|
||||
"帮我",
|
||||
"请",
|
||||
"一下",
|
||||
"分析",
|
||||
"看看",
|
||||
"告诉我",
|
||||
"想知道",
|
||||
"请问",
|
||||
]
|
||||
_FILLER_WORDS_EN: list[str] = [
|
||||
"please", "can you", "help me", "could you", "i want to", "i need to",
|
||||
"please",
|
||||
"can you",
|
||||
"help me",
|
||||
"could you",
|
||||
"i want to",
|
||||
"i need to",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
|
|
@ -101,9 +113,7 @@ class RuleQueryTransformer(QueryTransformerBase):
|
|||
self._synonyms = synonyms or {}
|
||||
self._max_sub_queries = max_sub_queries
|
||||
# Pre-compile filler patterns
|
||||
self._filler_patterns_cn = [
|
||||
re.compile(re.escape(w)) for w in self._FILLER_WORDS_CN
|
||||
]
|
||||
self._filler_patterns_cn = [re.compile(re.escape(w)) for w in self._FILLER_WORDS_CN]
|
||||
self._filler_patterns_en = [
|
||||
re.compile(re.escape(w), re.IGNORECASE) for w in self._FILLER_WORDS_EN
|
||||
]
|
||||
|
|
@ -166,7 +176,9 @@ def create_query_transformer(
|
|||
"""工厂函数:根据策略创建查询改写器"""
|
||||
if strategy == "llm":
|
||||
if llm_gateway is None:
|
||||
logger.warning("LLM strategy requested but no llm_gateway provided, falling back to NoOp")
|
||||
logger.warning(
|
||||
"LLM strategy requested but no llm_gateway provided, falling back to NoOp"
|
||||
)
|
||||
return NoOpQueryTransformer()
|
||||
return LLMQueryTransformer(llm_gateway, max_sub_queries=max_sub_queries)
|
||||
elif strategy == "rule":
|
||||
|
|
|
|||
|
|
@ -7,11 +7,11 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from agentkit.memory.base import MemoryItem
|
||||
from agentkit.memory.base import MemoryItem, MetadataDict
|
||||
from agentkit.memory.query_transformer import QueryTransformerBase, NoOpQueryTransformer
|
||||
from agentkit.memory.relevance_scorer import (
|
||||
RelevanceScorer,
|
||||
|
|
@ -19,6 +19,10 @@ from agentkit.memory.relevance_scorer import (
|
|||
RetrievalEvaluation,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# 避免与 retriever.py 形成运行时循环导入。
|
||||
from agentkit.memory.retriever import MemoryRetriever
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -70,7 +74,7 @@ class RAGSelfCorrectionLoop:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
retriever: Any, # MemoryRetriever
|
||||
retriever: MemoryRetriever,
|
||||
scorer: RelevanceScorer | None = None,
|
||||
query_transformer: QueryTransformerBase | None = None,
|
||||
max_retries: int = 3,
|
||||
|
|
@ -87,7 +91,7 @@ class RAGSelfCorrectionLoop:
|
|||
query: str,
|
||||
top_k: int = 5,
|
||||
token_budget: int = 3000,
|
||||
filters: dict[str, Any] | None = None,
|
||||
filters: MetadataDict | None = None,
|
||||
) -> RAGLoopResult:
|
||||
"""执行带自纠正的检索
|
||||
|
||||
|
|
@ -107,8 +111,11 @@ class RAGSelfCorrectionLoop:
|
|||
while retry_count <= self._max_retries:
|
||||
# RETRIEVE
|
||||
items = await self._retriever.retrieve(
|
||||
current_query, top_k=top_k, token_budget=token_budget,
|
||||
filters=filters, _skip_correction=True,
|
||||
current_query,
|
||||
top_k=top_k,
|
||||
token_budget=token_budget,
|
||||
filters=filters,
|
||||
_skip_correction=True,
|
||||
)
|
||||
|
||||
# EVALUATE
|
||||
|
|
@ -144,9 +151,7 @@ class RAGSelfCorrectionLoop:
|
|||
# CORRECT — rewrite query and retry
|
||||
retry_count += 1
|
||||
if retry_count <= self._max_retries:
|
||||
current_query = await self._rewrite_query(
|
||||
query, current_query, evaluation
|
||||
)
|
||||
current_query = await self._rewrite_query(query, current_query, evaluation)
|
||||
continue
|
||||
|
||||
# DEGRADE — exceeded max retries
|
||||
|
|
@ -154,9 +159,7 @@ class RAGSelfCorrectionLoop:
|
|||
|
||||
# Degraded result: filter to relevant items and mark low confidence
|
||||
relevant_items = [
|
||||
s.item
|
||||
for s in evaluation.scores
|
||||
if s.verdict != RelevanceVerdict.INCORRECT
|
||||
s.item for s in evaluation.scores if s.verdict != RelevanceVerdict.INCORRECT
|
||||
]
|
||||
result_items = relevant_items if relevant_items else items
|
||||
|
||||
|
|
|
|||
|
|
@ -6,11 +6,9 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from agentkit.memory.base import MemoryItem
|
||||
|
||||
|
|
@ -120,9 +118,7 @@ class RelevanceScorer:
|
|||
reason=reason,
|
||||
)
|
||||
|
||||
def evaluate(
|
||||
self, query: str, items: list[MemoryItem]
|
||||
) -> RetrievalEvaluation:
|
||||
def evaluate(self, query: str, items: list[MemoryItem]) -> RetrievalEvaluation:
|
||||
"""评估一次检索的整体质量"""
|
||||
if not items:
|
||||
return RetrievalEvaluation(
|
||||
|
|
@ -134,9 +130,7 @@ class RelevanceScorer:
|
|||
)
|
||||
|
||||
scores = [self.score_item(query, item) for item in items]
|
||||
relevant_count = sum(
|
||||
1 for s in scores if s.verdict != RelevanceVerdict.INCORRECT
|
||||
)
|
||||
relevant_count = sum(1 for s in scores if s.verdict != RelevanceVerdict.INCORRECT)
|
||||
avg_score = sum(s.score for s in scores) / len(scores)
|
||||
|
||||
# Overall verdict based on average score and relevant ratio
|
||||
|
|
|
|||
|
|
@ -7,19 +7,16 @@ from __future__ import annotations
|
|||
|
||||
import asyncio
|
||||
import logging
|
||||
import math
|
||||
from dataclasses import replace
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from agentkit.memory.base import Memory, MemoryItem
|
||||
from agentkit.memory.base import MemoryItem, MetadataDict
|
||||
from agentkit.memory.working import WorkingMemory
|
||||
from agentkit.memory.episodic import EpisodicMemory
|
||||
from agentkit.memory.semantic import SemanticMemory
|
||||
from agentkit.memory.query_transformer import QueryTransformerBase
|
||||
from agentkit.memory.rag_loop import RAGSelfCorrectionLoop
|
||||
from agentkit.memory.relevance_scorer import RelevanceScorer
|
||||
from agentkit.memory.knowledge_base import KnowledgeBase, QueryResult
|
||||
from agentkit.memory.knowledge_base import KnowledgeBase
|
||||
from agentkit.memory.multi_source_retriever import MultiSourceRetriever
|
||||
from agentkit.tools.base import Tool
|
||||
|
||||
|
|
@ -32,11 +29,11 @@ def _estimate_tokens(text: str) -> int:
|
|||
Chinese characters typically use 1-2 tokens each.
|
||||
English words typically use 1 token each.
|
||||
"""
|
||||
cjk_count = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
|
||||
cjk_count = sum(1 for c in text if "\u4e00" <= c <= "\u9fff")
|
||||
non_cjk = text
|
||||
for c in text:
|
||||
if '\u4e00' <= c <= '\u9fff':
|
||||
non_cjk = non_cjk.replace(c, ' ')
|
||||
if "\u4e00" <= c <= "\u9fff":
|
||||
non_cjk = non_cjk.replace(c, " ")
|
||||
word_count = len(non_cjk.split())
|
||||
return cjk_count * 2 + word_count
|
||||
|
||||
|
|
@ -89,7 +86,7 @@ class MemoryRetriever:
|
|||
query: str,
|
||||
top_k: int = 5,
|
||||
token_budget: int = 3000,
|
||||
filters: dict[str, Any] | None = None,
|
||||
filters: MetadataDict | None = None,
|
||||
_skip_correction: bool = False,
|
||||
sources: list[str] | None = None,
|
||||
source_weights: dict[str, float] | None = None,
|
||||
|
|
@ -121,9 +118,7 @@ class MemoryRetriever:
|
|||
query, top_k=top_k, token_budget=token_budget, filters=filters
|
||||
)
|
||||
if result.degraded:
|
||||
logger.warning(
|
||||
f"RAG self-correction degraded after {result.total_retries} retries"
|
||||
)
|
||||
logger.warning(f"RAG self-correction degraded after {result.total_retries} retries")
|
||||
return result.items
|
||||
# Query transformation
|
||||
if self._query_transformer is not None:
|
||||
|
|
@ -139,9 +134,7 @@ class MemoryRetriever:
|
|||
|
||||
# Sub-query search in parallel
|
||||
if sub_queries:
|
||||
sub_tasks = [
|
||||
self._search_layers(sq, top_k, filters) for sq in sub_queries
|
||||
]
|
||||
sub_tasks = [self._search_layers(sq, top_k, filters) for sq in sub_queries]
|
||||
sub_results = await asyncio.gather(*sub_tasks, return_exceptions=True)
|
||||
for result in sub_results:
|
||||
if isinstance(result, Exception):
|
||||
|
|
@ -178,7 +171,7 @@ class MemoryRetriever:
|
|||
self,
|
||||
query: str,
|
||||
top_k: int = 5,
|
||||
filters: dict[str, Any] | None = None,
|
||||
filters: MetadataDict | None = None,
|
||||
) -> list[MemoryItem]:
|
||||
"""Search all configured memory layers with a single query"""
|
||||
tasks = []
|
||||
|
|
@ -237,7 +230,8 @@ class MemoryRetriever:
|
|||
# QueryResult → MemoryItem
|
||||
items = []
|
||||
for r in kb_results:
|
||||
items.append(MemoryItem(
|
||||
items.append(
|
||||
MemoryItem(
|
||||
key=r.source_id,
|
||||
value=r.content,
|
||||
metadata={
|
||||
|
|
@ -248,7 +242,8 @@ class MemoryRetriever:
|
|||
"document_title": r.title,
|
||||
},
|
||||
score=r.score,
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
# Token 预算管理
|
||||
selected = []
|
||||
|
|
@ -318,7 +313,9 @@ class MemoryRetriever:
|
|||
if source == "rag":
|
||||
kb_type = item.metadata.get("kb_type", "知识库")
|
||||
document_title = item.metadata.get("document_title", "未知文档")
|
||||
return f"### 知识库参考 [来源: {kb_type} | 相关度: {score:.2f} | 文档: {document_title}]"
|
||||
return (
|
||||
f"### 知识库参考 [来源: {kb_type} | 相关度: {score:.2f} | 文档: {document_title}]"
|
||||
)
|
||||
elif source == "graph":
|
||||
return f"### 知识图谱 [实体: {item.key} | 相关度: {score:.2f}]"
|
||||
elif source == "episodic":
|
||||
|
|
@ -330,7 +327,7 @@ class MemoryRetriever:
|
|||
return f"### 参考 [来源: {source} | 相关度: {score:.2f}]"
|
||||
|
||||
async def store_episode(
|
||||
self, key: str, value: Any, metadata: dict[str, Any] | None = None
|
||||
self, key: str, value: object, metadata: MetadataDict | None = None
|
||||
) -> None:
|
||||
"""Store an episode into episodic memory if available.
|
||||
|
||||
|
|
@ -386,12 +383,14 @@ class RetrieveKnowledgeTool(Tool):
|
|||
items = await self._retriever.retrieve(query, top_k=5)
|
||||
results = []
|
||||
for item in items:
|
||||
results.append({
|
||||
results.append(
|
||||
{
|
||||
"content": item.value,
|
||||
"score": item.score,
|
||||
"source": item.metadata.get("source", "unknown"),
|
||||
"document_title": item.metadata.get("document_title", ""),
|
||||
})
|
||||
}
|
||||
)
|
||||
return {"query": query, "results": results, "call_count": self._call_count}
|
||||
except Exception as e:
|
||||
return {"error": str(e), "results": []}
|
||||
|
|
|
|||
|
|
@ -3,14 +3,45 @@
|
|||
适配器模式,对接外部 RAG 服务和知识图谱。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from __future__ import annotations
|
||||
|
||||
from agentkit.memory.base import Memory, MemoryItem
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
from agentkit.memory.base import Memory, MemoryItem, MetadataDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agentkit.memory.http_rag import RAGSearchResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _RAGServiceLike(Protocol):
|
||||
"""RAG 检索服务最小接口契约(duck-typed)。"""
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
knowledge_base_ids: list[str] | None = ...,
|
||||
top_k: int = ...,
|
||||
) -> list[RAGSearchResult]: ...
|
||||
|
||||
async def enhanced_search(
|
||||
self,
|
||||
query: str,
|
||||
knowledge_base_ids: list[str] | None = ...,
|
||||
top_k: int = ...,
|
||||
use_rerank: bool = ...,
|
||||
use_compression: bool = ...,
|
||||
) -> list[RAGSearchResult]: ...
|
||||
|
||||
|
||||
class _GraphServiceLike(Protocol):
|
||||
"""知识图谱服务最小接口契约(duck-typed)。"""
|
||||
|
||||
async def query(self, query: str, depth: int = ...) -> list[dict[str, object]]: ...
|
||||
|
||||
|
||||
class SemanticMemory(Memory):
|
||||
"""Semantic Memory - 知识库检索
|
||||
|
||||
|
|
@ -19,8 +50,8 @@ class SemanticMemory(Memory):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
rag_service: Any = None,
|
||||
graph_service: Any = None,
|
||||
rag_service: _RAGServiceLike | None = None,
|
||||
graph_service: _GraphServiceLike | None = None,
|
||||
knowledge_base_ids: list[str] | None = None,
|
||||
search_mode: str = "standard",
|
||||
use_rerank: bool = True,
|
||||
|
|
@ -45,9 +76,9 @@ class SemanticMemory(Memory):
|
|||
self._use_compression = use_compression
|
||||
self._kb_weights = kb_weights
|
||||
|
||||
async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None:
|
||||
async def store(self, key: str, value: object, metadata: MetadataDict | None = None) -> None:
|
||||
"""Semantic Memory 通常只读,写入委托给 RAG 服务的 ingest 方法"""
|
||||
if self._rag_service and hasattr(self._rag_service, 'ingest'):
|
||||
if self._rag_service and hasattr(self._rag_service, "ingest"):
|
||||
await self._rag_service.ingest(key, value, metadata)
|
||||
else:
|
||||
logger.warning("SemanticMemory.store: no RAG service configured for writing")
|
||||
|
|
@ -56,7 +87,9 @@ class SemanticMemory(Memory):
|
|||
"""按 key 精确检索(Semantic Memory 通常不按 key 检索)"""
|
||||
return None
|
||||
|
||||
async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None) -> list[MemoryItem]:
|
||||
async def search(
|
||||
self, query: str, top_k: int = 5, filters: MetadataDict | None = None
|
||||
) -> list[MemoryItem]:
|
||||
"""语义检索知识库"""
|
||||
items = []
|
||||
|
||||
|
|
@ -64,7 +97,9 @@ class SemanticMemory(Memory):
|
|||
if self._rag_service:
|
||||
try:
|
||||
kb_ids = (filters or {}).get("knowledge_base_ids", self._knowledge_base_ids)
|
||||
if self._search_mode == "enhanced" and hasattr(self._rag_service, "enhanced_search"):
|
||||
if self._search_mode == "enhanced" and hasattr(
|
||||
self._rag_service, "enhanced_search"
|
||||
):
|
||||
results = await self._rag_service.enhanced_search(
|
||||
query,
|
||||
knowledge_base_ids=kb_ids,
|
||||
|
|
@ -73,14 +108,17 @@ class SemanticMemory(Memory):
|
|||
use_compression=self._use_compression,
|
||||
)
|
||||
else:
|
||||
results = await self._rag_service.search(query, knowledge_base_ids=kb_ids, top_k=top_k)
|
||||
results = await self._rag_service.search(
|
||||
query, knowledge_base_ids=kb_ids, top_k=top_k
|
||||
)
|
||||
for r in results:
|
||||
kb_id = r.get("knowledge_base_id", "")
|
||||
score = r.get("score", 0.0)
|
||||
# Apply per-KB weights
|
||||
if self._kb_weights and kb_id in self._kb_weights:
|
||||
score *= self._kb_weights[kb_id]
|
||||
items.append(MemoryItem(
|
||||
items.append(
|
||||
MemoryItem(
|
||||
key=r.get("id", ""),
|
||||
value=r.get("content", ""),
|
||||
metadata={
|
||||
|
|
@ -90,7 +128,8 @@ class SemanticMemory(Memory):
|
|||
"knowledge_base_id": kb_id,
|
||||
},
|
||||
score=score,
|
||||
))
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"RAG search failed: {e}")
|
||||
|
||||
|
|
@ -99,7 +138,8 @@ class SemanticMemory(Memory):
|
|||
try:
|
||||
graph_results = await self._graph_service.query(query, depth=2)
|
||||
for r in graph_results[:top_k]:
|
||||
items.append(MemoryItem(
|
||||
items.append(
|
||||
MemoryItem(
|
||||
key=r.get("id", ""),
|
||||
value=r.get("content", ""),
|
||||
metadata={
|
||||
|
|
@ -108,7 +148,8 @@ class SemanticMemory(Memory):
|
|||
"relations": r.get("relations", []),
|
||||
},
|
||||
score=r.get("score", 0.0),
|
||||
))
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Graph search failed: {e}")
|
||||
|
||||
|
|
|
|||
|
|
@ -3,11 +3,10 @@
|
|||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
from agentkit.memory.base import Memory, MemoryItem
|
||||
from agentkit.memory.base import Memory, MemoryItem, MetadataDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -32,7 +31,7 @@ class WorkingMemory(Memory):
|
|||
def _make_key(self, key: str) -> str:
|
||||
return f"{self._key_prefix}:{key}"
|
||||
|
||||
async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None:
|
||||
async def store(self, key: str, value: object, metadata: MetadataDict | None = None) -> None:
|
||||
redis_key = self._make_key(key)
|
||||
item = MemoryItem(
|
||||
key=key,
|
||||
|
|
@ -57,10 +56,14 @@ class WorkingMemory(Memory):
|
|||
value=item_dict["value"],
|
||||
metadata=item_dict.get("metadata", {}),
|
||||
score=item_dict.get("score", 1.0),
|
||||
created_at=datetime.fromisoformat(item_dict["created_at"]) if item_dict.get("created_at") else datetime.now(timezone.utc),
|
||||
created_at=datetime.fromisoformat(item_dict["created_at"])
|
||||
if item_dict.get("created_at")
|
||||
else datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None) -> list[MemoryItem]:
|
||||
async def search(
|
||||
self, query: str, top_k: int = 5, filters: MetadataDict | None = None
|
||||
) -> list[MemoryItem]:
|
||||
"""Working Memory 不支持语义检索,按 key 前缀匹配"""
|
||||
pattern = self._make_key(f"{query}*")
|
||||
keys = []
|
||||
|
|
@ -74,13 +77,15 @@ class WorkingMemory(Memory):
|
|||
data = await self._redis.get(key)
|
||||
if data:
|
||||
item_dict = json.loads(data)
|
||||
items.append(MemoryItem(
|
||||
items.append(
|
||||
MemoryItem(
|
||||
key=item_dict["key"],
|
||||
value=item_dict["value"],
|
||||
metadata=item_dict.get("metadata", {}),
|
||||
score=1.0,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
))
|
||||
)
|
||||
)
|
||||
return items
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
|
|
|
|||
Loading…
Reference in New Issue