fischer-agentkit/src/agentkit/core/config_driven.py

1450 lines
59 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""ConfigDrivenAgent - 配置驱动的 Agent 定义
核心设计:
- 从 YAML/Dict 配置自动组装 AgentPrompt + LLM + Tool + Memory
- 支持三种任务模式llm_generate / tool_call / custom
- v2: 支持 SkillConfig + ReAct 执行模式 + LLMGateway + Quality Gate
- 新增 Agent 从写 150 行代码降为 10-20 行配置
"""
import asyncio
import json
import logging
import os
from collections.abc import AsyncGenerator, Awaitable
from datetime import datetime, timezone
from typing import Any, Callable, Coroutine
import yaml
from agentkit.core.base import BaseAgent
from agentkit.core.exceptions import ConfigValidationError, TaskCancelledError
from agentkit.core.protocol import (
AgentCapability,
CancellationToken,
TaskMessage,
TaskResult,
TaskStatus,
)
from agentkit.core.react import ReActEvent
from agentkit.evolution.lifecycle import EvolutionMixin
from agentkit.evolution.reflector import Reflector
from agentkit.prompts.section import PromptSection
from agentkit.prompts.template import PromptTemplate
from agentkit.tools.base import Tool
from agentkit.tools.registry import ToolRegistry
logger = logging.getLogger(__name__)
# Evolution hook backpressure for execute_stream(): fire-and-forget with a cap
# and shutdown drain. ponytail: module-level set means the cap is global across
# agents, not per-agent; upgrade path is a per-agent semaphore if fairness matters.
_pending_evolution_tasks: set[asyncio.Task[None]] = set()
_evolution_dropped_count: int = 0
def _schedule_evolution(coro: Coroutine[Any, Any, None], cap: int) -> None:
"""Schedule a fire-and-forget evolution task with backpressure.
Drops + logs + increments the dropped counter when pending tasks reach ``cap``,
mirroring the portal webhook backpressure pattern (``max_concurrent * 2``).
"""
global _evolution_dropped_count
if len(_pending_evolution_tasks) >= cap:
_evolution_dropped_count += 1
logger.warning("Evolution backpressure cap reached (%d pending), dropping task", cap)
coro.close() # avoid 'coroutine never awaited' RuntimeWarning
return
task = asyncio.create_task(coro)
_pending_evolution_tasks.add(task)
task.add_done_callback(_pending_evolution_tasks.discard)
async def drain_pending_evolution_tasks() -> None:
"""Drain pending fire-and-forget evolution tasks on app shutdown."""
if not _pending_evolution_tasks:
return
logger.info("Draining %d pending evolution tasks", len(_pending_evolution_tasks))
await asyncio.gather(*_pending_evolution_tasks, return_exceptions=True)
class AgentConfig:
"""Agent 配置模型,从 YAML 或 Dict 构建"""
def __init__(
self,
name: str,
agent_type: str,
version: str = "1.0.0",
description: str = "",
task_mode: str = "llm_generate",
supported_tasks: list[str] | None = None,
max_concurrency: int = 1,
input_schema: dict[str, object] | None = None,
output_schema: dict[str, object] | None = None,
prompt: dict[str, str] | None = None,
llm: dict[str, object] | None = None,
tools: list[str] | None = None,
memory: dict[str, object] | None = None,
custom_handler: str | None = None,
):
self.name = name
self.agent_type = agent_type
self.version = version
self.description = description
self.task_mode = task_mode
self.supported_tasks = supported_tasks or [agent_type]
self.max_concurrency = max_concurrency
self.input_schema = input_schema
self.output_schema = output_schema
self.prompt = prompt or {}
self.llm = llm or {}
self.tools = tools or []
self.memory = memory or {}
self.custom_handler = custom_handler
self._validate()
def _validate(self) -> None:
"""校验配置合法性"""
valid_modes = {"llm_generate", "tool_call", "custom"}
if self.task_mode not in valid_modes:
raise ConfigValidationError(
agent_name=self.name,
key="task_mode",
reason=f"Invalid task_mode '{self.task_mode}', must be one of {valid_modes}",
)
if self.task_mode == "llm_generate" and not self.prompt:
raise ConfigValidationError(
agent_name=self.name,
key="prompt",
reason="llm_generate mode requires 'prompt' configuration",
)
if self.task_mode == "tool_call" and not self.tools:
raise ConfigValidationError(
agent_name=self.name,
key="tools",
reason="tool_call mode requires at least one tool in 'tools' list",
)
if self.task_mode == "custom" and not self.custom_handler:
raise ConfigValidationError(
agent_name=self.name,
key="custom_handler",
reason="custom mode requires 'custom_handler' (dotted path to callable)",
)
@classmethod
def from_dict(cls, data: dict[str, object]) -> "AgentConfig":
"""从字典创建配置"""
return cls(
name=data["name"],
agent_type=data["agent_type"],
version=data.get("version", "1.0.0"),
description=data.get("description", ""),
task_mode=data.get("task_mode", "llm_generate"),
supported_tasks=data.get("supported_tasks"),
max_concurrency=data.get("max_concurrency", 1),
input_schema=data.get("input_schema"),
output_schema=data.get("output_schema"),
prompt=data.get("prompt"),
llm=data.get("llm"),
tools=data.get("tools"),
memory=data.get("memory"),
custom_handler=data.get("custom_handler"),
)
@classmethod
def from_yaml(cls, path: str) -> "AgentConfig":
"""从 YAML 文件加载配置"""
with open(path, "r", encoding="utf-8") as f:
data = yaml.safe_load(f)
if not isinstance(data, dict):
raise ConfigValidationError(
agent_name="unknown",
key="config",
reason=f"YAML config must be a mapping, got {type(data)}",
)
return cls.from_dict(data)
def to_dict(self) -> dict[str, object]:
"""序列化为字典"""
d = {
"name": self.name,
"agent_type": self.agent_type,
"version": self.version,
"description": self.description,
"task_mode": self.task_mode,
"supported_tasks": self.supported_tasks,
"max_concurrency": self.max_concurrency,
}
if self.input_schema:
d["input_schema"] = self.input_schema
if self.output_schema:
d["output_schema"] = self.output_schema
if self.prompt:
d["prompt"] = self.prompt
if self.llm:
d["llm"] = self.llm
if self.tools:
d["tools"] = self.tools
if self.memory:
d["memory"] = self.memory
if self.custom_handler:
d["custom_handler"] = self.custom_handler
return d
class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
"""配置驱动的 Agent
从 YAML/Dict 配置自动组装,支持三种任务模式:
- llm_generate: 渲染 Prompt → 调用 LLM → 解析 JSON 输出
- tool_call: 调用注册的 Tool 并返回结果
- custom: 自定义 handler 函数
v2 增强:
- 接受 SkillConfig自动创建 Skill 并启用 ReAct 模式
- llm_gateway 参数直接传入 LLMGateway
- llm_client 参数自动包装为 LLMGateway向后兼容
- Quality Gate 自动集成
示例 YAML 配置::
name: content_generator
agent_type: content_generation
task_mode: llm_generate
description: "内容生成 Agent"
prompt:
identity: "你是一个专业的内容生成助手"
instructions: "根据用户需求生成高质量内容"
output_format: "以 JSON 格式输出"
llm:
model: "gpt-4"
temperature: 0.7
tools:
- retrieve_knowledge
"""
# Security: whitelist of allowed module prefixes for dynamic handler import
_ALLOWED_HANDLER_PREFIXES = (
"agentkit.",
"app.agent_framework.",
)
def __init__(
self,
config: AgentConfig,
tool_registry: ToolRegistry | None = None,
llm_client: object | None = None,
custom_handlers: dict[str, Callable[..., Coroutine]] | None = None,
llm_gateway: object | None = None, # NEW v2 param: LLMGateway
mcp_servers: dict[str, str] | None = None, # NEW v2 param: MCP server URLs
compressor: object | None = None, # CompressionStrategy | None
):
# v2: If SkillConfig, extract skill info
from agentkit.skills.base import SkillConfig, Skill
self._skill_config: SkillConfig | None = None
self._skill_instance: Skill | None = None
if isinstance(config, SkillConfig):
self._skill_config = config
self._skill_instance = Skill(config=config)
self._config = config
self._tool_registry = tool_registry or ToolRegistry()
self._llm_client = llm_client
self._custom_handlers = custom_handlers or {}
self._prompt_template: PromptTemplate | None = None
# Call super().__init__() first
super().__init__(
name=config.name,
agent_type=config.agent_type,
version=config.version,
)
# v2: Backward compat — wrap llm_client into LLMGateway if no gateway provided
if llm_gateway is not None:
self._llm_gateway = llm_gateway
elif llm_client is not None:
self._llm_gateway = self._wrap_llm_client(llm_client)
else:
self._llm_gateway = None
# v2: Set skill on base agent
if self._skill_instance:
self._skill = self._skill_instance
# v2: Initialize ReAct engine if gateway available
self._react_engine = None
if self._llm_gateway:
from agentkit.core.react import ReActEngine
self._react_engine = ReActEngine(
llm_gateway=self._llm_gateway,
max_steps=getattr(config, "max_steps", 5),
)
# v2: Initialize Quality Gate (always available)
from agentkit.quality.gate import QualityGate
self._quality_gate = QualityGate()
# v2: Initialize Evolution if configured
evolution_config = getattr(config, "evolution", None)
if evolution_config is not None:
# Support both dict and EvolutionConfig
if isinstance(evolution_config, dict):
is_enabled = evolution_config.get("enabled", False)
else:
is_enabled = getattr(evolution_config, "enabled", False)
else:
is_enabled = False
if is_enabled:
reflector = Reflector()
EvolutionMixin.__init__(
self,
reflector=reflector,
)
self._evolution_enabled = True
else:
EvolutionMixin.__init__(self) # Initialize with no components
self._evolution_enabled = False
# v2: Initialize Output Standardizer
from agentkit.quality.output import OutputStandardizer
self._output_standardizer = OutputStandardizer()
# v2: Store compressor for ReAct engine
self._compressor = compressor
# 从配置构建 Prompt 模板
if config.prompt:
sections = PromptSection(
identity=config.prompt.get("identity", ""),
context=config.prompt.get("context", ""),
instructions=config.prompt.get("instructions", ""),
constraints=config.prompt.get("constraints", ""),
output_format=config.prompt.get("output_format", ""),
examples=config.prompt.get("examples", ""),
)
self._prompt_template = PromptTemplate(
sections=sections,
name=config.name,
version=config.version,
)
# 从配置绑定 Tool
self._bind_tools()
# v2: Merge Skill-bound tools into Agent's tool list
if self._skill_instance and self._skill_instance.tools:
for tool in self._skill_instance.tools:
if not any(t.name == tool.name for t in self._tools):
self.use_tool(tool)
logger.info(f"Merged skill tool '{tool.name}' into agent '{self.name}'")
# v2: Register MCP tools if mcp_servers provided
self._mcp_clients: list[object] = []
self._mcp_servers: dict[str, str] = mcp_servers or {}
self._mcp_tools_registered = False
# Memory integration: 从 config.memory 自动实例化 MemoryRetriever
self._memory_retriever: object | None = None
if config.memory:
try:
from agentkit.memory.retriever import MemoryRetriever
from agentkit.memory.working import WorkingMemory
from agentkit.memory.semantic import SemanticMemory
from agentkit.memory.http_rag import HttpRAGService
working = None
episodic = None
semantic = None
if config.memory.get("working", {}).get("enabled"):
import redis.asyncio as aioredis
redis_url = config.memory["working"].get("redis_url", "redis://localhost:6379")
redis_client = aioredis.from_url(redis_url, decode_responses=True)
working = WorkingMemory(redis=redis_client)
if config.memory.get("episodic", {}).get("enabled"):
from agentkit.memory.episodic import EpisodicMemory
from agentkit.memory.embedder import OpenAIEmbedder, EmbeddingCache
epi_conf = config.memory["episodic"]
embedder = None
if epi_conf.get("embedder_api_key") or os.environ.get("OPENAI_API_KEY"):
cache = EmbeddingCache(
max_size=epi_conf.get("cache_max_size", 1000),
ttl=epi_conf.get("cache_ttl", 3600),
)
embedder = OpenAIEmbedder(
api_key=epi_conf.get("embedder_api_key"),
model=epi_conf.get("embedder_model", "text-embedding-3-small"),
base_url=epi_conf.get("embedder_base_url"),
cache=cache,
)
episodic = EpisodicMemory(
session_factory=None, # Set externally when DB session is available
episodic_model=None, # Set externally when ORM model is available
embedder=embedder,
decay_rate=epi_conf.get("decay_rate", 0.01),
alpha=epi_conf.get("alpha", 0.7),
retrieve_limit=epi_conf.get("retrieve_limit", 200),
pgvector_enabled=epi_conf.get("pgvector_enabled", True),
table_name=epi_conf.get("table_name", "episodic_memories"),
)
if config.memory.get("semantic", {}).get("enabled"):
sem_conf = config.memory["semantic"]
rag_service = HttpRAGService(
base_url=sem_conf["base_url"],
api_key=sem_conf.get("api_key"),
knowledge_base_ids=sem_conf.get("knowledge_base_ids", []),
timeout=sem_conf.get("timeout", 30),
)
semantic = SemanticMemory(
rag_service=rag_service,
knowledge_base_ids=sem_conf.get("knowledge_base_ids", []),
search_mode=sem_conf.get("search_mode", "standard"),
use_rerank=sem_conf.get("use_rerank", True),
use_compression=sem_conf.get("use_compression", False),
kb_weights=sem_conf.get("kb_weights"),
)
self._memory_retriever = MemoryRetriever(
working_memory=working,
episodic_memory=episodic,
semantic_memory=semantic,
)
# Inject into BaseAgent
self._memory_retriever_ref = self._memory_retriever
logger.info(f"ConfigDrivenAgent '{self.name}' initialized memory system")
except Exception as e:
logger.warning(f"Failed to initialize memory system: {e}")
self._memory_retriever = None
# Auto-register retrieve_knowledge tool if semantic memory is configured
if self._memory_retriever:
retrieve_tool = self._memory_retriever.create_retrieve_tool()
if retrieve_tool:
self.use_tool(retrieve_tool)
def get_tools(self) -> list[Tool]:
"""Return registered tools for this agent (from config + post-init registration)."""
all_tools = list(self._tools)
if self._tool_registry:
for tool in self._tool_registry.list_tools():
if not any(t.name == tool.name for t in all_tools):
all_tools.append(tool)
return all_tools
def get_model(self) -> str:
"""Return the LLM model name for this agent."""
return self._config.llm.get("model", "default") if self._config.llm else "default"
def get_system_prompt(self) -> str | None:
"""Return the system prompt for this agent, including available tools."""
parts = []
if self._prompt_template:
sections = self._prompt_template._sections
for key in ("identity", "context", "instructions", "constraints", "output_format"):
val = getattr(sections, key, "")
if val:
parts.append(val)
# Append available tools description so the LLM knows what it can use
# Use _tool_registry (which may have tools registered post-init) as well as _tools (from config)
all_tools = list(self._tools)
if self._tool_registry:
for tool in self._tool_registry.list_tools():
if not any(t.name == tool.name for t in all_tools):
all_tools.append(tool)
if all_tools:
tools_desc = self._build_tools_description(all_tools)
parts.append(f"\n\n## 可用工具\n{tools_desc}")
return "\n".join(parts) if parts else None
def _build_tools_description(self, tools: list | None = None) -> str:
"""Build a text description of available tools for the system prompt."""
tools = tools or self._tools
lines = []
for tool in tools:
lines.append(f"- **{tool.name}**: {tool.description}")
if tool.input_schema and "properties" in tool.input_schema:
params = list(tool.input_schema["properties"].keys())
if params:
lines[-1] += f" (参数: {', '.join(params)})"
return "\n".join(lines)
def get_react_config(self) -> dict:
"""Return ReAct engine configuration."""
max_steps = 10
timeout_seconds = None
if self._skill_config:
max_steps = self._skill_config.max_steps
timeout_seconds = getattr(self._skill_config, "timeout_seconds", None)
return {
"max_steps": max_steps,
"timeout_seconds": timeout_seconds,
}
@property
def config(self) -> AgentConfig:
return self._config
@property
def prompt_template(self) -> PromptTemplate | None:
return self._prompt_template
async def on_task_complete(self, task: TaskMessage, output: dict) -> None:
"""Task complete hook - trigger evolution if enabled"""
if self._evolution_enabled:
try:
from agentkit.core.protocol import TaskResult, TaskStatus
from datetime import datetime, timezone
result = TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.COMPLETED,
output_data=output,
error_message=None,
started_at=datetime.now(timezone.utc),
completed_at=datetime.now(timezone.utc),
)
await self.evolve_after_task(task, result)
except Exception as e:
logger.warning(f"Evolution after task failed: {e}")
async def on_task_failed(self, task: TaskMessage, error: Exception) -> None:
"""Task failed hook - record failure for evolution"""
if self._evolution_enabled:
try:
from agentkit.core.protocol import TaskResult, TaskStatus
from datetime import datetime, timezone
result = TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.FAILED,
output_data=None,
error_message=str(error),
started_at=datetime.now(timezone.utc),
completed_at=datetime.now(timezone.utc),
)
await self.evolve_after_task(task, result)
except Exception as e:
logger.warning(f"Evolution after task failure failed: {e}")
def _trigger_evolution_hooks(self, task: TaskMessage, result: TaskResult) -> None:
"""Schedule evolution after a streaming task (fire-and-forget, backpressure-capped).
Mirrors the sync on_task_complete/on_task_failed path but non-blocking so
streaming latency is unaffected. Evolution errors are swallowed inside
_evolve_safe and must never fail the stream. KTD-4: lifecycle parity with
execute() for the streaming path.
"""
if not self._evolution_enabled:
return
cap = max(2, self._config.max_concurrency * 2)
_schedule_evolution(self._evolve_safe(task, result), cap=cap)
async def _evolve_safe(self, task: TaskMessage, result: TaskResult) -> None:
"""Run evolve_after_task, swallowing errors (evolution must not fail stream)."""
try:
await self.evolve_after_task(task, result)
except Exception:
logger.warning("Evolution after stream task failed", exc_info=True)
def _bind_tools(self) -> None:
"""根据配置绑定工具"""
for tool_name in self._config.tools:
try:
tool = self._tool_registry.get(tool_name)
self.use_tool(tool)
logger.info(f"ConfigDrivenAgent '{self.name}' bound tool '{tool_name}'")
except Exception as e:
logger.warning(
f"ConfigDrivenAgent '{self.name}' failed to bind tool '{tool_name}': {e}"
)
def _auto_set_current_module(self) -> None:
"""Auto-set _current_module from SkillConfig for evolution.
Creates a Module from the current SkillConfig's instruction/prompt
so that prompt optimization has a target to work with.
"""
from agentkit.evolution.prompt_optimizer import Module, Signature
prompt = self._config.prompt or {}
instruction_parts = []
for key in ("identity", "instructions", "constraints"):
val = prompt.get(key, "")
if val:
instruction_parts.append(val)
instruction = "\n".join(instruction_parts)
input_fields = {}
if self._config.input_schema:
for field_name, field_info in self._config.input_schema.items():
input_fields[field_name] = (
str(field_info) if not isinstance(field_info, str) else field_info
)
output_fields = {}
if self._config.output_schema:
for field_name, field_info in self._config.output_schema.items():
output_fields[field_name] = (
str(field_info) if not isinstance(field_info, str) else field_info
)
module = Module(
name=self.name,
signature=Signature(
input_fields=input_fields or {"input": "task input"},
output_fields=output_fields or {"output": "task output"},
instruction=instruction,
),
)
self.set_current_module(module)
logger.debug(f"Auto-set _current_module for agent '{self.name}'")
async def _register_mcp_tools(self) -> None:
"""Lazily register tools from MCP servers as agent tools.
Called on first task execution to allow async MCP client operations.
"""
if self._mcp_tools_registered or not self._mcp_servers:
return
self._mcp_tools_registered = True
from agentkit.mcp.client import MCPClient
for server_name, base_url in self._mcp_servers.items():
try:
client = MCPClient(server_url=base_url)
self._mcp_clients.append(client)
# List available tools from the MCP server
tools = await client.list_tools()
for tool_info in tools:
tool_name = tool_info.get("name", "")
tool_desc = tool_info.get("description", "")
if not tool_name:
continue
# Create MCPTool and register it
mcp_tool = client.as_tool(tool_name, tool_desc)
self.use_tool(mcp_tool)
logger.info(
f"Agent '{self.name}' registered MCP tool '{tool_name}' "
f"from server '{server_name}'"
)
except Exception as e:
logger.warning(
f"Agent '{self.name}' failed to connect to MCP server "
f"'{server_name}' at {base_url}: {e}"
)
def get_capabilities(self) -> AgentCapability:
return AgentCapability(
agent_name=self.name,
agent_type=self.agent_type,
version=self.version,
supported_tasks=self._config.supported_tasks,
max_concurrency=self._config.max_concurrency,
description=self._config.description,
input_schema=self._config.input_schema,
output_schema=self._config.output_schema,
)
async def handle_task(self, task: TaskMessage) -> dict:
"""根据 execution_mode 和 task_mode 执行任务
v2 execution_mode 优先级:
- react: 使用 ReAct 引擎自主推理
- direct: 直接调用 LLM不经过 ReAct 循环)
- custom: 使用自定义 handler
如果没有 SkillConfig回退到传统 task_mode 分支。
"""
# Lazy-register MCP tools on first task execution
await self._register_mcp_tools()
# v2: execution_mode routing (when SkillConfig is present)
if self._skill_config:
execution_mode = self._skill_config.execution_mode
if execution_mode == "react" and self._react_engine:
return await self._handle_react(task)
elif execution_mode == "rewoo" and self._react_engine:
return await self._handle_rewoo(task)
elif execution_mode == "plan_exec" and self._react_engine:
return await self._handle_plan_exec(task)
elif execution_mode == "reflexion" and self._react_engine:
return await self._handle_reflexion(task)
elif execution_mode == "direct":
return await self._handle_direct(task)
elif execution_mode == "custom":
return await self._handle_custom(task)
# Fall back to existing task_mode modes
if self._config.task_mode == "llm_generate":
return await self._handle_llm_generate(task)
elif self._config.task_mode == "tool_call":
return await self._handle_tool_call(task)
elif self._config.task_mode == "custom":
return await self._handle_custom(task)
else:
raise ConfigValidationError(
agent_name=self.name,
key="task_mode",
reason=f"Unknown task_mode: {self._config.task_mode}",
)
# ── 流式执行U3 ────────────────────────────────────────
def _build_llm_messages(self, task: TaskMessage) -> tuple[str | None, list[dict[str, str]]]:
"""Build (system_prompt, user_messages) from task + prompt template.
Shared by all _handle_*_stream methods to avoid duplicating the
message-rendering logic that mirrors the sync _handle_* methods.
"""
variables = task.input_data.copy()
variables["task_type"] = task.task_type
if self._prompt_template:
rendered_messages = self._prompt_template.render(variables=variables)
else:
rendered_messages = [{"role": "user", "content": str(task.input_data)}]
system_prompt: str | None = None
user_messages: list[dict[str, str]] = []
for msg in rendered_messages:
if msg["role"] == "system":
system_prompt = msg["content"]
else:
user_messages.append(msg)
if not user_messages:
user_messages.append({"role": "user", "content": str(task.input_data)})
return system_prompt, user_messages
async def execute_stream(self, task: TaskMessage) -> AsyncGenerator[ReActEvent, None]:
"""流式执行任务yield ReActEvent。
镜像 execute() → handle_task() 分派,但不包装 TaskResult —
直接 yield 事件,由调用方负责转发和累积。
P2 fix: 注册 CancellationToken 到 _active_tokens使 cancel_task() 能
协作式取消流式任务。原实现绕过 BaseAgent.execute(),未注册 token。
KTD-4: 在 finally 中触发 on_task_complete/on_task_failed 进化钩子,
与 execute() 保持生命周期对等。使用 fire-and-forget + 背压上限,
进化错误不得阻塞流式返回。PlanExec/Reflexion 等子引擎的异常会向上
传播到此处 finally因此钩子集中在此触发子引擎无需重复触发。
"""
token = CancellationToken()
self._active_tokens[task.task_id] = token
_stream_output: dict = {}
_stream_error: BaseException | None = None
_stream_completed = False
try:
await self._register_mcp_tools()
async for event in self.handle_task_stream(task):
if event.event_type == "final_answer":
_raw = event.data.get("output", "")
_stream_output = {"content": _raw} if isinstance(_raw, str) else _raw
yield event
_stream_completed = True
except asyncio.CancelledError as ce:
# Cancellation must propagate, but hooks still fire (U2 edge case).
_stream_error = ce
raise
except Exception as e:
_stream_error = e
raise
finally:
# async generator 的 finally 在 generator 关闭时执行GC/aclose/正常结束)
self._active_tokens.pop(task.task_id, None)
# KTD-4: lifecycle parity — fire evolution hooks fire-and-forget.
try:
now = datetime.now(timezone.utc)
if _stream_error is not None:
if isinstance(_stream_error, (asyncio.CancelledError, TaskCancelledError)):
status = TaskStatus.CANCELLED
err_msg = f"stream cancelled: {_stream_error}"
else:
status = TaskStatus.FAILED
err_msg = str(_stream_error)
result = TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=status,
output_data=None,
error_message=err_msg,
started_at=now,
completed_at=now,
)
elif _stream_completed:
result = TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.COMPLETED,
output_data=_stream_output,
error_message=None,
started_at=now,
completed_at=now,
)
else:
# Stream closed before completion (consumer aclose / GC).
result = TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.CANCELLED,
output_data=None,
error_message="stream closed before completion",
started_at=now,
completed_at=now,
)
self._trigger_evolution_hooks(task, result)
except Exception:
logger.debug("evolution hook scheduling failed", exc_info=True)
async def handle_task_stream(self, task: TaskMessage) -> AsyncGenerator[ReActEvent, None]:
"""根据 execution_mode / task_mode 流式分派,镜像 handle_task()。"""
if self._skill_config:
execution_mode = self._skill_config.execution_mode
if execution_mode == "react" and self._react_engine:
async for e in self._handle_react_stream(task):
yield e
return
if execution_mode == "rewoo" and self._llm_gateway:
async for e in self._handle_rewoo_stream(task):
yield e
return
if execution_mode == "plan_exec" and self._llm_gateway:
async for e in self._handle_plan_exec_stream(task):
yield e
return
if execution_mode == "reflexion" and self._llm_gateway:
async for e in self._handle_reflexion_stream(task):
yield e
return
if execution_mode == "direct":
async for e in self._wrap_sync_as_stream(self._handle_direct, task):
yield e
return
if execution_mode == "custom":
async for e in self._wrap_sync_as_stream(self._handle_custom, task):
yield e
return
# Fall back to task_mode modes
if self._config.task_mode == "llm_generate":
async for e in self._wrap_sync_as_stream(self._handle_llm_generate, task):
yield e
return
if self._config.task_mode == "tool_call":
async for e in self._wrap_sync_as_stream(self._handle_tool_call, task):
yield e
return
if self._config.task_mode == "custom":
async for e in self._wrap_sync_as_stream(self._handle_custom, task):
yield e
return
# Unknown mode — wrap sync result as single final_answer event
result = await self.handle_task(task)
yield ReActEvent(
event_type="final_answer",
step=0,
data={"output": result.get("content", str(result))},
)
async def _handle_react_stream(self, task: TaskMessage) -> AsyncGenerator[ReActEvent, None]:
"""ReAct mode streaming: delegate to ReActEngine.execute_stream()."""
if self._evolution_enabled and self._current_module is None:
self._auto_set_current_module()
system_prompt, user_messages = self._build_llm_messages(task)
retrieval_config = self._config.memory.get("retrieval", {}) if self._config.memory else {}
async for event in self._react_engine.execute_stream( # type: ignore[union-attr]
messages=user_messages,
tools=self.get_tools() or None,
model=self._config.llm.get("model", "default") if self._config.llm else "default",
agent_name=self.name,
task_type=task.task_type,
system_prompt=system_prompt,
memory_retriever=self._memory_retriever,
task_id=task.task_id,
retrieval_config=retrieval_config or None,
cancellation_token=self._active_tokens.get(task.task_id),
timeout_seconds=float(task.timeout_seconds) if task.timeout_seconds > 0 else None,
compressor=self._compressor,
):
yield event
async def _handle_rewoo_stream(self, task: TaskMessage) -> AsyncGenerator[ReActEvent, None]:
"""ReWOO mode streaming: delegate to ReWOOEngine.execute_stream()."""
from agentkit.core.rewoo import ReWOOEngine
system_prompt, user_messages = self._build_llm_messages(task)
rewoo_engine = ReWOOEngine(
llm_gateway=self._llm_gateway,
max_plan_steps=self._skill_config.max_steps if self._skill_config else 5,
default_timeout=300.0,
fallback_strategies=(
self._skill_config.fallback_strategies
if self._skill_config and self._skill_config.fallback_strategies
else None
),
)
async for event in rewoo_engine.execute_stream(
messages=user_messages,
tools=self.get_tools() or None,
model=self._config.llm.get("model", "default") if self._config.llm else "default",
agent_name=self.name,
task_type=task.task_type,
system_prompt=system_prompt,
task_id=task.task_id,
cancellation_token=self._active_tokens.get(task.task_id),
timeout_seconds=float(task.timeout_seconds) if task.timeout_seconds > 0 else None,
):
yield event
async def _handle_plan_exec_stream(self, task: TaskMessage) -> AsyncGenerator[ReActEvent, None]:
"""Plan-Exec mode streaming: delegate to PlanExecEngine.execute_stream()."""
from agentkit.core.plan_exec_engine import PlanExecEngine
system_prompt, user_messages = self._build_llm_messages(task)
plan_exec_engine = PlanExecEngine(
llm_gateway=self._llm_gateway,
max_replans=2,
default_timeout=300.0,
)
async for event in plan_exec_engine.execute_stream(
messages=user_messages,
tools=self.get_tools() or None,
model=self._config.llm.get("model", "default") if self._config.llm else "default",
agent_name=self.name,
task_type=task.task_type,
system_prompt=system_prompt,
task_id=task.task_id,
cancellation_token=self._active_tokens.get(task.task_id),
timeout_seconds=float(task.timeout_seconds) if task.timeout_seconds > 0 else None,
):
yield event
async def _handle_reflexion_stream(self, task: TaskMessage) -> AsyncGenerator[ReActEvent, None]:
"""Reflexion mode streaming: delegate to ReflexionEngine.execute_stream()."""
from agentkit.core.reflexion import ReflexionEngine
system_prompt, user_messages = self._build_llm_messages(task)
reflexion_engine = ReflexionEngine(
llm_gateway=self._llm_gateway,
max_steps=self._skill_config.max_steps if self._skill_config else 5,
max_reflections=2,
quality_threshold=0.7,
default_timeout=300.0,
)
async for event in reflexion_engine.execute_stream(
messages=user_messages,
tools=self.get_tools() or None,
model=self._config.llm.get("model", "default") if self._config.llm else "default",
agent_name=self.name,
task_type=task.task_type,
system_prompt=system_prompt,
task_id=task.task_id,
cancellation_token=self._active_tokens.get(task.task_id),
timeout_seconds=float(task.timeout_seconds) if task.timeout_seconds > 0 else None,
):
yield event
async def _wrap_sync_as_stream(
self,
handler: Callable[[TaskMessage], Awaitable[dict]],
task: TaskMessage,
) -> AsyncGenerator[ReActEvent, None]:
"""Wrap a sync handler's result as a single final_answer stream event."""
result = await handler(task)
yield ReActEvent(
event_type="final_answer",
step=0,
data={"output": result.get("content", str(result))},
)
async def _handle_react(self, task: TaskMessage) -> dict:
"""ReAct mode: use ReAct engine for autonomous reasoning"""
# Auto-set _current_module from SkillConfig if evolution is enabled
if self._evolution_enabled and self._current_module is None:
self._auto_set_current_module()
# Build variables for prompt rendering
variables = task.input_data.copy()
variables["task_type"] = task.task_type
# Use PromptTemplate.render() to get full messages (system + user)
if self._prompt_template:
rendered_messages = self._prompt_template.render(variables=variables)
else:
rendered_messages = [{"role": "user", "content": str(task.input_data)}]
# Separate system_prompt from user messages
# PromptTemplate.render() returns [system_msg, user_msg] or [user_msg]
system_prompt = None
user_messages = []
for msg in rendered_messages:
if msg["role"] == "system":
system_prompt = msg["content"]
else:
user_messages.append(msg)
# If no user messages, add a default one
if not user_messages:
user_messages.append({"role": "user", "content": str(task.input_data)})
# Get CancellationToken for this task (set by BaseAgent.execute)
cancellation_token = self._active_tokens.get(task.task_id)
# Determine timeout from task or config
timeout_seconds = float(task.timeout_seconds) if task.timeout_seconds > 0 else None
# Execute ReAct loop
retrieval_config = self._config.memory.get("retrieval", {}) if self._config.memory else {}
result = await self._react_engine.execute(
messages=user_messages,
tools=self.get_tools() or None,
model=self._config.llm.get("model", "default") if self._config.llm else "default",
agent_name=self.name,
task_type=task.task_type,
system_prompt=system_prompt,
memory_retriever=self._memory_retriever,
task_id=task.task_id,
retrieval_config=retrieval_config or None,
cancellation_token=cancellation_token,
timeout_seconds=timeout_seconds,
compressor=self._compressor,
)
# Parse result
return self._parse_llm_response(result.output)
async def _handle_rewoo(self, task: TaskMessage) -> dict:
"""ReWOO mode: plan all tool calls upfront, then execute in batch"""
from agentkit.core.rewoo import ReWOOEngine
variables = task.input_data.copy()
variables["task_type"] = task.task_type
if self._prompt_template:
rendered_messages = self._prompt_template.render(variables=variables)
else:
rendered_messages = [{"role": "user", "content": str(task.input_data)}]
system_prompt = None
user_messages = []
for msg in rendered_messages:
if msg["role"] == "system":
system_prompt = msg["content"]
else:
user_messages.append(msg)
if not user_messages:
user_messages.append({"role": "user", "content": str(task.input_data)})
cancellation_token = self._active_tokens.get(task.task_id)
timeout_seconds = float(task.timeout_seconds) if task.timeout_seconds > 0 else None
rewoo_engine = ReWOOEngine(
llm_gateway=self._llm_gateway,
max_plan_steps=self._skill_config.max_steps if self._skill_config else 5,
default_timeout=300.0,
fallback_strategies=(
self._skill_config.fallback_strategies
if self._skill_config and self._skill_config.fallback_strategies
else None
),
)
result = await rewoo_engine.execute(
messages=user_messages,
tools=self.get_tools() or None,
model=self._config.llm.get("model", "default") if self._config.llm else "default",
agent_name=self.name,
task_type=task.task_type,
system_prompt=system_prompt,
task_id=task.task_id,
cancellation_token=cancellation_token,
timeout_seconds=timeout_seconds,
)
return self._parse_llm_response(result.output)
async def _handle_plan_exec(self, task: TaskMessage) -> dict:
"""Plan-and-Execute mode: decompose task into plan, execute steps, replan if needed"""
from agentkit.core.plan_exec_engine import PlanExecEngine
variables = task.input_data.copy()
variables["task_type"] = task.task_type
if self._prompt_template:
rendered_messages = self._prompt_template.render(variables=variables)
else:
rendered_messages = [{"role": "user", "content": str(task.input_data)}]
system_prompt = None
user_messages = []
for msg in rendered_messages:
if msg["role"] == "system":
system_prompt = msg["content"]
else:
user_messages.append(msg)
if not user_messages:
user_messages.append({"role": "user", "content": str(task.input_data)})
cancellation_token = self._active_tokens.get(task.task_id)
timeout_seconds = float(task.timeout_seconds) if task.timeout_seconds > 0 else None
plan_exec_engine = PlanExecEngine(
llm_gateway=self._llm_gateway,
max_replans=2,
default_timeout=300.0,
)
result = await plan_exec_engine.execute(
messages=user_messages,
tools=self.get_tools() or None,
model=self._config.llm.get("model", "default") if self._config.llm else "default",
agent_name=self.name,
task_type=task.task_type,
system_prompt=system_prompt,
task_id=task.task_id,
cancellation_token=cancellation_token,
timeout_seconds=timeout_seconds,
)
return self._parse_llm_response(result.output)
async def _handle_reflexion(self, task: TaskMessage) -> dict:
"""Reflexion mode: ReAct + Evaluate + Reflect + Retry for high-precision tasks"""
from agentkit.core.reflexion import ReflexionEngine
variables = task.input_data.copy()
variables["task_type"] = task.task_type
if self._prompt_template:
rendered_messages = self._prompt_template.render(variables=variables)
else:
rendered_messages = [{"role": "user", "content": str(task.input_data)}]
system_prompt = None
user_messages = []
for msg in rendered_messages:
if msg["role"] == "system":
system_prompt = msg["content"]
else:
user_messages.append(msg)
if not user_messages:
user_messages.append({"role": "user", "content": str(task.input_data)})
cancellation_token = self._active_tokens.get(task.task_id)
timeout_seconds = float(task.timeout_seconds) if task.timeout_seconds > 0 else None
reflexion_engine = ReflexionEngine(
llm_gateway=self._llm_gateway,
max_steps=self._skill_config.max_steps if self._skill_config else 5,
max_reflections=2,
quality_threshold=0.7,
default_timeout=300.0,
)
result = await reflexion_engine.execute(
messages=user_messages,
tools=self.get_tools() or None,
model=self._config.llm.get("model", "default") if self._config.llm else "default",
agent_name=self.name,
task_type=task.task_type,
system_prompt=system_prompt,
task_id=task.task_id,
cancellation_token=cancellation_token,
timeout_seconds=timeout_seconds,
)
return self._parse_llm_response(result.output)
async def _handle_direct(self, task: TaskMessage) -> dict:
"""Direct mode: single LLM call without ReAct loop.
Renders the full prompt template and makes one LLM call via LLMGateway.
Falls back to _handle_llm_generate if no LLMGateway is available.
"""
if not self._llm_gateway:
return await self._handle_llm_generate(task)
# Build variables for prompt rendering
variables = task.input_data.copy()
variables["task_type"] = task.task_type
# Use PromptTemplate.render() to get full messages
if self._prompt_template:
rendered_messages = self._prompt_template.render(variables=variables)
else:
rendered_messages = [{"role": "user", "content": str(task.input_data)}]
# Make a single LLM call
model = self._config.llm.get("model", "default") if self._config.llm else "default"
response = await self._llm_gateway.chat(
messages=rendered_messages,
model=model,
agent_name=self.name,
task_type=task.task_type,
)
return self._parse_llm_response(response.content)
async def handle_task_with_feedback(self, task: TaskMessage, feedback: str) -> dict:
"""Re-execute task with quality feedback"""
enhanced_input = task.input_data.copy()
enhanced_input["quality_feedback"] = feedback
enhanced_task = TaskMessage(
task_id=task.task_id,
agent_name=task.agent_name,
task_type=task.task_type,
input_data=enhanced_input,
priority=task.priority,
created_at=task.created_at,
callback_url=task.callback_url,
timeout_seconds=task.timeout_seconds,
conversation_id=task.conversation_id,
)
return await self.handle_task(enhanced_task)
def _wrap_llm_client(self, llm_client: object):
"""Wrap legacy llm_client into LLMGateway"""
from agentkit.llm.gateway import LLMGateway
from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, TokenUsage
class ClientProvider(LLMProvider):
"""Adapter: wraps legacy llm_client as an LLMProvider"""
def __init__(self, raw_client: object):
self._raw_client = raw_client
async def chat(self, request: LLMRequest) -> LLMResponse:
kwargs = dict(request._extra) if hasattr(request, "_extra") else {}
kwargs["model"] = request.model
kwargs["temperature"] = request.temperature
kwargs["max_tokens"] = request.max_tokens
if hasattr(self._raw_client, "chat"):
response = await self._raw_client.chat(messages=request.messages, **kwargs)
elif hasattr(self._raw_client, "create"):
response = await self._raw_client.create(messages=request.messages, **kwargs)
elif callable(self._raw_client):
response = await self._raw_client(messages=request.messages, **kwargs)
else:
raise ConfigValidationError(
agent_name="",
key="llm_client",
reason="LLM client must have 'chat'/'create' method or be callable",
)
# Normalize response to string
if isinstance(response, str):
content = response
elif isinstance(response, dict):
content = response.get("content", json.dumps(response))
elif hasattr(response, "content"):
content = response.content
else:
content = str(response)
return LLMResponse(
content=content,
model=request.model,
usage=TokenUsage(prompt_tokens=0, completion_tokens=0),
)
gateway = LLMGateway()
gateway.register_provider("wrapped", ClientProvider(llm_client))
return gateway
async def _handle_llm_generate(self, task: TaskMessage) -> dict:
"""LLM 生成模式:渲染 Prompt → 调用 LLM → 解析输出"""
if not self._prompt_template:
raise ConfigValidationError(
agent_name=self.name,
key="prompt",
reason="llm_generate mode requires prompt configuration",
)
# 渲染 Prompt
variables = task.input_data.copy()
variables["task_type"] = task.task_type
messages = self._prompt_template.render(variables=variables)
# 调用 LLM
if self._llm_client is None:
# 无 LLM 客户端时返回渲染后的 Prompt降级模式
return {
"mode": "llm_generate_no_client",
"messages": messages,
"note": "No LLM client configured, returning rendered prompt",
}
# 使用配置的 LLM 参数
llm_params = self._config.llm.copy()
response = await self._call_llm(messages, **llm_params)
return self._parse_llm_response(response)
async def _handle_tool_call(self, task: TaskMessage) -> dict:
"""工具调用模式:调用指定 Tool 并返回结果"""
if not self._tools:
raise ConfigValidationError(
agent_name=self.name,
key="tools",
reason="tool_call mode requires at least one bound tool",
)
# 使用第一个绑定的工具(或根据 task_type 匹配)
tool = self._resolve_tool(task)
result = await tool.safe_execute(**task.input_data)
return result
async def _handle_custom(self, task: TaskMessage) -> dict:
"""自定义模式:调用注册的 handler 函数"""
handler_name = self._config.custom_handler
if handler_name not in self._custom_handlers:
# 尝试动态导入
handler = self._import_handler(handler_name)
self._custom_handlers[handler_name] = handler
handler = self._custom_handlers[handler_name]
result = await handler(task)
if not isinstance(result, dict):
raise ConfigValidationError(
agent_name=self.name,
key="custom_handler",
reason=f"Custom handler '{handler_name}' must return dict, got {type(result)}",
)
return result
def _resolve_tool(self, task: TaskMessage) -> Tool:
"""根据任务类型解析要使用的工具"""
# 优先匹配 task_type 与 tool name
for tool in self._tools:
if task.task_type in tool.name or task.task_type.replace("_", "-") in tool.name:
return tool
# 回退到第一个工具
return self._tools[0]
async def _call_llm(self, messages: list[dict[str, str]], **kwargs) -> str:
"""调用 LLM 客户端"""
model = kwargs.pop("model", "default")
temperature = kwargs.pop("temperature", 0.7)
max_tokens = kwargs.pop("max_tokens", 2000)
# 标准化 LLM 客户端调用接口
if hasattr(self._llm_client, "chat"):
response = await self._llm_client.chat(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
**kwargs,
)
elif hasattr(self._llm_client, "create"):
response = await self._llm_client.create(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
**kwargs,
)
elif callable(self._llm_client):
response = await self._llm_client(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
**kwargs,
)
else:
raise ConfigValidationError(
agent_name=self.name,
key="llm_client",
reason="LLM client must have 'chat'/'create' method or be callable",
)
# 提取文本内容
if isinstance(response, str):
return response
if isinstance(response, dict):
return response.get("content", str(response))
if hasattr(response, "content"):
return response.content
return str(response)
def _parse_llm_response(self, response: str) -> dict:
"""解析 LLM 响应为 dict"""
# 尝试直接解析 JSON
try:
return json.loads(response)
except (json.JSONDecodeError, TypeError):
pass
# 尝试提取 JSON 块
import re
json_match = re.search(r"```(?:json)?\s*\n?(.*?)\n?```", response, re.DOTALL)
if json_match:
try:
return json.loads(json_match.group(1))
except (json.JSONDecodeError, TypeError):
pass
# 降级:包装为文本结果
return {"text": response}
def _import_handler(self, dotted_path: str) -> Callable[..., Coroutine]:
"""动态导入自定义 handler"""
# Security: validate module prefix to prevent arbitrary code execution
if not any(dotted_path.startswith(prefix) for prefix in self._ALLOWED_HANDLER_PREFIXES):
raise ConfigValidationError(
agent_name=self.name,
key="custom_handler",
reason=f"Handler '{dotted_path}' is not in allowed module prefixes: {self._ALLOWED_HANDLER_PREFIXES}",
)
try:
module_path, func_name = dotted_path.rsplit(".", 1)
import importlib
module = importlib.import_module(module_path)
handler = getattr(module, func_name)
if not callable(handler):
raise ConfigValidationError(
agent_name=self.name,
key="custom_handler",
reason=f"'{dotted_path}' is not callable",
)
return handler
except (ImportError, AttributeError, ValueError) as e:
raise ConfigValidationError(
agent_name=self.name,
key="custom_handler",
reason=f"Failed to import custom handler '{dotted_path}': {e}",
)