322 lines
10 KiB
Python
322 lines
10 KiB
Python
"""U2 tests: execute_stream evolution hook wiring (OQ6 fix).
|
|
|
|
Verifies that ConfigDrivenAgent.execute_stream() fires evolution hooks
|
|
(on_task_complete / on_task_failed) in its finally block with lifecycle
|
|
parity to the sync execute() path. Covers happy path, failure, cancellation,
|
|
early close, evolution-error suppression, backpressure cap, REST/stream
|
|
parity, and evolution-disabled no-op.
|
|
"""
|
|
|
|
import asyncio
|
|
|
|
import pytest
|
|
|
|
from agentkit.core.config_driven import (
|
|
AgentConfig,
|
|
ConfigDrivenAgent,
|
|
drain_pending_evolution_tasks,
|
|
)
|
|
from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus
|
|
from agentkit.core.react import ReActEvent
|
|
|
|
|
|
# ── Helpers ──────────────────────────────────────────────
|
|
|
|
|
|
def _make_task(**overrides) -> TaskMessage:
|
|
defaults = dict(
|
|
task_id="stream-task-001",
|
|
agent_name="stream_agent",
|
|
task_type="generate",
|
|
priority=1,
|
|
input_data={"query": "hello"},
|
|
callback_url=None,
|
|
created_at=None,
|
|
)
|
|
defaults.update(overrides)
|
|
return TaskMessage.from_dict(defaults)
|
|
|
|
|
|
def _make_agent(max_concurrency: int = 1) -> ConfigDrivenAgent:
|
|
config = AgentConfig.from_dict(
|
|
{
|
|
"name": "stream_agent",
|
|
"agent_type": "content_generation",
|
|
"task_mode": "llm_generate",
|
|
"prompt": {
|
|
"identity": "test agent",
|
|
"instructions": "do the thing",
|
|
"output_format": "text",
|
|
},
|
|
"max_concurrency": max_concurrency,
|
|
}
|
|
)
|
|
agent = ConfigDrivenAgent(config=config)
|
|
agent._evolution_enabled = True
|
|
return agent
|
|
|
|
|
|
def _final_answer_event(output: str = "hello") -> ReActEvent:
|
|
return ReActEvent(
|
|
event_type="final_answer",
|
|
step=0,
|
|
data={"output": output},
|
|
)
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
async def _isolate_evolution_state():
|
|
"""Reset module-level evolution state before each test, drain after.
|
|
|
|
Without this, stuck tasks from a prior test would inflate the pending
|
|
set and break backpressure assertions in later tests.
|
|
"""
|
|
import agentkit.core.config_driven as cd
|
|
|
|
for task in list(cd._pending_evolution_tasks):
|
|
task.cancel()
|
|
if cd._pending_evolution_tasks:
|
|
await asyncio.gather(*cd._pending_evolution_tasks, return_exceptions=True)
|
|
cd._pending_evolution_tasks.clear()
|
|
cd._evolution_dropped_count = 0
|
|
yield
|
|
await drain_pending_evolution_tasks()
|
|
|
|
|
|
# ── Happy path ───────────────────────────────────────────
|
|
|
|
|
|
class TestExecuteStreamHooks:
|
|
async def test_success_fires_on_task_complete(self):
|
|
"""Stream completion fires evolve_after_task with COMPLETED status."""
|
|
agent = _make_agent()
|
|
fired: list[TaskResult] = []
|
|
|
|
async def record_evolve(task, result, memory_store=None):
|
|
fired.append(result)
|
|
|
|
agent.evolve_after_task = record_evolve
|
|
|
|
async def good_stream(task):
|
|
yield _final_answer_event("hello world")
|
|
|
|
agent.handle_task_stream = good_stream
|
|
|
|
events = []
|
|
async for event in agent.execute_stream(_make_task()):
|
|
events.append(event)
|
|
|
|
await drain_pending_evolution_tasks()
|
|
|
|
assert len(events) == 1
|
|
assert events[0].event_type == "final_answer"
|
|
assert len(fired) == 1
|
|
assert fired[0].status == TaskStatus.COMPLETED
|
|
# KTD-8: output_data includes trace_outcome for lifecycle._is_failure_path()
|
|
assert fired[0].output_data == {"content": "hello world", "trace_outcome": "success"}
|
|
|
|
async def test_failure_fires_on_task_failed(self):
|
|
"""Stream exception fires evolve_after_task with FAILED status."""
|
|
agent = _make_agent()
|
|
fired: list[TaskResult] = []
|
|
|
|
async def record_evolve(task, result, memory_store=None):
|
|
fired.append(result)
|
|
|
|
agent.evolve_after_task = record_evolve
|
|
|
|
async def failing_stream(task):
|
|
yield _final_answer_event("partial") # yield once before failing
|
|
raise RuntimeError("stream blew up")
|
|
|
|
agent.handle_task_stream = failing_stream
|
|
|
|
with pytest.raises(RuntimeError, match="stream blew up"):
|
|
async for _ in agent.execute_stream(_make_task()):
|
|
pass
|
|
|
|
await drain_pending_evolution_tasks()
|
|
|
|
assert len(fired) == 1
|
|
assert fired[0].status == TaskStatus.FAILED
|
|
assert "stream blew up" in (fired[0].error_message or "")
|
|
|
|
|
|
# ── Edge cases ───────────────────────────────────────────
|
|
|
|
|
|
class TestExecuteStreamEdgeCases:
|
|
async def test_cancellation_fires_cancelled_status(self):
|
|
"""Stream cancelled mid-flight fires hooks with CANCELLED status."""
|
|
agent = _make_agent()
|
|
fired: list[TaskResult] = []
|
|
|
|
async def record_evolve(task, result, memory_store=None):
|
|
fired.append(result)
|
|
|
|
agent.evolve_after_task = record_evolve
|
|
|
|
started = asyncio.Event()
|
|
|
|
async def slow_stream(task):
|
|
started.set()
|
|
await asyncio.sleep(60)
|
|
yield _final_answer_event("never reached")
|
|
|
|
agent.handle_task_stream = slow_stream
|
|
|
|
async def consume():
|
|
async for _ in agent.execute_stream(_make_task()):
|
|
pass
|
|
|
|
consumer = asyncio.create_task(consume())
|
|
await started.wait()
|
|
await asyncio.sleep(0.05) # let it settle into sleep(60)
|
|
consumer.cancel()
|
|
with pytest.raises(asyncio.CancelledError):
|
|
await consumer
|
|
|
|
await drain_pending_evolution_tasks()
|
|
|
|
assert len(fired) == 1
|
|
assert fired[0].status == TaskStatus.CANCELLED
|
|
|
|
async def test_stream_closed_early_fires_cancelled(self):
|
|
"""Consumer aclose() before final_answer fires CANCELLED status."""
|
|
agent = _make_agent()
|
|
fired: list[TaskResult] = []
|
|
|
|
async def record_evolve(task, result, memory_store=None):
|
|
fired.append(result)
|
|
|
|
agent.evolve_after_task = record_evolve
|
|
|
|
async def blocking_stream(task):
|
|
yield ReActEvent(event_type="thinking", step=0, data={"content": "thinking..."})
|
|
await asyncio.sleep(60)
|
|
yield _final_answer_event("late")
|
|
|
|
agent.handle_task_stream = blocking_stream
|
|
|
|
gen = agent.execute_stream(_make_task())
|
|
first = await gen.__anext__()
|
|
assert first.event_type == "thinking"
|
|
await gen.aclose()
|
|
|
|
await drain_pending_evolution_tasks()
|
|
|
|
assert len(fired) == 1
|
|
assert fired[0].status == TaskStatus.CANCELLED
|
|
assert "stream closed before completion" in (fired[0].error_message or "")
|
|
|
|
async def test_evolution_error_does_not_propagate(self):
|
|
"""Evolution task error is swallowed — stream completes normally."""
|
|
agent = _make_agent()
|
|
|
|
async def failing_evolve(task, result, memory_store=None):
|
|
raise RuntimeError("evolution exploded")
|
|
|
|
agent.evolve_after_task = failing_evolve
|
|
|
|
async def good_stream(task):
|
|
yield _final_answer_event("ok")
|
|
|
|
agent.handle_task_stream = good_stream
|
|
|
|
events = []
|
|
async for event in agent.execute_stream(_make_task()):
|
|
events.append(event)
|
|
|
|
# drain must not raise despite evolution error
|
|
await drain_pending_evolution_tasks()
|
|
|
|
assert len(events) == 1
|
|
assert events[0].data.get("output") == "ok"
|
|
|
|
async def test_backpressure_cap_drops(self):
|
|
"""When pending evolution tasks hit cap, excess is dropped + counted."""
|
|
agent = _make_agent(max_concurrency=1) # cap = max(2, 1*2) = 2
|
|
block = asyncio.Event()
|
|
|
|
async def stuck_evolve(task, result, memory_store=None):
|
|
await block.wait()
|
|
|
|
agent.evolve_after_task = stuck_evolve
|
|
|
|
async def good_stream(task):
|
|
yield _final_answer_event("ok")
|
|
|
|
agent.handle_task_stream = good_stream
|
|
|
|
import agentkit.core.config_driven as cd
|
|
|
|
# Fire 3 streams — first 2 fill the cap (stuck), 3rd is dropped
|
|
for i in range(3):
|
|
async for _ in agent.execute_stream(_make_task(task_id=f"bp-{i}")):
|
|
pass
|
|
await asyncio.sleep(0) # yield to let evolution tasks start
|
|
|
|
assert cd._evolution_dropped_count == 1
|
|
|
|
# Cleanup: release stuck tasks and drain
|
|
block.set()
|
|
await drain_pending_evolution_tasks()
|
|
|
|
|
|
# ── Parity & disabled ────────────────────────────────────
|
|
|
|
|
|
class TestExecuteStreamParity:
|
|
async def test_parity_rest_vs_stream(self):
|
|
"""Both REST on_task_complete and execute_stream fire COMPLETED evolve."""
|
|
agent = _make_agent()
|
|
stream_fired: list[TaskResult] = []
|
|
rest_fired: list[TaskResult] = []
|
|
|
|
async def good_stream(task):
|
|
yield _final_answer_event("hello")
|
|
|
|
agent.handle_task_stream = good_stream
|
|
|
|
async def stream_evolve(task, result, memory_store=None):
|
|
stream_fired.append(result)
|
|
|
|
agent.evolve_after_task = stream_evolve
|
|
|
|
async for _ in agent.execute_stream(_make_task(task_id="stream-1")):
|
|
pass
|
|
await drain_pending_evolution_tasks()
|
|
|
|
async def rest_evolve(task, result, memory_store=None):
|
|
rest_fired.append(result)
|
|
|
|
agent.evolve_after_task = rest_evolve
|
|
await agent.on_task_complete(_make_task(task_id="rest-1"), {"content": "hello"})
|
|
|
|
assert len(stream_fired) == 1
|
|
assert stream_fired[0].status == TaskStatus.COMPLETED
|
|
assert len(rest_fired) == 1
|
|
assert rest_fired[0].status == TaskStatus.COMPLETED
|
|
|
|
async def test_evolution_disabled_no_hooks(self):
|
|
"""When _evolution_enabled is False, no hooks fire."""
|
|
agent = _make_agent()
|
|
agent._evolution_enabled = False
|
|
fired: list[TaskResult] = []
|
|
|
|
async def record_evolve(task, result, memory_store=None):
|
|
fired.append(result)
|
|
|
|
agent.evolve_after_task = record_evolve
|
|
|
|
async def good_stream(task):
|
|
yield _final_answer_event("hello")
|
|
|
|
agent.handle_task_stream = good_stream
|
|
|
|
async for _ in agent.execute_stream(_make_task()):
|
|
pass
|
|
await drain_pending_evolution_tasks()
|
|
|
|
assert len(fired) == 0
|