"""Memory 抽象基类 - 统一记忆接口""" from abc import ABC, abstractmethod from dataclasses import dataclass, field from datetime import datetime, timezone from typing import TypeAlias # 共享类型别名 — 跨 memory 子系统复用,避免 `Any` 残留。 # MetadataValue 覆盖 metadata dict 中实际出现的原始类型; # MemoryValue 额外允许 dict/list 容器以容纳结构化负载(如 episodic 经验字典)。 MetadataValue: TypeAlias = str | int | float | bool | None MetadataDict: TypeAlias = dict[str, MetadataValue] MemoryValue: TypeAlias = ( str | int | float | bool | None | dict[str, MetadataValue] | list[MetadataValue] ) @dataclass class MemoryItem: """记忆条目""" key: str value: object metadata: MetadataDict = field(default_factory=dict) score: float = 1.0 created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) def to_dict(self) -> dict[str, object]: return { "key": self.key, "value": self.value, "metadata": self.metadata, "score": self.score, "created_at": self.created_at.isoformat(), } class Memory(ABC): """记忆抽象基类 三层记忆系统的统一接口: - WorkingMemory: 当前任务上下文(Redis, 短生命周期) - EpisodicMemory: 任务经验(pgvector+PG, 永久) - SemanticMemory: 知识库(RAG+Graph, 永久) """ @abstractmethod async def store(self, key: str, value: object, metadata: MetadataDict | None = None) -> None: """存储记忆""" ... @abstractmethod async def retrieve(self, key: str) -> MemoryItem | None: """按 key 精确检索""" ... @abstractmethod async def search( self, query: str, top_k: int = 5, filters: MetadataDict | None = None ) -> list[MemoryItem]: """语义检索""" ... @abstractmethod async def delete(self, key: str) -> bool: """删除记忆""" ... async def store_batch(self, items: list[tuple[str, object, MetadataDict | None]]) -> None: """批量存储""" for key, value, metadata in items: await self.store(key, value, metadata) async def get_context(self, query: str, token_budget: int = 3000) -> str: """获取格式化的上下文字符串(用于注入 Prompt)""" items = await self.search(query, top_k=10) context_parts = [] total_tokens = 0 for item in items: text = str(item.value) estimated_tokens = len(text) // 4 # 粗略估算 if total_tokens + estimated_tokens > token_budget: break context_parts.append(text) total_tokens += estimated_tokens return "\n".join(context_parts)