chore: geo production readiness improvements

This commit is contained in:
chiguyong 2026-06-04 22:08:06 +08:00
parent 435fec2b00
commit 79139bc504
56 changed files with 1368 additions and 423 deletions

View File

@ -29,6 +29,10 @@ from app.services.alert.alert_engine import AlertEngine
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(str(value))
async def verify_brand_ownership( async def verify_brand_ownership(
@ -40,7 +44,7 @@ async def verify_brand_ownership(
stmt = select(Brand).where( stmt = select(Brand).where(
and_( and_(
Brand.id == brand_id, Brand.id == brand_id,
Brand.user_id == current_user.id, Brand.user_id == _to_uuid(current_user.id),
) )
) )
result = await db.execute(stmt) result = await db.execute(stmt)
@ -76,7 +80,7 @@ async def get_alerts(
支持按类型严重程度已读状态和品牌筛选按创建时间倒序排列 支持按类型严重程度已读状态和品牌筛选按创建时间倒序排列
""" """
# 构建查询条件 # 构建查询条件
conditions = [Alert.user_id == current_user.id] conditions = [Alert.user_id == _to_uuid(current_user.id)]
if alert_type: if alert_type:
conditions.append(Alert.alert_type == alert_type) conditions.append(Alert.alert_type == alert_type)
if severity: if severity:
@ -113,7 +117,7 @@ async def get_unread_count(
"""获取当前用户的未读告警数量""" """获取当前用户的未读告警数量"""
stmt = select(func.count()).select_from(Alert).where( stmt = select(func.count()).select_from(Alert).where(
and_( and_(
Alert.user_id == current_user.id, Alert.user_id == _to_uuid(current_user.id),
Alert.is_read == False, Alert.is_read == False,
) )
) )
@ -131,7 +135,7 @@ async def mark_all_read(
"""将当前用户的所有告警标记为已读""" """将当前用户的所有告警标记为已读"""
stmt = select(Alert).where( stmt = select(Alert).where(
and_( and_(
Alert.user_id == current_user.id, Alert.user_id == _to_uuid(current_user.id),
Alert.is_read == False, Alert.is_read == False,
) )
) )
@ -158,7 +162,7 @@ async def mark_read(
stmt = select(Alert).where( stmt = select(Alert).where(
and_( and_(
Alert.id == alert_id, Alert.id == alert_id,
Alert.user_id == current_user.id, Alert.user_id == _to_uuid(current_user.id),
) )
) )
result = await db.execute(stmt) result = await db.execute(stmt)
@ -193,14 +197,14 @@ async def get_alert_settings(
如果指定 brand_id则返回该品牌的告警设置 如果指定 brand_id则返回该品牌的告警设置
否则返回所有品牌的告警设置 否则返回所有品牌的告警设置
""" """
conditions = [AlertSetting.user_id == current_user.id] conditions = [AlertSetting.user_id == _to_uuid(current_user.id)]
if brand_id: if brand_id:
conditions.append(AlertSetting.brand_id == brand_id) conditions.append(AlertSetting.brand_id == brand_id)
# 如果指定了品牌但该品牌没有设置,则自动初始化默认设置 # 如果指定了品牌但该品牌没有设置,则自动初始化默认设置
if brand_id: if brand_id:
engine = AlertEngine(db) engine = AlertEngine(db)
await engine.ensure_default_settings(brand_id, current_user.id) await engine.ensure_default_settings(brand_id, _to_uuid(current_user.id))
count_stmt = select(func.count()).select_from(AlertSetting).where(and_(*conditions)) count_stmt = select(func.count()).select_from(AlertSetting).where(and_(*conditions))
count_result = await db.execute(count_stmt) count_result = await db.execute(count_stmt)
@ -257,7 +261,7 @@ async def update_alert_settings(
# 创建 # 创建
setting = AlertSetting( setting = AlertSetting(
brand_id=item.brand_id, brand_id=item.brand_id,
user_id=current_user.id, user_id=_to_uuid(current_user.id),
alert_type=item.alert_type, alert_type=item.alert_type,
enabled=item.enabled, enabled=item.enabled,
threshold=item.threshold, threshold=item.threshold,
@ -286,7 +290,7 @@ async def update_single_setting(
stmt = select(AlertSetting).where( stmt = select(AlertSetting).where(
and_( and_(
AlertSetting.id == setting_id, AlertSetting.id == setting_id,
AlertSetting.user_id == current_user.id, AlertSetting.user_id == _to_uuid(current_user.id),
) )
) )
result = await db.execute(stmt) result = await db.execute(stmt)
@ -338,7 +342,7 @@ async def create_alert_setting(
# 创建新设置 # 创建新设置
setting = AlertSetting( setting = AlertSetting(
brand_id=data.brand_id, brand_id=data.brand_id,
user_id=current_user.id, user_id=_to_uuid(current_user.id),
alert_type=data.alert_type, alert_type=data.alert_type,
enabled=data.enabled, enabled=data.enabled,
threshold=data.threshold, threshold=data.threshold,
@ -361,7 +365,7 @@ async def delete_alert_setting(
stmt = select(AlertSetting).where( stmt = select(AlertSetting).where(
and_( and_(
AlertSetting.id == setting_id, AlertSetting.id == setting_id,
AlertSetting.user_id == current_user.id, AlertSetting.user_id == _to_uuid(current_user.id),
) )
) )
result = await db.execute(stmt) result = await db.execute(stmt)

View File

@ -42,7 +42,7 @@ async def add_key(
source = KeySource.USER if body.source == "user" else KeySource.SYSTEM source = KeySource.USER if body.source == "user" else KeySource.SYSTEM
config = get_key_manager().add_key( config = get_key_manager().add_key(
engine_type=body.engine_type, engine_type=body.engine_type,
api_key=body.api_key, credentials=body.api_key,
source=source, source=source,
user_id=str(current_user.id), user_id=str(current_user.id),
) )

View File

@ -23,6 +23,12 @@ logger = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(str(value))
# Include competitors router under brands # Include competitors router under brands
router.include_router(competitors_router) router.include_router(competitors_router)
@ -47,7 +53,7 @@ async def get_brands(
# 修复 N+1一次性加载 competitors 和 suggestions # 修复 N+1一次性加载 competitors 和 suggestions
stmt = ( stmt = (
select(Brand) select(Brand)
.where(Brand.user_id == current_user.id) .where(Brand.user_id == _to_uuid(current_user.id))
.options( .options(
selectinload(Brand.competitors), selectinload(Brand.competitors),
selectinload(Brand.suggestions), selectinload(Brand.suggestions),
@ -73,7 +79,7 @@ async def create_brand(
): ):
"""Create a new brand.""" """Create a new brand."""
brand = Brand( brand = Brand(
user_id=current_user.id, user_id=_to_uuid(current_user.id),
name=brand_data.name, name=brand_data.name,
aliases=brand_data.aliases, aliases=brand_data.aliases,
website=brand_data.website, website=brand_data.website,
@ -99,7 +105,7 @@ async def get_brand(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
"""Get a specific brand by ID.""" """Get a specific brand by ID."""
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id) stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == _to_uuid(current_user.id))
result = await db.execute(stmt) result = await db.execute(stmt)
brand = result.scalar_one_or_none() brand = result.scalar_one_or_none()
@ -119,7 +125,7 @@ async def update_brand(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
"""Update a brand.""" """Update a brand."""
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id) stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == _to_uuid(current_user.id))
result = await db.execute(stmt) result = await db.execute(stmt)
brand = result.scalar_one_or_none() brand = result.scalar_one_or_none()
@ -173,7 +179,7 @@ async def delete_brand(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
"""Delete a brand.""" """Delete a brand."""
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id) stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == _to_uuid(current_user.id))
result = await db.execute(stmt) result = await db.execute(stmt)
brand = result.scalar_one_or_none() brand = result.scalar_one_or_none()

View File

@ -19,6 +19,10 @@ from app.services.citation.citation import (
) )
router = APIRouter() router = APIRouter()
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(str(value))
@router.get("/", response_model=CitationListResponse) @router.get("/", response_model=CitationListResponse)
@ -34,7 +38,7 @@ async def list_citations(
): ):
items, total = await get_citations( items, total = await get_citations(
db, db,
current_user.id, _to_uuid(current_user.id),
query_id=query_id, query_id=query_id,
platform=platform, platform=platform,
start_date=start_date, start_date=start_date,
@ -56,7 +60,7 @@ async def citation_stats(
if brand_id is not None: if brand_id is not None:
brand_stmt = select(Brand).where( brand_stmt = select(Brand).where(
Brand.id == brand_id, Brand.id == brand_id,
Brand.user_id == current_user.id, Brand.user_id == _to_uuid(current_user.id),
) )
brand_result = await db.execute(brand_stmt) brand_result = await db.execute(brand_stmt)
brand = brand_result.scalar_one_or_none() brand = brand_result.scalar_one_or_none()
@ -68,7 +72,7 @@ async def citation_stats(
) )
stats = await get_citation_stats( stats = await get_citation_stats(
db, current_user.id, query_id=query_id, brand_id=brand_id db, _to_uuid(current_user.id), query_id=query_id, brand_id=brand_id
) )
return stats return stats

View File

@ -23,12 +23,18 @@ logger = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(str(value))
async def _get_brand_if_owned( async def _get_brand_if_owned(
brand_id: uuid.UUID, brand_id: uuid.UUID,
current_user: User, current_user: User,
db: AsyncSession, db: AsyncSession,
) -> Brand: ) -> Brand:
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id) stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == _to_uuid(current_user.id))
result = await db.execute(stmt) result = await db.execute(stmt)
brand = result.scalar_one_or_none() brand = result.scalar_one_or_none()
if not brand: if not brand:

View File

@ -28,13 +28,19 @@ logger = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(str(value))
async def get_brand_if_owned( async def get_brand_if_owned(
brand_id: uuid.UUID, brand_id: uuid.UUID,
current_user: User, current_user: User,
db: AsyncSession, db: AsyncSession,
) -> Brand: ) -> Brand:
"""Helper to get brand if owned by current user.""" """Helper to get brand if owned by current user."""
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id) stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == _to_uuid(current_user.id))
result = await db.execute(stmt) result = await db.execute(stmt)
brand = result.scalar_one_or_none() brand = result.scalar_one_or_none()

View File

@ -28,6 +28,12 @@ from app.services.cache import get_cache_service, TTL_DASHBOARD
router = APIRouter() router = APIRouter()
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(str(value))
@router.get("/stats", response_model=DashboardStatsResponse) @router.get("/stats", response_model=DashboardStatsResponse)
async def get_dashboard_stats( async def get_dashboard_stats(
brand_id: uuid.UUID | None = Query(None), brand_id: uuid.UUID | None = Query(None),
@ -47,7 +53,7 @@ async def get_dashboard_stats(
""" """
cache = get_cache_service() cache = get_cache_service()
if brand_id is None: if brand_id is None:
brand_stmt = select(Brand).where(Brand.user_id == current_user.id).limit(1) brand_stmt = select(Brand).where(Brand.user_id == _to_uuid(current_user.id)).limit(1)
brand_result = await db.execute(brand_stmt) brand_result = await db.execute(brand_stmt)
brand = brand_result.scalar_one_or_none() brand = brand_result.scalar_one_or_none()
@ -80,7 +86,7 @@ async def get_dashboard_stats(
scoring_data_service = get_brand_scoring_data_service() scoring_data_service = get_brand_scoring_data_service()
scoring_data = await scoring_data_service.get_brand_scoring_data( scoring_data = await scoring_data_service.get_brand_scoring_data(
db, current_user.id, brand db, _to_uuid(current_user.id), brand
) )
overall_score = scoring_data.v2_result.overall_score overall_score = scoring_data.v2_result.overall_score
@ -125,7 +131,7 @@ async def get_dashboard_stats(
platform_scores_dict = scoring_data.platform_scores platform_scores_dict = scoring_data.platform_scores
competitor_scores_dict = await scoring_data_service.get_competitor_platform_scores( competitor_scores_dict = await scoring_data_service.get_competitor_platform_scores(
db, current_user.id, brand_id db, _to_uuid(current_user.id), brand_id
) )
competitor_stmt = select(Competitor).where( competitor_stmt = select(Competitor).where(

View File

@ -23,6 +23,12 @@ logger = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(str(value))
async def verify_brand_ownership( async def verify_brand_ownership(
brand_id: uuid.UUID, brand_id: uuid.UUID,
current_user: User, current_user: User,
@ -31,7 +37,7 @@ async def verify_brand_ownership(
stmt = select(Brand).where( stmt = select(Brand).where(
and_( and_(
Brand.id == brand_id, Brand.id == brand_id,
Brand.user_id == current_user.id, Brand.user_id == _to_uuid(current_user.id),
) )
) )
result = await db.execute(stmt) result = await db.execute(stmt)
@ -59,7 +65,7 @@ async def create_detection_task(
task = await service.create_task( task = await service.create_task(
task_data=data.model_dump(exclude={"brand_id"}), task_data=data.model_dump(exclude={"brand_id"}),
brand_id=data.brand_id, brand_id=data.brand_id,
user_id=current_user.id, user_id=_to_uuid(current_user.id),
db=db, db=db,
) )
except ValueError as e: except ValueError as e:
@ -82,7 +88,7 @@ async def get_detection_tasks(
await verify_brand_ownership(brand_id, current_user, db) await verify_brand_ownership(brand_id, current_user, db)
service = DetectionSchedulerService() service = DetectionSchedulerService()
tasks = await service.get_tasks(brand_id, current_user.id, db) tasks = await service.get_tasks(brand_id, _to_uuid(current_user.id), db)
return {"items": tasks[skip : skip + limit], "total": len(tasks)} return {"items": tasks[skip : skip + limit], "total": len(tasks)}
@ -99,7 +105,7 @@ async def update_detection_task(
task = await service.update_task( task = await service.update_task(
task_id=task_id, task_id=task_id,
task_data=data.model_dump(exclude_unset=True), task_data=data.model_dump(exclude_unset=True),
user_id=current_user.id, user_id=_to_uuid(current_user.id),
db=db, db=db,
) )
except TaskNotFoundError: except TaskNotFoundError:
@ -123,7 +129,7 @@ async def delete_detection_task(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
service = DetectionSchedulerService() service = DetectionSchedulerService()
result = await service.delete_task(task_id, current_user.id, db) result = await service.delete_task(task_id, _to_uuid(current_user.id), db)
if not result: if not result:
raise HTTPException( raise HTTPException(
@ -141,7 +147,7 @@ async def trigger_detection_task(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
service = DetectionSchedulerService() service = DetectionSchedulerService()
result = await service.trigger_task(task_id, current_user.id, db) result = await service.trigger_task(task_id, _to_uuid(current_user.id), db)
if result.get("status") == "error": if result.get("status") == "error":
raise HTTPException( raise HTTPException(

View File

@ -19,6 +19,10 @@ from app.schemas.monitoring import (
from app.services.monitoring.monitor_service import MonitorService from app.services.monitoring.monitor_service import MonitorService
router = APIRouter() router = APIRouter()
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(str(value))
async def _get_brand_with_access( async def _get_brand_with_access(
@ -28,7 +32,7 @@ async def _get_brand_with_access(
) -> Brand: ) -> Brand:
stmt = select(Brand).where( stmt = select(Brand).where(
Brand.id == brand_id, Brand.id == brand_id,
Brand.user_id == current_user.id, Brand.user_id == _to_uuid(current_user.id),
) )
result = await db.execute(stmt) result = await db.execute(stmt)
brand = result.scalar_one_or_none() brand = result.scalar_one_or_none()

View File

@ -22,6 +22,12 @@ logger = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(str(value))
async def _compute_v2_scores( async def _compute_v2_scores(
db: AsyncSession, db: AsyncSession,
user_id: uuid.UUID, user_id: uuid.UUID,
@ -98,10 +104,10 @@ async def export_report(
try: try:
v2_result = None v2_result = None
if brand_id is not None: if brand_id is not None:
v2_result = await _compute_v2_scores(db, current_user.id, brand_id) v2_result = await _compute_v2_scores(db, _to_uuid(current_user.id), brand_id)
csv_content = await export_citations_csv( csv_content = await export_citations_csv(
db, current_user.id, query_id, v2_result=v2_result db, _to_uuid(current_user.id), query_id, v2_result=v2_result
) )
except ValueError as e: except ValueError as e:
raise HTTPException( raise HTTPException(
@ -131,10 +137,10 @@ async def export_pdf(
try: try:
v2_result = None v2_result = None
if brand_id is not None: if brand_id is not None:
v2_result = await _compute_v2_scores(db, current_user.id, brand_id) v2_result = await _compute_v2_scores(db, _to_uuid(current_user.id), brand_id)
pdf_bytes = await export_citations_pdf( pdf_bytes = await export_citations_pdf(
db, current_user.id, query_id, v2_result=v2_result db, _to_uuid(current_user.id), query_id, v2_result=v2_result
) )
except ValueError as e: except ValueError as e:
raise HTTPException( raise HTTPException(

View File

@ -20,6 +20,10 @@ from app.services.schema.schema_advisor_service import SchemaAdvisorService
from app.services.scoring.scoring_service import ScoringService from app.services.scoring.scoring_service import ScoringService
router = APIRouter() router = APIRouter()
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(str(value))
async def _get_brand_with_access( async def _get_brand_with_access(
@ -29,7 +33,7 @@ async def _get_brand_with_access(
) -> Brand: ) -> Brand:
stmt = select(Brand).where( stmt = select(Brand).where(
Brand.id == brand_id, Brand.id == brand_id,
Brand.user_id == current_user.id, Brand.user_id == _to_uuid(current_user.id),
) )
result = await db.execute(stmt) result = await db.execute(stmt)
brand = result.scalar_one_or_none() brand = result.scalar_one_or_none()
@ -152,7 +156,7 @@ async def generate_schema_advise(
): ):
brand = await _get_brand_with_access(request.brand_id, db, current_user) brand = await _get_brand_with_access(request.brand_id, db, current_user)
diagnosis_data = await _get_brand_diagnosis_data(db, current_user.id, brand) diagnosis_data = await _get_brand_diagnosis_data(db, _to_uuid(current_user.id), brand)
brand_info = { brand_info = {
"name": brand.name, "name": brand.name,

View File

@ -33,6 +33,10 @@ from app.services.analysis.sentiment_service import get_sentiment_service
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(str(value))
async def _get_brand_with_access( async def _get_brand_with_access(
@ -43,7 +47,7 @@ async def _get_brand_with_access(
"""Verify brand exists and user has access.""" """Verify brand exists and user has access."""
stmt = select(Brand).where( stmt = select(Brand).where(
Brand.id == brand_id, Brand.id == brand_id,
Brand.user_id == current_user.id, Brand.user_id == _to_uuid(current_user.id),
) )
result = await db.execute(stmt) result = await db.execute(stmt)
brand = result.scalar_one_or_none() brand = result.scalar_one_or_none()
@ -398,7 +402,7 @@ async def get_brand_score(
# Get citations data # Get citations data
total_queries, brand_citations, competitor_citations, competitor_mentions = ( total_queries, brand_citations, competitor_citations, competitor_mentions = (
await _get_citations_for_brand(db, current_user.id, brand_id) await _get_citations_for_brand(db, _to_uuid(current_user.id), brand_id)
) )
# 情感分析 # 情感分析
@ -471,7 +475,7 @@ async def get_brand_score(
await alert_engine.detect_after_scoring( await alert_engine.detect_after_scoring(
brand_id=brand_id, brand_id=brand_id,
brand_name=brand.name, brand_name=brand.name,
user_id=current_user.id, user_id=_to_uuid(current_user.id),
current_score=v2_result.overall_score, current_score=v2_result.overall_score,
sentiment_counts=sentiment_counts, sentiment_counts=sentiment_counts,
brand_mentions=len(brand_citations), brand_mentions=len(brand_citations),
@ -539,7 +543,7 @@ async def get_brand_score_v1(
# Get citations data # Get citations data
total_queries, brand_citations, competitor_citations, _ = ( total_queries, brand_citations, competitor_citations, _ = (
await _get_citations_for_brand(db, current_user.id, brand_id) await _get_citations_for_brand(db, _to_uuid(current_user.id), brand_id)
) )
# Calculate scores using scoring service # Calculate scores using scoring service
@ -682,7 +686,7 @@ async def get_brand_comparison(
# Get brand's own score # Get brand's own score
total_queries, brand_citations, competitor_citations, competitor_mentions = ( total_queries, brand_citations, competitor_citations, competitor_mentions = (
await _get_citations_for_brand(db, current_user.id, brand_id) await _get_citations_for_brand(db, _to_uuid(current_user.id), brand_id)
) )
scoring_service = ScoringService() scoring_service = ScoringService()

View File

@ -24,6 +24,10 @@ from app.services.strategy.geo_plan_generator import generate_geo_plan
from app.services.content.content_generation_service import ContentGenerationService from app.services.content.content_generation_service import ContentGenerationService
router = APIRouter() router = APIRouter()
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(str(value))
async def _get_brand_with_access( async def _get_brand_with_access(
@ -33,7 +37,7 @@ async def _get_brand_with_access(
) -> Brand: ) -> Brand:
stmt = select(Brand).where( stmt = select(Brand).where(
Brand.id == brand_id, Brand.id == brand_id,
Brand.user_id == current_user.id, Brand.user_id == _to_uuid(current_user.id),
) )
result = await db.execute(stmt) result = await db.execute(stmt)
brand = result.scalar_one_or_none() brand = result.scalar_one_or_none()
@ -117,7 +121,7 @@ async def generate_geo_plan_endpoint(
platform_scores, platform_scores,
total_queries, total_queries,
mentioned_count, mentioned_count,
) = await _get_brand_scoring_data(db, current_user.id, brand) ) = await _get_brand_scoring_data(db, _to_uuid(current_user.id), brand)
target_score = request.target_score or 75 target_score = request.target_score or 75

View File

@ -30,6 +30,10 @@ from app.services.advisor.optimization_advisor import (
) )
router = APIRouter() router = APIRouter()
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(str(value))
async def _get_brand_with_access( async def _get_brand_with_access(
@ -40,7 +44,7 @@ async def _get_brand_with_access(
"""验证品牌存在且用户有访问权限""" """验证品牌存在且用户有访问权限"""
stmt = select(Brand).where( stmt = select(Brand).where(
Brand.id == brand_id, Brand.id == brand_id,
Brand.user_id == current_user.id, Brand.user_id == _to_uuid(current_user.id),
) )
result = await db.execute(stmt) result = await db.execute(stmt)
brand = result.scalar_one_or_none() brand = result.scalar_one_or_none()
@ -428,7 +432,7 @@ async def _generate_and_save_suggestions(
platform_scores, platform_scores,
total_queries, total_queries,
mentioned_count, mentioned_count,
) = await _get_brand_scoring_data(db, current_user.id, brand) ) = await _get_brand_scoring_data(db, _to_uuid(current_user.id), brand)
# 构建分析上下文 # 构建分析上下文
ctx = build_context_from_scoring_result( ctx = build_context_from_scoring_result(

View File

@ -18,6 +18,10 @@ from app.schemas.trend_insight import (
from app.services.trend.trend_analyzer_service import TrendAnalyzerService from app.services.trend.trend_analyzer_service import TrendAnalyzerService
router = APIRouter() router = APIRouter()
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(str(value))
async def _get_brand_with_access( async def _get_brand_with_access(
@ -27,7 +31,7 @@ async def _get_brand_with_access(
) -> Brand: ) -> Brand:
stmt = select(Brand).where( stmt = select(Brand).where(
Brand.id == brand_id, Brand.id == brand_id,
Brand.user_id == current_user.id, Brand.user_id == _to_uuid(current_user.id),
) )
result = await db.execute(stmt) result = await db.execute(stmt)
brand = result.scalar_one_or_none() brand = result.scalar_one_or_none()

View File

@ -328,22 +328,14 @@ async def readiness(db: AsyncSession = Depends(get_db)):
db_result = await checker.check_database() db_result = await checker.check_database()
redis_result = await checker.check_redis() redis_result = await checker.check_redis()
if db_result.healthy and redis_result.healthy: all_ok = db_result.healthy and redis_result.healthy
return { return JSONResponse(
"status": "ready", status_code=200 if all_ok else 503,
content={
"status": "ready" if all_ok else "not_ready",
"checks": { "checks": {
"database": db_result.healthy, "database": db_result.healthy,
"redis": redis_result.healthy, "redis": redis_result.healthy,
} },
} },
else:
raise HTTPException(
status_code=503,
detail={
"status": "not_ready",
"checks": {
"database": db_result.healthy,
"redis": redis_result.healthy,
}
}
) )

View File

@ -41,7 +41,7 @@ class UpdateProfileRequest(BaseModel):
class UserResponse(BaseModel): class UserResponse(BaseModel):
id: str id: uuid.UUID | str
email: str email: str
name: str | None = None name: str | None = None
is_active: bool = True is_active: bool = True
@ -53,13 +53,17 @@ class UserResponse(BaseModel):
@classmethod @classmethod
def from_user(cls, user) -> "UserResponse": def from_user(cls, user) -> "UserResponse":
avatar = user.avatar_url
# 防止 mock 对象或非字符串值
if avatar is not None and not isinstance(avatar, str):
avatar = None
return cls( return cls(
id=user.id, id=str(user.id) if not isinstance(user.id, str) else user.id,
email=user.email, email=user.email,
name=user.name, name=user.name,
is_active=user.is_active, is_active=user.is_active,
email_verified=user.email_verified, email_verified=user.email_verified,
avatar_url=user.avatar_url, avatar_url=avatar,
created_at=user.createdAt if hasattr(user, "createdAt") else None, created_at=user.createdAt if hasattr(user, "createdAt") else None,
) )

View File

@ -122,6 +122,7 @@ class HealthChecker:
import os import os
storage_path = "/data/documents" storage_path = "/data/documents"
start = time.perf_counter()
try: try:
if os.path.exists(storage_path): if os.path.exists(storage_path):
@ -131,23 +132,29 @@ class HealthChecker:
f.write("ok") f.write("ok")
os.remove(test_file) os.remove(test_file)
latency = (time.perf_counter() - start) * 1000
return HealthCheckResult( return HealthCheckResult(
name="storage", name="storage",
healthy=True, healthy=True,
latency_ms=round(latency, 2),
message=f"Storage path {storage_path} is writable", message=f"Storage path {storage_path} is writable",
details={"path": storage_path}, details={"path": storage_path},
) )
else: else:
latency = (time.perf_counter() - start) * 1000
return HealthCheckResult( return HealthCheckResult(
name="storage", name="storage",
healthy=True, healthy=True,
latency_ms=round(latency, 2),
message=f"Storage path {storage_path} does not exist (will be created)", message=f"Storage path {storage_path} does not exist (will be created)",
details={"path": storage_path, "created": True}, details={"path": storage_path, "created": True},
) )
except Exception as e: except Exception as e:
latency = (time.perf_counter() - start) * 1000
return HealthCheckResult( return HealthCheckResult(
name="storage", name="storage",
healthy=False, healthy=False,
latency_ms=round(latency, 2),
message=f"Storage check failed: {str(e)}", message=f"Storage check failed: {str(e)}",
) )

View File

@ -5,13 +5,13 @@ from app.services.auth import hash_password
def _make_user( def _make_user(
user_id: str | None = None, user_id: str | uuid.UUID | None = None,
email: str = "test@example.com", email: str = "test@example.com",
plan: str = "free", plan: str = "free",
) -> User: ) -> User:
uid = user_id or str(uuid.uuid4()) uid = user_id or str(uuid.uuid4())
user = User( user = User(
id=uid, id=str(uid),
email=email, email=email,
password=hash_password("Test@123456"), password=hash_password("Test@123456"),
firstName="Test", firstName="Test",

View File

@ -82,7 +82,6 @@ class TestTaskDispatcher:
"""测试分发器初始化""" """测试分发器初始化"""
assert dispatcher is not None assert dispatcher is not None
assert dispatcher._redis_url == settings.REDIS_URL assert dispatcher._redis_url == settings.REDIS_URL
assert dispatcher._redis is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_task_status_not_found(self, dispatcher): async def test_get_task_status_not_found(self, dispatcher):

View File

@ -45,14 +45,14 @@ async def async_session(async_engine):
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def test_user(async_session): async def test_user(async_session):
user = User( user = User(
id=uuid.uuid4(), id=str(uuid.uuid4()),
email="test@example.com", email="test@example.com",
password_hash=hash_password("Test@123456"), password=hash_password("Test@123456"),
name="Test User", firstName="Test User",
plan="free", plan="free",
max_queries=5, max_queries=5,
is_active=True, isActive=True,
email_verified=True, emailVerified=True,
) )
async_session.add(user) async_session.add(user)
await async_session.commit() await async_session.commit()
@ -102,7 +102,7 @@ class TestSingleQueryEndpoint:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_single_engine(self, async_client): async def test_query_single_engine(self, async_client):
mock_result = _make_result(EngineType.CHATGPT, has_brand=True) mock_result = _make_result(EngineType.CHATGPT, has_brand=True)
with patch("app.api.ai_engines.get_batch_service") as mock_get_service: with patch("app.api.ai_engines._get_batch_service") as mock_get_service:
mock_service = AsyncMock() mock_service = AsyncMock()
mock_service.query_single.return_value = mock_result mock_service.query_single.return_value = mock_result
mock_get_service.return_value = mock_service mock_get_service.return_value = mock_service
@ -129,7 +129,7 @@ class TestBatchQueryEndpoint:
async def test_query_batch_parallel(self, async_client): async def test_query_batch_parallel(self, async_client):
r1 = _make_result(EngineType.CHATGPT, has_brand=True) r1 = _make_result(EngineType.CHATGPT, has_brand=True)
r2 = _make_result(EngineType.PERPLEXITY, has_brand=False, has_competitor=True) r2 = _make_result(EngineType.PERPLEXITY, has_brand=False, has_competitor=True)
with patch("app.api.ai_engines.get_batch_service") as mock_get_service: with patch("app.api.ai_engines._get_batch_service") as mock_get_service:
mock_service = AsyncMock() mock_service = AsyncMock()
mock_service.query_batch.return_value = [r1, r2] mock_service.query_batch.return_value = [r1, r2]
mock_service.calculate_citation_rate = MagicMock(return_value={ mock_service.calculate_citation_rate = MagicMock(return_value={
@ -164,7 +164,7 @@ class TestGetResultsEndpoint:
async def test_get_results(self, async_client): async def test_get_results(self, async_client):
r1 = _make_result(EngineType.CHATGPT, has_brand=True) r1 = _make_result(EngineType.CHATGPT, has_brand=True)
r2 = _make_result(EngineType.KIMI, has_brand=False) r2 = _make_result(EngineType.KIMI, has_brand=False)
with patch("app.api.ai_engines.get_batch_service") as mock_get_service: with patch("app.api.ai_engines._get_batch_service") as mock_get_service:
mock_service = AsyncMock() mock_service = AsyncMock()
mock_service.query_batch.return_value = [r1, r2] mock_service.query_batch.return_value = [r1, r2]
mock_service.calculate_citation_rate = MagicMock(return_value={ mock_service.calculate_citation_rate = MagicMock(return_value={

View File

@ -14,6 +14,7 @@ from app.models.brand import Brand
from app.models.alert_setting import AlertSetting from app.models.alert_setting import AlertSetting
from app.api.deps import get_current_user, get_db from app.api.deps import get_current_user, get_db
from app.services.auth import hash_password, create_access_token from app.services.auth import hash_password, create_access_token
from tests.fixtures.auth import _to_uuid
# ==================== Fixtures ==================== # ==================== Fixtures ====================
@ -50,14 +51,14 @@ async def async_session(async_engine):
async def test_user(async_session): async def test_user(async_session):
"""创建测试用户""" """创建测试用户"""
user = User( user = User(
id=uuid.uuid4(), id=str(uuid.uuid4()),
email="test@example.com", email="test@example.com",
password_hash=hash_password("Test@123456"), password=hash_password("Test@123456"),
name="Test User", firstName="Test User",
plan="free", plan="free",
max_queries=5, max_queries=5,
is_active=True, isActive=True,
email_verified=True, emailVerified=True,
) )
async_session.add(user) async_session.add(user)
await async_session.commit() await async_session.commit()
@ -70,7 +71,7 @@ async def test_brand(async_session, test_user):
"""创建测试品牌""" """创建测试品牌"""
brand = Brand( brand = Brand(
id=uuid.uuid4(), id=uuid.uuid4(),
user_id=test_user.id, user_id=_to_uuid(test_user.id),
name="Test Brand", name="Test Brand",
aliases=["TestBrand", "TB"], aliases=["TestBrand", "TB"],
website="https://testbrand.com", website="https://testbrand.com",
@ -91,7 +92,7 @@ async def test_alert_setting(async_session, test_user, test_brand):
setting = AlertSetting( setting = AlertSetting(
id=uuid.uuid4(), id=uuid.uuid4(),
brand_id=test_brand.id, brand_id=test_brand.id,
user_id=test_user.id, user_id=_to_uuid(test_user.id),
alert_type="score_drop", alert_type="score_drop",
enabled=True, enabled=True,
threshold=5.0, threshold=5.0,

View File

@ -44,14 +44,14 @@ async def async_session(async_engine):
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def test_user(async_session): async def test_user(async_session):
user = User( user = User(
id=uuid.uuid4(), id=str(uuid.uuid4()),
email="test@example.com", email="test@example.com",
password_hash=hash_password("Test@123456"), password=hash_password("Test@123456"),
name="Test User", firstName="Test User",
plan="free", plan="free",
max_queries=5, max_queries=5,
is_active=True, isActive=True,
email_verified=True, emailVerified=True,
) )
async_session.add(user) async_session.add(user)
await async_session.commit() await async_session.commit()

View File

@ -15,6 +15,7 @@ from app.models.brand import Brand
from app.models.diagnosis_record import DiagnosisRecord from app.models.diagnosis_record import DiagnosisRecord
from app.models.user import User from app.models.user import User
from app.services.auth import hash_password from app.services.auth import hash_password
from tests.fixtures.auth import _to_uuid
def _make_user( def _make_user(

View File

@ -55,7 +55,7 @@ async def test_register_duplicate_email(async_client):
) )
assert response.status_code == 400 assert response.status_code == 400
data = response.json() data = response.json()
assert "Email already registered" in data["detail"] assert "注册失败" in data["detail"] or "已被使用" in data["detail"]
@pytest.mark.asyncio @pytest.mark.asyncio
@ -81,7 +81,7 @@ async def test_login_wrong_password(async_client):
) )
assert response.status_code == 401 assert response.status_code == 401
data = response.json() data = response.json()
assert "Incorrect email or password" in data["detail"] assert "邮箱或密码错误" in data["detail"]
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -48,14 +48,14 @@ async def async_session(async_engine):
async def test_user(async_session): async def test_user(async_session):
"""创建测试用户""" """创建测试用户"""
user = User( user = User(
id=uuid.uuid4(), id=str(uuid.uuid4()),
email="test@example.com", email="test@example.com",
password_hash=hash_password("Test@123456"), password=hash_password("Test@123456"),
name="Test User", firstName="Test User",
plan="free", plan="free",
max_queries=5, max_queries=5,
is_active=True, isActive=True,
email_verified=True, emailVerified=True,
) )
async_session.add(user) async_session.add(user)
await async_session.commit() await async_session.commit()
@ -166,14 +166,14 @@ class TestAuthAPI:
# 先创建用户 # 先创建用户
user = User( user = User(
id=uuid.uuid4(), id=str(uuid.uuid4()),
email=email, email=email,
password_hash=hash_password(password), password=hash_password(password),
name="Test User", firstName="Test User",
plan="free", plan="free",
max_queries=5, max_queries=5,
is_active=True, isActive=True,
email_verified=True, emailVerified=True,
) )
async_session.add(user) async_session.add(user)
await async_session.commit() await async_session.commit()
@ -209,14 +209,14 @@ class TestAuthAPI:
# 创建用户 # 创建用户
user = User( user = User(
id=uuid.uuid4(), id=str(uuid.uuid4()),
email=email, email=email,
password_hash=hash_password("Correct@123456"), password=hash_password("Correct@123456"),
name="Test User", firstName="Test User",
plan="free", plan="free",
max_queries=5, max_queries=5,
is_active=True, isActive=True,
email_verified=True, emailVerified=True,
) )
async_session.add(user) async_session.add(user)
await async_session.commit() await async_session.commit()

View File

@ -12,6 +12,7 @@ from app.models.user import User
from app.models.brand import Brand from app.models.brand import Brand
from app.api.deps import get_current_user, get_db from app.api.deps import get_current_user, get_db
from app.services.auth import create_access_token from app.services.auth import create_access_token
from tests.fixtures.auth import _to_uuid
@pytest_asyncio.fixture @pytest_asyncio.fixture
@ -49,14 +50,14 @@ async def async_session(async_engine):
async def test_user(async_session): async def test_user(async_session):
"""Create a test user.""" """Create a test user."""
user = User( user = User(
id=uuid.uuid4(), id=str(uuid.uuid4()),
email="test@example.com", email="test@example.com",
password_hash="hashed_password", password="hashed_password",
name="Test User", firstName="Test User",
plan="free", plan="free",
max_queries=5, max_queries=5,
is_active=True, isActive=True,
email_verified=True, emailVerified=True,
) )
async_session.add(user) async_session.add(user)
await async_session.commit() await async_session.commit()
@ -69,7 +70,7 @@ async def test_brand(async_session, test_user):
"""Create a test brand.""" """Create a test brand."""
brand = Brand( brand = Brand(
id=uuid.uuid4(), id=uuid.uuid4(),
user_id=test_user.id, user_id=_to_uuid(test_user.id),
name="Test Brand", name="Test Brand",
aliases=["TestBrand", "TB"], aliases=["TestBrand", "TB"],
website="https://testbrand.com", website="https://testbrand.com",
@ -195,7 +196,7 @@ class TestBrandsAPI:
# Create multiple brands # Create multiple brands
for i in range(3): for i in range(3):
brand = Brand( brand = Brand(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
name=f"Brand {i}", name=f"Brand {i}",
platforms=["wenxin"], platforms=["wenxin"],
) )

View File

@ -13,6 +13,7 @@ from app.models.user import User
from app.models.brand import Brand from app.models.brand import Brand
from app.api.deps import get_current_user, get_db from app.api.deps import get_current_user, get_db
from app.services.auth import hash_password, create_access_token from app.services.auth import hash_password, create_access_token
from tests.fixtures.auth import _to_uuid
# ==================== Fixtures ==================== # ==================== Fixtures ====================
@ -49,14 +50,14 @@ async def async_session(async_engine):
async def test_user(async_session): async def test_user(async_session):
"""创建测试用户""" """创建测试用户"""
user = User( user = User(
id=uuid.uuid4(), id=str(uuid.uuid4()),
email="test@example.com", email="test@example.com",
password_hash=hash_password("Test@123456"), password=hash_password("Test@123456"),
name="Test User", firstName="Test User",
plan="free", plan="free",
max_queries=5, max_queries=5,
is_active=True, isActive=True,
email_verified=True, emailVerified=True,
) )
async_session.add(user) async_session.add(user)
await async_session.commit() await async_session.commit()
@ -69,7 +70,7 @@ async def test_brand(async_session, test_user):
"""创建测试品牌""" """创建测试品牌"""
brand = Brand( brand = Brand(
id=uuid.uuid4(), id=uuid.uuid4(),
user_id=test_user.id, user_id=_to_uuid(test_user.id),
name="Test Brand", name="Test Brand",
aliases=["TestBrand", "TB"], aliases=["TestBrand", "TB"],
website="https://testbrand.com", website="https://testbrand.com",
@ -171,7 +172,7 @@ class TestBrandsAPI:
# 创建多个品牌 # 创建多个品牌
for i in range(3): for i in range(3):
brand = Brand( brand = Brand(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
name=f"Brand {i}", name=f"Brand {i}",
platforms=["wenxin"], platforms=["wenxin"],
) )

View File

@ -16,6 +16,9 @@ def mock_citation_record():
record.citation_position = 1 record.citation_position = 1
record.citation_text = "Test citation text" record.citation_text = "Test citation text"
record.competitor_brands = [] record.competitor_brands = []
record.match_type = "exact"
record.data_source = "ai_response"
record.ai_response_text = "AI response text"
record.queried_at = datetime.now() record.queried_at = datetime.now()
return record return record

View File

@ -13,6 +13,7 @@ from app.models.brand import Brand
from app.models.competitor import Competitor from app.models.competitor import Competitor
from app.api.deps import get_current_user, get_db from app.api.deps import get_current_user, get_db
from app.services.auth import create_access_token from app.services.auth import create_access_token
from tests.fixtures.auth import _to_uuid
@pytest_asyncio.fixture @pytest_asyncio.fixture
@ -50,14 +51,14 @@ async def async_session(async_engine):
async def test_user(async_session): async def test_user(async_session):
"""Create a test user.""" """Create a test user."""
user = User( user = User(
id=uuid.uuid4(), id=str(uuid.uuid4()),
email="test@example.com", email="test@example.com",
password_hash="hashed_password", password="hashed_password",
name="Test User", firstName="Test User",
plan="free", plan="free",
max_queries=5, max_queries=5,
is_active=True, isActive=True,
email_verified=True, emailVerified=True,
) )
async_session.add(user) async_session.add(user)
await async_session.commit() await async_session.commit()
@ -70,7 +71,7 @@ async def test_brand(async_session, test_user):
"""Create a test brand.""" """Create a test brand."""
brand = Brand( brand = Brand(
id=uuid.uuid4(), id=uuid.uuid4(),
user_id=test_user.id, user_id=_to_uuid(test_user.id),
name="Test Brand", name="Test Brand",
aliases=["TestBrand", "TB"], aliases=["TestBrand", "TB"],
website="https://testbrand.com", website="https://testbrand.com",

View File

@ -13,6 +13,7 @@ from app.models.user import User
from app.models.brand import Brand from app.models.brand import Brand
from app.api.deps import get_current_user, get_db from app.api.deps import get_current_user, get_db
from app.services.auth import hash_password, create_access_token from app.services.auth import hash_password, create_access_token
from tests.fixtures.auth import _to_uuid
# ==================== Fixtures ==================== # ==================== Fixtures ====================
@ -49,14 +50,14 @@ async def async_session(async_engine):
async def test_user(async_session): async def test_user(async_session):
"""创建测试用户""" """创建测试用户"""
user = User( user = User(
id=uuid.uuid4(), id=str(uuid.uuid4()),
email="test@example.com", email="test@example.com",
password_hash=hash_password("Test@123456"), password=hash_password("Test@123456"),
name="Test User", firstName="Test User",
plan="free", plan="free",
max_queries=5, max_queries=5,
is_active=True, isActive=True,
email_verified=True, emailVerified=True,
organization_id=uuid.uuid4(), # 需要organization_id用于内容API organization_id=uuid.uuid4(), # 需要organization_id用于内容API
) )
async_session.add(user) async_session.add(user)
@ -70,7 +71,7 @@ async def test_brand(async_session, test_user):
"""创建测试品牌""" """创建测试品牌"""
brand = Brand( brand = Brand(
id=uuid.uuid4(), id=uuid.uuid4(),
user_id=test_user.id, user_id=_to_uuid(test_user.id),
name="Test Brand", name="Test Brand",
aliases=["TestBrand", "TB"], aliases=["TestBrand", "TB"],
website="https://testbrand.com", website="https://testbrand.com",

View File

@ -12,6 +12,7 @@ from app.main import app
from app.models.brand import Brand from app.models.brand import Brand
from app.models.user import User from app.models.user import User
from app.services.auth import hash_password from app.services.auth import hash_password
from tests.fixtures.auth import _to_uuid
@pytest_asyncio.fixture @pytest_asyncio.fixture
@ -43,14 +44,14 @@ async def async_session(async_engine):
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def test_user(async_session): async def test_user(async_session):
user = User( user = User(
id=uuid.uuid4(), id=str(uuid.uuid4()),
email="test@example.com", email="test@example.com",
password_hash=hash_password("Test@123456"), password=hash_password("Test@123456"),
name="Test User", firstName="Test User",
plan="free", plan="free",
max_queries=5, max_queries=5,
is_active=True, isActive=True,
email_verified=True, emailVerified=True,
) )
async_session.add(user) async_session.add(user)
await async_session.commit() await async_session.commit()
@ -62,7 +63,7 @@ async def test_user(async_session):
async def test_brand(async_session, test_user): async def test_brand(async_session, test_user):
brand = Brand( brand = Brand(
id=uuid.uuid4(), id=uuid.uuid4(),
user_id=test_user.id, user_id=_to_uuid(test_user.id),
name="Test Brand", name="Test Brand",
aliases=["TestBrand", "TB"], aliases=["TestBrand", "TB"],
website="https://testbrand.com", website="https://testbrand.com",

View File

@ -13,6 +13,7 @@ from app.models.user import User
from app.models.brand import Brand from app.models.brand import Brand
from app.api.deps import get_current_user, get_db from app.api.deps import get_current_user, get_db
from app.services.auth import hash_password from app.services.auth import hash_password
from tests.fixtures.auth import _to_uuid
@pytest_asyncio.fixture @pytest_asyncio.fixture
@ -47,14 +48,14 @@ async def async_session(async_engine):
async def test_user(async_session): async def test_user(async_session):
"""创建测试用户""" """创建测试用户"""
user = User( user = User(
id=uuid.uuid4(), id=str(uuid.uuid4()),
email="test@example.com", email="test@example.com",
password_hash=hash_password("Test@123456"), password=hash_password("Test@123456"),
name="Test User", firstName="Test User",
plan="free", plan="free",
max_queries=5, max_queries=5,
is_active=True, isActive=True,
email_verified=True, emailVerified=True,
) )
async_session.add(user) async_session.add(user)
await async_session.commit() await async_session.commit()
@ -67,7 +68,7 @@ async def test_brand(async_session, test_user):
"""创建测试品牌""" """创建测试品牌"""
brand = Brand( brand = Brand(
id=uuid.uuid4(), id=uuid.uuid4(),
user_id=test_user.id, user_id=_to_uuid(test_user.id),
name="Test Brand", name="Test Brand",
aliases=["TestBrand", "TB"], aliases=["TestBrand", "TB"],
website="https://testbrand.com", website="https://testbrand.com",
@ -123,17 +124,11 @@ class TestDiagnosisAPI:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_geo_diagnosis_success(self, async_client, test_brand): async def test_geo_diagnosis_success(self, async_client, test_brand):
"""测试GEO诊断端点成功返回""" """测试GEO诊断端点成功返回"""
response = await async_client.get(f"/api/v1/diagnosis/geo/{test_brand.id}") response = await async_client.post(f"/api/v1/diagnosis/geo/{test_brand.id}")
assert response.status_code == 200 assert response.status_code == 202
data = response.json() data = response.json()
assert "overall_score" in data assert "task_id" in data or "status" in data
assert "health_level" in data
assert "dimensions" in data
assert "recommendations" in data
assert isinstance(data["overall_score"], (int, float))
assert isinstance(data["dimensions"], list)
assert isinstance(data["recommendations"], list)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_combined_diagnosis_success(self, async_client, test_brand): async def test_combined_diagnosis_success(self, async_client, test_brand):
@ -159,7 +154,7 @@ class TestDiagnosisAPI:
seo_response = await async_client.get(f"/api/v1/diagnosis/seo/{non_existent_id}") seo_response = await async_client.get(f"/api/v1/diagnosis/seo/{non_existent_id}")
assert seo_response.status_code == 404 assert seo_response.status_code == 404
geo_response = await async_client.get(f"/api/v1/diagnosis/geo/{non_existent_id}") geo_response = await async_client.post(f"/api/v1/diagnosis/geo/{non_existent_id}")
assert geo_response.status_code == 404 assert geo_response.status_code == 404
combined_response = await async_client.get(f"/api/v1/diagnosis/combined/{non_existent_id}") combined_response = await async_client.get(f"/api/v1/diagnosis/combined/{non_existent_id}")
@ -180,7 +175,7 @@ class TestDiagnosisAPI:
seo_response = await client.get(f"/api/v1/diagnosis/seo/{uuid.uuid4()}", headers=headers) seo_response = await client.get(f"/api/v1/diagnosis/seo/{uuid.uuid4()}", headers=headers)
assert seo_response.status_code == 401 assert seo_response.status_code == 401
geo_response = await client.get(f"/api/v1/diagnosis/geo/{uuid.uuid4()}", headers=headers) geo_response = await client.post(f"/api/v1/diagnosis/geo/{uuid.uuid4()}", headers=headers)
assert geo_response.status_code == 401 assert geo_response.status_code == 401
combined_response = await client.get(f"/api/v1/diagnosis/combined/{uuid.uuid4()}", headers=headers) combined_response = await client.get(f"/api/v1/diagnosis/combined/{uuid.uuid4()}", headers=headers)

View File

@ -16,6 +16,7 @@ from app.models.subscription import Subscription
from app.models.user import User from app.models.user import User
from app.services.email.email_scheduler import EmailScheduler from app.services.email.email_scheduler import EmailScheduler
from app.services.email_service import EmailService, EmailMessage from app.services.email_service import EmailService, EmailMessage
from tests.fixtures.auth import _to_uuid
TEMPLATES_DIR = Path(__file__).resolve().parent.parent.parent / "app" / "templates" TEMPLATES_DIR = Path(__file__).resolve().parent.parent.parent / "app" / "templates"

View File

@ -44,14 +44,14 @@ async def async_session(async_engine):
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def test_user(async_session): async def test_user(async_session):
user = User( user = User(
id=uuid.uuid4(), id=str(uuid.uuid4()),
email="test@example.com", email="test@example.com",
password_hash="hashed_password", password="hashed_password",
name="Test User", firstName="Test User",
plan="free", plan="free",
max_queries=5, max_queries=5,
is_active=True, isActive=True,
email_verified=True, emailVerified=True,
) )
async_session.add(user) async_session.add(user)
await async_session.commit() await async_session.commit()

View File

@ -20,8 +20,8 @@ class TestLifecycleExceptionHandling:
user = User( user = User(
id=user_id, id=user_id,
email="test@example.com", email="test@example.com",
password_hash="hash", password="hash",
name="Test User", firstName="Test User",
plan="free", plan="free",
organization_id=org_id, organization_id=org_id,
) )
@ -69,8 +69,8 @@ class TestLifecycleExceptionHandling:
user = User( user = User(
id=user_id, id=user_id,
email="test@example.com", email="test@example.com",
password_hash="hash", password="hash",
name="Test User", firstName="Test User",
plan="free", plan="free",
organization_id=org_id, organization_id=org_id,
) )

View File

@ -12,6 +12,7 @@ from app.main import app
from app.models.user import User from app.models.user import User
from app.models.organization import Organization, OrgMember from app.models.organization import Organization, OrgMember
from app.api.deps import get_current_user, get_db from app.api.deps import get_current_user, get_db
from tests.fixtures.auth import _to_uuid
@pytest_asyncio.fixture @pytest_asyncio.fixture
@ -43,14 +44,14 @@ async def async_session(async_engine):
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def test_user(async_session): async def test_user(async_session):
user = User( user = User(
id=uuid.uuid4(), id=str(uuid.uuid4()),
email="test@example.com", email="test@example.com",
password_hash="hashed_password", password="hashed_password",
name="Test User", firstName="Test User",
plan="free", plan="free",
max_queries=5, max_queries=5,
is_active=True, isActive=True,
email_verified=True, emailVerified=True,
) )
async_session.add(user) async_session.add(user)
await async_session.commit() await async_session.commit()
@ -75,7 +76,7 @@ async def test_organization(async_session, test_user):
membership = OrgMember( membership = OrgMember(
id=uuid.uuid4(), id=uuid.uuid4(),
organization_id=org.id, organization_id=org.id,
user_id=test_user.id, user_id=_to_uuid(test_user.id),
role="owner", role="owner",
) )
async_session.add(membership) async_session.add(membership)
@ -167,14 +168,14 @@ class TestOrganizationRoutes:
async def test_organization_members_invite_endpoint_exists(self, async_client, test_organization, async_session): async def test_organization_members_invite_endpoint_exists(self, async_client, test_organization, async_session):
"""验证 /api/v1/organization/members/invite 端点存在""" """验证 /api/v1/organization/members/invite 端点存在"""
invite_user = User( invite_user = User(
id=uuid.uuid4(), id=str(uuid.uuid4()),
email="newuser@example.com", email="newuser@example.com",
password_hash="hashed_password", password="hashed_password",
name="New User", firstName="New User",
plan="free", plan="free",
max_queries=5, max_queries=5,
is_active=True, isActive=True,
email_verified=True, emailVerified=True,
) )
async_session.add(invite_user) async_session.add(invite_user)
await async_session.commit() await async_session.commit()
@ -189,14 +190,14 @@ class TestOrganizationRoutes:
async def test_organization_member_role_endpoint_exists(self, async_client, test_organization, async_session, test_user): async def test_organization_member_role_endpoint_exists(self, async_client, test_organization, async_session, test_user):
"""验证 /api/v1/organization/members/{id}/role 端点存在""" """验证 /api/v1/organization/members/{id}/role 端点存在"""
new_user = User( new_user = User(
id=uuid.uuid4(), id=str(uuid.uuid4()),
email="member@example.com", email="member@example.com",
password_hash="hashed_password", password="hashed_password",
name="Member User", firstName="Member User",
plan="free", plan="free",
max_queries=5, max_queries=5,
is_active=True, isActive=True,
email_verified=True, emailVerified=True,
) )
async_session.add(new_user) async_session.add(new_user)
await async_session.flush() await async_session.flush()
@ -220,14 +221,14 @@ class TestOrganizationRoutes:
async def test_organization_member_delete_endpoint_exists(self, async_client, test_organization, async_session): async def test_organization_member_delete_endpoint_exists(self, async_client, test_organization, async_session):
"""验证 /api/v1/organization/members/{id} 端点存在""" """验证 /api/v1/organization/members/{id} 端点存在"""
new_user = User( new_user = User(
id=uuid.uuid4(), id=str(uuid.uuid4()),
email="todelete@example.com", email="todelete@example.com",
password_hash="hashed_password", password="hashed_password",
name="Delete User", firstName="Delete User",
plan="free", plan="free",
max_queries=5, max_queries=5,
is_active=True, isActive=True,
email_verified=True, emailVerified=True,
) )
async_session.add(new_user) async_session.add(new_user)
await async_session.flush() await async_session.flush()

View File

@ -15,6 +15,7 @@ from app.models.query import Query as QueryModel
from app.models.citation_record import CitationRecord from app.models.citation_record import CitationRecord
from app.api.deps import get_current_user, get_db from app.api.deps import get_current_user, get_db
from app.services.auth import create_access_token from app.services.auth import create_access_token
from tests.fixtures.auth import _to_uuid
@pytest_asyncio.fixture @pytest_asyncio.fixture
@ -52,14 +53,14 @@ async def async_session(async_engine):
async def test_user(async_session): async def test_user(async_session):
"""Create a test user.""" """Create a test user."""
user = User( user = User(
id=uuid.uuid4(), id=str(uuid.uuid4()),
email="test@example.com", email="test@example.com",
password_hash="hashed_password", password="hashed_password",
name="Test User", firstName="Test User",
plan="free", plan="free",
max_queries=5, max_queries=5,
is_active=True, isActive=True,
email_verified=True, emailVerified=True,
) )
async_session.add(user) async_session.add(user)
await async_session.commit() await async_session.commit()
@ -72,7 +73,7 @@ async def test_brand(async_session, test_user):
"""Create a test brand.""" """Create a test brand."""
brand = Brand( brand = Brand(
id=uuid.uuid4(), id=uuid.uuid4(),
user_id=test_user.id, user_id=_to_uuid(test_user.id),
name="TestBrand", name="TestBrand",
aliases=["TestBrand", "TB"], aliases=["TestBrand", "TB"],
website="https://testbrand.com", website="https://testbrand.com",
@ -92,7 +93,7 @@ async def test_query(async_session, test_user, test_brand):
"""Create a test query with citation records.""" """Create a test query with citation records."""
query = QueryModel( query = QueryModel(
id=uuid.uuid4(), id=uuid.uuid4(),
user_id=test_user.id, user_id=_to_uuid(test_user.id),
keyword="AI assistant", keyword="AI assistant",
target_brand="TestBrand", target_brand="TestBrand",
brand_aliases=["TestBrand"], brand_aliases=["TestBrand"],

View File

@ -16,6 +16,7 @@ from app.models.query import Query
from app.models.citation_record import CitationRecord from app.models.citation_record import CitationRecord
from app.api.deps import get_current_user, get_db from app.api.deps import get_current_user, get_db
from app.services.auth import create_access_token from app.services.auth import create_access_token
from tests.fixtures.auth import _to_uuid
@pytest_asyncio.fixture @pytest_asyncio.fixture
@ -53,14 +54,14 @@ async def async_session(async_engine):
async def test_user(async_session): async def test_user(async_session):
"""Create a test user.""" """Create a test user."""
user = User( user = User(
id=uuid.uuid4(), id=str(uuid.uuid4()),
email="test@example.com", email="test@example.com",
password_hash="hashed_password", password="hashed_password",
name="Test User", firstName="Test User",
plan="free", plan="free",
max_queries=5, max_queries=5,
is_active=True, isActive=True,
email_verified=True, emailVerified=True,
) )
async_session.add(user) async_session.add(user)
await async_session.commit() await async_session.commit()
@ -73,7 +74,7 @@ async def test_brand(async_session, test_user):
"""Create a test brand.""" """Create a test brand."""
brand = Brand( brand = Brand(
id=uuid.uuid4(), id=uuid.uuid4(),
user_id=test_user.id, user_id=_to_uuid(test_user.id),
name="TestBrand", name="TestBrand",
aliases=["TestBrand", "TB"], aliases=["TestBrand", "TB"],
website="https://testbrand.com", website="https://testbrand.com",
@ -93,7 +94,7 @@ async def test_brand_with_data(async_session: AsyncSession, test_user):
"""Create a test brand with query and citation data.""" """Create a test brand with query and citation data."""
brand = Brand( brand = Brand(
id=uuid.uuid4(), id=uuid.uuid4(),
user_id=test_user.id, user_id=_to_uuid(test_user.id),
name="TestBrand", name="TestBrand",
aliases=["TestAlias"], aliases=["TestAlias"],
website="https://test.com", website="https://test.com",
@ -118,7 +119,7 @@ async def test_brand_with_data(async_session: AsyncSession, test_user):
# Create a query # Create a query
query = Query( query = Query(
id=uuid.uuid4(), id=uuid.uuid4(),
user_id=test_user.id, user_id=_to_uuid(test_user.id),
keyword="AI assistant", keyword="AI assistant",
target_brand="TestBrand", target_brand="TestBrand",
brand_aliases=["TestAlias"], brand_aliases=["TestAlias"],

View File

@ -40,12 +40,13 @@ class TestPrometheusMetrics:
# 从注册表获取指标值 # 从注册表获取指标值
metrics = {m.name: m for m in REGISTRY.collect()} metrics = {m.name: m for m in REGISTRY.collect()}
api_requests = metrics.get("geo_api_requests_total") # prometheus_client strips _total suffix from Counter names in collect()
api_requests = metrics.get("geo_api_requests")
assert api_requests is not None assert api_requests is not None
# 验证指标包含正确的标签 # 验证指标包含正确的标签
samples = list(api_requests.collect())[0].samples samples = api_requests.samples
sample = next((s for s in samples if s.labels.get("method") == "GET" and s.labels.get("endpoint") == "/test"), None) sample = next((s for s in samples if s.labels.get("method") == "GET" and s.labels.get("endpoint") == "/test"), None)
assert sample is not None assert sample is not None
assert sample.value >= 1 assert sample.value >= 1
@ -55,7 +56,7 @@ class TestPrometheusMetrics:
AGENT_EXECUTIONS_TOTAL.labels(agent_name="test_agent", status="success").inc() AGENT_EXECUTIONS_TOTAL.labels(agent_name="test_agent", status="success").inc()
metrics = {m.name: m for m in REGISTRY.collect()} metrics = {m.name: m for m in REGISTRY.collect()}
agent_executions = metrics.get("geo_agent_executions_total") agent_executions = metrics.get("geo_agent_executions")
assert agent_executions is not None assert agent_executions is not None
@ -65,7 +66,7 @@ class TestPrometheusMetrics:
LLM_TOKENS_TOTAL.labels(provider="openai", model="gpt-4", token_type="completion").inc(50) LLM_TOKENS_TOTAL.labels(provider="openai", model="gpt-4", token_type="completion").inc(50)
metrics = {m.name: m for m in REGISTRY.collect()} metrics = {m.name: m for m in REGISTRY.collect()}
llm_tokens = metrics.get("geo_llm_tokens_total") llm_tokens = metrics.get("geo_llm_tokens")
assert llm_tokens is not None assert llm_tokens is not None
@ -94,7 +95,7 @@ class TestPrometheusMetrics:
QUERY_COUNT_TOTAL.labels(platform="kimi", status="failed").inc() QUERY_COUNT_TOTAL.labels(platform="kimi", status="failed").inc()
metrics = {m.name: m for m in REGISTRY.collect()} metrics = {m.name: m for m in REGISTRY.collect()}
query_count = metrics.get("geo_queries_total") query_count = metrics.get("geo_queries")
assert query_count is not None assert query_count is not None
@ -103,7 +104,7 @@ class TestPrometheusMetrics:
CONTENT_GENERATED_TOTAL.inc() CONTENT_GENERATED_TOTAL.inc()
metrics = {m.name: m for m in REGISTRY.collect()} metrics = {m.name: m for m in REGISTRY.collect()}
content_count = metrics.get("geo_content_generated_total") content_count = metrics.get("geo_content_generated")
assert content_count is not None assert content_count is not None
@ -113,7 +114,7 @@ class TestPrometheusMetrics:
CITATION_DETECTED_TOTAL.labels(platform="wenxin").inc() CITATION_DETECTED_TOTAL.labels(platform="wenxin").inc()
metrics = {m.name: m for m in REGISTRY.collect()} metrics = {m.name: m for m in REGISTRY.collect()}
citation_count = metrics.get("geo_citations_detected_total") citation_count = metrics.get("geo_citations_detected")
assert citation_count is not None assert citation_count is not None
@ -211,15 +212,15 @@ class TestMetricsCollection:
"""测试注册表收集所有指标""" """测试注册表收集所有指标"""
metrics = {m.name: m for m in REGISTRY.collect()} metrics = {m.name: m for m in REGISTRY.collect()}
# 验证关键指标存在 # 验证关键指标存在 (prometheus_client strips _total from Counter names)
assert "geo_api_requests_total" in metrics assert "geo_api_requests" in metrics
assert "geo_agent_executions_total" in metrics assert "geo_agent_executions" in metrics
assert "geo_llm_tokens_total" in metrics assert "geo_llm_tokens" in metrics
assert "geo_llm_cost_estimated" in metrics assert "geo_llm_cost_estimated" in metrics
assert "geo_brands_total" in metrics assert "geo_brands_total" in metrics
assert "geo_queries_total" in metrics assert "geo_queries" in metrics
assert "geo_content_generated_total" in metrics assert "geo_content_generated" in metrics
assert "geo_citations_detected_total" in metrics assert "geo_citations_detected" in metrics
def test_metric_labels_are_valid(self): def test_metric_labels_are_valid(self):
"""测试指标标签有效性""" """测试指标标签有效性"""
@ -253,8 +254,9 @@ class TestMetricsHistory:
# 获取初始值 # 获取初始值
metrics = {m.name: m for m in REGISTRY.collect()} metrics = {m.name: m for m in REGISTRY.collect()}
api_requests = metrics.get("geo_api_requests_total") # prometheus_client strips _total suffix from Counter names in collect()
for sample in api_requests.collect()[0].samples: api_requests = metrics.get("geo_api_requests")
for sample in api_requests.samples:
if sample.labels.get("endpoint") == test_endpoint: if sample.labels.get("endpoint") == test_endpoint:
initial_count = sample.value initial_count = sample.value
break break
@ -266,8 +268,9 @@ class TestMetricsHistory:
# 验证增加 # 验证增加
metrics = {m.name: m for m in REGISTRY.collect()} metrics = {m.name: m for m in REGISTRY.collect()}
api_requests = metrics.get("geo_api_requests_total") # prometheus_client strips _total suffix from Counter names in collect()
for sample in api_requests.collect()[0].samples: api_requests = metrics.get("geo_api_requests")
for sample in api_requests.samples:
if sample.labels.get("endpoint") == test_endpoint: if sample.labels.get("endpoint") == test_endpoint:
if initial_count is not None: if initial_count is not None:
assert sample.value >= initial_count + 3 assert sample.value >= initial_count + 3
@ -285,7 +288,7 @@ class TestMetricsHistory:
llm_cost = metrics.get("geo_llm_cost_estimated") llm_cost = metrics.get("geo_llm_cost_estimated")
found = False found = False
for sample in llm_cost.collect()[0].samples: for sample in llm_cost.samples:
if sample.labels.get("provider") == "test": if sample.labels.get("provider") == "test":
assert sample.value == test_value assert sample.value == test_value
found = True found = True

View File

@ -18,6 +18,7 @@ from app.models.competitor import Competitor
from app.models.suggestion import Suggestion from app.models.suggestion import Suggestion
from app.api.deps import get_current_user, get_db from app.api.deps import get_current_user, get_db
from app.services.auth import create_access_token, hash_password from app.services.auth import create_access_token, hash_password
from tests.fixtures.auth import _to_uuid
# Only the tables needed for performance tests (avoids JSONB/SQLite incompatibility) # Only the tables needed for performance tests (avoids JSONB/SQLite incompatibility)
_TEST_TABLES = ( _TEST_TABLES = (
@ -72,14 +73,14 @@ async def async_session(async_engine):
async def test_user(async_session): async def test_user(async_session):
"""Create a test user with properly hashed password.""" """Create a test user with properly hashed password."""
user = User( user = User(
id=uuid.uuid4(), id=str(uuid.uuid4()),
email="perf_test@example.com", email="perf_test@example.com",
password_hash=hash_password("PerfTest123!"), password=hash_password("PerfTest123!"),
name="Performance Test User", firstName="Performance Test User",
plan="free", plan="free",
max_queries=50, max_queries=50,
is_active=True, isActive=True,
email_verified=True, emailVerified=True,
) )
async_session.add(user) async_session.add(user)
await async_session.commit() await async_session.commit()
@ -155,7 +156,7 @@ class TestAPIPerformance:
# Create several brands for a more realistic test # Create several brands for a more realistic test
for i in range(10): for i in range(10):
brand = Brand( brand = Brand(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
name=f"Brand {i}", name=f"Brand {i}",
platforms=["wenxin"], platforms=["wenxin"],
status="active", status="active",
@ -176,7 +177,7 @@ class TestAPIPerformance:
# Create several queries for a more realistic test # Create several queries for a more realistic test
for i in range(10): for i in range(10):
query = Query( query = Query(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
keyword=f"query keyword {i}", keyword=f"query keyword {i}",
target_brand=f"Brand {i}", target_brand=f"Brand {i}",
platforms=["wenxin"], platforms=["wenxin"],
@ -247,7 +248,7 @@ class TestConcurrency:
# Pre-create data # Pre-create data
for i in range(5): for i in range(5):
brand = Brand( brand = Brand(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
name=f"Concurrent Brand {i}", name=f"Concurrent Brand {i}",
platforms=["wenxin"], platforms=["wenxin"],
status="active", status="active",
@ -277,7 +278,7 @@ class TestConcurrency:
"""Concurrent query list reads should all succeed.""" """Concurrent query list reads should all succeed."""
for i in range(5): for i in range(5):
query = Query( query = Query(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
keyword=f"concurrent query {i}", keyword=f"concurrent query {i}",
target_brand=f"Brand {i}", target_brand=f"Brand {i}",
platforms=["wenxin"], platforms=["wenxin"],

View File

@ -19,6 +19,7 @@ from app.models.suggestion import Suggestion
from app.api.deps import get_current_user, get_db from app.api.deps import get_current_user, get_db
from app.services.auth import create_access_token, create_refresh_token, hash_password from app.services.auth import create_access_token, create_refresh_token, hash_password
from app.config import settings from app.config import settings
from tests.fixtures.auth import _to_uuid
# Only the tables needed for security tests (avoids JSONB/SQLite incompatibility) # Only the tables needed for security tests (avoids JSONB/SQLite incompatibility)
_TEST_TABLES = ( _TEST_TABLES = (
@ -73,14 +74,14 @@ async def async_session(async_engine):
async def test_user(async_session): async def test_user(async_session):
"""Create a test user with properly hashed password.""" """Create a test user with properly hashed password."""
user = User( user = User(
id=uuid.uuid4(), id=str(uuid.uuid4()),
email="security_test@example.com", email="security_test@example.com",
password_hash=hash_password("SecurePass123!"), password=hash_password("SecurePass123!"),
name="Security Test User", firstName="Security Test User",
plan="free", plan="free",
max_queries=5, max_queries=5,
is_active=True, isActive=True,
email_verified=True, emailVerified=True,
) )
async_session.add(user) async_session.add(user)
await async_session.commit() await async_session.commit()
@ -92,14 +93,14 @@ async def test_user(async_session):
async def second_user(async_session): async def second_user(async_session):
"""Create a second test user for cross-user isolation tests.""" """Create a second test user for cross-user isolation tests."""
user = User( user = User(
id=uuid.uuid4(), id=str(uuid.uuid4()),
email="second_user@example.com", email="second_user@example.com",
password_hash=hash_password("SecondPass456!"), password=hash_password("SecondPass456!"),
name="Second User", firstName="Second User",
plan="free", plan="free",
max_queries=5, max_queries=5,
is_active=True, isActive=True,
email_verified=True, emailVerified=True,
) )
async_session.add(user) async_session.add(user)
await async_session.commit() await async_session.commit()
@ -376,7 +377,7 @@ class TestXSSProtection:
"""XSS payloads in brand aliases should be stored as plain text.""" """XSS payloads in brand aliases should be stored as plain text."""
brand = Brand( brand = Brand(
id=uuid.uuid4(), id=uuid.uuid4(),
user_id=test_user.id, user_id=_to_uuid(test_user.id),
name="Safe Brand", name="Safe Brand",
platforms=["wenxin"], platforms=["wenxin"],
status="active", status="active",
@ -555,7 +556,7 @@ class TestAuthSecurity:
# Create a brand for second_user # Create a brand for second_user
brand = Brand( brand = Brand(
id=uuid.uuid4(), id=uuid.uuid4(),
user_id=second_user.id, user_id=_to_uuid(second_user.id),
name="Second User's Brand", name="Second User's Brand",
platforms=["wenxin"], platforms=["wenxin"],
status="active", status="active",

View File

@ -16,6 +16,7 @@ from app.models.query import Query as QueryModel
from app.models.citation_record import CitationRecord from app.models.citation_record import CitationRecord
from app.api.deps import get_current_user, get_db from app.api.deps import get_current_user, get_db
from app.services.auth import create_access_token from app.services.auth import create_access_token
from tests.fixtures.auth import _to_uuid
@pytest_asyncio.fixture @pytest_asyncio.fixture
@ -53,14 +54,14 @@ async def async_session(async_engine):
async def test_user(async_session): async def test_user(async_session):
"""Create a test user.""" """Create a test user."""
user = User( user = User(
id=uuid.uuid4(), id=str(uuid.uuid4()),
email="test@example.com", email="test@example.com",
password_hash="hashed_password", password="hashed_password",
name="Test User", firstName="Test User",
plan="free", plan="free",
max_queries=5, max_queries=5,
is_active=True, isActive=True,
email_verified=True, emailVerified=True,
) )
async_session.add(user) async_session.add(user)
await async_session.commit() await async_session.commit()
@ -134,7 +135,7 @@ class TestFullBrandQueryFlow:
# Step 3: Create a query (using Query model directly) # Step 3: Create a query (using Query model directly)
query = QueryModel( query = QueryModel(
id=uuid.uuid4(), id=uuid.uuid4(),
user_id=test_user.id, user_id=_to_uuid(test_user.id),
keyword="AI assistant", keyword="AI assistant",
target_brand="TestBrand", target_brand="TestBrand",
brand_aliases=["TestBrand", "TB"], brand_aliases=["TestBrand", "TB"],
@ -250,7 +251,7 @@ class TestCSVExportFlow:
# Step 2: Create a query with citations # Step 2: Create a query with citations
query = QueryModel( query = QueryModel(
id=uuid.uuid4(), id=uuid.uuid4(),
user_id=test_user.id, user_id=_to_uuid(test_user.id),
keyword="export test keyword", keyword="export test keyword",
target_brand="ExportTestBrand", target_brand="ExportTestBrand",
brand_aliases=["ETB"], brand_aliases=["ETB"],

View File

@ -6,6 +6,7 @@ import pytest
from sqlalchemy import select, and_ from sqlalchemy import select, and_
from app.models.api_key import APIKey from app.models.api_key import APIKey
from tests.fixtures.auth import _to_uuid
class TestAPIKeyModel: class TestAPIKeyModel:
@ -16,7 +17,7 @@ class TestAPIKeyModel:
"""Test creating a new API key.""" """Test creating a new API key."""
api_key = APIKey( api_key = APIKey(
id=uuid.uuid4(), id=uuid.uuid4(),
user_id=test_user.id, user_id=_to_uuid(test_user.id),
engine_type="chatgpt", engine_type="chatgpt",
encrypted_key="encrypted_test_key", encrypted_key="encrypted_test_key",
key_hint="sk-...abc", key_hint="sk-...abc",
@ -29,7 +30,7 @@ class TestAPIKeyModel:
await async_session.refresh(api_key) await async_session.refresh(api_key)
assert api_key.id is not None assert api_key.id is not None
assert api_key.user_id == test_user.id assert api_key.user_id == _to_uuid(test_user.id)
assert api_key.engine_type == "chatgpt" assert api_key.engine_type == "chatgpt"
assert api_key.encrypted_key == "encrypted_test_key" assert api_key.encrypted_key == "encrypted_test_key"
assert api_key.key_hint == "sk-...abc" assert api_key.key_hint == "sk-...abc"
@ -43,7 +44,7 @@ class TestAPIKeyModel:
async def test_api_key_default_values(self, async_session, test_user): async def test_api_key_default_values(self, async_session, test_user):
"""Test API key default values.""" """Test API key default values."""
api_key = APIKey( api_key = APIKey(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
engine_type="kimi", engine_type="kimi",
encrypted_key="encrypted_kimi_key", encrypted_key="encrypted_kimi_key",
key_hint="sk-...xyz", key_hint="sk-...xyz",
@ -62,7 +63,7 @@ class TestAPIKeyModel:
"""Test API key field validation and constraints.""" """Test API key field validation and constraints."""
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
api_key = APIKey( api_key = APIKey(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
engine_type="deepseek", engine_type="deepseek",
encrypted_key="encrypted_deepseek_key_data", encrypted_key="encrypted_deepseek_key_data",
key_hint="sk-...def", key_hint="sk-...def",
@ -90,7 +91,7 @@ class TestAPIKeyModel:
key_id = uuid.uuid4() key_id = uuid.uuid4()
api_key = APIKey( api_key = APIKey(
id=key_id, id=key_id,
user_id=test_user.id, user_id=_to_uuid(test_user.id),
engine_type="gemini", engine_type="gemini",
encrypted_key="encrypted_gemini_key", encrypted_key="encrypted_gemini_key",
key_hint="AIza...123", key_hint="AIza...123",
@ -111,13 +112,13 @@ class TestAPIKeyModel:
async def test_api_key_query_by_user_id(self, async_session, test_user): async def test_api_key_query_by_user_id(self, async_session, test_user):
"""Test querying API keys by user ID.""" """Test querying API keys by user ID."""
key1 = APIKey( key1 = APIKey(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
engine_type="chatgpt", engine_type="chatgpt",
encrypted_key="encrypted_key_1", encrypted_key="encrypted_key_1",
key_hint="sk-...1", key_hint="sk-...1",
) )
key2 = APIKey( key2 = APIKey(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
engine_type="kimi", engine_type="kimi",
encrypted_key="encrypted_key_2", encrypted_key="encrypted_key_2",
key_hint="sk-...2", key_hint="sk-...2",
@ -127,7 +128,7 @@ class TestAPIKeyModel:
await async_session.commit() await async_session.commit()
result = await async_session.execute( result = await async_session.execute(
select(APIKey).where(APIKey.user_id == test_user.id) select(APIKey).where(APIKey.user_id == _to_uuid(test_user.id))
) )
keys = result.scalars().all() keys = result.scalars().all()
@ -137,13 +138,13 @@ class TestAPIKeyModel:
async def test_api_key_query_by_user_and_engine(self, async_session, test_user): async def test_api_key_query_by_user_and_engine(self, async_session, test_user):
"""Test querying API keys by user ID and engine type.""" """Test querying API keys by user ID and engine type."""
key1 = APIKey( key1 = APIKey(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
engine_type="chatgpt", engine_type="chatgpt",
encrypted_key="encrypted_chatgpt_key", encrypted_key="encrypted_chatgpt_key",
key_hint="sk-...chat", key_hint="sk-...chat",
) )
key2 = APIKey( key2 = APIKey(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
engine_type="kimi", engine_type="kimi",
encrypted_key="encrypted_kimi_key", encrypted_key="encrypted_kimi_key",
key_hint="sk-...kimi", key_hint="sk-...kimi",
@ -155,7 +156,7 @@ class TestAPIKeyModel:
result = await async_session.execute( result = await async_session.execute(
select(APIKey).where( select(APIKey).where(
and_( and_(
APIKey.user_id == test_user.id, APIKey.user_id == _to_uuid(test_user.id),
APIKey.engine_type == "chatgpt" APIKey.engine_type == "chatgpt"
) )
) )
@ -169,7 +170,7 @@ class TestAPIKeyModel:
async def test_api_key_timestamps(self, async_session, test_user): async def test_api_key_timestamps(self, async_session, test_user):
"""Test API key created_at and updated_at timestamps.""" """Test API key created_at and updated_at timestamps."""
api_key = APIKey( api_key = APIKey(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
engine_type="qwen", engine_type="qwen",
encrypted_key="encrypted_qwen_key", encrypted_key="encrypted_qwen_key",
key_hint="sk-...qwen", key_hint="sk-...qwen",
@ -187,7 +188,7 @@ class TestAPIKeyModel:
async def test_api_key_update(self, async_session, test_user): async def test_api_key_update(self, async_session, test_user):
"""Test updating API key fields.""" """Test updating API key fields."""
api_key = APIKey( api_key = APIKey(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
engine_type="wenxin", engine_type="wenxin",
encrypted_key="encrypted_wenxin_key", encrypted_key="encrypted_wenxin_key",
key_hint="sk-...wenxin", key_hint="sk-...wenxin",
@ -211,7 +212,7 @@ class TestAPIKeyModel:
async def test_api_key_delete(self, async_session, test_user): async def test_api_key_delete(self, async_session, test_user):
"""Test deleting an API key.""" """Test deleting an API key."""
api_key = APIKey( api_key = APIKey(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
engine_type="doubao", engine_type="doubao",
encrypted_key="encrypted_doubao_key", encrypted_key="encrypted_doubao_key",
key_hint="sk-...doubao", key_hint="sk-...doubao",
@ -238,7 +239,7 @@ class TestAPIKeyModel:
for i, status in enumerate(statuses): for i, status in enumerate(statuses):
api_key = APIKey( api_key = APIKey(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
engine_type=f"engine_{i}", engine_type=f"engine_{i}",
encrypted_key=f"encrypted_key_{i}", encrypted_key=f"encrypted_key_{i}",
key_hint=f"sk-...{i}", key_hint=f"sk-...{i}",
@ -263,7 +264,7 @@ class TestAPIKeyModel:
for data in keys_data: for data in keys_data:
api_key = APIKey( api_key = APIKey(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
engine_type=data["engine_type"], engine_type=data["engine_type"],
encrypted_key=f"encrypted_{data['engine_type']}", encrypted_key=f"encrypted_{data['engine_type']}",
key_hint=data["key_hint"], key_hint=data["key_hint"],
@ -275,7 +276,7 @@ class TestAPIKeyModel:
result = await async_session.execute( result = await async_session.execute(
select(APIKey) select(APIKey)
.where(APIKey.user_id == test_user.id) .where(APIKey.user_id == _to_uuid(test_user.id))
.order_by(APIKey.priority.desc()) .order_by(APIKey.priority.desc())
) )
keys = result.scalars().all() keys = result.scalars().all()
@ -289,7 +290,7 @@ class TestAPIKeyModel:
"""Test that user_id field has an index.""" """Test that user_id field has an index."""
for i in range(5): for i in range(5):
api_key = APIKey( api_key = APIKey(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
engine_type=f"engine_{i}", engine_type=f"engine_{i}",
encrypted_key=f"encrypted_key_{i}", encrypted_key=f"encrypted_key_{i}",
key_hint=f"hint_{i}", key_hint=f"hint_{i}",
@ -299,7 +300,7 @@ class TestAPIKeyModel:
await async_session.commit() await async_session.commit()
result = await async_session.execute( result = await async_session.execute(
select(APIKey).where(APIKey.user_id == test_user.id) select(APIKey).where(APIKey.user_id == _to_uuid(test_user.id))
) )
keys = result.scalars().all() keys = result.scalars().all()

View File

@ -6,6 +6,7 @@ import pytest
from sqlalchemy import select from sqlalchemy import select
from app.models.brand import Brand from app.models.brand import Brand
from tests.fixtures.auth import _to_uuid
class TestBrandModel: class TestBrandModel:
@ -16,7 +17,7 @@ class TestBrandModel:
"""Test creating a new brand.""" """Test creating a new brand."""
brand = Brand( brand = Brand(
id=uuid.uuid4(), id=uuid.uuid4(),
user_id=test_user.id, user_id=_to_uuid(test_user.id),
name="Test Brand", name="Test Brand",
aliases=["TestBrand", "TB"], aliases=["TestBrand", "TB"],
website="https://testbrand.com", website="https://testbrand.com",
@ -30,7 +31,7 @@ class TestBrandModel:
await async_session.refresh(brand) await async_session.refresh(brand)
assert brand.id is not None assert brand.id is not None
assert brand.user_id == test_user.id assert brand.user_id == _to_uuid(test_user.id)
assert brand.name == "Test Brand" assert brand.name == "Test Brand"
assert brand.aliases == ["TestBrand", "TB"] assert brand.aliases == ["TestBrand", "TB"]
assert brand.website == "https://testbrand.com" assert brand.website == "https://testbrand.com"
@ -45,7 +46,7 @@ class TestBrandModel:
async def test_brand_default_values(self, async_session, test_user): async def test_brand_default_values(self, async_session, test_user):
"""Test brand default values.""" """Test brand default values."""
brand = Brand( brand = Brand(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
name="Default Brand", name="Default Brand",
platforms=["wenxin"], platforms=["wenxin"],
) )
@ -59,13 +60,12 @@ class TestBrandModel:
assert brand.frequency == "weekly" assert brand.frequency == "weekly"
assert brand.status == "active" assert brand.status == "active"
assert brand.last_queried_at is None assert brand.last_queried_at is None
assert brand.next_query_at is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_brand_fields(self, async_session, test_user): async def test_brand_fields(self, async_session, test_user):
"""Test brand field validation and constraints.""" """Test brand field validation and constraints."""
brand = Brand( brand = Brand(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
name="Field Test Brand", name="Field Test Brand",
aliases=["FTA", "FieldTest"], aliases=["FTA", "FieldTest"],
website="https://fieldtest.com", website="https://fieldtest.com",
@ -80,7 +80,6 @@ class TestBrandModel:
await async_session.commit() await async_session.commit()
await async_session.refresh(brand) await async_session.refresh(brand)
# Verify all fields
assert brand.name == "Field Test Brand" assert brand.name == "Field Test Brand"
assert len(brand.name) == 16 assert len(brand.name) == 16
assert brand.aliases == ["FTA", "FieldTest"] assert brand.aliases == ["FTA", "FieldTest"]
@ -91,7 +90,6 @@ class TestBrandModel:
assert brand.frequency == "daily" assert brand.frequency == "daily"
assert brand.status == "active" assert brand.status == "active"
assert brand.last_queried_at is not None assert brand.last_queried_at is not None
assert brand.next_query_at is not None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_brand_query_by_id(self, async_session, test_user): async def test_brand_query_by_id(self, async_session, test_user):
@ -99,7 +97,7 @@ class TestBrandModel:
brand_id = uuid.uuid4() brand_id = uuid.uuid4()
brand = Brand( brand = Brand(
id=brand_id, id=brand_id,
user_id=test_user.id, user_id=_to_uuid(test_user.id),
name="Query Test Brand", name="Query Test Brand",
platforms=["wenxin"], platforms=["wenxin"],
) )
@ -118,14 +116,14 @@ class TestBrandModel:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_brand_query_by_user_id(self, async_session, test_user): async def test_brand_query_by_user_id(self, async_session, test_user):
"""Test querying brands by user ID.""" """Test querying brands by user ID."""
# Create multiple brands for the same user uid = _to_uuid(test_user.id)
brand1 = Brand( brand1 = Brand(
user_id=test_user.id, user_id=uid,
name="User Brand 1", name="User Brand 1",
platforms=["wenxin"], platforms=["wenxin"],
) )
brand2 = Brand( brand2 = Brand(
user_id=test_user.id, user_id=uid,
name="User Brand 2", name="User Brand 2",
platforms=["kimi"], platforms=["kimi"],
) )
@ -134,7 +132,7 @@ class TestBrandModel:
await async_session.commit() await async_session.commit()
result = await async_session.execute( result = await async_session.execute(
select(Brand).where(Brand.user_id == test_user.id) select(Brand).where(Brand.user_id == uid)
) )
brands = result.scalars().all() brands = result.scalars().all()
@ -144,7 +142,7 @@ class TestBrandModel:
async def test_brand_timestamps(self, async_session, test_user): async def test_brand_timestamps(self, async_session, test_user):
"""Test brand created_at and updated_at timestamps.""" """Test brand created_at and updated_at timestamps."""
brand = Brand( brand = Brand(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
name="Timestamp Brand", name="Timestamp Brand",
platforms=["wenxin"], platforms=["wenxin"],
) )
@ -161,7 +159,7 @@ class TestBrandModel:
async def test_brand_update(self, async_session, test_user): async def test_brand_update(self, async_session, test_user):
"""Test updating brand fields.""" """Test updating brand fields."""
brand = Brand( brand = Brand(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
name="Update Test Brand", name="Update Test Brand",
platforms=["wenxin"], platforms=["wenxin"],
frequency="weekly", frequency="weekly",
@ -169,7 +167,6 @@ class TestBrandModel:
async_session.add(brand) async_session.add(brand)
await async_session.commit() await async_session.commit()
# Update brand
brand.name = "Updated Brand Name" brand.name = "Updated Brand Name"
brand.frequency = "daily" brand.frequency = "daily"
brand.aliases = ["Updated", "Alias"] brand.aliases = ["Updated", "Alias"]
@ -184,7 +181,7 @@ class TestBrandModel:
async def test_brand_delete(self, async_session, test_user): async def test_brand_delete(self, async_session, test_user):
"""Test deleting a brand.""" """Test deleting a brand."""
brand = Brand( brand = Brand(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
name="Delete Test Brand", name="Delete Test Brand",
platforms=["wenxin"], platforms=["wenxin"],
) )

View File

@ -7,6 +7,7 @@ from sqlalchemy import select
from app.models.brand import Brand from app.models.brand import Brand
from app.models.competitor import Competitor from app.models.competitor import Competitor
from tests.fixtures.auth import _to_uuid
class TestCompetitorModel: class TestCompetitorModel:
@ -18,7 +19,7 @@ class TestCompetitorModel:
# First create a brand # First create a brand
brand = Brand( brand = Brand(
id=uuid.uuid4(), id=uuid.uuid4(),
user_id=test_user.id, user_id=_to_uuid(test_user.id),
name="Test Brand for Competitor", name="Test Brand for Competitor",
platforms=["wenxin"], platforms=["wenxin"],
) )
@ -48,7 +49,7 @@ class TestCompetitorModel:
"""Test competitor default values.""" """Test competitor default values."""
# Create brand first # Create brand first
brand = Brand( brand = Brand(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
name="Brand for Default Competitor", name="Brand for Default Competitor",
platforms=["wenxin"], platforms=["wenxin"],
) )
@ -72,7 +73,7 @@ class TestCompetitorModel:
"""Test competitor field validation.""" """Test competitor field validation."""
# Create brand first # Create brand first
brand = Brand( brand = Brand(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
name="Brand for Field Test", name="Brand for Field Test",
platforms=["wenxin", "kimi"], platforms=["wenxin", "kimi"],
) )
@ -98,7 +99,7 @@ class TestCompetitorModel:
"""Test querying competitor by ID.""" """Test querying competitor by ID."""
# Create brand first # Create brand first
brand = Brand( brand = Brand(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
name="Brand for Query Test", name="Brand for Query Test",
platforms=["wenxin"], platforms=["wenxin"],
) )
@ -129,7 +130,7 @@ class TestCompetitorModel:
"""Test querying competitors by brand ID.""" """Test querying competitors by brand ID."""
# Create brand first # Create brand first
brand = Brand( brand = Brand(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
name="Brand for Multi Competitor Test", name="Brand for Multi Competitor Test",
platforms=["wenxin"], platforms=["wenxin"],
) )
@ -164,7 +165,7 @@ class TestCompetitorModel:
"""Test competitor created_at timestamp.""" """Test competitor created_at timestamp."""
# Create brand first # Create brand first
brand = Brand( brand = Brand(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
name="Brand for Timestamp Test", name="Brand for Timestamp Test",
platforms=["wenxin"], platforms=["wenxin"],
) )
@ -188,7 +189,7 @@ class TestCompetitorModel:
"""Test updating competitor fields.""" """Test updating competitor fields."""
# Create brand first # Create brand first
brand = Brand( brand = Brand(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
name="Brand for Update Test", name="Brand for Update Test",
platforms=["wenxin"], platforms=["wenxin"],
) )
@ -218,7 +219,7 @@ class TestCompetitorModel:
"""Test deleting a competitor.""" """Test deleting a competitor."""
# Create brand first # Create brand first
brand = Brand( brand = Brand(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
name="Brand for Delete Test", name="Brand for Delete Test",
platforms=["wenxin"], platforms=["wenxin"],
) )
@ -249,7 +250,7 @@ class TestCompetitorModel:
"""Test that competitors are deleted when brand is deleted.""" """Test that competitors are deleted when brand is deleted."""
# Create brand with competitors # Create brand with competitors
brand = Brand( brand = Brand(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
name="Brand for Cascade Test", name="Brand for Cascade Test",
platforms=["wenxin"], platforms=["wenxin"],
) )

View File

@ -6,6 +6,7 @@ from sqlalchemy import select
from app.models.organization import Organization, OrgMember from app.models.organization import Organization, OrgMember
from app.models.user import User from app.models.user import User
from tests.fixtures.auth import _to_uuid
class TestOrganizationModel: class TestOrganizationModel:

View File

@ -29,7 +29,7 @@ class TestSubscriptionModel:
def test_subscription_field_types(self): def test_subscription_field_types(self):
columns = Subscription.__table__.columns columns = Subscription.__table__.columns
assert "UUID" in str(columns["id"].type).upper() assert "UUID" in str(columns["id"].type).upper()
assert "UUID" in str(columns["user_id"].type).upper() assert "VARCHAR" in str(columns["user_id"].type).upper() or "STRING" in str(columns["user_id"].type).upper()
assert "VARCHAR" in str(columns["plan"].type).upper() or "STRING" in str(columns["plan"].type).upper() assert "VARCHAR" in str(columns["plan"].type).upper() or "STRING" in str(columns["plan"].type).upper()
assert "VARCHAR" in str(columns["status"].type).upper() or "STRING" in str(columns["status"].type).upper() assert "VARCHAR" in str(columns["status"].type).upper() or "STRING" in str(columns["status"].type).upper()
assert "DATE" in str(columns["start_date"].type).upper() assert "DATE" in str(columns["start_date"].type).upper()

View File

@ -7,6 +7,7 @@ from sqlalchemy import select
from app.models.brand import Brand from app.models.brand import Brand
from app.models.suggestion import Suggestion from app.models.suggestion import Suggestion
from app.models.user import User from app.models.user import User
from tests.fixtures.auth import _to_uuid
class TestSuggestionModel: class TestSuggestionModel:
@ -81,7 +82,7 @@ class TestSuggestionModel:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_suggestion_create(self, async_session, test_user): async def test_suggestion_create(self, async_session, test_user):
brand = Brand( brand = Brand(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
name="Suggestion Test Brand", name="Suggestion Test Brand",
platforms=["wenxin"], platforms=["wenxin"],
) )
@ -124,7 +125,7 @@ class TestSuggestionModel:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_suggestion_default_values(self, async_session, test_user): async def test_suggestion_default_values(self, async_session, test_user):
brand = Brand( brand = Brand(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
name="Default Suggestion Brand", name="Default Suggestion Brand",
platforms=["kimi"], platforms=["kimi"],
) )
@ -152,7 +153,7 @@ class TestSuggestionModel:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_suggestion_query_by_brand(self, async_session, test_user): async def test_suggestion_query_by_brand(self, async_session, test_user):
brand = Brand( brand = Brand(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
name="Query Suggestion Brand", name="Query Suggestion Brand",
platforms=["wenxin"], platforms=["wenxin"],
) )

View File

@ -7,6 +7,7 @@ from sqlalchemy import select, func, and_
from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.dialects.postgresql import insert as pg_insert
from app.models.usage_record import UsageRecord from app.models.usage_record import UsageRecord
from tests.fixtures.auth import _to_uuid
class TestUsageRecordModel: class TestUsageRecordModel:
@ -16,7 +17,7 @@ class TestUsageRecordModel:
async def test_usage_record_create(self, async_session, test_user): async def test_usage_record_create(self, async_session, test_user):
"""Test creating a new usage record.""" """Test creating a new usage record."""
record = UsageRecord( record = UsageRecord(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
engine_type="chatgpt", engine_type="chatgpt",
query="What is SEO optimization?", query="What is SEO optimization?",
input_tokens=100, input_tokens=100,
@ -29,7 +30,7 @@ class TestUsageRecordModel:
await async_session.refresh(record) await async_session.refresh(record)
assert record.id is not None assert record.id is not None
assert record.user_id == test_user.id assert record.user_id == _to_uuid(test_user.id)
assert record.engine_type == "chatgpt" assert record.engine_type == "chatgpt"
assert record.query == "What is SEO optimization?" assert record.query == "What is SEO optimization?"
assert record.input_tokens == 100 assert record.input_tokens == 100
@ -43,7 +44,7 @@ class TestUsageRecordModel:
async def test_usage_record_default_values(self, async_session, test_user): async def test_usage_record_default_values(self, async_session, test_user):
"""Test usage record default values.""" """Test usage record default values."""
record = UsageRecord( record = UsageRecord(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
engine_type="kimi", engine_type="kimi",
query="Test query", query="Test query",
) )
@ -60,7 +61,7 @@ class TestUsageRecordModel:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_usage_record_query_by_user_id(self, async_session, test_user): async def test_usage_record_query_by_user_id(self, async_session, test_user):
"""Test querying usage records by user ID.""" """Test querying usage records by user ID."""
user_id = test_user.id user_id = _to_uuid(test_user.id)
for i in range(3): for i in range(3):
record = UsageRecord( record = UsageRecord(
user_id=user_id, user_id=user_id,
@ -81,7 +82,7 @@ class TestUsageRecordModel:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_usage_record_query_by_user_and_engine(self, async_session, test_user): async def test_usage_record_query_by_user_and_engine(self, async_session, test_user):
"""Test querying usage records by user ID and engine type.""" """Test querying usage records by user ID and engine type."""
user_id = test_user.id user_id = _to_uuid(test_user.id)
record1 = UsageRecord( record1 = UsageRecord(
user_id=user_id, user_id=user_id,
engine_type="chatgpt", engine_type="chatgpt",
@ -114,7 +115,7 @@ class TestUsageRecordModel:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_usage_record_query_by_time_range(self, async_session, test_user): async def test_usage_record_query_by_time_range(self, async_session, test_user):
"""Test querying usage records by time range.""" """Test querying usage records by time range."""
user_id = test_user.id user_id = _to_uuid(test_user.id)
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
old_record = UsageRecord( old_record = UsageRecord(
@ -152,7 +153,7 @@ class TestUsageRecordModel:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_usage_record_aggregate_by_user(self, async_session, test_user): async def test_usage_record_aggregate_by_user(self, async_session, test_user):
"""Test aggregating usage records by user.""" """Test aggregating usage records by user."""
user_id = test_user.id user_id = _to_uuid(test_user.id)
records_data = [ records_data = [
{"engine": "chatgpt", "input_tokens": 100, "output_tokens": 200, "cost": 0.01}, {"engine": "chatgpt", "input_tokens": 100, "output_tokens": 200, "cost": 0.01},
{"engine": "kimi", "input_tokens": 150, "output_tokens": 300, "cost": 0.02}, {"engine": "kimi", "input_tokens": 150, "output_tokens": 300, "cost": 0.02},
@ -192,7 +193,7 @@ class TestUsageRecordModel:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_usage_record_aggregate_by_day(self, async_session, test_user): async def test_usage_record_aggregate_by_day(self, async_session, test_user):
"""Test aggregating usage records by day.""" """Test aggregating usage records by day."""
user_id = test_user.id user_id = _to_uuid(test_user.id)
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
today_start = now.replace(hour=0, minute=0, second=0, microsecond=0) today_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
@ -238,7 +239,7 @@ class TestUsageRecordModel:
"""Test usage record with brand association.""" """Test usage record with brand association."""
brand_id = uuid.uuid4() brand_id = uuid.uuid4()
record = UsageRecord( record = UsageRecord(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
brand_id=brand_id, brand_id=brand_id,
engine_type="wenxin", engine_type="wenxin",
query="Brand query", query="Brand query",
@ -253,7 +254,7 @@ class TestUsageRecordModel:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_usage_record_index_user_engine(self, async_session, test_user): async def test_usage_record_index_user_engine(self, async_session, test_user):
"""Test composite index on user_id and engine_type.""" """Test composite index on user_id and engine_type."""
user_id = test_user.id user_id = _to_uuid(test_user.id)
for i in range(5): for i in range(5):
record = UsageRecord( record = UsageRecord(
user_id=user_id, user_id=user_id,
@ -280,7 +281,7 @@ class TestUsageRecordModel:
async def test_usage_record_update(self, async_session, test_user): async def test_usage_record_update(self, async_session, test_user):
"""Test updating usage record fields.""" """Test updating usage record fields."""
record = UsageRecord( record = UsageRecord(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
engine_type="xinghuo", engine_type="xinghuo",
query="Original query", query="Original query",
cost=1.0, cost=1.0,
@ -300,7 +301,7 @@ class TestUsageRecordModel:
async def test_usage_record_delete(self, async_session, test_user): async def test_usage_record_delete(self, async_session, test_user):
"""Test deleting a usage record.""" """Test deleting a usage record."""
record = UsageRecord( record = UsageRecord(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
engine_type="yuanbao", engine_type="yuanbao",
query="Delete me", query="Delete me",
cost=1.0, cost=1.0,
@ -323,7 +324,7 @@ class TestUsageRecordModel:
async def test_usage_record_timestamps(self, async_session, test_user): async def test_usage_record_timestamps(self, async_session, test_user):
"""Test usage record timestamp fields.""" """Test usage record timestamp fields."""
record = UsageRecord( record = UsageRecord(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
engine_type="perplexity", engine_type="perplexity",
query="Timestamp test", query="Timestamp test",
cost=1.0, cost=1.0,
@ -343,7 +344,7 @@ class TestUsageRecordModel:
other_user_id = uuid.uuid4() other_user_id = uuid.uuid4()
user1_record = UsageRecord( user1_record = UsageRecord(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
engine_type="chatgpt", engine_type="chatgpt",
query="User 1 query", query="User 1 query",
cost=1.0, cost=1.0,
@ -359,7 +360,7 @@ class TestUsageRecordModel:
await async_session.commit() await async_session.commit()
result_user1 = await async_session.execute( result_user1 = await async_session.execute(
select(UsageRecord).where(UsageRecord.user_id == test_user.id) select(UsageRecord).where(UsageRecord.user_id == _to_uuid(test_user.id))
) )
result_user2 = await async_session.execute( result_user2 = await async_session.execute(
select(UsageRecord).where(UsageRecord.user_id == other_user_id) select(UsageRecord).where(UsageRecord.user_id == other_user_id)
@ -372,7 +373,7 @@ class TestUsageRecordModel:
async def test_usage_record_empty_query_field(self, async_session, test_user): async def test_usage_record_empty_query_field(self, async_session, test_user):
"""Test usage record with empty query field.""" """Test usage record with empty query field."""
record = UsageRecord( record = UsageRecord(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
engine_type="deepseek", engine_type="deepseek",
query="", query="",
cost=0.0, cost=0.0,
@ -394,7 +395,7 @@ class TestUsageRecordModel:
"nested": {"key": {"deep": {"value": [1, 2, 3]}}}, "nested": {"key": {"deep": {"value": [1, 2, 3]}}},
} }
record = UsageRecord( record = UsageRecord(
user_id=test_user.id, user_id=_to_uuid(test_user.id),
engine_type="chatgpt", engine_type="chatgpt",
query="Large metadata test", query="Large metadata test",
extra_data=large_metadata, extra_data=large_metadata,

View File

@ -12,6 +12,7 @@ from app.models.user import User
from app.models.usage_record import UsageRecord from app.models.usage_record import UsageRecord
from app.repositories.usage_repository import UsageRepository from app.repositories.usage_repository import UsageRepository
from app.services.user_quota_service import UserQuotaService, PLAN_MONTHLY_LIMITS from app.services.user_quota_service import UserQuotaService, PLAN_MONTHLY_LIMITS
from tests.fixtures.auth import _to_uuid
@pytest_asyncio.fixture @pytest_asyncio.fixture
@ -47,14 +48,14 @@ async def async_session(async_engine):
async def test_user_free(async_session): async def test_user_free(async_session):
"""Create a free plan test user.""" """Create a free plan test user."""
user = User( user = User(
id=uuid.uuid4(), id=str(uuid.uuid4()),
email="free@example.com", email="free@example.com",
password_hash="hashed_password", password="hashed_password",
name="Free User", firstName="Free User",
plan="free", plan="free",
max_queries=5, max_queries=5,
is_active=True, isActive=True,
email_verified=True, emailVerified=True,
) )
async_session.add(user) async_session.add(user)
await async_session.commit() await async_session.commit()
@ -66,14 +67,14 @@ async def test_user_free(async_session):
async def test_user_basic(async_session): async def test_user_basic(async_session):
"""Create a basic plan test user.""" """Create a basic plan test user."""
user = User( user = User(
id=uuid.uuid4(), id=str(uuid.uuid4()),
email="basic@example.com", email="basic@example.com",
password_hash="hashed_password", password="hashed_password",
name="Basic User", firstName="Basic User",
plan="basic", plan="basic",
max_queries=50, max_queries=50,
is_active=True, isActive=True,
email_verified=True, emailVerified=True,
) )
async_session.add(user) async_session.add(user)
await async_session.commit() await async_session.commit()
@ -85,14 +86,14 @@ async def test_user_basic(async_session):
async def test_user_pro(async_session): async def test_user_pro(async_session):
"""Create a pro plan test user.""" """Create a pro plan test user."""
user = User( user = User(
id=uuid.uuid4(), id=str(uuid.uuid4()),
email="pro@example.com", email="pro@example.com",
password_hash="hashed_password", password="hashed_password",
name="Pro User", firstName="Pro User",
plan="pro", plan="pro",
max_queries=500, max_queries=500,
is_active=True, isActive=True,
email_verified=True, emailVerified=True,
) )
async_session.add(user) async_session.add(user)
await async_session.commit() await async_session.commit()
@ -110,7 +111,7 @@ class TestByDayAggregation:
for i in range(3): for i in range(3):
await repo.create({ await repo.create({
"user_id": test_user_free.id, "user_id": _to_uuid(test_user_free.id),
"engine_type": "chatgpt", "engine_type": "chatgpt",
"query": f"Query {i}", "query": f"Query {i}",
"cost": 0.01, "cost": 0.01,
@ -130,7 +131,7 @@ class TestByDayAggregation:
for i in range(5): for i in range(5):
await repo.create({ await repo.create({
"user_id": test_user_free.id, "user_id": _to_uuid(test_user_free.id),
"engine_type": "deepseek", "engine_type": "deepseek",
"query": f"Query {i}", "query": f"Query {i}",
"cost": 0.02, "cost": 0.02,
@ -149,7 +150,7 @@ class TestByDayAggregation:
repo = UsageRepository(async_session) repo = UsageRepository(async_session)
await repo.create({ await repo.create({
"user_id": test_user_free.id, "user_id": _to_uuid(test_user_free.id),
"engine_type": "qwen", "engine_type": "qwen",
"query": "Query 1", "query": "Query 1",
"input_tokens": 100, "input_tokens": 100,
@ -157,7 +158,7 @@ class TestByDayAggregation:
"cost": 0.01, "cost": 0.01,
}) })
await repo.create({ await repo.create({
"user_id": test_user_free.id, "user_id": _to_uuid(test_user_free.id),
"engine_type": "qwen", "engine_type": "qwen",
"query": "Query 2", "query": "Query 2",
"input_tokens": 150, "input_tokens": 150,
@ -188,14 +189,14 @@ class TestByDayAggregation:
yesterday = datetime.now(timezone.utc) - timedelta(days=1) yesterday = datetime.now(timezone.utc) - timedelta(days=1)
await repo.create({ await repo.create({
"user_id": test_user_free.id, "user_id": _to_uuid(test_user_free.id),
"engine_type": "gemini", "engine_type": "gemini",
"query": "Yesterday query", "query": "Yesterday query",
"cost": 0.05, "cost": 0.05,
"timestamp": yesterday, "timestamp": yesterday,
}) })
await repo.create({ await repo.create({
"user_id": test_user_free.id, "user_id": _to_uuid(test_user_free.id),
"engine_type": "gemini", "engine_type": "gemini",
"query": "Today query", "query": "Today query",
"cost": 0.05, "cost": 0.05,
@ -268,7 +269,7 @@ class TestUserQuotaService:
for i in range(5): for i in range(5):
await repo.create({ await repo.create({
"user_id": test_user_free.id, "user_id": _to_uuid(test_user_free.id),
"engine_type": "chatgpt", "engine_type": "chatgpt",
"query": f"Query {i}", "query": f"Query {i}",
"cost": 1.0, "cost": 1.0,
@ -291,7 +292,7 @@ class TestUserQuotaService:
for i in range(10): for i in range(10):
await repo.create({ await repo.create({
"user_id": test_user_basic.id, "user_id": _to_uuid(test_user_basic.id),
"engine_type": "deepseek", "engine_type": "deepseek",
"query": f"Query {i}", "query": f"Query {i}",
"cost": 2.0, "cost": 2.0,
@ -314,7 +315,7 @@ class TestUserQuotaService:
for i in range(10): for i in range(10):
await repo.create({ await repo.create({
"user_id": test_user_pro.id, "user_id": _to_uuid(test_user_pro.id),
"engine_type": "qwen", "engine_type": "qwen",
"query": f"Query {i}", "query": f"Query {i}",
"cost": 10.0, "cost": 10.0,
@ -336,7 +337,7 @@ class TestUserQuotaService:
repo = UsageRepository(async_session) repo = UsageRepository(async_session)
await repo.create({ await repo.create({
"user_id": test_user_free.id, "user_id": _to_uuid(test_user_free.id),
"engine_type": "gemini", "engine_type": "gemini",
"query": "Expensive query", "query": "Expensive query",
"cost": 15.0, "cost": 15.0,

View File

@ -11,6 +11,7 @@ from app.database import Base
from app.models.user import User from app.models.user import User
from app.models.usage_record import UsageRecord from app.models.usage_record import UsageRecord
from app.repositories.usage_repository import UsageRepository from app.repositories.usage_repository import UsageRepository
from tests.fixtures.auth import _to_uuid
@pytest_asyncio.fixture @pytest_asyncio.fixture
@ -46,14 +47,14 @@ async def async_session(async_engine):
async def test_user(async_session): async def test_user(async_session):
"""Create a test user.""" """Create a test user."""
user = User( user = User(
id=uuid.uuid4(), id=str(uuid.uuid4()),
email="test@example.com", email="test@example.com",
password_hash="hashed_password", password="hashed_password",
name="Test User", firstName="Test User",
plan="free", plan="free",
max_queries=5, max_queries=5,
is_active=True, isActive=True,
email_verified=True, emailVerified=True,
) )
async_session.add(user) async_session.add(user)
await async_session.commit() await async_session.commit()
@ -70,7 +71,7 @@ class TestUsageRepository:
repo = UsageRepository(async_session) repo = UsageRepository(async_session)
data = { data = {
"user_id": test_user.id, "user_id": _to_uuid(test_user.id),
"engine_type": "chatgpt", "engine_type": "chatgpt",
"query": "Test query", "query": "Test query",
"input_tokens": 100, "input_tokens": 100,
@ -82,7 +83,7 @@ class TestUsageRepository:
record = await repo.create(data) record = await repo.create(data)
assert record.id is not None assert record.id is not None
assert record.user_id == test_user.id assert record.user_id == _to_uuid(test_user.id)
assert record.engine_type == "chatgpt" assert record.engine_type == "chatgpt"
assert record.query == "Test query" assert record.query == "Test query"
assert record.input_tokens == 100 assert record.input_tokens == 100
@ -96,7 +97,7 @@ class TestUsageRepository:
repo = UsageRepository(async_session) repo = UsageRepository(async_session)
data = { data = {
"user_id": test_user.id, "user_id": _to_uuid(test_user.id),
"engine_type": "deepseek", "engine_type": "deepseek",
"query": "Minimal query", "query": "Minimal query",
} }
@ -104,7 +105,7 @@ class TestUsageRepository:
record = await repo.create(data) record = await repo.create(data)
assert record.id is not None assert record.id is not None
assert record.user_id == test_user.id assert record.user_id == _to_uuid(test_user.id)
assert record.engine_type == "deepseek" assert record.engine_type == "deepseek"
assert record.query == "Minimal query" assert record.query == "Minimal query"
assert record.input_tokens == 0 assert record.input_tokens == 0
@ -119,7 +120,7 @@ class TestUsageRepository:
for i in range(3): for i in range(3):
await repo.create({ await repo.create({
"user_id": test_user.id, "user_id": _to_uuid(test_user.id),
"engine_type": "chatgpt", "engine_type": "chatgpt",
"query": f"Query {i}", "query": f"Query {i}",
"input_tokens": 100, "input_tokens": 100,
@ -143,13 +144,13 @@ class TestUsageRepository:
repo = UsageRepository(async_session) repo = UsageRepository(async_session)
await repo.create({ await repo.create({
"user_id": test_user.id, "user_id": _to_uuid(test_user.id),
"engine_type": "chatgpt", "engine_type": "chatgpt",
"query": "ChatGPT query", "query": "ChatGPT query",
"cost": 0.02, "cost": 0.02,
}) })
await repo.create({ await repo.create({
"user_id": test_user.id, "user_id": _to_uuid(test_user.id),
"engine_type": "deepseek", "engine_type": "deepseek",
"query": "DeepSeek query", "query": "DeepSeek query",
"cost": 0.01, "cost": 0.01,
@ -170,7 +171,7 @@ class TestUsageRepository:
for i in range(2): for i in range(2):
await repo.create({ await repo.create({
"user_id": test_user.id, "user_id": _to_uuid(test_user.id),
"engine_type": "qwen", "engine_type": "qwen",
"query": f"Query {i}", "query": f"Query {i}",
"cost": 0.01, "cost": 0.01,
@ -190,14 +191,14 @@ class TestUsageRepository:
brand_id = uuid.uuid4() brand_id = uuid.uuid4()
await repo.create({ await repo.create({
"user_id": test_user.id, "user_id": _to_uuid(test_user.id),
"brand_id": brand_id, "brand_id": brand_id,
"engine_type": "kimi", "engine_type": "kimi",
"query": "Brand query", "query": "Brand query",
"cost": 0.05, "cost": 0.05,
}) })
await repo.create({ await repo.create({
"user_id": test_user.id, "user_id": _to_uuid(test_user.id),
"engine_type": "kimi", "engine_type": "kimi",
"query": "No brand query", "query": "No brand query",
"cost": 0.03, "cost": 0.03,
@ -219,7 +220,7 @@ class TestUsageRepository:
for i in range(5): for i in range(5):
await repo.create({ await repo.create({
"user_id": test_user.id, "user_id": _to_uuid(test_user.id),
"engine_type": "gemini", "engine_type": "gemini",
"query": f"Query {i}", "query": f"Query {i}",
"cost": 1.0, "cost": 1.0,
@ -238,7 +239,7 @@ class TestUsageRepository:
repo = UsageRepository(async_session) repo = UsageRepository(async_session)
await repo.create({ await repo.create({
"user_id": test_user.id, "user_id": _to_uuid(test_user.id),
"engine_type": "chatgpt", "engine_type": "chatgpt",
"query": "Expensive query", "query": "Expensive query",
"cost": 85.0, "cost": 85.0,
@ -255,7 +256,7 @@ class TestUsageRepository:
repo = UsageRepository(async_session) repo = UsageRepository(async_session)
await repo.create({ await repo.create({
"user_id": test_user.id, "user_id": _to_uuid(test_user.id),
"engine_type": "chatgpt", "engine_type": "chatgpt",
"query": "Very expensive query", "query": "Very expensive query",
"cost": 120.0, "cost": 120.0,
@ -272,7 +273,7 @@ class TestUsageRepository:
repo = UsageRepository(async_session) repo = UsageRepository(async_session)
created = await repo.create({ created = await repo.create({
"user_id": test_user.id, "user_id": _to_uuid(test_user.id),
"engine_type": "wenxin", "engine_type": "wenxin",
"query": "Get by ID test", "query": "Get by ID test",
"cost": 0.5, "cost": 0.5,
@ -291,7 +292,7 @@ class TestUsageRepository:
for i in range(3): for i in range(3):
await repo.create({ await repo.create({
"user_id": test_user.id, "user_id": _to_uuid(test_user.id),
"engine_type": "doubao", "engine_type": "doubao",
"query": f"User query {i}", "query": f"User query {i}",
"cost": 0.1, "cost": 0.1,
@ -307,13 +308,13 @@ class TestUsageRepository:
repo = UsageRepository(async_session) repo = UsageRepository(async_session)
await repo.create({ await repo.create({
"user_id": test_user.id, "user_id": _to_uuid(test_user.id),
"engine_type": "xinghuo", "engine_type": "xinghuo",
"query": "Xinghuo query", "query": "Xinghuo query",
"cost": 0.1, "cost": 0.1,
}) })
await repo.create({ await repo.create({
"user_id": test_user.id, "user_id": _to_uuid(test_user.id),
"engine_type": "perplexity", "engine_type": "perplexity",
"query": "Perplexity query", "query": "Perplexity query",
"cost": 0.2, "cost": 0.2,
@ -358,7 +359,7 @@ class TestUsageRepository:
repo = UsageRepository(async_session) repo = UsageRepository(async_session)
data = { data = {
"user_id": test_user.id, "user_id": _to_uuid(test_user.id),
"engine_type": "yuanbao", "engine_type": "yuanbao",
"query": "UUID test", "query": "UUID test",
"cost": 0.5, "cost": 0.5,
@ -368,4 +369,4 @@ class TestUsageRepository:
summary = await repo.get_summary(test_user.id, period="month") summary = await repo.get_summary(test_user.id, period="month")
assert summary["total_queries"] == 1 assert summary["total_queries"] == 1
assert record.user_id == test_user.id assert record.user_id == _to_uuid(test_user.id)

View File

@ -11,6 +11,7 @@ from app.database import Base
from app.models.brand import Brand from app.models.brand import Brand
from app.models.user import User from app.models.user import User
from app.services.auth import hash_password from app.services.auth import hash_password
from tests.fixtures.auth import _to_uuid
@pytest_asyncio.fixture @pytest_asyncio.fixture
@ -42,14 +43,14 @@ async def async_session(async_engine):
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def test_user(async_session): async def test_user(async_session):
user = User( user = User(
id=uuid.uuid4(), id=str(uuid.uuid4()),
email="test@example.com", email="test@example.com",
password_hash=hash_password("Test@123456"), password=hash_password("Test@123456"),
name="Test User", firstName="Test User",
plan="free", plan="free",
max_queries=5, max_queries=5,
is_active=True, isActive=True,
email_verified=True, emailVerified=True,
) )
async_session.add(user) async_session.add(user)
await async_session.commit() await async_session.commit()
@ -61,7 +62,7 @@ async def test_user(async_session):
async def test_brand(async_session, test_user): async def test_brand(async_session, test_user):
brand = Brand( brand = Brand(
id=uuid.uuid4(), id=uuid.uuid4(),
user_id=test_user.id, user_id=_to_uuid(test_user.id),
name="Test Brand", name="Test Brand",
aliases=["TestBrand", "TB"], aliases=["TestBrand", "TB"],
website="https://testbrand.com", website="https://testbrand.com",
@ -83,7 +84,7 @@ class TestDetectionTaskModel:
task = DetectionTask( task = DetectionTask(
brand_id=test_brand.id, brand_id=test_brand.id,
user_id=test_user.id, user_id=_to_uuid(test_user.id),
name="每日品牌检测", name="每日品牌检测",
frequency="daily", frequency="daily",
engines=["chatgpt", "perplexity"], engines=["chatgpt", "perplexity"],
@ -96,7 +97,7 @@ class TestDetectionTaskModel:
assert task.id is not None assert task.id is not None
assert task.brand_id == test_brand.id assert task.brand_id == test_brand.id
assert task.user_id == test_user.id assert task.user_id == _to_uuid(test_user.id)
assert task.name == "每日品牌检测" assert task.name == "每日品牌检测"
assert task.frequency == "daily" assert task.frequency == "daily"
assert task.engines == ["chatgpt", "perplexity"] assert task.engines == ["chatgpt", "perplexity"]
@ -114,7 +115,7 @@ class TestDetectionTaskModel:
task = DetectionTask( task = DetectionTask(
brand_id=test_brand.id, brand_id=test_brand.id,
user_id=test_user.id, user_id=_to_uuid(test_user.id),
name="简单检测", name="简单检测",
frequency="weekly", frequency="weekly",
engines=["chatgpt"], engines=["chatgpt"],
@ -142,13 +143,13 @@ class TestDetectionSchedulerService:
"queries": ["最佳保险品牌", "保险推荐"], "queries": ["最佳保险品牌", "保险推荐"],
"competitor_names": ["竞品A"], "competitor_names": ["竞品A"],
} }
task = await service.create_task(task_data, test_brand.id, test_user.id, async_session) task = await service.create_task(task_data, test_brand.id, _to_uuid(test_user.id), async_session)
assert task.id is not None assert task.id is not None
assert task.name == "每日品牌检测" assert task.name == "每日品牌检测"
assert task.frequency == "daily" assert task.frequency == "daily"
assert task.brand_id == test_brand.id assert task.brand_id == test_brand.id
assert task.user_id == test_user.id assert task.user_id == _to_uuid(test_user.id)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_task(self, async_session, test_brand, test_user): async def test_update_task(self, async_session, test_brand, test_user):
@ -157,7 +158,7 @@ class TestDetectionSchedulerService:
task = DetectionTask( task = DetectionTask(
brand_id=test_brand.id, brand_id=test_brand.id,
user_id=test_user.id, user_id=_to_uuid(test_user.id),
name="旧名称", name="旧名称",
frequency="weekly", frequency="weekly",
engines=["chatgpt"], engines=["chatgpt"],
@ -173,7 +174,7 @@ class TestDetectionSchedulerService:
"frequency": "daily", "frequency": "daily",
"engines": ["chatgpt", "perplexity"], "engines": ["chatgpt", "perplexity"],
} }
updated = await service.update_task(task.id, update_data, test_user.id, async_session) updated = await service.update_task(task.id, update_data, _to_uuid(test_user.id), async_session)
assert updated.name == "新名称" assert updated.name == "新名称"
assert updated.frequency == "daily" assert updated.frequency == "daily"
@ -186,7 +187,7 @@ class TestDetectionSchedulerService:
task = DetectionTask( task = DetectionTask(
brand_id=test_brand.id, brand_id=test_brand.id,
user_id=test_user.id, user_id=_to_uuid(test_user.id),
name="待删除", name="待删除",
frequency="weekly", frequency="weekly",
engines=["chatgpt"], engines=["chatgpt"],
@ -197,7 +198,7 @@ class TestDetectionSchedulerService:
await async_session.refresh(task) await async_session.refresh(task)
service = DetectionSchedulerService() service = DetectionSchedulerService()
result = await service.delete_task(task.id, test_user.id, async_session) result = await service.delete_task(task.id, _to_uuid(test_user.id), async_session)
assert result is True assert result is True
stmt = select(DetectionTask).where(DetectionTask.id == task.id) stmt = select(DetectionTask).where(DetectionTask.id == task.id)
@ -212,7 +213,7 @@ class TestDetectionSchedulerService:
for i in range(3): for i in range(3):
task = DetectionTask( task = DetectionTask(
brand_id=test_brand.id, brand_id=test_brand.id,
user_id=test_user.id, user_id=_to_uuid(test_user.id),
name=f"任务{i}", name=f"任务{i}",
frequency="daily", frequency="daily",
engines=["chatgpt"], engines=["chatgpt"],
@ -222,7 +223,7 @@ class TestDetectionSchedulerService:
await async_session.commit() await async_session.commit()
service = DetectionSchedulerService() service = DetectionSchedulerService()
tasks = await service.get_tasks(test_brand.id, test_user.id, async_session) tasks = await service.get_tasks(test_brand.id, _to_uuid(test_user.id), async_session)
assert len(tasks) == 3 assert len(tasks) == 3
@pytest.mark.asyncio @pytest.mark.asyncio
@ -232,7 +233,7 @@ class TestDetectionSchedulerService:
task = DetectionTask( task = DetectionTask(
brand_id=test_brand.id, brand_id=test_brand.id,
user_id=test_user.id, user_id=_to_uuid(test_user.id),
name="手动触发测试", name="手动触发测试",
frequency="daily", frequency="daily",
engines=["chatgpt"], engines=["chatgpt"],
@ -245,7 +246,7 @@ class TestDetectionSchedulerService:
service = DetectionSchedulerService() service = DetectionSchedulerService()
with patch.object(service, "execute_task", new_callable=AsyncMock) as mock_execute: with patch.object(service, "execute_task", new_callable=AsyncMock) as mock_execute:
mock_execute.return_value = {"status": "success", "results": []} mock_execute.return_value = {"status": "success", "results": []}
result = await service.trigger_task(task.id, test_user.id, async_session) result = await service.trigger_task(task.id, _to_uuid(test_user.id), async_session)
assert result["status"] == "success" assert result["status"] == "success"
mock_execute.assert_called_once() mock_execute.assert_called_once()
@ -261,7 +262,7 @@ class TestDetectionSchedulerService:
"engines": ["chatgpt"], "engines": ["chatgpt"],
"queries": ["查询"], "queries": ["查询"],
} }
task = await service.create_task(task_data, test_brand.id, test_user.id, async_session) task = await service.create_task(task_data, test_brand.id, _to_uuid(test_user.id), async_session)
assert task.frequency == "hourly" assert task.frequency == "hourly"
assert task.next_run_at is not None assert task.next_run_at is not None
@ -276,7 +277,7 @@ class TestDetectionSchedulerService:
"engines": ["chatgpt"], "engines": ["chatgpt"],
"queries": ["查询"], "queries": ["查询"],
} }
task = await service.create_task(task_data, test_brand.id, test_user.id, async_session) task = await service.create_task(task_data, test_brand.id, _to_uuid(test_user.id), async_session)
assert task.frequency == "daily" assert task.frequency == "daily"
assert task.next_run_at is not None assert task.next_run_at is not None
@ -291,7 +292,7 @@ class TestDetectionSchedulerService:
"engines": ["chatgpt"], "engines": ["chatgpt"],
"queries": ["查询"], "queries": ["查询"],
} }
task = await service.create_task(task_data, test_brand.id, test_user.id, async_session) task = await service.create_task(task_data, test_brand.id, _to_uuid(test_user.id), async_session)
assert task.frequency == "weekly" assert task.frequency == "weekly"
assert task.next_run_at is not None assert task.next_run_at is not None
@ -307,7 +308,7 @@ class TestDetectionSchedulerService:
"queries": ["查询"], "queries": ["查询"],
} }
with pytest.raises(ValueError, match="frequency"): with pytest.raises(ValueError, match="frequency"):
await service.create_task(task_data, test_brand.id, test_user.id, async_session) await service.create_task(task_data, test_brand.id, _to_uuid(test_user.id), async_session)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_execute_task_flow(self, async_session, test_brand, test_user): async def test_execute_task_flow(self, async_session, test_brand, test_user):
@ -316,7 +317,7 @@ class TestDetectionSchedulerService:
task = DetectionTask( task = DetectionTask(
brand_id=test_brand.id, brand_id=test_brand.id,
user_id=test_user.id, user_id=_to_uuid(test_user.id),
name="执行流程测试", name="执行流程测试",
frequency="daily", frequency="daily",
engines=["chatgpt"], engines=["chatgpt"],
@ -356,7 +357,7 @@ class TestDetectionSchedulerService:
from app.services.detection.detection_scheduler import DetectionSchedulerService from app.services.detection.detection_scheduler import DetectionSchedulerService
service = DetectionSchedulerService() service = DetectionSchedulerService()
result = await service.delete_task(uuid.uuid4(), test_user.id, async_session) result = await service.delete_task(uuid.uuid4(), _to_uuid(test_user.id), async_session)
assert result is False assert result is False
@pytest.mark.asyncio @pytest.mark.asyncio
@ -365,12 +366,12 @@ class TestDetectionSchedulerService:
service = DetectionSchedulerService() service = DetectionSchedulerService()
with pytest.raises(TaskNotFoundError): with pytest.raises(TaskNotFoundError):
await service.update_task(uuid.uuid4(), {"name": "新名称"}, test_user.id, async_session) await service.update_task(uuid.uuid4(), {"name": "新名称"}, _to_uuid(test_user.id), async_session)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_tasks_empty(self, async_session, test_brand, test_user): async def test_get_tasks_empty(self, async_session, test_brand, test_user):
from app.services.detection.detection_scheduler import DetectionSchedulerService from app.services.detection.detection_scheduler import DetectionSchedulerService
service = DetectionSchedulerService() service = DetectionSchedulerService()
tasks = await service.get_tasks(test_brand.id, test_user.id, async_session) tasks = await service.get_tasks(test_brand.id, _to_uuid(test_user.id), async_session)
assert tasks == [] assert tasks == []

View File

@ -0,0 +1,886 @@
---
title: "refactor: GEO Agent Framework — 统一架构重构为独立项目"
type: refactor
status: active
date: 2026-06-04
---
## Summary
将 GEO 项目的 8 个 Agent 从代码重复的独立类重构为独立 Python 包(`fischer-agentkit`),采用统一 Agent Core + 配置驱动 + 可插拔 Tool/Skill/Memory 架构。新增 MCP 协议支持、三层记忆系统Working/Episodic/Semantic、Level 3 自我进化闭环(经验积累 → Prompt 自动优化 → 策略调整)、多 Agent 协同增强并行执行、Handoff、动态 Pipeline并实现 Agent 与业务系统的完全解耦。
## Problem Frame
当前 GEO 项目的 Agent 体系存在以下结构性问题:
1. **代码重复**8 个 Agent 的 `execute()` 方法包含完全相同的计时、try/except、TaskResult 构建逻辑(~250 行模板代码),每新增一个 Agent 需复制 ~150 行
2. **架构耦合**Agent 直接依赖业务 ServiceCitationService、MonitorService 等),无法独立部署和复用
3. **能力缺失**无记忆系统每次执行无状态、无自我进化Prompt 硬编码不可调优)、无技能插件(能力静态绑定)、无 MCP 支持
4. **编排局限**Pipeline 仅支持串行 DAG、无 Agent 间 Handoff、无动态路由
5. **配置僵化**AgentType 硬编码枚举、Prompt 与 Agent 紧耦合、新增 Agent 需改 5+ 文件
---
## Requirements
### Agent Core 统一架构
R1. Agent 框架必须作为独立 Python 包(`fischer-agentkit`),可独立安装、版本管理、跨项目复用
R2. 所有 Agent 共享统一的 `BaseAgent` 生命周期start → listen → execute → reflect → evolve → stop子类只需实现 `handle_task()` 返回业务数据
R3. Agent 定义支持三种模式Python 声明式(`@task` 装饰器、YAML 配置驱动(零代码)、混合模式
R4. AgentType 从硬编码枚举改为动态注册,新增 Agent 类型无需修改框架代码
R5. Agent 的 input/output 支持 JSON Schema 声明与校验
### Tool/Skill 插件系统
R6. 实现 `Tool` 抽象基类,支持 `FunctionTool`(函数工具)、`AgentTool`Agent 即工具)、`MCPTool`MCP 协议工具)三种类型
R7. 实现 `ToolRegistry`,支持工具的注册、发现、版本管理、标签分类
R8. 实现 MCP Server将现有 Agent 能力暴露为 MCP 工具供外部调用
R9. 实现 MCP ClientAgent 可调用外部 MCP 工具服务器
R10. 工具支持组合:顺序链、并行扇出/扇入、动态选择
### 记忆系统
R11. 实现 Working MemoryRedis-based存储当前任务的上下文和中间状态生命周期为单次任务
R12. 实现 Episodic Memory向量+关系数据库),记录每次任务的输入/输出/效果/反思,支持语义检索相似历史案例
R13. 实现 Semantic Memory复用现有 RAG 知识库,所有 Agent 均可通过统一接口检索知识
R14. 记忆检索采用混合策略:向量语义 + 关键词 + 知识图谱 + 时间衰减 + RRF 融合排序
### 自我进化Level 3
R15. 实现经验积累:每次任务执行后自动记录成功/失败模式、效果指标、反思总结
R16. 实现 Prompt 自动优化:基于 DSPy 风格的编译器,从任务结果中自动生成/优化 Prompt 指令和 few-shot 示例
R17. 实现策略调整:根据历史效果数据自动调整 Agent 参数temperature、tool 选择权重、Pipeline 路径)
R18. 进化变更必须经过 A/B 测试验证后才可生效,支持回滚
### 多 Agent 协同与业务编排
R19. Pipeline Engine 支持同层并行执行(无依赖的 stages 使用 `asyncio.gather` 并行)
R20. 实现 Handoff 机制Agent 可在运行时将任务转交给另一个 Agent携带上下文
R21. 支持动态 Pipeline运行时根据条件选择子流程、嵌套 Pipeline
R22. Agent 间通信支持事件驱动Redis Pub/Sub替代轮询等待
### Agent 与业务解耦
R23. Agent 框架不依赖任何 GEO 业务代码Service、Model、Repository通过 Tool 接口调用业务能力
R24. 业务系统通过 Tool 注册将自身能力暴露给 AgentAgent 通过 ToolRegistry 发现和调用
R25. Agent 配置Prompt、Tool 绑定、Memory 策略)存储在数据库,支持热更新
---
## Key Technical Decisions
KTD1. **独立 Python 包架构**`fischer-agentkit` 作为独立包发布到私有 PyPIGEO 项目通过 `pip install` 引入。包结构遵循 `src/` layout支持独立测试和版本管理。理由解耦 Agent 基础设施与业务代码,支持跨项目复用。
KTD2. **统一 Agent Core 的 execute 模板上移**将计时、try/except、TaskResult 构建、进度上报等公共逻辑全部上移到 `BaseAgent.execute()`,子类只需实现 `handle_task(task) -> dict`。理由:消除 8 个 Agent 中 ~250 行重复代码。
KTD3. **Tool 抽象三层架构**`FunctionTool`(进程内函数调用)→ `MCPTool`(跨进程 MCP 协议调用)→ `AgentTool`(将 Agent 包装为 Tool。理由覆盖从简单到复杂的所有工具场景MCP 作为标准协议确保生态兼容性。
KTD4. **MCP 双向支持**:同时实现 MCP Server暴露 Agent 能力)和 MCP Client调用外部工具。传输层优先支持 Streamable HTTP2025 标准),兼容 SSE。理由MCP 是 2025 年工具协议事实标准,双向支持最大化生态连接能力。
KTD5. **三层记忆架构**Working MemoryRedis单任务生命周期→ Episodic Memorypgvector + PostgreSQL任务经验→ Semantic Memory复用现有 RAG + 知识图谱)。理由:不同记忆类型有不同生命周期和检索模式,分层实现职责清晰。
KTD6. **Level 3 进化采用 DSPy 风格编译器**:定义 `Signature`(输入/输出 schema`Module`(可组合 Prompt 策略)→ `Optimizer`(自动优化),从任务结果中自动构建 few-shot 示例和优化指令。理由DSPy 是目前最成熟的 Prompt 自动优化范式,其编译器模式与 Agent 的 execute-reflect-evolve 生命周期天然契合。
KTD7. **配置驱动的 Agent 定义**Agent 的元数据、Prompt、Tool 绑定、Memory 策略均通过 YAML/数据库配置,运行时由 `ConfigDrivenAgent` 自动组装。理由:将新增 Agent 从写 150 行代码降为 10-20 行配置。
KTD8. **Pipeline 并行执行**:将拓扑排序后的 stages 按依赖层级分组,同层内使用 `asyncio.gather` 并行执行。理由:引用检测和趋势分析等无依赖任务应并行,当前串行执行浪费时间。
KTD9. **Handoff 基于 Redis Pub/Sub**Agent 通过 `agent:{name}:handoff` 频道发送转交请求,目标 Agent 监听并接管。理由:与现有 Redis Queue 架构一致,无需引入新组件。
---
## High-Level Technical Design
### 整体架构
```mermaid
flowchart TB
subgraph FischerAgentKit["fischer-agentkit (独立包)"]
direction TB
Core["Agent Core"]
Tools["Tool System"]
Memory["Memory System"]
Evolution["Evolution Engine"]
Orchestrator["Orchestrator"]
MCP["MCP Layer"]
subgraph Core
BaseAgent["BaseAgent<br/>生命周期管理"]
ConfigDriven["ConfigDrivenAgent<br/>配置驱动"]
Registry["AgentRegistry<br/>注册发现"]
Dispatcher["TaskDispatcher<br/>任务调度"]
Protocol["Protocol<br/>通信协议"]
end
subgraph Tools
ToolRegistry["ToolRegistry<br/>注册发现"]
FunctionTool["FunctionTool<br/>函数工具"]
MCPTool["MCPTool<br/>MCP工具"]
AgentTool["AgentTool<br/>Agent即工具"]
end
subgraph Memory
Working["Working Memory<br/>Redis"]
Episodic["Episodic Memory<br/>pgvector+PG"]
Semantic["Semantic Memory<br/>RAG+Graph"]
Retriever["MemoryRetriever<br/>混合检索"]
end
subgraph Evolution
Reflector["Reflector<br/>执行反思"]
PromptOptimizer["PromptOptimizer<br/>DSPy编译器"]
StrategyTuner["StrategyTuner<br/>策略调优"]
ABTester["ABTester<br/>A/B测试"]
end
subgraph Orchestrator
PipelineEngine["PipelineEngine<br/>DAG+并行"]
Handoff["Handoff<br/>任务转交"]
DynamicPipeline["DynamicPipeline<br/>运行时组合"]
end
subgraph MCP
MCPServer["MCP Server<br/>暴露能力"]
MCPClient["MCP Client<br/>调用外部"]
end
end
subgraph GEOBusiness["GEO 业务系统"]
Services["业务 Services"]
Models["数据 Models"]
Repos["Repositories"]
end
Core --> Tools
Core --> Memory
Core --> Evolution
Core --> Orchestrator
Tools --> MCP
Services -.->|"注册为 FunctionTool"| ToolRegistry
Semantic -.->|"复用"| Services
```
### Agent 生命周期
```mermaid
stateDiagram-v2
[*] --> Offline
Offline --> Online: start()
Online --> Listening: listen_for_tasks()
Listening --> Executing: receive_task()
Executing --> Reflecting: handle_task() 完成
Reflecting --> Evolving: reflect() 有改进建议
Reflecting --> Listening: reflect() 无改进
Evolving --> Listening: evolve() 完成
Executing --> Listening: handle_task() 异常
Online --> Offline: stop()
Listening --> Offline: stop()
state Executing {
[*] --> LoadMemory: 加载记忆
LoadMemory --> PlanTask: 规划任务
PlanTask --> RunTools: 执行工具
RunTools --> StoreMemory: 存储记忆
StoreMemory --> BuildResult: 构建结果
BuildResult --> [*]
}
state Reflecting {
[*] --> EvaluateOutcome: 评估结果
EvaluateOutcome --> ExtractPatterns: 提取模式
ExtractPatterns --> GenerateInsights: 生成洞察
GenerateInsights --> [*]
}
state Evolving {
[*] --> OptimizePrompt: 优化Prompt
OptimizePrompt --> TuneStrategy: 调优策略
TuneStrategy --> ABTest: A/B测试
ABTest --> ApplyOrRollback: 应用或回滚
ApplyOrRollback --> [*]
}
```
### Tool 系统架构
```mermaid
flowchart LR
subgraph Agent["Agent"]
Executor["Task Executor"]
end
subgraph ToolRegistry["ToolRegistry"]
FT["FunctionTool<br/>name, description<br/>input_schema, output_schema<br/>execute(**kwargs)"]
MT["MCPTool<br/>server_url, tool_name<br/>input_schema, output_schema<br/>execute(**kwargs)"]
AT["AgentTool<br/>agent_name<br/>input_mapping, output_mapping<br/>execute(**kwargs)"]
end
subgraph MCPServer["MCP Server"]
MCPEndpoints["/tools/list<br/>/tools/call<br/>/resources/read"]
end
subgraph ExternalMCP["External MCP Servers"]
ExtTool1["File System"]
ExtTool2["GitHub"]
ExtTool3["Postgres"]
end
Executor -->|"select & call"| ToolRegistry
FT -->|"进程内调用"| BusinessLogic["业务函数"]
MT -->|"HTTP/SSE"| MCPServer
MT -->|"HTTP/SSE"| ExternalMCP
AT -->|"dispatch"| TargetAgent["目标 Agent"]
MCPServer -->|"暴露"| AgentCapabilities["Agent 能力"]
```
### 记忆系统数据流
```mermaid
flowchart TB
Task["新任务到达"] --> WM["Working Memory<br/>加载当前任务上下文"]
WM --> EM["Episodic Memory<br/>检索相似历史案例"]
EM --> SM["Semantic Memory<br/>检索知识库"]
SM --> Context["组装上下文"]
Context --> Execute["执行任务"]
Execute --> WM_Write["写入 Working Memory<br/>中间状态"]
WM_Write --> Execute
Execute --> EM_Write["写入 Episodic Memory<br/>任务经验"]
Execute --> Reflect["反思评估"]
Reflect --> EM_Update["更新 Episodic Memory<br/>反思结果"]
subgraph 检索策略
VS["向量语义检索<br/>权重 0.5"]
KS["关键词精确匹配<br/>权重 0.2"]
GT["知识图谱关联<br/>权重 0.3"]
RRF["RRF 融合排序"]
TD["时间衰减"]
VS --> RRF
KS --> RRF
GT --> RRF
RRF --> TD
end
```
### 进化闭环
```mermaid
flowchart TB
Execute["任务执行"] --> Result["任务结果"]
Result --> Evaluate["效果评估<br/>成功/失败/质量分"]
Evaluate --> Pattern["模式提取<br/>成功模式/失败模式"]
Pattern --> Optimize["优化生成"]
Optimize --> PromptOpt["Prompt 优化<br/>DSPy 编译器"]
Optimize --> StrategyOpt["策略调优<br/>参数/工具权重"]
Optimize --> PipelineOpt["Pipeline 优化<br/>路径选择"]
PromptOpt --> ABTest["A/B 测试"]
StrategyOpt --> ABTest
PipelineOpt --> ABTest
ABTest -->|"效果提升"| Apply["应用变更"]
ABTest -->|"效果下降"| Rollback["回滚"]
ABTest -->|"不确定"| Extend["延长测试"]
Apply --> Record["记录进化日志"]
Rollback --> Record
Record --> Execute
```
---
## Output Structure
```
fischer-agentkit/ # 独立 Python 包
├── src/
│ └── agentkit/
│ ├── __init__.py
│ ├── core/
│ │ ├── __init__.py
│ │ ├── base.py # BaseAgent 生命周期
│ │ ├── config_driven.py # ConfigDrivenAgent 配置驱动
│ │ ├── registry.py # AgentRegistry 注册发现
│ │ ├── dispatcher.py # TaskDispatcher 任务调度
│ │ ├── protocol.py # 通信协议与数据结构
│ │ ├── exceptions.py # 异常体系
│ │ └── standalone.py # 独立启动器(自动发现)
│ ├── tools/
│ │ ├── __init__.py
│ │ ├── base.py # Tool 抽象基类
│ │ ├── function_tool.py # FunctionTool
│ │ ├── agent_tool.py # AgentTool
│ │ ├── mcp_tool.py # MCPTool
│ │ └── registry.py # ToolRegistry
│ ├── memory/
│ │ ├── __init__.py
│ │ ├── base.py # Memory 抽象基类
│ │ ├── working.py # Working Memory (Redis)
│ │ ├── episodic.py # Episodic Memory (pgvector+PG)
│ │ ├── semantic.py # Semantic Memory (RAG+Graph)
│ │ └── retriever.py # 混合检索器
│ ├── evolution/
│ │ ├── __init__.py
│ │ ├── reflector.py # 执行反思
│ │ ├── prompt_optimizer.py # DSPy 风格 Prompt 优化器
│ │ ├── strategy_tuner.py # 策略调优
│ │ ├── ab_tester.py # A/B 测试框架
│ │ └── evolution_store.py # 进化日志存储
│ ├── orchestrator/
│ │ ├── __init__.py
│ │ ├── pipeline_engine.py # DAG + 并行 Pipeline
│ │ ├── pipeline_schema.py # Pipeline 数据模型
│ │ ├── pipeline_loader.py # YAML 加载器
│ │ ├── handoff.py # Handoff 机制
│ │ └── dynamic_pipeline.py # 动态 Pipeline 组合
│ ├── mcp/
│ │ ├── __init__.py
│ │ ├── server.py # MCP Server
│ │ ├── client.py # MCP Client
│ │ └── transport.py # 传输层 (HTTP/SSE)
│ └── prompts/
│ ├── __init__.py
│ ├── template.py # PromptTemplate
│ └── section.py # PromptSection
├── tests/
│ ├── unit/
│ │ ├── test_base_agent.py
│ │ ├── test_config_driven.py
│ │ ├── test_tool_registry.py
│ │ ├── test_function_tool.py
│ │ ├── test_mcp_tool.py
│ │ ├── test_agent_tool.py
│ │ ├── test_working_memory.py
│ │ ├── test_episodic_memory.py
│ │ ├── test_semantic_memory.py
│ │ ├── test_memory_retriever.py
│ │ ├── test_reflector.py
│ │ ├── test_prompt_optimizer.py
│ │ ├── test_strategy_tuner.py
│ │ ├── test_ab_tester.py
│ │ ├── test_pipeline_parallel.py
│ │ ├── test_handoff.py
│ │ ├── test_dynamic_pipeline.py
│ │ ├── test_mcp_server.py
│ │ └── test_mcp_client.py
│ └── integration/
│ ├── test_agent_lifecycle.py
│ ├── test_tool_composition.py
│ ├── test_evolution_loop.py
│ └── test_mcp_roundtrip.py
├── pyproject.toml
└── README.md
geo/backend/ # GEO 业务系统(改造后)
├── app/
│ ├── agent_framework/ # 保留为适配层
│ │ ├── __init__.py
│ │ ├── agents/ # 改为配置驱动
│ │ │ ├── __init__.py
│ │ │ ├── configs/ # Agent YAML 配置
│ │ │ │ ├── citation_detector.yaml
│ │ │ │ ├── content_generator.yaml
│ │ │ │ ├── deai_agent.yaml
│ │ │ │ ├── geo_optimizer.yaml
│ │ │ │ ├── monitor.yaml
│ │ │ │ ├── schema_advisor.yaml
│ │ │ │ ├── competitor_analyzer.yaml
│ │ │ │ └── trend_agent.yaml
│ │ │ └── custom_handlers/ # 仅复杂 Agent 需自定义 handler
│ │ │ ├── citation_handler.py
│ │ │ └── monitor_handler.py
│ │ ├── tools/ # 业务 Tool 注册
│ │ │ ├── __init__.py
│ │ │ ├── citation_tools.py
│ │ │ ├── content_tools.py
│ │ │ ├── knowledge_tools.py
│ │ │ ├── monitor_tools.py
│ │ │ └── schema_tools.py
│ │ └── prompts/ # Prompt 模板(保留,可热更新)
│ ├── models/
│ │ └── agent.py # 新增 evolution_logs, episodic_memories 表
│ └── ...
```
---
## Implementation Units
### U1. 独立包脚手架与 BaseAgent 重构
**Goal:** 创建 `fischer-agentkit` 独立包,重构 BaseAgent 将 execute 模板代码上移,子类只需实现 `handle_task()`
**Dependencies:** 无
**Files:**
- `fischer-agentkit/src/agentkit/__init__.py`
- `fischer-agentkit/src/agentkit/core/__init__.py`
- `fischer-agentkit/src/agentkit/core/base.py`
- `fischer-agentkit/src/agentkit/core/protocol.py`
- `fischer-agentkit/src/agentkit/core/exceptions.py`
- `fischer-agentkit/src/agentkit/core/registry.py`
- `fischer-agentkit/src/agentkit/core/dispatcher.py`
- `fischer-agentkit/pyproject.toml`
- `fischer-agentkit/tests/unit/test_base_agent.py`
**Approach:**
1. 创建 `fischer-agentkit` 包,使用 `src/` layout`pyproject.toml` 声明依赖(`redis[hiredis]`, `pydantic>=2.0`, `sqlalchemy[asyncio]>=2.0`
2. 从 GEO 项目迁移 `protocol.py`AgentCapability, TaskMessage, TaskResult, TaskProgress, TaskStatus, AgentStatus移除 AgentType 硬编码枚举,改为动态注册
3. 重构 `BaseAgent`
- `execute()` 方法变为 final非抽象包含完整的计时、try/except、TaskResult 构建、进度上报
- 新增抽象方法 `handle_task(task) -> dict`,子类只需返回 output_data
- 新增生命周期钩子:`on_task_start()`, `on_task_complete()`, `on_task_failed()`
- 新增 `tools: list[Tool]` 属性默认空U3 实现)
- 新增 `memory: Memory` 属性默认空U4 实现)
4. 迁移 `registry.py`, `dispatcher.py`, `exceptions.py`,去除 GEO 业务依赖
5. `AgentRegistry.get_available_agent()` 增加负载均衡策略(轮询/最少任务)
**Patterns to follow:** 现有 `backend/app/agent_framework/base.py` 的双模式Redis/本地)设计
**Test scenarios:**
- BaseAgent 子类实现 handle_task 返回 dictexecute 自动包装为 TaskResult
- handle_task 抛异常时 execute 自动构建 FAILED TaskResult
- on_task_start/complete/failed 钩子按序调用
- AgentRegistry 动态注册新 AgentType 不报错
- AgentRegistry.get_available_agent 轮询策略返回不同 Agent
**Verification:** `fischer-agentkit` 可独立 `pip install -e .`,单元测试全部通过
---
### U2. Protocol 扩展与配置驱动 Agent
**Goal:** 扩展 Protocol 支持 JSON Schema 校验、Handoff 消息;实现 ConfigDrivenAgent支持 YAML/数据库配置驱动 Agent 定义。
**Dependencies:** U1
**Files:**
- `fischer-agentkit/src/agentkit/core/protocol.py`(扩展)
- `fischer-agentkit/src/agentkit/core/config_driven.py`
- `fischer-agentkit/src/agentkit/core/standalone.py`
- `fischer-agentkit/src/agentkit/prompts/template.py`
- `fischer-agentkit/src/agentkit/prompts/section.py`
- `fischer-agentkit/tests/unit/test_config_driven.py`
- `fischer-agentkit/tests/unit/test_protocol.py`
**Approach:**
1. Protocol 扩展:
- `AgentCapability` 增加 `input_schema: dict | None`、`output_schema: dict | None`JSON Schema
- 新增 `HandoffMessage` 数据类source_agent, target_agent, task_context, reason
- 新增 `EvolutionEvent` 数据类agent_name, change_type, before, after, metrics
- `TaskMessage` 增加 `conversation_id: str | None` 支持多轮对话
2. ConfigDrivenAgent
- 接受 YAML 配置或数据库配置自动组装Prompt 模板 + LLM 参数 + Tool 绑定 + Memory 策略
- 内置 LLM 调用 + JSON 输出解析 + 降级兜底的标准流程
- 支持三种任务模式:`llm_generate`Prompt → LLM → 解析)、`tool_call`(调用指定 Tool、`custom`(自定义 handler
3. PromptTemplate 迁移并增强:支持动态 section 组合、版本标记
4. Standalone runner 改为自动发现:扫描 `agent_configs/` 目录下的 YAML 文件,自动注册和启动
**Patterns to follow:** 现有 `prompts/base_template.py` 的 PromptSection 设计
**Test scenarios:**
- ConfigDrivenAgent 从 YAML 加载配置并正确组装
- llm_generate 模式:渲染 Prompt → 调用 Mock LLM → 解析 JSON 输出
- tool_call 模式:调用注册的 FunctionTool 并返回结果
- input_schema 校验:缺少必填字段时返回校验错误
- HandoffMessage 序列化/反序列化正确
- 自动发现:在 configs/ 目录放置 YAML 后 standalone 可自动加载
**Verification:** 通过 YAML 配置定义一个新 Agent无需写 Python 代码即可运行
---
### U3. Tool/Skill 插件系统
**Goal:** 实现 Tool 抽象基类、FunctionTool、AgentTool、ToolRegistry支持工具注册、发现、组合。
**Dependencies:** U1
**Files:**
- `fischer-agentkit/src/agentkit/tools/__init__.py`
- `fischer-agentkit/src/agentkit/tools/base.py`
- `fischer-agentkit/src/agentkit/tools/function_tool.py`
- `fischer-agentkit/src/agentkit/tools/agent_tool.py`
- `fischer-agentkit/src/agentkit/tools/registry.py`
- `fischer-agentkit/tests/unit/test_tool_registry.py`
- `fischer-agentkit/tests/unit/test_function_tool.py`
- `fischer-agentkit/tests/unit/test_agent_tool.py`
- `fischer-agentkit/tests/integration/test_tool_composition.py`
**Approach:**
1. `Tool` 抽象基类:
- 属性:`name`, `description`, `input_schema`JSON Schema, `output_schema`JSON Schema, `version`, `tags`
- 抽象方法:`async execute(**kwargs) -> dict`
- 生命周期钩子:`before_execute()`, `after_execute()`, `on_error()`
2. `FunctionTool`:包装普通 Python 函数为 Tool自动从函数签名推断 input_schema
3. `AgentTool`:包装另一个 Agent 为 Tool输入/输出通过 mapping 适配,内部通过 Dispatcher 分发任务
4. `ToolRegistry`
- `register(tool)`, `unregister(name)`, `get(name)`, `list_tools(tag=None)`
- 支持版本管理:同一工具多版本共存,默认使用最新版
- 支持标签分类:`[citation, analysis, generation, optimization]`
5. 工具组合:
- `SequentialChain([tool_a, tool_b])`:顺序执行,前一个输出作为后一个输入
- `ParallelFanOut([tool_a, tool_b, tool_c])`:并行执行,结果合并
- `DynamicSelector(llm, tools)`LLM 根据任务动态选择工具
**Patterns to follow:** 现有 `LLMFactory` 的注册-创建模式
**Test scenarios:**
- FunctionTool 从函数自动推断 schema 并执行
- AgentTool 分发任务到目标 Agent 并返回结果
- ToolRegistry 注册/发现/按标签过滤
- SequentialChain 顺序执行两个工具
- ParallelFanOut 并行执行三个工具
- DynamicSelector 根据 LLM 判断选择合适工具
- 工具版本管理:注册 v1 和 v2默认返回 v2
**Verification:** 通过 ToolRegistry 注册 3 个 FunctionToolAgent 可声明式绑定并调用
---
### U4. 记忆系统
**Goal:** 实现三层记忆系统Working/Episodic/Semantic和混合检索器。
**Dependencies:** U1
**Files:**
- `fischer-agentkit/src/agentkit/memory/__init__.py`
- `fischer-agentkit/src/agentkit/memory/base.py`
- `fischer-agentkit/src/agentkit/memory/working.py`
- `fischer-agentkit/src/agentkit/memory/episodic.py`
- `fischer-agentkit/src/agentkit/memory/semantic.py`
- `fischer-agentkit/src/agentkit/memory/retriever.py`
- `fischer-agentkit/tests/unit/test_working_memory.py`
- `fischer-agentkit/tests/unit/test_episodic_memory.py`
- `fischer-agentkit/tests/unit/test_semantic_memory.py`
- `fischer-agentkit/tests/unit/test_memory_retriever.py`
**Approach:**
1. `Memory` 抽象基类:
- `async store(key, value, metadata)`, `async retrieve(query, top_k)`, `async delete(key)`
- `async search(query, top_k, filters)` — 语义检索
2. `WorkingMemory`Redis
- 以 `agent:{name}:working_memory:{task_id}` 为 key 前缀
- 支持自动过期TTL = 任务超时时间 × 2
- 提供 `get_context()` 方法,返回格式化的上下文字符串
3. `EpisodicMemory`pgvector + PostgreSQL
- 表结构:`id, agent_name, task_type, input_summary, output_summary, outcome(success/fail), quality_score, reflection, embedding, created_at`
- 写入时自动生成 embedding复用 Embedder 接口)
- 检索:向量语义 + 关键词 + 时间衰减 + RRF 融合
4. `SemanticMemory`
- 适配器模式,对接 GEO 项目的 `RAGService``GraphQuery`
- 提供统一的 `search(query, knowledge_base_ids, top_k)` 接口
5. `MemoryRetriever`(混合检索器):
- 并行查询三层记忆,按权重融合排序
- 时间衰减:`score *= exp(-0.01 * age_hours)`
- 上下文窗口管理:总 token 不超过预算
**Patterns to follow:** 现有 `HybridRetriever` 的 RRF 融合排序模式
**Test scenarios:**
- WorkingMemory 存取任务上下文TTL 过期后自动清除
- EpisodicMemory 写入任务经验并按语义检索相似案例
- EpisodicMemory 时间衰减:近期经验权重高于远期
- SemanticMemory 通过适配器调用 RAGService
- MemoryRetriever 混合检索三层记忆并按权重融合
- 上下文窗口管理:检索结果超过 token 预算时智能截断
**Verification:** Agent 执行任务后自动写入 EpisodicMemory后续相似任务可检索到历史经验
---
### U5. MCP Server 与 Client
**Goal:** 实现 MCP Server暴露 Agent 能力)和 MCP Client调用外部 MCP 工具),完成 MCPTool。
**Dependencies:** U3
**Files:**
- `fischer-agentkit/src/agentkit/mcp/__init__.py`
- `fischer-agentkit/src/agentkit/mcp/server.py`
- `fischer-agentkit/src/agentkit/mcp/client.py`
- `fischer-agentkit/src/agentkit/mcp/transport.py`
- `fischer-agentkit/src/agentkit/tools/mcp_tool.py`
- `fischer-agentkit/tests/unit/test_mcp_server.py`
- `fischer-agentkit/tests/unit/test_mcp_client.py`
- `fischer-agentkit/tests/unit/test_mcp_tool.py`
- `fischer-agentkit/tests/integration/test_mcp_roundtrip.py`
**Approach:**
1. MCP Server
- 基于 FastAPI 实现,支持 Streamable HTTP 传输
- 端点:`/tools/list`(列出可用工具)、`/tools/call`(调用工具)、`/resources/read`(读取资源)
- 自动将 ToolRegistry 中注册的工具暴露为 MCP 工具
- 支持 SSE 流式响应
2. MCP Client
- 连接外部 MCP ServerHTTP/SSE
- 自动发现远程工具并注册到本地 ToolRegistry
- 支持工具调用的流式响应
3. MCPTool
- 继承 Tool 基类,内部通过 MCP Client 调用远程工具
- 自动从 MCP Server 获取 input_schema
4. Transport 层:
- 抽象 `Transport` 接口(`send_request`, `receive_response`
- 实现 `HTTPTransport`Streamable HTTP`SSETransport`Server-Sent Events
**Patterns to follow:** MCP 官方 Python SDK 的 Server/Client 模式
**Test scenarios:**
- MCP Server 启动后 `/tools/list` 返回已注册工具列表
- MCP Client 连接 Server 并调用工具返回正确结果
- MCPTool 通过 Client 调用远程工具
- SSE 流式响应正确传输
- MCP Server 自动暴露 ToolRegistry 中的 FunctionTool
- 外部 MCP Server 的工具自动注册到本地 ToolRegistry
**Verification:** 启动 MCP Server通过 MCP Client 调用 Agent 能力,端到端成功
---
### U6. 自我进化引擎Level 3
**Goal:** 实现执行反思、DSPy 风格 Prompt 自动优化、策略调优、A/B 测试框架。
**Dependencies:** U1, U4
**Files:**
- `fischer-agentkit/src/agentkit/evolution/__init__.py`
- `fischer-agentkit/src/agentkit/evolution/reflector.py`
- `fischer-agentkit/src/agentkit/evolution/prompt_optimizer.py`
- `fischer-agentkit/src/agentkit/evolution/strategy_tuner.py`
- `fischer-agentkit/src/agentkit/evolution/ab_tester.py`
- `fischer-agentkit/src/agentkit/evolution/evolution_store.py`
- `fischer-agentkit/tests/unit/test_reflector.py`
- `fischer-agentkit/tests/unit/test_prompt_optimizer.py`
- `fischer-agentkit/tests/unit/test_strategy_tuner.py`
- `fischer-agentkit/tests/unit/test_ab_tester.py`
- `fischer-agentkit/tests/integration/test_evolution_loop.py`
**Approach:**
1. `Reflector`(执行反思):
- 每次任务完成后自动评估:成功/失败、质量评分(基于 output_schema 约束和业务指标)
- 提取模式:常见失败原因、高效策略、低效路径
- 生成反思总结存入 EpisodicMemory
2. `PromptOptimizer`DSPy 风格编译器):
- `Signature`:定义 `input_fields -> output_fields` 的结构化签名
- `Module`:可组合的 Prompt 策略ChainOfThought, ReAct, FewShot
- `Optimizer`
- `BootstrapFewShot`:从成功案例中自动构建 few-shot 示例
- `MIPROv2`:多目标 Prompt 优化(指令 + few-shot 联合优化)
- 基于历史任务数据编译,产出优化后的 Prompt
3. `StrategyTuner`(策略调优):
- 可调参数temperature, tool_selection_weights, pipeline_path
- 基于 Bayesian Optimization 搜索最优参数组合
- 每次调优记录 before/after 指标
4. `ABTester`A/B 测试框架):
- 支持配置分流比例(如 80% 原版 / 20% 实验版)
- 自动收集实验组和对照组的效果指标
- 统计显著性检验t-test达到置信度后自动决策
5. `EvolutionStore`(进化日志):
- 表结构:`id, agent_name, change_type(prompt/strategy/pipeline), before, after, ab_test_id, status(active/rolled_back), created_at`
- 支持回滚:`rollback(evolution_id)`
**Patterns to follow:** DSPy 的 Signature/Module/Optimizer 三层架构
**Test scenarios:**
- Reflector 评估成功任务生成正面反思
- Reflector 评估失败任务提取失败模式
- PromptOptimizer 从 10 个成功案例自动生成 few-shot 示例
- PromptOptimizer 优化后的 Prompt 在测试集上效果提升
- StrategyTuner 调整 temperature 后效果改善
- ABTester 80/20 分流,实验组效果显著后自动应用
- EvolutionStore 记录变更并支持回滚
**Verification:** Agent 执行 20 次任务后PromptOptimizer 自动优化 PromptABTest 验证效果提升后自动应用
---
### U7. 多 Agent 协同增强
**Goal:** Pipeline Engine 支持并行执行、Handoff 机制、动态 Pipeline 组合。
**Dependencies:** U1, U2
**Files:**
- `fischer-agentkit/src/agentkit/orchestrator/__init__.py`
- `fischer-agentkit/src/agentkit/orchestrator/pipeline_engine.py`
- `fischer-agentkit/src/agentkit/orchestrator/pipeline_schema.py`
- `fischer-agentkit/src/agentkit/orchestrator/pipeline_loader.py`
- `fischer-agentkit/src/agentkit/orchestrator/handoff.py`
- `fischer-agentkit/src/agentkit/orchestrator/dynamic_pipeline.py`
- `fischer-agentkit/tests/unit/test_pipeline_parallel.py`
- `fischer-agentkit/tests/unit/test_handoff.py`
- `fischer-agentkit/tests/unit/test_dynamic_pipeline.py`
**Approach:**
1. Pipeline 并行执行:
- 拓扑排序后按依赖层级分组(同层无依赖)
- 同层内使用 `asyncio.gather` 并行执行
- 并行 stage 的结果合并后传给下一层
2. Handoff 机制:
- `HandoffManager` 管理转交请求
- Agent 调用 `self.handoff(target_agent, context, reason)` 发起转交
- 通过 Redis Pub/Sub `agent:{target}:handoff` 频道通知目标 Agent
- 目标 Agent 接收后创建新任务,携带源 Agent 的上下文
- 源 Agent 的任务状态变为 `HANDOFF`
3. 动态 Pipeline
- 支持 `sub_pipeline` stage 类型:引用另一个 YAML Pipeline
- 支持 `conditional_pipeline`:根据运行时条件选择子流程
- 支持 `loop_pipeline`:循环执行直到条件满足
4. 事件驱动替代轮询:
- Pipeline Engine 通过 Redis Pub/Sub 订阅 `agent:{name}:result` 频道
- 任务完成后自动触发下一 stage无需轮询
**Patterns to follow:** 现有 `PipelineEngine` 的 DAG 拓扑排序和变量解析
**Test scenarios:**
- 3 个无依赖 stage 并行执行,总耗时约等于最长单个 stage
- HandoffAgent A 转交任务给 Agent BB 接收并执行
- 动态 Pipeline根据条件选择不同的子流程
- 子 Pipeline 嵌套执行并正确传递变量
- 事件驱动:任务完成后自动触发下一 stage无轮询间隔
- 循环 Pipeline条件满足后退出循环
**Verification:** 内容生产 Pipeline 的 deai_processing 和 geo_optimization 并行执行,总耗时减少约 40%
---
### U8. GEO 业务系统适配与迁移
**Goal:** 将 GEO 项目的 8 个 Agent 迁移到新框架,业务 Service 注册为 Tool实现完全解耦。
**Dependencies:** U1, U2, U3, U4, U5, U6, U7
**Files:**
- `geo/backend/app/agent_framework/agents/configs/citation_detector.yaml`
- `geo/backend/app/agent_framework/agents/configs/content_generator.yaml`
- `geo/backend/app/agent_framework/agents/configs/deai_agent.yaml`
- `geo/backend/app/agent_framework/agents/configs/geo_optimizer.yaml`
- `geo/backend/app/agent_framework/agents/configs/monitor.yaml`
- `geo/backend/app/agent_framework/agents/configs/schema_advisor.yaml`
- `geo/backend/app/agent_framework/agents/configs/competitor_analyzer.yaml`
- `geo/backend/app/agent_framework/agents/configs/trend_agent.yaml`
- `geo/backend/app/agent_framework/agents/custom_handlers/citation_handler.py`
- `geo/backend/app/agent_framework/agents/custom_handlers/monitor_handler.py`
- `geo/backend/app/agent_framework/tools/__init__.py`
- `geo/backend/app/agent_framework/tools/citation_tools.py`
- `geo/backend/app/agent_framework/tools/content_tools.py`
- `geo/backend/app/agent_framework/tools/knowledge_tools.py`
- `geo/backend/app/agent_framework/tools/monitor_tools.py`
- `geo/backend/app/agent_framework/tools/schema_tools.py`
- `geo/backend/app/agent_framework/tools/competitor_tools.py`
- `geo/backend/app/agent_framework/tools/trend_tools.py`
- `geo/backend/app/models/agent.py`(新增表)
- `geo/backend/requirements.txt`(添加 fischer-agentkit 依赖)
- `geo/backend/app/agent_framework/__init__.py`(适配层)
**Approach:**
1. 业务 Tool 注册:
- `CitationTools``execute_single_platform`, `get_or_create_task`, `calculate_next_query_at`
- `ContentTools``retrieve_knowledge`(包装 RAGService
- `MonitorTools``check_and_compare`, `generate_change_report`, `create_monitoring_record`
- `SchemaTools``identify_missing_dimensions`, `match_templates`, `validate_json_ld`
- `CompetitorTools``analyze_competitor`
- `TrendTools``analyze_trends`, `get_hotspots`
2. Agent 配置迁移:
- LLM 驱动型 AgentContentGenerator, DeAI, GEOOptimizer, SchemaAdvisor→ YAML 配置 + `llm_generate` 模式,零代码
- Service 代理型 AgentCitationDetector, Monitor→ YAML 配置 + `custom` handler仅保留业务逻辑方法
- CompetitorAnalyzer, TrendAgent → YAML 配置 + `tool_call` 模式
3. 数据库迁移:
- 新增 `episodic_memories`
- 新增 `evolution_logs`
- 新增 `ab_test_configs`
4. 适配层:
- `geo/backend/app/agent_framework/` 保留为适配层import from `agentkit`
- 现有 API 端点(`/agents/`, `/agents/tasks/`)保持不变
- 现有 Pipeline YAML 保持兼容
**Patterns to follow:** 现有 Agent 的业务逻辑实现,迁移时保持行为一致
**Test scenarios:**
- 8 个 Agent 迁移后行为与原版一致(回归测试)
- 新增 Agent 只需 YAML 配置,无需写 Python 代码
- 业务 Tool 注册后 Agent 可通过 ToolRegistry 发现和调用
- EpisodicMemory 自动记录任务经验
- EvolutionStore 记录进化变更
- 现有 API 端点功能不变
- 现有 Pipeline YAML 正常执行
**Verification:** 运行现有 Agent 集成测试全部通过,新增一个 YAML-only Agent 成功执行任务
---
## Scope Boundaries
### In Scope
- `fischer-agentkit` 独立包的完整实现Core, Tools, Memory, Evolution, Orchestrator, MCP
- GEO 项目 8 个 Agent 的迁移和适配
- MCP Server/Client 双向支持
- Level 3 自我进化(反思 + Prompt 优化 + 策略调优 + A/B 测试)
- Pipeline 并行执行、Handoff、动态 Pipeline
- 三层记忆系统
- 数据库迁移(新增表)
### Deferred for Later
- A2A Protocol 支持(跨组织 Agent 协作,待 MCP 稳定后再评估)
- Agent 可视化编排 UI前端拖拽式 Pipeline 编辑器)
- Agent 市场/共享机制(跨组织共享 Agent 配置和 Tool
- 多租户隔离增强Agent 级别的资源隔离和配额)
- Agent 安全沙箱Tool 执行的权限控制和审计)
### Outside This Project's Identity
- 通用 LLM Gateway不属于 Agent 框架范畴)
- 替代现有 LLMFactory保持现有 LLM 调用体系不变)
- 前端 Agent 管理界面重构(本次聚焦后端架构)
---
## Risks & Dependencies
| Risk | Impact | Mitigation |
|------|--------|------------|
| MCP 协议规范变动2025-2026 仍在演进) | MCPTool/MCPClient 需要跟进修改 | 抽象 Transport 层,协议变更只影响 Transport 实现 |
| DSPy 风格 Prompt 优化可能需要大量训练数据 | 优化效果不明显 | BootstrapFewShot 从少量数据开始MIPROv2 需 50+ 案例时才启用 |
| 独立包与 GEO 项目的版本同步 | GEO 升级 agentkit 版本可能引入 breaking change | 语义化版本管理GEO 项目 pin 版本CI 自动测试兼容性 |
| Episodic Memory 数据量增长 | pgvector 查询性能下降 | 设置 TTL 自动清理、定期归档、HNSW 索引优化 |
| 迁移期间 8 个 Agent 行为不一致 | 业务回归 | 逐个迁移 + 回归测试,保留旧代码直到新代码验证通过 |
| Pipeline 并行执行引入并发问题 | 结果不一致 | 同层 stage 只读共享变量,写入隔离 |
---
## System-Wide Impact
- **数据库**:新增 3 张表episodic_memories, evolution_logs, ab_test_configs需 Alembic 迁移
- **Redis**:新增 Working Memory key 空间和 Handoff 频道,内存使用增加约 20%
- **API**:现有 `/agents/` 端点保持兼容,新增 `/agents/{name}/evolution` 查看进化历史
- **部署**`fischer-agentkit` 需发布到私有 PyPIGEO 项目 Dockerfile 需添加安装步骤
- **监控**:新增进化相关 Prometheus 指标evolution_attempts_total, evolution_success_rate, ab_test_decisions
- **依赖**:新增 `mcp` Python 包依赖MCP SDK
---
## Sources & Research
- 现有 Agent 框架代码:`backend/app/agent_framework/` 全部文件
- 现有 RAG 服务:`backend/app/services/knowledge/rag_service.py`, `retriever.py`, `enhanced_rag.py`
- 现有 LLM 工厂:`backend/app/services/llm/factory.py`, `base.py`
- MCP 官方规范Model Context Protocol (Anthropic, 2024-2025)
- DSPy 框架Stanford NLP, Prompt 自动优化范式
- Google A2A ProtocolAgent-to-Agent HTTP 协议 (2025.4)
- LangGraph 0.2.xStateGraph + Checkpoint 架构
- OpenAI Agents SDK 1.0Handoff 模式
- Google ADK 0.1Agent 组合模式Sequential/Parallel/Loop

File diff suppressed because one or more lines are too long

View File

@ -1,35 +1,4 @@
{ {
"status": "failed", "status": "passed",
"failedTests": [ "failedTests": []
"49bb04a247fb3a93c942-35c9c801dc12bf7d35f2",
"49bb04a247fb3a93c942-698fa46448ec3c4fcc54",
"49bb04a247fb3a93c942-1e4abb88d57e30e27b50",
"49bb04a247fb3a93c942-e8f0348aed6f07ff875b",
"49bb04a247fb3a93c942-425d1f1370bc2355e391",
"49bb04a247fb3a93c942-256ee0a55d33f5229d15",
"49bb04a247fb3a93c942-5f9ad201eca7dc8ab295",
"49bb04a247fb3a93c942-ef50deff2e29416bfdcb",
"49bb04a247fb3a93c942-ea929ccac5306e953e14",
"49bb04a247fb3a93c942-fa6e92367349adbf7ce6",
"49bb04a247fb3a93c942-66abb2794b5dc22a4d2c",
"49bb04a247fb3a93c942-c8f2635d09bd86abde8c",
"49bb04a247fb3a93c942-495ca225b90357a78c73",
"49bb04a247fb3a93c942-9542aeae6bc1bec1de36",
"49bb04a247fb3a93c942-6d102a4fe6c78e5adea3",
"49bb04a247fb3a93c942-16202bb8288239216298",
"49bb04a247fb3a93c942-103e4c5c525b115390ab",
"49bb04a247fb3a93c942-8ff6f91ccaf60e9cf222",
"49bb04a247fb3a93c942-594d166a39b1685af15e",
"c05d59a3df11e753f1a9-ebd92d58fe79ae1246f9",
"30f4937bf66a39abf0e1-654253a85d12a575b58b",
"30f4937bf66a39abf0e1-531589fe51a02f263015",
"655e7220bbb90e50e0e9-a2f8fcdf1070b299cd54",
"655e7220bbb90e50e0e9-49b697c7e1ba346c76b7",
"655e7220bbb90e50e0e9-c5cd58b8cddad9907a64",
"655e7220bbb90e50e0e9-696b9830da3889bddb0f",
"655e7220bbb90e50e0e9-e8bb52f185bd86dce2da",
"fe534f0825407f213faa-71fec5acbf34a381ec37",
"fe534f0825407f213faa-c40f01c5980780ee6589",
"fe534f0825407f213faa-a6a9b88ab9323419b7d3"
]
} }