"""监控模块TDD测试 测试策略: - 使用真实Prometheus指标收集 - 测试真实数据库连接检查 - Redis连接检查使用真实或跳过 - LLM提供商检查使用真实调用或跳过 """ import time import pytest import pytest_asyncio from prometheus_client import REGISTRY from app.middleware.prometheus_metrics import ( API_REQUESTS_TOTAL, AGENT_EXECUTIONS_TOTAL, LLM_REQUESTS_TOTAL, LLM_TOKENS_TOTAL, LLM_COST_ESTIMATED, BRAND_COUNT, QUERY_COUNT_TOTAL, CONTENT_GENERATED_TOTAL, CITATION_DETECTED_TOTAL, ) from app.services.health_checker import HealthChecker, HealthCheckResult # ============================================================================ # 指标测试 # ============================================================================ class TestPrometheusMetrics: """Prometheus指标测试""" def test_api_requests_counter_increment(self): """测试API请求计数""" # 记录一个请求 API_REQUESTS_TOTAL.labels(method="GET", endpoint="/test", status="200").inc() # 从注册表获取指标值 metrics = {m.name: m for m in REGISTRY.collect()} # prometheus_client strips _total suffix from Counter names in collect() api_requests = metrics.get("geo_api_requests") assert api_requests is not None # 验证指标包含正确的标签 samples = api_requests.samples sample = next((s for s in samples if s.labels.get("method") == "GET" and s.labels.get("endpoint") == "/test"), None) assert sample is not None assert sample.value >= 1 def test_agent_executions_counter(self): """测试Agent执行计数""" AGENT_EXECUTIONS_TOTAL.labels(agent_name="test_agent", status="success").inc() metrics = {m.name: m for m in REGISTRY.collect()} agent_executions = metrics.get("geo_agent_executions") assert agent_executions is not None def test_llm_tokens_counter(self): """测试LLM Tokens计数""" LLM_TOKENS_TOTAL.labels(provider="openai", model="gpt-4", token_type="prompt").inc(100) LLM_TOKENS_TOTAL.labels(provider="openai", model="gpt-4", token_type="completion").inc(50) metrics = {m.name: m for m in REGISTRY.collect()} llm_tokens = metrics.get("geo_llm_tokens") assert llm_tokens is not None def test_llm_cost_gauge(self): """测试LLM成本估算""" LLM_COST_ESTIMATED.labels(provider="openai", model="gpt-4").set(0.25) metrics = {m.name: m for m in REGISTRY.collect()} llm_cost = metrics.get("geo_llm_cost_estimated") assert llm_cost is not None def test_brand_count_gauge(self): """测试品牌计数""" BRAND_COUNT.inc() BRAND_COUNT.inc() metrics = {m.name: m for m in REGISTRY.collect()} brand_count = metrics.get("geo_brands_total") assert brand_count is not None def test_query_count_total(self): """测试查询计数""" QUERY_COUNT_TOTAL.labels(platform="kimi", status="success").inc() QUERY_COUNT_TOTAL.labels(platform="kimi", status="failed").inc() metrics = {m.name: m for m in REGISTRY.collect()} query_count = metrics.get("geo_queries") assert query_count is not None def test_content_generated_counter(self): """测试内容生成计数""" CONTENT_GENERATED_TOTAL.inc() metrics = {m.name: m for m in REGISTRY.collect()} content_count = metrics.get("geo_content_generated") assert content_count is not None def test_citation_detected_counter(self): """测试引用检测计数""" CITATION_DETECTED_TOTAL.labels(platform="kimi").inc() CITATION_DETECTED_TOTAL.labels(platform="wenxin").inc() metrics = {m.name: m for m in REGISTRY.collect()} citation_count = metrics.get("geo_citations_detected") assert citation_count is not None # ============================================================================ # 健康检查测试 # ============================================================================ class TestHealthChecker: """健康检查服务测试""" @pytest.mark.asyncio async def test_health_check_result_dataclass(self): """测试HealthCheckResult数据结构""" result = HealthCheckResult( name="test", healthy=True, latency_ms=10.5, message="OK", details={"key": "value"}, ) assert result.name == "test" assert result.healthy is True assert result.latency_ms == 10.5 assert result.message == "OK" assert result.details["key"] == "value" @pytest.mark.asyncio async def test_health_check_result_optional_fields(self): """测试HealthCheckResult可选字段""" result = HealthCheckResult( name="test", healthy=False, ) assert result.latency_ms is None assert result.message is None assert result.details is None # ============================================================================ # 健康检查器测试(需要真实数据库会话) # ============================================================================ class TestHealthCheckerWithDatabase: """健康检查器集成测试""" @pytest_asyncio.fixture async def health_checker(self, test_session): """创建健康检查器实例""" return HealthChecker(db=test_session, redis_url="redis://localhost:6379") @pytest.mark.asyncio async def test_check_database_healthy(self, health_checker): """测试数据库健康检查(健康状态)""" result = await health_checker.check_database() assert result.name == "database" assert result.healthy is True assert result.latency_ms is not None assert result.latency_ms >= 0 assert "OK" in result.message @pytest.mark.asyncio async def test_check_redis_connection(self, health_checker): """测试Redis连接检查""" result = await health_checker.check_redis() assert result.name == "redis" # Redis可能不可用,但检查应该完成 assert result.latency_ms is not None # 如果连接失败,healthy为False,但不应该抛出异常 if not result.healthy: assert "Connection failed" in result.message or "Error" in result.message @pytest.mark.asyncio async def test_check_storage(self, health_checker): """测试存储检查""" result = await health_checker.check_storage() assert result.name == "storage" # 存储路径可能不存在,但检查应该完成 assert result.latency_ms is not None # ============================================================================ # 指标收集测试 # ============================================================================ class TestMetricsCollection: """指标收集测试""" def test_registry_collects_all_metrics(self): """测试注册表收集所有指标""" metrics = {m.name: m for m in REGISTRY.collect()} # 验证关键指标存在 (prometheus_client strips _total from Counter names) assert "geo_api_requests" in metrics assert "geo_agent_executions" in metrics assert "geo_llm_tokens" in metrics assert "geo_llm_cost_estimated" in metrics assert "geo_brands_total" in metrics assert "geo_queries" in metrics assert "geo_content_generated" in metrics assert "geo_citations_detected" in metrics def test_metric_labels_are_valid(self): """测试指标标签有效性""" # API_REQUESTS_TOTAL应该有method, endpoint, status标签 api_metric = API_REQUESTS_TOTAL # 这不会抛出异常 api_metric.labels(method="GET", endpoint="/test", status="200").inc() def test_agent_metric_labels(self): """测试Agent指标标签""" metric = AGENT_EXECUTIONS_TOTAL metric.labels(agent_name="test", status="success").inc() def test_llm_metric_labels(self): """测试LLM指标标签""" metric = LLM_REQUESTS_TOTAL metric.labels(provider="openai", model="gpt-4", status="success").inc() # ============================================================================ # 指标历史记录测试 # ============================================================================ class TestMetricsHistory: """指标历史记录测试""" def test_increment_counter_multiple_times(self): """测试多次递增计数器""" test_endpoint = "/test-increment" initial_count = None # 获取初始值 metrics = {m.name: m for m in REGISTRY.collect()} # prometheus_client strips _total suffix from Counter names in collect() api_requests = metrics.get("geo_api_requests") for sample in api_requests.samples: if sample.labels.get("endpoint") == test_endpoint: initial_count = sample.value break # 增加计数 API_REQUESTS_TOTAL.labels(method="GET", endpoint=test_endpoint, status="200").inc() API_REQUESTS_TOTAL.labels(method="GET", endpoint=test_endpoint, status="200").inc() API_REQUESTS_TOTAL.labels(method="GET", endpoint=test_endpoint, status="200").inc() # 验证增加 metrics = {m.name: m for m in REGISTRY.collect()} # prometheus_client strips _total suffix from Counter names in collect() api_requests = metrics.get("geo_api_requests") for sample in api_requests.samples: if sample.labels.get("endpoint") == test_endpoint: if initial_count is not None: assert sample.value >= initial_count + 3 else: assert sample.value >= 3 break def test_set_gauge_value(self): """测试设置gauge值""" test_value = 42.5 LLM_COST_ESTIMATED.labels(provider="test", model="test").set(test_value) metrics = {m.name: m for m in REGISTRY.collect()} llm_cost = metrics.get("geo_llm_cost_estimated") found = False for sample in llm_cost.samples: if sample.labels.get("provider") == "test": assert sample.value == test_value found = True break assert found, "Should find the gauge with test labels"