feat: complex-task-quality-loop (R1-R12) #22

Merged
fischer merged 13 commits from feat/complex-task-quality-loop into main 2026-07-05 22:31:22 +08:00
6 changed files with 736 additions and 133 deletions
Showing only changes of commit dd259153fa - Show all commits

View File

@ -7,17 +7,25 @@
- 新增 Agent 从写 150 行代码降为 10-20 行配置
"""
import asyncio
import json
import logging
import os
from collections.abc import AsyncGenerator, Awaitable
from typing import Callable, Coroutine
from datetime import datetime, timezone
from typing import Any, Callable, Coroutine
import yaml
from agentkit.core.base import BaseAgent
from agentkit.core.exceptions import ConfigValidationError
from agentkit.core.protocol import AgentCapability, CancellationToken, TaskMessage
from agentkit.core.exceptions import ConfigValidationError, TaskCancelledError
from agentkit.core.protocol import (
AgentCapability,
CancellationToken,
TaskMessage,
TaskResult,
TaskStatus,
)
from agentkit.core.react import ReActEvent
from agentkit.evolution.lifecycle import EvolutionMixin
from agentkit.evolution.reflector import Reflector
@ -28,6 +36,37 @@ from agentkit.tools.registry import ToolRegistry
logger = logging.getLogger(__name__)
# Evolution hook backpressure for execute_stream(): fire-and-forget with a cap
# and shutdown drain. ponytail: module-level set means the cap is global across
# agents, not per-agent; upgrade path is a per-agent semaphore if fairness matters.
_pending_evolution_tasks: set[asyncio.Task[None]] = set()
_evolution_dropped_count: int = 0
def _schedule_evolution(coro: Coroutine[Any, Any, None], cap: int) -> None:
"""Schedule a fire-and-forget evolution task with backpressure.
Drops + logs + increments the dropped counter when pending tasks reach ``cap``,
mirroring the portal webhook backpressure pattern (``max_concurrent * 2``).
"""
global _evolution_dropped_count
if len(_pending_evolution_tasks) >= cap:
_evolution_dropped_count += 1
logger.warning("Evolution backpressure cap reached (%d pending), dropping task", cap)
coro.close() # avoid 'coroutine never awaited' RuntimeWarning
return
task = asyncio.create_task(coro)
_pending_evolution_tasks.add(task)
task.add_done_callback(_pending_evolution_tasks.discard)
async def drain_pending_evolution_tasks() -> None:
"""Drain pending fire-and-forget evolution tasks on app shutdown."""
if not _pending_evolution_tasks:
return
logger.info("Draining %d pending evolution tasks", len(_pending_evolution_tasks))
await asyncio.gather(*_pending_evolution_tasks, return_exceptions=True)
class AgentConfig:
"""Agent 配置模型,从 YAML 或 Dict 构建"""
@ -510,6 +549,26 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
except Exception as e:
logger.warning(f"Evolution after task failure failed: {e}")
def _trigger_evolution_hooks(self, task: TaskMessage, result: TaskResult) -> None:
"""Schedule evolution after a streaming task (fire-and-forget, backpressure-capped).
Mirrors the sync on_task_complete/on_task_failed path but non-blocking so
streaming latency is unaffected. Evolution errors are swallowed inside
_evolve_safe and must never fail the stream. KTD-4: lifecycle parity with
execute() for the streaming path.
"""
if not self._evolution_enabled:
return
cap = max(2, self._config.max_concurrency * 2)
_schedule_evolution(self._evolve_safe(task, result), cap=cap)
async def _evolve_safe(self, task: TaskMessage, result: TaskResult) -> None:
"""Run evolve_after_task, swallowing errors (evolution must not fail stream)."""
try:
await self.evolve_after_task(task, result)
except Exception:
logger.warning("Evolution after stream task failed", exc_info=True)
def _bind_tools(self) -> None:
"""根据配置绑定工具"""
for tool_name in self._config.tools:
@ -658,9 +717,7 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
# ── 流式执行U3 ────────────────────────────────────────
def _build_llm_messages(
self, task: TaskMessage
) -> tuple[str | None, list[dict[str, str]]]:
def _build_llm_messages(self, task: TaskMessage) -> tuple[str | None, list[dict[str, str]]]:
"""Build (system_prompt, user_messages) from task + prompt template.
Shared by all _handle_*_stream methods to avoid duplicating the
@ -691,16 +748,78 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
P2 fix: 注册 CancellationToken _active_tokens使 cancel_task()
协作式取消流式任务原实现绕过 BaseAgent.execute()未注册 token
KTD-4: finally 中触发 on_task_complete/on_task_failed 进化钩子
execute() 保持生命周期对等使用 fire-and-forget + 背压上限
进化错误不得阻塞流式返回PlanExec/Reflexion 等子引擎的异常会向上
传播到此处 finally因此钩子集中在此触发子引擎无需重复触发
"""
token = CancellationToken()
self._active_tokens[task.task_id] = token
_stream_output: dict = {}
_stream_error: BaseException | None = None
_stream_completed = False
try:
await self._register_mcp_tools()
async for event in self.handle_task_stream(task):
if event.event_type == "final_answer":
_raw = event.data.get("output", "")
_stream_output = {"content": _raw} if isinstance(_raw, str) else _raw
yield event
_stream_completed = True
except asyncio.CancelledError as ce:
# Cancellation must propagate, but hooks still fire (U2 edge case).
_stream_error = ce
raise
except Exception as e:
_stream_error = e
raise
finally:
# async generator 的 finally 在 generator 关闭时执行GC/aclose/正常结束)
self._active_tokens.pop(task.task_id, None)
# KTD-4: lifecycle parity — fire evolution hooks fire-and-forget.
try:
now = datetime.now(timezone.utc)
if _stream_error is not None:
if isinstance(_stream_error, (asyncio.CancelledError, TaskCancelledError)):
status = TaskStatus.CANCELLED
err_msg = f"stream cancelled: {_stream_error}"
else:
status = TaskStatus.FAILED
err_msg = str(_stream_error)
result = TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=status,
output_data=None,
error_message=err_msg,
started_at=now,
completed_at=now,
)
elif _stream_completed:
result = TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.COMPLETED,
output_data=_stream_output,
error_message=None,
started_at=now,
completed_at=now,
)
else:
# Stream closed before completion (consumer aclose / GC).
result = TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.CANCELLED,
output_data=None,
error_message="stream closed before completion",
started_at=now,
completed_at=now,
)
self._trigger_evolution_hooks(task, result)
except Exception:
logger.debug("evolution hook scheduling failed", exc_info=True)
async def handle_task_stream(self, task: TaskMessage) -> AsyncGenerator[ReActEvent, None]:
"""根据 execution_mode / task_mode 流式分派,镜像 handle_task()。"""

View File

@ -126,7 +126,9 @@ class PlanExecEngine:
3. Replanner Phase: 失败时重规划
"""
self._confirmation_handler = confirmation_handler
effective_timeout = timeout_seconds if timeout_seconds is not None else self._default_timeout
effective_timeout = (
timeout_seconds if timeout_seconds is not None else self._default_timeout
)
try:
if effective_timeout > 0:
@ -198,6 +200,11 @@ class PlanExecEngine:
- "step_completed": 步骤执行完成
- "replanning": 触发重规划
- "final_answer": 最终结果
U2: 进化钩子on_task_complete/on_task_failed由外层
ConfigDrivenAgent.execute_stream() finally 集中触发本引擎仅向上
传播异常与 final_answer 事件不重复触发钩子避免双重进化
ponytail: 引擎无 evolution 上下文钩子上移至 agent 层是单触发点
"""
self._confirmation_handler = confirmation_handler
# Memory retrieval
@ -207,7 +214,9 @@ class PlanExecEngine:
top_k = (retrieval_config or {}).get("top_k", 5)
token_budget = (retrieval_config or {}).get("token_budget", 2000)
memory_context = await memory_retriever.get_context_string(
query=query, top_k=top_k, token_budget=token_budget,
query=query,
top_k=top_k,
token_budget=token_budget,
)
if memory_context:
if system_prompt:
@ -258,12 +267,14 @@ class PlanExecEngine:
},
)
state.trajectory.append(ReActStep(
state.trajectory.append(
ReActStep(
step=state.step_counter,
action="plan_generated",
content=f"Generated plan with {len(plan.steps)} steps",
tokens=0,
))
)
)
# Persist plan as Spec if spec_manager is provided
if self._spec_manager is not None:
@ -335,18 +346,24 @@ class PlanExecEngine:
},
)
state.trajectory.append(ReActStep(
state.trajectory.append(
ReActStep(
step=state.step_counter,
action="step_completed" if step_result.status == PlanStepStatus.COMPLETED else "step_failed",
action="step_completed"
if step_result.status == PlanStepStatus.COMPLETED
else "step_failed",
tool_name=step_name,
result=step_result.result,
tokens=0,
))
)
)
if trace_recorder is not None:
trace_recorder.record_step(
step=state.step_counter,
action="step_completed" if step_result.status == PlanStepStatus.COMPLETED else "step_failed",
action="step_completed"
if step_result.status == PlanStepStatus.COMPLETED
else "step_failed",
tool_name=step_name,
output_data=step_result.result,
error=step_result.error,
@ -372,19 +389,27 @@ class PlanExecEngine:
)
pipeline = self._plan_to_pipeline(current_plan, agent_name)
pipeline_result = self._plan_result_to_pipeline_result(current_plan, plan_result)
pipeline_result = self._plan_result_to_pipeline_result(
current_plan, plan_result
)
reflection_report = await self._reflector.reflect(pipeline, pipeline_result, replan_count)
revised_pipeline = await self._replanner.replan(pipeline, pipeline_result, reflection_report)
reflection_report = await self._reflector.reflect(
pipeline, pipeline_result, replan_count
)
revised_pipeline = await self._replanner.replan(
pipeline, pipeline_result, reflection_report
)
current_plan = self._pipeline_to_plan(revised_pipeline, plan.goal)
self._merge_completed_results(current_plan, plan_result)
state.trajectory.append(ReActStep(
state.trajectory.append(
ReActStep(
step=state.step_counter,
action="replanning",
content=f"Replanned (attempt {replan_count}): {reflection_report.root_cause}",
tokens=0,
))
)
)
continue
@ -404,12 +429,14 @@ class PlanExecEngine:
# 最终步骤
state.step_counter += 1
state.trajectory.append(ReActStep(
state.trajectory.append(
ReActStep(
step=state.step_counter,
action="final_answer",
content=output,
tokens=0,
))
)
)
yield ReActEvent(
event_type="final_answer",
@ -470,7 +497,9 @@ class PlanExecEngine:
top_k = (retrieval_config or {}).get("top_k", 5)
token_budget = (retrieval_config or {}).get("token_budget", 2000)
memory_context = await memory_retriever.get_context_string(
query=query, top_k=top_k, token_budget=token_budget,
query=query,
top_k=top_k,
token_budget=token_budget,
)
if memory_context:
if system_prompt:
@ -509,20 +538,33 @@ class PlanExecEngine:
# Emit plan_generated event
if self._step_event_callback:
try:
await self._step_event_callback("plan_generated", {
await self._step_event_callback(
"plan_generated",
{
"plan_id": plan.plan_id,
"goal": plan.goal,
"steps": [s.to_dict() for s in plan.steps],
})
except (RuntimeError, ValueError, TypeError, KeyError, AttributeError, ConnectionError, asyncio.TimeoutError) as e:
},
)
except (
RuntimeError,
ValueError,
TypeError,
KeyError,
AttributeError,
ConnectionError,
asyncio.TimeoutError,
) as e:
logger.warning(f"Step event callback failed: {e}")
trajectory.append(ReActStep(
trajectory.append(
ReActStep(
step=1,
action="plan_generated",
content=f"Generated plan with {len(plan.steps)} steps",
tokens=0,
))
)
)
# Persist plan as Spec if spec_manager is provided
if self._spec_manager is not None:
@ -530,12 +572,23 @@ class PlanExecEngine:
self._spec_manager.create(spec)
if self._step_event_callback:
try:
await self._step_event_callback("spec_created", {
await self._step_event_callback(
"spec_created",
{
"spec_id": spec.spec_id,
"goal": spec.goal,
"num_steps": len(spec.steps),
})
except (RuntimeError, ValueError, TypeError, KeyError, AttributeError, ConnectionError, asyncio.TimeoutError) as e:
},
)
except (
RuntimeError,
ValueError,
TypeError,
KeyError,
AttributeError,
ConnectionError,
asyncio.TimeoutError,
) as e:
logger.warning(f"Step event callback failed: {e}")
if trace_recorder is not None:
@ -572,12 +625,14 @@ class PlanExecEngine:
else:
trace_outcome = "success"
trajectory.append(ReActStep(
trajectory.append(
ReActStep(
step=len(trajectory) + 1,
action="final_answer",
content=output,
tokens=0,
))
)
)
return ReActResult(
output=output,
@ -666,32 +721,53 @@ class PlanExecEngine:
for sid, step_result in plan_result.step_results.items():
plan_step = current_plan.get_step(sid)
step_name = plan_step.name if plan_step else sid
trajectory.append(ReActStep(
trajectory.append(
ReActStep(
step=len(trajectory) + 1,
action="step_completed" if step_result.status == PlanStepStatus.COMPLETED else "step_failed",
action="step_completed"
if step_result.status == PlanStepStatus.COMPLETED
else "step_failed",
tool_name=step_name,
result=step_result.result,
tokens=0,
))
)
)
# Emit step event callback
if self._step_event_callback:
event_type = "step_completed" if step_result.status == PlanStepStatus.COMPLETED else "step_failed"
event_type = (
"step_completed"
if step_result.status == PlanStepStatus.COMPLETED
else "step_failed"
)
try:
await self._step_event_callback(event_type, {
await self._step_event_callback(
event_type,
{
"step_id": sid,
"step_name": step_name,
"status": step_result.status.value,
"result": step_result.result,
"error": step_result.error,
})
except (RuntimeError, ValueError, TypeError, KeyError, AttributeError, ConnectionError, asyncio.TimeoutError) as e:
},
)
except (
RuntimeError,
ValueError,
TypeError,
KeyError,
AttributeError,
ConnectionError,
asyncio.TimeoutError,
) as e:
logger.warning(f"Step event callback failed: {e}")
if trace_recorder is not None:
trace_recorder.record_step(
step=len(trajectory),
action="step_completed" if step_result.status == PlanStepStatus.COMPLETED else "step_failed",
action="step_completed"
if step_result.status == PlanStepStatus.COMPLETED
else "step_failed",
tool_name=step_name,
output_data=step_result.result,
error=step_result.error,
@ -714,10 +790,14 @@ class PlanExecEngine:
pipeline_result = self._plan_result_to_pipeline_result(current_plan, plan_result)
# 反思
reflection_report = await self._reflector.reflect(pipeline, pipeline_result, replan_count)
reflection_report = await self._reflector.reflect(
pipeline, pipeline_result, replan_count
)
# 重规划
revised_pipeline = await self._replanner.replan(pipeline, pipeline_result, reflection_report)
revised_pipeline = await self._replanner.replan(
pipeline, pipeline_result, reflection_report
)
# 将修正后的 Pipeline 转回 ExecutionPlan
current_plan = self._pipeline_to_plan(revised_pipeline, plan.goal)
@ -728,20 +808,33 @@ class PlanExecEngine:
# Emit replanning event
if self._step_event_callback:
try:
await self._step_event_callback("replanning", {
await self._step_event_callback(
"replanning",
{
"replan_count": replan_count,
"root_cause": reflection_report.root_cause,
"new_plan_id": current_plan.plan_id,
})
except (RuntimeError, ValueError, TypeError, KeyError, AttributeError, ConnectionError, asyncio.TimeoutError) as e:
},
)
except (
RuntimeError,
ValueError,
TypeError,
KeyError,
AttributeError,
ConnectionError,
asyncio.TimeoutError,
) as e:
logger.warning(f"Step event callback failed: {e}")
trajectory.append(ReActStep(
trajectory.append(
ReActStep(
step=len(trajectory) + 1,
action="replanning",
content=f"Replanned (attempt {replan_count}): {reflection_report.root_cause}",
tokens=0,
))
)
)
if trace_recorder is not None:
trace_recorder.record_step(
@ -850,13 +943,15 @@ class PlanExecEngine:
stages = []
for step in plan.steps:
stages.append(PipelineStage(
stages.append(
PipelineStage(
name=step.step_id,
agent=agent_name,
action=step.description,
depends_on=step.dependencies,
inputs=step.input_data,
))
)
)
return Pipeline(
name=f"plan_{plan.plan_id}",
@ -904,14 +999,16 @@ class PlanExecEngine:
"""将修正后的 Pipeline 转回 ExecutionPlan"""
steps = []
for stage in pipeline.stages:
steps.append(PlanStep(
steps.append(
PlanStep(
step_id=stage.name,
name=stage.name,
description=stage.action,
dependencies=stage.depends_on,
input_data=stage.inputs,
required_skills=[],
))
)
)
plan = ExecutionPlan(
goal=goal,
@ -944,10 +1041,12 @@ class PlanExecEngine:
for step in plan.steps:
sr = plan_result.step_results.get(step.step_id)
if sr and sr.status == PlanStepStatus.COMPLETED and sr.result:
completed_results.append({
completed_results.append(
{
"step": step.name,
"result": sr.result,
})
}
)
if not completed_results:
# 没有成功步骤
@ -964,7 +1063,11 @@ class PlanExecEngine:
# 简单聚合:将所有成功步骤结果格式化
parts = []
for item in completed_results:
result_str = json.dumps(item["result"], ensure_ascii=False) if isinstance(item["result"], dict) else str(item["result"])
result_str = (
json.dumps(item["result"], ensure_ascii=False)
if isinstance(item["result"], dict)
else str(item["result"])
)
parts.append(f"**{item['step']}**: {result_str}")
return "\n\n".join(parts)
@ -1068,7 +1171,9 @@ class _ReActStepAgent:
dep_results = input_data.get("dependency_results", {})
# 构建步骤 prompt
prompt_parts = [f"Execute the following task step:\n\nStep: {step_name}\nDescription: {step_description}"]
prompt_parts = [
f"Execute the following task step:\n\nStep: {step_name}\nDescription: {step_description}"
]
if dep_results:
prompt_parts.append(
f"\nResults from previous steps:\n{json.dumps(dep_results, ensure_ascii=False, indent=2)}"

View File

@ -78,7 +78,9 @@ class ReflexionEngine:
if max_reflections < 1:
raise ValueError(f"max_reflections must be >= 1, got {max_reflections}")
if not 0.0 <= quality_threshold <= 1.0:
raise ValueError(f"quality_threshold must be between 0.0 and 1.0, got {quality_threshold}")
raise ValueError(
f"quality_threshold must be between 0.0 and 1.0, got {quality_threshold}"
)
self._llm_gateway = llm_gateway
self._max_steps = max_steps
@ -116,7 +118,9 @@ class ReflexionEngine:
reflect_model: 用于生成反思的模型默认与 evaluate_model 相同
其余参数与 ReActEngine.execute() 相同
"""
effective_timeout = timeout_seconds if timeout_seconds is not None else self._default_timeout
effective_timeout = (
timeout_seconds if timeout_seconds is not None else self._default_timeout
)
act_model = model
effective_evaluate_model = evaluate_model or act_model
effective_reflect_model = reflect_model or effective_evaluate_model
@ -187,7 +191,9 @@ class ReflexionEngine:
reflect_model: str = "default",
) -> ReflexionResult:
# Telemetry
agent_request_counter().add(1, {"agent.name": agent_name, "agent.type": task_type or "reflexion"})
agent_request_counter().add(
1, {"agent.name": agent_name, "agent.type": task_type or "reflexion"}
)
_span_cm = None
_span = None
@ -348,6 +354,11 @@ class ReflexionEngine:
"""执行 Reflexion 循环,以流式事件形式返回
在每次 ReAct 执行评估反思和重试时发出事件
U2: 进化钩子on_task_complete/on_task_failed由外层
ConfigDrivenAgent.execute_stream() finally 集中触发本引擎仅向上
传播异常与 final_answer 事件不重复触发钩子避免双重进化
ponytail: 引擎无 evolution 上下文钩子上移至 agent 层是单触发点
"""
act_model = model
effective_evaluate_model = evaluate_model or act_model
@ -600,9 +611,7 @@ class ReflexionEngine:
def _parse_evaluation_score(self, content: str) -> float:
"""从 LLM 响应中解析评估分数"""
# 尝试从代码块中提取 JSON
json_match = re.search(
r"```(?:json)?\s*\n?(.*?)\n?```", content, re.DOTALL
)
json_match = re.search(r"```(?:json)?\s*\n?(.*?)\n?```", content, re.DOTALL)
if json_match:
try:
data = json.loads(json_match.group(1))

