fischer-agentkit/src/agentkit/memory/base.py

87 lines
2.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Memory 抽象基类 - 统一记忆接口"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import TypeAlias
# 共享类型别名 — 跨 memory 子系统复用,避免 `Any` 残留。
# MetadataValue 覆盖 metadata dict 中实际出现的原始类型;
# MemoryValue 额外允许 dict/list 容器以容纳结构化负载(如 episodic 经验字典)。
MetadataValue: TypeAlias = str | int | float | bool | None
MetadataDict: TypeAlias = dict[str, MetadataValue]
MemoryValue: TypeAlias = (
str | int | float | bool | None | dict[str, MetadataValue] | list[MetadataValue]
)
@dataclass
class MemoryItem:
"""记忆条目"""
key: str
value: object
metadata: MetadataDict = field(default_factory=dict)
score: float = 1.0
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
def to_dict(self) -> dict[str, object]:
return {
"key": self.key,
"value": self.value,
"metadata": self.metadata,
"score": self.score,
"created_at": self.created_at.isoformat(),
}
class Memory(ABC):
"""记忆抽象基类
三层记忆系统的统一接口:
- WorkingMemory: 当前任务上下文Redis, 短生命周期)
- EpisodicMemory: 任务经验pgvector+PG, 永久)
- SemanticMemory: 知识库RAG+Graph, 永久)
"""
@abstractmethod
async def store(self, key: str, value: object, metadata: MetadataDict | None = None) -> None:
"""存储记忆"""
...
@abstractmethod
async def retrieve(self, key: str) -> MemoryItem | None:
"""按 key 精确检索"""
...
@abstractmethod
async def search(
self, query: str, top_k: int = 5, filters: MetadataDict | None = None
) -> list[MemoryItem]:
"""语义检索"""
...
@abstractmethod
async def delete(self, key: str) -> bool:
"""删除记忆"""
...
async def store_batch(self, items: list[tuple[str, object, MetadataDict | None]]) -> None:
"""批量存储"""
for key, value, metadata in items:
await self.store(key, value, metadata)
async def get_context(self, query: str, token_budget: int = 3000) -> str:
"""获取格式化的上下文字符串(用于注入 Prompt"""
items = await self.search(query, top_k=10)
context_parts = []
total_tokens = 0
for item in items:
text = str(item.value)
estimated_tokens = len(text) // 4 # 粗略估算
if total_tokens + estimated_tokens > token_budget:
break
context_parts.append(text)
total_tokens += estimated_tokens
return "\n".join(context_parts)