fischer-agentkit/tests/unit/test_risk_guard_learner.py

202 lines
7.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""RiskGuardLearner 单元测试"""
import json
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from agentkit.evolution.experience_schema import TaskExperience
from agentkit.evolution.risk_guard_learner import RiskGuardLearner
def _make_experience(
experience_id="exp1",
task_type="code_reviewer",
goal="review code",
outcome="failure",
failure_reasons=None,
optimization_tips=None,
) -> TaskExperience:
return TaskExperience(
experience_id=experience_id,
task_type=task_type,
goal=goal,
steps_summary="loaded skill; ran review",
outcome=outcome,
failure_reasons=failure_reasons or ["no code provided"],
optimization_tips=optimization_tips or ["require code input"],
)
def _make_llm_response(content: str):
return SimpleNamespace(content=content)
class TestRiskGuardLearner:
@pytest.mark.asyncio
async def test_learn_happy_path(self):
"""3 条失败轨迹 + 合法 JSON → 返回建议"""
store = AsyncMock()
store.search.return_value = [
_make_experience("e1", "code_reviewer", "review A"),
_make_experience("e2", "code_reviewer", "review B"),
_make_experience("e3", "code_reviewer", "review C"),
]
llm = AsyncMock()
llm.chat.return_value = _make_llm_response(
json.dumps(
[
{
"skill_name": "code_reviewer",
"precondition": "输入必须包含待审查的代码片段",
"reason": "多次因输入为空导致审查失败",
"confidence": 0.85,
},
{
"skill_name": "code_reviewer",
"precondition": "代码片段长度 >= 10 字符",
"reason": "过短输入无法有效审查",
"confidence": 0.6,
},
]
)
)
learner = RiskGuardLearner(store, llm)
suggestions = await learner.learn()
assert len(suggestions) == 2
assert suggestions[0].skill_name == "code_reviewer"
assert suggestions[0].precondition == "输入必须包含待审查的代码片段"
assert suggestions[0].confidence == 0.85
assert set(suggestions[0].source_experience_ids) == {"e1", "e2", "e3"}
@pytest.mark.asyncio
async def test_learn_skill_name_filter(self):
"""skill_name 透传给 search 的 task_type"""
store = AsyncMock()
store.search.return_value = [_make_experience("e1", "code_reviewer")]
llm = AsyncMock()
llm.chat.return_value = _make_llm_response("[]")
learner = RiskGuardLearner(store, llm)
await learner.learn(skill_name="code_reviewer")
store.search.assert_called_once_with(query="failure", top_k=20, task_type="code_reviewer")
@pytest.mark.asyncio
async def test_learn_llm_exception_returns_empty(self):
"""LLM 调用抛异常 → 返回空列表,不抛"""
store = AsyncMock()
store.search.return_value = [_make_experience("e1")]
llm = AsyncMock()
llm.chat.side_effect = RuntimeError("LLM down")
learner = RiskGuardLearner(store, llm)
suggestions = await learner.learn()
assert suggestions == []
@pytest.mark.asyncio
async def test_learn_invalid_json_returns_empty(self):
"""LLM 返回非法 JSON → 返回空列表"""
store = AsyncMock()
store.search.return_value = [_make_experience("e1")]
llm = AsyncMock()
llm.chat.return_value = _make_llm_response("not json at all")
learner = RiskGuardLearner(store, llm)
suggestions = await learner.learn()
assert suggestions == []
@pytest.mark.asyncio
async def test_learn_no_failures_returns_empty(self):
"""ExperienceStore 返回空 → 返回空列表,不调用 LLM"""
store = AsyncMock()
store.search.return_value = []
llm = AsyncMock()
learner = RiskGuardLearner(store, llm)
suggestions = await learner.learn()
assert suggestions == []
llm.chat.assert_not_called()
@pytest.mark.asyncio
async def test_learn_filters_non_failure_outcomes(self):
"""只保留 outcome == 'failure' 的轨迹"""
store = AsyncMock()
store.search.return_value = [
_make_experience("e1", goal="failure-goal", outcome="failure"),
_make_experience("e2", goal="success-goal", outcome="success"),
_make_experience("e3", goal="partial-goal", outcome="partial"),
]
llm = AsyncMock()
llm.chat.return_value = _make_llm_response("[]")
learner = RiskGuardLearner(store, llm)
await learner.learn()
# 只有 e1 是 failureprompt 中应含 failure-goal不含 success/partial 的 goal
call_args = llm.chat.call_args
prompt = call_args.kwargs["messages"][1]["content"]
assert "failure-goal" in prompt
assert "success-goal" not in prompt
assert "partial-goal" not in prompt
@pytest.mark.asyncio
async def test_confidence_clamped(self):
"""confidence 被 clamp 到 [0.0, 1.0]"""
store = AsyncMock()
store.search.return_value = [_make_experience("e1")]
llm = AsyncMock()
llm.chat.return_value = _make_llm_response(
json.dumps(
[
{"skill_name": "s", "precondition": "p1", "reason": "r", "confidence": 1.5},
{"skill_name": "s", "precondition": "p2", "reason": "r", "confidence": -0.3},
{"skill_name": "s", "precondition": "p3", "reason": "r", "confidence": 0.5},
]
)
)
learner = RiskGuardLearner(store, llm)
suggestions = await learner.learn()
assert len(suggestions) == 3
assert suggestions[0].confidence == 1.0
assert suggestions[1].confidence == 0.0
assert suggestions[2].confidence == 0.5
@pytest.mark.asyncio
async def test_learn_json_in_markdown_codeblock(self):
"""LLM 返回 markdown 代码块包裹的 JSON 也能解析"""
store = AsyncMock()
store.search.return_value = [_make_experience("e1")]
llm = AsyncMock()
llm.chat.return_value = _make_llm_response(
'```json\n[{"skill_name":"s","precondition":"p","reason":"r","confidence":0.7}]\n```'
)
learner = RiskGuardLearner(store, llm)
suggestions = await learner.learn()
assert len(suggestions) == 1
assert suggestions[0].precondition == "p"
@pytest.mark.asyncio
async def test_learn_skips_items_missing_fields(self):
"""缺少 precondition 或 skill_name 的条目被跳过"""
store = AsyncMock()
store.search.return_value = [_make_experience("e1")]
llm = AsyncMock()
llm.chat.return_value = _make_llm_response(
json.dumps(
[
{"skill_name": "s", "precondition": "", "reason": "r", "confidence": 0.5},
{"skill_name": "", "precondition": "p", "reason": "r", "confidence": 0.5},
{"skill_name": "s", "precondition": "valid", "reason": "r", "confidence": 0.5},
]
)
)
learner = RiskGuardLearner(store, llm)
suggestions = await learner.learn()
assert len(suggestions) == 1
assert suggestions[0].precondition == "valid"
@pytest.mark.asyncio
async def test_learn_search_exception_returns_empty(self):
"""ExperienceStore.search 抛异常 → 返回空列表"""
store = AsyncMock()
store.search.side_effect = RuntimeError("DB down")
llm = AsyncMock()
learner = RiskGuardLearner(store, llm)
suggestions = await learner.learn()
assert suggestions == []