feat(core): wire evolution hooks into execute_stream path (U2, OQ6 fix)
ConfigDrivenAgent.execute_stream() now fires on_task_complete/on_task_failed evolution hooks in its finally block, achieving lifecycle parity with the sync execute() path. This fixes the OQ6 gap where WebSocket-routed streaming tasks bypassed evolution entirely. Implementation: - Module-level backpressure manager (_schedule_evolution / drain_pending_evolution_tasks) with cap = max(2, max_concurrency * 2), drop + log + counter on exceed, and shutdown drain via asyncio.gather(return_exceptions=True). - _trigger_evolution_hooks / _evolve_safe methods on ConfigDrivenAgent: fire-and-forget via asyncio.create_task, evolution errors swallowed (never fail the stream). - execute_stream finally block distinguishes cancelled (CancelledError / TaskCancelledError -> CANCELLED), failed (Exception -> FAILED), completed (final_answer received -> COMPLETED), and early-close (no completion, no error -> CANCELLED "stream closed before completion"). - app.py shutdown drains pending evolution tasks. - plan_exec_engine.py / reflexion.py: doc comments noting hooks fire at the ConfigDrivenAgent layer (single chokepoint, no double-fire). - portal.py: verification comments at 3 execute_stream call sites (these call react_engine.execute_stream directly, bypassing ConfigDrivenAgent - known gap tracked separately). Tests (8 new in test_execute_stream_hooks.py): - Happy path: success fires COMPLETED, failure fires FAILED. - Edge cases: cancellation fires CANCELLED, early aclose fires CANCELLED, evolution error suppressed, backpressure cap drops + counts. - Parity: REST on_task_complete vs execute_stream both fire COMPLETED. - Disabled: _evolution_enabled=False fires no hooks.
This commit is contained in:
parent
2932ee51ed
commit
dd259153fa
|
|
@ -7,17 +7,25 @@
|
|||
- 新增 Agent 从写 150 行代码降为 10-20 行配置
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import AsyncGenerator, Awaitable
|
||||
from typing import Callable, Coroutine
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Callable, Coroutine
|
||||
|
||||
import yaml
|
||||
|
||||
from agentkit.core.base import BaseAgent
|
||||
from agentkit.core.exceptions import ConfigValidationError
|
||||
from agentkit.core.protocol import AgentCapability, CancellationToken, TaskMessage
|
||||
from agentkit.core.exceptions import ConfigValidationError, TaskCancelledError
|
||||
from agentkit.core.protocol import (
|
||||
AgentCapability,
|
||||
CancellationToken,
|
||||
TaskMessage,
|
||||
TaskResult,
|
||||
TaskStatus,
|
||||
)
|
||||
from agentkit.core.react import ReActEvent
|
||||
from agentkit.evolution.lifecycle import EvolutionMixin
|
||||
from agentkit.evolution.reflector import Reflector
|
||||
|
|
@ -28,6 +36,37 @@ from agentkit.tools.registry import ToolRegistry
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Evolution hook backpressure for execute_stream(): fire-and-forget with a cap
|
||||
# and shutdown drain. ponytail: module-level set means the cap is global across
|
||||
# agents, not per-agent; upgrade path is a per-agent semaphore if fairness matters.
|
||||
_pending_evolution_tasks: set[asyncio.Task[None]] = set()
|
||||
_evolution_dropped_count: int = 0
|
||||
|
||||
|
||||
def _schedule_evolution(coro: Coroutine[Any, Any, None], cap: int) -> None:
|
||||
"""Schedule a fire-and-forget evolution task with backpressure.
|
||||
|
||||
Drops + logs + increments the dropped counter when pending tasks reach ``cap``,
|
||||
mirroring the portal webhook backpressure pattern (``max_concurrent * 2``).
|
||||
"""
|
||||
global _evolution_dropped_count
|
||||
if len(_pending_evolution_tasks) >= cap:
|
||||
_evolution_dropped_count += 1
|
||||
logger.warning("Evolution backpressure cap reached (%d pending), dropping task", cap)
|
||||
coro.close() # avoid 'coroutine never awaited' RuntimeWarning
|
||||
return
|
||||
task = asyncio.create_task(coro)
|
||||
_pending_evolution_tasks.add(task)
|
||||
task.add_done_callback(_pending_evolution_tasks.discard)
|
||||
|
||||
|
||||
async def drain_pending_evolution_tasks() -> None:
|
||||
"""Drain pending fire-and-forget evolution tasks on app shutdown."""
|
||||
if not _pending_evolution_tasks:
|
||||
return
|
||||
logger.info("Draining %d pending evolution tasks", len(_pending_evolution_tasks))
|
||||
await asyncio.gather(*_pending_evolution_tasks, return_exceptions=True)
|
||||
|
||||
|
||||
class AgentConfig:
|
||||
"""Agent 配置模型,从 YAML 或 Dict 构建"""
|
||||
|
|
@ -510,6 +549,26 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
|
|||
except Exception as e:
|
||||
logger.warning(f"Evolution after task failure failed: {e}")
|
||||
|
||||
def _trigger_evolution_hooks(self, task: TaskMessage, result: TaskResult) -> None:
|
||||
"""Schedule evolution after a streaming task (fire-and-forget, backpressure-capped).
|
||||
|
||||
Mirrors the sync on_task_complete/on_task_failed path but non-blocking so
|
||||
streaming latency is unaffected. Evolution errors are swallowed inside
|
||||
_evolve_safe and must never fail the stream. KTD-4: lifecycle parity with
|
||||
execute() for the streaming path.
|
||||
"""
|
||||
if not self._evolution_enabled:
|
||||
return
|
||||
cap = max(2, self._config.max_concurrency * 2)
|
||||
_schedule_evolution(self._evolve_safe(task, result), cap=cap)
|
||||
|
||||
async def _evolve_safe(self, task: TaskMessage, result: TaskResult) -> None:
|
||||
"""Run evolve_after_task, swallowing errors (evolution must not fail stream)."""
|
||||
try:
|
||||
await self.evolve_after_task(task, result)
|
||||
except Exception:
|
||||
logger.warning("Evolution after stream task failed", exc_info=True)
|
||||
|
||||
def _bind_tools(self) -> None:
|
||||
"""根据配置绑定工具"""
|
||||
for tool_name in self._config.tools:
|
||||
|
|
@ -658,9 +717,7 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
|
|||
|
||||
# ── 流式执行(U3) ────────────────────────────────────────
|
||||
|
||||
def _build_llm_messages(
|
||||
self, task: TaskMessage
|
||||
) -> tuple[str | None, list[dict[str, str]]]:
|
||||
def _build_llm_messages(self, task: TaskMessage) -> tuple[str | None, list[dict[str, str]]]:
|
||||
"""Build (system_prompt, user_messages) from task + prompt template.
|
||||
|
||||
Shared by all _handle_*_stream methods to avoid duplicating the
|
||||
|
|
@ -691,16 +748,78 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
|
|||
|
||||
P2 fix: 注册 CancellationToken 到 _active_tokens,使 cancel_task() 能
|
||||
协作式取消流式任务。原实现绕过 BaseAgent.execute(),未注册 token。
|
||||
|
||||
KTD-4: 在 finally 中触发 on_task_complete/on_task_failed 进化钩子,
|
||||
与 execute() 保持生命周期对等。使用 fire-and-forget + 背压上限,
|
||||
进化错误不得阻塞流式返回。PlanExec/Reflexion 等子引擎的异常会向上
|
||||
传播到此处 finally,因此钩子集中在此触发,子引擎无需重复触发。
|
||||
"""
|
||||
token = CancellationToken()
|
||||
self._active_tokens[task.task_id] = token
|
||||
_stream_output: dict = {}
|
||||
_stream_error: BaseException | None = None
|
||||
_stream_completed = False
|
||||
try:
|
||||
await self._register_mcp_tools()
|
||||
async for event in self.handle_task_stream(task):
|
||||
if event.event_type == "final_answer":
|
||||
_raw = event.data.get("output", "")
|
||||
_stream_output = {"content": _raw} if isinstance(_raw, str) else _raw
|
||||
yield event
|
||||
_stream_completed = True
|
||||
except asyncio.CancelledError as ce:
|
||||
# Cancellation must propagate, but hooks still fire (U2 edge case).
|
||||
_stream_error = ce
|
||||
raise
|
||||
except Exception as e:
|
||||
_stream_error = e
|
||||
raise
|
||||
finally:
|
||||
# async generator 的 finally 在 generator 关闭时执行(GC/aclose/正常结束)
|
||||
self._active_tokens.pop(task.task_id, None)
|
||||
# KTD-4: lifecycle parity — fire evolution hooks fire-and-forget.
|
||||
try:
|
||||
now = datetime.now(timezone.utc)
|
||||
if _stream_error is not None:
|
||||
if isinstance(_stream_error, (asyncio.CancelledError, TaskCancelledError)):
|
||||
status = TaskStatus.CANCELLED
|
||||
err_msg = f"stream cancelled: {_stream_error}"
|
||||
else:
|
||||
status = TaskStatus.FAILED
|
||||
err_msg = str(_stream_error)
|
||||
result = TaskResult(
|
||||
task_id=task.task_id,
|
||||
agent_name=self.name,
|
||||
status=status,
|
||||
output_data=None,
|
||||
error_message=err_msg,
|
||||
started_at=now,
|
||||
completed_at=now,
|
||||
)
|
||||
elif _stream_completed:
|
||||
result = TaskResult(
|
||||
task_id=task.task_id,
|
||||
agent_name=self.name,
|
||||
status=TaskStatus.COMPLETED,
|
||||
output_data=_stream_output,
|
||||
error_message=None,
|
||||
started_at=now,
|
||||
completed_at=now,
|
||||
)
|
||||
else:
|
||||
# Stream closed before completion (consumer aclose / GC).
|
||||
result = TaskResult(
|
||||
task_id=task.task_id,
|
||||
agent_name=self.name,
|
||||
status=TaskStatus.CANCELLED,
|
||||
output_data=None,
|
||||
error_message="stream closed before completion",
|
||||
started_at=now,
|
||||
completed_at=now,
|
||||
)
|
||||
self._trigger_evolution_hooks(task, result)
|
||||
except Exception:
|
||||
logger.debug("evolution hook scheduling failed", exc_info=True)
|
||||
|
||||
async def handle_task_stream(self, task: TaskMessage) -> AsyncGenerator[ReActEvent, None]:
|
||||
"""根据 execution_mode / task_mode 流式分派,镜像 handle_task()。"""
|
||||
|
|
|
|||
|
|
@ -126,7 +126,9 @@ class PlanExecEngine:
|
|||
3. Replanner Phase: 失败时重规划
|
||||
"""
|
||||
self._confirmation_handler = confirmation_handler
|
||||
effective_timeout = timeout_seconds if timeout_seconds is not None else self._default_timeout
|
||||
effective_timeout = (
|
||||
timeout_seconds if timeout_seconds is not None else self._default_timeout
|
||||
)
|
||||
|
||||
try:
|
||||
if effective_timeout > 0:
|
||||
|
|
@ -198,6 +200,11 @@ class PlanExecEngine:
|
|||
- "step_completed": 步骤执行完成
|
||||
- "replanning": 触发重规划
|
||||
- "final_answer": 最终结果
|
||||
|
||||
U2: 进化钩子(on_task_complete/on_task_failed)由外层
|
||||
ConfigDrivenAgent.execute_stream() 的 finally 集中触发——本引擎仅向上
|
||||
传播异常与 final_answer 事件,不重复触发钩子(避免双重进化)。
|
||||
ponytail: 引擎无 evolution 上下文,钩子上移至 agent 层是单触发点。
|
||||
"""
|
||||
self._confirmation_handler = confirmation_handler
|
||||
# Memory retrieval
|
||||
|
|
@ -207,7 +214,9 @@ class PlanExecEngine:
|
|||
top_k = (retrieval_config or {}).get("top_k", 5)
|
||||
token_budget = (retrieval_config or {}).get("token_budget", 2000)
|
||||
memory_context = await memory_retriever.get_context_string(
|
||||
query=query, top_k=top_k, token_budget=token_budget,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
token_budget=token_budget,
|
||||
)
|
||||
if memory_context:
|
||||
if system_prompt:
|
||||
|
|
@ -258,12 +267,14 @@ class PlanExecEngine:
|
|||
},
|
||||
)
|
||||
|
||||
state.trajectory.append(ReActStep(
|
||||
state.trajectory.append(
|
||||
ReActStep(
|
||||
step=state.step_counter,
|
||||
action="plan_generated",
|
||||
content=f"Generated plan with {len(plan.steps)} steps",
|
||||
tokens=0,
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
# Persist plan as Spec if spec_manager is provided
|
||||
if self._spec_manager is not None:
|
||||
|
|
@ -335,18 +346,24 @@ class PlanExecEngine:
|
|||
},
|
||||
)
|
||||
|
||||
state.trajectory.append(ReActStep(
|
||||
state.trajectory.append(
|
||||
ReActStep(
|
||||
step=state.step_counter,
|
||||
action="step_completed" if step_result.status == PlanStepStatus.COMPLETED else "step_failed",
|
||||
action="step_completed"
|
||||
if step_result.status == PlanStepStatus.COMPLETED
|
||||
else "step_failed",
|
||||
tool_name=step_name,
|
||||
result=step_result.result,
|
||||
tokens=0,
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
if trace_recorder is not None:
|
||||
trace_recorder.record_step(
|
||||
step=state.step_counter,
|
||||
action="step_completed" if step_result.status == PlanStepStatus.COMPLETED else "step_failed",
|
||||
action="step_completed"
|
||||
if step_result.status == PlanStepStatus.COMPLETED
|
||||
else "step_failed",
|
||||
tool_name=step_name,
|
||||
output_data=step_result.result,
|
||||
error=step_result.error,
|
||||
|
|
@ -372,19 +389,27 @@ class PlanExecEngine:
|
|||
)
|
||||
|
||||
pipeline = self._plan_to_pipeline(current_plan, agent_name)
|
||||
pipeline_result = self._plan_result_to_pipeline_result(current_plan, plan_result)
|
||||
pipeline_result = self._plan_result_to_pipeline_result(
|
||||
current_plan, plan_result
|
||||
)
|
||||
|
||||
reflection_report = await self._reflector.reflect(pipeline, pipeline_result, replan_count)
|
||||
revised_pipeline = await self._replanner.replan(pipeline, pipeline_result, reflection_report)
|
||||
reflection_report = await self._reflector.reflect(
|
||||
pipeline, pipeline_result, replan_count
|
||||
)
|
||||
revised_pipeline = await self._replanner.replan(
|
||||
pipeline, pipeline_result, reflection_report
|
||||
)
|
||||
current_plan = self._pipeline_to_plan(revised_pipeline, plan.goal)
|
||||
self._merge_completed_results(current_plan, plan_result)
|
||||
|
||||
state.trajectory.append(ReActStep(
|
||||
state.trajectory.append(
|
||||
ReActStep(
|
||||
step=state.step_counter,
|
||||
action="replanning",
|
||||
content=f"Replanned (attempt {replan_count}): {reflection_report.root_cause}",
|
||||
tokens=0,
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
continue
|
||||
|
||||
|
|
@ -404,12 +429,14 @@ class PlanExecEngine:
|
|||
|
||||
# 最终步骤
|
||||
state.step_counter += 1
|
||||
state.trajectory.append(ReActStep(
|
||||
state.trajectory.append(
|
||||
ReActStep(
|
||||
step=state.step_counter,
|
||||
action="final_answer",
|
||||
content=output,
|
||||
tokens=0,
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
yield ReActEvent(
|
||||
event_type="final_answer",
|
||||
|
|
@ -470,7 +497,9 @@ class PlanExecEngine:
|
|||
top_k = (retrieval_config or {}).get("top_k", 5)
|
||||
token_budget = (retrieval_config or {}).get("token_budget", 2000)
|
||||
memory_context = await memory_retriever.get_context_string(
|
||||
query=query, top_k=top_k, token_budget=token_budget,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
token_budget=token_budget,
|
||||
)
|
||||
if memory_context:
|
||||
if system_prompt:
|
||||
|
|
@ -509,20 +538,33 @@ class PlanExecEngine:
|
|||
# Emit plan_generated event
|
||||
if self._step_event_callback:
|
||||
try:
|
||||
await self._step_event_callback("plan_generated", {
|
||||
await self._step_event_callback(
|
||||
"plan_generated",
|
||||
{
|
||||
"plan_id": plan.plan_id,
|
||||
"goal": plan.goal,
|
||||
"steps": [s.to_dict() for s in plan.steps],
|
||||
})
|
||||
except (RuntimeError, ValueError, TypeError, KeyError, AttributeError, ConnectionError, asyncio.TimeoutError) as e:
|
||||
},
|
||||
)
|
||||
except (
|
||||
RuntimeError,
|
||||
ValueError,
|
||||
TypeError,
|
||||
KeyError,
|
||||
AttributeError,
|
||||
ConnectionError,
|
||||
asyncio.TimeoutError,
|
||||
) as e:
|
||||
logger.warning(f"Step event callback failed: {e}")
|
||||
|
||||
trajectory.append(ReActStep(
|
||||
trajectory.append(
|
||||
ReActStep(
|
||||
step=1,
|
||||
action="plan_generated",
|
||||
content=f"Generated plan with {len(plan.steps)} steps",
|
||||
tokens=0,
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
# Persist plan as Spec if spec_manager is provided
|
||||
if self._spec_manager is not None:
|
||||
|
|
@ -530,12 +572,23 @@ class PlanExecEngine:
|
|||
self._spec_manager.create(spec)
|
||||
if self._step_event_callback:
|
||||
try:
|
||||
await self._step_event_callback("spec_created", {
|
||||
await self._step_event_callback(
|
||||
"spec_created",
|
||||
{
|
||||
"spec_id": spec.spec_id,
|
||||
"goal": spec.goal,
|
||||
"num_steps": len(spec.steps),
|
||||
})
|
||||
except (RuntimeError, ValueError, TypeError, KeyError, AttributeError, ConnectionError, asyncio.TimeoutError) as e:
|
||||
},
|
||||
)
|
||||
except (
|
||||
RuntimeError,
|
||||
ValueError,
|
||||
TypeError,
|
||||
KeyError,
|
||||
AttributeError,
|
||||
ConnectionError,
|
||||
asyncio.TimeoutError,
|
||||
) as e:
|
||||
logger.warning(f"Step event callback failed: {e}")
|
||||
|
||||
if trace_recorder is not None:
|
||||
|
|
@ -572,12 +625,14 @@ class PlanExecEngine:
|
|||
else:
|
||||
trace_outcome = "success"
|
||||
|
||||
trajectory.append(ReActStep(
|
||||
trajectory.append(
|
||||
ReActStep(
|
||||
step=len(trajectory) + 1,
|
||||
action="final_answer",
|
||||
content=output,
|
||||
tokens=0,
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
return ReActResult(
|
||||
output=output,
|
||||
|
|
@ -666,32 +721,53 @@ class PlanExecEngine:
|
|||
for sid, step_result in plan_result.step_results.items():
|
||||
plan_step = current_plan.get_step(sid)
|
||||
step_name = plan_step.name if plan_step else sid
|
||||
trajectory.append(ReActStep(
|
||||
trajectory.append(
|
||||
ReActStep(
|
||||
step=len(trajectory) + 1,
|
||||
action="step_completed" if step_result.status == PlanStepStatus.COMPLETED else "step_failed",
|
||||
action="step_completed"
|
||||
if step_result.status == PlanStepStatus.COMPLETED
|
||||
else "step_failed",
|
||||
tool_name=step_name,
|
||||
result=step_result.result,
|
||||
tokens=0,
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
# Emit step event callback
|
||||
if self._step_event_callback:
|
||||
event_type = "step_completed" if step_result.status == PlanStepStatus.COMPLETED else "step_failed"
|
||||
event_type = (
|
||||
"step_completed"
|
||||
if step_result.status == PlanStepStatus.COMPLETED
|
||||
else "step_failed"
|
||||
)
|
||||
try:
|
||||
await self._step_event_callback(event_type, {
|
||||
await self._step_event_callback(
|
||||
event_type,
|
||||
{
|
||||
"step_id": sid,
|
||||
"step_name": step_name,
|
||||
"status": step_result.status.value,
|
||||
"result": step_result.result,
|
||||
"error": step_result.error,
|
||||
})
|
||||
except (RuntimeError, ValueError, TypeError, KeyError, AttributeError, ConnectionError, asyncio.TimeoutError) as e:
|
||||
},
|
||||
)
|
||||
except (
|
||||
RuntimeError,
|
||||
ValueError,
|
||||
TypeError,
|
||||
KeyError,
|
||||
AttributeError,
|
||||
ConnectionError,
|
||||
asyncio.TimeoutError,
|
||||
) as e:
|
||||
logger.warning(f"Step event callback failed: {e}")
|
||||
|
||||
if trace_recorder is not None:
|
||||
trace_recorder.record_step(
|
||||
step=len(trajectory),
|
||||
action="step_completed" if step_result.status == PlanStepStatus.COMPLETED else "step_failed",
|
||||
action="step_completed"
|
||||
if step_result.status == PlanStepStatus.COMPLETED
|
||||
else "step_failed",
|
||||
tool_name=step_name,
|
||||
output_data=step_result.result,
|
||||
error=step_result.error,
|
||||
|
|
@ -714,10 +790,14 @@ class PlanExecEngine:
|
|||
pipeline_result = self._plan_result_to_pipeline_result(current_plan, plan_result)
|
||||
|
||||
# 反思
|
||||
reflection_report = await self._reflector.reflect(pipeline, pipeline_result, replan_count)
|
||||
reflection_report = await self._reflector.reflect(
|
||||
pipeline, pipeline_result, replan_count
|
||||
)
|
||||
|
||||
# 重规划
|
||||
revised_pipeline = await self._replanner.replan(pipeline, pipeline_result, reflection_report)
|
||||
revised_pipeline = await self._replanner.replan(
|
||||
pipeline, pipeline_result, reflection_report
|
||||
)
|
||||
|
||||
# 将修正后的 Pipeline 转回 ExecutionPlan
|
||||
current_plan = self._pipeline_to_plan(revised_pipeline, plan.goal)
|
||||
|
|
@ -728,20 +808,33 @@ class PlanExecEngine:
|
|||
# Emit replanning event
|
||||
if self._step_event_callback:
|
||||
try:
|
||||
await self._step_event_callback("replanning", {
|
||||
await self._step_event_callback(
|
||||
"replanning",
|
||||
{
|
||||
"replan_count": replan_count,
|
||||
"root_cause": reflection_report.root_cause,
|
||||
"new_plan_id": current_plan.plan_id,
|
||||
})
|
||||
except (RuntimeError, ValueError, TypeError, KeyError, AttributeError, ConnectionError, asyncio.TimeoutError) as e:
|
||||
},
|
||||
)
|
||||
except (
|
||||
RuntimeError,
|
||||
ValueError,
|
||||
TypeError,
|
||||
KeyError,
|
||||
AttributeError,
|
||||
ConnectionError,
|
||||
asyncio.TimeoutError,
|
||||
) as e:
|
||||
logger.warning(f"Step event callback failed: {e}")
|
||||
|
||||
trajectory.append(ReActStep(
|
||||
trajectory.append(
|
||||
ReActStep(
|
||||
step=len(trajectory) + 1,
|
||||
action="replanning",
|
||||
content=f"Replanned (attempt {replan_count}): {reflection_report.root_cause}",
|
||||
tokens=0,
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
if trace_recorder is not None:
|
||||
trace_recorder.record_step(
|
||||
|
|
@ -850,13 +943,15 @@ class PlanExecEngine:
|
|||
|
||||
stages = []
|
||||
for step in plan.steps:
|
||||
stages.append(PipelineStage(
|
||||
stages.append(
|
||||
PipelineStage(
|
||||
name=step.step_id,
|
||||
agent=agent_name,
|
||||
action=step.description,
|
||||
depends_on=step.dependencies,
|
||||
inputs=step.input_data,
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
return Pipeline(
|
||||
name=f"plan_{plan.plan_id}",
|
||||
|
|
@ -904,14 +999,16 @@ class PlanExecEngine:
|
|||
"""将修正后的 Pipeline 转回 ExecutionPlan"""
|
||||
steps = []
|
||||
for stage in pipeline.stages:
|
||||
steps.append(PlanStep(
|
||||
steps.append(
|
||||
PlanStep(
|
||||
step_id=stage.name,
|
||||
name=stage.name,
|
||||
description=stage.action,
|
||||
dependencies=stage.depends_on,
|
||||
input_data=stage.inputs,
|
||||
required_skills=[],
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
plan = ExecutionPlan(
|
||||
goal=goal,
|
||||
|
|
@ -944,10 +1041,12 @@ class PlanExecEngine:
|
|||
for step in plan.steps:
|
||||
sr = plan_result.step_results.get(step.step_id)
|
||||
if sr and sr.status == PlanStepStatus.COMPLETED and sr.result:
|
||||
completed_results.append({
|
||||
completed_results.append(
|
||||
{
|
||||
"step": step.name,
|
||||
"result": sr.result,
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
if not completed_results:
|
||||
# 没有成功步骤
|
||||
|
|
@ -964,7 +1063,11 @@ class PlanExecEngine:
|
|||
# 简单聚合:将所有成功步骤结果格式化
|
||||
parts = []
|
||||
for item in completed_results:
|
||||
result_str = json.dumps(item["result"], ensure_ascii=False) if isinstance(item["result"], dict) else str(item["result"])
|
||||
result_str = (
|
||||
json.dumps(item["result"], ensure_ascii=False)
|
||||
if isinstance(item["result"], dict)
|
||||
else str(item["result"])
|
||||
)
|
||||
parts.append(f"**{item['step']}**: {result_str}")
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
|
@ -1068,7 +1171,9 @@ class _ReActStepAgent:
|
|||
dep_results = input_data.get("dependency_results", {})
|
||||
|
||||
# 构建步骤 prompt
|
||||
prompt_parts = [f"Execute the following task step:\n\nStep: {step_name}\nDescription: {step_description}"]
|
||||
prompt_parts = [
|
||||
f"Execute the following task step:\n\nStep: {step_name}\nDescription: {step_description}"
|
||||
]
|
||||
if dep_results:
|
||||
prompt_parts.append(
|
||||
f"\nResults from previous steps:\n{json.dumps(dep_results, ensure_ascii=False, indent=2)}"
|
||||
|
|
|
|||
|
|
@ -78,7 +78,9 @@ class ReflexionEngine:
|
|||
if max_reflections < 1:
|
||||
raise ValueError(f"max_reflections must be >= 1, got {max_reflections}")
|
||||
if not 0.0 <= quality_threshold <= 1.0:
|
||||
raise ValueError(f"quality_threshold must be between 0.0 and 1.0, got {quality_threshold}")
|
||||
raise ValueError(
|
||||
f"quality_threshold must be between 0.0 and 1.0, got {quality_threshold}"
|
||||
)
|
||||
|
||||
self._llm_gateway = llm_gateway
|
||||
self._max_steps = max_steps
|
||||
|
|
@ -116,7 +118,9 @@ class ReflexionEngine:
|
|||
reflect_model: 用于生成反思的模型,默认与 evaluate_model 相同
|
||||
其余参数与 ReActEngine.execute() 相同
|
||||
"""
|
||||
effective_timeout = timeout_seconds if timeout_seconds is not None else self._default_timeout
|
||||
effective_timeout = (
|
||||
timeout_seconds if timeout_seconds is not None else self._default_timeout
|
||||
)
|
||||
act_model = model
|
||||
effective_evaluate_model = evaluate_model or act_model
|
||||
effective_reflect_model = reflect_model or effective_evaluate_model
|
||||
|
|
@ -187,7 +191,9 @@ class ReflexionEngine:
|
|||
reflect_model: str = "default",
|
||||
) -> ReflexionResult:
|
||||
# Telemetry
|
||||
agent_request_counter().add(1, {"agent.name": agent_name, "agent.type": task_type or "reflexion"})
|
||||
agent_request_counter().add(
|
||||
1, {"agent.name": agent_name, "agent.type": task_type or "reflexion"}
|
||||
)
|
||||
|
||||
_span_cm = None
|
||||
_span = None
|
||||
|
|
@ -348,6 +354,11 @@ class ReflexionEngine:
|
|||
"""执行 Reflexion 循环,以流式事件形式返回
|
||||
|
||||
在每次 ReAct 执行、评估、反思和重试时发出事件。
|
||||
|
||||
U2: 进化钩子(on_task_complete/on_task_failed)由外层
|
||||
ConfigDrivenAgent.execute_stream() 的 finally 集中触发——本引擎仅向上
|
||||
传播异常与 final_answer 事件,不重复触发钩子(避免双重进化)。
|
||||
ponytail: 引擎无 evolution 上下文,钩子上移至 agent 层是单触发点。
|
||||
"""
|
||||
act_model = model
|
||||
effective_evaluate_model = evaluate_model or act_model
|
||||
|
|
@ -600,9 +611,7 @@ class ReflexionEngine:
|
|||
def _parse_evaluation_score(self, content: str) -> float:
|
||||
"""从 LLM 响应中解析评估分数"""
|
||||
# 尝试从代码块中提取 JSON
|
||||
json_match = re.search(
|
||||
r"```(?:json)?\s*\n?(.*?)\n?```", content, re.DOTALL
|
||||
)
|
||||
json_match = re.search(r"```(?:json)?\s*\n?(.*?)\n?```", content, re.DOTALL)
|
||||
if json_match:
|
||||
try:
|
||||
data = json.loads(json_match.group(1))
|
||||
|
|
|
|||
|
|
@ -81,7 +81,14 @@ def _build_llm_gateway(config: ServerConfig) -> LLMGateway:
|
|||
backend=config.usage_store.get("backend", "memory"),
|
||||
redis_url=config.usage_store.get("redis_url", "redis://localhost:6379"),
|
||||
)
|
||||
except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError) as e:
|
||||
except (
|
||||
ConnectionError,
|
||||
OSError,
|
||||
asyncio.TimeoutError,
|
||||
ValueError,
|
||||
KeyError,
|
||||
RuntimeError,
|
||||
) as e:
|
||||
logger.warning(f"Failed to initialize usage store: {e}, using in-memory")
|
||||
|
||||
gateway = LLMGateway(config=config.llm_config, usage_store=usage_store)
|
||||
|
|
@ -478,7 +485,14 @@ async def lifespan(app: FastAPI):
|
|||
_row = await _cur.fetchone()
|
||||
if _row is not None:
|
||||
default_cal_user_id = str(_row["id"])
|
||||
except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError):
|
||||
except (
|
||||
ConnectionError,
|
||||
OSError,
|
||||
asyncio.TimeoutError,
|
||||
ValueError,
|
||||
KeyError,
|
||||
RuntimeError,
|
||||
):
|
||||
logger.debug("Could not resolve default user_id for CalendarTool", exc_info=True)
|
||||
|
||||
calendar_tool = CalendarTool(
|
||||
|
|
@ -505,7 +519,9 @@ async def lifespan(app: FastAPI):
|
|||
except (ValueError, KeyError, RuntimeError, AttributeError):
|
||||
# ponytail: log at debug — CalendarTool double-registration
|
||||
# is expected on reload, but silent pass hides real errors.
|
||||
logger.debug("CalendarTool already registered or registration failed", exc_info=True)
|
||||
logger.debug(
|
||||
"CalendarTool already registered or registration failed", exc_info=True
|
||||
)
|
||||
# Strip any existing "## 可用工具" section to avoid
|
||||
# duplicate tool blocks in the system prompt.
|
||||
base_prompt = getattr(default_agent, "_system_prompt", None) or (
|
||||
|
|
@ -570,7 +586,14 @@ async def lifespan(app: FastAPI):
|
|||
from agentkit.rag_platform.store import ensure_tables
|
||||
|
||||
await ensure_tables(rag_database_url)
|
||||
except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError):
|
||||
except (
|
||||
ConnectionError,
|
||||
OSError,
|
||||
asyncio.TimeoutError,
|
||||
ValueError,
|
||||
KeyError,
|
||||
RuntimeError,
|
||||
):
|
||||
logger.exception("Failed to ensure rag_platform tables")
|
||||
|
||||
# KBStore — KB/Document persistence
|
||||
|
|
@ -693,6 +716,14 @@ async def lifespan(app: FastAPI):
|
|||
except (RuntimeError, asyncio.TimeoutError, ConnectionError, OSError):
|
||||
logger.debug("close_all_adapters 异常已忽略")
|
||||
|
||||
# U2: drain pending fire-and-forget evolution tasks from execute_stream()
|
||||
try:
|
||||
from agentkit.core.config_driven import drain_pending_evolution_tasks
|
||||
|
||||
await drain_pending_evolution_tasks()
|
||||
except Exception:
|
||||
logger.debug("drain_pending_evolution_tasks 异常已忽略", exc_info=True)
|
||||
|
||||
|
||||
def _on_config_change(app: FastAPI, config: ServerConfig) -> None:
|
||||
"""Handle config change by reloading affected components.
|
||||
|
|
@ -736,7 +767,14 @@ def _on_config_change(app: FastAPI, config: ServerConfig) -> None:
|
|||
if hasattr(app.state, "agent_pool") and app.state.agent_pool is not None:
|
||||
app.state.agent_pool._llm_gateway = new_gateway
|
||||
logger.info(f"LLM Gateway reloaded (config v{current_version})")
|
||||
except (ValueError, TypeError, KeyError, RuntimeError, ConnectionError, OSError) as e:
|
||||
except (
|
||||
ValueError,
|
||||
TypeError,
|
||||
KeyError,
|
||||
RuntimeError,
|
||||
ConnectionError,
|
||||
OSError,
|
||||
) as e:
|
||||
logger.error(f"Failed to reload LLM Gateway: {e}")
|
||||
|
||||
# Reload skills if skill paths changed
|
||||
|
|
@ -1185,7 +1223,15 @@ def create_app(
|
|||
try:
|
||||
epi_session_factory = create_episodic_session_factory(database_url)
|
||||
epi_model = EpisodeModel
|
||||
except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError, ImportError) as db_err:
|
||||
except (
|
||||
ConnectionError,
|
||||
OSError,
|
||||
asyncio.TimeoutError,
|
||||
ValueError,
|
||||
KeyError,
|
||||
RuntimeError,
|
||||
ImportError,
|
||||
) as db_err:
|
||||
import logging as _log
|
||||
|
||||
_log.getLogger(__name__).warning(
|
||||
|
|
|
|||
|
|
@ -577,6 +577,9 @@ async def chat(request: ChatRequest, req: Request, _auth: None = Depends(_verify
|
|||
|
||||
collected_output: list[str] = []
|
||||
try:
|
||||
# U2 verify: calls react_engine.execute_stream directly, bypassing
|
||||
# ConfigDrivenAgent.execute_stream — evolution hooks NOT propagated
|
||||
# here. Routing through agent.execute_stream is tracked separately.
|
||||
async for event in react_engine.execute_stream(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
|
|
@ -698,6 +701,9 @@ async def chat_stream(request: ChatRequest, req: Request, _auth: None = Depends(
|
|||
|
||||
collected_output: list[str] = []
|
||||
try:
|
||||
# U2 verify: calls react_engine.execute_stream directly, bypassing
|
||||
# ConfigDrivenAgent.execute_stream — evolution hooks NOT propagated
|
||||
# here. Routing through agent.execute_stream is tracked separately.
|
||||
async for event in react_engine.execute_stream(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
|
|
@ -812,9 +818,7 @@ async def _conversation_has_board_started(conversation_id: str) -> bool:
|
|||
list endpoint.
|
||||
"""
|
||||
try:
|
||||
return await _conversation_store.has_message_with_type(
|
||||
conversation_id, "board_started"
|
||||
)
|
||||
return await _conversation_store.has_message_with_type(conversation_id, "board_started")
|
||||
except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError):
|
||||
logger.warning("is_board lookup failed for %s", conversation_id, exc_info=True)
|
||||
return False
|
||||
|
|
@ -881,10 +885,7 @@ async def get_conversation(
|
|||
"messages": [_hydrate_persisted_message(conv.id, i, m) for i, m in enumerate(history)],
|
||||
"created_at": conv.created_at.isoformat(),
|
||||
"updated_at": conv.updated_at.isoformat(),
|
||||
"is_board": any(
|
||||
(m.metadata or {}).get("message_type") == "board_started"
|
||||
for m in history
|
||||
),
|
||||
"is_board": any((m.metadata or {}).get("message_type") == "board_started" for m in history),
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -998,6 +999,9 @@ async def _execute_react_background(
|
|||
):
|
||||
logger.warning("Failed to update TaskStore RUNNING", exc_info=True)
|
||||
|
||||
# U2 verify: calls react_engine.execute_stream directly, bypassing
|
||||
# ConfigDrivenAgent.execute_stream — evolution hooks NOT propagated
|
||||
# here. Routing through agent.execute_stream is tracked separately.
|
||||
async for event in react_engine.execute_stream(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,320 @@
|
|||
"""U2 tests: execute_stream evolution hook wiring (OQ6 fix).
|
||||
|
||||
Verifies that ConfigDrivenAgent.execute_stream() fires evolution hooks
|
||||
(on_task_complete / on_task_failed) in its finally block with lifecycle
|
||||
parity to the sync execute() path. Covers happy path, failure, cancellation,
|
||||
early close, evolution-error suppression, backpressure cap, REST/stream
|
||||
parity, and evolution-disabled no-op.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.core.config_driven import (
|
||||
AgentConfig,
|
||||
ConfigDrivenAgent,
|
||||
drain_pending_evolution_tasks,
|
||||
)
|
||||
from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus
|
||||
from agentkit.core.react import ReActEvent
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def _make_task(**overrides) -> TaskMessage:
|
||||
defaults = dict(
|
||||
task_id="stream-task-001",
|
||||
agent_name="stream_agent",
|
||||
task_type="generate",
|
||||
priority=1,
|
||||
input_data={"query": "hello"},
|
||||
callback_url=None,
|
||||
created_at=None,
|
||||
)
|
||||
defaults.update(overrides)
|
||||
return TaskMessage.from_dict(defaults)
|
||||
|
||||
|
||||
def _make_agent(max_concurrency: int = 1) -> ConfigDrivenAgent:
|
||||
config = AgentConfig.from_dict(
|
||||
{
|
||||
"name": "stream_agent",
|
||||
"agent_type": "content_generation",
|
||||
"task_mode": "llm_generate",
|
||||
"prompt": {
|
||||
"identity": "test agent",
|
||||
"instructions": "do the thing",
|
||||
"output_format": "text",
|
||||
},
|
||||
"max_concurrency": max_concurrency,
|
||||
}
|
||||
)
|
||||
agent = ConfigDrivenAgent(config=config)
|
||||
agent._evolution_enabled = True
|
||||
return agent
|
||||
|
||||
|
||||
def _final_answer_event(output: str = "hello") -> ReActEvent:
|
||||
return ReActEvent(
|
||||
event_type="final_answer",
|
||||
step=0,
|
||||
data={"output": output},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def _isolate_evolution_state():
|
||||
"""Reset module-level evolution state before each test, drain after.
|
||||
|
||||
Without this, stuck tasks from a prior test would inflate the pending
|
||||
set and break backpressure assertions in later tests.
|
||||
"""
|
||||
import agentkit.core.config_driven as cd
|
||||
|
||||
for task in list(cd._pending_evolution_tasks):
|
||||
task.cancel()
|
||||
if cd._pending_evolution_tasks:
|
||||
await asyncio.gather(*cd._pending_evolution_tasks, return_exceptions=True)
|
||||
cd._pending_evolution_tasks.clear()
|
||||
cd._evolution_dropped_count = 0
|
||||
yield
|
||||
await drain_pending_evolution_tasks()
|
||||
|
||||
|
||||
# ── Happy path ───────────────────────────────────────────
|
||||
|
||||
|
||||
class TestExecuteStreamHooks:
|
||||
async def test_success_fires_on_task_complete(self):
|
||||
"""Stream completion fires evolve_after_task with COMPLETED status."""
|
||||
agent = _make_agent()
|
||||
fired: list[TaskResult] = []
|
||||
|
||||
async def record_evolve(task, result, memory_store=None):
|
||||
fired.append(result)
|
||||
|
||||
agent.evolve_after_task = record_evolve
|
||||
|
||||
async def good_stream(task):
|
||||
yield _final_answer_event("hello world")
|
||||
|
||||
agent.handle_task_stream = good_stream
|
||||
|
||||
events = []
|
||||
async for event in agent.execute_stream(_make_task()):
|
||||
events.append(event)
|
||||
|
||||
await drain_pending_evolution_tasks()
|
||||
|
||||
assert len(events) == 1
|
||||
assert events[0].event_type == "final_answer"
|
||||
assert len(fired) == 1
|
||||
assert fired[0].status == TaskStatus.COMPLETED
|
||||
assert fired[0].output_data == {"content": "hello world"}
|
||||
|
||||
async def test_failure_fires_on_task_failed(self):
|
||||
"""Stream exception fires evolve_after_task with FAILED status."""
|
||||
agent = _make_agent()
|
||||
fired: list[TaskResult] = []
|
||||
|
||||
async def record_evolve(task, result, memory_store=None):
|
||||
fired.append(result)
|
||||
|
||||
agent.evolve_after_task = record_evolve
|
||||
|
||||
async def failing_stream(task):
|
||||
yield _final_answer_event("partial") # yield once before failing
|
||||
raise RuntimeError("stream blew up")
|
||||
|
||||
agent.handle_task_stream = failing_stream
|
||||
|
||||
with pytest.raises(RuntimeError, match="stream blew up"):
|
||||
async for _ in agent.execute_stream(_make_task()):
|
||||
pass
|
||||
|
||||
await drain_pending_evolution_tasks()
|
||||
|
||||
assert len(fired) == 1
|
||||
assert fired[0].status == TaskStatus.FAILED
|
||||
assert "stream blew up" in (fired[0].error_message or "")
|
||||
|
||||
|
||||
# ── Edge cases ───────────────────────────────────────────
|
||||
|
||||
|
||||
class TestExecuteStreamEdgeCases:
|
||||
async def test_cancellation_fires_cancelled_status(self):
|
||||
"""Stream cancelled mid-flight fires hooks with CANCELLED status."""
|
||||
agent = _make_agent()
|
||||
fired: list[TaskResult] = []
|
||||
|
||||
async def record_evolve(task, result, memory_store=None):
|
||||
fired.append(result)
|
||||
|
||||
agent.evolve_after_task = record_evolve
|
||||
|
||||
started = asyncio.Event()
|
||||
|
||||
async def slow_stream(task):
|
||||
started.set()
|
||||
await asyncio.sleep(60)
|
||||
yield _final_answer_event("never reached")
|
||||
|
||||
agent.handle_task_stream = slow_stream
|
||||
|
||||
async def consume():
|
||||
async for _ in agent.execute_stream(_make_task()):
|
||||
pass
|
||||
|
||||
consumer = asyncio.create_task(consume())
|
||||
await started.wait()
|
||||
await asyncio.sleep(0.05) # let it settle into sleep(60)
|
||||
consumer.cancel()
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await consumer
|
||||
|
||||
await drain_pending_evolution_tasks()
|
||||
|
||||
assert len(fired) == 1
|
||||
assert fired[0].status == TaskStatus.CANCELLED
|
||||
|
||||
async def test_stream_closed_early_fires_cancelled(self):
|
||||
"""Consumer aclose() before final_answer fires CANCELLED status."""
|
||||
agent = _make_agent()
|
||||
fired: list[TaskResult] = []
|
||||
|
||||
async def record_evolve(task, result, memory_store=None):
|
||||
fired.append(result)
|
||||
|
||||
agent.evolve_after_task = record_evolve
|
||||
|
||||
async def blocking_stream(task):
|
||||
yield ReActEvent(event_type="thinking", step=0, data={"content": "thinking..."})
|
||||
await asyncio.sleep(60)
|
||||
yield _final_answer_event("late")
|
||||
|
||||
agent.handle_task_stream = blocking_stream
|
||||
|
||||
gen = agent.execute_stream(_make_task())
|
||||
first = await gen.__anext__()
|
||||
assert first.event_type == "thinking"
|
||||
await gen.aclose()
|
||||
|
||||
await drain_pending_evolution_tasks()
|
||||
|
||||
assert len(fired) == 1
|
||||
assert fired[0].status == TaskStatus.CANCELLED
|
||||
assert "stream closed before completion" in (fired[0].error_message or "")
|
||||
|
||||
async def test_evolution_error_does_not_propagate(self):
|
||||
"""Evolution task error is swallowed — stream completes normally."""
|
||||
agent = _make_agent()
|
||||
|
||||
async def failing_evolve(task, result, memory_store=None):
|
||||
raise RuntimeError("evolution exploded")
|
||||
|
||||
agent.evolve_after_task = failing_evolve
|
||||
|
||||
async def good_stream(task):
|
||||
yield _final_answer_event("ok")
|
||||
|
||||
agent.handle_task_stream = good_stream
|
||||
|
||||
events = []
|
||||
async for event in agent.execute_stream(_make_task()):
|
||||
events.append(event)
|
||||
|
||||
# drain must not raise despite evolution error
|
||||
await drain_pending_evolution_tasks()
|
||||
|
||||
assert len(events) == 1
|
||||
assert events[0].data.get("output") == "ok"
|
||||
|
||||
async def test_backpressure_cap_drops(self):
|
||||
"""When pending evolution tasks hit cap, excess is dropped + counted."""
|
||||
agent = _make_agent(max_concurrency=1) # cap = max(2, 1*2) = 2
|
||||
block = asyncio.Event()
|
||||
|
||||
async def stuck_evolve(task, result, memory_store=None):
|
||||
await block.wait()
|
||||
|
||||
agent.evolve_after_task = stuck_evolve
|
||||
|
||||
async def good_stream(task):
|
||||
yield _final_answer_event("ok")
|
||||
|
||||
agent.handle_task_stream = good_stream
|
||||
|
||||
import agentkit.core.config_driven as cd
|
||||
|
||||
# Fire 3 streams — first 2 fill the cap (stuck), 3rd is dropped
|
||||
for i in range(3):
|
||||
async for _ in agent.execute_stream(_make_task(task_id=f"bp-{i}")):
|
||||
pass
|
||||
await asyncio.sleep(0) # yield to let evolution tasks start
|
||||
|
||||
assert cd._evolution_dropped_count == 1
|
||||
|
||||
# Cleanup: release stuck tasks and drain
|
||||
block.set()
|
||||
await drain_pending_evolution_tasks()
|
||||
|
||||
|
||||
# ── Parity & disabled ────────────────────────────────────
|
||||
|
||||
|
||||
class TestExecuteStreamParity:
|
||||
async def test_parity_rest_vs_stream(self):
|
||||
"""Both REST on_task_complete and execute_stream fire COMPLETED evolve."""
|
||||
agent = _make_agent()
|
||||
stream_fired: list[TaskResult] = []
|
||||
rest_fired: list[TaskResult] = []
|
||||
|
||||
async def good_stream(task):
|
||||
yield _final_answer_event("hello")
|
||||
|
||||
agent.handle_task_stream = good_stream
|
||||
|
||||
async def stream_evolve(task, result, memory_store=None):
|
||||
stream_fired.append(result)
|
||||
|
||||
agent.evolve_after_task = stream_evolve
|
||||
|
||||
async for _ in agent.execute_stream(_make_task(task_id="stream-1")):
|
||||
pass
|
||||
await drain_pending_evolution_tasks()
|
||||
|
||||
async def rest_evolve(task, result, memory_store=None):
|
||||
rest_fired.append(result)
|
||||
|
||||
agent.evolve_after_task = rest_evolve
|
||||
await agent.on_task_complete(_make_task(task_id="rest-1"), {"content": "hello"})
|
||||
|
||||
assert len(stream_fired) == 1
|
||||
assert stream_fired[0].status == TaskStatus.COMPLETED
|
||||
assert len(rest_fired) == 1
|
||||
assert rest_fired[0].status == TaskStatus.COMPLETED
|
||||
|
||||
async def test_evolution_disabled_no_hooks(self):
|
||||
"""When _evolution_enabled is False, no hooks fire."""
|
||||
agent = _make_agent()
|
||||
agent._evolution_enabled = False
|
||||
fired: list[TaskResult] = []
|
||||
|
||||
async def record_evolve(task, result, memory_store=None):
|
||||
fired.append(result)
|
||||
|
||||
agent.evolve_after_task = record_evolve
|
||||
|
||||
async def good_stream(task):
|
||||
yield _final_answer_event("hello")
|
||||
|
||||
agent.handle_task_stream = good_stream
|
||||
|
||||
async for _ in agent.execute_stream(_make_task()):
|
||||
pass
|
||||
await drain_pending_evolution_tasks()
|
||||
|
||||
assert len(fired) == 0
|
||||
Loading…
Reference in New Issue