fischer-agentkit/src/agentkit/memory/retriever.py

126 lines
3.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""MemoryRetriever - 混合检索器
并行查询三层记忆,按权重融合排序。
"""
import asyncio
import logging
import math
from dataclasses import replace
from datetime import datetime
from typing import Any
from agentkit.memory.base import Memory, MemoryItem
from agentkit.memory.working import WorkingMemory
from agentkit.memory.episodic import EpisodicMemory
from agentkit.memory.semantic import SemanticMemory
logger = logging.getLogger(__name__)
class MemoryRetriever:
"""混合检索器 - 并行查询三层记忆,按权重融合排序
检索策略:
1. 并行查询 Working/Episodic/Semantic 三层
2. 按权重融合排序(默认 Working 0.2, Episodic 0.4, Semantic 0.4
3. 时间衰减:越久远的记忆权重越低
4. 上下文窗口管理:总 token 不超过预算
"""
def __init__(
self,
working_memory: WorkingMemory | None = None,
episodic_memory: EpisodicMemory | None = None,
semantic_memory: SemanticMemory | None = None,
weights: dict[str, float] | None = None,
):
self._working = working_memory
self._episodic = episodic_memory
self._semantic = semantic_memory
self._weights = weights or {
"working": 0.2,
"episodic": 0.4,
"semantic": 0.4,
}
async def retrieve(
self,
query: str,
top_k: int = 5,
token_budget: int = 3000,
filters: dict[str, Any] | None = None,
) -> list[MemoryItem]:
"""混合检索三层记忆"""
tasks = []
layer_names = []
if self._working:
tasks.append(self._working.search(query, top_k=top_k, filters=filters))
layer_names.append("working")
if self._episodic:
tasks.append(self._episodic.search(query, top_k=top_k, filters=filters))
layer_names.append("episodic")
if self._semantic:
tasks.append(self._semantic.search(query, top_k=top_k, filters=filters))
layer_names.append("semantic")
if not tasks:
return []
# 并行查询
results = await asyncio.gather(*tasks, return_exceptions=True)
# 融合排序
all_items = []
for layer_name, result in zip(layer_names, results):
if isinstance(result, Exception):
logger.error(f"Memory search failed for {layer_name}: {result}")
continue
weight = self._weights.get(layer_name, 0.3)
for item in result:
weighted = replace(item, score=item.score * weight)
all_items.append(weighted)
# 按分数排序
all_items.sort(key=lambda x: x.score, reverse=True)
# Token 预算管理
selected = []
total_tokens = 0
for item in all_items:
text = str(item.value)
estimated_tokens = len(text) // 4
if total_tokens + estimated_tokens > token_budget:
continue
selected.append(item)
total_tokens += estimated_tokens
if len(selected) >= top_k:
break
return selected
async def get_context_string(
self,
query: str,
top_k: int = 5,
token_budget: int = 3000,
) -> str:
"""获取格式化的上下文字符串"""
items = await self.retrieve(query, top_k, token_budget)
parts = []
for item in items:
parts.append(str(item.value))
return "\n\n".join(parts)
async def store_episode(
self, key: str, value: Any, metadata: dict[str, Any] | None = None
) -> None:
"""Store an episode into episodic memory if available.
Public API that delegates to the underlying EpisodicMemory, avoiding
the need for callers to access the private ``_episodic`` attribute.
"""
if self._episodic is not None:
await self._episodic.store(key, value, metadata)