test(memory): add memory system tests with BaseAgent lifecycle integration
- InMemoryMemory test implementation - SemanticMemory: RAG + graph search tests - MemoryRetriever: weight-based ranking, token budget - BaseAgent lifecycle: on_task_start loads, on_task_complete stores, on_task_failed records - 19 new tests, total 89 passing
This commit is contained in:
parent
d73a3391ab
commit
cc6a858150
|
|
@ -0,0 +1,359 @@
|
||||||
|
"""U4 测试: 记忆系统 - 三层记忆 + 混合检索 + BaseAgent 生命周期集成"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.core.base import BaseAgent
|
||||||
|
from agentkit.core.protocol import AgentCapability, TaskMessage, TaskResult, TaskStatus
|
||||||
|
from agentkit.memory.base import Memory, MemoryItem
|
||||||
|
from agentkit.memory.episodic import EpisodicMemory
|
||||||
|
from agentkit.memory.retriever import MemoryRetriever
|
||||||
|
from agentkit.memory.semantic import SemanticMemory
|
||||||
|
from agentkit.memory.working import WorkingMemory
|
||||||
|
|
||||||
|
|
||||||
|
# ── In-Memory Memory 实现(用于测试) ────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class InMemoryMemory(Memory):
|
||||||
|
"""基于内存的 Memory 实现,用于测试"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._store: dict[str, MemoryItem] = {}
|
||||||
|
|
||||||
|
async def store(self, key: str, value, metadata=None) -> None:
|
||||||
|
self._store[key] = MemoryItem(
|
||||||
|
key=key, value=value, metadata=metadata or {}, score=1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
async def retrieve(self, key: str) -> MemoryItem | None:
|
||||||
|
return self._store.get(key)
|
||||||
|
|
||||||
|
async def search(self, query: str, top_k: int = 5, filters=None) -> list[MemoryItem]:
|
||||||
|
results = []
|
||||||
|
for item in self._store.values():
|
||||||
|
if query.lower() in str(item.value).lower() or query.lower() in item.key.lower():
|
||||||
|
results.append(item)
|
||||||
|
return results[:top_k]
|
||||||
|
|
||||||
|
async def delete(self, key: str) -> bool:
|
||||||
|
return self._store.pop(key, None) is not None
|
||||||
|
|
||||||
|
|
||||||
|
# ── Memory 基类测试 ──────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestMemoryBase:
|
||||||
|
async def test_in_memory_store_and_retrieve(self):
|
||||||
|
mem = InMemoryMemory()
|
||||||
|
await mem.store("key1", {"data": "hello"}, {"tag": "test"})
|
||||||
|
item = await mem.retrieve("key1")
|
||||||
|
assert item is not None
|
||||||
|
assert item.value["data"] == "hello"
|
||||||
|
assert item.metadata["tag"] == "test"
|
||||||
|
|
||||||
|
async def test_in_memory_search(self):
|
||||||
|
mem = InMemoryMemory()
|
||||||
|
await mem.store("agent:task1", "Generated content about AI")
|
||||||
|
await mem.store("agent:task2", "Analysis of trends")
|
||||||
|
await mem.store("agent:task3", "AI research summary")
|
||||||
|
|
||||||
|
results = await mem.search("AI")
|
||||||
|
assert len(results) == 2
|
||||||
|
|
||||||
|
async def test_in_memory_delete(self):
|
||||||
|
mem = InMemoryMemory()
|
||||||
|
await mem.store("key1", "value1")
|
||||||
|
assert await mem.delete("key1") is True
|
||||||
|
assert await mem.retrieve("key1") is None
|
||||||
|
assert await mem.delete("nonexistent") is False
|
||||||
|
|
||||||
|
async def test_batch_store(self):
|
||||||
|
mem = InMemoryMemory()
|
||||||
|
items = [("k1", "v1", None), ("k2", "v2", {"tag": "t"})]
|
||||||
|
await mem.store_batch(items)
|
||||||
|
assert await mem.retrieve("k1") is not None
|
||||||
|
assert await mem.retrieve("k2") is not None
|
||||||
|
|
||||||
|
async def test_get_context(self):
|
||||||
|
mem = InMemoryMemory()
|
||||||
|
await mem.store("k1", "First context item about Python")
|
||||||
|
await mem.store("k2", "Second context item about AI")
|
||||||
|
context = await mem.get_context("Python")
|
||||||
|
assert "Python" in context
|
||||||
|
|
||||||
|
|
||||||
|
# ── SemanticMemory 测试 ──────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestSemanticMemory:
|
||||||
|
async def test_rag_search(self):
|
||||||
|
"""通过 RAG 服务检索知识"""
|
||||||
|
|
||||||
|
class MockRAGService:
|
||||||
|
async def search(self, query, knowledge_base_ids=None, top_k=5):
|
||||||
|
return [
|
||||||
|
{"id": "doc1", "content": f"Knowledge about {query}", "score": 0.9, "source": "kb1"},
|
||||||
|
]
|
||||||
|
|
||||||
|
mem = SemanticMemory(rag_service=MockRAGService(), knowledge_base_ids=["kb1"])
|
||||||
|
results = await mem.search("AI trends")
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].value == "Knowledge about AI trends"
|
||||||
|
assert results[0].score == 0.9
|
||||||
|
|
||||||
|
async def test_graph_search(self):
|
||||||
|
"""通过知识图谱检索"""
|
||||||
|
|
||||||
|
class MockGraphService:
|
||||||
|
async def query(self, query, depth=2):
|
||||||
|
return [
|
||||||
|
{"id": "node1", "content": f"Entity: {query}", "score": 0.7, "entities": ["AI"]},
|
||||||
|
]
|
||||||
|
|
||||||
|
mem = SemanticMemory(graph_service=MockGraphService())
|
||||||
|
results = await mem.search("machine learning")
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].metadata["source"] == "graph"
|
||||||
|
|
||||||
|
async def test_combined_rag_and_graph(self):
|
||||||
|
"""RAG + 图谱联合检索"""
|
||||||
|
|
||||||
|
class MockRAG:
|
||||||
|
async def search(self, query, **kwargs):
|
||||||
|
return [{"id": "r1", "content": "RAG result", "score": 0.8, "source": "rag"}]
|
||||||
|
|
||||||
|
class MockGraph:
|
||||||
|
async def query(self, query, **kwargs):
|
||||||
|
return [{"id": "g1", "content": "Graph result", "score": 0.6, "source": "graph"}]
|
||||||
|
|
||||||
|
mem = SemanticMemory(rag_service=MockRAG(), graph_service=MockGraph())
|
||||||
|
results = await mem.search("test")
|
||||||
|
assert len(results) == 2
|
||||||
|
# 按 score 排序
|
||||||
|
assert results[0].score >= results[1].score
|
||||||
|
|
||||||
|
async def test_semantic_read_only(self):
|
||||||
|
"""Semantic Memory 通常只读"""
|
||||||
|
mem = SemanticMemory()
|
||||||
|
assert await mem.delete("any") is False
|
||||||
|
|
||||||
|
|
||||||
|
# ── EpisodicMemory 测试(使用 mock ORM) ─────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestEpisodicMemory:
|
||||||
|
async def test_time_decay(self):
|
||||||
|
"""时间衰减:近期经验权重高于远期"""
|
||||||
|
# 直接测试衰减公式
|
||||||
|
decay_rate = 0.01
|
||||||
|
now = datetime.utcnow()
|
||||||
|
|
||||||
|
recent_score = 0.8 * math.exp(-decay_rate * 1) # 1 hour ago
|
||||||
|
old_score = 0.8 * math.exp(-decay_rate * 100) # 100 hours ago
|
||||||
|
|
||||||
|
assert recent_score > old_score
|
||||||
|
assert recent_score > 0.7
|
||||||
|
assert old_score < 0.5
|
||||||
|
|
||||||
|
|
||||||
|
# ── MemoryRetriever 测试 ─────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestMemoryRetriever:
|
||||||
|
async def test_retriever_with_in_memory(self):
|
||||||
|
"""混合检索器使用 InMemoryMemory"""
|
||||||
|
working = InMemoryMemory()
|
||||||
|
await working.store("current_task", "Working on AI content generation")
|
||||||
|
|
||||||
|
retriever = MemoryRetriever(working_memory=working)
|
||||||
|
results = await retriever.retrieve("AI content")
|
||||||
|
assert len(results) >= 1
|
||||||
|
|
||||||
|
async def test_retriever_weights(self):
|
||||||
|
"""不同层权重影响排序"""
|
||||||
|
working = InMemoryMemory()
|
||||||
|
semantic = InMemoryMemory()
|
||||||
|
|
||||||
|
await working.store("task1", "Working memory result")
|
||||||
|
await semantic.store("doc1", "Semantic memory result")
|
||||||
|
|
||||||
|
retriever = MemoryRetriever(
|
||||||
|
working_memory=working,
|
||||||
|
semantic_memory=semantic,
|
||||||
|
weights={"working": 0.2, "semantic": 0.8},
|
||||||
|
)
|
||||||
|
results = await retriever.retrieve("result")
|
||||||
|
# Semantic 权重更高,应排前面
|
||||||
|
if len(results) >= 2:
|
||||||
|
assert results[0].score >= results[1].score
|
||||||
|
|
||||||
|
async def test_retriever_token_budget(self):
|
||||||
|
"""Token 预算管理"""
|
||||||
|
working = InMemoryMemory()
|
||||||
|
for i in range(20):
|
||||||
|
await working.store(f"item_{i}", f"Long content item number {i} " * 50)
|
||||||
|
|
||||||
|
retriever = MemoryRetriever(working_memory=working)
|
||||||
|
results = await retriever.retrieve("content", token_budget=200)
|
||||||
|
total_chars = sum(len(str(r.value)) for r in results)
|
||||||
|
# 粗略估算 token 数不应远超预算
|
||||||
|
assert total_chars // 4 <= 250 # 允许少量溢出
|
||||||
|
|
||||||
|
async def test_get_context_string(self):
|
||||||
|
"""获取格式化上下文字符串"""
|
||||||
|
working = InMemoryMemory()
|
||||||
|
await working.store("ctx1", "Context about Python programming")
|
||||||
|
|
||||||
|
retriever = MemoryRetriever(working_memory=working)
|
||||||
|
context = await retriever.get_context_string("Python")
|
||||||
|
assert "Python" in context
|
||||||
|
|
||||||
|
async def test_empty_retriever(self):
|
||||||
|
"""无记忆层时不报错"""
|
||||||
|
retriever = MemoryRetriever()
|
||||||
|
results = await retriever.retrieve("anything")
|
||||||
|
assert results == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── BaseAgent + Memory 生命周期集成测试 ───────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentMemoryIntegration:
|
||||||
|
async def test_memory_injected_into_agent(self):
|
||||||
|
"""Memory 可注入到 Agent"""
|
||||||
|
|
||||||
|
class TestAgent(BaseAgent):
|
||||||
|
async def handle_task(self, task):
|
||||||
|
return {"result": "done"}
|
||||||
|
|
||||||
|
def get_capabilities(self):
|
||||||
|
return AgentCapability(
|
||||||
|
agent_name=self.name, agent_type="test", version="1.0",
|
||||||
|
supported_tasks=["test"], max_concurrency=1, description="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
agent = TestAgent(name="test", agent_type="test")
|
||||||
|
mem = InMemoryMemory()
|
||||||
|
agent.use_memory(mem)
|
||||||
|
assert agent.memory is mem
|
||||||
|
|
||||||
|
async def test_on_task_complete_stores_memory(self):
|
||||||
|
"""on_task_complete 钩子可存储记忆"""
|
||||||
|
|
||||||
|
class MemoryAwareAgent(BaseAgent):
|
||||||
|
async def handle_task(self, task):
|
||||||
|
return {"answer": "42"}
|
||||||
|
|
||||||
|
def get_capabilities(self):
|
||||||
|
return AgentCapability(
|
||||||
|
agent_name=self.name, agent_type="test", version="1.0",
|
||||||
|
supported_tasks=["test"], max_concurrency=1, description="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_task_complete(self, task, output):
|
||||||
|
if self.memory:
|
||||||
|
await self.memory.store(
|
||||||
|
f"task:{task.task_id}",
|
||||||
|
output,
|
||||||
|
{"agent_name": self.name, "task_type": task.task_type},
|
||||||
|
)
|
||||||
|
|
||||||
|
mem = InMemoryMemory()
|
||||||
|
agent = MemoryAwareAgent(name="mem_agent", agent_type="test")
|
||||||
|
agent.use_memory(mem)
|
||||||
|
|
||||||
|
task = TaskMessage(
|
||||||
|
task_id="t-001", agent_name="mem_agent", task_type="test",
|
||||||
|
priority=1, input_data={}, callback_url=None,
|
||||||
|
created_at=datetime.utcnow(),
|
||||||
|
)
|
||||||
|
result = await agent.execute(task)
|
||||||
|
assert result.status == TaskStatus.COMPLETED
|
||||||
|
|
||||||
|
# 验证记忆已存储
|
||||||
|
stored = await mem.retrieve("task:t-001")
|
||||||
|
assert stored is not None
|
||||||
|
assert stored.value["answer"] == "42"
|
||||||
|
|
||||||
|
async def test_on_task_start_loads_memory(self):
|
||||||
|
"""on_task_start 钩子可加载记忆到任务上下文"""
|
||||||
|
|
||||||
|
class ContextAwareAgent(BaseAgent):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.loaded_context = None
|
||||||
|
|
||||||
|
async def handle_task(self, task):
|
||||||
|
return {"context_used": self.loaded_context is not None}
|
||||||
|
|
||||||
|
def get_capabilities(self):
|
||||||
|
return AgentCapability(
|
||||||
|
agent_name=self.name, agent_type="test", version="1.0",
|
||||||
|
supported_tasks=["test"], max_concurrency=1, description="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_task_start(self, task):
|
||||||
|
if self.memory:
|
||||||
|
context = await self.memory.get_context(task.task_type)
|
||||||
|
self.loaded_context = context
|
||||||
|
|
||||||
|
mem = InMemoryMemory()
|
||||||
|
await mem.store("test", "Previous experience with similar tasks")
|
||||||
|
|
||||||
|
agent = ContextAwareAgent(name="ctx_agent", agent_type="test")
|
||||||
|
agent.use_memory(mem)
|
||||||
|
|
||||||
|
task = TaskMessage(
|
||||||
|
task_id="t-002", agent_name="ctx_agent", task_type="test",
|
||||||
|
priority=1, input_data={}, callback_url=None,
|
||||||
|
created_at=datetime.utcnow(),
|
||||||
|
)
|
||||||
|
result = await agent.execute(task)
|
||||||
|
assert result.output_data["context_used"] is True
|
||||||
|
|
||||||
|
async def test_on_task_failed_records_failure(self):
|
||||||
|
"""on_task_failed 钩子可记录失败模式"""
|
||||||
|
|
||||||
|
class ResilientAgent(BaseAgent):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.failure_recorded = False
|
||||||
|
|
||||||
|
async def handle_task(self, task):
|
||||||
|
raise ValueError("simulated failure")
|
||||||
|
|
||||||
|
def get_capabilities(self):
|
||||||
|
return AgentCapability(
|
||||||
|
agent_name=self.name, agent_type="test", version="1.0",
|
||||||
|
supported_tasks=["test"], max_concurrency=1, description="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_task_failed(self, task, error):
|
||||||
|
if self.memory:
|
||||||
|
await self.memory.store(
|
||||||
|
f"failure:{task.task_id}",
|
||||||
|
{"error": str(error), "task_type": task.task_type},
|
||||||
|
{"outcome": "failure"},
|
||||||
|
)
|
||||||
|
self.failure_recorded = True
|
||||||
|
|
||||||
|
mem = InMemoryMemory()
|
||||||
|
agent = ResilientAgent(name="resilient", agent_type="test")
|
||||||
|
agent.use_memory(mem)
|
||||||
|
|
||||||
|
task = TaskMessage(
|
||||||
|
task_id="t-003", agent_name="resilient", task_type="test",
|
||||||
|
priority=1, input_data={}, callback_url=None,
|
||||||
|
created_at=datetime.utcnow(),
|
||||||
|
)
|
||||||
|
result = await agent.execute(task)
|
||||||
|
assert result.status == TaskStatus.FAILED
|
||||||
|
assert agent.failure_recorded is True
|
||||||
|
|
||||||
|
stored = await mem.retrieve("failure:t-003")
|
||||||
|
assert stored is not None
|
||||||
|
assert "simulated failure" in stored.value["error"]
|
||||||
Loading…
Reference in New Issue