202 lines
7.9 KiB
Python
202 lines
7.9 KiB
Python
"""RiskGuardLearner 单元测试"""
|
||
|
||
import json
|
||
from types import SimpleNamespace
|
||
from unittest.mock import AsyncMock
|
||
|
||
import pytest
|
||
|
||
from agentkit.evolution.experience_schema import TaskExperience
|
||
from agentkit.evolution.risk_guard_learner import RiskGuardLearner
|
||
|
||
|
||
def _make_experience(
|
||
experience_id="exp1",
|
||
task_type="code_reviewer",
|
||
goal="review code",
|
||
outcome="failure",
|
||
failure_reasons=None,
|
||
optimization_tips=None,
|
||
) -> TaskExperience:
|
||
return TaskExperience(
|
||
experience_id=experience_id,
|
||
task_type=task_type,
|
||
goal=goal,
|
||
steps_summary="loaded skill; ran review",
|
||
outcome=outcome,
|
||
failure_reasons=failure_reasons or ["no code provided"],
|
||
optimization_tips=optimization_tips or ["require code input"],
|
||
)
|
||
|
||
|
||
def _make_llm_response(content: str):
|
||
return SimpleNamespace(content=content)
|
||
|
||
|
||
class TestRiskGuardLearner:
|
||
@pytest.mark.asyncio
|
||
async def test_learn_happy_path(self):
|
||
"""3 条失败轨迹 + 合法 JSON → 返回建议"""
|
||
store = AsyncMock()
|
||
store.search.return_value = [
|
||
_make_experience("e1", "code_reviewer", "review A"),
|
||
_make_experience("e2", "code_reviewer", "review B"),
|
||
_make_experience("e3", "code_reviewer", "review C"),
|
||
]
|
||
llm = AsyncMock()
|
||
llm.chat.return_value = _make_llm_response(
|
||
json.dumps(
|
||
[
|
||
{
|
||
"skill_name": "code_reviewer",
|
||
"precondition": "输入必须包含待审查的代码片段",
|
||
"reason": "多次因输入为空导致审查失败",
|
||
"confidence": 0.85,
|
||
},
|
||
{
|
||
"skill_name": "code_reviewer",
|
||
"precondition": "代码片段长度 >= 10 字符",
|
||
"reason": "过短输入无法有效审查",
|
||
"confidence": 0.6,
|
||
},
|
||
]
|
||
)
|
||
)
|
||
learner = RiskGuardLearner(store, llm)
|
||
suggestions = await learner.learn()
|
||
assert len(suggestions) == 2
|
||
assert suggestions[0].skill_name == "code_reviewer"
|
||
assert suggestions[0].precondition == "输入必须包含待审查的代码片段"
|
||
assert suggestions[0].confidence == 0.85
|
||
assert set(suggestions[0].source_experience_ids) == {"e1", "e2", "e3"}
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_learn_skill_name_filter(self):
|
||
"""skill_name 透传给 search 的 task_type"""
|
||
store = AsyncMock()
|
||
store.search.return_value = [_make_experience("e1", "code_reviewer")]
|
||
llm = AsyncMock()
|
||
llm.chat.return_value = _make_llm_response("[]")
|
||
learner = RiskGuardLearner(store, llm)
|
||
await learner.learn(skill_name="code_reviewer")
|
||
store.search.assert_called_once_with(query="failure", top_k=20, task_type="code_reviewer")
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_learn_llm_exception_returns_empty(self):
|
||
"""LLM 调用抛异常 → 返回空列表,不抛"""
|
||
store = AsyncMock()
|
||
store.search.return_value = [_make_experience("e1")]
|
||
llm = AsyncMock()
|
||
llm.chat.side_effect = RuntimeError("LLM down")
|
||
learner = RiskGuardLearner(store, llm)
|
||
suggestions = await learner.learn()
|
||
assert suggestions == []
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_learn_invalid_json_returns_empty(self):
|
||
"""LLM 返回非法 JSON → 返回空列表"""
|
||
store = AsyncMock()
|
||
store.search.return_value = [_make_experience("e1")]
|
||
llm = AsyncMock()
|
||
llm.chat.return_value = _make_llm_response("not json at all")
|
||
learner = RiskGuardLearner(store, llm)
|
||
suggestions = await learner.learn()
|
||
assert suggestions == []
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_learn_no_failures_returns_empty(self):
|
||
"""ExperienceStore 返回空 → 返回空列表,不调用 LLM"""
|
||
store = AsyncMock()
|
||
store.search.return_value = []
|
||
llm = AsyncMock()
|
||
learner = RiskGuardLearner(store, llm)
|
||
suggestions = await learner.learn()
|
||
assert suggestions == []
|
||
llm.chat.assert_not_called()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_learn_filters_non_failure_outcomes(self):
|
||
"""只保留 outcome == 'failure' 的轨迹"""
|
||
store = AsyncMock()
|
||
store.search.return_value = [
|
||
_make_experience("e1", goal="failure-goal", outcome="failure"),
|
||
_make_experience("e2", goal="success-goal", outcome="success"),
|
||
_make_experience("e3", goal="partial-goal", outcome="partial"),
|
||
]
|
||
llm = AsyncMock()
|
||
llm.chat.return_value = _make_llm_response("[]")
|
||
learner = RiskGuardLearner(store, llm)
|
||
await learner.learn()
|
||
# 只有 e1 是 failure,prompt 中应含 failure-goal,不含 success/partial 的 goal
|
||
call_args = llm.chat.call_args
|
||
prompt = call_args.kwargs["messages"][1]["content"]
|
||
assert "failure-goal" in prompt
|
||
assert "success-goal" not in prompt
|
||
assert "partial-goal" not in prompt
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_confidence_clamped(self):
|
||
"""confidence 被 clamp 到 [0.0, 1.0]"""
|
||
store = AsyncMock()
|
||
store.search.return_value = [_make_experience("e1")]
|
||
llm = AsyncMock()
|
||
llm.chat.return_value = _make_llm_response(
|
||
json.dumps(
|
||
[
|
||
{"skill_name": "s", "precondition": "p1", "reason": "r", "confidence": 1.5},
|
||
{"skill_name": "s", "precondition": "p2", "reason": "r", "confidence": -0.3},
|
||
{"skill_name": "s", "precondition": "p3", "reason": "r", "confidence": 0.5},
|
||
]
|
||
)
|
||
)
|
||
learner = RiskGuardLearner(store, llm)
|
||
suggestions = await learner.learn()
|
||
assert len(suggestions) == 3
|
||
assert suggestions[0].confidence == 1.0
|
||
assert suggestions[1].confidence == 0.0
|
||
assert suggestions[2].confidence == 0.5
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_learn_json_in_markdown_codeblock(self):
|
||
"""LLM 返回 markdown 代码块包裹的 JSON 也能解析"""
|
||
store = AsyncMock()
|
||
store.search.return_value = [_make_experience("e1")]
|
||
llm = AsyncMock()
|
||
llm.chat.return_value = _make_llm_response(
|
||
'```json\n[{"skill_name":"s","precondition":"p","reason":"r","confidence":0.7}]\n```'
|
||
)
|
||
learner = RiskGuardLearner(store, llm)
|
||
suggestions = await learner.learn()
|
||
assert len(suggestions) == 1
|
||
assert suggestions[0].precondition == "p"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_learn_skips_items_missing_fields(self):
|
||
"""缺少 precondition 或 skill_name 的条目被跳过"""
|
||
store = AsyncMock()
|
||
store.search.return_value = [_make_experience("e1")]
|
||
llm = AsyncMock()
|
||
llm.chat.return_value = _make_llm_response(
|
||
json.dumps(
|
||
[
|
||
{"skill_name": "s", "precondition": "", "reason": "r", "confidence": 0.5},
|
||
{"skill_name": "", "precondition": "p", "reason": "r", "confidence": 0.5},
|
||
{"skill_name": "s", "precondition": "valid", "reason": "r", "confidence": 0.5},
|
||
]
|
||
)
|
||
)
|
||
learner = RiskGuardLearner(store, llm)
|
||
suggestions = await learner.learn()
|
||
assert len(suggestions) == 1
|
||
assert suggestions[0].precondition == "valid"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_learn_search_exception_returns_empty(self):
|
||
"""ExperienceStore.search 抛异常 → 返回空列表"""
|
||
store = AsyncMock()
|
||
store.search.side_effect = RuntimeError("DB down")
|
||
llm = AsyncMock()
|
||
learner = RiskGuardLearner(store, llm)
|
||
suggestions = await learner.learn()
|
||
assert suggestions == []
|