fischer-agentkit/tests/unit/test_output_standardizer.py

247 lines
9.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""OutputStandardizer 单元测试"""
from datetime import datetime, timezone
import pytest
from agentkit.quality.gate import QualityCheck, QualityResult
from agentkit.quality.output import OutputMetadata, OutputStandardizer, StandardOutput
from agentkit.skills.base import Skill, SkillConfig
# ── 辅助函数 ───────────────────────────────────────────────
def _make_skill(
name: str = "test_skill",
output_schema: dict | None = None,
) -> Skill:
"""创建测试用 Skill 实例"""
config = SkillConfig.from_dict({
"name": name,
"agent_type": "test",
"task_mode": "llm_generate",
"prompt": {"identity": "测试技能"},
"output_schema": output_schema,
})
return Skill(config)
def _make_quality_result(passed: bool, check_count: int = 1) -> QualityResult:
"""创建测试用 QualityResult"""
checks = [
QualityCheck(name=f"check_{i}", passed=passed)
for i in range(check_count)
]
return QualityResult(passed=passed, checks=checks, can_retry=False)
def _make_mixed_quality_result(passed_count: int, failed_count: int) -> QualityResult:
"""创建混合通过/失败的 QualityResult"""
checks = [
QualityCheck(name=f"pass_{i}", passed=True)
for i in range(passed_count)
] + [
QualityCheck(name=f"fail_{i}", passed=False, message=f"fail {i}")
for i in range(failed_count)
]
total_passed = failed_count == 0
return QualityResult(passed=total_passed, checks=checks, can_retry=False)
# ── OutputMetadata 测试 ────────────────────────────────────
class TestOutputMetadata:
"""OutputMetadata 数据类测试"""
def test_fields(self):
now = datetime.now(timezone.utc)
meta = OutputMetadata(version="1.0.0", produced_at=now, quality_score=0.8)
assert meta.version == "1.0.0"
assert meta.produced_at == now
assert meta.quality_score == 0.8
# ── StandardOutput 测试 ────────────────────────────────────
class TestStandardOutput:
"""StandardOutput 数据类测试"""
def test_fields(self):
meta = OutputMetadata(
version="1.0.0",
produced_at=datetime.now(timezone.utc),
quality_score=1.0,
)
output = StandardOutput(skill_name="my_skill", data={"key": "value"}, metadata=meta)
assert output.skill_name == "my_skill"
assert output.data == {"key": "value"}
assert output.metadata is meta
# ── OutputStandardizer.standardize 测试 ─────────────────────
class TestOutputStandardizer:
"""OutputStandardizer 标准化输出测试"""
@pytest.fixture
def standardizer(self) -> OutputStandardizer:
return OutputStandardizer()
async def test_standardized_output_contains_skill_name_and_metadata(
self, standardizer: OutputStandardizer
):
"""标准化输出包含 skill_name 和 metadata"""
skill = _make_skill(name="content_gen")
raw = {"title": "Hello", "content": "World"}
result = await standardizer.standardize(raw, skill)
assert isinstance(result, StandardOutput)
assert result.skill_name == "content_gen"
assert isinstance(result.metadata, OutputMetadata)
async def test_metadata_contains_version_and_produced_at(
self, standardizer: OutputStandardizer
):
"""metadata 包含 version 和 produced_at"""
skill = _make_skill()
raw = {"data": "test"}
result = await standardizer.standardize(raw, skill)
assert result.metadata.version == skill.config.version
assert isinstance(result.metadata.produced_at, datetime)
assert result.metadata.produced_at.tzinfo is not None
async def test_produced_at_uses_utc_timezone(self, standardizer: OutputStandardizer):
"""produced_at 使用 UTC 时区"""
skill = _make_skill()
raw = {"data": "test"}
result = await standardizer.standardize(raw, skill)
assert result.metadata.produced_at.tzinfo == timezone.utc
async def test_field_type_normalization_string_to_integer(
self, standardizer: OutputStandardizer
):
"""字段类型归一化:字符串 → 整数"""
schema = {
"type": "object",
"properties": {
"count": {"type": "integer"},
},
}
skill = _make_skill(output_schema=schema)
raw = {"count": "42"}
result = await standardizer.standardize(raw, skill)
assert result.data["count"] == 42
assert isinstance(result.data["count"], int)
async def test_field_type_normalization_string_to_number(
self, standardizer: OutputStandardizer
):
"""字段类型归一化:字符串 → 浮点数"""
schema = {
"type": "object",
"properties": {
"score": {"type": "number"},
},
}
skill = _make_skill(output_schema=schema)
raw = {"score": "3.14"}
result = await standardizer.standardize(raw, skill)
assert result.data["score"] == 3.14
assert isinstance(result.data["score"], float)
async def test_field_type_normalization_string_to_boolean(
self, standardizer: OutputStandardizer
):
"""字段类型归一化:字符串 → 布尔值"""
schema = {
"type": "object",
"properties": {
"active": {"type": "boolean"},
},
}
skill = _make_skill(output_schema=schema)
raw = {"active": "true"}
result = await standardizer.standardize(raw, skill)
assert result.data["active"] is True
async def test_empty_output_schema_no_schema_validation(
self, standardizer: OutputStandardizer
):
"""无 output_schema → 不做 schema 验证"""
skill = _make_skill(output_schema=None)
raw = {"anything": "goes", "number": 42}
result = await standardizer.standardize(raw, skill)
assert result.data == raw
async def test_quality_score_calculated_from_quality_result(
self, standardizer: OutputStandardizer
):
"""quality_score 从 QualityResult 正确计算"""
skill = _make_skill()
raw = {"data": "test"}
quality_result = _make_mixed_quality_result(passed_count=3, failed_count=1)
result = await standardizer.standardize(raw, skill, quality_result)
# 3 passed + 1 failed = 4 total, score = 3/4 = 0.75
assert result.metadata.quality_score == 0.75
async def test_quality_score_is_one_when_no_quality_result(
self, standardizer: OutputStandardizer
):
"""无 quality_result → quality_score = 1.0"""
skill = _make_skill()
raw = {"data": "test"}
result = await standardizer.standardize(raw, skill)
assert result.metadata.quality_score == 1.0
async def test_quality_score_all_passed(self, standardizer: OutputStandardizer):
"""所有检查通过 → quality_score = 1.0"""
skill = _make_skill()
raw = {"data": "test"}
quality_result = _make_quality_result(passed=True, check_count=5)
result = await standardizer.standardize(raw, skill, quality_result)
assert result.metadata.quality_score == 1.0
async def test_quality_score_all_failed(self, standardizer: OutputStandardizer):
"""所有检查失败 → quality_score = 0.0"""
skill = _make_skill()
raw = {"data": "test"}
quality_result = _make_quality_result(passed=False, check_count=3)
result = await standardizer.standardize(raw, skill, quality_result)
assert result.metadata.quality_score == 0.0
async def test_standard_output_data_matches_raw_when_no_normalization(
self, standardizer: OutputStandardizer
):
"""无归一化需求时StandardOutput.data 与 raw_output 一致"""
skill = _make_skill()
raw = {"title": "Hello", "count": 42, "active": True}
result = await standardizer.standardize(raw, skill)
assert result.data == raw
async def test_type_normalization_invalid_value_kept_as_is(
self, standardizer: OutputStandardizer
):
"""类型归一化失败时保留原值"""
schema = {
"type": "object",
"properties": {
"count": {"type": "integer"},
},
}
skill = _make_skill(output_schema=schema)
raw = {"count": "not_a_number"}
result = await standardizer.standardize(raw, skill)
# 无法转换,保留原值
assert result.data["count"] == "not_a_number"
async def test_quality_score_with_empty_checks(self, standardizer: OutputStandardizer):
"""空 checks 列表 → quality_score = 1.0"""
skill = _make_skill()
raw = {"data": "test"}
quality_result = QualityResult(passed=True, checks=[], can_retry=False)
result = await standardizer.standardize(raw, skill, quality_result)
assert result.metadata.quality_score == 1.0