geo/backend/tests/test_infrastructure/test_monitoring.py

298 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""监控模块TDD测试
测试策略:
- 使用真实Prometheus指标收集
- 测试真实数据库连接检查
- Redis连接检查使用真实或跳过
- LLM提供商检查使用真实调用或跳过
"""
import time
import pytest
import pytest_asyncio
from prometheus_client import REGISTRY
from app.middleware.prometheus_metrics import (
API_REQUESTS_TOTAL,
AGENT_EXECUTIONS_TOTAL,
LLM_REQUESTS_TOTAL,
LLM_TOKENS_TOTAL,
LLM_COST_ESTIMATED,
BRAND_COUNT,
QUERY_COUNT_TOTAL,
CONTENT_GENERATED_TOTAL,
CITATION_DETECTED_TOTAL,
)
from app.services.health_checker import HealthChecker, HealthCheckResult
# ============================================================================
# 指标测试
# ============================================================================
class TestPrometheusMetrics:
"""Prometheus指标测试"""
def test_api_requests_counter_increment(self):
"""测试API请求计数"""
# 记录一个请求
API_REQUESTS_TOTAL.labels(method="GET", endpoint="/test", status="200").inc()
# 从注册表获取指标值
metrics = {m.name: m for m in REGISTRY.collect()}
# prometheus_client strips _total suffix from Counter names in collect()
api_requests = metrics.get("geo_api_requests")
assert api_requests is not None
# 验证指标包含正确的标签
samples = api_requests.samples
sample = next((s for s in samples if s.labels.get("method") == "GET" and s.labels.get("endpoint") == "/test"), None)
assert sample is not None
assert sample.value >= 1
def test_agent_executions_counter(self):
"""测试Agent执行计数"""
AGENT_EXECUTIONS_TOTAL.labels(agent_name="test_agent", status="success").inc()
metrics = {m.name: m for m in REGISTRY.collect()}
agent_executions = metrics.get("geo_agent_executions")
assert agent_executions is not None
def test_llm_tokens_counter(self):
"""测试LLM Tokens计数"""
LLM_TOKENS_TOTAL.labels(provider="openai", model="gpt-4", token_type="prompt").inc(100)
LLM_TOKENS_TOTAL.labels(provider="openai", model="gpt-4", token_type="completion").inc(50)
metrics = {m.name: m for m in REGISTRY.collect()}
llm_tokens = metrics.get("geo_llm_tokens")
assert llm_tokens is not None
def test_llm_cost_gauge(self):
"""测试LLM成本估算"""
LLM_COST_ESTIMATED.labels(provider="openai", model="gpt-4").set(0.25)
metrics = {m.name: m for m in REGISTRY.collect()}
llm_cost = metrics.get("geo_llm_cost_estimated")
assert llm_cost is not None
def test_brand_count_gauge(self):
"""测试品牌计数"""
BRAND_COUNT.inc()
BRAND_COUNT.inc()
metrics = {m.name: m for m in REGISTRY.collect()}
brand_count = metrics.get("geo_brands_total")
assert brand_count is not None
def test_query_count_total(self):
"""测试查询计数"""
QUERY_COUNT_TOTAL.labels(platform="kimi", status="success").inc()
QUERY_COUNT_TOTAL.labels(platform="kimi", status="failed").inc()
metrics = {m.name: m for m in REGISTRY.collect()}
query_count = metrics.get("geo_queries")
assert query_count is not None
def test_content_generated_counter(self):
"""测试内容生成计数"""
CONTENT_GENERATED_TOTAL.inc()
metrics = {m.name: m for m in REGISTRY.collect()}
content_count = metrics.get("geo_content_generated")
assert content_count is not None
def test_citation_detected_counter(self):
"""测试引用检测计数"""
CITATION_DETECTED_TOTAL.labels(platform="kimi").inc()
CITATION_DETECTED_TOTAL.labels(platform="wenxin").inc()
metrics = {m.name: m for m in REGISTRY.collect()}
citation_count = metrics.get("geo_citations_detected")
assert citation_count is not None
# ============================================================================
# 健康检查测试
# ============================================================================
class TestHealthChecker:
"""健康检查服务测试"""
@pytest.mark.asyncio
async def test_health_check_result_dataclass(self):
"""测试HealthCheckResult数据结构"""
result = HealthCheckResult(
name="test",
healthy=True,
latency_ms=10.5,
message="OK",
details={"key": "value"},
)
assert result.name == "test"
assert result.healthy is True
assert result.latency_ms == 10.5
assert result.message == "OK"
assert result.details["key"] == "value"
@pytest.mark.asyncio
async def test_health_check_result_optional_fields(self):
"""测试HealthCheckResult可选字段"""
result = HealthCheckResult(
name="test",
healthy=False,
)
assert result.latency_ms is None
assert result.message is None
assert result.details is None
# ============================================================================
# 健康检查器测试(需要真实数据库会话)
# ============================================================================
class TestHealthCheckerWithDatabase:
"""健康检查器集成测试"""
@pytest_asyncio.fixture
async def health_checker(self, test_session):
"""创建健康检查器实例"""
return HealthChecker(db=test_session, redis_url="redis://localhost:6379")
@pytest.mark.asyncio
async def test_check_database_healthy(self, health_checker):
"""测试数据库健康检查(健康状态)"""
result = await health_checker.check_database()
assert result.name == "database"
assert result.healthy is True
assert result.latency_ms is not None
assert result.latency_ms >= 0
assert "OK" in result.message
@pytest.mark.asyncio
async def test_check_redis_connection(self, health_checker):
"""测试Redis连接检查"""
result = await health_checker.check_redis()
assert result.name == "redis"
# Redis可能不可用但检查应该完成
assert result.latency_ms is not None
# 如果连接失败healthy为False但不应该抛出异常
if not result.healthy:
assert "Connection failed" in result.message or "Error" in result.message
@pytest.mark.asyncio
async def test_check_storage(self, health_checker):
"""测试存储检查"""
result = await health_checker.check_storage()
assert result.name == "storage"
# 存储路径可能不存在,但检查应该完成
assert result.latency_ms is not None
# ============================================================================
# 指标收集测试
# ============================================================================
class TestMetricsCollection:
"""指标收集测试"""
def test_registry_collects_all_metrics(self):
"""测试注册表收集所有指标"""
metrics = {m.name: m for m in REGISTRY.collect()}
# 验证关键指标存在 (prometheus_client strips _total from Counter names)
assert "geo_api_requests" in metrics
assert "geo_agent_executions" in metrics
assert "geo_llm_tokens" in metrics
assert "geo_llm_cost_estimated" in metrics
assert "geo_brands_total" in metrics
assert "geo_queries" in metrics
assert "geo_content_generated" in metrics
assert "geo_citations_detected" in metrics
def test_metric_labels_are_valid(self):
"""测试指标标签有效性"""
# API_REQUESTS_TOTAL应该有method, endpoint, status标签
api_metric = API_REQUESTS_TOTAL
# 这不会抛出异常
api_metric.labels(method="GET", endpoint="/test", status="200").inc()
def test_agent_metric_labels(self):
"""测试Agent指标标签"""
metric = AGENT_EXECUTIONS_TOTAL
metric.labels(agent_name="test", status="success").inc()
def test_llm_metric_labels(self):
"""测试LLM指标标签"""
metric = LLM_REQUESTS_TOTAL
metric.labels(provider="openai", model="gpt-4", status="success").inc()
# ============================================================================
# 指标历史记录测试
# ============================================================================
class TestMetricsHistory:
"""指标历史记录测试"""
def test_increment_counter_multiple_times(self):
"""测试多次递增计数器"""
test_endpoint = "/test-increment"
initial_count = None
# 获取初始值
metrics = {m.name: m for m in REGISTRY.collect()}
# prometheus_client strips _total suffix from Counter names in collect()
api_requests = metrics.get("geo_api_requests")
for sample in api_requests.samples:
if sample.labels.get("endpoint") == test_endpoint:
initial_count = sample.value
break
# 增加计数
API_REQUESTS_TOTAL.labels(method="GET", endpoint=test_endpoint, status="200").inc()
API_REQUESTS_TOTAL.labels(method="GET", endpoint=test_endpoint, status="200").inc()
API_REQUESTS_TOTAL.labels(method="GET", endpoint=test_endpoint, status="200").inc()
# 验证增加
metrics = {m.name: m for m in REGISTRY.collect()}
# prometheus_client strips _total suffix from Counter names in collect()
api_requests = metrics.get("geo_api_requests")
for sample in api_requests.samples:
if sample.labels.get("endpoint") == test_endpoint:
if initial_count is not None:
assert sample.value >= initial_count + 3
else:
assert sample.value >= 3
break
def test_set_gauge_value(self):
"""测试设置gauge值"""
test_value = 42.5
LLM_COST_ESTIMATED.labels(provider="test", model="test").set(test_value)
metrics = {m.name: m for m in REGISTRY.collect()}
llm_cost = metrics.get("geo_llm_cost_estimated")
found = False
for sample in llm_cost.samples:
if sample.labels.get("provider") == "test":
assert sample.value == test_value
found = True
break
assert found, "Should find the gauge with test labels"