202 lines
7.3 KiB
Python
202 lines
7.3 KiB
Python
"""ABTester - A/B 测试框架
|
||
|
||
支持配置分流比例,自动收集效果指标,统计显著性检验。
|
||
"""
|
||
|
||
import logging
|
||
import math
|
||
from dataclasses import dataclass
|
||
from typing import TYPE_CHECKING
|
||
|
||
if TYPE_CHECKING:
|
||
from agentkit.evolution.evolution_store import InMemoryEvolutionStore
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@dataclass
|
||
class ABTestConfig:
|
||
"""A/B 测试配置"""
|
||
test_id: str
|
||
agent_name: str
|
||
change_type: str # prompt / strategy / pipeline
|
||
control_ratio: float = 0.5 # 对照组比例(hash-based 分流,默认 50/50)
|
||
min_samples: int = 10 # 最小样本量
|
||
confidence_level: float = 0.95 # 置信度
|
||
status: str = "running" # running / completed / rolled_back
|
||
|
||
|
||
@dataclass
|
||
class ABTestResult:
|
||
"""A/B 测试结果"""
|
||
test_id: str
|
||
control_metric: float
|
||
experiment_metric: float
|
||
control_samples: int
|
||
experiment_samples: int
|
||
is_significant: bool
|
||
winner: str | None # control / experiment / None
|
||
p_value: float | None = None
|
||
|
||
|
||
class ABTester:
|
||
"""A/B 测试框架
|
||
|
||
使用 hash-based 分流确保确定性、可复现的组分配。
|
||
支持将结果持久化到 EvolutionStore。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
evolution_store: "InMemoryEvolutionStore | None" = None,
|
||
min_samples: int = 10,
|
||
):
|
||
self._tests: dict[str, ABTestConfig] = {}
|
||
self._results: dict[str, list[tuple[str, float]]] = {} # test_id -> [(group, metric)]
|
||
self._evolution_store = evolution_store
|
||
self._default_min_samples = min_samples
|
||
|
||
def create_test(self, config: ABTestConfig) -> None:
|
||
"""创建 A/B 测试"""
|
||
# 如果 config 未指定 min_samples,使用默认值
|
||
if config.min_samples == 30 and self._default_min_samples != 30:
|
||
config = ABTestConfig(
|
||
test_id=config.test_id,
|
||
agent_name=config.agent_name,
|
||
change_type=config.change_type,
|
||
control_ratio=config.control_ratio,
|
||
min_samples=self._default_min_samples,
|
||
confidence_level=config.confidence_level,
|
||
status=config.status,
|
||
)
|
||
self._tests[config.test_id] = config
|
||
self._results[config.test_id] = []
|
||
logger.info(f"A/B test '{config.test_id}' created for agent '{config.agent_name}'")
|
||
|
||
def assign_group(self, test_id: str, task_id: str = "") -> str:
|
||
"""分配测试组(hash-based 确定性分配)
|
||
|
||
Args:
|
||
test_id: 测试 ID
|
||
task_id: 任务 ID,用于 hash 分流。如果为空则回退到 test_id 的 hash
|
||
|
||
Returns:
|
||
"control" 或 "experiment"
|
||
"""
|
||
config = self._tests.get(test_id)
|
||
if not config:
|
||
return "control"
|
||
|
||
# Hash-based deterministic assignment
|
||
key = task_id or test_id
|
||
group_index = hash(key) % 2
|
||
return "control" if group_index == 0 else "experiment"
|
||
|
||
def record_result(self, test_id: str, group: str, metric: float) -> None:
|
||
"""记录测试结果"""
|
||
if test_id not in self._results:
|
||
self._results[test_id] = []
|
||
self._results[test_id].append((group, metric))
|
||
|
||
async def persist_results(self, test_id: str) -> None:
|
||
"""将测试结果持久化到 EvolutionStore"""
|
||
if self._evolution_store is None:
|
||
logger.debug("No evolution store configured, skipping persistence")
|
||
return
|
||
|
||
results = self._results.get(test_id, [])
|
||
if not results:
|
||
return
|
||
|
||
# Aggregate results by group
|
||
control_metrics = [m for g, m in results if g == "control"]
|
||
experiment_metrics = [m for g, m in results if g == "experiment"]
|
||
|
||
control_avg = sum(control_metrics) / len(control_metrics) if control_metrics else 0.0
|
||
experiment_avg = sum(experiment_metrics) / len(experiment_metrics) if experiment_metrics else 0.0
|
||
|
||
try:
|
||
await self._evolution_store.record_ab_test_result(
|
||
test_id=test_id,
|
||
variant="control",
|
||
score=control_avg,
|
||
sample_count=len(control_metrics),
|
||
)
|
||
await self._evolution_store.record_ab_test_result(
|
||
test_id=test_id,
|
||
variant="experiment",
|
||
score=experiment_avg,
|
||
sample_count=len(experiment_metrics),
|
||
)
|
||
logger.info(f"A/B test results persisted for test '{test_id}'")
|
||
except Exception as e:
|
||
logger.error(f"Failed to persist A/B test results: {e}")
|
||
|
||
async def evaluate(self, test_id: str) -> ABTestResult | None:
|
||
"""评估 A/B 测试结果"""
|
||
config = self._tests.get(test_id)
|
||
if not config:
|
||
return None
|
||
|
||
results = self._results.get(test_id, [])
|
||
control_metrics = [m for g, m in results if g == "control"]
|
||
experiment_metrics = [m for g, m in results if g == "experiment"]
|
||
|
||
if len(control_metrics) < config.min_samples or len(experiment_metrics) < config.min_samples:
|
||
return ABTestResult(
|
||
test_id=test_id,
|
||
control_metric=sum(control_metrics) / len(control_metrics) if control_metrics else 0,
|
||
experiment_metric=sum(experiment_metrics) / len(experiment_metrics) if experiment_metrics else 0,
|
||
control_samples=len(control_metrics),
|
||
experiment_samples=len(experiment_metrics),
|
||
is_significant=False,
|
||
winner=None,
|
||
)
|
||
|
||
# 简单 t-test
|
||
control_mean = sum(control_metrics) / len(control_metrics)
|
||
experiment_mean = sum(experiment_metrics) / len(experiment_metrics)
|
||
|
||
control_var = sum((m - control_mean) ** 2 for m in control_metrics) / (len(control_metrics) - 1)
|
||
experiment_var = sum((m - experiment_mean) ** 2 for m in experiment_metrics) / (len(experiment_metrics) - 1)
|
||
|
||
pooled_se = math.sqrt(control_var / len(control_metrics) + experiment_var / len(experiment_metrics))
|
||
|
||
# Handle zero variance case: if means differ but variance is zero,
|
||
# the difference is clearly significant
|
||
if pooled_se == 0:
|
||
if abs(experiment_mean - control_mean) > 1e-10:
|
||
is_significant = True
|
||
winner = "experiment" if experiment_mean > control_mean else "control"
|
||
p_value = 0.0
|
||
else:
|
||
is_significant = False
|
||
winner = None
|
||
p_value = 1.0
|
||
else:
|
||
t_stat = (experiment_mean - control_mean) / pooled_se
|
||
|
||
# 近似 p-value (双侧)
|
||
p_value = 2 * (1 - self._normal_cdf(abs(t_stat)))
|
||
is_significant = p_value < (1 - config.confidence_level)
|
||
|
||
winner = None
|
||
if is_significant:
|
||
winner = "experiment" if experiment_mean > control_mean else "control"
|
||
|
||
return ABTestResult(
|
||
test_id=test_id,
|
||
control_metric=control_mean,
|
||
experiment_metric=experiment_mean,
|
||
control_samples=len(control_metrics),
|
||
experiment_samples=len(experiment_metrics),
|
||
is_significant=is_significant,
|
||
winner=winner,
|
||
p_value=p_value,
|
||
)
|
||
|
||
@staticmethod
|
||
def _normal_cdf(x: float) -> float:
|
||
"""标准正态分布 CDF 近似"""
|
||
return 0.5 * (1 + math.erf(x / math.sqrt(2)))
|