87 lines
2.8 KiB
Python
87 lines
2.8 KiB
Python
"""Memory 抽象基类 - 统一记忆接口"""
|
||
|
||
from abc import ABC, abstractmethod
|
||
from dataclasses import dataclass, field
|
||
from datetime import datetime, timezone
|
||
from typing import TypeAlias
|
||
|
||
# 共享类型别名 — 跨 memory 子系统复用,避免 `Any` 残留。
|
||
# MetadataValue 覆盖 metadata dict 中实际出现的原始类型;
|
||
# MemoryValue 额外允许 dict/list 容器以容纳结构化负载(如 episodic 经验字典)。
|
||
MetadataValue: TypeAlias = str | int | float | bool | None
|
||
MetadataDict: TypeAlias = dict[str, MetadataValue]
|
||
MemoryValue: TypeAlias = (
|
||
str | int | float | bool | None | dict[str, MetadataValue] | list[MetadataValue]
|
||
)
|
||
|
||
|
||
@dataclass
|
||
class MemoryItem:
|
||
"""记忆条目"""
|
||
|
||
key: str
|
||
value: object
|
||
metadata: MetadataDict = field(default_factory=dict)
|
||
score: float = 1.0
|
||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||
|
||
def to_dict(self) -> dict[str, object]:
|
||
return {
|
||
"key": self.key,
|
||
"value": self.value,
|
||
"metadata": self.metadata,
|
||
"score": self.score,
|
||
"created_at": self.created_at.isoformat(),
|
||
}
|
||
|
||
|
||
class Memory(ABC):
|
||
"""记忆抽象基类
|
||
|
||
三层记忆系统的统一接口:
|
||
- WorkingMemory: 当前任务上下文(Redis, 短生命周期)
|
||
- EpisodicMemory: 任务经验(pgvector+PG, 永久)
|
||
- SemanticMemory: 知识库(RAG+Graph, 永久)
|
||
"""
|
||
|
||
@abstractmethod
|
||
async def store(self, key: str, value: object, metadata: MetadataDict | None = None) -> None:
|
||
"""存储记忆"""
|
||
...
|
||
|
||
@abstractmethod
|
||
async def retrieve(self, key: str) -> MemoryItem | None:
|
||
"""按 key 精确检索"""
|
||
...
|
||
|
||
@abstractmethod
|
||
async def search(
|
||
self, query: str, top_k: int = 5, filters: MetadataDict | None = None
|
||
) -> list[MemoryItem]:
|
||
"""语义检索"""
|
||
...
|
||
|
||
@abstractmethod
|
||
async def delete(self, key: str) -> bool:
|
||
"""删除记忆"""
|
||
...
|
||
|
||
async def store_batch(self, items: list[tuple[str, object, MetadataDict | None]]) -> None:
|
||
"""批量存储"""
|
||
for key, value, metadata in items:
|
||
await self.store(key, value, metadata)
|
||
|
||
async def get_context(self, query: str, token_budget: int = 3000) -> str:
|
||
"""获取格式化的上下文字符串(用于注入 Prompt)"""
|
||
items = await self.search(query, top_k=10)
|
||
context_parts = []
|
||
total_tokens = 0
|
||
for item in items:
|
||
text = str(item.value)
|
||
estimated_tokens = len(text) // 4 # 粗略估算
|
||
if total_tokens + estimated_tokens > token_budget:
|
||
break
|
||
context_parts.append(text)
|
||
total_tokens += estimated_tokens
|
||
return "\n".join(context_parts)
|