diff --git a/.trae/rules/project_rules.md b/.trae/rules/project_rules.md index b749419..dd9b228 100644 --- a/.trae/rules/project_rules.md +++ b/.trae/rules/project_rules.md @@ -36,4 +36,3 @@ This applies to ALL async generator functions in the codebase. When adding an ea ## Testing - Run `python3 -m pytest tests/unit/ -x -q` before committing -- Known failing test (unrelated): `test_rewoo_agent_yaml_loads` — skip if needed diff --git a/src/agentkit/core/config_driven.py b/src/agentkit/core/config_driven.py index 7c6075d..dc9a546 100644 --- a/src/agentkit/core/config_driven.py +++ b/src/agentkit/core/config_driven.py @@ -245,21 +245,22 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin): self._react_engine = ReActEngine( llm_gateway=self._llm_gateway, - max_steps=getattr(config, 'max_steps', 5), + max_steps=getattr(config, "max_steps", 5), ) # v2: Initialize Quality Gate (always available) from agentkit.quality.gate import QualityGate + self._quality_gate = QualityGate() # v2: Initialize Evolution if configured - evolution_config = getattr(config, 'evolution', None) + evolution_config = getattr(config, "evolution", None) if evolution_config is not None: # Support both dict and EvolutionConfig if isinstance(evolution_config, dict): is_enabled = evolution_config.get("enabled", False) else: - is_enabled = getattr(evolution_config, 'enabled', False) + is_enabled = getattr(evolution_config, "enabled", False) else: is_enabled = False @@ -276,6 +277,7 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin): # v2: Initialize Output Standardizer from agentkit.quality.output import OutputStandardizer + self._output_standardizer = OutputStandardizer() # v2: Store compressor for ReAct engine @@ -327,6 +329,7 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin): if config.memory.get("working", {}).get("enabled"): import redis.asyncio as aioredis + redis_url = config.memory["working"].get("redis_url", "redis://localhost:6379") redis_client = aioredis.from_url(redis_url, decode_responses=True) working = WorkingMemory(redis=redis_client) @@ -350,7 +353,7 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin): ) episodic = EpisodicMemory( session_factory=None, # Set externally when DB session is available - episodic_model=None, # Set externally when ORM model is available + episodic_model=None, # Set externally when ORM model is available embedder=embedder, decay_rate=epi_conf.get("decay_rate", 0.01), alpha=epi_conf.get("alpha", 0.7), @@ -471,6 +474,7 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin): try: from agentkit.core.protocol import TaskResult, TaskStatus from datetime import datetime, timezone + result = TaskResult( task_id=task.task_id, agent_name=self.name, @@ -490,6 +494,7 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin): try: from agentkit.core.protocol import TaskResult, TaskStatus from datetime import datetime, timezone + result = TaskResult( task_id=task.task_id, agent_name=self.name, @@ -534,12 +539,16 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin): input_fields = {} if self._config.input_schema: for field_name, field_info in self._config.input_schema.items(): - input_fields[field_name] = str(field_info) if not isinstance(field_info, str) else field_info + input_fields[field_name] = ( + str(field_info) if not isinstance(field_info, str) else field_info + ) output_fields = {} if self._config.output_schema: for field_name, field_info in self._config.output_schema.items(): - output_fields[field_name] = str(field_info) if not isinstance(field_info, str) else field_info + output_fields[field_name] = ( + str(field_info) if not isinstance(field_info, str) else field_info + ) module = Module( name=self.name, @@ -731,6 +740,11 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin): llm_gateway=self._llm_gateway, max_plan_steps=self._skill_config.max_steps if self._skill_config else 5, default_timeout=300.0, + fallback_strategies=( + self._skill_config.fallback_strategies + if self._skill_config and self._skill_config.fallback_strategies + else None + ), ) result = await rewoo_engine.execute( @@ -901,23 +915,17 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin): self._raw_client = raw_client async def chat(self, request: LLMRequest) -> LLMResponse: - kwargs = dict(request._extra) if hasattr(request, '_extra') else {} + kwargs = dict(request._extra) if hasattr(request, "_extra") else {} kwargs["model"] = request.model kwargs["temperature"] = request.temperature kwargs["max_tokens"] = request.max_tokens if hasattr(self._raw_client, "chat"): - response = await self._raw_client.chat( - messages=request.messages, **kwargs - ) + response = await self._raw_client.chat(messages=request.messages, **kwargs) elif hasattr(self._raw_client, "create"): - response = await self._raw_client.create( - messages=request.messages, **kwargs - ) + response = await self._raw_client.create(messages=request.messages, **kwargs) elif callable(self._raw_client): - response = await self._raw_client( - messages=request.messages, **kwargs - ) + response = await self._raw_client(messages=request.messages, **kwargs) else: raise ConfigValidationError( agent_name="", @@ -1072,6 +1080,7 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin): # 尝试提取 JSON 块 import re + json_match = re.search(r"```(?:json)?\s*\n?(.*?)\n?```", response, re.DOTALL) if json_match: try: @@ -1095,6 +1104,7 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin): try: module_path, func_name = dotted_path.rsplit(".", 1) import importlib + module = importlib.import_module(module_path) handler = getattr(module, func_name) if not callable(handler): diff --git a/src/agentkit/skills/base.py b/src/agentkit/skills/base.py index 9c70b6d..5832b45 100644 --- a/src/agentkit/skills/base.py +++ b/src/agentkit/skills/base.py @@ -87,6 +87,8 @@ class SkillConfig(AgentConfig): capabilities: list[str | dict[str, Any] | CapabilityTag] | None = None, # v5 新增字段:对齐守卫 alignment: dict[str, Any] | None = None, + # v6 新增字段:ReWOO fallback 策略(YAML 可配置) + fallback_strategies: list[str] | None = None, ): super().__init__( name=name, @@ -116,7 +118,10 @@ class SkillConfig(AgentConfig): self.capabilities = self._parse_capabilities(capabilities or []) # v5: 对齐守卫配置 from agentkit.quality.alignment import AlignmentConfig + self.alignment = AlignmentConfig(**(alignment or {})) + # v6: ReWOO fallback 策略(None 时 ReWOOEngine 用默认值) + self.fallback_strategies = fallback_strategies self._validate_v2() def _validate_v2(self) -> None: @@ -130,6 +135,22 @@ class SkillConfig(AgentConfig): f"must be one of {self.VALID_EXECUTION_MODES}" ), ) + # v6: validate fallback_strategies against ReWOOEngine.VALID_STRATEGIES (#20). + # Lazy import to avoid circular dependency (skills.base -> core.rewoo). + if self.fallback_strategies is not None: + from agentkit.core.rewoo import ReWOOEngine + + valid = ReWOOEngine.VALID_STRATEGIES + invalid = [s for s in self.fallback_strategies if s not in valid] + if invalid: + raise ConfigValidationError( + agent_name=self.name, + key="fallback_strategies", + reason=( + f"Invalid fallback_strategies {invalid}, " + f"must be subset of {valid}" + ), + ) @staticmethod def _parse_dependencies( @@ -191,6 +212,7 @@ class SkillConfig(AgentConfig): dependencies=data.get("dependencies"), capabilities=data.get("capabilities"), alignment=data.get("alignment"), + fallback_strategies=data.get("fallback_strategies"), ) @classmethod @@ -249,8 +271,7 @@ class SkillConfig(AgentConfig): for dep in self.dependencies ] d["capabilities"] = [ - {"tag": cap.tag, "description": cap.description} - for cap in self.capabilities + {"tag": cap.tag, "description": cap.description} for cap in self.capabilities ] # v5: 对齐守卫 d["alignment"] = { @@ -260,6 +281,8 @@ class SkillConfig(AgentConfig): "audit_enabled": self.alignment.audit_enabled, "audit_model": self.alignment.audit_model, } + # v6: ReWOO fallback 策略 + d["fallback_strategies"] = self.fallback_strategies return d diff --git a/tests/unit/memory/test_document_loader.py b/tests/unit/memory/test_document_loader.py index 73964a9..77c065b 100644 --- a/tests/unit/memory/test_document_loader.py +++ b/tests/unit/memory/test_document_loader.py @@ -330,7 +330,7 @@ class TestDocumentLoaderXlsx: assert doc.metadata["truncated"] is True assert doc.metadata["row_count"] == 5 - assert f"truncated at 5 rows" in doc.content + assert "truncated at 5 rows" in doc.content finally: dl_module.MAX_ROWS_PER_SHEET = original_max diff --git a/tests/unit/test_execution_modes.py b/tests/unit/test_execution_modes.py index 8801223..d0244a7 100644 --- a/tests/unit/test_execution_modes.py +++ b/tests/unit/test_execution_modes.py @@ -30,30 +30,50 @@ class TestSkillConfigExecutionModes: """SkillConfig.VALID_EXECUTION_MODES 扩展测试""" def test_rewoo_is_valid_mode(self): - config = SkillConfig(name="test_rewoo", agent_type="test", execution_mode="rewoo", - prompt={"identity": "test", "instructions": "test"}) + config = SkillConfig( + name="test_rewoo", + agent_type="test", + execution_mode="rewoo", + prompt={"identity": "test", "instructions": "test"}, + ) assert config.execution_mode == "rewoo" def test_plan_exec_is_valid_mode(self): - config = SkillConfig(name="test_plan_exec", agent_type="test", execution_mode="plan_exec", - prompt={"identity": "test", "instructions": "test"}) + config = SkillConfig( + name="test_plan_exec", + agent_type="test", + execution_mode="plan_exec", + prompt={"identity": "test", "instructions": "test"}, + ) assert config.execution_mode == "plan_exec" def test_reflexion_is_valid_mode(self): - config = SkillConfig(name="test_reflexion", agent_type="test", execution_mode="reflexion", - prompt={"identity": "test", "instructions": "test"}) + config = SkillConfig( + name="test_reflexion", + agent_type="test", + execution_mode="reflexion", + prompt={"identity": "test", "instructions": "test"}, + ) assert config.execution_mode == "reflexion" def test_existing_modes_still_valid(self): for mode in ("react", "direct", "custom"): - config = SkillConfig(name=f"test_{mode}", agent_type="test", execution_mode=mode, - prompt={"identity": "test", "instructions": "test"}) + config = SkillConfig( + name=f"test_{mode}", + agent_type="test", + execution_mode=mode, + prompt={"identity": "test", "instructions": "test"}, + ) assert config.execution_mode == mode def test_invalid_mode_raises_error(self): with pytest.raises(ConfigValidationError): - SkillConfig(name="test_invalid", agent_type="test", execution_mode="nonexistent", - prompt={"identity": "test", "instructions": "test"}) + SkillConfig( + name="test_invalid", + agent_type="test", + execution_mode="nonexistent", + prompt={"identity": "test", "instructions": "test"}, + ) def test_all_six_modes_in_valid_set(self): expected = {"react", "direct", "custom", "rewoo", "plan_exec", "reflexion"} @@ -75,6 +95,7 @@ class TestYAMLConfigLoading: config = SkillConfig(**data) assert config.execution_mode == "rewoo" assert config.agent_type == "parallel_data_fetch" + assert config.fallback_strategies == ["simplified_rewoo", "react", "direct"] def test_plan_exec_agent_yaml_loads(self): data = self._load_yaml("plan_exec_agent.yaml") @@ -100,15 +121,16 @@ class TestYAMLConfigLoading: assert config.execution_mode == "direct" assert config.agent_type == "simple_generation" - def test_different_models_per_agent(self): + def test_all_agents_use_default_model(self): + """All agent YAMLs use model: 'default' (LLM gateway resolves the actual provider).""" direct_data = self._load_yaml("direct_agent.yaml") - assert direct_data["llm"]["model"] == "openai/gpt-4o-mini" + assert direct_data["llm"]["model"] == "default" plan_data = self._load_yaml("plan_exec_agent.yaml") - assert plan_data["llm"]["model"] == "anthropic/claude-opus-4-20250514" + assert plan_data["llm"]["model"] == "default" react_data = self._load_yaml("react_agent.yaml") - assert react_data["llm"]["model"] == "anthropic/claude-sonnet-4-20250514" + assert react_data["llm"]["model"] == "default" def test_direct_agent_has_no_tools(self): data = self._load_yaml("direct_agent.yaml") @@ -117,7 +139,7 @@ class TestYAMLConfigLoading: def test_capabilities_parsed(self): data = self._load_yaml("react_agent.yaml") config = SkillConfig(**data) - cap_tags = [c.tag if hasattr(c, 'tag') else c for c in config.capabilities] + cap_tags = [c.tag if hasattr(c, "tag") else c for c in config.capabilities] assert "dynamic_adaptation" in cap_tags @@ -144,7 +166,9 @@ class TestConfigDrivenAgentRouting: @pytest.mark.asyncio async def test_rewoo_routes_to_handle_rewoo(self): agent = self._make_agent("rewoo") - with patch.object(agent, '_handle_rewoo', new_callable=AsyncMock, return_value={"content": "rewoo result"}) as mock: + with patch.object( + agent, "_handle_rewoo", new_callable=AsyncMock, return_value={"content": "rewoo result"} + ) as mock: result = await agent.handle_task(_make_task()) mock.assert_called_once() assert result == {"content": "rewoo result"} @@ -152,7 +176,12 @@ class TestConfigDrivenAgentRouting: @pytest.mark.asyncio async def test_plan_exec_routes_to_handle_plan_exec(self): agent = self._make_agent("plan_exec") - with patch.object(agent, '_handle_plan_exec', new_callable=AsyncMock, return_value={"content": "plan_exec result"}) as mock: + with patch.object( + agent, + "_handle_plan_exec", + new_callable=AsyncMock, + return_value={"content": "plan_exec result"}, + ) as mock: result = await agent.handle_task(_make_task()) mock.assert_called_once() assert result == {"content": "plan_exec result"} @@ -160,7 +189,12 @@ class TestConfigDrivenAgentRouting: @pytest.mark.asyncio async def test_reflexion_routes_to_handle_reflexion(self): agent = self._make_agent("reflexion") - with patch.object(agent, '_handle_reflexion', new_callable=AsyncMock, return_value={"content": "reflexion result"}) as mock: + with patch.object( + agent, + "_handle_reflexion", + new_callable=AsyncMock, + return_value={"content": "reflexion result"}, + ) as mock: result = await agent.handle_task(_make_task()) mock.assert_called_once() assert result == {"content": "reflexion result"} @@ -168,7 +202,81 @@ class TestConfigDrivenAgentRouting: @pytest.mark.asyncio async def test_react_still_routes_correctly(self): agent = self._make_agent("react") - with patch.object(agent, '_handle_react', new_callable=AsyncMock, return_value={"content": "react result"}) as mock: + with patch.object( + agent, "_handle_react", new_callable=AsyncMock, return_value={"content": "react result"} + ) as mock: result = await agent.handle_task(_make_task()) mock.assert_called_once() assert result == {"content": "react result"} + + +class TestFallbackStrategiesWiring: + """Verify fallback_strategies flows from SkillConfig -> ReWOOEngine (#5)""" + + @staticmethod + def _make_fake_engine(captured_kwargs: dict): + """Build a FakeReWOOEngine that records kwargs and returns a result with .output.""" + + class FakeReWOOEngine: + def __init__(self, **kwargs): + captured_kwargs.update(kwargs) + + async def execute(self, **kwargs): + class _Result: + output = "rewoo result" + + return _Result() + + return FakeReWOOEngine + + @pytest.mark.asyncio + async def test_fallback_strategies_passed_to_rewoo_engine(self): + """SkillConfig.fallback_strategies must reach ReWOOEngine constructor.""" + from agentkit.core.config_driven import ConfigDrivenAgent + from agentkit.llm.gateway import LLMGateway + + config = SkillConfig( + name="test_rewoo_wiring", + agent_type="test", + execution_mode="rewoo", + prompt={"identity": "test", "instructions": "test"}, + fallback_strategies=["simplified_rewoo", "direct"], + ) + + llm_gateway = MagicMock(spec=LLMGateway) + llm_gateway.chat = AsyncMock() + agent = ConfigDrivenAgent(config=config, llm_gateway=llm_gateway) + + captured_kwargs: dict = {} + FakeReWOOEngine = self._make_fake_engine(captured_kwargs) + + # ReWOOEngine is imported lazily inside _handle_rewoo, so patch at source. + with patch("agentkit.core.rewoo.ReWOOEngine", FakeReWOOEngine): + await agent._handle_rewoo(_make_task()) + + assert captured_kwargs.get("fallback_strategies") == ["simplified_rewoo", "direct"] + + @pytest.mark.asyncio + async def test_fallback_strategies_none_when_not_configured(self): + """When fallback_strategies is None, ReWOOEngine receives None (uses defaults).""" + from agentkit.core.config_driven import ConfigDrivenAgent + from agentkit.llm.gateway import LLMGateway + + config = SkillConfig( + name="test_rewoo_no_fallback", + agent_type="test", + execution_mode="rewoo", + prompt={"identity": "test", "instructions": "test"}, + ) + + llm_gateway = MagicMock(spec=LLMGateway) + llm_gateway.chat = AsyncMock() + agent = ConfigDrivenAgent(config=config, llm_gateway=llm_gateway) + + captured_kwargs: dict = {} + FakeReWOOEngine = self._make_fake_engine(captured_kwargs) + + with patch("agentkit.core.rewoo.ReWOOEngine", FakeReWOOEngine): + await agent._handle_rewoo(_make_task()) + + assert captured_kwargs.get("fallback_strategies") is None diff --git a/tests/unit/test_skill_config.py b/tests/unit/test_skill_config.py index 28784be..86b97d3 100644 --- a/tests/unit/test_skill_config.py +++ b/tests/unit/test_skill_config.py @@ -227,6 +227,126 @@ class TestSkillConfig: assert result["execution_mode"] == "react" assert result["max_steps"] == 5 + def test_fallback_strategies_default_is_none(self): + """SkillConfig.fallback_strategies defaults to None (#6).""" + config = SkillConfig( + name="test", + agent_type="test", + task_mode="llm_generate", + prompt={"identity": "test"}, + ) + assert config.fallback_strategies is None + + def test_fallback_strategies_from_dict(self): + """from_dict populates fallback_strategies (#6).""" + data = { + "name": "test", + "agent_type": "test", + "task_mode": "llm_generate", + "prompt": {"identity": "test"}, + "fallback_strategies": ["simplified_rewoo", "react", "direct"], + } + config = SkillConfig.from_dict(data) + assert config.fallback_strategies == ["simplified_rewoo", "react", "direct"] + + def test_fallback_strategies_to_dict(self): + """to_dict includes fallback_strategies when set (#6, #14).""" + config = SkillConfig( + name="test", + agent_type="test", + task_mode="llm_generate", + prompt={"identity": "test"}, + fallback_strategies=["simplified_rewoo", "direct"], + ) + result = config.to_dict() + assert result["fallback_strategies"] == ["simplified_rewoo", "direct"] + + def test_fallback_strategies_to_dict_none_when_not_set(self): + """to_dict includes fallback_strategies as None when not configured (#14).""" + config = SkillConfig( + name="test", + agent_type="test", + task_mode="llm_generate", + prompt={"identity": "test"}, + ) + result = config.to_dict() + assert "fallback_strategies" in result + assert result["fallback_strategies"] is None + + def test_fallback_strategies_yaml_round_trip(self): + """YAML round-trip preserves fallback_strategies (#6).""" + import tempfile + + data = { + "name": "round_trip", + "agent_type": "test", + "task_mode": "llm_generate", + "prompt": {"identity": "test"}, + "fallback_strategies": ["simplified_rewoo", "react", "direct"], + } + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml.dump(data, f) + path = f.name + try: + config = SkillConfig.from_yaml(path) + assert config.fallback_strategies == ["simplified_rewoo", "react", "direct"] + # Round-trip through to_dict -> from_dict + round_tripped = SkillConfig.from_dict(config.to_dict()) + assert round_tripped.fallback_strategies == ["simplified_rewoo", "react", "direct"] + finally: + os.unlink(path) + + def test_fallback_strategies_invalid_value_raises_error(self): + """Invalid fallback_strategies value raises ConfigValidationError (#20).""" + # _validate_v2 is called in __init__, so construction itself raises + with pytest.raises(ConfigValidationError, match="fallback_strategies"): + SkillConfig( + name="test_invalid_fallback", + agent_type="test", + task_mode="llm_generate", + prompt={"identity": "test"}, + fallback_strategies=["simplified_rewoo", "invalid_strategy"], + ) + + def test_fallback_strategies_invalid_value_via_from_dict(self): + """Invalid fallback_strategies via from_dict also raises (#20).""" + with pytest.raises(ConfigValidationError, match="fallback_strategies"): + SkillConfig.from_dict({ + "name": "test_invalid_fallback", + "agent_type": "test", + "task_mode": "llm_generate", + "prompt": {"identity": "test"}, + "fallback_strategies": ["bogus"], + }) + + def test_fallback_strategies_valid_subset_accepted(self): + """Valid subset of ReWOOEngine.VALID_STRATEGIES is accepted (#20).""" + # All valid strategies + for strategies in [ + ["simplified_rewoo"], + ["react", "direct"], + ["simplified_rewoo", "react", "direct", "plan_exec"], + ]: + config = SkillConfig( + name="test", + agent_type="test", + task_mode="llm_generate", + prompt={"identity": "test"}, + fallback_strategies=strategies, + ) + assert config.fallback_strategies == strategies + + def test_fallback_strategies_none_bypasses_validation(self): + """None fallback_strategies skips validation (uses ReWOOEngine defaults) (#20).""" + config = SkillConfig( + name="test", + agent_type="test", + task_mode="llm_generate", + prompt={"identity": "test"}, + fallback_strategies=None, + ) + assert config.fallback_strategies is None + def test_invalid_execution_mode_raises_config_validation_error(self): data = { "name": "bad_mode",