fischer-agentkit/src/agentkit/evolution/ab_tester.py

202 lines
7.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""ABTester - A/B 测试框架
支持配置分流比例,自动收集效果指标,统计显著性检验。
"""
import logging
import math
from dataclasses import dataclass
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from agentkit.evolution.evolution_store import InMemoryEvolutionStore
logger = logging.getLogger(__name__)
@dataclass
class ABTestConfig:
"""A/B 测试配置"""
test_id: str
agent_name: str
change_type: str # prompt / strategy / pipeline
control_ratio: float = 0.5 # 对照组比例hash-based 分流,默认 50/50
min_samples: int = 10 # 最小样本量
confidence_level: float = 0.95 # 置信度
status: str = "running" # running / completed / rolled_back
@dataclass
class ABTestResult:
"""A/B 测试结果"""
test_id: str
control_metric: float
experiment_metric: float
control_samples: int
experiment_samples: int
is_significant: bool
winner: str | None # control / experiment / None
p_value: float | None = None
class ABTester:
"""A/B 测试框架
使用 hash-based 分流确保确定性、可复现的组分配。
支持将结果持久化到 EvolutionStore。
"""
def __init__(
self,
evolution_store: "InMemoryEvolutionStore | None" = None,
min_samples: int = 10,
):
self._tests: dict[str, ABTestConfig] = {}
self._results: dict[str, list[tuple[str, float]]] = {} # test_id -> [(group, metric)]
self._evolution_store = evolution_store
self._default_min_samples = min_samples
def create_test(self, config: ABTestConfig) -> None:
"""创建 A/B 测试"""
# 如果 config 未指定 min_samples使用默认值
if config.min_samples == 30 and self._default_min_samples != 30:
config = ABTestConfig(
test_id=config.test_id,
agent_name=config.agent_name,
change_type=config.change_type,
control_ratio=config.control_ratio,
min_samples=self._default_min_samples,
confidence_level=config.confidence_level,
status=config.status,
)
self._tests[config.test_id] = config
self._results[config.test_id] = []
logger.info(f"A/B test '{config.test_id}' created for agent '{config.agent_name}'")
def assign_group(self, test_id: str, task_id: str = "") -> str:
"""分配测试组hash-based 确定性分配)
Args:
test_id: 测试 ID
task_id: 任务 ID用于 hash 分流。如果为空则回退到 test_id 的 hash
Returns:
"control""experiment"
"""
config = self._tests.get(test_id)
if not config:
return "control"
# Hash-based deterministic assignment
key = task_id or test_id
group_index = hash(key) % 2
return "control" if group_index == 0 else "experiment"
def record_result(self, test_id: str, group: str, metric: float) -> None:
"""记录测试结果"""
if test_id not in self._results:
self._results[test_id] = []
self._results[test_id].append((group, metric))
async def persist_results(self, test_id: str) -> None:
"""将测试结果持久化到 EvolutionStore"""
if self._evolution_store is None:
logger.debug("No evolution store configured, skipping persistence")
return
results = self._results.get(test_id, [])
if not results:
return
# Aggregate results by group
control_metrics = [m for g, m in results if g == "control"]
experiment_metrics = [m for g, m in results if g == "experiment"]
control_avg = sum(control_metrics) / len(control_metrics) if control_metrics else 0.0
experiment_avg = sum(experiment_metrics) / len(experiment_metrics) if experiment_metrics else 0.0
try:
await self._evolution_store.record_ab_test_result(
test_id=test_id,
variant="control",
score=control_avg,
sample_count=len(control_metrics),
)
await self._evolution_store.record_ab_test_result(
test_id=test_id,
variant="experiment",
score=experiment_avg,
sample_count=len(experiment_metrics),
)
logger.info(f"A/B test results persisted for test '{test_id}'")
except Exception as e:
logger.error(f"Failed to persist A/B test results: {e}")
async def evaluate(self, test_id: str) -> ABTestResult | None:
"""评估 A/B 测试结果"""
config = self._tests.get(test_id)
if not config:
return None
results = self._results.get(test_id, [])
control_metrics = [m for g, m in results if g == "control"]
experiment_metrics = [m for g, m in results if g == "experiment"]
if len(control_metrics) < config.min_samples or len(experiment_metrics) < config.min_samples:
return ABTestResult(
test_id=test_id,
control_metric=sum(control_metrics) / len(control_metrics) if control_metrics else 0,
experiment_metric=sum(experiment_metrics) / len(experiment_metrics) if experiment_metrics else 0,
control_samples=len(control_metrics),
experiment_samples=len(experiment_metrics),
is_significant=False,
winner=None,
)
# 简单 t-test
control_mean = sum(control_metrics) / len(control_metrics)
experiment_mean = sum(experiment_metrics) / len(experiment_metrics)
control_var = sum((m - control_mean) ** 2 for m in control_metrics) / (len(control_metrics) - 1)
experiment_var = sum((m - experiment_mean) ** 2 for m in experiment_metrics) / (len(experiment_metrics) - 1)
pooled_se = math.sqrt(control_var / len(control_metrics) + experiment_var / len(experiment_metrics))
# Handle zero variance case: if means differ but variance is zero,
# the difference is clearly significant
if pooled_se == 0:
if abs(experiment_mean - control_mean) > 1e-10:
is_significant = True
winner = "experiment" if experiment_mean > control_mean else "control"
p_value = 0.0
else:
is_significant = False
winner = None
p_value = 1.0
else:
t_stat = (experiment_mean - control_mean) / pooled_se
# 近似 p-value (双侧)
p_value = 2 * (1 - self._normal_cdf(abs(t_stat)))
is_significant = p_value < (1 - config.confidence_level)
winner = None
if is_significant:
winner = "experiment" if experiment_mean > control_mean else "control"
return ABTestResult(
test_id=test_id,
control_metric=control_mean,
experiment_metric=experiment_mean,
control_samples=len(control_metrics),
experiment_samples=len(experiment_metrics),
is_significant=is_significant,
winner=winner,
p_value=p_value,
)
@staticmethod
def _normal_cdf(x: float) -> float:
"""标准正态分布 CDF 近似"""
return 0.5 * (1 + math.erf(x / math.sqrt(2)))