276 lines
11 KiB
Python
276 lines
11 KiB
Python
"""QualityGate 单元测试"""
|
||
|
||
import asyncio
|
||
from unittest.mock import AsyncMock, MagicMock
|
||
|
||
import pytest
|
||
|
||
from agentkit.skills.base import QualityGateConfig, Skill, SkillConfig
|
||
from agentkit.quality.gate import QualityCheck, QualityGate, QualityResult
|
||
|
||
|
||
# ── 辅助函数 ───────────────────────────────────────────────
|
||
|
||
|
||
def _make_skill(
|
||
required_fields: list[str] | None = None,
|
||
min_word_count: int = 0,
|
||
max_retries: int = 0,
|
||
custom_validator: str | None = None,
|
||
output_schema: dict | None = None,
|
||
) -> Skill:
|
||
"""创建测试用 Skill 实例"""
|
||
config = SkillConfig.from_dict({
|
||
"name": "test_skill",
|
||
"agent_type": "test",
|
||
"task_mode": "llm_generate",
|
||
"prompt": {"identity": "测试技能"},
|
||
"quality_gate": {
|
||
"required_fields": required_fields or [],
|
||
"min_word_count": min_word_count,
|
||
"max_retries": max_retries,
|
||
"custom_validator": custom_validator,
|
||
},
|
||
"output_schema": output_schema,
|
||
})
|
||
return Skill(config)
|
||
|
||
|
||
# ── QualityCheck 测试 ──────────────────────────────────────
|
||
|
||
|
||
class TestQualityCheck:
|
||
"""QualityCheck 数据类测试"""
|
||
|
||
def test_passed_check(self):
|
||
check = QualityCheck(name="required_field:title", passed=True)
|
||
assert check.name == "required_field:title"
|
||
assert check.passed is True
|
||
assert check.message is None
|
||
|
||
def test_failed_check_with_message(self):
|
||
check = QualityCheck(
|
||
name="required_field:title", passed=False, message="Field 'title' is missing"
|
||
)
|
||
assert check.passed is False
|
||
assert check.message == "Field 'title' is missing"
|
||
|
||
|
||
# ── QualityResult 测试 ─────────────────────────────────────
|
||
|
||
|
||
class TestQualityResult:
|
||
"""QualityResult 数据类测试"""
|
||
|
||
def test_passed_result(self):
|
||
result = QualityResult(
|
||
passed=True, checks=[QualityCheck(name="x", passed=True)], can_retry=False
|
||
)
|
||
assert result.passed is True
|
||
assert result.can_retry is False
|
||
|
||
def test_failed_result_with_retry(self):
|
||
result = QualityResult(
|
||
passed=False,
|
||
checks=[QualityCheck(name="x", passed=False, message="fail")],
|
||
can_retry=True,
|
||
)
|
||
assert result.passed is False
|
||
assert result.can_retry is True
|
||
|
||
|
||
# ── QualityGate.validate 测试 ──────────────────────────────
|
||
|
||
|
||
class TestQualityGateValidate:
|
||
"""QualityGate.validate 多维度质量检查"""
|
||
|
||
@pytest.fixture
|
||
def gate(self) -> QualityGate:
|
||
return QualityGate()
|
||
|
||
async def test_all_required_fields_present(self, gate: QualityGate):
|
||
"""所有必填字段都存在 → passed=True"""
|
||
skill = _make_skill(required_fields=["title", "content"])
|
||
output = {"title": "Hello", "content": "World"}
|
||
result = await gate.validate(output, skill)
|
||
assert result.passed is True
|
||
|
||
async def test_missing_required_field(self, gate: QualityGate):
|
||
"""缺少必填字段 → passed=False,并附带 message"""
|
||
skill = _make_skill(required_fields=["title", "content"])
|
||
output = {"title": "Hello"} # 缺少 content
|
||
result = await gate.validate(output, skill)
|
||
assert result.passed is False
|
||
field_checks = [c for c in result.checks if c.name == "required_field:content"]
|
||
assert len(field_checks) == 1
|
||
assert field_checks[0].passed is False
|
||
assert "content" in field_checks[0].message
|
||
|
||
async def test_required_field_present_but_none(self, gate: QualityGate):
|
||
"""必填字段存在但值为 None → 视为缺失"""
|
||
skill = _make_skill(required_fields=["title"])
|
||
output = {"title": None}
|
||
result = await gate.validate(output, skill)
|
||
assert result.passed is False
|
||
|
||
async def test_min_word_count_sufficient(self, gate: QualityGate):
|
||
"""字数满足最低要求 → passed=True"""
|
||
skill = _make_skill(min_word_count=5)
|
||
output = {"content": "one two three four five six"}
|
||
result = await gate.validate(output, skill)
|
||
word_check = [c for c in result.checks if c.name == "min_word_count"]
|
||
assert len(word_check) == 1
|
||
assert word_check[0].passed is True
|
||
|
||
async def test_min_word_count_insufficient(self, gate: QualityGate):
|
||
"""字数不足 → passed=False,附带 message"""
|
||
skill = _make_skill(min_word_count=100)
|
||
output = {"content": "short text"}
|
||
result = await gate.validate(output, skill)
|
||
word_check = [c for c in result.checks if c.name == "min_word_count"]
|
||
assert len(word_check) == 1
|
||
assert word_check[0].passed is False
|
||
assert "100" in word_check[0].message
|
||
|
||
async def test_min_word_count_with_non_string_content(self, gate: QualityGate):
|
||
"""content 不是字符串时,转为字符串后计算字数"""
|
||
skill = _make_skill(min_word_count=1)
|
||
output = {"content": 12345}
|
||
result = await gate.validate(output, skill)
|
||
word_check = [c for c in result.checks if c.name == "min_word_count"]
|
||
assert len(word_check) == 1
|
||
assert word_check[0].passed is True # str(12345) = "12345" → 1 word
|
||
|
||
async def test_json_schema_validation_passes(self, gate: QualityGate):
|
||
"""JSON Schema 验证通过"""
|
||
schema = {
|
||
"type": "object",
|
||
"properties": {
|
||
"title": {"type": "string"},
|
||
},
|
||
"required": ["title"],
|
||
}
|
||
skill = _make_skill(output_schema=schema)
|
||
output = {"title": "Hello"}
|
||
result = await gate.validate(output, skill)
|
||
schema_checks = [c for c in result.checks if c.name == "schema"]
|
||
assert len(schema_checks) == 1
|
||
assert schema_checks[0].passed is True
|
||
|
||
async def test_json_schema_validation_fails(self, gate: QualityGate):
|
||
"""JSON Schema 验证失败"""
|
||
schema = {
|
||
"type": "object",
|
||
"properties": {
|
||
"count": {"type": "integer"},
|
||
},
|
||
"required": ["count"],
|
||
}
|
||
skill = _make_skill(output_schema=schema)
|
||
output = {"count": "not_an_integer"}
|
||
result = await gate.validate(output, skill)
|
||
schema_checks = [c for c in result.checks if c.name == "schema"]
|
||
assert len(schema_checks) == 1
|
||
assert schema_checks[0].passed is False
|
||
|
||
async def test_max_retries_greater_than_zero(self, gate: QualityGate):
|
||
"""max_retries > 0 → can_retry=True"""
|
||
skill = _make_skill(max_retries=3)
|
||
result = await gate.validate({}, skill)
|
||
assert result.can_retry is True
|
||
|
||
async def test_max_retries_zero(self, gate: QualityGate):
|
||
"""max_retries = 0 → can_retry=False"""
|
||
skill = _make_skill(max_retries=0)
|
||
result = await gate.validate({}, skill)
|
||
assert result.can_retry is False
|
||
|
||
async def test_custom_validator_returns_true(self, gate: QualityGate):
|
||
"""自定义验证器返回 True → passed=True"""
|
||
import sys
|
||
from unittest.mock import MagicMock
|
||
|
||
mock_module = MagicMock()
|
||
mock_validator = AsyncMock(return_value=True)
|
||
mock_module.check_output = mock_validator
|
||
sys.modules["agentkit.test_validators"] = mock_module
|
||
|
||
try:
|
||
skill = _make_skill(custom_validator="agentkit.test_validators.check_output")
|
||
result = await gate.validate({"data": "ok"}, skill)
|
||
custom_checks = [c for c in result.checks if c.name == "custom"]
|
||
assert len(custom_checks) == 1
|
||
assert custom_checks[0].passed is True
|
||
finally:
|
||
del sys.modules["agentkit.test_validators"]
|
||
|
||
async def test_custom_validator_returns_false(self, gate: QualityGate):
|
||
"""自定义验证器返回 False → passed=False"""
|
||
import sys
|
||
from unittest.mock import MagicMock
|
||
|
||
mock_module = MagicMock()
|
||
mock_validator = AsyncMock(return_value=False)
|
||
mock_module.check_quality = mock_validator
|
||
sys.modules["agentkit.test_validators2"] = mock_module
|
||
|
||
try:
|
||
skill = _make_skill(custom_validator="agentkit.test_validators2.check_quality")
|
||
result = await gate.validate({"data": "bad"}, skill)
|
||
custom_checks = [c for c in result.checks if c.name == "custom"]
|
||
assert len(custom_checks) == 1
|
||
assert custom_checks[0].passed is False
|
||
finally:
|
||
del sys.modules["agentkit.test_validators2"]
|
||
|
||
async def test_custom_validator_does_not_exist(self, gate: QualityGate):
|
||
"""自定义验证器不存在 → 跳过(passed=True,附带 message)"""
|
||
# 使用白名单前缀但模块不存在
|
||
skill = _make_skill(custom_validator="agentkit.nonexistent_module.validator")
|
||
result = await gate.validate({"data": "ok"}, skill)
|
||
custom_checks = [c for c in result.checks if c.name == "custom"]
|
||
assert len(custom_checks) == 1
|
||
assert custom_checks[0].passed is True
|
||
assert custom_checks[0].message is not None
|
||
|
||
async def test_empty_quality_gate_config(self, gate: QualityGate):
|
||
"""空 quality_gate 配置 → 所有检查通过"""
|
||
skill = _make_skill() # 默认空配置
|
||
output = {"anything": "goes"}
|
||
result = await gate.validate(output, skill)
|
||
assert result.passed is True
|
||
|
||
async def test_passed_is_false_when_any_check_fails(self, gate: QualityGate):
|
||
"""任一检查失败 → passed=False"""
|
||
skill = _make_skill(required_fields=["title", "body"])
|
||
output = {"title": "Hello"} # 缺少 body
|
||
result = await gate.validate(output, skill)
|
||
assert result.passed is False
|
||
|
||
async def test_no_output_schema_skips_schema_check(self, gate: QualityGate):
|
||
"""无 output_schema → 不执行 schema 检查"""
|
||
skill = _make_skill(output_schema=None)
|
||
output = {"anything": "goes"}
|
||
result = await gate.validate(output, skill)
|
||
schema_checks = [c for c in result.checks if c.name == "schema"]
|
||
assert len(schema_checks) == 0
|
||
|
||
async def test_custom_validator_sync_function(self, gate: QualityGate):
|
||
"""自定义验证器是同步函数 → 也能正常调用"""
|
||
import sys
|
||
from unittest.mock import MagicMock
|
||
|
||
mock_module = MagicMock()
|
||
mock_module.sync_check = MagicMock(return_value=True)
|
||
sys.modules["test_sync_validators"] = mock_module
|
||
|
||
try:
|
||
skill = _make_skill(custom_validator="test_sync_validators.sync_check")
|
||
result = await gate.validate({"data": "ok"}, skill)
|
||
custom_checks = [c for c in result.checks if c.name == "custom"]
|
||
assert len(custom_checks) == 1
|
||
assert custom_checks[0].passed is True
|
||
finally:
|
||
del sys.modules["test_sync_validators"]
|