163 lines
6.0 KiB
Python
163 lines
6.0 KiB
Python
"""Semantic Memory - 知识库适配器
|
||
|
||
适配器模式,对接外部 RAG 服务和知识图谱。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
from typing import TYPE_CHECKING, Protocol
|
||
|
||
from agentkit.memory.base import Memory, MemoryItem, MetadataDict
|
||
|
||
if TYPE_CHECKING:
|
||
from agentkit.memory.http_rag import RAGSearchResult
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class _RAGServiceLike(Protocol):
|
||
"""RAG 检索服务最小接口契约(duck-typed)。"""
|
||
|
||
async def search(
|
||
self,
|
||
query: str,
|
||
knowledge_base_ids: list[str] | None = ...,
|
||
top_k: int = ...,
|
||
) -> list[RAGSearchResult]: ...
|
||
|
||
async def enhanced_search(
|
||
self,
|
||
query: str,
|
||
knowledge_base_ids: list[str] | None = ...,
|
||
top_k: int = ...,
|
||
use_rerank: bool = ...,
|
||
use_compression: bool = ...,
|
||
) -> list[RAGSearchResult]: ...
|
||
|
||
|
||
class _GraphServiceLike(Protocol):
|
||
"""知识图谱服务最小接口契约(duck-typed)。"""
|
||
|
||
async def query(self, query: str, depth: int = ...) -> list[dict[str, object]]: ...
|
||
|
||
|
||
class SemanticMemory(Memory):
|
||
"""Semantic Memory - 知识库检索
|
||
|
||
通过适配器对接外部 RAG 服务,不直接依赖具体实现。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
rag_service: _RAGServiceLike | None = None,
|
||
graph_service: _GraphServiceLike | None = None,
|
||
knowledge_base_ids: list[str] | None = None,
|
||
search_mode: str = "standard",
|
||
use_rerank: bool = True,
|
||
use_compression: bool = False,
|
||
kb_weights: dict[str, float] | None = None,
|
||
):
|
||
"""
|
||
Args:
|
||
rag_service: RAG 检索服务(需提供 search 方法)
|
||
graph_service: 知识图谱服务(需提供 query 方法)
|
||
knowledge_base_ids: 默认检索的知识库 ID 列表
|
||
search_mode: 检索模式,"standard" 或 "enhanced"
|
||
use_rerank: 启用 rerank 重排序(仅 enhanced 模式生效)
|
||
use_compression: 启用上下文压缩(仅 enhanced 模式生效)
|
||
kb_weights: 知识库权重映射,key 为知识库 ID,value 为权重倍数
|
||
"""
|
||
self._rag_service = rag_service
|
||
self._graph_service = graph_service
|
||
self._knowledge_base_ids = knowledge_base_ids or []
|
||
self._search_mode = search_mode
|
||
self._use_rerank = use_rerank
|
||
self._use_compression = use_compression
|
||
self._kb_weights = kb_weights
|
||
|
||
async def store(self, key: str, value: object, metadata: MetadataDict | None = None) -> None:
|
||
"""Semantic Memory 通常只读,写入委托给 RAG 服务的 ingest 方法"""
|
||
if self._rag_service and hasattr(self._rag_service, "ingest"):
|
||
await self._rag_service.ingest(key, value, metadata)
|
||
else:
|
||
logger.warning("SemanticMemory.store: no RAG service configured for writing")
|
||
|
||
async def retrieve(self, key: str) -> MemoryItem | None:
|
||
"""按 key 精确检索(Semantic Memory 通常不按 key 检索)"""
|
||
return None
|
||
|
||
async def search(
|
||
self, query: str, top_k: int = 5, filters: MetadataDict | None = None
|
||
) -> list[MemoryItem]:
|
||
"""语义检索知识库"""
|
||
items = []
|
||
|
||
# RAG 检索
|
||
if self._rag_service:
|
||
try:
|
||
kb_ids = (filters or {}).get("knowledge_base_ids", self._knowledge_base_ids)
|
||
if self._search_mode == "enhanced" and hasattr(
|
||
self._rag_service, "enhanced_search"
|
||
):
|
||
results = await self._rag_service.enhanced_search(
|
||
query,
|
||
knowledge_base_ids=kb_ids,
|
||
top_k=top_k,
|
||
use_rerank=self._use_rerank,
|
||
use_compression=self._use_compression,
|
||
)
|
||
else:
|
||
results = await self._rag_service.search(
|
||
query, knowledge_base_ids=kb_ids, top_k=top_k
|
||
)
|
||
for r in results:
|
||
kb_id = r.get("knowledge_base_id", "")
|
||
score = r.get("score", 0.0)
|
||
# Apply per-KB weights
|
||
if self._kb_weights and kb_id in self._kb_weights:
|
||
score *= self._kb_weights[kb_id]
|
||
items.append(
|
||
MemoryItem(
|
||
key=r.get("id", ""),
|
||
value=r.get("content", ""),
|
||
metadata={
|
||
"source": r.get("source", "rag"),
|
||
"score": score,
|
||
"document_id": r.get("document_id"),
|
||
"knowledge_base_id": kb_id,
|
||
},
|
||
score=score,
|
||
)
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"RAG search failed: {e}")
|
||
|
||
# 知识图谱检索
|
||
if self._graph_service:
|
||
try:
|
||
graph_results = await self._graph_service.query(query, depth=2)
|
||
for r in graph_results[:top_k]:
|
||
items.append(
|
||
MemoryItem(
|
||
key=r.get("id", ""),
|
||
value=r.get("content", ""),
|
||
metadata={
|
||
"source": "graph",
|
||
"entities": r.get("entities", []),
|
||
"relations": r.get("relations", []),
|
||
},
|
||
score=r.get("score", 0.0),
|
||
)
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"Graph search failed: {e}")
|
||
|
||
items.sort(key=lambda x: x.score, reverse=True)
|
||
return items[:top_k]
|
||
|
||
async def delete(self, key: str) -> bool:
|
||
"""Semantic Memory 通常只读"""
|
||
logger.warning("SemanticMemory.delete: read-only memory")
|
||
return False
|