126 lines
3.9 KiB
Python
126 lines
3.9 KiB
Python
"""MemoryRetriever - 混合检索器
|
||
|
||
并行查询三层记忆,按权重融合排序。
|
||
"""
|
||
|
||
import asyncio
|
||
import logging
|
||
import math
|
||
from dataclasses import replace
|
||
from datetime import datetime
|
||
from typing import Any
|
||
|
||
from agentkit.memory.base import Memory, MemoryItem
|
||
from agentkit.memory.working import WorkingMemory
|
||
from agentkit.memory.episodic import EpisodicMemory
|
||
from agentkit.memory.semantic import SemanticMemory
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class MemoryRetriever:
|
||
"""混合检索器 - 并行查询三层记忆,按权重融合排序
|
||
|
||
检索策略:
|
||
1. 并行查询 Working/Episodic/Semantic 三层
|
||
2. 按权重融合排序(默认 Working 0.2, Episodic 0.4, Semantic 0.4)
|
||
3. 时间衰减:越久远的记忆权重越低
|
||
4. 上下文窗口管理:总 token 不超过预算
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
working_memory: WorkingMemory | None = None,
|
||
episodic_memory: EpisodicMemory | None = None,
|
||
semantic_memory: SemanticMemory | None = None,
|
||
weights: dict[str, float] | None = None,
|
||
):
|
||
self._working = working_memory
|
||
self._episodic = episodic_memory
|
||
self._semantic = semantic_memory
|
||
self._weights = weights or {
|
||
"working": 0.2,
|
||
"episodic": 0.4,
|
||
"semantic": 0.4,
|
||
}
|
||
|
||
async def retrieve(
|
||
self,
|
||
query: str,
|
||
top_k: int = 5,
|
||
token_budget: int = 3000,
|
||
filters: dict[str, Any] | None = None,
|
||
) -> list[MemoryItem]:
|
||
"""混合检索三层记忆"""
|
||
tasks = []
|
||
layer_names = []
|
||
|
||
if self._working:
|
||
tasks.append(self._working.search(query, top_k=top_k, filters=filters))
|
||
layer_names.append("working")
|
||
if self._episodic:
|
||
tasks.append(self._episodic.search(query, top_k=top_k, filters=filters))
|
||
layer_names.append("episodic")
|
||
if self._semantic:
|
||
tasks.append(self._semantic.search(query, top_k=top_k, filters=filters))
|
||
layer_names.append("semantic")
|
||
|
||
if not tasks:
|
||
return []
|
||
|
||
# 并行查询
|
||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
||
# 融合排序
|
||
all_items = []
|
||
for layer_name, result in zip(layer_names, results):
|
||
if isinstance(result, Exception):
|
||
logger.error(f"Memory search failed for {layer_name}: {result}")
|
||
continue
|
||
weight = self._weights.get(layer_name, 0.3)
|
||
for item in result:
|
||
weighted = replace(item, score=item.score * weight)
|
||
all_items.append(weighted)
|
||
|
||
# 按分数排序
|
||
all_items.sort(key=lambda x: x.score, reverse=True)
|
||
|
||
# Token 预算管理
|
||
selected = []
|
||
total_tokens = 0
|
||
for item in all_items:
|
||
text = str(item.value)
|
||
estimated_tokens = len(text) // 4
|
||
if total_tokens + estimated_tokens > token_budget:
|
||
continue
|
||
selected.append(item)
|
||
total_tokens += estimated_tokens
|
||
if len(selected) >= top_k:
|
||
break
|
||
|
||
return selected
|
||
|
||
async def get_context_string(
|
||
self,
|
||
query: str,
|
||
top_k: int = 5,
|
||
token_budget: int = 3000,
|
||
) -> str:
|
||
"""获取格式化的上下文字符串"""
|
||
items = await self.retrieve(query, top_k, token_budget)
|
||
parts = []
|
||
for item in items:
|
||
parts.append(str(item.value))
|
||
return "\n\n".join(parts)
|
||
|
||
async def store_episode(
|
||
self, key: str, value: Any, metadata: dict[str, Any] | None = None
|
||
) -> None:
|
||
"""Store an episode into episodic memory if available.
|
||
|
||
Public API that delegates to the underlying EpisodicMemory, avoiding
|
||
the need for callers to access the private ``_episodic`` attribute.
|
||
"""
|
||
if self._episodic is not None:
|
||
await self._episodic.store(key, value, metadata)
|