fischer-agentkit/tests/unit/test_memory_system.py

360 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""U4 测试: 记忆系统 - 三层记忆 + 混合检索 + BaseAgent 生命周期集成"""
import math
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock
import pytest
from agentkit.core.base import BaseAgent
from agentkit.core.protocol import AgentCapability, TaskMessage, TaskResult, TaskStatus
from agentkit.memory.base import Memory, MemoryItem
from agentkit.memory.episodic import EpisodicMemory
from agentkit.memory.retriever import MemoryRetriever
from agentkit.memory.semantic import SemanticMemory
from agentkit.memory.working import WorkingMemory
# ── In-Memory Memory 实现(用于测试) ────────────────────
class InMemoryMemory(Memory):
"""基于内存的 Memory 实现,用于测试"""
def __init__(self):
self._store: dict[str, MemoryItem] = {}
async def store(self, key: str, value, metadata=None) -> None:
self._store[key] = MemoryItem(
key=key, value=value, metadata=metadata or {}, score=1.0
)
async def retrieve(self, key: str) -> MemoryItem | None:
return self._store.get(key)
async def search(self, query: str, top_k: int = 5, filters=None) -> list[MemoryItem]:
results = []
for item in self._store.values():
if query.lower() in str(item.value).lower() or query.lower() in item.key.lower():
results.append(item)
return results[:top_k]
async def delete(self, key: str) -> bool:
return self._store.pop(key, None) is not None
# ── Memory 基类测试 ──────────────────────────────────────
class TestMemoryBase:
async def test_in_memory_store_and_retrieve(self):
mem = InMemoryMemory()
await mem.store("key1", {"data": "hello"}, {"tag": "test"})
item = await mem.retrieve("key1")
assert item is not None
assert item.value["data"] == "hello"
assert item.metadata["tag"] == "test"
async def test_in_memory_search(self):
mem = InMemoryMemory()
await mem.store("agent:task1", "Generated content about AI")
await mem.store("agent:task2", "Analysis of trends")
await mem.store("agent:task3", "AI research summary")
results = await mem.search("AI")
assert len(results) == 2
async def test_in_memory_delete(self):
mem = InMemoryMemory()
await mem.store("key1", "value1")
assert await mem.delete("key1") is True
assert await mem.retrieve("key1") is None
assert await mem.delete("nonexistent") is False
async def test_batch_store(self):
mem = InMemoryMemory()
items = [("k1", "v1", None), ("k2", "v2", {"tag": "t"})]
await mem.store_batch(items)
assert await mem.retrieve("k1") is not None
assert await mem.retrieve("k2") is not None
async def test_get_context(self):
mem = InMemoryMemory()
await mem.store("k1", "First context item about Python")
await mem.store("k2", "Second context item about AI")
context = await mem.get_context("Python")
assert "Python" in context
# ── SemanticMemory 测试 ──────────────────────────────────
class TestSemanticMemory:
async def test_rag_search(self):
"""通过 RAG 服务检索知识"""
class MockRAGService:
async def search(self, query, knowledge_base_ids=None, top_k=5):
return [
{"id": "doc1", "content": f"Knowledge about {query}", "score": 0.9, "source": "kb1"},
]
mem = SemanticMemory(rag_service=MockRAGService(), knowledge_base_ids=["kb1"])
results = await mem.search("AI trends")
assert len(results) == 1
assert results[0].value == "Knowledge about AI trends"
assert results[0].score == 0.9
async def test_graph_search(self):
"""通过知识图谱检索"""
class MockGraphService:
async def query(self, query, depth=2):
return [
{"id": "node1", "content": f"Entity: {query}", "score": 0.7, "entities": ["AI"]},
]
mem = SemanticMemory(graph_service=MockGraphService())
results = await mem.search("machine learning")
assert len(results) == 1
assert results[0].metadata["source"] == "graph"
async def test_combined_rag_and_graph(self):
"""RAG + 图谱联合检索"""
class MockRAG:
async def search(self, query, **kwargs):
return [{"id": "r1", "content": "RAG result", "score": 0.8, "source": "rag"}]
class MockGraph:
async def query(self, query, **kwargs):
return [{"id": "g1", "content": "Graph result", "score": 0.6, "source": "graph"}]
mem = SemanticMemory(rag_service=MockRAG(), graph_service=MockGraph())
results = await mem.search("test")
assert len(results) == 2
# 按 score 排序
assert results[0].score >= results[1].score
async def test_semantic_read_only(self):
"""Semantic Memory 通常只读"""
mem = SemanticMemory()
assert await mem.delete("any") is False
# ── EpisodicMemory 测试(使用 mock ORM ─────────────────
class TestEpisodicMemory:
async def test_time_decay(self):
"""时间衰减:近期经验权重高于远期"""
# 直接测试衰减公式
decay_rate = 0.01
now = datetime.now(timezone.utc)
recent_score = 0.8 * math.exp(-decay_rate * 1) # 1 hour ago
old_score = 0.8 * math.exp(-decay_rate * 100) # 100 hours ago
assert recent_score > old_score
assert recent_score > 0.7
assert old_score < 0.5
# ── MemoryRetriever 测试 ─────────────────────────────────
class TestMemoryRetriever:
async def test_retriever_with_in_memory(self):
"""混合检索器使用 InMemoryMemory"""
working = InMemoryMemory()
await working.store("current_task", "Working on AI content generation")
retriever = MemoryRetriever(working_memory=working)
results = await retriever.retrieve("AI content")
assert len(results) >= 1
async def test_retriever_weights(self):
"""不同层权重影响排序"""
working = InMemoryMemory()
semantic = InMemoryMemory()
await working.store("task1", "Working memory result")
await semantic.store("doc1", "Semantic memory result")
retriever = MemoryRetriever(
working_memory=working,
semantic_memory=semantic,
weights={"working": 0.2, "semantic": 0.8},
)
results = await retriever.retrieve("result")
# Semantic 权重更高,应排前面
if len(results) >= 2:
assert results[0].score >= results[1].score
async def test_retriever_token_budget(self):
"""Token 预算管理"""
working = InMemoryMemory()
for i in range(20):
await working.store(f"item_{i}", f"Long content item number {i} " * 50)
retriever = MemoryRetriever(working_memory=working)
results = await retriever.retrieve("content", token_budget=200)
total_chars = sum(len(str(r.value)) for r in results)
# 粗略估算 token 数不应远超预算
assert total_chars // 4 <= 250 # 允许少量溢出
async def test_get_context_string(self):
"""获取格式化上下文字符串"""
working = InMemoryMemory()
await working.store("ctx1", "Context about Python programming")
retriever = MemoryRetriever(working_memory=working)
context = await retriever.get_context_string("Python")
assert "Python" in context
async def test_empty_retriever(self):
"""无记忆层时不报错"""
retriever = MemoryRetriever()
results = await retriever.retrieve("anything")
assert results == []
# ── BaseAgent + Memory 生命周期集成测试 ───────────────────
class TestAgentMemoryIntegration:
async def test_memory_injected_into_agent(self):
"""Memory 可注入到 Agent"""
class TestAgent(BaseAgent):
async def handle_task(self, task):
return {"result": "done"}
def get_capabilities(self):
return AgentCapability(
agent_name=self.name, agent_type="test", version="1.0",
supported_tasks=["test"], max_concurrency=1, description="test",
)
agent = TestAgent(name="test", agent_type="test")
mem = InMemoryMemory()
agent.use_memory(mem)
assert agent.memory is mem
async def test_on_task_complete_stores_memory(self):
"""on_task_complete 钩子可存储记忆"""
class MemoryAwareAgent(BaseAgent):
async def handle_task(self, task):
return {"answer": "42"}
def get_capabilities(self):
return AgentCapability(
agent_name=self.name, agent_type="test", version="1.0",
supported_tasks=["test"], max_concurrency=1, description="test",
)
async def on_task_complete(self, task, output):
if self.memory:
await self.memory.store(
f"task:{task.task_id}",
output,
{"agent_name": self.name, "task_type": task.task_type},
)
mem = InMemoryMemory()
agent = MemoryAwareAgent(name="mem_agent", agent_type="test")
agent.use_memory(mem)
task = TaskMessage(
task_id="t-001", agent_name="mem_agent", task_type="test",
priority=1, input_data={}, callback_url=None,
created_at=datetime.now(timezone.utc),
)
result = await agent.execute(task)
assert result.status == TaskStatus.COMPLETED
# 验证记忆已存储
stored = await mem.retrieve("task:t-001")
assert stored is not None
assert stored.value["answer"] == "42"
async def test_on_task_start_loads_memory(self):
"""on_task_start 钩子可加载记忆到任务上下文"""
class ContextAwareAgent(BaseAgent):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.loaded_context = None
async def handle_task(self, task):
return {"context_used": self.loaded_context is not None}
def get_capabilities(self):
return AgentCapability(
agent_name=self.name, agent_type="test", version="1.0",
supported_tasks=["test"], max_concurrency=1, description="test",
)
async def on_task_start(self, task):
if self.memory:
context = await self.memory.get_context(task.task_type)
self.loaded_context = context
mem = InMemoryMemory()
await mem.store("test", "Previous experience with similar tasks")
agent = ContextAwareAgent(name="ctx_agent", agent_type="test")
agent.use_memory(mem)
task = TaskMessage(
task_id="t-002", agent_name="ctx_agent", task_type="test",
priority=1, input_data={}, callback_url=None,
created_at=datetime.now(timezone.utc),
)
result = await agent.execute(task)
assert result.output_data["context_used"] is True
async def test_on_task_failed_records_failure(self):
"""on_task_failed 钩子可记录失败模式"""
class ResilientAgent(BaseAgent):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.failure_recorded = False
async def handle_task(self, task):
raise ValueError("simulated failure")
def get_capabilities(self):
return AgentCapability(
agent_name=self.name, agent_type="test", version="1.0",
supported_tasks=["test"], max_concurrency=1, description="test",
)
async def on_task_failed(self, task, error):
if self.memory:
await self.memory.store(
f"failure:{task.task_id}",
{"error": str(error), "task_type": task.task_type},
{"outcome": "failure"},
)
self.failure_recorded = True
mem = InMemoryMemory()
agent = ResilientAgent(name="resilient", agent_type="test")
agent.use_memory(mem)
task = TaskMessage(
task_id="t-003", agent_name="resilient", task_type="test",
priority=1, input_data={}, callback_url=None,
created_at=datetime.now(timezone.utc),
)
result = await agent.execute(task)
assert result.status == TaskStatus.FAILED
assert agent.failure_recorded is True
stored = await mem.retrieve("failure:t-003")
assert stored is not None
assert "simulated failure" in stored.value["error"]