"""请求指标收集中间件:计时、慢请求告警、响应时间响应头、Prometheus指标收集。 合并自原 middleware/metrics.py(MetricsMiddleware)和 monitoring/middleware.py(MonitoringMiddleware)。 """ import time import logging from typing import Callable from fastapi import Request, Response from starlette.middleware.base import BaseHTTPMiddleware from app.middleware.prometheus_metrics import ( API_REQUESTS_TOTAL, API_REQUEST_DURATION_SECONDS, API_REQUESTS_IN_PROGRESS, ) logger = logging.getLogger("geo.metrics") # 慢请求阈值(秒) SLOW_REQUEST_THRESHOLD = 1.0 # 跳过指标收集的路径前缀(健康检查等高频低价值路径) _SKIP_PATHS = {"/health", "/ready", "/docs", "/openapi.json", "/favicon.ico", "/metrics"} class MetricsMiddleware(BaseHTTPMiddleware): """记录每个 HTTP 请求的耗时,并: - 在响应头写入 X-Response-Time - 对超过阈值的慢请求输出 WARNING 日志(携带结构化字段) - 预留 Sentry 集成点(TODO 注释标注) """ async def dispatch(self, request: Request, call_next) -> Response: # 跳过健康检查等低价值路径,避免日志噪音 if request.url.path in _SKIP_PATHS: return await call_next(request) start_time = time.perf_counter() response = await call_next(request) duration = time.perf_counter() - start_time duration_ms = round(duration * 1000, 2) # 写回响应时间响应头 response.headers["X-Response-Time"] = f"{duration:.3f}s" # 从 request.state 获取 request_id(由 RequestIdMiddleware 注入) request_id = getattr(request.state, "request_id", None) log_extra: dict = { "path": request.url.path, "method": request.method, "duration_ms": duration_ms, "status_code": response.status_code, } if request_id: log_extra["request_id"] = request_id if duration >= SLOW_REQUEST_THRESHOLD: logger.warning("Slow request detected", extra=log_extra) else: logger.debug("Request completed", extra=log_extra) # TODO: 集成 Sentry 性能监控 # if sentry_sdk: sentry_sdk.set_measurement("response_time_ms", duration_ms) return response class MonitoringMiddleware(BaseHTTPMiddleware): """API监控中间件 — 收集 Prometheus 指标。 - 记录请求总数、耗时分布、活跃请求数 - 自动规范化端点标签(替换路径中的ID参数) """ async def dispatch(self, request: Request, call_next: Callable) -> Response: # 跳过排除路径 if request.url.path in _SKIP_PATHS: return await call_next(request) # 提取端点标识(用于指标标签) endpoint = self._get_endpoint_label(request) # 增加活跃请求计数 API_REQUESTS_IN_PROGRESS.labels( method=request.method, endpoint=endpoint ).inc() # 记录开始时间 start_time = time.perf_counter() try: # 执行请求 response = await call_next(request) status_code = response.status_code except Exception as e: status_code = 500 raise finally: # 计算耗时 duration = time.perf_counter() - start_time # 记录指标 API_REQUESTS_TOTAL.labels( method=request.method, endpoint=endpoint, status=str(status_code) ).inc() API_REQUEST_DURATION_SECONDS.labels( method=request.method, endpoint=endpoint ).observe(duration) # 减少活跃请求计数 API_REQUESTS_IN_PROGRESS.labels( method=request.method, endpoint=endpoint ).dec() return response def _get_endpoint_label(self, request: Request) -> str: """提取端点标签""" path = request.url.path # 规范化路径(替换ID等参数) parts = path.strip("/").split("/") # 处理常见模式:/api/v1/resources/{id} if len(parts) >= 4 and parts[0] == "api": resource = parts[2] if len(parts) > 2 else "unknown" action = parts[3] if len(parts) > 3 else "list" # 映射到规范标签 if action.isdigit(): return f"{resource}_detail" return f"{resource}_{action}" return "other"