fix: 审计发现的问题修复

API一致性修复:
- C1: 新增organization.py路由(/api/v1/organization/*)
- C2: 修复suggestions API路径(/api/v1/brands/*而非/api/v1/suggestions/*)
- H7: 修复platforms路由双重前缀(/api/v1/platforms而非/api/v1/api/platforms)

密钥管理改进:
- H3: APIKeyManager支持双密钥(dict格式),文心一言适配器使用KeyManager
- H8: 新增APIKeyFilter日志过滤器,拦截key=和Bearer token

异常处理改进:
- H1: batch_query.py改为httpx.HTTPError分层处理
- H1: database.py改为SQLAlchemyError并抛出ConnectionError
- H1: lifecycle.py和usage_tracker添加日志记录

测试: 764 passed
This commit is contained in:
chiguyong 2026-05-25 23:33:25 +08:00
parent fe4ba39514
commit aeaa50e89e
22 changed files with 1627 additions and 41 deletions

View File

@ -1,4 +1,5 @@
import uuid import uuid
import logging
from datetime import datetime, timezone from datetime import datetime, timezone
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
@ -11,6 +12,8 @@ from app.database import get_db
from app.models.lifecycle import LifecycleProject, ProjectStage from app.models.lifecycle import LifecycleProject, ProjectStage
from app.models.organization import Organization, OrgMember from app.models.organization import Organization, OrgMember
from app.models.user import User from app.models.user import User
logger = logging.getLogger(__name__)
from app.schemas.lifecycle import ( from app.schemas.lifecycle import (
ProjectCreateRequest, ProjectCreateRequest,
ProjectResponse, ProjectResponse,
@ -171,7 +174,8 @@ async def project_stats(
contents_stmt = select(func.count()).where(Content.organization_id == org_id) contents_stmt = select(func.count()).where(Content.organization_id == org_id)
contents_result = await db.execute(contents_stmt) contents_result = await db.execute(contents_stmt)
contents_produced = contents_result.scalar() or 0 contents_produced = contents_result.scalar() or 0
except Exception: except Exception as e:
logger.warning(f"Failed to count contents: {e}")
contents_produced = 0 contents_produced = 0
# avg AI citation rate # avg AI citation rate
@ -197,7 +201,8 @@ async def project_stats(
avg_ai_citation_rate = round(cited_count / total_citations, 4) if total_citations > 0 else None avg_ai_citation_rate = round(cited_count / total_citations, 4) if total_citations > 0 else None
else: else:
avg_ai_citation_rate = None avg_ai_citation_rate = None
except Exception: except Exception as e:
logger.warning(f"Failed to calculate AI citation rate: {e}")
avg_ai_citation_rate = None avg_ai_citation_rate = None
# current stage distribution (map int stage to string) # current stage distribution (map int stage to string)

View File

@ -0,0 +1,343 @@
"""组织管理 API — /api/v1/organization/*"""
import uuid
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, EmailStr, Field
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.database import get_db
from app.models.organization import Organization, OrgMember
from app.models.user import User
router = APIRouter(prefix="/api/v1/organization", tags=["组织管理"])
class OrgInfoResponse(BaseModel):
id: str
name: str
slug: str
description: Optional[str] = None
logo_url: Optional[str] = None
plan: str
max_members: int
member_count: int = 0
model_config = {"from_attributes": True}
class MemberResponse(BaseModel):
id: str
user_id: str
name: str
email: str
role: str
joined_at: str
invited_by: Optional[str] = None
model_config = {"from_attributes": True}
class InviteMemberRequest(BaseModel):
email: EmailStr
role: str = Field(default="viewer", pattern="^(owner|admin|editor|viewer)$")
class UpdateMemberRoleRequest(BaseModel):
role: str = Field(..., pattern="^(owner|admin|editor|viewer)$")
class InviteResponse(BaseModel):
id: str
email: str
role: str
message: str
@router.get("/info", response_model=OrgInfoResponse)
async def get_org_info(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""获取当前用户所属组织的基本信息"""
if not current_user.organization_id:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="用户未加入任何组织",
)
stmt = select(Organization).where(Organization.id == current_user.organization_id)
result = await db.execute(stmt)
org = result.scalar_one_or_none()
if not org:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="组织不存在",
)
count_stmt = select(func.count(OrgMember.id)).where(
OrgMember.organization_id == org.id
)
count_result = await db.execute(count_stmt)
member_count = count_result.scalar() or 0
return OrgInfoResponse(
id=str(org.id),
name=org.name,
slug=org.slug,
description=org.description,
logo_url=org.logo_url,
plan=org.plan,
max_members=org.max_members,
member_count=member_count,
)
@router.get("/members", response_model=list[MemberResponse])
async def list_members(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""获取当前用户所属组织的所有成员"""
if not current_user.organization_id:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="用户未加入任何组织",
)
stmt = (
select(OrgMember, User)
.join(User, User.id == OrgMember.user_id)
.where(OrgMember.organization_id == current_user.organization_id)
.order_by(OrgMember.joined_at.desc())
)
result = await db.execute(stmt)
rows = result.all()
members = []
for member, user in rows:
members.append(
MemberResponse(
id=str(member.id),
user_id=str(member.user_id),
name=user.name or "",
email=user.email,
role=member.role,
joined_at=member.joined_at.isoformat() if member.joined_at else "",
invited_by=str(member.invited_by) if member.invited_by else None,
)
)
return members
@router.post("/members/invite", response_model=InviteResponse, status_code=status.HTTP_201_CREATED)
async def invite_member(
body: InviteMemberRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""邀请新成员加入组织"""
if not current_user.organization_id:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="用户未加入任何组织",
)
current_membership_stmt = select(OrgMember).where(
OrgMember.organization_id == current_user.organization_id,
OrgMember.user_id == current_user.id,
)
current_membership_result = await db.execute(current_membership_stmt)
current_membership = current_membership_result.scalar_one_or_none()
if not current_membership or current_membership.role not in ["owner", "admin"]:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="仅管理员可以邀请新成员",
)
target_user_stmt = select(User).where(User.email == body.email)
target_user_result = await db.execute(target_user_stmt)
target_user = target_user_result.scalar_one_or_none()
if not target_user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="用户不存在",
)
existing_stmt = select(OrgMember).where(
OrgMember.organization_id == current_user.organization_id,
OrgMember.user_id == target_user.id,
)
existing_result = await db.execute(existing_stmt)
existing_membership = existing_result.scalar_one_or_none()
if existing_membership:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="该用户已经是组织成员",
)
count_stmt = select(func.count(OrgMember.id)).where(
OrgMember.organization_id == current_user.organization_id
)
count_result = await db.execute(count_stmt)
current_member_count = count_result.scalar() or 0
org_stmt = select(Organization).where(Organization.id == current_user.organization_id)
org_result = await db.execute(org_stmt)
org = org_result.scalar_one_or_none()
if org and current_member_count >= org.max_members:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"成员数量已达上限({org.max_members}人)",
)
new_membership = OrgMember(
id=uuid.uuid4(),
organization_id=current_user.organization_id,
user_id=target_user.id,
role=body.role,
invited_by=current_user.id,
)
db.add(new_membership)
await db.commit()
await db.refresh(new_membership)
return InviteResponse(
id=str(new_membership.id),
email=target_user.email,
role=new_membership.role,
message=f"成功邀请 {target_user.email} 加入组织",
)
@router.put("/members/{user_id}/role", response_model=MemberResponse)
async def update_member_role(
user_id: uuid.UUID,
body: UpdateMemberRoleRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""更新成员角色"""
if not current_user.organization_id:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="用户未加入任何组织",
)
current_membership_stmt = select(OrgMember).where(
OrgMember.organization_id == current_user.organization_id,
OrgMember.user_id == current_user.id,
)
current_membership_result = await db.execute(current_membership_stmt)
current_membership = current_membership_result.scalar_one_or_none()
if not current_membership or current_membership.role not in ["owner", "admin"]:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="仅管理员可以修改成员角色",
)
target_membership_stmt = select(OrgMember, User).join(
User, User.id == OrgMember.user_id
).where(
OrgMember.organization_id == current_user.organization_id,
OrgMember.user_id == user_id,
)
target_result = await db.execute(target_membership_stmt)
target_row = target_result.first()
if not target_row:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="成员不存在",
)
target_membership, target_user = target_row
if target_membership.role == "owner":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="不能修改所有者角色",
)
if body.role == "owner":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="不能将成员设为所有者",
)
target_membership.role = body.role
await db.commit()
await db.refresh(target_membership)
return MemberResponse(
id=str(target_membership.id),
user_id=str(target_membership.user_id),
name=target_user.name or "",
email=target_user.email,
role=target_membership.role,
joined_at=target_membership.joined_at.isoformat() if target_membership.joined_at else "",
invited_by=str(target_membership.invited_by) if target_membership.invited_by else None,
)
@router.delete("/members/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
async def remove_member(
user_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""移除组织成员"""
if not current_user.organization_id:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="用户未加入任何组织",
)
current_membership_stmt = select(OrgMember).where(
OrgMember.organization_id == current_user.organization_id,
OrgMember.user_id == current_user.id,
)
current_membership_result = await db.execute(current_membership_stmt)
current_membership = current_membership_result.scalar_one_or_none()
if not current_membership or current_membership.role not in ["owner", "admin"]:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="仅管理员可以移除成员",
)
target_membership_stmt = select(OrgMember).where(
OrgMember.organization_id == current_user.organization_id,
OrgMember.user_id == user_id,
)
target_result = await db.execute(target_membership_stmt)
target_membership = target_result.scalar_one_or_none()
if not target_membership:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="成员不存在",
)
if target_membership.role == "owner":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="不能移除所有者",
)
if user_id == current_user.id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="不能移除自己",
)
await db.delete(target_membership)
await db.commit()

View File

@ -14,7 +14,7 @@ from app.config import settings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/platforms", tags=["platforms"]) router = APIRouter(prefix="/platforms", tags=["platforms"])
class PlatformHealthStatus: class PlatformHealthStatus:

View File

@ -1,11 +1,15 @@
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
from sqlalchemy.orm import declarative_base from sqlalchemy.orm import declarative_base
from sqlalchemy import text, JSON from sqlalchemy import text, JSON
from sqlalchemy import exc as sqlalchemy_exc
from sqlalchemy.types import TypeDecorator from sqlalchemy.types import TypeDecorator
from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.dialects.postgresql import JSONB
import logging
from app.config import settings from app.config import settings
logger = logging.getLogger(__name__)
class JSONType(TypeDecorator): class JSONType(TypeDecorator):
"""A JSON type that uses JSONB on PostgreSQL and JSON on other databases.""" """A JSON type that uses JSONB on PostgreSQL and JSON on other databases."""
@ -53,5 +57,9 @@ async def check_db_connection() -> bool:
async with AsyncSessionLocal() as session: async with AsyncSessionLocal() as session:
await session.execute(text("SELECT 1")) await session.execute(text("SELECT 1"))
return True return True
except Exception: except sqlalchemy_exc.SQLAlchemyError as e:
return False logger.error(f"Database connection error: {e}", exc_info=True)
raise ConnectionError(f"Failed to connect to database: {e}") from e
except Exception as e:
logger.error(f"Unexpected database error: {e}", exc_info=True)
raise

View File

@ -3,6 +3,8 @@ import logging
import json import json
from datetime import datetime, timezone from datetime import datetime, timezone
from app.middleware.logging_filter import APIKeyFilter
class JSONFormatter(logging.Formatter): class JSONFormatter(logging.Formatter):
"""将日志记录格式化为 JSON 字符串,便于日志收集平台(如 ELK、Loki解析。""" """将日志记录格式化为 JSON 字符串,便于日志收集平台(如 ELK、Loki解析。"""
@ -47,11 +49,10 @@ def setup_logging(level: int = logging.INFO) -> None:
handler.setFormatter(JSONFormatter()) handler.setFormatter(JSONFormatter())
root_logger = logging.getLogger() root_logger = logging.getLogger()
# 清空已有 handlers避免重复输出
root_logger.handlers.clear() root_logger.handlers.clear()
root_logger.addHandler(handler) root_logger.addHandler(handler)
root_logger.setLevel(level) root_logger.setLevel(level)
root_logger.addFilter(APIKeyFilter())
# 降低 uvicorn/sqlalchemy 等第三方库的噪音
logging.getLogger("uvicorn.access").setLevel(logging.WARNING) logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING) logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)

View File

@ -19,6 +19,7 @@ from app.api.admin import router as admin_router
from app.api.content import router as content_router from app.api.content import router as content_router
from app.api.contents import router as contents_router from app.api.contents import router as contents_router
from app.api.clients import router as clients_router from app.api.clients import router as clients_router
from app.api.organization import router as organization_router
from app.api.agents import router as agents_router from app.api.agents import router as agents_router
from app.api.knowledge import router as knowledge_router from app.api.knowledge import router as knowledge_router
from app.api.distribution import router as distribution_router from app.api.distribution import router as distribution_router
@ -157,6 +158,7 @@ app.include_router(lifecycle_router, prefix="/api/v1/lifecycle", tags=["lifecycl
app.include_router(knowledge_router, prefix="/api/v1/knowledge", tags=["知识库"]) app.include_router(knowledge_router, prefix="/api/v1/knowledge", tags=["知识库"])
app.include_router(content_router, prefix="/api/v1/content", tags=["内容生产"]) app.include_router(content_router, prefix="/api/v1/content", tags=["内容生产"])
app.include_router(contents_router, prefix="/api/v1/contents", tags=["内容管理"]) app.include_router(contents_router, prefix="/api/v1/contents", tags=["内容管理"])
app.include_router(organization_router)
app.include_router(clients_router, prefix="/api/v1/clients", tags=["客户管理"]) app.include_router(clients_router, prefix="/api/v1/clients", tags=["客户管理"])
app.include_router(distribution_router, prefix="/api/v1/distribution", tags=["内容分发"]) app.include_router(distribution_router, prefix="/api/v1/distribution", tags=["内容分发"])
app.include_router(analytics_router, prefix="/api/v1/analytics", tags=["监测优化"]) app.include_router(analytics_router, prefix="/api/v1/analytics", tags=["监测优化"])

View File

@ -0,0 +1,35 @@
import logging
import re
class APIKeyFilter(logging.Filter):
"""日志过滤器,拦截敏感信息"""
PATTERNS = [
(r'key=[A-Za-z0-9_-]{30,}', 'key=***REDACTED***'),
(r'Bearer\s+[A-Za-z0-9_-]{20,}', 'Bearer ***REDACTED***'),
(r'Authorization:\s*[A-Za-z0-9_-]{20,}', 'Authorization: ***REDACTED***'),
(r'api[_-]?key=[A-Za-z0-9_-]{30,}', 'api_key=***REDACTED***'),
(r'(?:^|\s)(AIza[A-Za-z0-9_-]{30,})(?:\s|$)', r'\1***REDACTED***'),
(r'(?:^|\s)(sk-[A-Za-z0-9_-]{30,})(?:\s|$)', r'\1***REDACTED***'),
(r'(?:^|\s)(gsk-[A-Za-z0-9_-]{30,})(?:\s|$)', r'\1***REDACTED***'),
]
def _redact_string(self, text: str) -> str:
for pattern, replacement in self.PATTERNS:
text = re.sub(pattern, replacement, text)
return text
def filter(self, record: logging.LogRecord) -> bool:
if record.args:
new_args = []
for arg in record.args:
if isinstance(arg, str):
arg = self._redact_string(arg)
new_args.append(arg)
record.args = tuple(new_args)
if record.msg and isinstance(record.msg, str):
record.msg = self._redact_string(record.msg)
return True

View File

@ -3,6 +3,8 @@ import logging
from functools import lru_cache from functools import lru_cache
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import httpx
from .base import AIEngineAdapter, AIQueryResult, EngineType from .base import AIEngineAdapter, AIQueryResult, EngineType
from app.services.usage_recorder import UsageRecorder from app.services.usage_recorder import UsageRecorder
from app.services.usage_tracker import UsageTracker from app.services.usage_tracker import UsageTracker
@ -63,8 +65,12 @@ def _build_adapters() -> dict[str, AIEngineAdapter]:
for engine_type, cls in _ADAPTER_CLASSES.items(): for engine_type, cls in _ADAPTER_CLASSES.items():
try: try:
adapters[engine_type.value] = cls() adapters[engine_type.value] = cls()
except Exception: except httpx.HTTPError as e:
logger.warning(f"Failed to initialize {engine_type.value} adapter") logger.warning(f"HTTP error from {engine_type.value}: {e}")
except asyncio.TimeoutError as e:
logger.warning(f"Timeout from {engine_type.value}: {e}")
except Exception as e:
logger.error(f"Unexpected error from {engine_type.value}: {e}", exc_info=True)
return adapters return adapters
@ -79,8 +85,12 @@ def _build_adapters_with_key_manager(
key_manager=key_manager, key_manager=key_manager,
user_id=user_id, user_id=user_id,
) )
except Exception: except httpx.HTTPError as e:
logger.warning(f"Failed to initialize {engine_type.value} adapter") logger.warning(f"HTTP error from {engine_type.value}: {e}")
except asyncio.TimeoutError as e:
logger.warning(f"Timeout from {engine_type.value}: {e}")
except Exception as e:
logger.error(f"Unexpected error from {engine_type.value}: {e}", exc_info=True)
return adapters return adapters
@ -113,18 +123,21 @@ class BatchQueryService:
result = await adapter.query(query, brand_name, competitor_names) result = await adapter.query(query, brand_name, competitor_names)
if self._user_id: if self._user_id:
self._recorder.record( try:
user_id=self._user_id, self._recorder.record(
brand_id=self._brand_id, user_id=self._user_id,
engine_type=engine_type.value, brand_id=self._brand_id,
query=query, engine_type=engine_type.value,
input_tokens=result.input_tokens, query=query,
output_tokens=result.output_tokens, input_tokens=result.input_tokens,
metadata={ output_tokens=result.output_tokens,
"brand_name": brand_name, metadata={
"response_time_ms": result.response_time_ms, "brand_name": brand_name,
}, "response_time_ms": result.response_time_ms,
) },
)
except Exception as e:
logger.warning(f"Failed to record usage: {e}")
return result return result

View File

@ -5,6 +5,7 @@ from datetime import UTC, datetime
import httpx import httpx
from app.services.api_key_manager import APIKeyManager, KeyCredentials
from .base import AIEngineAdapter, AIQueryResult, EngineType from .base import AIEngineAdapter, AIQueryResult, EngineType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -27,28 +28,48 @@ class WenxinAdapter(AIEngineAdapter):
secret_key: str | None = None, secret_key: str | None = None,
rate_limiter=None, rate_limiter=None,
proxy: str | None = None, proxy: str | None = None,
key_manager=None, key_manager: APIKeyManager | None = None,
user_id: str | None = None, user_id: str | None = None,
): ):
super().__init__( self._key_manager = key_manager
api_key=api_key, self._user_id = user_id
rate_limiter=rate_limiter, self.rate_limiter = rate_limiter
proxy=proxy, self.proxy = proxy or self._load_proxy()
key_manager=key_manager, self.api_key = api_key or ""
user_id=user_id, self.secret_key = secret_key or ""
) self._resolve_keys_from_manager(api_key, secret_key, key_manager, user_id)
self.secret_key = secret_key or os.getenv("BAIDU_QIANFAN_SECRET_KEY", "")
self._model = _DEFAULT_MODEL self._model = _DEFAULT_MODEL
self._client = httpx.AsyncClient( self._client = httpx.AsyncClient(
timeout=httpx.Timeout(connect=10.0, read=60.0, write=10.0, pool=10.0), timeout=httpx.Timeout(connect=10.0, read=60.0, write=10.0, pool=10.0),
) )
def _resolve_keys_from_manager(
self,
direct_api_key: str | None,
direct_secret_key: str | None,
key_manager: APIKeyManager | None,
user_id: str | None,
) -> None:
if direct_api_key and direct_api_key.strip():
return
if not key_manager:
return
creds = key_manager.get_credentials("wenxin", user_id=user_id)
if creds:
if not self.api_key:
self.api_key = creds.api_key
if not self.secret_key:
self.secret_key = creds.secret_key or ""
def get_engine_type(self) -> EngineType: def get_engine_type(self) -> EngineType:
return EngineType.WENXIN return EngineType.WENXIN
def _get_env_key(self) -> str | None: def _get_env_key(self) -> str | None:
return os.getenv("BAIDU_QIANFAN_API_KEY", "") return os.getenv("BAIDU_QIANFAN_API_KEY", "")
def _get_env_secret_key(self) -> str | None:
return os.getenv("BAIDU_QIANFAN_SECRET_KEY", "")
def _load_proxy(self) -> str | None: def _load_proxy(self) -> str | None:
return os.getenv("BAIDU_PROXY") or os.getenv("HTTPS_PROXY") or os.getenv("https_proxy") return os.getenv("BAIDU_PROXY") or os.getenv("HTTPS_PROXY") or os.getenv("https_proxy")

View File

@ -1,4 +1,5 @@
import base64 import base64
import json
import logging import logging
import os import os
import uuid import uuid
@ -21,6 +22,21 @@ class KeySource(str, Enum):
ENV = "env" ENV = "env"
@dataclass
class KeyCredentials:
"""统一凭证格式支持单Key和双Key"""
api_key: str
secret_key: str | None = None
def to_dict(self) -> dict:
if self.secret_key:
return {"api_key": self.api_key, "secret_key": self.secret_key}
return {"api_key": self.api_key}
def to_json(self) -> str:
return json.dumps(self.to_dict())
@dataclass @dataclass
class APIKeyConfig: class APIKeyConfig:
engine_type: str engine_type: str
@ -56,16 +72,23 @@ class APIKeyManager:
def add_key( def add_key(
self, self,
engine_type: str, engine_type: str,
api_key: str, credentials: str | dict,
source: KeySource = KeySource.SYSTEM, source: KeySource = KeySource.SYSTEM,
user_id: str | None = None, user_id: str | None = None,
priority: int = 0, priority: int = 0,
) -> APIKeyConfig: ) -> APIKeyConfig:
if isinstance(credentials, dict):
key_hint = self._create_dual_key_hint(credentials)
credentials_json = json.dumps(credentials)
encrypted_key = self._encrypt(credentials_json)
else:
key_hint = self._mask_key(credentials)
encrypted_key = self._encrypt(credentials)
config = APIKeyConfig( config = APIKeyConfig(
engine_type=engine_type, engine_type=engine_type,
key_source=source, key_source=source,
encrypted_key=self._encrypt(api_key), encrypted_key=encrypted_key,
key_hint=self._mask_key(api_key), key_hint=key_hint,
status=KeyStatus.UNKNOWN, status=KeyStatus.UNKNOWN,
priority=priority, priority=priority,
user_id=user_id, user_id=user_id,
@ -76,8 +99,39 @@ class APIKeyManager:
self._keys[engine_type].sort(key=lambda k: k.priority, reverse=True) self._keys[engine_type].sort(key=lambda k: k.priority, reverse=True)
return config return config
def get_key(self, engine_type: str, user_id: str | None = None) -> str | None: def _create_dual_key_hint(self, credentials: dict) -> str:
api_key = credentials.get("api_key", "")
secret_key = credentials.get("secret_key", "")
api_hint = self._mask_key(api_key) if api_key else "***"
secret_hint = self._mask_key(secret_key) if secret_key else "***"
return f"{api_hint}|{secret_hint}"
def get_key(self, engine_type: str, user_id: str | None = None) -> str | dict | None:
configs = self._keys.get(engine_type, []) configs = self._keys.get(engine_type, [])
config = self._find_best_key(configs, user_id)
if not config:
return None
decrypted = self._decrypt(config.encrypted_key)
try:
return json.loads(decrypted)
except (json.JSONDecodeError, TypeError):
return decrypted
def get_credentials(self, engine_type: str, user_id: str | None = None) -> KeyCredentials | None:
key_data = self.get_key(engine_type, user_id)
if not key_data:
return None
if isinstance(key_data, dict):
return KeyCredentials(
api_key=key_data.get("api_key", ""),
secret_key=key_data.get("secret_key")
)
return KeyCredentials(api_key=key_data)
def get_all_keys(self, engine_type: str, user_id: str | None = None) -> dict | str | None:
return self.get_key(engine_type, user_id)
def _find_best_key(self, configs: list[APIKeyConfig], user_id: str | None = None) -> APIKeyConfig | None:
if user_id: if user_id:
for c in configs: for c in configs:
if ( if (
@ -85,10 +139,10 @@ class APIKeyManager:
and c.user_id == user_id and c.user_id == user_id
and c.status in self._USABLE_STATUSES and c.status in self._USABLE_STATUSES
): ):
return self._decrypt(c.encrypted_key) return c
for c in configs: for c in configs:
if c.key_source in self._FALLBACK_SOURCES and c.status in self._USABLE_STATUSES: if c.key_source in self._FALLBACK_SOURCES and c.status in self._USABLE_STATUSES:
return self._decrypt(c.encrypted_key) return c
return None return None
def get_any_available_key(self, engine_type: str) -> str | None: def get_any_available_key(self, engine_type: str) -> str | None:

View File

@ -1,3 +1,4 @@
import logging
import uuid import uuid
from datetime import datetime from datetime import datetime
@ -8,6 +9,17 @@ from sqlalchemy.pool import StaticPool
from app.database import Base from app.database import Base
from app.models.user import User from app.models.user import User
from app.middleware.logging_filter import APIKeyFilter
@pytest.fixture(autouse=True)
def add_api_key_filter():
"""自动为每个测试添加APIKeyFilter到root logger"""
root_logger = logging.getLogger()
api_key_filter = APIKeyFilter()
root_logger.addFilter(api_key_filter)
yield
root_logger.removeFilter(api_key_filter)
@pytest_asyncio.fixture @pytest_asyncio.fixture

View File

@ -0,0 +1,110 @@
import pytest
from unittest.mock import patch, AsyncMock, MagicMock, PropertyMock
import logging
import uuid
from app.api.lifecycle import project_stats
class TestLifecycleExceptionHandling:
"""测试 lifecycle.py 中的异常处理行为"""
@pytest.mark.asyncio
async def test_project_stats_handles_content_query_failure(self, caplog):
"""测试 project_stats 当 Content 查询失败时的处理"""
from app.models.user import User
org_id = uuid.uuid4()
user_id = uuid.uuid4()
user = User(
id=user_id,
email="test@example.com",
password_hash="hash",
name="Test User",
plan="free",
organization_id=org_id,
)
execute_results = [
MagicMock(one=MagicMock(total=0, active=0)),
MagicMock(),
MagicMock(),
MagicMock(),
MagicMock(),
]
execute_results[1].all.return_value = []
execute_results[2].scalar.return_value = 0
execute_results[3].all.return_value = []
execute_results[4].all.return_value = []
execute_count = [0]
def execute_side_effect(*args, **kwargs):
idx = execute_count[0]
execute_count[0] += 1
if idx == 2:
raise RuntimeError("Content table not available")
return execute_results[idx]
mock_session = AsyncMock()
mock_session.execute.side_effect = execute_side_effect
with caplog.at_level(logging.WARNING):
result = await project_stats(
db=mock_session,
current_user=user
)
assert result.contents_produced == 0
assert any("Failed to count contents" in record.message for record in caplog.records)
@pytest.mark.asyncio
async def test_project_stats_handles_citation_query_failure(self, caplog):
"""测试 project_stats 当 CitationRecord 查询失败时的处理"""
from app.models.user import User
org_id = uuid.uuid4()
user_id = uuid.uuid4()
user = User(
id=user_id,
email="test@example.com",
password_hash="hash",
name="Test User",
plan="free",
organization_id=org_id,
)
execute_results = [
MagicMock(one=MagicMock(total=0, active=0)),
MagicMock(),
MagicMock(),
MagicMock(),
MagicMock(),
MagicMock(),
]
execute_results[1].all.return_value = []
execute_results[2].scalar.return_value = 0
execute_results[3].all.return_value = []
execute_results[4].one.return_value = MagicMock(total_citations=0, cited_count=0)
execute_results[5].all.return_value = []
execute_count = [0]
def execute_side_effect(*args, **kwargs):
idx = execute_count[0]
execute_count[0] += 1
if idx == 4:
raise RuntimeError("CitationRecord table not available")
return execute_results[idx]
mock_session = AsyncMock()
mock_session.execute.side_effect = execute_side_effect
with caplog.at_level(logging.WARNING):
result = await project_stats(
db=mock_session,
current_user=user
)
assert result.avg_ai_citation_rate is None
assert any("Failed to calculate AI citation rate" in record.message for record in caplog.records)

View File

@ -0,0 +1,267 @@
"""组织管理API测试 - 验证 /api/v1/organization/* 端点"""
import uuid
import pytest
import pytest_asyncio
from httpx import AsyncClient, ASGITransport
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine
from sqlalchemy.pool import StaticPool
from app.database import Base
from app.main import app
from app.models.user import User
from app.models.organization import Organization, OrgMember
from app.api.deps import get_current_user, get_db
@pytest_asyncio.fixture
async def async_engine():
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest_asyncio.fixture
async def async_session(async_engine):
async_session_maker = async_sessionmaker(
async_engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
async with async_session_maker() as session:
yield session
@pytest_asyncio.fixture
async def test_user(async_session):
user = User(
id=uuid.uuid4(),
email="test@example.com",
password_hash="hashed_password",
name="Test User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
)
async_session.add(user)
await async_session.commit()
await async_session.refresh(user)
return user
@pytest_asyncio.fixture
async def test_organization(async_session, test_user):
org = Organization(
id=uuid.uuid4(),
name="Test Organization",
slug="test-org",
plan="free",
)
async_session.add(org)
await async_session.flush()
test_user.organization_id = org.id
async_session.add(test_user)
membership = OrgMember(
id=uuid.uuid4(),
organization_id=org.id,
user_id=test_user.id,
role="owner",
)
async_session.add(membership)
await async_session.commit()
await async_session.refresh(org)
return org
@pytest_asyncio.fixture
async def async_client(async_session, test_user):
async def override_get_db():
yield async_session
async def override_get_current_user():
return test_user
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_get_current_user
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
app.dependency_overrides.clear()
class TestOrganizationRoutes:
"""组织管理API端点测试"""
@pytest.mark.asyncio
async def test_routes_are_registered(self):
"""验证组织管理路由已正确注册到app"""
from app.main import app as main_app
org_routes = [
route.path for route in main_app.routes
if hasattr(route, 'path') and route.path.startswith('/api/v1/organization')
]
print(f"\n已注册的组织路由: {org_routes}")
assert '/api/v1/organization/info' in org_routes, "路由未注册"
assert '/api/v1/organization/members' in org_routes, "路由未注册"
assert '/api/v1/organization/members/invite' in org_routes, "路由未注册"
@pytest.mark.asyncio
async def test_direct_endpoint_call(self, async_session, test_user, test_organization):
"""直接测试端点调用"""
async def override_get_db():
yield async_session
async def override_get_current_user():
return test_user
from app.main import app as test_app
test_app.dependency_overrides[get_db] = override_get_db
test_app.dependency_overrides[get_current_user] = override_get_current_user
transport = ASGITransport(app=test_app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
# Check which routes are available
available_routes = [r.path for r in test_app.routes if hasattr(r, 'path') and 'organization' in r.path]
print(f"\n测试app中的组织路由: {available_routes}")
response = await client.get("/api/v1/organization/info")
print(f"Response status: {response.status_code}")
print(f"Response: {response.text[:200]}")
assert response.status_code == 200
test_app.dependency_overrides.clear()
@pytest.mark.asyncio
async def test_organization_info_endpoint_exists(self, async_client, test_organization):
"""验证 /api/v1/organization/info 端点存在并返回200"""
response = await async_client.get("/api/v1/organization/info")
assert response.status_code == 200, f"期望返回200实际返回 {response.status_code}"
data = response.json()
assert "id" in data
assert "name" in data
assert "slug" in data
@pytest.mark.asyncio
async def test_organization_members_endpoint_exists(self, async_client, test_organization):
"""验证 /api/v1/organization/members 端点存在并返回200"""
response = await async_client.get("/api/v1/organization/members")
assert response.status_code == 200, f"期望返回200实际返回 {response.status_code}"
data = response.json()
assert isinstance(data, list)
@pytest.mark.asyncio
async def test_organization_members_invite_endpoint_exists(self, async_client, test_organization, async_session):
"""验证 /api/v1/organization/members/invite 端点存在"""
invite_user = User(
id=uuid.uuid4(),
email="newuser@example.com",
password_hash="hashed_password",
name="New User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
)
async_session.add(invite_user)
await async_session.commit()
response = await async_client.post(
"/api/v1/organization/members/invite",
json={"email": "newuser@example.com", "role": "viewer"}
)
assert response.status_code == 201, f"期望返回201实际返回 {response.status_code}"
@pytest.mark.asyncio
async def test_organization_member_role_endpoint_exists(self, async_client, test_organization, async_session, test_user):
"""验证 /api/v1/organization/members/{id}/role 端点存在"""
new_user = User(
id=uuid.uuid4(),
email="member@example.com",
password_hash="hashed_password",
name="Member User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
)
async_session.add(new_user)
await async_session.flush()
membership = OrgMember(
id=uuid.uuid4(),
organization_id=test_organization.id,
user_id=new_user.id,
role="viewer",
)
async_session.add(membership)
await async_session.commit()
response = await async_client.put(
f"/api/v1/organization/members/{new_user.id}/role",
json={"role": "admin"}
)
assert response.status_code == 200, f"期望返回200实际返回 {response.status_code}"
@pytest.mark.asyncio
async def test_organization_member_delete_endpoint_exists(self, async_client, test_organization, async_session):
"""验证 /api/v1/organization/members/{id} 端点存在"""
new_user = User(
id=uuid.uuid4(),
email="todelete@example.com",
password_hash="hashed_password",
name="Delete User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
)
async_session.add(new_user)
await async_session.flush()
membership = OrgMember(
id=uuid.uuid4(),
organization_id=test_organization.id,
user_id=new_user.id,
role="viewer",
)
async_session.add(membership)
await async_session.commit()
response = await async_client.delete(f"/api/v1/organization/members/{new_user.id}")
assert response.status_code == 204, f"期望返回204实际返回 {response.status_code}"
@pytest.mark.asyncio
async def test_user_without_organization_returns_404(self, async_client, test_user, async_session):
"""验证未加入组织的用户访问 /api/v1/organization/info 返回404"""
response = await async_client.get("/api/v1/organization/info")
assert response.status_code == 404, f"期望返回404实际返回 {response.status_code}"
@pytest.mark.asyncio
async def test_unauthorized_access(self, async_session):
"""验证未授权访问返回401"""
async def override_get_db():
yield async_session
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides.pop(get_current_user, None)
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.get("/api/v1/organization/info")
assert response.status_code == 401, f"期望返回401实际返回 {response.status_code}"
app.dependency_overrides.clear()

View File

@ -0,0 +1,59 @@
"""平台API路由测试 - 验证不存在双重前缀问题"""
import pytest
import pytest_asyncio
from httpx import AsyncClient, ASGITransport
from app.main import app
class TestPlatformsRoutePrefix:
"""平台路由前缀测试"""
@pytest.mark.asyncio
async def test_platforms_endpoint_not_double_prefixed(self):
"""
验证platforms端点路径不是 /api/v1/api/*
预期
- /api/v1/api/platforms/health 应该返回404不存在
- /api/v1/platforms/health 应该返回200正确路径
"""
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
wrong_path_response = await client.get("/api/v1/api/platforms/health")
correct_path_response = await client.get("/api/v1/platforms/health")
assert wrong_path_response.status_code == 404, \
f"错误:/api/v1/api/platforms/health 返回 {wrong_path_response.status_code}应该是404说明存在双重前缀问题"
assert correct_path_response.status_code == 200, \
f"错误:/api/v1/platforms/health 返回 {correct_path_response.status_code}应该是200端点不可用"
@pytest.mark.asyncio
async def test_platforms_health_endpoint_accessible(self):
"""验证platforms健康检查端点可通过正确路径访问"""
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.get("/api/v1/platforms/health")
assert response.status_code == 200, \
f"平台健康检查端点无法访问,状态码: {response.status_code}"
data = response.json()
assert "platforms" in data
assert "total" in data
assert "configured_count" in data
@pytest.mark.asyncio
async def test_platforms_health_by_name_endpoint(self):
"""验证platforms按名称健康检查端点可通过正确路径访问"""
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.get("/api/v1/platforms/health/kimi")
assert response.status_code == 200, \
f"平台按名称健康检查端点无法访问,状态码: {response.status_code}"
data = response.json()
assert "name" in data
assert data["name"] == "kimi"

View File

@ -0,0 +1,45 @@
import pytest
from unittest.mock import patch, AsyncMock, MagicMock
from sqlalchemy.exc import SQLAlchemyError, OperationalError
from app.database import check_db_connection
class TestDatabaseExceptionHandling:
"""测试 database.py 中的异常处理行为"""
@pytest.mark.asyncio
async def test_check_db_connection_handles_sqlalchemy_error(self):
"""测试 check_db_connection 对 SQLAlchemyError 的处理"""
with patch('app.database.AsyncSessionLocal') as mock_session_local:
mock_session = AsyncMock()
mock_session_local.return_value.__aenter__.return_value = mock_session
mock_session.execute.side_effect = OperationalError(
statement="SELECT 1",
params={},
orig=Exception("Connection refused")
)
with pytest.raises(ConnectionError, match="Failed to connect to database"):
await check_db_connection()
@pytest.mark.asyncio
async def test_check_db_connection_handles_generic_exception(self):
"""测试 check_db_connection 对通用异常的处理"""
with patch('app.database.AsyncSessionLocal') as mock_session_local:
mock_session = AsyncMock()
mock_session_local.return_value.__aenter__.return_value = mock_session
mock_session.execute.side_effect = RuntimeError("Unexpected error")
with pytest.raises(RuntimeError):
await check_db_connection()
@pytest.mark.asyncio
async def test_check_db_connection_success(self):
"""测试 check_db_connection 成功时返回 True"""
with patch('app.database.AsyncSessionLocal') as mock_session_local:
mock_session = AsyncMock()
mock_session_local.return_value.__aenter__.return_value = mock_session
result = await check_db_connection()
assert result is True

View File

@ -0,0 +1,107 @@
"""测试日志过滤器功能 - TDD RED阶段"""
import logging
import pytest
API_KEY = "AIzaSyDummyKey123456789abcdefghijklmnop"
SK_TEST_KEY = "sk_test_1234567890abcdefghijklmnopqrstuvwxyz1234567890"
class TestAPIKeyFilter:
"""测试APIKeyFilter过滤器"""
def test_api_key_filtered_from_logs(self, caplog):
"""验证API Key从日志中过滤"""
with caplog.at_level(logging.INFO):
logging.info(f"Request to Gemini API with key={API_KEY}")
assert API_KEY not in caplog.text, f"API Key should not appear in logs, but found: {API_KEY}"
def test_api_key_filtered_from_url(self, caplog):
"""验证URL中的API Key从日志中过滤"""
with caplog.at_level(logging.INFO):
logging.info(f"Making request to https://api.gemini.com?key={API_KEY}&other=value")
assert API_KEY not in caplog.text
assert "key=***REDACTED***" in caplog.text
def test_sk_test_key_filtered(self, caplog):
"""验证sk_test密钥格式也被过滤"""
with caplog.at_level(logging.INFO):
logging.info(f"Request with api_key={SK_TEST_KEY}")
assert SK_TEST_KEY not in caplog.text
def test_bearer_token_filtered(self, caplog):
"""验证Bearer token从日志中过滤"""
with caplog.at_level(logging.INFO):
logging.info("Request with Bearer token abcdefghijklmnopqrstuvwxyz")
assert "Bearer abcdefghijklmnopqrstuvwxyz" not in caplog.text
def test_authorization_header_filtered(self, caplog):
"""验证Authorization header从日志中过滤"""
with caplog.at_level(logging.INFO):
logging.info("Request with Authorization: secret_api_key_here_1234567890")
assert "secret_api_key_here_1234567890" not in caplog.text
def test_normal_log_preserved(self, caplog):
"""验证普通日志内容不受影响"""
normal_message = "This is a normal log message without any secrets"
with caplog.at_level(logging.INFO):
logging.info(normal_message)
assert normal_message in caplog.text
class TestAPIKeyFilterDirectly:
"""直接测试APIKeyFilter类"""
def test_filter_class_exists(self):
"""验证APIKeyFilter类存在"""
from app.middleware.logging_filter import APIKeyFilter
assert APIKeyFilter is not None
def test_filter_redacts_url_key_param(self):
"""验证过滤器直接处理URL中的key参数时正确脱敏"""
from app.middleware.logging_filter import APIKeyFilter
filter_instance = APIKeyFilter()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="",
lineno=0,
msg=f"Request to https://api.gemini.com?key={API_KEY}",
args=(),
exc_info=None
)
result = filter_instance.filter(record)
assert result is True
assert API_KEY not in record.msg
assert "key=***REDACTED***" in record.msg
def test_filter_preserves_normal_content(self):
"""验证过滤器保留正常日志内容"""
from app.middleware.logging_filter import APIKeyFilter
filter_instance = APIKeyFilter()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="",
lineno=0,
msg="Normal log message without secrets",
args=(),
exc_info=None
)
result = filter_instance.filter(record)
assert result is True
assert record.msg == "Normal log message without secrets"

View File

@ -0,0 +1,148 @@
import pytest
from app.services.api_key_manager import APIKeyManager, KeySource
class TestKeyCredentials:
def test_key_credentials_single_key(self):
from app.services.api_key_manager import KeyCredentials
creds = KeyCredentials(api_key="test-api-key-123")
assert creds.api_key == "test-api-key-123"
assert creds.secret_key is None
assert creds.to_dict() == {"api_key": "test-api-key-123"}
def test_key_credentials_dual_keys(self):
from app.services.api_key_manager import KeyCredentials
creds = KeyCredentials(api_key="test-api-key-123", secret_key="test-secret-key-456")
assert creds.api_key == "test-api-key-123"
assert creds.secret_key == "test-secret-key-456"
assert creds.to_dict() == {"api_key": "test-api-key-123", "secret_key": "test-secret-key-456"}
class TestAPIKeyManagerDualKeys:
@pytest.fixture
def manager(self):
return APIKeyManager()
def test_add_dual_keys_dict_for_wenxin(self, manager):
"""测试为文心一言添加双密钥字典"""
credentials = {"api_key": "baidu-api-key-xxx", "secret_key": "baidu-secret-key-yyy"}
config = manager.add_key("wenxin", credentials, source=KeySource.USER)
assert config.engine_type == "wenxin"
assert config.key_source == KeySource.USER
def test_add_dual_keys_and_retrieve(self, manager):
"""测试添加双密钥后能够正确获取"""
credentials = {"api_key": "baidu-api-key-xxx", "secret_key": "baidu-secret-key-yyy"}
manager.add_key("wenxin", credentials, source=KeySource.SYSTEM)
retrieved = manager.get_key("wenxin")
assert retrieved is not None
assert isinstance(retrieved, dict)
assert retrieved["api_key"] == "baidu-api-key-xxx"
assert retrieved["secret_key"] == "baidu-secret-key-yyy"
def test_get_credentials_returns_complete_credentials(self, manager):
"""测试获取完整凭证包含api_key和secret_key"""
credentials = {"api_key": "baidu-api-key-xxx", "secret_key": "baidu-secret-key-yyy"}
manager.add_key("wenxin", credentials, source=KeySource.SYSTEM)
creds = manager.get_credentials("wenxin")
assert creds is not None
assert creds.api_key == "baidu-api-key-xxx"
assert creds.secret_key == "baidu-secret-key-yyy"
def test_get_credentials_with_single_key(self, manager):
"""测试获取单Key凭证兼容现有逻辑"""
manager.add_key("chatgpt", "sk-chatgpt-key-1234567890", source=KeySource.SYSTEM)
creds = manager.get_credentials("chatgpt")
assert creds is not None
assert creds.api_key == "sk-chatgpt-key-1234567890"
assert creds.secret_key is None
def test_get_credentials_returns_none_for_unknown_engine(self, manager):
"""测试获取未知引擎凭证返回None"""
creds = manager.get_credentials("unknown_engine")
assert creds is None
def test_add_single_key_string_still_works(self, manager):
"""测试单Key字符串格式仍然正常工作"""
config = manager.add_key("chatgpt", "sk-test-key-1234567890", source=KeySource.SYSTEM)
assert config.engine_type == "chatgpt"
key = manager.get_key("chatgpt")
assert key == "sk-test-key-1234567890"
def test_add_dual_keys_string_auto_wrap(self, manager):
"""测试单Key字符串自动包装为api_key格式"""
manager.add_key("wenxin", "single-key-value", source=KeySource.SYSTEM)
creds = manager.get_credentials("wenxin")
assert creds is not None
assert creds.api_key == "single-key-value"
assert creds.secret_key is None
def test_get_all_keys_returns_dict_for_dual_key_engine(self, manager):
"""测试获取所有Key信息时返回完整字典"""
credentials = {"api_key": "api-key-xxx", "secret_key": "secret-key-yyy"}
manager.add_key("wenxin", credentials, source=KeySource.SYSTEM)
keys = manager.get_all_keys("wenxin")
assert keys is not None
assert "api_key" in keys
assert "secret_key" in keys
assert keys["api_key"] == "api-key-xxx"
assert keys["secret_key"] == "secret-key-yyy"
def test_env_mapping_includes_dual_keys(self):
"""测试环境变量映射支持文心一言双Key"""
manager = APIKeyManager()
assert "wenxin" in manager._ENV_MAPPING
assert manager._ENV_MAPPING["wenxin"] == "BAIDU_QIANFAN_API_KEY"
def test_env_loading_for_dual_key_engine(self, manager, monkeypatch):
"""测试加载环境变量时正确处理文心一言双Key"""
monkeypatch.setenv("BAIDU_QIANFAN_API_KEY", "env-api-key-xxx")
monkeypatch.setenv("BAIDU_SECRET_KEY", "env-secret-key-yyy")
manager.load_env_keys()
creds = manager.get_credentials("wenxin")
assert creds is not None
assert creds.api_key == "env-api-key-xxx"
class TestWenxinAdapterDualKeys:
def test_wenxin_adapter_gets_dual_keys_from_manager(self):
"""测试文心一言适配器从KeyManager获取双密钥"""
from app.services.ai_engine.wenxin import WenxinAdapter
from app.services.api_key_manager import APIKeyManager, KeySource
manager = APIKeyManager()
credentials = {"api_key": "test-api-key-123", "secret_key": "test-secret-key-456"}
manager.add_key("wenxin", credentials, source=KeySource.SYSTEM)
adapter = WenxinAdapter(key_manager=manager)
assert adapter.api_key == "test-api-key-123"
assert adapter.secret_key == "test-secret-key-456"
def test_wenxin_adapter_with_single_key_fallback(self):
"""测试文心一言适配器单Key时的降级处理"""
from app.services.ai_engine.wenxin import WenxinAdapter
from app.services.api_key_manager import APIKeyManager, KeySource
manager = APIKeyManager()
manager.add_key("wenxin", "single-key-only", source=KeySource.SYSTEM)
adapter = WenxinAdapter(key_manager=manager)
assert adapter.api_key == "single-key-only"
assert adapter.secret_key == ""
def test_wenxin_adapter_direct_credentials_take_precedence(self):
"""测试直接传入的credentials优先于KeyManager"""
from app.services.ai_engine.wenxin import WenxinAdapter
from app.services.api_key_manager import APIKeyManager, KeySource
manager = APIKeyManager()
manager.add_key("wenxin", {"api_key": "manager-api", "secret_key": "manager-secret"}, source=KeySource.SYSTEM)
adapter = WenxinAdapter(
api_key="direct-api",
secret_key="direct-secret",
key_manager=manager
)
assert adapter.api_key == "direct-api"
assert adapter.secret_key == "direct-secret"

View File

@ -0,0 +1,178 @@
import pytest
from unittest.mock import patch, MagicMock, call
import httpx
import asyncio
from app.services.ai_engine.base import AIEngineAdapter, EngineType
class TestBatchQueryExceptionHandling:
"""测试 batch_query.py 中的异常处理行为"""
def test_build_adapters_handles_generic_exception(self, caplog):
"""测试 _build_adapters 对通用异常的处理(应记录 error 日志)"""
from app.services.ai_engine.batch_query import _build_adapters, _ADAPTER_CLASSES
class FailingAdapter(AIEngineAdapter):
def __init__(self):
super().__init__(api_key="test")
raise RuntimeError("Simulated initialization failure")
def get_engine_type(self):
return EngineType.CHATGPT
def _get_env_key(self):
return "TEST_KEY"
async def query(self, query, brand_name, competitor_names=None):
pass
with patch.dict(_ADAPTER_CLASSES, {EngineType.CHATGPT: FailingAdapter}):
_build_adapters.cache_clear()
result = _build_adapters()
assert EngineType.CHATGPT.value not in result
def test_build_adapters_handles_httpx_http_error(self, caplog):
"""测试 _build_adapters 对 httpx.HTTPError 的处理(应记录 warning 日志)"""
from app.services.ai_engine.batch_query import _build_adapters, _ADAPTER_CLASSES
class HttpErrorAdapter(AIEngineAdapter):
def __init__(self):
super().__init__(api_key="test")
raise httpx.HTTPError("HTTP connection failed")
def get_engine_type(self):
return EngineType.PERPLEXITY
def _get_env_key(self):
return "TEST_KEY"
async def query(self, query, brand_name, competitor_names=None):
pass
with patch.dict(_ADAPTER_CLASSES, {EngineType.PERPLEXITY: HttpErrorAdapter}):
_build_adapters.cache_clear()
result = _build_adapters()
assert EngineType.PERPLEXITY.value not in result
assert any("HTTP error" in record.message for record in caplog.records)
def test_build_adapters_handles_timeout_error(self, caplog):
"""测试 _build_adapters 对 asyncio.TimeoutError 的处理(应记录 warning 日志)"""
from app.services.ai_engine.batch_query import _build_adapters, _ADAPTER_CLASSES
class TimeoutAdapter(AIEngineAdapter):
def __init__(self):
super().__init__(api_key="test")
raise asyncio.TimeoutError("Connection timeout")
def get_engine_type(self):
return EngineType.KIMI
def _get_env_key(self):
return "TEST_KEY"
async def query(self, query, brand_name, competitor_names=None):
pass
with patch.dict(_ADAPTER_CLASSES, {EngineType.KIMI: TimeoutAdapter}):
_build_adapters.cache_clear()
result = _build_adapters()
assert EngineType.KIMI.value not in result
assert any("Timeout" in record.message for record in caplog.records)
def test_build_adapters_with_key_manager_handles_httpx_http_error(self, caplog):
"""测试 _build_adapters_with_key_manager 对 httpx.HTTPError 的处理"""
from app.services.ai_engine.batch_query import _build_adapters_with_key_manager, _ADAPTER_CLASSES
class HttpErrorAdapter(AIEngineAdapter):
def __init__(self, key_manager=None, user_id=None):
super().__init__(api_key="test")
raise httpx.HTTPError("HTTP connection failed")
def get_engine_type(self):
return EngineType.WENXIN
def _get_env_key(self):
return "TEST_KEY"
async def query(self, query, brand_name, competitor_names=None):
pass
with patch.dict(_ADAPTER_CLASSES, {EngineType.WENXIN: HttpErrorAdapter}):
result = _build_adapters_with_key_manager(
key_manager=MagicMock(),
user_id="test_user"
)
assert EngineType.WENXIN.value not in result
assert any("HTTP error" in record.message for record in caplog.records)
def test_build_adapters_with_key_manager_handles_timeout_error(self, caplog):
"""测试 _build_adapters_with_key_manager 对 asyncio.TimeoutError 的处理"""
from app.services.ai_engine.batch_query import _build_adapters_with_key_manager, _ADAPTER_CLASSES
class TimeoutAdapter(AIEngineAdapter):
def __init__(self, key_manager=None, user_id=None):
super().__init__(api_key="test")
raise asyncio.TimeoutError("Connection timeout")
def get_engine_type(self):
return EngineType.DEEPSEEK
def _get_env_key(self):
return "TEST_KEY"
async def query(self, query, brand_name, competitor_names=None):
pass
with patch.dict(_ADAPTER_CLASSES, {EngineType.DEEPSEEK: TimeoutAdapter}):
result = _build_adapters_with_key_manager(
key_manager=MagicMock(),
user_id="test_user"
)
assert EngineType.DEEPSEEK.value not in result
assert any("Timeout" in record.message for record in caplog.records)
def test_build_adapters_with_key_manager_handles_generic_exception(self, caplog):
"""测试 _build_adapters_with_key_manager 对通用异常的处理"""
from app.services.ai_engine.batch_query import _build_adapters_with_key_manager, _ADAPTER_CLASSES
class FailingAdapter(AIEngineAdapter):
def __init__(self, key_manager=None, user_id=None):
super().__init__(api_key="test")
raise RuntimeError("Simulated key manager failure")
def get_engine_type(self):
return EngineType.QWEN
def _get_env_key(self):
return "TEST_KEY"
async def query(self, query, brand_name, competitor_names=None):
pass
with patch.dict(_ADAPTER_CLASSES, {EngineType.QWEN: FailingAdapter}):
result = _build_adapters_with_key_manager(
key_manager=MagicMock(),
user_id="test_user"
)
assert EngineType.QWEN.value not in result
def test_register_adapter_handles_exception(self):
"""测试 register_adapter 函数对异常的处理"""
from app.services.ai_engine.batch_query import register_adapter
class BadAdapter(AIEngineAdapter):
def __init__(self):
super().__init__(api_key="test")
raise ValueError("Adapter registration failed")
def get_engine_type(self):
return EngineType.GEMINI
def _get_env_key(self):
return "TEST_KEY"
async def query(self, query, brand_name, competitor_names=None):
pass
register_adapter(BadAdapter)

View File

@ -0,0 +1,70 @@
import pytest
from unittest.mock import patch, MagicMock
import logging
from app.services.usage_tracker import UsageTracker
class TestUsageTrackerExceptionHandling:
"""测试 usage_tracker.py 中的异常处理行为"""
def test_usage_tracker_record_handles_exception(self, caplog):
"""测试 UsageTracker.record 方法对异常的处理"""
with caplog.at_level(logging.WARNING):
tracker = UsageTracker()
with patch.object(tracker, '_records', side_effect=Exception("Simulated error")):
try:
tracker.record(
user_id="test_user",
brand_id="test_brand",
engine_type="chatgpt",
query="test query",
input_tokens=100,
output_tokens=200,
cost=0.05,
)
except Exception:
pass
def test_usage_tracker_get_summary_handles_empty_records(self):
"""测试 UsageTracker.get_summary 方法处理空记录"""
tracker = UsageTracker()
summary = tracker.get_summary(user_id="test_user")
assert summary.total_queries == 0
assert summary.total_input_tokens == 0
assert summary.total_output_tokens == 0
assert summary.total_cost == 0.0
@pytest.mark.asyncio
async def test_usage_tracker_record_async_requires_session(self):
"""测试 UsageTracker.record_async 需要有效的 session"""
tracker = UsageTracker()
with pytest.raises(RuntimeError, match="not initialized with AsyncSession"):
await tracker.record_async(
user_id="test_user",
brand_id="test_brand",
engine_type="chatgpt",
query="test query",
input_tokens=100,
output_tokens=200,
cost=0.05,
)
@pytest.mark.asyncio
async def test_usage_tracker_get_summary_async_requires_session(self):
"""测试 UsageTracker.get_summary_async 需要有效的 session"""
tracker = UsageTracker()
with pytest.raises(RuntimeError, match="not initialized with AsyncSession"):
await tracker.get_summary_async(user_id="test_user")
@pytest.mark.asyncio
async def test_usage_tracker_check_quota_async_requires_session(self):
"""测试 UsageTracker.check_quota_async 需要有效的 session"""
tracker = UsageTracker()
with pytest.raises(RuntimeError, match="not initialized with AsyncSession"):
await tracker.check_quota_async(user_id="test_user")

View File

@ -0,0 +1,108 @@
/**
* Suggestions API
*
* API
* : /api/v1/brands/${brandId}/suggestions/*
*/
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
import { suggestionsApi } from "@/lib/api/suggestions";
import { API_BASE } from "@/lib/api/client";
const mockFetch = vi.fn();
const originalFetch = global.fetch;
beforeEach(() => {
global.fetch = mockFetch;
vi.clearAllMocks();
mockFetch.mockResolvedValue({
ok: true,
status: 200,
json: () => Promise.resolve([]),
});
});
afterEach(() => {
global.fetch = originalFetch;
});
function getCalledUrl() {
return mockFetch.mock.calls[0][0] as string;
}
function getCalledMethod() {
return mockFetch.mock.calls[0][1]?.method as string;
}
describe("suggestionsApi 路径测试", () => {
describe("getSuggestions", () => {
it("应使用正确的路径 /api/v1/brands/${brandId}/suggestions", async () => {
await suggestionsApi.getSuggestions("token-123", "brand-abc");
const url = getCalledUrl();
expect(url).toBe(`${API_BASE}/api/v1/brands/brand-abc/suggestions`);
});
it("应正确传递查询参数", async () => {
await suggestionsApi.getSuggestions("token-123", "brand-abc", {
status: "pending",
page: 1,
});
const url = getCalledUrl();
expect(url).toContain(`${API_BASE}/api/v1/brands/brand-abc/suggestions`);
expect(url).toContain("status=pending");
expect(url).toContain("page=1");
});
it("空参数时不应包含查询字符串", async () => {
await suggestionsApi.getSuggestions("token-123", "brand-abc", {});
const url = getCalledUrl();
expect(url).toBe(`${API_BASE}/api/v1/brands/brand-abc/suggestions`);
});
});
describe("regenerateSuggestions", () => {
it("应使用正确的路径 /api/v1/brands/${brandId}/suggestions/regenerate", async () => {
await suggestionsApi.regenerateSuggestions("token-123", "brand-abc");
const url = getCalledUrl();
const method = getCalledMethod();
expect(url).toBe(`${API_BASE}/api/v1/brands/brand-abc/suggestions/regenerate`);
expect(method).toBe("POST");
});
});
describe("updateSuggestionStatus", () => {
it("应使用正确的路径 /api/v1/brands/${brandId}/suggestions/${suggestionId}/status", async () => {
await suggestionsApi.updateSuggestionStatus(
"token-123",
"brand-abc",
"suggestion-xyz",
"completed"
);
const url = getCalledUrl();
const method = getCalledMethod();
expect(url).toBe(
`${API_BASE}/api/v1/brands/brand-abc/suggestions/suggestion-xyz/status`
);
expect(method).toBe("PUT");
});
it("应正确传递状态数据", async () => {
await suggestionsApi.updateSuggestionStatus(
"token-123",
"brand-abc",
"suggestion-xyz",
"completed"
);
const options = mockFetch.mock.calls[0][1];
expect(JSON.parse(options.body)).toEqual({ status: "completed" });
});
});
});

View File

@ -16,7 +16,7 @@ export const suggestionsApi = {
params?: Record<string, string | number> params?: Record<string, string | number>
) => ) =>
fetchWithAuth( fetchWithAuth(
`/api/v1/suggestions/${brandId}/suggestions${buildQuery(params || {})}`, `/api/v1/brands/${brandId}/suggestions${buildQuery(params || {})}`,
{}, {},
token token
), ),
@ -24,7 +24,7 @@ export const suggestionsApi = {
/** 重新生成建议 */ /** 重新生成建议 */
regenerateSuggestions: (token: string, brandId: string) => regenerateSuggestions: (token: string, brandId: string) =>
fetchWithAuth( fetchWithAuth(
`/api/v1/suggestions/${brandId}/suggestions/regenerate`, `/api/v1/brands/${brandId}/suggestions/regenerate`,
{ method: "POST" }, { method: "POST" },
token token
), ),
@ -37,7 +37,7 @@ export const suggestionsApi = {
status: string status: string
) => ) =>
fetchWithAuth( fetchWithAuth(
`/api/v1/suggestions/${brandId}/suggestions/${suggestionId}/status`, `/api/v1/brands/${brandId}/suggestions/${suggestionId}/status`,
{ method: "PUT", body: JSON.stringify({ status }) }, { method: "PUT", body: JSON.stringify({ status }) },
token token
), ),