fix(review): document-processing code review fixes — validation, tests, formatting
Deploy to Production / deploy (push) Waiting to run
Details
Deploy to Production / deploy (push) Waiting to run
Details
- SkillConfig._validate_v2: validate fallback_strategies against ReWOOEngine.VALID_STRATEGIES (lazy import, #20) - test_skill_config: +4 tests for fallback_strategies validation - test_document_loader: +8 xlsx edge case tests (empty workbook, malformed bytes, column mismatch, row/cell truncation, multi-sheet, file size limit, None cells, #16) - test_execution_modes: fix ReWOOEngine patch path (lazy import -> patch at source) + FakeReWOOEngine.execute return .output attribute - config_driven: ruff formatting (quotes, blank lines after imports) - project_rules: remove stale "known failing test" note (now passes)
This commit is contained in:
parent
b9bb1b7cf1
commit
3337589395
|
|
@ -36,4 +36,3 @@ This applies to ALL async generator functions in the codebase. When adding an ea
|
||||||
## Testing
|
## Testing
|
||||||
|
|
||||||
- Run `python3 -m pytest tests/unit/ -x -q` before committing
|
- Run `python3 -m pytest tests/unit/ -x -q` before committing
|
||||||
- Known failing test (unrelated): `test_rewoo_agent_yaml_loads` — skip if needed
|
|
||||||
|
|
|
||||||
|
|
@ -245,21 +245,22 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
|
||||||
|
|
||||||
self._react_engine = ReActEngine(
|
self._react_engine = ReActEngine(
|
||||||
llm_gateway=self._llm_gateway,
|
llm_gateway=self._llm_gateway,
|
||||||
max_steps=getattr(config, 'max_steps', 5),
|
max_steps=getattr(config, "max_steps", 5),
|
||||||
)
|
)
|
||||||
|
|
||||||
# v2: Initialize Quality Gate (always available)
|
# v2: Initialize Quality Gate (always available)
|
||||||
from agentkit.quality.gate import QualityGate
|
from agentkit.quality.gate import QualityGate
|
||||||
|
|
||||||
self._quality_gate = QualityGate()
|
self._quality_gate = QualityGate()
|
||||||
|
|
||||||
# v2: Initialize Evolution if configured
|
# v2: Initialize Evolution if configured
|
||||||
evolution_config = getattr(config, 'evolution', None)
|
evolution_config = getattr(config, "evolution", None)
|
||||||
if evolution_config is not None:
|
if evolution_config is not None:
|
||||||
# Support both dict and EvolutionConfig
|
# Support both dict and EvolutionConfig
|
||||||
if isinstance(evolution_config, dict):
|
if isinstance(evolution_config, dict):
|
||||||
is_enabled = evolution_config.get("enabled", False)
|
is_enabled = evolution_config.get("enabled", False)
|
||||||
else:
|
else:
|
||||||
is_enabled = getattr(evolution_config, 'enabled', False)
|
is_enabled = getattr(evolution_config, "enabled", False)
|
||||||
else:
|
else:
|
||||||
is_enabled = False
|
is_enabled = False
|
||||||
|
|
||||||
|
|
@ -276,6 +277,7 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
|
||||||
|
|
||||||
# v2: Initialize Output Standardizer
|
# v2: Initialize Output Standardizer
|
||||||
from agentkit.quality.output import OutputStandardizer
|
from agentkit.quality.output import OutputStandardizer
|
||||||
|
|
||||||
self._output_standardizer = OutputStandardizer()
|
self._output_standardizer = OutputStandardizer()
|
||||||
|
|
||||||
# v2: Store compressor for ReAct engine
|
# v2: Store compressor for ReAct engine
|
||||||
|
|
@ -327,6 +329,7 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
|
||||||
|
|
||||||
if config.memory.get("working", {}).get("enabled"):
|
if config.memory.get("working", {}).get("enabled"):
|
||||||
import redis.asyncio as aioredis
|
import redis.asyncio as aioredis
|
||||||
|
|
||||||
redis_url = config.memory["working"].get("redis_url", "redis://localhost:6379")
|
redis_url = config.memory["working"].get("redis_url", "redis://localhost:6379")
|
||||||
redis_client = aioredis.from_url(redis_url, decode_responses=True)
|
redis_client = aioredis.from_url(redis_url, decode_responses=True)
|
||||||
working = WorkingMemory(redis=redis_client)
|
working = WorkingMemory(redis=redis_client)
|
||||||
|
|
@ -350,7 +353,7 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
|
||||||
)
|
)
|
||||||
episodic = EpisodicMemory(
|
episodic = EpisodicMemory(
|
||||||
session_factory=None, # Set externally when DB session is available
|
session_factory=None, # Set externally when DB session is available
|
||||||
episodic_model=None, # Set externally when ORM model is available
|
episodic_model=None, # Set externally when ORM model is available
|
||||||
embedder=embedder,
|
embedder=embedder,
|
||||||
decay_rate=epi_conf.get("decay_rate", 0.01),
|
decay_rate=epi_conf.get("decay_rate", 0.01),
|
||||||
alpha=epi_conf.get("alpha", 0.7),
|
alpha=epi_conf.get("alpha", 0.7),
|
||||||
|
|
@ -471,6 +474,7 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
|
||||||
try:
|
try:
|
||||||
from agentkit.core.protocol import TaskResult, TaskStatus
|
from agentkit.core.protocol import TaskResult, TaskStatus
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
result = TaskResult(
|
result = TaskResult(
|
||||||
task_id=task.task_id,
|
task_id=task.task_id,
|
||||||
agent_name=self.name,
|
agent_name=self.name,
|
||||||
|
|
@ -490,6 +494,7 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
|
||||||
try:
|
try:
|
||||||
from agentkit.core.protocol import TaskResult, TaskStatus
|
from agentkit.core.protocol import TaskResult, TaskStatus
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
result = TaskResult(
|
result = TaskResult(
|
||||||
task_id=task.task_id,
|
task_id=task.task_id,
|
||||||
agent_name=self.name,
|
agent_name=self.name,
|
||||||
|
|
@ -534,12 +539,16 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
|
||||||
input_fields = {}
|
input_fields = {}
|
||||||
if self._config.input_schema:
|
if self._config.input_schema:
|
||||||
for field_name, field_info in self._config.input_schema.items():
|
for field_name, field_info in self._config.input_schema.items():
|
||||||
input_fields[field_name] = str(field_info) if not isinstance(field_info, str) else field_info
|
input_fields[field_name] = (
|
||||||
|
str(field_info) if not isinstance(field_info, str) else field_info
|
||||||
|
)
|
||||||
|
|
||||||
output_fields = {}
|
output_fields = {}
|
||||||
if self._config.output_schema:
|
if self._config.output_schema:
|
||||||
for field_name, field_info in self._config.output_schema.items():
|
for field_name, field_info in self._config.output_schema.items():
|
||||||
output_fields[field_name] = str(field_info) if not isinstance(field_info, str) else field_info
|
output_fields[field_name] = (
|
||||||
|
str(field_info) if not isinstance(field_info, str) else field_info
|
||||||
|
)
|
||||||
|
|
||||||
module = Module(
|
module = Module(
|
||||||
name=self.name,
|
name=self.name,
|
||||||
|
|
@ -731,6 +740,11 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
|
||||||
llm_gateway=self._llm_gateway,
|
llm_gateway=self._llm_gateway,
|
||||||
max_plan_steps=self._skill_config.max_steps if self._skill_config else 5,
|
max_plan_steps=self._skill_config.max_steps if self._skill_config else 5,
|
||||||
default_timeout=300.0,
|
default_timeout=300.0,
|
||||||
|
fallback_strategies=(
|
||||||
|
self._skill_config.fallback_strategies
|
||||||
|
if self._skill_config and self._skill_config.fallback_strategies
|
||||||
|
else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await rewoo_engine.execute(
|
result = await rewoo_engine.execute(
|
||||||
|
|
@ -901,23 +915,17 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
|
||||||
self._raw_client = raw_client
|
self._raw_client = raw_client
|
||||||
|
|
||||||
async def chat(self, request: LLMRequest) -> LLMResponse:
|
async def chat(self, request: LLMRequest) -> LLMResponse:
|
||||||
kwargs = dict(request._extra) if hasattr(request, '_extra') else {}
|
kwargs = dict(request._extra) if hasattr(request, "_extra") else {}
|
||||||
kwargs["model"] = request.model
|
kwargs["model"] = request.model
|
||||||
kwargs["temperature"] = request.temperature
|
kwargs["temperature"] = request.temperature
|
||||||
kwargs["max_tokens"] = request.max_tokens
|
kwargs["max_tokens"] = request.max_tokens
|
||||||
|
|
||||||
if hasattr(self._raw_client, "chat"):
|
if hasattr(self._raw_client, "chat"):
|
||||||
response = await self._raw_client.chat(
|
response = await self._raw_client.chat(messages=request.messages, **kwargs)
|
||||||
messages=request.messages, **kwargs
|
|
||||||
)
|
|
||||||
elif hasattr(self._raw_client, "create"):
|
elif hasattr(self._raw_client, "create"):
|
||||||
response = await self._raw_client.create(
|
response = await self._raw_client.create(messages=request.messages, **kwargs)
|
||||||
messages=request.messages, **kwargs
|
|
||||||
)
|
|
||||||
elif callable(self._raw_client):
|
elif callable(self._raw_client):
|
||||||
response = await self._raw_client(
|
response = await self._raw_client(messages=request.messages, **kwargs)
|
||||||
messages=request.messages, **kwargs
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ConfigValidationError(
|
raise ConfigValidationError(
|
||||||
agent_name="",
|
agent_name="",
|
||||||
|
|
@ -1072,6 +1080,7 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
|
||||||
|
|
||||||
# 尝试提取 JSON 块
|
# 尝试提取 JSON 块
|
||||||
import re
|
import re
|
||||||
|
|
||||||
json_match = re.search(r"```(?:json)?\s*\n?(.*?)\n?```", response, re.DOTALL)
|
json_match = re.search(r"```(?:json)?\s*\n?(.*?)\n?```", response, re.DOTALL)
|
||||||
if json_match:
|
if json_match:
|
||||||
try:
|
try:
|
||||||
|
|
@ -1095,6 +1104,7 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
|
||||||
try:
|
try:
|
||||||
module_path, func_name = dotted_path.rsplit(".", 1)
|
module_path, func_name = dotted_path.rsplit(".", 1)
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
module = importlib.import_module(module_path)
|
module = importlib.import_module(module_path)
|
||||||
handler = getattr(module, func_name)
|
handler = getattr(module, func_name)
|
||||||
if not callable(handler):
|
if not callable(handler):
|
||||||
|
|
|
||||||
|
|
@ -87,6 +87,8 @@ class SkillConfig(AgentConfig):
|
||||||
capabilities: list[str | dict[str, Any] | CapabilityTag] | None = None,
|
capabilities: list[str | dict[str, Any] | CapabilityTag] | None = None,
|
||||||
# v5 新增字段:对齐守卫
|
# v5 新增字段:对齐守卫
|
||||||
alignment: dict[str, Any] | None = None,
|
alignment: dict[str, Any] | None = None,
|
||||||
|
# v6 新增字段:ReWOO fallback 策略(YAML 可配置)
|
||||||
|
fallback_strategies: list[str] | None = None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
name=name,
|
name=name,
|
||||||
|
|
@ -116,7 +118,10 @@ class SkillConfig(AgentConfig):
|
||||||
self.capabilities = self._parse_capabilities(capabilities or [])
|
self.capabilities = self._parse_capabilities(capabilities or [])
|
||||||
# v5: 对齐守卫配置
|
# v5: 对齐守卫配置
|
||||||
from agentkit.quality.alignment import AlignmentConfig
|
from agentkit.quality.alignment import AlignmentConfig
|
||||||
|
|
||||||
self.alignment = AlignmentConfig(**(alignment or {}))
|
self.alignment = AlignmentConfig(**(alignment or {}))
|
||||||
|
# v6: ReWOO fallback 策略(None 时 ReWOOEngine 用默认值)
|
||||||
|
self.fallback_strategies = fallback_strategies
|
||||||
self._validate_v2()
|
self._validate_v2()
|
||||||
|
|
||||||
def _validate_v2(self) -> None:
|
def _validate_v2(self) -> None:
|
||||||
|
|
@ -130,6 +135,22 @@ class SkillConfig(AgentConfig):
|
||||||
f"must be one of {self.VALID_EXECUTION_MODES}"
|
f"must be one of {self.VALID_EXECUTION_MODES}"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
# v6: validate fallback_strategies against ReWOOEngine.VALID_STRATEGIES (#20).
|
||||||
|
# Lazy import to avoid circular dependency (skills.base -> core.rewoo).
|
||||||
|
if self.fallback_strategies is not None:
|
||||||
|
from agentkit.core.rewoo import ReWOOEngine
|
||||||
|
|
||||||
|
valid = ReWOOEngine.VALID_STRATEGIES
|
||||||
|
invalid = [s for s in self.fallback_strategies if s not in valid]
|
||||||
|
if invalid:
|
||||||
|
raise ConfigValidationError(
|
||||||
|
agent_name=self.name,
|
||||||
|
key="fallback_strategies",
|
||||||
|
reason=(
|
||||||
|
f"Invalid fallback_strategies {invalid}, "
|
||||||
|
f"must be subset of {valid}"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _parse_dependencies(
|
def _parse_dependencies(
|
||||||
|
|
@ -191,6 +212,7 @@ class SkillConfig(AgentConfig):
|
||||||
dependencies=data.get("dependencies"),
|
dependencies=data.get("dependencies"),
|
||||||
capabilities=data.get("capabilities"),
|
capabilities=data.get("capabilities"),
|
||||||
alignment=data.get("alignment"),
|
alignment=data.get("alignment"),
|
||||||
|
fallback_strategies=data.get("fallback_strategies"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -249,8 +271,7 @@ class SkillConfig(AgentConfig):
|
||||||
for dep in self.dependencies
|
for dep in self.dependencies
|
||||||
]
|
]
|
||||||
d["capabilities"] = [
|
d["capabilities"] = [
|
||||||
{"tag": cap.tag, "description": cap.description}
|
{"tag": cap.tag, "description": cap.description} for cap in self.capabilities
|
||||||
for cap in self.capabilities
|
|
||||||
]
|
]
|
||||||
# v5: 对齐守卫
|
# v5: 对齐守卫
|
||||||
d["alignment"] = {
|
d["alignment"] = {
|
||||||
|
|
@ -260,6 +281,8 @@ class SkillConfig(AgentConfig):
|
||||||
"audit_enabled": self.alignment.audit_enabled,
|
"audit_enabled": self.alignment.audit_enabled,
|
||||||
"audit_model": self.alignment.audit_model,
|
"audit_model": self.alignment.audit_model,
|
||||||
}
|
}
|
||||||
|
# v6: ReWOO fallback 策略
|
||||||
|
d["fallback_strategies"] = self.fallback_strategies
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -330,7 +330,7 @@ class TestDocumentLoaderXlsx:
|
||||||
|
|
||||||
assert doc.metadata["truncated"] is True
|
assert doc.metadata["truncated"] is True
|
||||||
assert doc.metadata["row_count"] == 5
|
assert doc.metadata["row_count"] == 5
|
||||||
assert f"truncated at 5 rows" in doc.content
|
assert "truncated at 5 rows" in doc.content
|
||||||
finally:
|
finally:
|
||||||
dl_module.MAX_ROWS_PER_SHEET = original_max
|
dl_module.MAX_ROWS_PER_SHEET = original_max
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -30,30 +30,50 @@ class TestSkillConfigExecutionModes:
|
||||||
"""SkillConfig.VALID_EXECUTION_MODES 扩展测试"""
|
"""SkillConfig.VALID_EXECUTION_MODES 扩展测试"""
|
||||||
|
|
||||||
def test_rewoo_is_valid_mode(self):
|
def test_rewoo_is_valid_mode(self):
|
||||||
config = SkillConfig(name="test_rewoo", agent_type="test", execution_mode="rewoo",
|
config = SkillConfig(
|
||||||
prompt={"identity": "test", "instructions": "test"})
|
name="test_rewoo",
|
||||||
|
agent_type="test",
|
||||||
|
execution_mode="rewoo",
|
||||||
|
prompt={"identity": "test", "instructions": "test"},
|
||||||
|
)
|
||||||
assert config.execution_mode == "rewoo"
|
assert config.execution_mode == "rewoo"
|
||||||
|
|
||||||
def test_plan_exec_is_valid_mode(self):
|
def test_plan_exec_is_valid_mode(self):
|
||||||
config = SkillConfig(name="test_plan_exec", agent_type="test", execution_mode="plan_exec",
|
config = SkillConfig(
|
||||||
prompt={"identity": "test", "instructions": "test"})
|
name="test_plan_exec",
|
||||||
|
agent_type="test",
|
||||||
|
execution_mode="plan_exec",
|
||||||
|
prompt={"identity": "test", "instructions": "test"},
|
||||||
|
)
|
||||||
assert config.execution_mode == "plan_exec"
|
assert config.execution_mode == "plan_exec"
|
||||||
|
|
||||||
def test_reflexion_is_valid_mode(self):
|
def test_reflexion_is_valid_mode(self):
|
||||||
config = SkillConfig(name="test_reflexion", agent_type="test", execution_mode="reflexion",
|
config = SkillConfig(
|
||||||
prompt={"identity": "test", "instructions": "test"})
|
name="test_reflexion",
|
||||||
|
agent_type="test",
|
||||||
|
execution_mode="reflexion",
|
||||||
|
prompt={"identity": "test", "instructions": "test"},
|
||||||
|
)
|
||||||
assert config.execution_mode == "reflexion"
|
assert config.execution_mode == "reflexion"
|
||||||
|
|
||||||
def test_existing_modes_still_valid(self):
|
def test_existing_modes_still_valid(self):
|
||||||
for mode in ("react", "direct", "custom"):
|
for mode in ("react", "direct", "custom"):
|
||||||
config = SkillConfig(name=f"test_{mode}", agent_type="test", execution_mode=mode,
|
config = SkillConfig(
|
||||||
prompt={"identity": "test", "instructions": "test"})
|
name=f"test_{mode}",
|
||||||
|
agent_type="test",
|
||||||
|
execution_mode=mode,
|
||||||
|
prompt={"identity": "test", "instructions": "test"},
|
||||||
|
)
|
||||||
assert config.execution_mode == mode
|
assert config.execution_mode == mode
|
||||||
|
|
||||||
def test_invalid_mode_raises_error(self):
|
def test_invalid_mode_raises_error(self):
|
||||||
with pytest.raises(ConfigValidationError):
|
with pytest.raises(ConfigValidationError):
|
||||||
SkillConfig(name="test_invalid", agent_type="test", execution_mode="nonexistent",
|
SkillConfig(
|
||||||
prompt={"identity": "test", "instructions": "test"})
|
name="test_invalid",
|
||||||
|
agent_type="test",
|
||||||
|
execution_mode="nonexistent",
|
||||||
|
prompt={"identity": "test", "instructions": "test"},
|
||||||
|
)
|
||||||
|
|
||||||
def test_all_six_modes_in_valid_set(self):
|
def test_all_six_modes_in_valid_set(self):
|
||||||
expected = {"react", "direct", "custom", "rewoo", "plan_exec", "reflexion"}
|
expected = {"react", "direct", "custom", "rewoo", "plan_exec", "reflexion"}
|
||||||
|
|
@ -75,6 +95,7 @@ class TestYAMLConfigLoading:
|
||||||
config = SkillConfig(**data)
|
config = SkillConfig(**data)
|
||||||
assert config.execution_mode == "rewoo"
|
assert config.execution_mode == "rewoo"
|
||||||
assert config.agent_type == "parallel_data_fetch"
|
assert config.agent_type == "parallel_data_fetch"
|
||||||
|
assert config.fallback_strategies == ["simplified_rewoo", "react", "direct"]
|
||||||
|
|
||||||
def test_plan_exec_agent_yaml_loads(self):
|
def test_plan_exec_agent_yaml_loads(self):
|
||||||
data = self._load_yaml("plan_exec_agent.yaml")
|
data = self._load_yaml("plan_exec_agent.yaml")
|
||||||
|
|
@ -100,15 +121,16 @@ class TestYAMLConfigLoading:
|
||||||
assert config.execution_mode == "direct"
|
assert config.execution_mode == "direct"
|
||||||
assert config.agent_type == "simple_generation"
|
assert config.agent_type == "simple_generation"
|
||||||
|
|
||||||
def test_different_models_per_agent(self):
|
def test_all_agents_use_default_model(self):
|
||||||
|
"""All agent YAMLs use model: 'default' (LLM gateway resolves the actual provider)."""
|
||||||
direct_data = self._load_yaml("direct_agent.yaml")
|
direct_data = self._load_yaml("direct_agent.yaml")
|
||||||
assert direct_data["llm"]["model"] == "openai/gpt-4o-mini"
|
assert direct_data["llm"]["model"] == "default"
|
||||||
|
|
||||||
plan_data = self._load_yaml("plan_exec_agent.yaml")
|
plan_data = self._load_yaml("plan_exec_agent.yaml")
|
||||||
assert plan_data["llm"]["model"] == "anthropic/claude-opus-4-20250514"
|
assert plan_data["llm"]["model"] == "default"
|
||||||
|
|
||||||
react_data = self._load_yaml("react_agent.yaml")
|
react_data = self._load_yaml("react_agent.yaml")
|
||||||
assert react_data["llm"]["model"] == "anthropic/claude-sonnet-4-20250514"
|
assert react_data["llm"]["model"] == "default"
|
||||||
|
|
||||||
def test_direct_agent_has_no_tools(self):
|
def test_direct_agent_has_no_tools(self):
|
||||||
data = self._load_yaml("direct_agent.yaml")
|
data = self._load_yaml("direct_agent.yaml")
|
||||||
|
|
@ -117,7 +139,7 @@ class TestYAMLConfigLoading:
|
||||||
def test_capabilities_parsed(self):
|
def test_capabilities_parsed(self):
|
||||||
data = self._load_yaml("react_agent.yaml")
|
data = self._load_yaml("react_agent.yaml")
|
||||||
config = SkillConfig(**data)
|
config = SkillConfig(**data)
|
||||||
cap_tags = [c.tag if hasattr(c, 'tag') else c for c in config.capabilities]
|
cap_tags = [c.tag if hasattr(c, "tag") else c for c in config.capabilities]
|
||||||
assert "dynamic_adaptation" in cap_tags
|
assert "dynamic_adaptation" in cap_tags
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -144,7 +166,9 @@ class TestConfigDrivenAgentRouting:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_rewoo_routes_to_handle_rewoo(self):
|
async def test_rewoo_routes_to_handle_rewoo(self):
|
||||||
agent = self._make_agent("rewoo")
|
agent = self._make_agent("rewoo")
|
||||||
with patch.object(agent, '_handle_rewoo', new_callable=AsyncMock, return_value={"content": "rewoo result"}) as mock:
|
with patch.object(
|
||||||
|
agent, "_handle_rewoo", new_callable=AsyncMock, return_value={"content": "rewoo result"}
|
||||||
|
) as mock:
|
||||||
result = await agent.handle_task(_make_task())
|
result = await agent.handle_task(_make_task())
|
||||||
mock.assert_called_once()
|
mock.assert_called_once()
|
||||||
assert result == {"content": "rewoo result"}
|
assert result == {"content": "rewoo result"}
|
||||||
|
|
@ -152,7 +176,12 @@ class TestConfigDrivenAgentRouting:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_plan_exec_routes_to_handle_plan_exec(self):
|
async def test_plan_exec_routes_to_handle_plan_exec(self):
|
||||||
agent = self._make_agent("plan_exec")
|
agent = self._make_agent("plan_exec")
|
||||||
with patch.object(agent, '_handle_plan_exec', new_callable=AsyncMock, return_value={"content": "plan_exec result"}) as mock:
|
with patch.object(
|
||||||
|
agent,
|
||||||
|
"_handle_plan_exec",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value={"content": "plan_exec result"},
|
||||||
|
) as mock:
|
||||||
result = await agent.handle_task(_make_task())
|
result = await agent.handle_task(_make_task())
|
||||||
mock.assert_called_once()
|
mock.assert_called_once()
|
||||||
assert result == {"content": "plan_exec result"}
|
assert result == {"content": "plan_exec result"}
|
||||||
|
|
@ -160,7 +189,12 @@ class TestConfigDrivenAgentRouting:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_reflexion_routes_to_handle_reflexion(self):
|
async def test_reflexion_routes_to_handle_reflexion(self):
|
||||||
agent = self._make_agent("reflexion")
|
agent = self._make_agent("reflexion")
|
||||||
with patch.object(agent, '_handle_reflexion', new_callable=AsyncMock, return_value={"content": "reflexion result"}) as mock:
|
with patch.object(
|
||||||
|
agent,
|
||||||
|
"_handle_reflexion",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value={"content": "reflexion result"},
|
||||||
|
) as mock:
|
||||||
result = await agent.handle_task(_make_task())
|
result = await agent.handle_task(_make_task())
|
||||||
mock.assert_called_once()
|
mock.assert_called_once()
|
||||||
assert result == {"content": "reflexion result"}
|
assert result == {"content": "reflexion result"}
|
||||||
|
|
@ -168,7 +202,81 @@ class TestConfigDrivenAgentRouting:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_react_still_routes_correctly(self):
|
async def test_react_still_routes_correctly(self):
|
||||||
agent = self._make_agent("react")
|
agent = self._make_agent("react")
|
||||||
with patch.object(agent, '_handle_react', new_callable=AsyncMock, return_value={"content": "react result"}) as mock:
|
with patch.object(
|
||||||
|
agent, "_handle_react", new_callable=AsyncMock, return_value={"content": "react result"}
|
||||||
|
) as mock:
|
||||||
result = await agent.handle_task(_make_task())
|
result = await agent.handle_task(_make_task())
|
||||||
mock.assert_called_once()
|
mock.assert_called_once()
|
||||||
assert result == {"content": "react result"}
|
assert result == {"content": "react result"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestFallbackStrategiesWiring:
|
||||||
|
"""Verify fallback_strategies flows from SkillConfig -> ReWOOEngine (#5)"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _make_fake_engine(captured_kwargs: dict):
|
||||||
|
"""Build a FakeReWOOEngine that records kwargs and returns a result with .output."""
|
||||||
|
|
||||||
|
class FakeReWOOEngine:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
captured_kwargs.update(kwargs)
|
||||||
|
|
||||||
|
async def execute(self, **kwargs):
|
||||||
|
class _Result:
|
||||||
|
output = "rewoo result"
|
||||||
|
|
||||||
|
return _Result()
|
||||||
|
|
||||||
|
return FakeReWOOEngine
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fallback_strategies_passed_to_rewoo_engine(self):
|
||||||
|
"""SkillConfig.fallback_strategies must reach ReWOOEngine constructor."""
|
||||||
|
from agentkit.core.config_driven import ConfigDrivenAgent
|
||||||
|
from agentkit.llm.gateway import LLMGateway
|
||||||
|
|
||||||
|
config = SkillConfig(
|
||||||
|
name="test_rewoo_wiring",
|
||||||
|
agent_type="test",
|
||||||
|
execution_mode="rewoo",
|
||||||
|
prompt={"identity": "test", "instructions": "test"},
|
||||||
|
fallback_strategies=["simplified_rewoo", "direct"],
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_gateway = MagicMock(spec=LLMGateway)
|
||||||
|
llm_gateway.chat = AsyncMock()
|
||||||
|
agent = ConfigDrivenAgent(config=config, llm_gateway=llm_gateway)
|
||||||
|
|
||||||
|
captured_kwargs: dict = {}
|
||||||
|
FakeReWOOEngine = self._make_fake_engine(captured_kwargs)
|
||||||
|
|
||||||
|
# ReWOOEngine is imported lazily inside _handle_rewoo, so patch at source.
|
||||||
|
with patch("agentkit.core.rewoo.ReWOOEngine", FakeReWOOEngine):
|
||||||
|
await agent._handle_rewoo(_make_task())
|
||||||
|
|
||||||
|
assert captured_kwargs.get("fallback_strategies") == ["simplified_rewoo", "direct"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fallback_strategies_none_when_not_configured(self):
|
||||||
|
"""When fallback_strategies is None, ReWOOEngine receives None (uses defaults)."""
|
||||||
|
from agentkit.core.config_driven import ConfigDrivenAgent
|
||||||
|
from agentkit.llm.gateway import LLMGateway
|
||||||
|
|
||||||
|
config = SkillConfig(
|
||||||
|
name="test_rewoo_no_fallback",
|
||||||
|
agent_type="test",
|
||||||
|
execution_mode="rewoo",
|
||||||
|
prompt={"identity": "test", "instructions": "test"},
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_gateway = MagicMock(spec=LLMGateway)
|
||||||
|
llm_gateway.chat = AsyncMock()
|
||||||
|
agent = ConfigDrivenAgent(config=config, llm_gateway=llm_gateway)
|
||||||
|
|
||||||
|
captured_kwargs: dict = {}
|
||||||
|
FakeReWOOEngine = self._make_fake_engine(captured_kwargs)
|
||||||
|
|
||||||
|
with patch("agentkit.core.rewoo.ReWOOEngine", FakeReWOOEngine):
|
||||||
|
await agent._handle_rewoo(_make_task())
|
||||||
|
|
||||||
|
assert captured_kwargs.get("fallback_strategies") is None
|
||||||
|
|
|
||||||
|
|
@ -227,6 +227,126 @@ class TestSkillConfig:
|
||||||
assert result["execution_mode"] == "react"
|
assert result["execution_mode"] == "react"
|
||||||
assert result["max_steps"] == 5
|
assert result["max_steps"] == 5
|
||||||
|
|
||||||
|
def test_fallback_strategies_default_is_none(self):
|
||||||
|
"""SkillConfig.fallback_strategies defaults to None (#6)."""
|
||||||
|
config = SkillConfig(
|
||||||
|
name="test",
|
||||||
|
agent_type="test",
|
||||||
|
task_mode="llm_generate",
|
||||||
|
prompt={"identity": "test"},
|
||||||
|
)
|
||||||
|
assert config.fallback_strategies is None
|
||||||
|
|
||||||
|
def test_fallback_strategies_from_dict(self):
|
||||||
|
"""from_dict populates fallback_strategies (#6)."""
|
||||||
|
data = {
|
||||||
|
"name": "test",
|
||||||
|
"agent_type": "test",
|
||||||
|
"task_mode": "llm_generate",
|
||||||
|
"prompt": {"identity": "test"},
|
||||||
|
"fallback_strategies": ["simplified_rewoo", "react", "direct"],
|
||||||
|
}
|
||||||
|
config = SkillConfig.from_dict(data)
|
||||||
|
assert config.fallback_strategies == ["simplified_rewoo", "react", "direct"]
|
||||||
|
|
||||||
|
def test_fallback_strategies_to_dict(self):
|
||||||
|
"""to_dict includes fallback_strategies when set (#6, #14)."""
|
||||||
|
config = SkillConfig(
|
||||||
|
name="test",
|
||||||
|
agent_type="test",
|
||||||
|
task_mode="llm_generate",
|
||||||
|
prompt={"identity": "test"},
|
||||||
|
fallback_strategies=["simplified_rewoo", "direct"],
|
||||||
|
)
|
||||||
|
result = config.to_dict()
|
||||||
|
assert result["fallback_strategies"] == ["simplified_rewoo", "direct"]
|
||||||
|
|
||||||
|
def test_fallback_strategies_to_dict_none_when_not_set(self):
|
||||||
|
"""to_dict includes fallback_strategies as None when not configured (#14)."""
|
||||||
|
config = SkillConfig(
|
||||||
|
name="test",
|
||||||
|
agent_type="test",
|
||||||
|
task_mode="llm_generate",
|
||||||
|
prompt={"identity": "test"},
|
||||||
|
)
|
||||||
|
result = config.to_dict()
|
||||||
|
assert "fallback_strategies" in result
|
||||||
|
assert result["fallback_strategies"] is None
|
||||||
|
|
||||||
|
def test_fallback_strategies_yaml_round_trip(self):
|
||||||
|
"""YAML round-trip preserves fallback_strategies (#6)."""
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"name": "round_trip",
|
||||||
|
"agent_type": "test",
|
||||||
|
"task_mode": "llm_generate",
|
||||||
|
"prompt": {"identity": "test"},
|
||||||
|
"fallback_strategies": ["simplified_rewoo", "react", "direct"],
|
||||||
|
}
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
||||||
|
yaml.dump(data, f)
|
||||||
|
path = f.name
|
||||||
|
try:
|
||||||
|
config = SkillConfig.from_yaml(path)
|
||||||
|
assert config.fallback_strategies == ["simplified_rewoo", "react", "direct"]
|
||||||
|
# Round-trip through to_dict -> from_dict
|
||||||
|
round_tripped = SkillConfig.from_dict(config.to_dict())
|
||||||
|
assert round_tripped.fallback_strategies == ["simplified_rewoo", "react", "direct"]
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
def test_fallback_strategies_invalid_value_raises_error(self):
|
||||||
|
"""Invalid fallback_strategies value raises ConfigValidationError (#20)."""
|
||||||
|
# _validate_v2 is called in __init__, so construction itself raises
|
||||||
|
with pytest.raises(ConfigValidationError, match="fallback_strategies"):
|
||||||
|
SkillConfig(
|
||||||
|
name="test_invalid_fallback",
|
||||||
|
agent_type="test",
|
||||||
|
task_mode="llm_generate",
|
||||||
|
prompt={"identity": "test"},
|
||||||
|
fallback_strategies=["simplified_rewoo", "invalid_strategy"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_fallback_strategies_invalid_value_via_from_dict(self):
|
||||||
|
"""Invalid fallback_strategies via from_dict also raises (#20)."""
|
||||||
|
with pytest.raises(ConfigValidationError, match="fallback_strategies"):
|
||||||
|
SkillConfig.from_dict({
|
||||||
|
"name": "test_invalid_fallback",
|
||||||
|
"agent_type": "test",
|
||||||
|
"task_mode": "llm_generate",
|
||||||
|
"prompt": {"identity": "test"},
|
||||||
|
"fallback_strategies": ["bogus"],
|
||||||
|
})
|
||||||
|
|
||||||
|
def test_fallback_strategies_valid_subset_accepted(self):
|
||||||
|
"""Valid subset of ReWOOEngine.VALID_STRATEGIES is accepted (#20)."""
|
||||||
|
# All valid strategies
|
||||||
|
for strategies in [
|
||||||
|
["simplified_rewoo"],
|
||||||
|
["react", "direct"],
|
||||||
|
["simplified_rewoo", "react", "direct", "plan_exec"],
|
||||||
|
]:
|
||||||
|
config = SkillConfig(
|
||||||
|
name="test",
|
||||||
|
agent_type="test",
|
||||||
|
task_mode="llm_generate",
|
||||||
|
prompt={"identity": "test"},
|
||||||
|
fallback_strategies=strategies,
|
||||||
|
)
|
||||||
|
assert config.fallback_strategies == strategies
|
||||||
|
|
||||||
|
def test_fallback_strategies_none_bypasses_validation(self):
|
||||||
|
"""None fallback_strategies skips validation (uses ReWOOEngine defaults) (#20)."""
|
||||||
|
config = SkillConfig(
|
||||||
|
name="test",
|
||||||
|
agent_type="test",
|
||||||
|
task_mode="llm_generate",
|
||||||
|
prompt={"identity": "test"},
|
||||||
|
fallback_strategies=None,
|
||||||
|
)
|
||||||
|
assert config.fallback_strategies is None
|
||||||
|
|
||||||
def test_invalid_execution_mode_raises_config_validation_error(self):
|
def test_invalid_execution_mode_raises_config_validation_error(self):
|
||||||
data = {
|
data = {
|
||||||
"name": "bad_mode",
|
"name": "bad_mode",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue