fischer-agentkit/src/agentkit/memory/semantic.py

163 lines
6.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Semantic Memory - 知识库适配器
适配器模式,对接外部 RAG 服务和知识图谱。
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Protocol
from agentkit.memory.base import Memory, MemoryItem, MetadataDict
if TYPE_CHECKING:
from agentkit.memory.http_rag import RAGSearchResult
logger = logging.getLogger(__name__)
class _RAGServiceLike(Protocol):
"""RAG 检索服务最小接口契约duck-typed"""
async def search(
self,
query: str,
knowledge_base_ids: list[str] | None = ...,
top_k: int = ...,
) -> list[RAGSearchResult]: ...
async def enhanced_search(
self,
query: str,
knowledge_base_ids: list[str] | None = ...,
top_k: int = ...,
use_rerank: bool = ...,
use_compression: bool = ...,
) -> list[RAGSearchResult]: ...
class _GraphServiceLike(Protocol):
"""知识图谱服务最小接口契约duck-typed"""
async def query(self, query: str, depth: int = ...) -> list[dict[str, object]]: ...
class SemanticMemory(Memory):
"""Semantic Memory - 知识库检索
通过适配器对接外部 RAG 服务,不直接依赖具体实现。
"""
def __init__(
self,
rag_service: _RAGServiceLike | None = None,
graph_service: _GraphServiceLike | None = None,
knowledge_base_ids: list[str] | None = None,
search_mode: str = "standard",
use_rerank: bool = True,
use_compression: bool = False,
kb_weights: dict[str, float] | None = None,
):
"""
Args:
rag_service: RAG 检索服务(需提供 search 方法)
graph_service: 知识图谱服务(需提供 query 方法)
knowledge_base_ids: 默认检索的知识库 ID 列表
search_mode: 检索模式,"standard""enhanced"
use_rerank: 启用 rerank 重排序(仅 enhanced 模式生效)
use_compression: 启用上下文压缩(仅 enhanced 模式生效)
kb_weights: 知识库权重映射key 为知识库 IDvalue 为权重倍数
"""
self._rag_service = rag_service
self._graph_service = graph_service
self._knowledge_base_ids = knowledge_base_ids or []
self._search_mode = search_mode
self._use_rerank = use_rerank
self._use_compression = use_compression
self._kb_weights = kb_weights
async def store(self, key: str, value: object, metadata: MetadataDict | None = None) -> None:
"""Semantic Memory 通常只读,写入委托给 RAG 服务的 ingest 方法"""
if self._rag_service and hasattr(self._rag_service, "ingest"):
await self._rag_service.ingest(key, value, metadata)
else:
logger.warning("SemanticMemory.store: no RAG service configured for writing")
async def retrieve(self, key: str) -> MemoryItem | None:
"""按 key 精确检索Semantic Memory 通常不按 key 检索)"""
return None
async def search(
self, query: str, top_k: int = 5, filters: MetadataDict | None = None
) -> list[MemoryItem]:
"""语义检索知识库"""
items = []
# RAG 检索
if self._rag_service:
try:
kb_ids = (filters or {}).get("knowledge_base_ids", self._knowledge_base_ids)
if self._search_mode == "enhanced" and hasattr(
self._rag_service, "enhanced_search"
):
results = await self._rag_service.enhanced_search(
query,
knowledge_base_ids=kb_ids,
top_k=top_k,
use_rerank=self._use_rerank,
use_compression=self._use_compression,
)
else:
results = await self._rag_service.search(
query, knowledge_base_ids=kb_ids, top_k=top_k
)
for r in results:
kb_id = r.get("knowledge_base_id", "")
score = r.get("score", 0.0)
# Apply per-KB weights
if self._kb_weights and kb_id in self._kb_weights:
score *= self._kb_weights[kb_id]
items.append(
MemoryItem(
key=r.get("id", ""),
value=r.get("content", ""),
metadata={
"source": r.get("source", "rag"),
"score": score,
"document_id": r.get("document_id"),
"knowledge_base_id": kb_id,
},
score=score,
)
)
except Exception as e:
logger.error(f"RAG search failed: {e}")
# 知识图谱检索
if self._graph_service:
try:
graph_results = await self._graph_service.query(query, depth=2)
for r in graph_results[:top_k]:
items.append(
MemoryItem(
key=r.get("id", ""),
value=r.get("content", ""),
metadata={
"source": "graph",
"entities": r.get("entities", []),
"relations": r.get("relations", []),
},
score=r.get("score", 0.0),
)
)
except Exception as e:
logger.error(f"Graph search failed: {e}")
items.sort(key=lambda x: x.score, reverse=True)
return items[:top_k]
async def delete(self, key: str) -> bool:
"""Semantic Memory 通常只读"""
logger.warning("SemanticMemory.delete: read-only memory")
return False