298 lines
10 KiB
Python
298 lines
10 KiB
Python
"""监控模块TDD测试
|
||
|
||
测试策略:
|
||
- 使用真实Prometheus指标收集
|
||
- 测试真实数据库连接检查
|
||
- Redis连接检查使用真实或跳过
|
||
- LLM提供商检查使用真实调用或跳过
|
||
"""
|
||
import time
|
||
|
||
import pytest
|
||
import pytest_asyncio
|
||
from prometheus_client import REGISTRY
|
||
|
||
from app.middleware.prometheus_metrics import (
|
||
API_REQUESTS_TOTAL,
|
||
AGENT_EXECUTIONS_TOTAL,
|
||
LLM_REQUESTS_TOTAL,
|
||
LLM_TOKENS_TOTAL,
|
||
LLM_COST_ESTIMATED,
|
||
BRAND_COUNT,
|
||
QUERY_COUNT_TOTAL,
|
||
CONTENT_GENERATED_TOTAL,
|
||
CITATION_DETECTED_TOTAL,
|
||
)
|
||
from app.services.health_checker import HealthChecker, HealthCheckResult
|
||
|
||
|
||
# ============================================================================
|
||
# 指标测试
|
||
# ============================================================================
|
||
|
||
class TestPrometheusMetrics:
|
||
"""Prometheus指标测试"""
|
||
|
||
def test_api_requests_counter_increment(self):
|
||
"""测试API请求计数"""
|
||
# 记录一个请求
|
||
API_REQUESTS_TOTAL.labels(method="GET", endpoint="/test", status="200").inc()
|
||
|
||
# 从注册表获取指标值
|
||
metrics = {m.name: m for m in REGISTRY.collect()}
|
||
# prometheus_client strips _total suffix from Counter names in collect()
|
||
api_requests = metrics.get("geo_api_requests")
|
||
|
||
assert api_requests is not None
|
||
|
||
# 验证指标包含正确的标签
|
||
samples = api_requests.samples
|
||
sample = next((s for s in samples if s.labels.get("method") == "GET" and s.labels.get("endpoint") == "/test"), None)
|
||
assert sample is not None
|
||
assert sample.value >= 1
|
||
|
||
def test_agent_executions_counter(self):
|
||
"""测试Agent执行计数"""
|
||
AGENT_EXECUTIONS_TOTAL.labels(agent_name="test_agent", status="success").inc()
|
||
|
||
metrics = {m.name: m for m in REGISTRY.collect()}
|
||
agent_executions = metrics.get("geo_agent_executions")
|
||
|
||
assert agent_executions is not None
|
||
|
||
def test_llm_tokens_counter(self):
|
||
"""测试LLM Tokens计数"""
|
||
LLM_TOKENS_TOTAL.labels(provider="openai", model="gpt-4", token_type="prompt").inc(100)
|
||
LLM_TOKENS_TOTAL.labels(provider="openai", model="gpt-4", token_type="completion").inc(50)
|
||
|
||
metrics = {m.name: m for m in REGISTRY.collect()}
|
||
llm_tokens = metrics.get("geo_llm_tokens")
|
||
|
||
assert llm_tokens is not None
|
||
|
||
def test_llm_cost_gauge(self):
|
||
"""测试LLM成本估算"""
|
||
LLM_COST_ESTIMATED.labels(provider="openai", model="gpt-4").set(0.25)
|
||
|
||
metrics = {m.name: m for m in REGISTRY.collect()}
|
||
llm_cost = metrics.get("geo_llm_cost_estimated")
|
||
|
||
assert llm_cost is not None
|
||
|
||
def test_brand_count_gauge(self):
|
||
"""测试品牌计数"""
|
||
BRAND_COUNT.inc()
|
||
BRAND_COUNT.inc()
|
||
|
||
metrics = {m.name: m for m in REGISTRY.collect()}
|
||
brand_count = metrics.get("geo_brands_total")
|
||
|
||
assert brand_count is not None
|
||
|
||
def test_query_count_total(self):
|
||
"""测试查询计数"""
|
||
QUERY_COUNT_TOTAL.labels(platform="kimi", status="success").inc()
|
||
QUERY_COUNT_TOTAL.labels(platform="kimi", status="failed").inc()
|
||
|
||
metrics = {m.name: m for m in REGISTRY.collect()}
|
||
query_count = metrics.get("geo_queries")
|
||
|
||
assert query_count is not None
|
||
|
||
def test_content_generated_counter(self):
|
||
"""测试内容生成计数"""
|
||
CONTENT_GENERATED_TOTAL.inc()
|
||
|
||
metrics = {m.name: m for m in REGISTRY.collect()}
|
||
content_count = metrics.get("geo_content_generated")
|
||
|
||
assert content_count is not None
|
||
|
||
def test_citation_detected_counter(self):
|
||
"""测试引用检测计数"""
|
||
CITATION_DETECTED_TOTAL.labels(platform="kimi").inc()
|
||
CITATION_DETECTED_TOTAL.labels(platform="wenxin").inc()
|
||
|
||
metrics = {m.name: m for m in REGISTRY.collect()}
|
||
citation_count = metrics.get("geo_citations_detected")
|
||
|
||
assert citation_count is not None
|
||
|
||
|
||
# ============================================================================
|
||
# 健康检查测试
|
||
# ============================================================================
|
||
|
||
class TestHealthChecker:
|
||
"""健康检查服务测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_health_check_result_dataclass(self):
|
||
"""测试HealthCheckResult数据结构"""
|
||
result = HealthCheckResult(
|
||
name="test",
|
||
healthy=True,
|
||
latency_ms=10.5,
|
||
message="OK",
|
||
details={"key": "value"},
|
||
)
|
||
|
||
assert result.name == "test"
|
||
assert result.healthy is True
|
||
assert result.latency_ms == 10.5
|
||
assert result.message == "OK"
|
||
assert result.details["key"] == "value"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_health_check_result_optional_fields(self):
|
||
"""测试HealthCheckResult可选字段"""
|
||
result = HealthCheckResult(
|
||
name="test",
|
||
healthy=False,
|
||
)
|
||
|
||
assert result.latency_ms is None
|
||
assert result.message is None
|
||
assert result.details is None
|
||
|
||
|
||
# ============================================================================
|
||
# 健康检查器测试(需要真实数据库会话)
|
||
# ============================================================================
|
||
|
||
class TestHealthCheckerWithDatabase:
|
||
"""健康检查器集成测试"""
|
||
|
||
@pytest_asyncio.fixture
|
||
async def health_checker(self, test_session):
|
||
"""创建健康检查器实例"""
|
||
return HealthChecker(db=test_session, redis_url="redis://localhost:6379")
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_check_database_healthy(self, health_checker):
|
||
"""测试数据库健康检查(健康状态)"""
|
||
result = await health_checker.check_database()
|
||
|
||
assert result.name == "database"
|
||
assert result.healthy is True
|
||
assert result.latency_ms is not None
|
||
assert result.latency_ms >= 0
|
||
assert "OK" in result.message
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_check_redis_connection(self, health_checker):
|
||
"""测试Redis连接检查"""
|
||
result = await health_checker.check_redis()
|
||
|
||
assert result.name == "redis"
|
||
# Redis可能不可用,但检查应该完成
|
||
assert result.latency_ms is not None
|
||
# 如果连接失败,healthy为False,但不应该抛出异常
|
||
if not result.healthy:
|
||
assert "Connection failed" in result.message or "Error" in result.message
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_check_storage(self, health_checker):
|
||
"""测试存储检查"""
|
||
result = await health_checker.check_storage()
|
||
|
||
assert result.name == "storage"
|
||
# 存储路径可能不存在,但检查应该完成
|
||
assert result.latency_ms is not None
|
||
|
||
|
||
# ============================================================================
|
||
# 指标收集测试
|
||
# ============================================================================
|
||
|
||
class TestMetricsCollection:
|
||
"""指标收集测试"""
|
||
|
||
def test_registry_collects_all_metrics(self):
|
||
"""测试注册表收集所有指标"""
|
||
metrics = {m.name: m for m in REGISTRY.collect()}
|
||
|
||
# 验证关键指标存在 (prometheus_client strips _total from Counter names)
|
||
assert "geo_api_requests" in metrics
|
||
assert "geo_agent_executions" in metrics
|
||
assert "geo_llm_tokens" in metrics
|
||
assert "geo_llm_cost_estimated" in metrics
|
||
assert "geo_brands_total" in metrics
|
||
assert "geo_queries" in metrics
|
||
assert "geo_content_generated" in metrics
|
||
assert "geo_citations_detected" in metrics
|
||
|
||
def test_metric_labels_are_valid(self):
|
||
"""测试指标标签有效性"""
|
||
# API_REQUESTS_TOTAL应该有method, endpoint, status标签
|
||
api_metric = API_REQUESTS_TOTAL
|
||
# 这不会抛出异常
|
||
api_metric.labels(method="GET", endpoint="/test", status="200").inc()
|
||
|
||
def test_agent_metric_labels(self):
|
||
"""测试Agent指标标签"""
|
||
metric = AGENT_EXECUTIONS_TOTAL
|
||
metric.labels(agent_name="test", status="success").inc()
|
||
|
||
def test_llm_metric_labels(self):
|
||
"""测试LLM指标标签"""
|
||
metric = LLM_REQUESTS_TOTAL
|
||
metric.labels(provider="openai", model="gpt-4", status="success").inc()
|
||
|
||
|
||
# ============================================================================
|
||
# 指标历史记录测试
|
||
# ============================================================================
|
||
|
||
class TestMetricsHistory:
|
||
"""指标历史记录测试"""
|
||
|
||
def test_increment_counter_multiple_times(self):
|
||
"""测试多次递增计数器"""
|
||
test_endpoint = "/test-increment"
|
||
initial_count = None
|
||
|
||
# 获取初始值
|
||
metrics = {m.name: m for m in REGISTRY.collect()}
|
||
# prometheus_client strips _total suffix from Counter names in collect()
|
||
api_requests = metrics.get("geo_api_requests")
|
||
for sample in api_requests.samples:
|
||
if sample.labels.get("endpoint") == test_endpoint:
|
||
initial_count = sample.value
|
||
break
|
||
|
||
# 增加计数
|
||
API_REQUESTS_TOTAL.labels(method="GET", endpoint=test_endpoint, status="200").inc()
|
||
API_REQUESTS_TOTAL.labels(method="GET", endpoint=test_endpoint, status="200").inc()
|
||
API_REQUESTS_TOTAL.labels(method="GET", endpoint=test_endpoint, status="200").inc()
|
||
|
||
# 验证增加
|
||
metrics = {m.name: m for m in REGISTRY.collect()}
|
||
# prometheus_client strips _total suffix from Counter names in collect()
|
||
api_requests = metrics.get("geo_api_requests")
|
||
for sample in api_requests.samples:
|
||
if sample.labels.get("endpoint") == test_endpoint:
|
||
if initial_count is not None:
|
||
assert sample.value >= initial_count + 3
|
||
else:
|
||
assert sample.value >= 3
|
||
break
|
||
|
||
def test_set_gauge_value(self):
|
||
"""测试设置gauge值"""
|
||
test_value = 42.5
|
||
|
||
LLM_COST_ESTIMATED.labels(provider="test", model="test").set(test_value)
|
||
|
||
metrics = {m.name: m for m in REGISTRY.collect()}
|
||
llm_cost = metrics.get("geo_llm_cost_estimated")
|
||
|
||
found = False
|
||
for sample in llm_cost.samples:
|
||
if sample.labels.get("provider") == "test":
|
||
assert sample.value == test_value
|
||
found = True
|
||
break
|
||
|
||
assert found, "Should find the gauge with test labels"
|