142 lines
4.6 KiB
Python
142 lines
4.6 KiB
Python
"""QualityGate - 产出质量管理
|
||
|
||
多维度质量检查:必填字段、字数、JSON Schema、自定义验证器。
|
||
"""
|
||
|
||
import importlib
|
||
import logging
|
||
from dataclasses import dataclass
|
||
from typing import Any, Callable
|
||
|
||
from agentkit.skills.base import Skill
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@dataclass
|
||
class QualityCheck:
|
||
"""单条质量检查结果"""
|
||
|
||
name: str
|
||
passed: bool
|
||
message: str | None = None
|
||
|
||
|
||
@dataclass
|
||
class QualityResult:
|
||
"""质量检查汇总结果"""
|
||
|
||
passed: bool
|
||
checks: list[QualityCheck]
|
||
can_retry: bool
|
||
|
||
|
||
class QualityGate:
|
||
"""产出质量管理 — 多维度质量检查"""
|
||
|
||
async def validate(
|
||
self,
|
||
output: dict[str, Any],
|
||
skill: Skill,
|
||
) -> QualityResult:
|
||
"""对产出执行多维度质量检查
|
||
|
||
检查维度:
|
||
1. 必填字段检查
|
||
2. 最低字数检查
|
||
3. JSON Schema 验证(如 skill.config.output_schema 存在)
|
||
4. 自定义验证器(如 skill.config.quality_gate.custom_validator 存在)
|
||
"""
|
||
checks: list[QualityCheck] = []
|
||
qg = skill.config.quality_gate
|
||
|
||
# 1. 必填字段检查
|
||
for field in qg.required_fields:
|
||
present = field in output and output[field] is not None
|
||
checks.append(QualityCheck(
|
||
name=f"required_field:{field}",
|
||
passed=present,
|
||
message=f"Field '{field}' is missing" if not present else None,
|
||
))
|
||
|
||
# 2. 最低字数检查
|
||
if qg.min_word_count > 0:
|
||
content = output.get("content", "")
|
||
if isinstance(content, str):
|
||
word_count = len(content.split())
|
||
else:
|
||
word_count = len(str(content).split())
|
||
passed = word_count >= qg.min_word_count
|
||
checks.append(QualityCheck(
|
||
name="min_word_count",
|
||
passed=passed,
|
||
message=(
|
||
f"Word count {word_count} < minimum {qg.min_word_count}"
|
||
if not passed
|
||
else None
|
||
),
|
||
))
|
||
|
||
# 3. JSON Schema 验证
|
||
if skill.config.output_schema:
|
||
try:
|
||
import jsonschema
|
||
|
||
jsonschema.validate(output, skill.config.output_schema)
|
||
checks.append(QualityCheck(name="schema", passed=True))
|
||
except jsonschema.ValidationError as e:
|
||
checks.append(QualityCheck(name="schema", passed=False, message=str(e)))
|
||
except ImportError:
|
||
# jsonschema 未安装,跳过
|
||
pass
|
||
|
||
# 4. 自定义验证器
|
||
if qg.custom_validator:
|
||
try:
|
||
validator = self._import_validator(qg.custom_validator)
|
||
result = validator(output)
|
||
# 支持异步验证器
|
||
if hasattr(result, "__await__"):
|
||
result = await result
|
||
checks.append(QualityCheck(name="custom", passed=bool(result)))
|
||
except Exception as e:
|
||
# 验证器导入/执行失败,跳过并记录警告
|
||
checks.append(QualityCheck(
|
||
name="custom",
|
||
passed=True,
|
||
message=f"Validator skipped: {e}",
|
||
))
|
||
|
||
return QualityResult(
|
||
passed=all(c.passed for c in checks),
|
||
checks=checks,
|
||
can_retry=qg.max_retries > 0,
|
||
)
|
||
|
||
# 允许的验证器模块前缀白名单
|
||
_ALLOWED_VALIDATOR_PREFIXES = (
|
||
"agentkit.",
|
||
"app.agent_framework.",
|
||
)
|
||
|
||
def _import_validator(self, dotted_path: str) -> Callable:
|
||
"""从点分路径导入自定义验证器函数
|
||
|
||
出于安全考虑,只允许导入白名单前缀下的模块。
|
||
"""
|
||
# 安全校验:只允许白名单前缀的模块
|
||
if not any(dotted_path.startswith(prefix) for prefix in self._ALLOWED_VALIDATOR_PREFIXES):
|
||
raise ImportError(
|
||
f"Validator '{dotted_path}' is not in allowed module prefixes: "
|
||
f"{self._ALLOWED_VALIDATOR_PREFIXES}"
|
||
)
|
||
try:
|
||
module_path, func_name = dotted_path.rsplit(".", 1)
|
||
module = importlib.import_module(module_path)
|
||
handler = getattr(module, func_name)
|
||
if not callable(handler):
|
||
raise ValueError(f"'{dotted_path}' is not callable")
|
||
return handler
|
||
except (ImportError, AttributeError, ValueError) as e:
|
||
raise ImportError(f"Failed to import validator '{dotted_path}': {e}") from e
|