fischer-agentkit/tests/unit/test_execute_stream_hooks.py

321 lines
10 KiB
Python

"""U2 tests: execute_stream evolution hook wiring (OQ6 fix).
Verifies that ConfigDrivenAgent.execute_stream() fires evolution hooks
(on_task_complete / on_task_failed) in its finally block with lifecycle
parity to the sync execute() path. Covers happy path, failure, cancellation,
early close, evolution-error suppression, backpressure cap, REST/stream
parity, and evolution-disabled no-op.
"""
import asyncio
import pytest
from agentkit.core.config_driven import (
AgentConfig,
ConfigDrivenAgent,
drain_pending_evolution_tasks,
)
from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus
from agentkit.core.react import ReActEvent
# ── Helpers ──────────────────────────────────────────────
def _make_task(**overrides) -> TaskMessage:
defaults = dict(
task_id="stream-task-001",
agent_name="stream_agent",
task_type="generate",
priority=1,
input_data={"query": "hello"},
callback_url=None,
created_at=None,
)
defaults.update(overrides)
return TaskMessage.from_dict(defaults)
def _make_agent(max_concurrency: int = 1) -> ConfigDrivenAgent:
config = AgentConfig.from_dict(
{
"name": "stream_agent",
"agent_type": "content_generation",
"task_mode": "llm_generate",
"prompt": {
"identity": "test agent",
"instructions": "do the thing",
"output_format": "text",
},
"max_concurrency": max_concurrency,
}
)
agent = ConfigDrivenAgent(config=config)
agent._evolution_enabled = True
return agent
def _final_answer_event(output: str = "hello") -> ReActEvent:
return ReActEvent(
event_type="final_answer",
step=0,
data={"output": output},
)
@pytest.fixture(autouse=True)
async def _isolate_evolution_state():
"""Reset module-level evolution state before each test, drain after.
Without this, stuck tasks from a prior test would inflate the pending
set and break backpressure assertions in later tests.
"""
import agentkit.core.config_driven as cd
for task in list(cd._pending_evolution_tasks):
task.cancel()
if cd._pending_evolution_tasks:
await asyncio.gather(*cd._pending_evolution_tasks, return_exceptions=True)
cd._pending_evolution_tasks.clear()
cd._evolution_dropped_count = 0
yield
await drain_pending_evolution_tasks()
# ── Happy path ───────────────────────────────────────────
class TestExecuteStreamHooks:
async def test_success_fires_on_task_complete(self):
"""Stream completion fires evolve_after_task with COMPLETED status."""
agent = _make_agent()
fired: list[TaskResult] = []
async def record_evolve(task, result, memory_store=None):
fired.append(result)
agent.evolve_after_task = record_evolve
async def good_stream(task):
yield _final_answer_event("hello world")
agent.handle_task_stream = good_stream
events = []
async for event in agent.execute_stream(_make_task()):
events.append(event)
await drain_pending_evolution_tasks()
assert len(events) == 1
assert events[0].event_type == "final_answer"
assert len(fired) == 1
assert fired[0].status == TaskStatus.COMPLETED
assert fired[0].output_data == {"content": "hello world"}
async def test_failure_fires_on_task_failed(self):
"""Stream exception fires evolve_after_task with FAILED status."""
agent = _make_agent()
fired: list[TaskResult] = []
async def record_evolve(task, result, memory_store=None):
fired.append(result)
agent.evolve_after_task = record_evolve
async def failing_stream(task):
yield _final_answer_event("partial") # yield once before failing
raise RuntimeError("stream blew up")
agent.handle_task_stream = failing_stream
with pytest.raises(RuntimeError, match="stream blew up"):
async for _ in agent.execute_stream(_make_task()):
pass
await drain_pending_evolution_tasks()
assert len(fired) == 1
assert fired[0].status == TaskStatus.FAILED
assert "stream blew up" in (fired[0].error_message or "")
# ── Edge cases ───────────────────────────────────────────
class TestExecuteStreamEdgeCases:
async def test_cancellation_fires_cancelled_status(self):
"""Stream cancelled mid-flight fires hooks with CANCELLED status."""
agent = _make_agent()
fired: list[TaskResult] = []
async def record_evolve(task, result, memory_store=None):
fired.append(result)
agent.evolve_after_task = record_evolve
started = asyncio.Event()
async def slow_stream(task):
started.set()
await asyncio.sleep(60)
yield _final_answer_event("never reached")
agent.handle_task_stream = slow_stream
async def consume():
async for _ in agent.execute_stream(_make_task()):
pass
consumer = asyncio.create_task(consume())
await started.wait()
await asyncio.sleep(0.05) # let it settle into sleep(60)
consumer.cancel()
with pytest.raises(asyncio.CancelledError):
await consumer
await drain_pending_evolution_tasks()
assert len(fired) == 1
assert fired[0].status == TaskStatus.CANCELLED
async def test_stream_closed_early_fires_cancelled(self):
"""Consumer aclose() before final_answer fires CANCELLED status."""
agent = _make_agent()
fired: list[TaskResult] = []
async def record_evolve(task, result, memory_store=None):
fired.append(result)
agent.evolve_after_task = record_evolve
async def blocking_stream(task):
yield ReActEvent(event_type="thinking", step=0, data={"content": "thinking..."})
await asyncio.sleep(60)
yield _final_answer_event("late")
agent.handle_task_stream = blocking_stream
gen = agent.execute_stream(_make_task())
first = await gen.__anext__()
assert first.event_type == "thinking"
await gen.aclose()
await drain_pending_evolution_tasks()
assert len(fired) == 1
assert fired[0].status == TaskStatus.CANCELLED
assert "stream closed before completion" in (fired[0].error_message or "")
async def test_evolution_error_does_not_propagate(self):
"""Evolution task error is swallowed — stream completes normally."""
agent = _make_agent()
async def failing_evolve(task, result, memory_store=None):
raise RuntimeError("evolution exploded")
agent.evolve_after_task = failing_evolve
async def good_stream(task):
yield _final_answer_event("ok")
agent.handle_task_stream = good_stream
events = []
async for event in agent.execute_stream(_make_task()):
events.append(event)
# drain must not raise despite evolution error
await drain_pending_evolution_tasks()
assert len(events) == 1
assert events[0].data.get("output") == "ok"
async def test_backpressure_cap_drops(self):
"""When pending evolution tasks hit cap, excess is dropped + counted."""
agent = _make_agent(max_concurrency=1) # cap = max(2, 1*2) = 2
block = asyncio.Event()
async def stuck_evolve(task, result, memory_store=None):
await block.wait()
agent.evolve_after_task = stuck_evolve
async def good_stream(task):
yield _final_answer_event("ok")
agent.handle_task_stream = good_stream
import agentkit.core.config_driven as cd
# Fire 3 streams — first 2 fill the cap (stuck), 3rd is dropped
for i in range(3):
async for _ in agent.execute_stream(_make_task(task_id=f"bp-{i}")):
pass
await asyncio.sleep(0) # yield to let evolution tasks start
assert cd._evolution_dropped_count == 1
# Cleanup: release stuck tasks and drain
block.set()
await drain_pending_evolution_tasks()
# ── Parity & disabled ────────────────────────────────────
class TestExecuteStreamParity:
async def test_parity_rest_vs_stream(self):
"""Both REST on_task_complete and execute_stream fire COMPLETED evolve."""
agent = _make_agent()
stream_fired: list[TaskResult] = []
rest_fired: list[TaskResult] = []
async def good_stream(task):
yield _final_answer_event("hello")
agent.handle_task_stream = good_stream
async def stream_evolve(task, result, memory_store=None):
stream_fired.append(result)
agent.evolve_after_task = stream_evolve
async for _ in agent.execute_stream(_make_task(task_id="stream-1")):
pass
await drain_pending_evolution_tasks()
async def rest_evolve(task, result, memory_store=None):
rest_fired.append(result)
agent.evolve_after_task = rest_evolve
await agent.on_task_complete(_make_task(task_id="rest-1"), {"content": "hello"})
assert len(stream_fired) == 1
assert stream_fired[0].status == TaskStatus.COMPLETED
assert len(rest_fired) == 1
assert rest_fired[0].status == TaskStatus.COMPLETED
async def test_evolution_disabled_no_hooks(self):
"""When _evolution_enabled is False, no hooks fire."""
agent = _make_agent()
agent._evolution_enabled = False
fired: list[TaskResult] = []
async def record_evolve(task, result, memory_store=None):
fired.append(result)
agent.evolve_after_task = record_evolve
async def good_stream(task):
yield _final_answer_event("hello")
agent.handle_task_stream = good_stream
async for _ in agent.execute_stream(_make_task()):
pass
await drain_pending_evolution_tasks()
assert len(fired) == 0