360 lines
13 KiB
Python
360 lines
13 KiB
Python
"""U4 测试: 记忆系统 - 三层记忆 + 混合检索 + BaseAgent 生命周期集成"""
|
||
|
||
import math
|
||
from datetime import datetime, timedelta, timezone
|
||
from unittest.mock import AsyncMock
|
||
|
||
import pytest
|
||
|
||
from agentkit.core.base import BaseAgent
|
||
from agentkit.core.protocol import AgentCapability, TaskMessage, TaskResult, TaskStatus
|
||
from agentkit.memory.base import Memory, MemoryItem
|
||
from agentkit.memory.episodic import EpisodicMemory
|
||
from agentkit.memory.retriever import MemoryRetriever
|
||
from agentkit.memory.semantic import SemanticMemory
|
||
from agentkit.memory.working import WorkingMemory
|
||
|
||
|
||
# ── In-Memory Memory 实现(用于测试) ────────────────────
|
||
|
||
|
||
class InMemoryMemory(Memory):
|
||
"""基于内存的 Memory 实现,用于测试"""
|
||
|
||
def __init__(self):
|
||
self._store: dict[str, MemoryItem] = {}
|
||
|
||
async def store(self, key: str, value, metadata=None) -> None:
|
||
self._store[key] = MemoryItem(
|
||
key=key, value=value, metadata=metadata or {}, score=1.0
|
||
)
|
||
|
||
async def retrieve(self, key: str) -> MemoryItem | None:
|
||
return self._store.get(key)
|
||
|
||
async def search(self, query: str, top_k: int = 5, filters=None) -> list[MemoryItem]:
|
||
results = []
|
||
for item in self._store.values():
|
||
if query.lower() in str(item.value).lower() or query.lower() in item.key.lower():
|
||
results.append(item)
|
||
return results[:top_k]
|
||
|
||
async def delete(self, key: str) -> bool:
|
||
return self._store.pop(key, None) is not None
|
||
|
||
|
||
# ── Memory 基类测试 ──────────────────────────────────────
|
||
|
||
|
||
class TestMemoryBase:
|
||
async def test_in_memory_store_and_retrieve(self):
|
||
mem = InMemoryMemory()
|
||
await mem.store("key1", {"data": "hello"}, {"tag": "test"})
|
||
item = await mem.retrieve("key1")
|
||
assert item is not None
|
||
assert item.value["data"] == "hello"
|
||
assert item.metadata["tag"] == "test"
|
||
|
||
async def test_in_memory_search(self):
|
||
mem = InMemoryMemory()
|
||
await mem.store("agent:task1", "Generated content about AI")
|
||
await mem.store("agent:task2", "Analysis of trends")
|
||
await mem.store("agent:task3", "AI research summary")
|
||
|
||
results = await mem.search("AI")
|
||
assert len(results) == 2
|
||
|
||
async def test_in_memory_delete(self):
|
||
mem = InMemoryMemory()
|
||
await mem.store("key1", "value1")
|
||
assert await mem.delete("key1") is True
|
||
assert await mem.retrieve("key1") is None
|
||
assert await mem.delete("nonexistent") is False
|
||
|
||
async def test_batch_store(self):
|
||
mem = InMemoryMemory()
|
||
items = [("k1", "v1", None), ("k2", "v2", {"tag": "t"})]
|
||
await mem.store_batch(items)
|
||
assert await mem.retrieve("k1") is not None
|
||
assert await mem.retrieve("k2") is not None
|
||
|
||
async def test_get_context(self):
|
||
mem = InMemoryMemory()
|
||
await mem.store("k1", "First context item about Python")
|
||
await mem.store("k2", "Second context item about AI")
|
||
context = await mem.get_context("Python")
|
||
assert "Python" in context
|
||
|
||
|
||
# ── SemanticMemory 测试 ──────────────────────────────────
|
||
|
||
|
||
class TestSemanticMemory:
|
||
async def test_rag_search(self):
|
||
"""通过 RAG 服务检索知识"""
|
||
|
||
class MockRAGService:
|
||
async def search(self, query, knowledge_base_ids=None, top_k=5):
|
||
return [
|
||
{"id": "doc1", "content": f"Knowledge about {query}", "score": 0.9, "source": "kb1"},
|
||
]
|
||
|
||
mem = SemanticMemory(rag_service=MockRAGService(), knowledge_base_ids=["kb1"])
|
||
results = await mem.search("AI trends")
|
||
assert len(results) == 1
|
||
assert results[0].value == "Knowledge about AI trends"
|
||
assert results[0].score == 0.9
|
||
|
||
async def test_graph_search(self):
|
||
"""通过知识图谱检索"""
|
||
|
||
class MockGraphService:
|
||
async def query(self, query, depth=2):
|
||
return [
|
||
{"id": "node1", "content": f"Entity: {query}", "score": 0.7, "entities": ["AI"]},
|
||
]
|
||
|
||
mem = SemanticMemory(graph_service=MockGraphService())
|
||
results = await mem.search("machine learning")
|
||
assert len(results) == 1
|
||
assert results[0].metadata["source"] == "graph"
|
||
|
||
async def test_combined_rag_and_graph(self):
|
||
"""RAG + 图谱联合检索"""
|
||
|
||
class MockRAG:
|
||
async def search(self, query, **kwargs):
|
||
return [{"id": "r1", "content": "RAG result", "score": 0.8, "source": "rag"}]
|
||
|
||
class MockGraph:
|
||
async def query(self, query, **kwargs):
|
||
return [{"id": "g1", "content": "Graph result", "score": 0.6, "source": "graph"}]
|
||
|
||
mem = SemanticMemory(rag_service=MockRAG(), graph_service=MockGraph())
|
||
results = await mem.search("test")
|
||
assert len(results) == 2
|
||
# 按 score 排序
|
||
assert results[0].score >= results[1].score
|
||
|
||
async def test_semantic_read_only(self):
|
||
"""Semantic Memory 通常只读"""
|
||
mem = SemanticMemory()
|
||
assert await mem.delete("any") is False
|
||
|
||
|
||
# ── EpisodicMemory 测试(使用 mock ORM) ─────────────────
|
||
|
||
|
||
class TestEpisodicMemory:
|
||
async def test_time_decay(self):
|
||
"""时间衰减:近期经验权重高于远期"""
|
||
# 直接测试衰减公式
|
||
decay_rate = 0.01
|
||
now = datetime.now(timezone.utc)
|
||
|
||
recent_score = 0.8 * math.exp(-decay_rate * 1) # 1 hour ago
|
||
old_score = 0.8 * math.exp(-decay_rate * 100) # 100 hours ago
|
||
|
||
assert recent_score > old_score
|
||
assert recent_score > 0.7
|
||
assert old_score < 0.5
|
||
|
||
|
||
# ── MemoryRetriever 测试 ─────────────────────────────────
|
||
|
||
|
||
class TestMemoryRetriever:
|
||
async def test_retriever_with_in_memory(self):
|
||
"""混合检索器使用 InMemoryMemory"""
|
||
working = InMemoryMemory()
|
||
await working.store("current_task", "Working on AI content generation")
|
||
|
||
retriever = MemoryRetriever(working_memory=working)
|
||
results = await retriever.retrieve("AI content")
|
||
assert len(results) >= 1
|
||
|
||
async def test_retriever_weights(self):
|
||
"""不同层权重影响排序"""
|
||
working = InMemoryMemory()
|
||
semantic = InMemoryMemory()
|
||
|
||
await working.store("task1", "Working memory result")
|
||
await semantic.store("doc1", "Semantic memory result")
|
||
|
||
retriever = MemoryRetriever(
|
||
working_memory=working,
|
||
semantic_memory=semantic,
|
||
weights={"working": 0.2, "semantic": 0.8},
|
||
)
|
||
results = await retriever.retrieve("result")
|
||
# Semantic 权重更高,应排前面
|
||
if len(results) >= 2:
|
||
assert results[0].score >= results[1].score
|
||
|
||
async def test_retriever_token_budget(self):
|
||
"""Token 预算管理"""
|
||
working = InMemoryMemory()
|
||
for i in range(20):
|
||
await working.store(f"item_{i}", f"Long content item number {i} " * 50)
|
||
|
||
retriever = MemoryRetriever(working_memory=working)
|
||
results = await retriever.retrieve("content", token_budget=200)
|
||
total_chars = sum(len(str(r.value)) for r in results)
|
||
# 粗略估算 token 数不应远超预算
|
||
assert total_chars // 4 <= 250 # 允许少量溢出
|
||
|
||
async def test_get_context_string(self):
|
||
"""获取格式化上下文字符串"""
|
||
working = InMemoryMemory()
|
||
await working.store("ctx1", "Context about Python programming")
|
||
|
||
retriever = MemoryRetriever(working_memory=working)
|
||
context = await retriever.get_context_string("Python")
|
||
assert "Python" in context
|
||
|
||
async def test_empty_retriever(self):
|
||
"""无记忆层时不报错"""
|
||
retriever = MemoryRetriever()
|
||
results = await retriever.retrieve("anything")
|
||
assert results == []
|
||
|
||
|
||
# ── BaseAgent + Memory 生命周期集成测试 ───────────────────
|
||
|
||
|
||
class TestAgentMemoryIntegration:
|
||
async def test_memory_injected_into_agent(self):
|
||
"""Memory 可注入到 Agent"""
|
||
|
||
class TestAgent(BaseAgent):
|
||
async def handle_task(self, task):
|
||
return {"result": "done"}
|
||
|
||
def get_capabilities(self):
|
||
return AgentCapability(
|
||
agent_name=self.name, agent_type="test", version="1.0",
|
||
supported_tasks=["test"], max_concurrency=1, description="test",
|
||
)
|
||
|
||
agent = TestAgent(name="test", agent_type="test")
|
||
mem = InMemoryMemory()
|
||
agent.use_memory(mem)
|
||
assert agent.memory is mem
|
||
|
||
async def test_on_task_complete_stores_memory(self):
|
||
"""on_task_complete 钩子可存储记忆"""
|
||
|
||
class MemoryAwareAgent(BaseAgent):
|
||
async def handle_task(self, task):
|
||
return {"answer": "42"}
|
||
|
||
def get_capabilities(self):
|
||
return AgentCapability(
|
||
agent_name=self.name, agent_type="test", version="1.0",
|
||
supported_tasks=["test"], max_concurrency=1, description="test",
|
||
)
|
||
|
||
async def on_task_complete(self, task, output):
|
||
if self.memory:
|
||
await self.memory.store(
|
||
f"task:{task.task_id}",
|
||
output,
|
||
{"agent_name": self.name, "task_type": task.task_type},
|
||
)
|
||
|
||
mem = InMemoryMemory()
|
||
agent = MemoryAwareAgent(name="mem_agent", agent_type="test")
|
||
agent.use_memory(mem)
|
||
|
||
task = TaskMessage(
|
||
task_id="t-001", agent_name="mem_agent", task_type="test",
|
||
priority=1, input_data={}, callback_url=None,
|
||
created_at=datetime.now(timezone.utc),
|
||
)
|
||
result = await agent.execute(task)
|
||
assert result.status == TaskStatus.COMPLETED
|
||
|
||
# 验证记忆已存储
|
||
stored = await mem.retrieve("task:t-001")
|
||
assert stored is not None
|
||
assert stored.value["answer"] == "42"
|
||
|
||
async def test_on_task_start_loads_memory(self):
|
||
"""on_task_start 钩子可加载记忆到任务上下文"""
|
||
|
||
class ContextAwareAgent(BaseAgent):
|
||
def __init__(self, *args, **kwargs):
|
||
super().__init__(*args, **kwargs)
|
||
self.loaded_context = None
|
||
|
||
async def handle_task(self, task):
|
||
return {"context_used": self.loaded_context is not None}
|
||
|
||
def get_capabilities(self):
|
||
return AgentCapability(
|
||
agent_name=self.name, agent_type="test", version="1.0",
|
||
supported_tasks=["test"], max_concurrency=1, description="test",
|
||
)
|
||
|
||
async def on_task_start(self, task):
|
||
if self.memory:
|
||
context = await self.memory.get_context(task.task_type)
|
||
self.loaded_context = context
|
||
|
||
mem = InMemoryMemory()
|
||
await mem.store("test", "Previous experience with similar tasks")
|
||
|
||
agent = ContextAwareAgent(name="ctx_agent", agent_type="test")
|
||
agent.use_memory(mem)
|
||
|
||
task = TaskMessage(
|
||
task_id="t-002", agent_name="ctx_agent", task_type="test",
|
||
priority=1, input_data={}, callback_url=None,
|
||
created_at=datetime.now(timezone.utc),
|
||
)
|
||
result = await agent.execute(task)
|
||
assert result.output_data["context_used"] is True
|
||
|
||
async def test_on_task_failed_records_failure(self):
|
||
"""on_task_failed 钩子可记录失败模式"""
|
||
|
||
class ResilientAgent(BaseAgent):
|
||
def __init__(self, *args, **kwargs):
|
||
super().__init__(*args, **kwargs)
|
||
self.failure_recorded = False
|
||
|
||
async def handle_task(self, task):
|
||
raise ValueError("simulated failure")
|
||
|
||
def get_capabilities(self):
|
||
return AgentCapability(
|
||
agent_name=self.name, agent_type="test", version="1.0",
|
||
supported_tasks=["test"], max_concurrency=1, description="test",
|
||
)
|
||
|
||
async def on_task_failed(self, task, error):
|
||
if self.memory:
|
||
await self.memory.store(
|
||
f"failure:{task.task_id}",
|
||
{"error": str(error), "task_type": task.task_type},
|
||
{"outcome": "failure"},
|
||
)
|
||
self.failure_recorded = True
|
||
|
||
mem = InMemoryMemory()
|
||
agent = ResilientAgent(name="resilient", agent_type="test")
|
||
agent.use_memory(mem)
|
||
|
||
task = TaskMessage(
|
||
task_id="t-003", agent_name="resilient", task_type="test",
|
||
priority=1, input_data={}, callback_url=None,
|
||
created_at=datetime.now(timezone.utc),
|
||
)
|
||
result = await agent.execute(task)
|
||
assert result.status == TaskStatus.FAILED
|
||
assert agent.failure_recorded is True
|
||
|
||
stored = await mem.retrieve("failure:t-003")
|
||
assert stored is not None
|
||
assert "simulated failure" in stored.value["error"]
|