View File

@ -81,7 +81,14 @@ def _build_llm_gateway(config: ServerConfig) -> LLMGateway:
backend=config.usage_store.get("backend", "memory"),
redis_url=config.usage_store.get("redis_url", "redis://localhost:6379"),
)
except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError) as e:
except (
ConnectionError,
OSError,
asyncio.TimeoutError,
ValueError,
KeyError,
RuntimeError,
) as e:
logger.warning(f"Failed to initialize usage store: {e}, using in-memory")
gateway = LLMGateway(config=config.llm_config, usage_store=usage_store)
@ -478,7 +485,14 @@ async def lifespan(app: FastAPI):
_row = await _cur.fetchone()
if _row is not None:
default_cal_user_id = str(_row["id"])
except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError):
except (
ConnectionError,
OSError,
asyncio.TimeoutError,
ValueError,
KeyError,
RuntimeError,
):
logger.debug("Could not resolve default user_id for CalendarTool", exc_info=True)
calendar_tool = CalendarTool(
@ -505,7 +519,9 @@ async def lifespan(app: FastAPI):
except (ValueError, KeyError, RuntimeError, AttributeError):
# ponytail: log at debug — CalendarTool double-registration
# is expected on reload, but silent pass hides real errors.
logger.debug("CalendarTool already registered or registration failed", exc_info=True)
logger.debug(
"CalendarTool already registered or registration failed", exc_info=True
)
# Strip any existing "## 可用工具" section to avoid
# duplicate tool blocks in the system prompt.
base_prompt = getattr(default_agent, "_system_prompt", None) or (
@ -570,7 +586,14 @@ async def lifespan(app: FastAPI):
from agentkit.rag_platform.store import ensure_tables
await ensure_tables(rag_database_url)
except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError):
except (
ConnectionError,
OSError,
asyncio.TimeoutError,
ValueError,
KeyError,
RuntimeError,
):
logger.exception("Failed to ensure rag_platform tables")
# KBStore — KB/Document persistence
@ -693,6 +716,14 @@ async def lifespan(app: FastAPI):
except (RuntimeError, asyncio.TimeoutError, ConnectionError, OSError):
logger.debug("close_all_adapters 异常已忽略")
# U2: drain pending fire-and-forget evolution tasks from execute_stream()
try:
from agentkit.core.config_driven import drain_pending_evolution_tasks
await drain_pending_evolution_tasks()
except Exception:
logger.debug("drain_pending_evolution_tasks 异常已忽略", exc_info=True)
def _on_config_change(app: FastAPI, config: ServerConfig) -> None:
"""Handle config change by reloading affected components.
@ -736,7 +767,14 @@ def _on_config_change(app: FastAPI, config: ServerConfig) -> None:
if hasattr(app.state, "agent_pool") and app.state.agent_pool is not None:
app.state.agent_pool._llm_gateway = new_gateway
logger.info(f"LLM Gateway reloaded (config v{current_version})")
except (ValueError, TypeError, KeyError, RuntimeError, ConnectionError, OSError) as e:
except (
ValueError,
TypeError,
KeyError,
RuntimeError,
ConnectionError,
OSError,
) as e:
logger.error(f"Failed to reload LLM Gateway: {e}")
# Reload skills if skill paths changed
@ -1185,7 +1223,15 @@ def create_app(
try:
epi_session_factory = create_episodic_session_factory(database_url)
epi_model = EpisodeModel
except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError, ImportError) as db_err:
except (
ConnectionError,
OSError,
asyncio.TimeoutError,
ValueError,
KeyError,
RuntimeError,
ImportError,
) as db_err:
import logging as _log
_log.getLogger(__name__).warning(

View File

@ -577,6 +577,9 @@ async def chat(request: ChatRequest, req: Request, _auth: None = Depends(_verify
collected_output: list[str] = []
try:
# U2 verify: calls react_engine.execute_stream directly, bypassing
# ConfigDrivenAgent.execute_stream — evolution hooks NOT propagated
# here. Routing through agent.execute_stream is tracked separately.
async for event in react_engine.execute_stream(
messages=messages,
tools=tools,
@ -698,6 +701,9 @@ async def chat_stream(request: ChatRequest, req: Request, _auth: None = Depends(
collected_output: list[str] = []
try:
# U2 verify: calls react_engine.execute_stream directly, bypassing
# ConfigDrivenAgent.execute_stream — evolution hooks NOT propagated
# here. Routing through agent.execute_stream is tracked separately.
async for event in react_engine.execute_stream(
messages=messages,
tools=tools,
@ -812,9 +818,7 @@ async def _conversation_has_board_started(conversation_id: str) -> bool:
list endpoint.
"""
try:
return await _conversation_store.has_message_with_type(
conversation_id, "board_started"
)
return await _conversation_store.has_message_with_type(conversation_id, "board_started")
except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError):
logger.warning("is_board lookup failed for %s", conversation_id, exc_info=True)
return False
@ -881,10 +885,7 @@ async def get_conversation(
"messages": [_hydrate_persisted_message(conv.id, i, m) for i, m in enumerate(history)],
"created_at": conv.created_at.isoformat(),
"updated_at": conv.updated_at.isoformat(),
"is_board": any(
(m.metadata or {}).get("message_type") == "board_started"
for m in history
),
"is_board": any((m.metadata or {}).get("message_type") == "board_started" for m in history),
}
@ -998,6 +999,9 @@ async def _execute_react_background(
):
logger.warning("Failed to update TaskStore RUNNING", exc_info=True)
# U2 verify: calls react_engine.execute_stream directly, bypassing
# ConfigDrivenAgent.execute_stream — evolution hooks NOT propagated
# here. Routing through agent.execute_stream is tracked separately.
async for event in react_engine.execute_stream(
messages=messages,
tools=tools,

View File

@ -0,0 +1,320 @@
"""U2 tests: execute_stream evolution hook wiring (OQ6 fix).
Verifies that ConfigDrivenAgent.execute_stream() fires evolution hooks
(on_task_complete / on_task_failed) in its finally block with lifecycle
parity to the sync execute() path. Covers happy path, failure, cancellation,
early close, evolution-error suppression, backpressure cap, REST/stream
parity, and evolution-disabled no-op.
"""
import asyncio
import pytest
from agentkit.core.config_driven import (
AgentConfig,
ConfigDrivenAgent,
drain_pending_evolution_tasks,
)
from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus
from agentkit.core.react import ReActEvent
# ── Helpers ──────────────────────────────────────────────
def _make_task(**overrides) -> TaskMessage:
defaults = dict(
task_id="stream-task-001",
agent_name="stream_agent",
task_type="generate",
priority=1,
input_data={"query": "hello"},
callback_url=None,
created_at=None,
)
defaults.update(overrides)
return TaskMessage.from_dict(defaults)
def _make_agent(max_concurrency: int = 1) -> ConfigDrivenAgent:
config = AgentConfig.from_dict(
{
"name": "stream_agent",
"agent_type": "content_generation",
"task_mode": "llm_generate",
"prompt": {
"identity": "test agent",
"instructions": "do the thing",
"output_format": "text",
},
"max_concurrency": max_concurrency,
}
)
agent = ConfigDrivenAgent(config=config)
agent._evolution_enabled = True
return agent
def _final_answer_event(output: str = "hello") -> ReActEvent:
return ReActEvent(
event_type="final_answer",
step=0,
data={"output": output},
)
@pytest.fixture(autouse=True)
async def _isolate_evolution_state():
"""Reset module-level evolution state before each test, drain after.
Without this, stuck tasks from a prior test would inflate the pending
set and break backpressure assertions in later tests.
"""
import agentkit.core.config_driven as cd
for task in list(cd._pending_evolution_tasks):
task.cancel()
if cd._pending_evolution_tasks:
await asyncio.gather(*cd._pending_evolution_tasks, return_exceptions=True)
cd._pending_evolution_tasks.clear()
cd._evolution_dropped_count = 0
yield
await drain_pending_evolution_tasks()
# ── Happy path ───────────────────────────────────────────
class TestExecuteStreamHooks:
async def test_success_fires_on_task_complete(self):
"""Stream completion fires evolve_after_task with COMPLETED status."""
agent = _make_agent()
fired: list[TaskResult] = []
async def record_evolve(task, result, memory_store=None):
fired.append(result)
agent.evolve_after_task = record_evolve
async def good_stream(task):
yield _final_answer_event("hello world")
agent.handle_task_stream = good_stream
events = []
async for event in agent.execute_stream(_make_task()):
events.append(event)
await drain_pending_evolution_tasks()
assert len(events) == 1
assert events[0].event_type == "final_answer"
assert len(fired) == 1
assert fired[0].status == TaskStatus.COMPLETED
assert fired[0].output_data == {"content": "hello world"}
async def test_failure_fires_on_task_failed(self):
"""Stream exception fires evolve_after_task with FAILED status."""
agent = _make_agent()
fired: list[TaskResult] = []
async def record_evolve(task, result, memory_store=None):
fired.append(result)
agent.evolve_after_task = record_evolve
async def failing_stream(task):
yield _final_answer_event("partial") # yield once before failing
raise RuntimeError("stream blew up")
agent.handle_task_stream = failing_stream
with pytest.raises(RuntimeError, match="stream blew up"):
async for _ in agent.execute_stream(_make_task()):
pass
await drain_pending_evolution_tasks()
assert len(fired) == 1
assert fired[0].status == TaskStatus.FAILED
assert "stream blew up" in (fired[0].error_message or "")
# ── Edge cases ───────────────────────────────────────────
class TestExecuteStreamEdgeCases:
async def test_cancellation_fires_cancelled_status(self):
"""Stream cancelled mid-flight fires hooks with CANCELLED status."""
agent = _make_agent()
fired: list[TaskResult] = []
async def record_evolve(task, result, memory_store=None):
fired.append(result)
agent.evolve_after_task = record_evolve
started = asyncio.Event()
async def slow_stream(task):
started.set()
await asyncio.sleep(60)
yield _final_answer_event("never reached")
agent.handle_task_stream = slow_stream
async def consume():
async for _ in agent.execute_stream(_make_task()):
pass
consumer = asyncio.create_task(consume())
await started.wait()
await asyncio.sleep(0.05) # let it settle into sleep(60)
consumer.cancel()
with pytest.raises(asyncio.CancelledError):
await consumer
await drain_pending_evolution_tasks()
assert len(fired) == 1
assert fired[0].status == TaskStatus.CANCELLED
async def test_stream_closed_early_fires_cancelled(self):
"""Consumer aclose() before final_answer fires CANCELLED status."""
agent = _make_agent()
fired: list[TaskResult] = []
async def record_evolve(task, result, memory_store=None):
fired.append(result)
agent.evolve_after_task = record_evolve
async def blocking_stream(task):
yield ReActEvent(event_type="thinking", step=0, data={"content": "thinking..."})
await asyncio.sleep(60)
yield _final_answer_event("late")
agent.handle_task_stream = blocking_stream
gen = agent.execute_stream(_make_task())
first = await gen.__anext__()
assert first.event_type == "thinking"
await gen.aclose()
await drain_pending_evolution_tasks()
assert len(fired) == 1
assert fired[0].status == TaskStatus.CANCELLED
assert "stream closed before completion" in (fired[0].error_message or "")
async def test_evolution_error_does_not_propagate(self):
"""Evolution task error is swallowed — stream completes normally."""
agent = _make_agent()
async def failing_evolve(task, result, memory_store=None):
raise RuntimeError("evolution exploded")
agent.evolve_after_task = failing_evolve
async def good_stream(task):
yield _final_answer_event("ok")
agent.handle_task_stream = good_stream
events = []
async for event in agent.execute_stream(_make_task()):
events.append(event)
# drain must not raise despite evolution error
await drain_pending_evolution_tasks()
assert len(events) == 1
assert events[0].data.get("output") == "ok"
async def test_backpressure_cap_drops(self):
"""When pending evolution tasks hit cap, excess is dropped + counted."""
agent = _make_agent(max_concurrency=1) # cap = max(2, 1*2) = 2
block = asyncio.Event()
async def stuck_evolve(task, result, memory_store=None):
await block.wait()
agent.evolve_after_task = stuck_evolve
async def good_stream(task):
yield _final_answer_event("ok")
agent.handle_task_stream = good_stream
import agentkit.core.config_driven as cd
# Fire 3 streams — first 2 fill the cap (stuck), 3rd is dropped
for i in range(3):
async for _ in agent.execute_stream(_make_task(task_id=f"bp-{i}")):
pass
await asyncio.sleep(0) # yield to let evolution tasks start
assert cd._evolution_dropped_count == 1
# Cleanup: release stuck tasks and drain
block.set()
await drain_pending_evolution_tasks()
# ── Parity & disabled ────────────────────────────────────
class TestExecuteStreamParity:
async def test_parity_rest_vs_stream(self):
"""Both REST on_task_complete and execute_stream fire COMPLETED evolve."""
agent = _make_agent()
stream_fired: list[TaskResult] = []
rest_fired: list[TaskResult] = []
async def good_stream(task):
yield _final_answer_event("hello")
agent.handle_task_stream = good_stream
async def stream_evolve(task, result, memory_store=None):
stream_fired.append(result)
agent.evolve_after_task = stream_evolve
async for _ in agent.execute_stream(_make_task(task_id="stream-1")):
pass
await drain_pending_evolution_tasks()
async def rest_evolve(task, result, memory_store=None):
rest_fired.append(result)
agent.evolve_after_task = rest_evolve
await agent.on_task_complete(_make_task(task_id="rest-1"), {"content": "hello"})
assert len(stream_fired) == 1
assert stream_fired[0].status == TaskStatus.COMPLETED
assert len(rest_fired) == 1
assert rest_fired[0].status == TaskStatus.COMPLETED
async def test_evolution_disabled_no_hooks(self):
"""When _evolution_enabled is False, no hooks fire."""
agent = _make_agent()
agent._evolution_enabled = False
fired: list[TaskResult] = []
async def record_evolve(task, result, memory_store=None):
fired.append(result)
agent.evolve_after_task = record_evolve
async def good_stream(task):
yield _final_answer_event("hello")
agent.handle_task_stream = good_stream
async for _ in agent.execute_stream(_make_task()):
pass
await drain_pending_evolution_tasks()
assert len(fired) == 0