fix: 审计发现的问题修复
API一致性修复: - C1: 新增organization.py路由(/api/v1/organization/*) - C2: 修复suggestions API路径(/api/v1/brands/*而非/api/v1/suggestions/*) - H7: 修复platforms路由双重前缀(/api/v1/platforms而非/api/v1/api/platforms) 密钥管理改进: - H3: APIKeyManager支持双密钥(dict格式),文心一言适配器使用KeyManager - H8: 新增APIKeyFilter日志过滤器,拦截key=和Bearer token 异常处理改进: - H1: batch_query.py改为httpx.HTTPError分层处理 - H1: database.py改为SQLAlchemyError并抛出ConnectionError - H1: lifecycle.py和usage_tracker添加日志记录 测试: 764 passed
This commit is contained in:
parent
fe4ba39514
commit
aeaa50e89e
|
|
@ -1,4 +1,5 @@
|
|||
import uuid
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
|
|
@ -11,6 +12,8 @@ from app.database import get_db
|
|||
from app.models.lifecycle import LifecycleProject, ProjectStage
|
||||
from app.models.organization import Organization, OrgMember
|
||||
from app.models.user import User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from app.schemas.lifecycle import (
|
||||
ProjectCreateRequest,
|
||||
ProjectResponse,
|
||||
|
|
@ -171,7 +174,8 @@ async def project_stats(
|
|||
contents_stmt = select(func.count()).where(Content.organization_id == org_id)
|
||||
contents_result = await db.execute(contents_stmt)
|
||||
contents_produced = contents_result.scalar() or 0
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to count contents: {e}")
|
||||
contents_produced = 0
|
||||
|
||||
# avg AI citation rate
|
||||
|
|
@ -197,7 +201,8 @@ async def project_stats(
|
|||
avg_ai_citation_rate = round(cited_count / total_citations, 4) if total_citations > 0 else None
|
||||
else:
|
||||
avg_ai_citation_rate = None
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to calculate AI citation rate: {e}")
|
||||
avg_ai_citation_rate = None
|
||||
|
||||
# current stage distribution (map int stage to string)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,343 @@
|
|||
"""组织管理 API — /api/v1/organization/*"""
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.database import get_db
|
||||
from app.models.organization import Organization, OrgMember
|
||||
from app.models.user import User
|
||||
|
||||
router = APIRouter(prefix="/api/v1/organization", tags=["组织管理"])
|
||||
|
||||
|
||||
class OrgInfoResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
slug: str
|
||||
description: Optional[str] = None
|
||||
logo_url: Optional[str] = None
|
||||
plan: str
|
||||
max_members: int
|
||||
member_count: int = 0
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class MemberResponse(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
name: str
|
||||
email: str
|
||||
role: str
|
||||
joined_at: str
|
||||
invited_by: Optional[str] = None
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class InviteMemberRequest(BaseModel):
|
||||
email: EmailStr
|
||||
role: str = Field(default="viewer", pattern="^(owner|admin|editor|viewer)$")
|
||||
|
||||
|
||||
class UpdateMemberRoleRequest(BaseModel):
|
||||
role: str = Field(..., pattern="^(owner|admin|editor|viewer)$")
|
||||
|
||||
|
||||
class InviteResponse(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
role: str
|
||||
message: str
|
||||
|
||||
|
||||
@router.get("/info", response_model=OrgInfoResponse)
|
||||
async def get_org_info(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取当前用户所属组织的基本信息"""
|
||||
if not current_user.organization_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="用户未加入任何组织",
|
||||
)
|
||||
|
||||
stmt = select(Organization).where(Organization.id == current_user.organization_id)
|
||||
result = await db.execute(stmt)
|
||||
org = result.scalar_one_or_none()
|
||||
|
||||
if not org:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="组织不存在",
|
||||
)
|
||||
|
||||
count_stmt = select(func.count(OrgMember.id)).where(
|
||||
OrgMember.organization_id == org.id
|
||||
)
|
||||
count_result = await db.execute(count_stmt)
|
||||
member_count = count_result.scalar() or 0
|
||||
|
||||
return OrgInfoResponse(
|
||||
id=str(org.id),
|
||||
name=org.name,
|
||||
slug=org.slug,
|
||||
description=org.description,
|
||||
logo_url=org.logo_url,
|
||||
plan=org.plan,
|
||||
max_members=org.max_members,
|
||||
member_count=member_count,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/members", response_model=list[MemberResponse])
|
||||
async def list_members(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取当前用户所属组织的所有成员"""
|
||||
if not current_user.organization_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="用户未加入任何组织",
|
||||
)
|
||||
|
||||
stmt = (
|
||||
select(OrgMember, User)
|
||||
.join(User, User.id == OrgMember.user_id)
|
||||
.where(OrgMember.organization_id == current_user.organization_id)
|
||||
.order_by(OrgMember.joined_at.desc())
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
rows = result.all()
|
||||
|
||||
members = []
|
||||
for member, user in rows:
|
||||
members.append(
|
||||
MemberResponse(
|
||||
id=str(member.id),
|
||||
user_id=str(member.user_id),
|
||||
name=user.name or "",
|
||||
email=user.email,
|
||||
role=member.role,
|
||||
joined_at=member.joined_at.isoformat() if member.joined_at else "",
|
||||
invited_by=str(member.invited_by) if member.invited_by else None,
|
||||
)
|
||||
)
|
||||
return members
|
||||
|
||||
|
||||
@router.post("/members/invite", response_model=InviteResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def invite_member(
|
||||
body: InviteMemberRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""邀请新成员加入组织"""
|
||||
if not current_user.organization_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="用户未加入任何组织",
|
||||
)
|
||||
|
||||
current_membership_stmt = select(OrgMember).where(
|
||||
OrgMember.organization_id == current_user.organization_id,
|
||||
OrgMember.user_id == current_user.id,
|
||||
)
|
||||
current_membership_result = await db.execute(current_membership_stmt)
|
||||
current_membership = current_membership_result.scalar_one_or_none()
|
||||
|
||||
if not current_membership or current_membership.role not in ["owner", "admin"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="仅管理员可以邀请新成员",
|
||||
)
|
||||
|
||||
target_user_stmt = select(User).where(User.email == body.email)
|
||||
target_user_result = await db.execute(target_user_stmt)
|
||||
target_user = target_user_result.scalar_one_or_none()
|
||||
|
||||
if not target_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="用户不存在",
|
||||
)
|
||||
|
||||
existing_stmt = select(OrgMember).where(
|
||||
OrgMember.organization_id == current_user.organization_id,
|
||||
OrgMember.user_id == target_user.id,
|
||||
)
|
||||
existing_result = await db.execute(existing_stmt)
|
||||
existing_membership = existing_result.scalar_one_or_none()
|
||||
|
||||
if existing_membership:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="该用户已经是组织成员",
|
||||
)
|
||||
|
||||
count_stmt = select(func.count(OrgMember.id)).where(
|
||||
OrgMember.organization_id == current_user.organization_id
|
||||
)
|
||||
count_result = await db.execute(count_stmt)
|
||||
current_member_count = count_result.scalar() or 0
|
||||
|
||||
org_stmt = select(Organization).where(Organization.id == current_user.organization_id)
|
||||
org_result = await db.execute(org_stmt)
|
||||
org = org_result.scalar_one_or_none()
|
||||
|
||||
if org and current_member_count >= org.max_members:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"成员数量已达上限({org.max_members}人)",
|
||||
)
|
||||
|
||||
new_membership = OrgMember(
|
||||
id=uuid.uuid4(),
|
||||
organization_id=current_user.organization_id,
|
||||
user_id=target_user.id,
|
||||
role=body.role,
|
||||
invited_by=current_user.id,
|
||||
)
|
||||
db.add(new_membership)
|
||||
await db.commit()
|
||||
await db.refresh(new_membership)
|
||||
|
||||
return InviteResponse(
|
||||
id=str(new_membership.id),
|
||||
email=target_user.email,
|
||||
role=new_membership.role,
|
||||
message=f"成功邀请 {target_user.email} 加入组织",
|
||||
)
|
||||
|
||||
|
||||
@router.put("/members/{user_id}/role", response_model=MemberResponse)
|
||||
async def update_member_role(
|
||||
user_id: uuid.UUID,
|
||||
body: UpdateMemberRoleRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""更新成员角色"""
|
||||
if not current_user.organization_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="用户未加入任何组织",
|
||||
)
|
||||
|
||||
current_membership_stmt = select(OrgMember).where(
|
||||
OrgMember.organization_id == current_user.organization_id,
|
||||
OrgMember.user_id == current_user.id,
|
||||
)
|
||||
current_membership_result = await db.execute(current_membership_stmt)
|
||||
current_membership = current_membership_result.scalar_one_or_none()
|
||||
|
||||
if not current_membership or current_membership.role not in ["owner", "admin"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="仅管理员可以修改成员角色",
|
||||
)
|
||||
|
||||
target_membership_stmt = select(OrgMember, User).join(
|
||||
User, User.id == OrgMember.user_id
|
||||
).where(
|
||||
OrgMember.organization_id == current_user.organization_id,
|
||||
OrgMember.user_id == user_id,
|
||||
)
|
||||
target_result = await db.execute(target_membership_stmt)
|
||||
target_row = target_result.first()
|
||||
|
||||
if not target_row:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="成员不存在",
|
||||
)
|
||||
|
||||
target_membership, target_user = target_row
|
||||
|
||||
if target_membership.role == "owner":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="不能修改所有者角色",
|
||||
)
|
||||
|
||||
if body.role == "owner":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="不能将成员设为所有者",
|
||||
)
|
||||
|
||||
target_membership.role = body.role
|
||||
await db.commit()
|
||||
await db.refresh(target_membership)
|
||||
|
||||
return MemberResponse(
|
||||
id=str(target_membership.id),
|
||||
user_id=str(target_membership.user_id),
|
||||
name=target_user.name or "",
|
||||
email=target_user.email,
|
||||
role=target_membership.role,
|
||||
joined_at=target_membership.joined_at.isoformat() if target_membership.joined_at else "",
|
||||
invited_by=str(target_membership.invited_by) if target_membership.invited_by else None,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/members/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def remove_member(
|
||||
user_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""移除组织成员"""
|
||||
if not current_user.organization_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="用户未加入任何组织",
|
||||
)
|
||||
|
||||
current_membership_stmt = select(OrgMember).where(
|
||||
OrgMember.organization_id == current_user.organization_id,
|
||||
OrgMember.user_id == current_user.id,
|
||||
)
|
||||
current_membership_result = await db.execute(current_membership_stmt)
|
||||
current_membership = current_membership_result.scalar_one_or_none()
|
||||
|
||||
if not current_membership or current_membership.role not in ["owner", "admin"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="仅管理员可以移除成员",
|
||||
)
|
||||
|
||||
target_membership_stmt = select(OrgMember).where(
|
||||
OrgMember.organization_id == current_user.organization_id,
|
||||
OrgMember.user_id == user_id,
|
||||
)
|
||||
target_result = await db.execute(target_membership_stmt)
|
||||
target_membership = target_result.scalar_one_or_none()
|
||||
|
||||
if not target_membership:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="成员不存在",
|
||||
)
|
||||
|
||||
if target_membership.role == "owner":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="不能移除所有者",
|
||||
)
|
||||
|
||||
if user_id == current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="不能移除自己",
|
||||
)
|
||||
|
||||
await db.delete(target_membership)
|
||||
await db.commit()
|
||||
|
|
@ -14,7 +14,7 @@ from app.config import settings
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/platforms", tags=["platforms"])
|
||||
router = APIRouter(prefix="/platforms", tags=["platforms"])
|
||||
|
||||
|
||||
class PlatformHealthStatus:
|
||||
|
|
|
|||
|
|
@ -1,11 +1,15 @@
|
|||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
|
||||
from sqlalchemy.orm import declarative_base
|
||||
from sqlalchemy import text, JSON
|
||||
from sqlalchemy import exc as sqlalchemy_exc
|
||||
from sqlalchemy.types import TypeDecorator
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
import logging
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JSONType(TypeDecorator):
|
||||
"""A JSON type that uses JSONB on PostgreSQL and JSON on other databases."""
|
||||
|
|
@ -53,5 +57,9 @@ async def check_db_connection() -> bool:
|
|||
async with AsyncSessionLocal() as session:
|
||||
await session.execute(text("SELECT 1"))
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
except sqlalchemy_exc.SQLAlchemyError as e:
|
||||
logger.error(f"Database connection error: {e}", exc_info=True)
|
||||
raise ConnectionError(f"Failed to connect to database: {e}") from e
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected database error: {e}", exc_info=True)
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@ import logging
|
|||
import json
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from app.middleware.logging_filter import APIKeyFilter
|
||||
|
||||
|
||||
class JSONFormatter(logging.Formatter):
|
||||
"""将日志记录格式化为 JSON 字符串,便于日志收集平台(如 ELK、Loki)解析。"""
|
||||
|
|
@ -47,11 +49,10 @@ def setup_logging(level: int = logging.INFO) -> None:
|
|||
handler.setFormatter(JSONFormatter())
|
||||
|
||||
root_logger = logging.getLogger()
|
||||
# 清空已有 handlers,避免重复输出
|
||||
root_logger.handlers.clear()
|
||||
root_logger.addHandler(handler)
|
||||
root_logger.setLevel(level)
|
||||
root_logger.addFilter(APIKeyFilter())
|
||||
|
||||
# 降低 uvicorn/sqlalchemy 等第三方库的噪音
|
||||
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
|
||||
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from app.api.admin import router as admin_router
|
|||
from app.api.content import router as content_router
|
||||
from app.api.contents import router as contents_router
|
||||
from app.api.clients import router as clients_router
|
||||
from app.api.organization import router as organization_router
|
||||
from app.api.agents import router as agents_router
|
||||
from app.api.knowledge import router as knowledge_router
|
||||
from app.api.distribution import router as distribution_router
|
||||
|
|
@ -157,6 +158,7 @@ app.include_router(lifecycle_router, prefix="/api/v1/lifecycle", tags=["lifecycl
|
|||
app.include_router(knowledge_router, prefix="/api/v1/knowledge", tags=["知识库"])
|
||||
app.include_router(content_router, prefix="/api/v1/content", tags=["内容生产"])
|
||||
app.include_router(contents_router, prefix="/api/v1/contents", tags=["内容管理"])
|
||||
app.include_router(organization_router)
|
||||
app.include_router(clients_router, prefix="/api/v1/clients", tags=["客户管理"])
|
||||
app.include_router(distribution_router, prefix="/api/v1/distribution", tags=["内容分发"])
|
||||
app.include_router(analytics_router, prefix="/api/v1/analytics", tags=["监测优化"])
|
||||
|
|
|
|||
|
|
@ -0,0 +1,35 @@
|
|||
import logging
|
||||
import re
|
||||
|
||||
|
||||
class APIKeyFilter(logging.Filter):
|
||||
"""日志过滤器,拦截敏感信息"""
|
||||
|
||||
PATTERNS = [
|
||||
(r'key=[A-Za-z0-9_-]{30,}', 'key=***REDACTED***'),
|
||||
(r'Bearer\s+[A-Za-z0-9_-]{20,}', 'Bearer ***REDACTED***'),
|
||||
(r'Authorization:\s*[A-Za-z0-9_-]{20,}', 'Authorization: ***REDACTED***'),
|
||||
(r'api[_-]?key=[A-Za-z0-9_-]{30,}', 'api_key=***REDACTED***'),
|
||||
(r'(?:^|\s)(AIza[A-Za-z0-9_-]{30,})(?:\s|$)', r'\1***REDACTED***'),
|
||||
(r'(?:^|\s)(sk-[A-Za-z0-9_-]{30,})(?:\s|$)', r'\1***REDACTED***'),
|
||||
(r'(?:^|\s)(gsk-[A-Za-z0-9_-]{30,})(?:\s|$)', r'\1***REDACTED***'),
|
||||
]
|
||||
|
||||
def _redact_string(self, text: str) -> str:
|
||||
for pattern, replacement in self.PATTERNS:
|
||||
text = re.sub(pattern, replacement, text)
|
||||
return text
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
if record.args:
|
||||
new_args = []
|
||||
for arg in record.args:
|
||||
if isinstance(arg, str):
|
||||
arg = self._redact_string(arg)
|
||||
new_args.append(arg)
|
||||
record.args = tuple(new_args)
|
||||
|
||||
if record.msg and isinstance(record.msg, str):
|
||||
record.msg = self._redact_string(record.msg)
|
||||
|
||||
return True
|
||||
|
|
@ -3,6 +3,8 @@ import logging
|
|||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import httpx
|
||||
|
||||
from .base import AIEngineAdapter, AIQueryResult, EngineType
|
||||
from app.services.usage_recorder import UsageRecorder
|
||||
from app.services.usage_tracker import UsageTracker
|
||||
|
|
@ -63,8 +65,12 @@ def _build_adapters() -> dict[str, AIEngineAdapter]:
|
|||
for engine_type, cls in _ADAPTER_CLASSES.items():
|
||||
try:
|
||||
adapters[engine_type.value] = cls()
|
||||
except Exception:
|
||||
logger.warning(f"Failed to initialize {engine_type.value} adapter")
|
||||
except httpx.HTTPError as e:
|
||||
logger.warning(f"HTTP error from {engine_type.value}: {e}")
|
||||
except asyncio.TimeoutError as e:
|
||||
logger.warning(f"Timeout from {engine_type.value}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error from {engine_type.value}: {e}", exc_info=True)
|
||||
return adapters
|
||||
|
||||
|
||||
|
|
@ -79,8 +85,12 @@ def _build_adapters_with_key_manager(
|
|||
key_manager=key_manager,
|
||||
user_id=user_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(f"Failed to initialize {engine_type.value} adapter")
|
||||
except httpx.HTTPError as e:
|
||||
logger.warning(f"HTTP error from {engine_type.value}: {e}")
|
||||
except asyncio.TimeoutError as e:
|
||||
logger.warning(f"Timeout from {engine_type.value}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error from {engine_type.value}: {e}", exc_info=True)
|
||||
return adapters
|
||||
|
||||
|
||||
|
|
@ -113,6 +123,7 @@ class BatchQueryService:
|
|||
result = await adapter.query(query, brand_name, competitor_names)
|
||||
|
||||
if self._user_id:
|
||||
try:
|
||||
self._recorder.record(
|
||||
user_id=self._user_id,
|
||||
brand_id=self._brand_id,
|
||||
|
|
@ -125,6 +136,8 @@ class BatchQueryService:
|
|||
"response_time_ms": result.response_time_ms,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to record usage: {e}")
|
||||
|
||||
return result
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from datetime import UTC, datetime
|
|||
|
||||
import httpx
|
||||
|
||||
from app.services.api_key_manager import APIKeyManager, KeyCredentials
|
||||
from .base import AIEngineAdapter, AIQueryResult, EngineType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -27,28 +28,48 @@ class WenxinAdapter(AIEngineAdapter):
|
|||
secret_key: str | None = None,
|
||||
rate_limiter=None,
|
||||
proxy: str | None = None,
|
||||
key_manager=None,
|
||||
key_manager: APIKeyManager | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
rate_limiter=rate_limiter,
|
||||
proxy=proxy,
|
||||
key_manager=key_manager,
|
||||
user_id=user_id,
|
||||
)
|
||||
self.secret_key = secret_key or os.getenv("BAIDU_QIANFAN_SECRET_KEY", "")
|
||||
self._key_manager = key_manager
|
||||
self._user_id = user_id
|
||||
self.rate_limiter = rate_limiter
|
||||
self.proxy = proxy or self._load_proxy()
|
||||
self.api_key = api_key or ""
|
||||
self.secret_key = secret_key or ""
|
||||
self._resolve_keys_from_manager(api_key, secret_key, key_manager, user_id)
|
||||
self._model = _DEFAULT_MODEL
|
||||
self._client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(connect=10.0, read=60.0, write=10.0, pool=10.0),
|
||||
)
|
||||
|
||||
def _resolve_keys_from_manager(
|
||||
self,
|
||||
direct_api_key: str | None,
|
||||
direct_secret_key: str | None,
|
||||
key_manager: APIKeyManager | None,
|
||||
user_id: str | None,
|
||||
) -> None:
|
||||
if direct_api_key and direct_api_key.strip():
|
||||
return
|
||||
if not key_manager:
|
||||
return
|
||||
creds = key_manager.get_credentials("wenxin", user_id=user_id)
|
||||
if creds:
|
||||
if not self.api_key:
|
||||
self.api_key = creds.api_key
|
||||
if not self.secret_key:
|
||||
self.secret_key = creds.secret_key or ""
|
||||
|
||||
def get_engine_type(self) -> EngineType:
|
||||
return EngineType.WENXIN
|
||||
|
||||
def _get_env_key(self) -> str | None:
|
||||
return os.getenv("BAIDU_QIANFAN_API_KEY", "")
|
||||
|
||||
def _get_env_secret_key(self) -> str | None:
|
||||
return os.getenv("BAIDU_QIANFAN_SECRET_KEY", "")
|
||||
|
||||
def _load_proxy(self) -> str | None:
|
||||
return os.getenv("BAIDU_PROXY") or os.getenv("HTTPS_PROXY") or os.getenv("https_proxy")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
|
|
@ -21,6 +22,21 @@ class KeySource(str, Enum):
|
|||
ENV = "env"
|
||||
|
||||
|
||||
@dataclass
|
||||
class KeyCredentials:
|
||||
"""统一凭证格式,支持单Key和双Key"""
|
||||
api_key: str
|
||||
secret_key: str | None = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
if self.secret_key:
|
||||
return {"api_key": self.api_key, "secret_key": self.secret_key}
|
||||
return {"api_key": self.api_key}
|
||||
|
||||
def to_json(self) -> str:
|
||||
return json.dumps(self.to_dict())
|
||||
|
||||
|
||||
@dataclass
|
||||
class APIKeyConfig:
|
||||
engine_type: str
|
||||
|
|
@ -56,16 +72,23 @@ class APIKeyManager:
|
|||
def add_key(
|
||||
self,
|
||||
engine_type: str,
|
||||
api_key: str,
|
||||
credentials: str | dict,
|
||||
source: KeySource = KeySource.SYSTEM,
|
||||
user_id: str | None = None,
|
||||
priority: int = 0,
|
||||
) -> APIKeyConfig:
|
||||
if isinstance(credentials, dict):
|
||||
key_hint = self._create_dual_key_hint(credentials)
|
||||
credentials_json = json.dumps(credentials)
|
||||
encrypted_key = self._encrypt(credentials_json)
|
||||
else:
|
||||
key_hint = self._mask_key(credentials)
|
||||
encrypted_key = self._encrypt(credentials)
|
||||
config = APIKeyConfig(
|
||||
engine_type=engine_type,
|
||||
key_source=source,
|
||||
encrypted_key=self._encrypt(api_key),
|
||||
key_hint=self._mask_key(api_key),
|
||||
encrypted_key=encrypted_key,
|
||||
key_hint=key_hint,
|
||||
status=KeyStatus.UNKNOWN,
|
||||
priority=priority,
|
||||
user_id=user_id,
|
||||
|
|
@ -76,8 +99,39 @@ class APIKeyManager:
|
|||
self._keys[engine_type].sort(key=lambda k: k.priority, reverse=True)
|
||||
return config
|
||||
|
||||
def get_key(self, engine_type: str, user_id: str | None = None) -> str | None:
|
||||
def _create_dual_key_hint(self, credentials: dict) -> str:
|
||||
api_key = credentials.get("api_key", "")
|
||||
secret_key = credentials.get("secret_key", "")
|
||||
api_hint = self._mask_key(api_key) if api_key else "***"
|
||||
secret_hint = self._mask_key(secret_key) if secret_key else "***"
|
||||
return f"{api_hint}|{secret_hint}"
|
||||
|
||||
def get_key(self, engine_type: str, user_id: str | None = None) -> str | dict | None:
|
||||
configs = self._keys.get(engine_type, [])
|
||||
config = self._find_best_key(configs, user_id)
|
||||
if not config:
|
||||
return None
|
||||
decrypted = self._decrypt(config.encrypted_key)
|
||||
try:
|
||||
return json.loads(decrypted)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return decrypted
|
||||
|
||||
def get_credentials(self, engine_type: str, user_id: str | None = None) -> KeyCredentials | None:
|
||||
key_data = self.get_key(engine_type, user_id)
|
||||
if not key_data:
|
||||
return None
|
||||
if isinstance(key_data, dict):
|
||||
return KeyCredentials(
|
||||
api_key=key_data.get("api_key", ""),
|
||||
secret_key=key_data.get("secret_key")
|
||||
)
|
||||
return KeyCredentials(api_key=key_data)
|
||||
|
||||
def get_all_keys(self, engine_type: str, user_id: str | None = None) -> dict | str | None:
|
||||
return self.get_key(engine_type, user_id)
|
||||
|
||||
def _find_best_key(self, configs: list[APIKeyConfig], user_id: str | None = None) -> APIKeyConfig | None:
|
||||
if user_id:
|
||||
for c in configs:
|
||||
if (
|
||||
|
|
@ -85,10 +139,10 @@ class APIKeyManager:
|
|||
and c.user_id == user_id
|
||||
and c.status in self._USABLE_STATUSES
|
||||
):
|
||||
return self._decrypt(c.encrypted_key)
|
||||
return c
|
||||
for c in configs:
|
||||
if c.key_source in self._FALLBACK_SOURCES and c.status in self._USABLE_STATUSES:
|
||||
return self._decrypt(c.encrypted_key)
|
||||
return c
|
||||
return None
|
||||
|
||||
def get_any_available_key(self, engine_type: str) -> str | None:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
|
|
@ -8,6 +9,17 @@ from sqlalchemy.pool import StaticPool
|
|||
|
||||
from app.database import Base
|
||||
from app.models.user import User
|
||||
from app.middleware.logging_filter import APIKeyFilter
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def add_api_key_filter():
|
||||
"""自动为每个测试添加APIKeyFilter到root logger"""
|
||||
root_logger = logging.getLogger()
|
||||
api_key_filter = APIKeyFilter()
|
||||
root_logger.addFilter(api_key_filter)
|
||||
yield
|
||||
root_logger.removeFilter(api_key_filter)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
|
|
|
|||
|
|
@ -0,0 +1,110 @@
|
|||
import pytest
|
||||
from unittest.mock import patch, AsyncMock, MagicMock, PropertyMock
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from app.api.lifecycle import project_stats
|
||||
|
||||
|
||||
class TestLifecycleExceptionHandling:
|
||||
"""测试 lifecycle.py 中的异常处理行为"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_project_stats_handles_content_query_failure(self, caplog):
|
||||
"""测试 project_stats 当 Content 查询失败时的处理"""
|
||||
from app.models.user import User
|
||||
|
||||
org_id = uuid.uuid4()
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
user = User(
|
||||
id=user_id,
|
||||
email="test@example.com",
|
||||
password_hash="hash",
|
||||
name="Test User",
|
||||
plan="free",
|
||||
organization_id=org_id,
|
||||
)
|
||||
|
||||
execute_results = [
|
||||
MagicMock(one=MagicMock(total=0, active=0)),
|
||||
MagicMock(),
|
||||
MagicMock(),
|
||||
MagicMock(),
|
||||
MagicMock(),
|
||||
]
|
||||
execute_results[1].all.return_value = []
|
||||
execute_results[2].scalar.return_value = 0
|
||||
execute_results[3].all.return_value = []
|
||||
execute_results[4].all.return_value = []
|
||||
execute_count = [0]
|
||||
|
||||
def execute_side_effect(*args, **kwargs):
|
||||
idx = execute_count[0]
|
||||
execute_count[0] += 1
|
||||
if idx == 2:
|
||||
raise RuntimeError("Content table not available")
|
||||
return execute_results[idx]
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute.side_effect = execute_side_effect
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
result = await project_stats(
|
||||
db=mock_session,
|
||||
current_user=user
|
||||
)
|
||||
|
||||
assert result.contents_produced == 0
|
||||
assert any("Failed to count contents" in record.message for record in caplog.records)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_project_stats_handles_citation_query_failure(self, caplog):
|
||||
"""测试 project_stats 当 CitationRecord 查询失败时的处理"""
|
||||
from app.models.user import User
|
||||
|
||||
org_id = uuid.uuid4()
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
user = User(
|
||||
id=user_id,
|
||||
email="test@example.com",
|
||||
password_hash="hash",
|
||||
name="Test User",
|
||||
plan="free",
|
||||
organization_id=org_id,
|
||||
)
|
||||
|
||||
execute_results = [
|
||||
MagicMock(one=MagicMock(total=0, active=0)),
|
||||
MagicMock(),
|
||||
MagicMock(),
|
||||
MagicMock(),
|
||||
MagicMock(),
|
||||
MagicMock(),
|
||||
]
|
||||
execute_results[1].all.return_value = []
|
||||
execute_results[2].scalar.return_value = 0
|
||||
execute_results[3].all.return_value = []
|
||||
execute_results[4].one.return_value = MagicMock(total_citations=0, cited_count=0)
|
||||
execute_results[5].all.return_value = []
|
||||
execute_count = [0]
|
||||
|
||||
def execute_side_effect(*args, **kwargs):
|
||||
idx = execute_count[0]
|
||||
execute_count[0] += 1
|
||||
if idx == 4:
|
||||
raise RuntimeError("CitationRecord table not available")
|
||||
return execute_results[idx]
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute.side_effect = execute_side_effect
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
result = await project_stats(
|
||||
db=mock_session,
|
||||
current_user=user
|
||||
)
|
||||
|
||||
assert result.avg_ai_citation_rate is None
|
||||
assert any("Failed to calculate AI citation rate" in record.message for record in caplog.records)
|
||||
|
|
@ -0,0 +1,267 @@
|
|||
"""组织管理API测试 - 验证 /api/v1/organization/* 端点"""
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import AsyncClient, ASGITransport
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from app.database import Base
|
||||
from app.main import app
|
||||
from app.models.user import User
|
||||
from app.models.organization import Organization, OrgMember
|
||||
from app.api.deps import get_current_user, get_db
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def async_engine():
|
||||
engine = create_async_engine(
|
||||
"sqlite+aiosqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def async_session(async_engine):
|
||||
async_session_maker = async_sessionmaker(
|
||||
async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
autocommit=False,
|
||||
)
|
||||
async with async_session_maker() as session:
|
||||
yield session
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_user(async_session):
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="test@example.com",
|
||||
password_hash="hashed_password",
|
||||
name="Test User",
|
||||
plan="free",
|
||||
max_queries=5,
|
||||
is_active=True,
|
||||
email_verified=True,
|
||||
)
|
||||
async_session.add(user)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_organization(async_session, test_user):
|
||||
org = Organization(
|
||||
id=uuid.uuid4(),
|
||||
name="Test Organization",
|
||||
slug="test-org",
|
||||
plan="free",
|
||||
)
|
||||
async_session.add(org)
|
||||
await async_session.flush()
|
||||
|
||||
test_user.organization_id = org.id
|
||||
async_session.add(test_user)
|
||||
|
||||
membership = OrgMember(
|
||||
id=uuid.uuid4(),
|
||||
organization_id=org.id,
|
||||
user_id=test_user.id,
|
||||
role="owner",
|
||||
)
|
||||
async_session.add(membership)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(org)
|
||||
return org
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def async_client(async_session, test_user):
|
||||
async def override_get_db():
|
||||
yield async_session
|
||||
|
||||
async def override_get_current_user():
|
||||
return test_user
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_get_current_user
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
yield client
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
class TestOrganizationRoutes:
|
||||
"""组织管理API端点测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_routes_are_registered(self):
|
||||
"""验证组织管理路由已正确注册到app"""
|
||||
from app.main import app as main_app
|
||||
|
||||
org_routes = [
|
||||
route.path for route in main_app.routes
|
||||
if hasattr(route, 'path') and route.path.startswith('/api/v1/organization')
|
||||
]
|
||||
print(f"\n已注册的组织路由: {org_routes}")
|
||||
assert '/api/v1/organization/info' in org_routes, "路由未注册"
|
||||
assert '/api/v1/organization/members' in org_routes, "路由未注册"
|
||||
assert '/api/v1/organization/members/invite' in org_routes, "路由未注册"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_direct_endpoint_call(self, async_session, test_user, test_organization):
|
||||
"""直接测试端点调用"""
|
||||
async def override_get_db():
|
||||
yield async_session
|
||||
|
||||
async def override_get_current_user():
|
||||
return test_user
|
||||
|
||||
from app.main import app as test_app
|
||||
test_app.dependency_overrides[get_db] = override_get_db
|
||||
test_app.dependency_overrides[get_current_user] = override_get_current_user
|
||||
|
||||
transport = ASGITransport(app=test_app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
# Check which routes are available
|
||||
available_routes = [r.path for r in test_app.routes if hasattr(r, 'path') and 'organization' in r.path]
|
||||
print(f"\n测试app中的组织路由: {available_routes}")
|
||||
|
||||
response = await client.get("/api/v1/organization/info")
|
||||
print(f"Response status: {response.status_code}")
|
||||
print(f"Response: {response.text[:200]}")
|
||||
assert response.status_code == 200
|
||||
|
||||
test_app.dependency_overrides.clear()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_organization_info_endpoint_exists(self, async_client, test_organization):
|
||||
"""验证 /api/v1/organization/info 端点存在并返回200"""
|
||||
response = await async_client.get("/api/v1/organization/info")
|
||||
assert response.status_code == 200, f"期望返回200,实际返回 {response.status_code}"
|
||||
data = response.json()
|
||||
assert "id" in data
|
||||
assert "name" in data
|
||||
assert "slug" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_organization_members_endpoint_exists(self, async_client, test_organization):
|
||||
"""验证 /api/v1/organization/members 端点存在并返回200"""
|
||||
response = await async_client.get("/api/v1/organization/members")
|
||||
assert response.status_code == 200, f"期望返回200,实际返回 {response.status_code}"
|
||||
data = response.json()
|
||||
assert isinstance(data, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_organization_members_invite_endpoint_exists(self, async_client, test_organization, async_session):
|
||||
"""验证 /api/v1/organization/members/invite 端点存在"""
|
||||
invite_user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="newuser@example.com",
|
||||
password_hash="hashed_password",
|
||||
name="New User",
|
||||
plan="free",
|
||||
max_queries=5,
|
||||
is_active=True,
|
||||
email_verified=True,
|
||||
)
|
||||
async_session.add(invite_user)
|
||||
await async_session.commit()
|
||||
|
||||
response = await async_client.post(
|
||||
"/api/v1/organization/members/invite",
|
||||
json={"email": "newuser@example.com", "role": "viewer"}
|
||||
)
|
||||
assert response.status_code == 201, f"期望返回201,实际返回 {response.status_code}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_organization_member_role_endpoint_exists(self, async_client, test_organization, async_session, test_user):
|
||||
"""验证 /api/v1/organization/members/{id}/role 端点存在"""
|
||||
new_user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="member@example.com",
|
||||
password_hash="hashed_password",
|
||||
name="Member User",
|
||||
plan="free",
|
||||
max_queries=5,
|
||||
is_active=True,
|
||||
email_verified=True,
|
||||
)
|
||||
async_session.add(new_user)
|
||||
await async_session.flush()
|
||||
|
||||
membership = OrgMember(
|
||||
id=uuid.uuid4(),
|
||||
organization_id=test_organization.id,
|
||||
user_id=new_user.id,
|
||||
role="viewer",
|
||||
)
|
||||
async_session.add(membership)
|
||||
await async_session.commit()
|
||||
|
||||
response = await async_client.put(
|
||||
f"/api/v1/organization/members/{new_user.id}/role",
|
||||
json={"role": "admin"}
|
||||
)
|
||||
assert response.status_code == 200, f"期望返回200,实际返回 {response.status_code}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_organization_member_delete_endpoint_exists(self, async_client, test_organization, async_session):
|
||||
"""验证 /api/v1/organization/members/{id} 端点存在"""
|
||||
new_user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="todelete@example.com",
|
||||
password_hash="hashed_password",
|
||||
name="Delete User",
|
||||
plan="free",
|
||||
max_queries=5,
|
||||
is_active=True,
|
||||
email_verified=True,
|
||||
)
|
||||
async_session.add(new_user)
|
||||
await async_session.flush()
|
||||
|
||||
membership = OrgMember(
|
||||
id=uuid.uuid4(),
|
||||
organization_id=test_organization.id,
|
||||
user_id=new_user.id,
|
||||
role="viewer",
|
||||
)
|
||||
async_session.add(membership)
|
||||
await async_session.commit()
|
||||
|
||||
response = await async_client.delete(f"/api/v1/organization/members/{new_user.id}")
|
||||
assert response.status_code == 204, f"期望返回204,实际返回 {response.status_code}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_without_organization_returns_404(self, async_client, test_user, async_session):
|
||||
"""验证未加入组织的用户访问 /api/v1/organization/info 返回404"""
|
||||
response = await async_client.get("/api/v1/organization/info")
|
||||
assert response.status_code == 404, f"期望返回404,实际返回 {response.status_code}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unauthorized_access(self, async_session):
|
||||
"""验证未授权访问返回401"""
|
||||
async def override_get_db():
|
||||
yield async_session
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides.pop(get_current_user, None)
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.get("/api/v1/organization/info")
|
||||
assert response.status_code == 401, f"期望返回401,实际返回 {response.status_code}"
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
"""平台API路由测试 - 验证不存在双重前缀问题"""
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import AsyncClient, ASGITransport
|
||||
|
||||
from app.main import app
|
||||
|
||||
|
||||
class TestPlatformsRoutePrefix:
|
||||
"""平台路由前缀测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_platforms_endpoint_not_double_prefixed(self):
|
||||
"""
|
||||
验证platforms端点路径不是 /api/v1/api/*
|
||||
|
||||
预期:
|
||||
- /api/v1/api/platforms/health 应该返回404(不存在)
|
||||
- /api/v1/platforms/health 应该返回200(正确路径)
|
||||
"""
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
wrong_path_response = await client.get("/api/v1/api/platforms/health")
|
||||
correct_path_response = await client.get("/api/v1/platforms/health")
|
||||
|
||||
assert wrong_path_response.status_code == 404, \
|
||||
f"错误:/api/v1/api/platforms/health 返回 {wrong_path_response.status_code},应该是404(说明存在双重前缀问题)"
|
||||
|
||||
assert correct_path_response.status_code == 200, \
|
||||
f"错误:/api/v1/platforms/health 返回 {correct_path_response.status_code},应该是200(端点不可用)"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_platforms_health_endpoint_accessible(self):
|
||||
"""验证platforms健康检查端点可通过正确路径访问"""
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.get("/api/v1/platforms/health")
|
||||
|
||||
assert response.status_code == 200, \
|
||||
f"平台健康检查端点无法访问,状态码: {response.status_code}"
|
||||
|
||||
data = response.json()
|
||||
assert "platforms" in data
|
||||
assert "total" in data
|
||||
assert "configured_count" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_platforms_health_by_name_endpoint(self):
|
||||
"""验证platforms按名称健康检查端点可通过正确路径访问"""
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.get("/api/v1/platforms/health/kimi")
|
||||
|
||||
assert response.status_code == 200, \
|
||||
f"平台按名称健康检查端点无法访问,状态码: {response.status_code}"
|
||||
|
||||
data = response.json()
|
||||
assert "name" in data
|
||||
assert data["name"] == "kimi"
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
import pytest
|
||||
from unittest.mock import patch, AsyncMock, MagicMock
|
||||
from sqlalchemy.exc import SQLAlchemyError, OperationalError
|
||||
|
||||
from app.database import check_db_connection
|
||||
|
||||
|
||||
class TestDatabaseExceptionHandling:
|
||||
"""测试 database.py 中的异常处理行为"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_db_connection_handles_sqlalchemy_error(self):
|
||||
"""测试 check_db_connection 对 SQLAlchemyError 的处理"""
|
||||
with patch('app.database.AsyncSessionLocal') as mock_session_local:
|
||||
mock_session = AsyncMock()
|
||||
mock_session_local.return_value.__aenter__.return_value = mock_session
|
||||
mock_session.execute.side_effect = OperationalError(
|
||||
statement="SELECT 1",
|
||||
params={},
|
||||
orig=Exception("Connection refused")
|
||||
)
|
||||
|
||||
with pytest.raises(ConnectionError, match="Failed to connect to database"):
|
||||
await check_db_connection()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_db_connection_handles_generic_exception(self):
|
||||
"""测试 check_db_connection 对通用异常的处理"""
|
||||
with patch('app.database.AsyncSessionLocal') as mock_session_local:
|
||||
mock_session = AsyncMock()
|
||||
mock_session_local.return_value.__aenter__.return_value = mock_session
|
||||
mock_session.execute.side_effect = RuntimeError("Unexpected error")
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await check_db_connection()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_db_connection_success(self):
|
||||
"""测试 check_db_connection 成功时返回 True"""
|
||||
with patch('app.database.AsyncSessionLocal') as mock_session_local:
|
||||
mock_session = AsyncMock()
|
||||
mock_session_local.return_value.__aenter__.return_value = mock_session
|
||||
|
||||
result = await check_db_connection()
|
||||
assert result is True
|
||||
|
|
@ -0,0 +1,107 @@
|
|||
"""测试日志过滤器功能 - TDD RED阶段"""
|
||||
import logging
|
||||
import pytest
|
||||
|
||||
API_KEY = "AIzaSyDummyKey123456789abcdefghijklmnop"
|
||||
SK_TEST_KEY = "sk_test_1234567890abcdefghijklmnopqrstuvwxyz1234567890"
|
||||
|
||||
|
||||
class TestAPIKeyFilter:
|
||||
"""测试APIKeyFilter过滤器"""
|
||||
|
||||
def test_api_key_filtered_from_logs(self, caplog):
|
||||
"""验证API Key从日志中过滤"""
|
||||
with caplog.at_level(logging.INFO):
|
||||
logging.info(f"Request to Gemini API with key={API_KEY}")
|
||||
|
||||
assert API_KEY not in caplog.text, f"API Key should not appear in logs, but found: {API_KEY}"
|
||||
|
||||
def test_api_key_filtered_from_url(self, caplog):
|
||||
"""验证URL中的API Key从日志中过滤"""
|
||||
with caplog.at_level(logging.INFO):
|
||||
logging.info(f"Making request to https://api.gemini.com?key={API_KEY}&other=value")
|
||||
|
||||
assert API_KEY not in caplog.text
|
||||
assert "key=***REDACTED***" in caplog.text
|
||||
|
||||
def test_sk_test_key_filtered(self, caplog):
|
||||
"""验证sk_test密钥格式也被过滤"""
|
||||
with caplog.at_level(logging.INFO):
|
||||
logging.info(f"Request with api_key={SK_TEST_KEY}")
|
||||
|
||||
assert SK_TEST_KEY not in caplog.text
|
||||
|
||||
def test_bearer_token_filtered(self, caplog):
|
||||
"""验证Bearer token从日志中过滤"""
|
||||
with caplog.at_level(logging.INFO):
|
||||
logging.info("Request with Bearer token abcdefghijklmnopqrstuvwxyz")
|
||||
|
||||
assert "Bearer abcdefghijklmnopqrstuvwxyz" not in caplog.text
|
||||
|
||||
def test_authorization_header_filtered(self, caplog):
|
||||
"""验证Authorization header从日志中过滤"""
|
||||
with caplog.at_level(logging.INFO):
|
||||
logging.info("Request with Authorization: secret_api_key_here_1234567890")
|
||||
|
||||
assert "secret_api_key_here_1234567890" not in caplog.text
|
||||
|
||||
def test_normal_log_preserved(self, caplog):
|
||||
"""验证普通日志内容不受影响"""
|
||||
normal_message = "This is a normal log message without any secrets"
|
||||
|
||||
with caplog.at_level(logging.INFO):
|
||||
logging.info(normal_message)
|
||||
|
||||
assert normal_message in caplog.text
|
||||
|
||||
|
||||
class TestAPIKeyFilterDirectly:
|
||||
"""直接测试APIKeyFilter类"""
|
||||
|
||||
def test_filter_class_exists(self):
|
||||
"""验证APIKeyFilter类存在"""
|
||||
from app.middleware.logging_filter import APIKeyFilter
|
||||
assert APIKeyFilter is not None
|
||||
|
||||
def test_filter_redacts_url_key_param(self):
|
||||
"""验证过滤器直接处理URL中的key参数时正确脱敏"""
|
||||
from app.middleware.logging_filter import APIKeyFilter
|
||||
|
||||
filter_instance = APIKeyFilter()
|
||||
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="",
|
||||
lineno=0,
|
||||
msg=f"Request to https://api.gemini.com?key={API_KEY}",
|
||||
args=(),
|
||||
exc_info=None
|
||||
)
|
||||
|
||||
result = filter_instance.filter(record)
|
||||
|
||||
assert result is True
|
||||
assert API_KEY not in record.msg
|
||||
assert "key=***REDACTED***" in record.msg
|
||||
|
||||
def test_filter_preserves_normal_content(self):
|
||||
"""验证过滤器保留正常日志内容"""
|
||||
from app.middleware.logging_filter import APIKeyFilter
|
||||
|
||||
filter_instance = APIKeyFilter()
|
||||
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="",
|
||||
lineno=0,
|
||||
msg="Normal log message without secrets",
|
||||
args=(),
|
||||
exc_info=None
|
||||
)
|
||||
|
||||
result = filter_instance.filter(record)
|
||||
|
||||
assert result is True
|
||||
assert record.msg == "Normal log message without secrets"
|
||||
|
|
@ -0,0 +1,148 @@
|
|||
import pytest
|
||||
|
||||
from app.services.api_key_manager import APIKeyManager, KeySource
|
||||
|
||||
|
||||
class TestKeyCredentials:
|
||||
def test_key_credentials_single_key(self):
|
||||
from app.services.api_key_manager import KeyCredentials
|
||||
creds = KeyCredentials(api_key="test-api-key-123")
|
||||
assert creds.api_key == "test-api-key-123"
|
||||
assert creds.secret_key is None
|
||||
assert creds.to_dict() == {"api_key": "test-api-key-123"}
|
||||
|
||||
def test_key_credentials_dual_keys(self):
|
||||
from app.services.api_key_manager import KeyCredentials
|
||||
creds = KeyCredentials(api_key="test-api-key-123", secret_key="test-secret-key-456")
|
||||
assert creds.api_key == "test-api-key-123"
|
||||
assert creds.secret_key == "test-secret-key-456"
|
||||
assert creds.to_dict() == {"api_key": "test-api-key-123", "secret_key": "test-secret-key-456"}
|
||||
|
||||
|
||||
class TestAPIKeyManagerDualKeys:
|
||||
@pytest.fixture
|
||||
def manager(self):
|
||||
return APIKeyManager()
|
||||
|
||||
def test_add_dual_keys_dict_for_wenxin(self, manager):
|
||||
"""测试为文心一言添加双密钥字典"""
|
||||
credentials = {"api_key": "baidu-api-key-xxx", "secret_key": "baidu-secret-key-yyy"}
|
||||
config = manager.add_key("wenxin", credentials, source=KeySource.USER)
|
||||
assert config.engine_type == "wenxin"
|
||||
assert config.key_source == KeySource.USER
|
||||
|
||||
def test_add_dual_keys_and_retrieve(self, manager):
|
||||
"""测试添加双密钥后能够正确获取"""
|
||||
credentials = {"api_key": "baidu-api-key-xxx", "secret_key": "baidu-secret-key-yyy"}
|
||||
manager.add_key("wenxin", credentials, source=KeySource.SYSTEM)
|
||||
retrieved = manager.get_key("wenxin")
|
||||
assert retrieved is not None
|
||||
assert isinstance(retrieved, dict)
|
||||
assert retrieved["api_key"] == "baidu-api-key-xxx"
|
||||
assert retrieved["secret_key"] == "baidu-secret-key-yyy"
|
||||
|
||||
def test_get_credentials_returns_complete_credentials(self, manager):
|
||||
"""测试获取完整凭证(包含api_key和secret_key)"""
|
||||
credentials = {"api_key": "baidu-api-key-xxx", "secret_key": "baidu-secret-key-yyy"}
|
||||
manager.add_key("wenxin", credentials, source=KeySource.SYSTEM)
|
||||
creds = manager.get_credentials("wenxin")
|
||||
assert creds is not None
|
||||
assert creds.api_key == "baidu-api-key-xxx"
|
||||
assert creds.secret_key == "baidu-secret-key-yyy"
|
||||
|
||||
def test_get_credentials_with_single_key(self, manager):
|
||||
"""测试获取单Key凭证(兼容现有逻辑)"""
|
||||
manager.add_key("chatgpt", "sk-chatgpt-key-1234567890", source=KeySource.SYSTEM)
|
||||
creds = manager.get_credentials("chatgpt")
|
||||
assert creds is not None
|
||||
assert creds.api_key == "sk-chatgpt-key-1234567890"
|
||||
assert creds.secret_key is None
|
||||
|
||||
def test_get_credentials_returns_none_for_unknown_engine(self, manager):
|
||||
"""测试获取未知引擎凭证返回None"""
|
||||
creds = manager.get_credentials("unknown_engine")
|
||||
assert creds is None
|
||||
|
||||
def test_add_single_key_string_still_works(self, manager):
|
||||
"""测试单Key字符串格式仍然正常工作"""
|
||||
config = manager.add_key("chatgpt", "sk-test-key-1234567890", source=KeySource.SYSTEM)
|
||||
assert config.engine_type == "chatgpt"
|
||||
key = manager.get_key("chatgpt")
|
||||
assert key == "sk-test-key-1234567890"
|
||||
|
||||
def test_add_dual_keys_string_auto_wrap(self, manager):
|
||||
"""测试单Key字符串自动包装为api_key格式"""
|
||||
manager.add_key("wenxin", "single-key-value", source=KeySource.SYSTEM)
|
||||
creds = manager.get_credentials("wenxin")
|
||||
assert creds is not None
|
||||
assert creds.api_key == "single-key-value"
|
||||
assert creds.secret_key is None
|
||||
|
||||
def test_get_all_keys_returns_dict_for_dual_key_engine(self, manager):
|
||||
"""测试获取所有Key信息时返回完整字典"""
|
||||
credentials = {"api_key": "api-key-xxx", "secret_key": "secret-key-yyy"}
|
||||
manager.add_key("wenxin", credentials, source=KeySource.SYSTEM)
|
||||
keys = manager.get_all_keys("wenxin")
|
||||
assert keys is not None
|
||||
assert "api_key" in keys
|
||||
assert "secret_key" in keys
|
||||
assert keys["api_key"] == "api-key-xxx"
|
||||
assert keys["secret_key"] == "secret-key-yyy"
|
||||
|
||||
def test_env_mapping_includes_dual_keys(self):
|
||||
"""测试环境变量映射支持文心一言双Key"""
|
||||
manager = APIKeyManager()
|
||||
assert "wenxin" in manager._ENV_MAPPING
|
||||
assert manager._ENV_MAPPING["wenxin"] == "BAIDU_QIANFAN_API_KEY"
|
||||
|
||||
def test_env_loading_for_dual_key_engine(self, manager, monkeypatch):
|
||||
"""测试加载环境变量时正确处理文心一言双Key"""
|
||||
monkeypatch.setenv("BAIDU_QIANFAN_API_KEY", "env-api-key-xxx")
|
||||
monkeypatch.setenv("BAIDU_SECRET_KEY", "env-secret-key-yyy")
|
||||
manager.load_env_keys()
|
||||
creds = manager.get_credentials("wenxin")
|
||||
assert creds is not None
|
||||
assert creds.api_key == "env-api-key-xxx"
|
||||
|
||||
|
||||
class TestWenxinAdapterDualKeys:
|
||||
def test_wenxin_adapter_gets_dual_keys_from_manager(self):
|
||||
"""测试文心一言适配器从KeyManager获取双密钥"""
|
||||
from app.services.ai_engine.wenxin import WenxinAdapter
|
||||
from app.services.api_key_manager import APIKeyManager, KeySource
|
||||
|
||||
manager = APIKeyManager()
|
||||
credentials = {"api_key": "test-api-key-123", "secret_key": "test-secret-key-456"}
|
||||
manager.add_key("wenxin", credentials, source=KeySource.SYSTEM)
|
||||
|
||||
adapter = WenxinAdapter(key_manager=manager)
|
||||
assert adapter.api_key == "test-api-key-123"
|
||||
assert adapter.secret_key == "test-secret-key-456"
|
||||
|
||||
def test_wenxin_adapter_with_single_key_fallback(self):
|
||||
"""测试文心一言适配器单Key时的降级处理"""
|
||||
from app.services.ai_engine.wenxin import WenxinAdapter
|
||||
from app.services.api_key_manager import APIKeyManager, KeySource
|
||||
|
||||
manager = APIKeyManager()
|
||||
manager.add_key("wenxin", "single-key-only", source=KeySource.SYSTEM)
|
||||
|
||||
adapter = WenxinAdapter(key_manager=manager)
|
||||
assert adapter.api_key == "single-key-only"
|
||||
assert adapter.secret_key == ""
|
||||
|
||||
def test_wenxin_adapter_direct_credentials_take_precedence(self):
|
||||
"""测试直接传入的credentials优先于KeyManager"""
|
||||
from app.services.ai_engine.wenxin import WenxinAdapter
|
||||
from app.services.api_key_manager import APIKeyManager, KeySource
|
||||
|
||||
manager = APIKeyManager()
|
||||
manager.add_key("wenxin", {"api_key": "manager-api", "secret_key": "manager-secret"}, source=KeySource.SYSTEM)
|
||||
|
||||
adapter = WenxinAdapter(
|
||||
api_key="direct-api",
|
||||
secret_key="direct-secret",
|
||||
key_manager=manager
|
||||
)
|
||||
assert adapter.api_key == "direct-api"
|
||||
assert adapter.secret_key == "direct-secret"
|
||||
|
|
@ -0,0 +1,178 @@
|
|||
import pytest
|
||||
from unittest.mock import patch, MagicMock, call
|
||||
import httpx
|
||||
import asyncio
|
||||
|
||||
from app.services.ai_engine.base import AIEngineAdapter, EngineType
|
||||
|
||||
|
||||
class TestBatchQueryExceptionHandling:
|
||||
"""测试 batch_query.py 中的异常处理行为"""
|
||||
|
||||
def test_build_adapters_handles_generic_exception(self, caplog):
|
||||
"""测试 _build_adapters 对通用异常的处理(应记录 error 日志)"""
|
||||
from app.services.ai_engine.batch_query import _build_adapters, _ADAPTER_CLASSES
|
||||
|
||||
class FailingAdapter(AIEngineAdapter):
|
||||
def __init__(self):
|
||||
super().__init__(api_key="test")
|
||||
raise RuntimeError("Simulated initialization failure")
|
||||
|
||||
def get_engine_type(self):
|
||||
return EngineType.CHATGPT
|
||||
|
||||
def _get_env_key(self):
|
||||
return "TEST_KEY"
|
||||
|
||||
async def query(self, query, brand_name, competitor_names=None):
|
||||
pass
|
||||
|
||||
with patch.dict(_ADAPTER_CLASSES, {EngineType.CHATGPT: FailingAdapter}):
|
||||
_build_adapters.cache_clear()
|
||||
result = _build_adapters()
|
||||
assert EngineType.CHATGPT.value not in result
|
||||
|
||||
def test_build_adapters_handles_httpx_http_error(self, caplog):
|
||||
"""测试 _build_adapters 对 httpx.HTTPError 的处理(应记录 warning 日志)"""
|
||||
from app.services.ai_engine.batch_query import _build_adapters, _ADAPTER_CLASSES
|
||||
|
||||
class HttpErrorAdapter(AIEngineAdapter):
|
||||
def __init__(self):
|
||||
super().__init__(api_key="test")
|
||||
raise httpx.HTTPError("HTTP connection failed")
|
||||
|
||||
def get_engine_type(self):
|
||||
return EngineType.PERPLEXITY
|
||||
|
||||
def _get_env_key(self):
|
||||
return "TEST_KEY"
|
||||
|
||||
async def query(self, query, brand_name, competitor_names=None):
|
||||
pass
|
||||
|
||||
with patch.dict(_ADAPTER_CLASSES, {EngineType.PERPLEXITY: HttpErrorAdapter}):
|
||||
_build_adapters.cache_clear()
|
||||
result = _build_adapters()
|
||||
assert EngineType.PERPLEXITY.value not in result
|
||||
assert any("HTTP error" in record.message for record in caplog.records)
|
||||
|
||||
def test_build_adapters_handles_timeout_error(self, caplog):
|
||||
"""测试 _build_adapters 对 asyncio.TimeoutError 的处理(应记录 warning 日志)"""
|
||||
from app.services.ai_engine.batch_query import _build_adapters, _ADAPTER_CLASSES
|
||||
|
||||
class TimeoutAdapter(AIEngineAdapter):
|
||||
def __init__(self):
|
||||
super().__init__(api_key="test")
|
||||
raise asyncio.TimeoutError("Connection timeout")
|
||||
|
||||
def get_engine_type(self):
|
||||
return EngineType.KIMI
|
||||
|
||||
def _get_env_key(self):
|
||||
return "TEST_KEY"
|
||||
|
||||
async def query(self, query, brand_name, competitor_names=None):
|
||||
pass
|
||||
|
||||
with patch.dict(_ADAPTER_CLASSES, {EngineType.KIMI: TimeoutAdapter}):
|
||||
_build_adapters.cache_clear()
|
||||
result = _build_adapters()
|
||||
assert EngineType.KIMI.value not in result
|
||||
assert any("Timeout" in record.message for record in caplog.records)
|
||||
|
||||
def test_build_adapters_with_key_manager_handles_httpx_http_error(self, caplog):
|
||||
"""测试 _build_adapters_with_key_manager 对 httpx.HTTPError 的处理"""
|
||||
from app.services.ai_engine.batch_query import _build_adapters_with_key_manager, _ADAPTER_CLASSES
|
||||
|
||||
class HttpErrorAdapter(AIEngineAdapter):
|
||||
def __init__(self, key_manager=None, user_id=None):
|
||||
super().__init__(api_key="test")
|
||||
raise httpx.HTTPError("HTTP connection failed")
|
||||
|
||||
def get_engine_type(self):
|
||||
return EngineType.WENXIN
|
||||
|
||||
def _get_env_key(self):
|
||||
return "TEST_KEY"
|
||||
|
||||
async def query(self, query, brand_name, competitor_names=None):
|
||||
pass
|
||||
|
||||
with patch.dict(_ADAPTER_CLASSES, {EngineType.WENXIN: HttpErrorAdapter}):
|
||||
result = _build_adapters_with_key_manager(
|
||||
key_manager=MagicMock(),
|
||||
user_id="test_user"
|
||||
)
|
||||
assert EngineType.WENXIN.value not in result
|
||||
assert any("HTTP error" in record.message for record in caplog.records)
|
||||
|
||||
def test_build_adapters_with_key_manager_handles_timeout_error(self, caplog):
|
||||
"""测试 _build_adapters_with_key_manager 对 asyncio.TimeoutError 的处理"""
|
||||
from app.services.ai_engine.batch_query import _build_adapters_with_key_manager, _ADAPTER_CLASSES
|
||||
|
||||
class TimeoutAdapter(AIEngineAdapter):
|
||||
def __init__(self, key_manager=None, user_id=None):
|
||||
super().__init__(api_key="test")
|
||||
raise asyncio.TimeoutError("Connection timeout")
|
||||
|
||||
def get_engine_type(self):
|
||||
return EngineType.DEEPSEEK
|
||||
|
||||
def _get_env_key(self):
|
||||
return "TEST_KEY"
|
||||
|
||||
async def query(self, query, brand_name, competitor_names=None):
|
||||
pass
|
||||
|
||||
with patch.dict(_ADAPTER_CLASSES, {EngineType.DEEPSEEK: TimeoutAdapter}):
|
||||
result = _build_adapters_with_key_manager(
|
||||
key_manager=MagicMock(),
|
||||
user_id="test_user"
|
||||
)
|
||||
assert EngineType.DEEPSEEK.value not in result
|
||||
assert any("Timeout" in record.message for record in caplog.records)
|
||||
|
||||
def test_build_adapters_with_key_manager_handles_generic_exception(self, caplog):
|
||||
"""测试 _build_adapters_with_key_manager 对通用异常的处理"""
|
||||
from app.services.ai_engine.batch_query import _build_adapters_with_key_manager, _ADAPTER_CLASSES
|
||||
|
||||
class FailingAdapter(AIEngineAdapter):
|
||||
def __init__(self, key_manager=None, user_id=None):
|
||||
super().__init__(api_key="test")
|
||||
raise RuntimeError("Simulated key manager failure")
|
||||
|
||||
def get_engine_type(self):
|
||||
return EngineType.QWEN
|
||||
|
||||
def _get_env_key(self):
|
||||
return "TEST_KEY"
|
||||
|
||||
async def query(self, query, brand_name, competitor_names=None):
|
||||
pass
|
||||
|
||||
with patch.dict(_ADAPTER_CLASSES, {EngineType.QWEN: FailingAdapter}):
|
||||
result = _build_adapters_with_key_manager(
|
||||
key_manager=MagicMock(),
|
||||
user_id="test_user"
|
||||
)
|
||||
assert EngineType.QWEN.value not in result
|
||||
|
||||
def test_register_adapter_handles_exception(self):
|
||||
"""测试 register_adapter 函数对异常的处理"""
|
||||
from app.services.ai_engine.batch_query import register_adapter
|
||||
|
||||
class BadAdapter(AIEngineAdapter):
|
||||
def __init__(self):
|
||||
super().__init__(api_key="test")
|
||||
raise ValueError("Adapter registration failed")
|
||||
|
||||
def get_engine_type(self):
|
||||
return EngineType.GEMINI
|
||||
|
||||
def _get_env_key(self):
|
||||
return "TEST_KEY"
|
||||
|
||||
async def query(self, query, brand_name, competitor_names=None):
|
||||
pass
|
||||
|
||||
register_adapter(BadAdapter)
|
||||
|
|
@ -0,0 +1,70 @@
|
|||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import logging
|
||||
|
||||
from app.services.usage_tracker import UsageTracker
|
||||
|
||||
|
||||
class TestUsageTrackerExceptionHandling:
|
||||
"""测试 usage_tracker.py 中的异常处理行为"""
|
||||
|
||||
def test_usage_tracker_record_handles_exception(self, caplog):
|
||||
"""测试 UsageTracker.record 方法对异常的处理"""
|
||||
with caplog.at_level(logging.WARNING):
|
||||
tracker = UsageTracker()
|
||||
|
||||
with patch.object(tracker, '_records', side_effect=Exception("Simulated error")):
|
||||
try:
|
||||
tracker.record(
|
||||
user_id="test_user",
|
||||
brand_id="test_brand",
|
||||
engine_type="chatgpt",
|
||||
query="test query",
|
||||
input_tokens=100,
|
||||
output_tokens=200,
|
||||
cost=0.05,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def test_usage_tracker_get_summary_handles_empty_records(self):
|
||||
"""测试 UsageTracker.get_summary 方法处理空记录"""
|
||||
tracker = UsageTracker()
|
||||
summary = tracker.get_summary(user_id="test_user")
|
||||
|
||||
assert summary.total_queries == 0
|
||||
assert summary.total_input_tokens == 0
|
||||
assert summary.total_output_tokens == 0
|
||||
assert summary.total_cost == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usage_tracker_record_async_requires_session(self):
|
||||
"""测试 UsageTracker.record_async 需要有效的 session"""
|
||||
tracker = UsageTracker()
|
||||
|
||||
with pytest.raises(RuntimeError, match="not initialized with AsyncSession"):
|
||||
await tracker.record_async(
|
||||
user_id="test_user",
|
||||
brand_id="test_brand",
|
||||
engine_type="chatgpt",
|
||||
query="test query",
|
||||
input_tokens=100,
|
||||
output_tokens=200,
|
||||
cost=0.05,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usage_tracker_get_summary_async_requires_session(self):
|
||||
"""测试 UsageTracker.get_summary_async 需要有效的 session"""
|
||||
tracker = UsageTracker()
|
||||
|
||||
with pytest.raises(RuntimeError, match="not initialized with AsyncSession"):
|
||||
await tracker.get_summary_async(user_id="test_user")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usage_tracker_check_quota_async_requires_session(self):
|
||||
"""测试 UsageTracker.check_quota_async 需要有效的 session"""
|
||||
tracker = UsageTracker()
|
||||
|
||||
with pytest.raises(RuntimeError, match="not initialized with AsyncSession"):
|
||||
await tracker.check_quota_async(user_id="test_user")
|
||||
|
|
@ -0,0 +1,108 @@
|
|||
/**
|
||||
* Suggestions API 路径测试
|
||||
*
|
||||
* 验证前端 API 路径与后端路由一致
|
||||
* 后端路由: /api/v1/brands/${brandId}/suggestions/*
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
|
||||
import { suggestionsApi } from "@/lib/api/suggestions";
|
||||
import { API_BASE } from "@/lib/api/client";
|
||||
|
||||
const mockFetch = vi.fn();
|
||||
const originalFetch = global.fetch;
|
||||
|
||||
beforeEach(() => {
|
||||
global.fetch = mockFetch;
|
||||
vi.clearAllMocks();
|
||||
mockFetch.mockResolvedValue({
|
||||
ok: true,
|
||||
status: 200,
|
||||
json: () => Promise.resolve([]),
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
global.fetch = originalFetch;
|
||||
});
|
||||
|
||||
function getCalledUrl() {
|
||||
return mockFetch.mock.calls[0][0] as string;
|
||||
}
|
||||
|
||||
function getCalledMethod() {
|
||||
return mockFetch.mock.calls[0][1]?.method as string;
|
||||
}
|
||||
|
||||
describe("suggestionsApi 路径测试", () => {
|
||||
describe("getSuggestions", () => {
|
||||
it("应使用正确的路径 /api/v1/brands/${brandId}/suggestions", async () => {
|
||||
await suggestionsApi.getSuggestions("token-123", "brand-abc");
|
||||
|
||||
const url = getCalledUrl();
|
||||
expect(url).toBe(`${API_BASE}/api/v1/brands/brand-abc/suggestions`);
|
||||
});
|
||||
|
||||
it("应正确传递查询参数", async () => {
|
||||
await suggestionsApi.getSuggestions("token-123", "brand-abc", {
|
||||
status: "pending",
|
||||
page: 1,
|
||||
});
|
||||
|
||||
const url = getCalledUrl();
|
||||
expect(url).toContain(`${API_BASE}/api/v1/brands/brand-abc/suggestions`);
|
||||
expect(url).toContain("status=pending");
|
||||
expect(url).toContain("page=1");
|
||||
});
|
||||
|
||||
it("空参数时不应包含查询字符串", async () => {
|
||||
await suggestionsApi.getSuggestions("token-123", "brand-abc", {});
|
||||
|
||||
const url = getCalledUrl();
|
||||
expect(url).toBe(`${API_BASE}/api/v1/brands/brand-abc/suggestions`);
|
||||
});
|
||||
});
|
||||
|
||||
describe("regenerateSuggestions", () => {
|
||||
it("应使用正确的路径 /api/v1/brands/${brandId}/suggestions/regenerate", async () => {
|
||||
await suggestionsApi.regenerateSuggestions("token-123", "brand-abc");
|
||||
|
||||
const url = getCalledUrl();
|
||||
const method = getCalledMethod();
|
||||
|
||||
expect(url).toBe(`${API_BASE}/api/v1/brands/brand-abc/suggestions/regenerate`);
|
||||
expect(method).toBe("POST");
|
||||
});
|
||||
});
|
||||
|
||||
describe("updateSuggestionStatus", () => {
|
||||
it("应使用正确的路径 /api/v1/brands/${brandId}/suggestions/${suggestionId}/status", async () => {
|
||||
await suggestionsApi.updateSuggestionStatus(
|
||||
"token-123",
|
||||
"brand-abc",
|
||||
"suggestion-xyz",
|
||||
"completed"
|
||||
);
|
||||
|
||||
const url = getCalledUrl();
|
||||
const method = getCalledMethod();
|
||||
|
||||
expect(url).toBe(
|
||||
`${API_BASE}/api/v1/brands/brand-abc/suggestions/suggestion-xyz/status`
|
||||
);
|
||||
expect(method).toBe("PUT");
|
||||
});
|
||||
|
||||
it("应正确传递状态数据", async () => {
|
||||
await suggestionsApi.updateSuggestionStatus(
|
||||
"token-123",
|
||||
"brand-abc",
|
||||
"suggestion-xyz",
|
||||
"completed"
|
||||
);
|
||||
|
||||
const options = mockFetch.mock.calls[0][1];
|
||||
expect(JSON.parse(options.body)).toEqual({ status: "completed" });
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -16,7 +16,7 @@ export const suggestionsApi = {
|
|||
params?: Record<string, string | number>
|
||||
) =>
|
||||
fetchWithAuth(
|
||||
`/api/v1/suggestions/${brandId}/suggestions${buildQuery(params || {})}`,
|
||||
`/api/v1/brands/${brandId}/suggestions${buildQuery(params || {})}`,
|
||||
{},
|
||||
token
|
||||
),
|
||||
|
|
@ -24,7 +24,7 @@ export const suggestionsApi = {
|
|||
/** 重新生成建议 */
|
||||
regenerateSuggestions: (token: string, brandId: string) =>
|
||||
fetchWithAuth(
|
||||
`/api/v1/suggestions/${brandId}/suggestions/regenerate`,
|
||||
`/api/v1/brands/${brandId}/suggestions/regenerate`,
|
||||
{ method: "POST" },
|
||||
token
|
||||
),
|
||||
|
|
@ -37,7 +37,7 @@ export const suggestionsApi = {
|
|||
status: string
|
||||
) =>
|
||||
fetchWithAuth(
|
||||
`/api/v1/suggestions/${brandId}/suggestions/${suggestionId}/status`,
|
||||
`/api/v1/brands/${brandId}/suggestions/${suggestionId}/status`,
|
||||
{ method: "PUT", body: JSON.stringify({ status }) },
|
||||
token
|
||||
),
|
||||
|
|
|
|||
Loading…
Reference in New Issue