147 lines
4.6 KiB
Python
147 lines
4.6 KiB
Python
"""请求指标收集中间件:计时、慢请求告警、响应时间响应头、Prometheus指标收集。
|
||
|
||
合并自原 middleware/metrics.py(MetricsMiddleware)和 monitoring/middleware.py(MonitoringMiddleware)。
|
||
"""
|
||
import time
|
||
import logging
|
||
from typing import Callable
|
||
|
||
from fastapi import Request, Response
|
||
from starlette.middleware.base import BaseHTTPMiddleware
|
||
|
||
from app.middleware.prometheus_metrics import (
|
||
API_REQUESTS_TOTAL,
|
||
API_REQUEST_DURATION_SECONDS,
|
||
API_REQUESTS_IN_PROGRESS,
|
||
)
|
||
|
||
logger = logging.getLogger("geo.metrics")
|
||
|
||
# 慢请求阈值(秒)
|
||
SLOW_REQUEST_THRESHOLD = 1.0
|
||
|
||
# 跳过指标收集的路径前缀(健康检查等高频低价值路径)
|
||
_SKIP_PATHS = {"/health", "/ready", "/docs", "/openapi.json", "/favicon.ico", "/metrics"}
|
||
|
||
|
||
class MetricsMiddleware(BaseHTTPMiddleware):
|
||
"""记录每个 HTTP 请求的耗时,并:
|
||
- 在响应头写入 X-Response-Time
|
||
- 对超过阈值的慢请求输出 WARNING 日志(携带结构化字段)
|
||
- 预留 Sentry 集成点(已集成)
|
||
"""
|
||
|
||
async def dispatch(self, request: Request, call_next) -> Response:
|
||
# 跳过健康检查等低价值路径,避免日志噪音
|
||
if request.url.path in _SKIP_PATHS:
|
||
return await call_next(request)
|
||
|
||
start_time = time.perf_counter()
|
||
response = await call_next(request)
|
||
duration = time.perf_counter() - start_time
|
||
duration_ms = round(duration * 1000, 2)
|
||
|
||
# 写回响应时间响应头
|
||
response.headers["X-Response-Time"] = f"{duration:.3f}s"
|
||
|
||
# 从 request.state 获取 request_id(由 RequestIdMiddleware 注入)
|
||
request_id = getattr(request.state, "request_id", None)
|
||
|
||
log_extra: dict = {
|
||
"path": request.url.path,
|
||
"method": request.method,
|
||
"duration_ms": duration_ms,
|
||
"status_code": response.status_code,
|
||
}
|
||
if request_id:
|
||
log_extra["request_id"] = request_id
|
||
|
||
if duration >= SLOW_REQUEST_THRESHOLD:
|
||
logger.warning("Slow request detected", extra=log_extra)
|
||
else:
|
||
logger.debug("Request completed", extra=log_extra)
|
||
|
||
# Sentry 性能监控
|
||
try:
|
||
import sentry_sdk
|
||
sentry_sdk.set_measurement("response_time_ms", duration_ms)
|
||
except Exception:
|
||
pass
|
||
|
||
return response
|
||
|
||
|
||
class MonitoringMiddleware(BaseHTTPMiddleware):
|
||
"""API监控中间件 — 收集 Prometheus 指标。
|
||
|
||
- 记录请求总数、耗时分布、活跃请求数
|
||
- 自动规范化端点标签(替换路径中的ID参数)
|
||
"""
|
||
|
||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||
# 跳过排除路径
|
||
if request.url.path in _SKIP_PATHS:
|
||
return await call_next(request)
|
||
|
||
# 提取端点标识(用于指标标签)
|
||
endpoint = self._get_endpoint_label(request)
|
||
|
||
# 增加活跃请求计数
|
||
API_REQUESTS_IN_PROGRESS.labels(
|
||
method=request.method,
|
||
endpoint=endpoint
|
||
).inc()
|
||
|
||
# 记录开始时间
|
||
start_time = time.perf_counter()
|
||
|
||
try:
|
||
# 执行请求
|
||
response = await call_next(request)
|
||
status_code = response.status_code
|
||
except Exception as e:
|
||
status_code = 500
|
||
raise
|
||
finally:
|
||
# 计算耗时
|
||
duration = time.perf_counter() - start_time
|
||
|
||
# 记录指标
|
||
API_REQUESTS_TOTAL.labels(
|
||
method=request.method,
|
||
endpoint=endpoint,
|
||
status=str(status_code)
|
||
).inc()
|
||
|
||
API_REQUEST_DURATION_SECONDS.labels(
|
||
method=request.method,
|
||
endpoint=endpoint
|
||
).observe(duration)
|
||
|
||
# 减少活跃请求计数
|
||
API_REQUESTS_IN_PROGRESS.labels(
|
||
method=request.method,
|
||
endpoint=endpoint
|
||
).dec()
|
||
|
||
return response
|
||
|
||
def _get_endpoint_label(self, request: Request) -> str:
|
||
"""提取端点标签"""
|
||
path = request.url.path
|
||
|
||
# 规范化路径(替换ID等参数)
|
||
parts = path.strip("/").split("/")
|
||
|
||
# 处理常见模式:/api/v1/resources/{id}
|
||
if len(parts) >= 4 and parts[0] == "api":
|
||
resource = parts[2] if len(parts) > 2 else "unknown"
|
||
action = parts[3] if len(parts) > 3 else "list"
|
||
|
||
# 映射到规范标签
|
||
if action.isdigit():
|
||
return f"{resource}_detail"
|
||
return f"{resource}_{action}"
|
||
|
||
return "other"
|