278 lines
9.5 KiB
Python
278 lines
9.5 KiB
Python
"""Integration tests for Agent lifecycle: start → execute task → return result → stop"""
|
|
|
|
import pytest
|
|
from datetime import datetime, timezone
|
|
from unittest.mock import AsyncMock
|
|
|
|
from agentkit.core.base import BaseAgent
|
|
from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent
|
|
from agentkit.core.protocol import (
|
|
AgentCapability,
|
|
AgentStatus,
|
|
TaskMessage,
|
|
TaskResult,
|
|
TaskStatus,
|
|
)
|
|
from agentkit.memory.base import Memory, MemoryItem
|
|
from agentkit.tools.function_tool import FunctionTool
|
|
|
|
|
|
# ── Helpers ────────────────────────────────────────────────
|
|
|
|
|
|
class InMemoryMemory(Memory):
|
|
"""In-memory Memory implementation for testing without Redis/PG."""
|
|
|
|
def __init__(self):
|
|
self._store: dict[str, MemoryItem] = {}
|
|
|
|
async def store(self, key: str, value, metadata=None) -> None:
|
|
self._store[key] = MemoryItem(
|
|
key=key, value=value, metadata=metadata or {}, created_at=datetime.now(timezone.utc)
|
|
)
|
|
|
|
async def retrieve(self, key: str) -> MemoryItem | None:
|
|
return self._store.get(key)
|
|
|
|
async def search(self, query: str, top_k: int = 5, filters=None) -> list[MemoryItem]:
|
|
results = []
|
|
for item in self._store.values():
|
|
if query.lower() in str(item.value).lower() or query.lower() in item.key.lower():
|
|
results.append(item)
|
|
return results[:top_k]
|
|
|
|
async def delete(self, key: str) -> bool:
|
|
if key in self._store:
|
|
del self._store[key]
|
|
return True
|
|
return False
|
|
|
|
|
|
class TrackingAgent(BaseAgent):
|
|
"""Agent that records lifecycle hook calls for testing."""
|
|
|
|
def __init__(self, should_fail: bool = False):
|
|
super().__init__(name="tracking_agent", agent_type="tracking")
|
|
self.should_fail = should_fail
|
|
self.hook_calls: list[str] = []
|
|
|
|
def get_capabilities(self) -> AgentCapability:
|
|
return AgentCapability(
|
|
agent_name=self.name,
|
|
agent_type=self.agent_type,
|
|
version=self.version,
|
|
supported_tasks=["tracking"],
|
|
max_concurrency=1,
|
|
description="Tracking test agent",
|
|
)
|
|
|
|
async def on_task_start(self, task: TaskMessage) -> None:
|
|
self.hook_calls.append("on_task_start")
|
|
|
|
async def on_task_complete(self, task: TaskMessage, output: dict) -> None:
|
|
self.hook_calls.append("on_task_complete")
|
|
|
|
async def on_task_failed(self, task: TaskMessage, error: Exception) -> None:
|
|
self.hook_calls.append("on_task_failed")
|
|
|
|
async def handle_task(self, task: TaskMessage) -> dict:
|
|
if self.should_fail:
|
|
raise RuntimeError("Intentional failure for testing")
|
|
return {"message": f"Handled task {task.task_id}"}
|
|
|
|
|
|
def _make_task(**overrides) -> TaskMessage:
|
|
defaults = dict(
|
|
task_id="task-001",
|
|
agent_name="test_agent",
|
|
task_type="test_task",
|
|
priority=1,
|
|
input_data={"query": "hello"},
|
|
callback_url=None,
|
|
created_at=datetime.now(timezone.utc),
|
|
)
|
|
defaults.update(overrides)
|
|
return TaskMessage(**defaults)
|
|
|
|
|
|
# ── Tests ──────────────────────────────────────────────────
|
|
|
|
|
|
@pytest.mark.integration
|
|
async def test_config_driven_agent_lifecycle():
|
|
"""ConfigDrivenAgent from config → start → execute task → return TaskResult → stop."""
|
|
config = AgentConfig(
|
|
name="lifecycle_agent",
|
|
agent_type="lifecycle_test",
|
|
task_mode="llm_generate",
|
|
description="Test lifecycle agent",
|
|
prompt={
|
|
"identity": "You are a test agent",
|
|
"instructions": "Process the input",
|
|
"output_format": "JSON",
|
|
},
|
|
)
|
|
|
|
mock_llm = AsyncMock()
|
|
mock_llm.chat = AsyncMock(return_value='{"result": "processed"}')
|
|
|
|
agent = ConfigDrivenAgent(config=config, llm_client=mock_llm)
|
|
|
|
# Start without Redis (local mode)
|
|
await agent.start()
|
|
assert agent.status == AgentStatus.ONLINE
|
|
|
|
# Execute a task
|
|
task = _make_task(agent_name="lifecycle_agent", task_type="lifecycle_test")
|
|
result = await agent.execute(task)
|
|
|
|
assert isinstance(result, TaskResult)
|
|
assert result.task_id == "task-001"
|
|
assert result.status == TaskStatus.COMPLETED
|
|
assert result.output_data is not None
|
|
assert result.error_message is None
|
|
|
|
# Stop
|
|
await agent.stop()
|
|
assert agent.status == AgentStatus.OFFLINE
|
|
|
|
|
|
@pytest.mark.integration
|
|
async def test_lifecycle_hooks_called_in_order():
|
|
"""BaseAgent lifecycle hooks called in order: on_task_start → handle_task → on_task_complete."""
|
|
agent = TrackingAgent(should_fail=False)
|
|
await agent.start()
|
|
|
|
task = _make_task(agent_name="tracking_agent", task_type="tracking")
|
|
result = await agent.execute(task)
|
|
|
|
assert result.status == TaskStatus.COMPLETED
|
|
assert agent.hook_calls == ["on_task_start", "on_task_complete"]
|
|
|
|
await agent.stop()
|
|
|
|
|
|
@pytest.mark.integration
|
|
async def test_task_failure_triggers_on_task_failed():
|
|
"""Task failure triggers on_task_failed, TaskResult status is FAILED."""
|
|
agent = TrackingAgent(should_fail=True)
|
|
await agent.start()
|
|
|
|
task = _make_task(agent_name="tracking_agent", task_type="tracking")
|
|
result = await agent.execute(task)
|
|
|
|
assert result.status == TaskStatus.FAILED
|
|
assert result.error_message == "Intentional failure for testing"
|
|
assert "on_task_failed" in agent.hook_calls
|
|
# on_task_start should be called before on_task_failed
|
|
assert agent.hook_calls.index("on_task_start") < agent.hook_calls.index("on_task_failed")
|
|
|
|
await agent.stop()
|
|
|
|
|
|
@pytest.mark.integration
|
|
async def test_agent_with_working_memory():
|
|
"""Agent with WorkingMemory stores and retrieves context during task execution."""
|
|
|
|
class MemoryAgent(BaseAgent):
|
|
def __init__(self, memory: Memory):
|
|
super().__init__(name="memory_agent", agent_type="memory_test")
|
|
self.use_memory(memory)
|
|
|
|
def get_capabilities(self) -> AgentCapability:
|
|
return AgentCapability(
|
|
agent_name=self.name,
|
|
agent_type=self.agent_type,
|
|
version=self.version,
|
|
supported_tasks=["memory_test"],
|
|
max_concurrency=1,
|
|
description="Memory test agent",
|
|
)
|
|
|
|
async def on_task_start(self, task: TaskMessage) -> None:
|
|
# Store context at task start
|
|
if self.memory:
|
|
await self.memory.store(
|
|
f"ctx:{task.task_id}",
|
|
{"task_type": task.task_type, "input": task.input_data},
|
|
)
|
|
|
|
async def handle_task(self, task: TaskMessage) -> dict:
|
|
# Retrieve stored context
|
|
if self.memory:
|
|
item = await self.memory.retrieve(f"ctx:{task.task_id}")
|
|
if item:
|
|
return {"retrieved_context": item.value, "processed": True}
|
|
return {"processed": True, "retrieved_context": None}
|
|
|
|
memory = InMemoryMemory()
|
|
agent = MemoryAgent(memory=memory)
|
|
await agent.start()
|
|
|
|
task = _make_task(agent_name="memory_agent", task_type="memory_test")
|
|
result = await agent.execute(task)
|
|
|
|
assert result.status == TaskStatus.COMPLETED
|
|
assert result.output_data["processed"] is True
|
|
assert result.output_data["retrieved_context"] is not None
|
|
assert result.output_data["retrieved_context"]["task_type"] == "memory_test"
|
|
|
|
# Verify memory still has the data
|
|
stored = await memory.retrieve("ctx:task-001")
|
|
assert stored is not None
|
|
|
|
await agent.stop()
|
|
|
|
|
|
@pytest.mark.integration
|
|
async def test_agent_with_episodic_memory():
|
|
"""Agent with EpisodicMemory records experience after task completion."""
|
|
|
|
class EpisodicAgent(BaseAgent):
|
|
def __init__(self, memory: Memory):
|
|
super().__init__(name="episodic_agent", agent_type="episodic_test")
|
|
self.use_memory(memory)
|
|
|
|
def get_capabilities(self) -> AgentCapability:
|
|
return AgentCapability(
|
|
agent_name=self.name,
|
|
agent_type=self.agent_type,
|
|
version=self.version,
|
|
supported_tasks=["episodic_test"],
|
|
max_concurrency=1,
|
|
description="Episodic test agent",
|
|
)
|
|
|
|
async def on_task_complete(self, task: TaskMessage, output: dict) -> None:
|
|
# Record experience after task completion
|
|
if self.memory:
|
|
await self.memory.store(
|
|
f"experience:{task.task_id}",
|
|
{
|
|
"input": task.input_data,
|
|
"output": output,
|
|
"task_type": task.task_type,
|
|
},
|
|
metadata={"outcome": "success"},
|
|
)
|
|
|
|
async def handle_task(self, task: TaskMessage) -> dict:
|
|
return {"answer": "42", "confidence": 0.95}
|
|
|
|
memory = InMemoryMemory()
|
|
agent = EpisodicAgent(memory=memory)
|
|
await agent.start()
|
|
|
|
task = _make_task(agent_name="episodic_agent", task_type="episodic_test")
|
|
result = await agent.execute(task)
|
|
|
|
assert result.status == TaskStatus.COMPLETED
|
|
|
|
# Verify experience was recorded
|
|
experience = await memory.retrieve("experience:task-001")
|
|
assert experience is not None
|
|
assert experience.value["output"]["answer"] == "42"
|
|
assert experience.metadata["outcome"] == "success"
|
|
|
|
await agent.stop()
|