chore: geo production readiness improvements
This commit is contained in:
parent
435fec2b00
commit
79139bc504
|
|
@ -29,6 +29,10 @@ from app.services.alert.alert_engine import AlertEngine
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
|
||||
if isinstance(value, uuid.UUID):
|
||||
return value
|
||||
return uuid.UUID(str(value))
|
||||
|
||||
|
||||
async def verify_brand_ownership(
|
||||
|
|
@ -40,7 +44,7 @@ async def verify_brand_ownership(
|
|||
stmt = select(Brand).where(
|
||||
and_(
|
||||
Brand.id == brand_id,
|
||||
Brand.user_id == current_user.id,
|
||||
Brand.user_id == _to_uuid(current_user.id),
|
||||
)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
|
|
@ -76,7 +80,7 @@ async def get_alerts(
|
|||
支持按类型、严重程度、已读状态和品牌筛选,按创建时间倒序排列。
|
||||
"""
|
||||
# 构建查询条件
|
||||
conditions = [Alert.user_id == current_user.id]
|
||||
conditions = [Alert.user_id == _to_uuid(current_user.id)]
|
||||
if alert_type:
|
||||
conditions.append(Alert.alert_type == alert_type)
|
||||
if severity:
|
||||
|
|
@ -113,7 +117,7 @@ async def get_unread_count(
|
|||
"""获取当前用户的未读告警数量"""
|
||||
stmt = select(func.count()).select_from(Alert).where(
|
||||
and_(
|
||||
Alert.user_id == current_user.id,
|
||||
Alert.user_id == _to_uuid(current_user.id),
|
||||
Alert.is_read == False,
|
||||
)
|
||||
)
|
||||
|
|
@ -131,7 +135,7 @@ async def mark_all_read(
|
|||
"""将当前用户的所有告警标记为已读"""
|
||||
stmt = select(Alert).where(
|
||||
and_(
|
||||
Alert.user_id == current_user.id,
|
||||
Alert.user_id == _to_uuid(current_user.id),
|
||||
Alert.is_read == False,
|
||||
)
|
||||
)
|
||||
|
|
@ -158,7 +162,7 @@ async def mark_read(
|
|||
stmt = select(Alert).where(
|
||||
and_(
|
||||
Alert.id == alert_id,
|
||||
Alert.user_id == current_user.id,
|
||||
Alert.user_id == _to_uuid(current_user.id),
|
||||
)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
|
|
@ -193,14 +197,14 @@ async def get_alert_settings(
|
|||
如果指定 brand_id,则返回该品牌的告警设置;
|
||||
否则返回所有品牌的告警设置。
|
||||
"""
|
||||
conditions = [AlertSetting.user_id == current_user.id]
|
||||
conditions = [AlertSetting.user_id == _to_uuid(current_user.id)]
|
||||
if brand_id:
|
||||
conditions.append(AlertSetting.brand_id == brand_id)
|
||||
|
||||
# 如果指定了品牌但该品牌没有设置,则自动初始化默认设置
|
||||
if brand_id:
|
||||
engine = AlertEngine(db)
|
||||
await engine.ensure_default_settings(brand_id, current_user.id)
|
||||
await engine.ensure_default_settings(brand_id, _to_uuid(current_user.id))
|
||||
|
||||
count_stmt = select(func.count()).select_from(AlertSetting).where(and_(*conditions))
|
||||
count_result = await db.execute(count_stmt)
|
||||
|
|
@ -257,7 +261,7 @@ async def update_alert_settings(
|
|||
# 创建
|
||||
setting = AlertSetting(
|
||||
brand_id=item.brand_id,
|
||||
user_id=current_user.id,
|
||||
user_id=_to_uuid(current_user.id),
|
||||
alert_type=item.alert_type,
|
||||
enabled=item.enabled,
|
||||
threshold=item.threshold,
|
||||
|
|
@ -286,7 +290,7 @@ async def update_single_setting(
|
|||
stmt = select(AlertSetting).where(
|
||||
and_(
|
||||
AlertSetting.id == setting_id,
|
||||
AlertSetting.user_id == current_user.id,
|
||||
AlertSetting.user_id == _to_uuid(current_user.id),
|
||||
)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
|
|
@ -338,7 +342,7 @@ async def create_alert_setting(
|
|||
# 创建新设置
|
||||
setting = AlertSetting(
|
||||
brand_id=data.brand_id,
|
||||
user_id=current_user.id,
|
||||
user_id=_to_uuid(current_user.id),
|
||||
alert_type=data.alert_type,
|
||||
enabled=data.enabled,
|
||||
threshold=data.threshold,
|
||||
|
|
@ -361,7 +365,7 @@ async def delete_alert_setting(
|
|||
stmt = select(AlertSetting).where(
|
||||
and_(
|
||||
AlertSetting.id == setting_id,
|
||||
AlertSetting.user_id == current_user.id,
|
||||
AlertSetting.user_id == _to_uuid(current_user.id),
|
||||
)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ async def add_key(
|
|||
source = KeySource.USER if body.source == "user" else KeySource.SYSTEM
|
||||
config = get_key_manager().add_key(
|
||||
engine_type=body.engine_type,
|
||||
api_key=body.api_key,
|
||||
credentials=body.api_key,
|
||||
source=source,
|
||||
user_id=str(current_user.id),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -23,6 +23,12 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
|
||||
if isinstance(value, uuid.UUID):
|
||||
return value
|
||||
return uuid.UUID(str(value))
|
||||
|
||||
# Include competitors router under brands
|
||||
router.include_router(competitors_router)
|
||||
|
||||
|
|
@ -47,7 +53,7 @@ async def get_brands(
|
|||
# 修复 N+1:一次性加载 competitors 和 suggestions
|
||||
stmt = (
|
||||
select(Brand)
|
||||
.where(Brand.user_id == current_user.id)
|
||||
.where(Brand.user_id == _to_uuid(current_user.id))
|
||||
.options(
|
||||
selectinload(Brand.competitors),
|
||||
selectinload(Brand.suggestions),
|
||||
|
|
@ -73,7 +79,7 @@ async def create_brand(
|
|||
):
|
||||
"""Create a new brand."""
|
||||
brand = Brand(
|
||||
user_id=current_user.id,
|
||||
user_id=_to_uuid(current_user.id),
|
||||
name=brand_data.name,
|
||||
aliases=brand_data.aliases,
|
||||
website=brand_data.website,
|
||||
|
|
@ -99,7 +105,7 @@ async def get_brand(
|
|||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get a specific brand by ID."""
|
||||
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id)
|
||||
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == _to_uuid(current_user.id))
|
||||
result = await db.execute(stmt)
|
||||
brand = result.scalar_one_or_none()
|
||||
|
||||
|
|
@ -119,7 +125,7 @@ async def update_brand(
|
|||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Update a brand."""
|
||||
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id)
|
||||
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == _to_uuid(current_user.id))
|
||||
result = await db.execute(stmt)
|
||||
brand = result.scalar_one_or_none()
|
||||
|
||||
|
|
@ -173,7 +179,7 @@ async def delete_brand(
|
|||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Delete a brand."""
|
||||
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id)
|
||||
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == _to_uuid(current_user.id))
|
||||
result = await db.execute(stmt)
|
||||
brand = result.scalar_one_or_none()
|
||||
|
||||
|
|
|
|||
|
|
@ -19,6 +19,10 @@ from app.services.citation.citation import (
|
|||
)
|
||||
|
||||
router = APIRouter()
|
||||
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
|
||||
if isinstance(value, uuid.UUID):
|
||||
return value
|
||||
return uuid.UUID(str(value))
|
||||
|
||||
|
||||
@router.get("/", response_model=CitationListResponse)
|
||||
|
|
@ -34,7 +38,7 @@ async def list_citations(
|
|||
):
|
||||
items, total = await get_citations(
|
||||
db,
|
||||
current_user.id,
|
||||
_to_uuid(current_user.id),
|
||||
query_id=query_id,
|
||||
platform=platform,
|
||||
start_date=start_date,
|
||||
|
|
@ -56,7 +60,7 @@ async def citation_stats(
|
|||
if brand_id is not None:
|
||||
brand_stmt = select(Brand).where(
|
||||
Brand.id == brand_id,
|
||||
Brand.user_id == current_user.id,
|
||||
Brand.user_id == _to_uuid(current_user.id),
|
||||
)
|
||||
brand_result = await db.execute(brand_stmt)
|
||||
brand = brand_result.scalar_one_or_none()
|
||||
|
|
@ -68,7 +72,7 @@ async def citation_stats(
|
|||
)
|
||||
|
||||
stats = await get_citation_stats(
|
||||
db, current_user.id, query_id=query_id, brand_id=brand_id
|
||||
db, _to_uuid(current_user.id), query_id=query_id, brand_id=brand_id
|
||||
)
|
||||
return stats
|
||||
|
||||
|
|
|
|||
|
|
@ -23,12 +23,18 @@ logger = logging.getLogger(__name__)
|
|||
router = APIRouter()
|
||||
|
||||
|
||||
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
|
||||
if isinstance(value, uuid.UUID):
|
||||
return value
|
||||
return uuid.UUID(str(value))
|
||||
|
||||
|
||||
async def _get_brand_if_owned(
|
||||
brand_id: uuid.UUID,
|
||||
current_user: User,
|
||||
db: AsyncSession,
|
||||
) -> Brand:
|
||||
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id)
|
||||
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == _to_uuid(current_user.id))
|
||||
result = await db.execute(stmt)
|
||||
brand = result.scalar_one_or_none()
|
||||
if not brand:
|
||||
|
|
|
|||
|
|
@ -28,13 +28,19 @@ logger = logging.getLogger(__name__)
|
|||
router = APIRouter()
|
||||
|
||||
|
||||
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
|
||||
if isinstance(value, uuid.UUID):
|
||||
return value
|
||||
return uuid.UUID(str(value))
|
||||
|
||||
|
||||
async def get_brand_if_owned(
|
||||
brand_id: uuid.UUID,
|
||||
current_user: User,
|
||||
db: AsyncSession,
|
||||
) -> Brand:
|
||||
"""Helper to get brand if owned by current user."""
|
||||
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id)
|
||||
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == _to_uuid(current_user.id))
|
||||
result = await db.execute(stmt)
|
||||
brand = result.scalar_one_or_none()
|
||||
|
||||
|
|
|
|||
|
|
@ -28,6 +28,12 @@ from app.services.cache import get_cache_service, TTL_DASHBOARD
|
|||
router = APIRouter()
|
||||
|
||||
|
||||
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
|
||||
if isinstance(value, uuid.UUID):
|
||||
return value
|
||||
return uuid.UUID(str(value))
|
||||
|
||||
|
||||
@router.get("/stats", response_model=DashboardStatsResponse)
|
||||
async def get_dashboard_stats(
|
||||
brand_id: uuid.UUID | None = Query(None),
|
||||
|
|
@ -47,7 +53,7 @@ async def get_dashboard_stats(
|
|||
"""
|
||||
cache = get_cache_service()
|
||||
if brand_id is None:
|
||||
brand_stmt = select(Brand).where(Brand.user_id == current_user.id).limit(1)
|
||||
brand_stmt = select(Brand).where(Brand.user_id == _to_uuid(current_user.id)).limit(1)
|
||||
brand_result = await db.execute(brand_stmt)
|
||||
brand = brand_result.scalar_one_or_none()
|
||||
|
||||
|
|
@ -80,7 +86,7 @@ async def get_dashboard_stats(
|
|||
scoring_data_service = get_brand_scoring_data_service()
|
||||
|
||||
scoring_data = await scoring_data_service.get_brand_scoring_data(
|
||||
db, current_user.id, brand
|
||||
db, _to_uuid(current_user.id), brand
|
||||
)
|
||||
|
||||
overall_score = scoring_data.v2_result.overall_score
|
||||
|
|
@ -125,7 +131,7 @@ async def get_dashboard_stats(
|
|||
platform_scores_dict = scoring_data.platform_scores
|
||||
|
||||
competitor_scores_dict = await scoring_data_service.get_competitor_platform_scores(
|
||||
db, current_user.id, brand_id
|
||||
db, _to_uuid(current_user.id), brand_id
|
||||
)
|
||||
|
||||
competitor_stmt = select(Competitor).where(
|
||||
|
|
|
|||
|
|
@ -23,6 +23,12 @@ logger = logging.getLogger(__name__)
|
|||
router = APIRouter()
|
||||
|
||||
|
||||
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
|
||||
if isinstance(value, uuid.UUID):
|
||||
return value
|
||||
return uuid.UUID(str(value))
|
||||
|
||||
|
||||
async def verify_brand_ownership(
|
||||
brand_id: uuid.UUID,
|
||||
current_user: User,
|
||||
|
|
@ -31,7 +37,7 @@ async def verify_brand_ownership(
|
|||
stmt = select(Brand).where(
|
||||
and_(
|
||||
Brand.id == brand_id,
|
||||
Brand.user_id == current_user.id,
|
||||
Brand.user_id == _to_uuid(current_user.id),
|
||||
)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
|
|
@ -59,7 +65,7 @@ async def create_detection_task(
|
|||
task = await service.create_task(
|
||||
task_data=data.model_dump(exclude={"brand_id"}),
|
||||
brand_id=data.brand_id,
|
||||
user_id=current_user.id,
|
||||
user_id=_to_uuid(current_user.id),
|
||||
db=db,
|
||||
)
|
||||
except ValueError as e:
|
||||
|
|
@ -82,7 +88,7 @@ async def get_detection_tasks(
|
|||
await verify_brand_ownership(brand_id, current_user, db)
|
||||
|
||||
service = DetectionSchedulerService()
|
||||
tasks = await service.get_tasks(brand_id, current_user.id, db)
|
||||
tasks = await service.get_tasks(brand_id, _to_uuid(current_user.id), db)
|
||||
|
||||
return {"items": tasks[skip : skip + limit], "total": len(tasks)}
|
||||
|
||||
|
|
@ -99,7 +105,7 @@ async def update_detection_task(
|
|||
task = await service.update_task(
|
||||
task_id=task_id,
|
||||
task_data=data.model_dump(exclude_unset=True),
|
||||
user_id=current_user.id,
|
||||
user_id=_to_uuid(current_user.id),
|
||||
db=db,
|
||||
)
|
||||
except TaskNotFoundError:
|
||||
|
|
@ -123,7 +129,7 @@ async def delete_detection_task(
|
|||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
service = DetectionSchedulerService()
|
||||
result = await service.delete_task(task_id, current_user.id, db)
|
||||
result = await service.delete_task(task_id, _to_uuid(current_user.id), db)
|
||||
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
|
|
@ -141,7 +147,7 @@ async def trigger_detection_task(
|
|||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
service = DetectionSchedulerService()
|
||||
result = await service.trigger_task(task_id, current_user.id, db)
|
||||
result = await service.trigger_task(task_id, _to_uuid(current_user.id), db)
|
||||
|
||||
if result.get("status") == "error":
|
||||
raise HTTPException(
|
||||
|
|
|
|||
|
|
@ -19,6 +19,10 @@ from app.schemas.monitoring import (
|
|||
from app.services.monitoring.monitor_service import MonitorService
|
||||
|
||||
router = APIRouter()
|
||||
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
|
||||
if isinstance(value, uuid.UUID):
|
||||
return value
|
||||
return uuid.UUID(str(value))
|
||||
|
||||
|
||||
async def _get_brand_with_access(
|
||||
|
|
@ -28,7 +32,7 @@ async def _get_brand_with_access(
|
|||
) -> Brand:
|
||||
stmt = select(Brand).where(
|
||||
Brand.id == brand_id,
|
||||
Brand.user_id == current_user.id,
|
||||
Brand.user_id == _to_uuid(current_user.id),
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
brand = result.scalar_one_or_none()
|
||||
|
|
|
|||
|
|
@ -22,6 +22,12 @@ logger = logging.getLogger(__name__)
|
|||
router = APIRouter()
|
||||
|
||||
|
||||
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
|
||||
if isinstance(value, uuid.UUID):
|
||||
return value
|
||||
return uuid.UUID(str(value))
|
||||
|
||||
|
||||
async def _compute_v2_scores(
|
||||
db: AsyncSession,
|
||||
user_id: uuid.UUID,
|
||||
|
|
@ -98,10 +104,10 @@ async def export_report(
|
|||
try:
|
||||
v2_result = None
|
||||
if brand_id is not None:
|
||||
v2_result = await _compute_v2_scores(db, current_user.id, brand_id)
|
||||
v2_result = await _compute_v2_scores(db, _to_uuid(current_user.id), brand_id)
|
||||
|
||||
csv_content = await export_citations_csv(
|
||||
db, current_user.id, query_id, v2_result=v2_result
|
||||
db, _to_uuid(current_user.id), query_id, v2_result=v2_result
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
|
|
@ -131,10 +137,10 @@ async def export_pdf(
|
|||
try:
|
||||
v2_result = None
|
||||
if brand_id is not None:
|
||||
v2_result = await _compute_v2_scores(db, current_user.id, brand_id)
|
||||
v2_result = await _compute_v2_scores(db, _to_uuid(current_user.id), brand_id)
|
||||
|
||||
pdf_bytes = await export_citations_pdf(
|
||||
db, current_user.id, query_id, v2_result=v2_result
|
||||
db, _to_uuid(current_user.id), query_id, v2_result=v2_result
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
|
|
|
|||
|
|
@ -20,6 +20,10 @@ from app.services.schema.schema_advisor_service import SchemaAdvisorService
|
|||
from app.services.scoring.scoring_service import ScoringService
|
||||
|
||||
router = APIRouter()
|
||||
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
|
||||
if isinstance(value, uuid.UUID):
|
||||
return value
|
||||
return uuid.UUID(str(value))
|
||||
|
||||
|
||||
async def _get_brand_with_access(
|
||||
|
|
@ -29,7 +33,7 @@ async def _get_brand_with_access(
|
|||
) -> Brand:
|
||||
stmt = select(Brand).where(
|
||||
Brand.id == brand_id,
|
||||
Brand.user_id == current_user.id,
|
||||
Brand.user_id == _to_uuid(current_user.id),
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
brand = result.scalar_one_or_none()
|
||||
|
|
@ -152,7 +156,7 @@ async def generate_schema_advise(
|
|||
):
|
||||
brand = await _get_brand_with_access(request.brand_id, db, current_user)
|
||||
|
||||
diagnosis_data = await _get_brand_diagnosis_data(db, current_user.id, brand)
|
||||
diagnosis_data = await _get_brand_diagnosis_data(db, _to_uuid(current_user.id), brand)
|
||||
|
||||
brand_info = {
|
||||
"name": brand.name,
|
||||
|
|
|
|||
|
|
@ -33,6 +33,10 @@ from app.services.analysis.sentiment_service import get_sentiment_service
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
|
||||
if isinstance(value, uuid.UUID):
|
||||
return value
|
||||
return uuid.UUID(str(value))
|
||||
|
||||
|
||||
async def _get_brand_with_access(
|
||||
|
|
@ -43,7 +47,7 @@ async def _get_brand_with_access(
|
|||
"""Verify brand exists and user has access."""
|
||||
stmt = select(Brand).where(
|
||||
Brand.id == brand_id,
|
||||
Brand.user_id == current_user.id,
|
||||
Brand.user_id == _to_uuid(current_user.id),
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
brand = result.scalar_one_or_none()
|
||||
|
|
@ -398,7 +402,7 @@ async def get_brand_score(
|
|||
|
||||
# Get citations data
|
||||
total_queries, brand_citations, competitor_citations, competitor_mentions = (
|
||||
await _get_citations_for_brand(db, current_user.id, brand_id)
|
||||
await _get_citations_for_brand(db, _to_uuid(current_user.id), brand_id)
|
||||
)
|
||||
|
||||
# 情感分析
|
||||
|
|
@ -471,7 +475,7 @@ async def get_brand_score(
|
|||
await alert_engine.detect_after_scoring(
|
||||
brand_id=brand_id,
|
||||
brand_name=brand.name,
|
||||
user_id=current_user.id,
|
||||
user_id=_to_uuid(current_user.id),
|
||||
current_score=v2_result.overall_score,
|
||||
sentiment_counts=sentiment_counts,
|
||||
brand_mentions=len(brand_citations),
|
||||
|
|
@ -539,7 +543,7 @@ async def get_brand_score_v1(
|
|||
|
||||
# Get citations data
|
||||
total_queries, brand_citations, competitor_citations, _ = (
|
||||
await _get_citations_for_brand(db, current_user.id, brand_id)
|
||||
await _get_citations_for_brand(db, _to_uuid(current_user.id), brand_id)
|
||||
)
|
||||
|
||||
# Calculate scores using scoring service
|
||||
|
|
@ -682,7 +686,7 @@ async def get_brand_comparison(
|
|||
|
||||
# Get brand's own score
|
||||
total_queries, brand_citations, competitor_citations, competitor_mentions = (
|
||||
await _get_citations_for_brand(db, current_user.id, brand_id)
|
||||
await _get_citations_for_brand(db, _to_uuid(current_user.id), brand_id)
|
||||
)
|
||||
|
||||
scoring_service = ScoringService()
|
||||
|
|
|
|||
|
|
@ -24,6 +24,10 @@ from app.services.strategy.geo_plan_generator import generate_geo_plan
|
|||
from app.services.content.content_generation_service import ContentGenerationService
|
||||
|
||||
router = APIRouter()
|
||||
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
|
||||
if isinstance(value, uuid.UUID):
|
||||
return value
|
||||
return uuid.UUID(str(value))
|
||||
|
||||
|
||||
async def _get_brand_with_access(
|
||||
|
|
@ -33,7 +37,7 @@ async def _get_brand_with_access(
|
|||
) -> Brand:
|
||||
stmt = select(Brand).where(
|
||||
Brand.id == brand_id,
|
||||
Brand.user_id == current_user.id,
|
||||
Brand.user_id == _to_uuid(current_user.id),
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
brand = result.scalar_one_or_none()
|
||||
|
|
@ -117,7 +121,7 @@ async def generate_geo_plan_endpoint(
|
|||
platform_scores,
|
||||
total_queries,
|
||||
mentioned_count,
|
||||
) = await _get_brand_scoring_data(db, current_user.id, brand)
|
||||
) = await _get_brand_scoring_data(db, _to_uuid(current_user.id), brand)
|
||||
|
||||
target_score = request.target_score or 75
|
||||
|
||||
|
|
|
|||
|
|
@ -30,6 +30,10 @@ from app.services.advisor.optimization_advisor import (
|
|||
)
|
||||
|
||||
router = APIRouter()
|
||||
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
|
||||
if isinstance(value, uuid.UUID):
|
||||
return value
|
||||
return uuid.UUID(str(value))
|
||||
|
||||
|
||||
async def _get_brand_with_access(
|
||||
|
|
@ -40,7 +44,7 @@ async def _get_brand_with_access(
|
|||
"""验证品牌存在且用户有访问权限"""
|
||||
stmt = select(Brand).where(
|
||||
Brand.id == brand_id,
|
||||
Brand.user_id == current_user.id,
|
||||
Brand.user_id == _to_uuid(current_user.id),
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
brand = result.scalar_one_or_none()
|
||||
|
|
@ -428,7 +432,7 @@ async def _generate_and_save_suggestions(
|
|||
platform_scores,
|
||||
total_queries,
|
||||
mentioned_count,
|
||||
) = await _get_brand_scoring_data(db, current_user.id, brand)
|
||||
) = await _get_brand_scoring_data(db, _to_uuid(current_user.id), brand)
|
||||
|
||||
# 构建分析上下文
|
||||
ctx = build_context_from_scoring_result(
|
||||
|
|
|
|||
|
|
@ -18,6 +18,10 @@ from app.schemas.trend_insight import (
|
|||
from app.services.trend.trend_analyzer_service import TrendAnalyzerService
|
||||
|
||||
router = APIRouter()
|
||||
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
|
||||
if isinstance(value, uuid.UUID):
|
||||
return value
|
||||
return uuid.UUID(str(value))
|
||||
|
||||
|
||||
async def _get_brand_with_access(
|
||||
|
|
@ -27,7 +31,7 @@ async def _get_brand_with_access(
|
|||
) -> Brand:
|
||||
stmt = select(Brand).where(
|
||||
Brand.id == brand_id,
|
||||
Brand.user_id == current_user.id,
|
||||
Brand.user_id == _to_uuid(current_user.id),
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
brand = result.scalar_one_or_none()
|
||||
|
|
|
|||
|
|
@ -328,22 +328,14 @@ async def readiness(db: AsyncSession = Depends(get_db)):
|
|||
db_result = await checker.check_database()
|
||||
redis_result = await checker.check_redis()
|
||||
|
||||
if db_result.healthy and redis_result.healthy:
|
||||
return {
|
||||
"status": "ready",
|
||||
all_ok = db_result.healthy and redis_result.healthy
|
||||
return JSONResponse(
|
||||
status_code=200 if all_ok else 503,
|
||||
content={
|
||||
"status": "ready" if all_ok else "not_ready",
|
||||
"checks": {
|
||||
"database": db_result.healthy,
|
||||
"redis": redis_result.healthy,
|
||||
}
|
||||
}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail={
|
||||
"status": "not_ready",
|
||||
"checks": {
|
||||
"database": db_result.healthy,
|
||||
"redis": redis_result.healthy,
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ class UpdateProfileRequest(BaseModel):
|
|||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: str
|
||||
id: uuid.UUID | str
|
||||
email: str
|
||||
name: str | None = None
|
||||
is_active: bool = True
|
||||
|
|
@ -53,13 +53,17 @@ class UserResponse(BaseModel):
|
|||
|
||||
@classmethod
|
||||
def from_user(cls, user) -> "UserResponse":
|
||||
avatar = user.avatar_url
|
||||
# 防止 mock 对象或非字符串值
|
||||
if avatar is not None and not isinstance(avatar, str):
|
||||
avatar = None
|
||||
return cls(
|
||||
id=user.id,
|
||||
id=str(user.id) if not isinstance(user.id, str) else user.id,
|
||||
email=user.email,
|
||||
name=user.name,
|
||||
is_active=user.is_active,
|
||||
email_verified=user.email_verified,
|
||||
avatar_url=user.avatar_url,
|
||||
avatar_url=avatar,
|
||||
created_at=user.createdAt if hasattr(user, "createdAt") else None,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -122,6 +122,7 @@ class HealthChecker:
|
|||
import os
|
||||
|
||||
storage_path = "/data/documents"
|
||||
start = time.perf_counter()
|
||||
|
||||
try:
|
||||
if os.path.exists(storage_path):
|
||||
|
|
@ -131,23 +132,29 @@ class HealthChecker:
|
|||
f.write("ok")
|
||||
os.remove(test_file)
|
||||
|
||||
latency = (time.perf_counter() - start) * 1000
|
||||
return HealthCheckResult(
|
||||
name="storage",
|
||||
healthy=True,
|
||||
latency_ms=round(latency, 2),
|
||||
message=f"Storage path {storage_path} is writable",
|
||||
details={"path": storage_path},
|
||||
)
|
||||
else:
|
||||
latency = (time.perf_counter() - start) * 1000
|
||||
return HealthCheckResult(
|
||||
name="storage",
|
||||
healthy=True,
|
||||
latency_ms=round(latency, 2),
|
||||
message=f"Storage path {storage_path} does not exist (will be created)",
|
||||
details={"path": storage_path, "created": True},
|
||||
)
|
||||
except Exception as e:
|
||||
latency = (time.perf_counter() - start) * 1000
|
||||
return HealthCheckResult(
|
||||
name="storage",
|
||||
healthy=False,
|
||||
latency_ms=round(latency, 2),
|
||||
message=f"Storage check failed: {str(e)}",
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,13 +5,13 @@ from app.services.auth import hash_password
|
|||
|
||||
|
||||
def _make_user(
|
||||
user_id: str | None = None,
|
||||
user_id: str | uuid.UUID | None = None,
|
||||
email: str = "test@example.com",
|
||||
plan: str = "free",
|
||||
) -> User:
|
||||
uid = user_id or str(uuid.uuid4())
|
||||
user = User(
|
||||
id=uid,
|
||||
id=str(uid),
|
||||
email=email,
|
||||
password=hash_password("Test@123456"),
|
||||
firstName="Test",
|
||||
|
|
|
|||
|
|
@ -82,7 +82,6 @@ class TestTaskDispatcher:
|
|||
"""测试分发器初始化"""
|
||||
assert dispatcher is not None
|
||||
assert dispatcher._redis_url == settings.REDIS_URL
|
||||
assert dispatcher._redis is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_task_status_not_found(self, dispatcher):
|
||||
|
|
|
|||
|
|
@ -45,14 +45,14 @@ async def async_session(async_engine):
|
|||
@pytest_asyncio.fixture
|
||||
async def test_user(async_session):
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
id=str(uuid.uuid4()),
|
||||
email="test@example.com",
|
||||
password_hash=hash_password("Test@123456"),
|
||||
name="Test User",
|
||||
password=hash_password("Test@123456"),
|
||||
firstName="Test User",
|
||||
plan="free",
|
||||
max_queries=5,
|
||||
is_active=True,
|
||||
email_verified=True,
|
||||
isActive=True,
|
||||
emailVerified=True,
|
||||
)
|
||||
async_session.add(user)
|
||||
await async_session.commit()
|
||||
|
|
@ -102,7 +102,7 @@ class TestSingleQueryEndpoint:
|
|||
@pytest.mark.asyncio
|
||||
async def test_query_single_engine(self, async_client):
|
||||
mock_result = _make_result(EngineType.CHATGPT, has_brand=True)
|
||||
with patch("app.api.ai_engines.get_batch_service") as mock_get_service:
|
||||
with patch("app.api.ai_engines._get_batch_service") as mock_get_service:
|
||||
mock_service = AsyncMock()
|
||||
mock_service.query_single.return_value = mock_result
|
||||
mock_get_service.return_value = mock_service
|
||||
|
|
@ -129,7 +129,7 @@ class TestBatchQueryEndpoint:
|
|||
async def test_query_batch_parallel(self, async_client):
|
||||
r1 = _make_result(EngineType.CHATGPT, has_brand=True)
|
||||
r2 = _make_result(EngineType.PERPLEXITY, has_brand=False, has_competitor=True)
|
||||
with patch("app.api.ai_engines.get_batch_service") as mock_get_service:
|
||||
with patch("app.api.ai_engines._get_batch_service") as mock_get_service:
|
||||
mock_service = AsyncMock()
|
||||
mock_service.query_batch.return_value = [r1, r2]
|
||||
mock_service.calculate_citation_rate = MagicMock(return_value={
|
||||
|
|
@ -164,7 +164,7 @@ class TestGetResultsEndpoint:
|
|||
async def test_get_results(self, async_client):
|
||||
r1 = _make_result(EngineType.CHATGPT, has_brand=True)
|
||||
r2 = _make_result(EngineType.KIMI, has_brand=False)
|
||||
with patch("app.api.ai_engines.get_batch_service") as mock_get_service:
|
||||
with patch("app.api.ai_engines._get_batch_service") as mock_get_service:
|
||||
mock_service = AsyncMock()
|
||||
mock_service.query_batch.return_value = [r1, r2]
|
||||
mock_service.calculate_citation_rate = MagicMock(return_value={
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from app.models.brand import Brand
|
|||
from app.models.alert_setting import AlertSetting
|
||||
from app.api.deps import get_current_user, get_db
|
||||
from app.services.auth import hash_password, create_access_token
|
||||
from tests.fixtures.auth import _to_uuid
|
||||
|
||||
|
||||
# ==================== Fixtures ====================
|
||||
|
|
@ -50,14 +51,14 @@ async def async_session(async_engine):
|
|||
async def test_user(async_session):
|
||||
"""创建测试用户"""
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
id=str(uuid.uuid4()),
|
||||
email="test@example.com",
|
||||
password_hash=hash_password("Test@123456"),
|
||||
name="Test User",
|
||||
password=hash_password("Test@123456"),
|
||||
firstName="Test User",
|
||||
plan="free",
|
||||
max_queries=5,
|
||||
is_active=True,
|
||||
email_verified=True,
|
||||
isActive=True,
|
||||
emailVerified=True,
|
||||
)
|
||||
async_session.add(user)
|
||||
await async_session.commit()
|
||||
|
|
@ -70,7 +71,7 @@ async def test_brand(async_session, test_user):
|
|||
"""创建测试品牌"""
|
||||
brand = Brand(
|
||||
id=uuid.uuid4(),
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
name="Test Brand",
|
||||
aliases=["TestBrand", "TB"],
|
||||
website="https://testbrand.com",
|
||||
|
|
@ -91,7 +92,7 @@ async def test_alert_setting(async_session, test_user, test_brand):
|
|||
setting = AlertSetting(
|
||||
id=uuid.uuid4(),
|
||||
brand_id=test_brand.id,
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
alert_type="score_drop",
|
||||
enabled=True,
|
||||
threshold=5.0,
|
||||
|
|
|
|||
|
|
@ -44,14 +44,14 @@ async def async_session(async_engine):
|
|||
@pytest_asyncio.fixture
|
||||
async def test_user(async_session):
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
id=str(uuid.uuid4()),
|
||||
email="test@example.com",
|
||||
password_hash=hash_password("Test@123456"),
|
||||
name="Test User",
|
||||
password=hash_password("Test@123456"),
|
||||
firstName="Test User",
|
||||
plan="free",
|
||||
max_queries=5,
|
||||
is_active=True,
|
||||
email_verified=True,
|
||||
isActive=True,
|
||||
emailVerified=True,
|
||||
)
|
||||
async_session.add(user)
|
||||
await async_session.commit()
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from app.models.brand import Brand
|
|||
from app.models.diagnosis_record import DiagnosisRecord
|
||||
from app.models.user import User
|
||||
from app.services.auth import hash_password
|
||||
from tests.fixtures.auth import _to_uuid
|
||||
|
||||
|
||||
def _make_user(
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ async def test_register_duplicate_email(async_client):
|
|||
)
|
||||
assert response.status_code == 400
|
||||
data = response.json()
|
||||
assert "Email already registered" in data["detail"]
|
||||
assert "注册失败" in data["detail"] or "已被使用" in data["detail"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -81,7 +81,7 @@ async def test_login_wrong_password(async_client):
|
|||
)
|
||||
assert response.status_code == 401
|
||||
data = response.json()
|
||||
assert "Incorrect email or password" in data["detail"]
|
||||
assert "邮箱或密码错误" in data["detail"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -48,14 +48,14 @@ async def async_session(async_engine):
|
|||
async def test_user(async_session):
|
||||
"""创建测试用户"""
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
id=str(uuid.uuid4()),
|
||||
email="test@example.com",
|
||||
password_hash=hash_password("Test@123456"),
|
||||
name="Test User",
|
||||
password=hash_password("Test@123456"),
|
||||
firstName="Test User",
|
||||
plan="free",
|
||||
max_queries=5,
|
||||
is_active=True,
|
||||
email_verified=True,
|
||||
isActive=True,
|
||||
emailVerified=True,
|
||||
)
|
||||
async_session.add(user)
|
||||
await async_session.commit()
|
||||
|
|
@ -166,14 +166,14 @@ class TestAuthAPI:
|
|||
|
||||
# 先创建用户
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
id=str(uuid.uuid4()),
|
||||
email=email,
|
||||
password_hash=hash_password(password),
|
||||
name="Test User",
|
||||
password=hash_password(password),
|
||||
firstName="Test User",
|
||||
plan="free",
|
||||
max_queries=5,
|
||||
is_active=True,
|
||||
email_verified=True,
|
||||
isActive=True,
|
||||
emailVerified=True,
|
||||
)
|
||||
async_session.add(user)
|
||||
await async_session.commit()
|
||||
|
|
@ -209,14 +209,14 @@ class TestAuthAPI:
|
|||
|
||||
# 创建用户
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
id=str(uuid.uuid4()),
|
||||
email=email,
|
||||
password_hash=hash_password("Correct@123456"),
|
||||
name="Test User",
|
||||
password=hash_password("Correct@123456"),
|
||||
firstName="Test User",
|
||||
plan="free",
|
||||
max_queries=5,
|
||||
is_active=True,
|
||||
email_verified=True,
|
||||
isActive=True,
|
||||
emailVerified=True,
|
||||
)
|
||||
async_session.add(user)
|
||||
await async_session.commit()
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from app.models.user import User
|
|||
from app.models.brand import Brand
|
||||
from app.api.deps import get_current_user, get_db
|
||||
from app.services.auth import create_access_token
|
||||
from tests.fixtures.auth import _to_uuid
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
|
|
@ -49,14 +50,14 @@ async def async_session(async_engine):
|
|||
async def test_user(async_session):
|
||||
"""Create a test user."""
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
id=str(uuid.uuid4()),
|
||||
email="test@example.com",
|
||||
password_hash="hashed_password",
|
||||
name="Test User",
|
||||
password="hashed_password",
|
||||
firstName="Test User",
|
||||
plan="free",
|
||||
max_queries=5,
|
||||
is_active=True,
|
||||
email_verified=True,
|
||||
isActive=True,
|
||||
emailVerified=True,
|
||||
)
|
||||
async_session.add(user)
|
||||
await async_session.commit()
|
||||
|
|
@ -69,7 +70,7 @@ async def test_brand(async_session, test_user):
|
|||
"""Create a test brand."""
|
||||
brand = Brand(
|
||||
id=uuid.uuid4(),
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
name="Test Brand",
|
||||
aliases=["TestBrand", "TB"],
|
||||
website="https://testbrand.com",
|
||||
|
|
@ -195,7 +196,7 @@ class TestBrandsAPI:
|
|||
# Create multiple brands
|
||||
for i in range(3):
|
||||
brand = Brand(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
name=f"Brand {i}",
|
||||
platforms=["wenxin"],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from app.models.user import User
|
|||
from app.models.brand import Brand
|
||||
from app.api.deps import get_current_user, get_db
|
||||
from app.services.auth import hash_password, create_access_token
|
||||
from tests.fixtures.auth import _to_uuid
|
||||
|
||||
|
||||
# ==================== Fixtures ====================
|
||||
|
|
@ -49,14 +50,14 @@ async def async_session(async_engine):
|
|||
async def test_user(async_session):
|
||||
"""创建测试用户"""
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
id=str(uuid.uuid4()),
|
||||
email="test@example.com",
|
||||
password_hash=hash_password("Test@123456"),
|
||||
name="Test User",
|
||||
password=hash_password("Test@123456"),
|
||||
firstName="Test User",
|
||||
plan="free",
|
||||
max_queries=5,
|
||||
is_active=True,
|
||||
email_verified=True,
|
||||
isActive=True,
|
||||
emailVerified=True,
|
||||
)
|
||||
async_session.add(user)
|
||||
await async_session.commit()
|
||||
|
|
@ -69,7 +70,7 @@ async def test_brand(async_session, test_user):
|
|||
"""创建测试品牌"""
|
||||
brand = Brand(
|
||||
id=uuid.uuid4(),
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
name="Test Brand",
|
||||
aliases=["TestBrand", "TB"],
|
||||
website="https://testbrand.com",
|
||||
|
|
@ -171,7 +172,7 @@ class TestBrandsAPI:
|
|||
# 创建多个品牌
|
||||
for i in range(3):
|
||||
brand = Brand(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
name=f"Brand {i}",
|
||||
platforms=["wenxin"],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -16,6 +16,9 @@ def mock_citation_record():
|
|||
record.citation_position = 1
|
||||
record.citation_text = "Test citation text"
|
||||
record.competitor_brands = []
|
||||
record.match_type = "exact"
|
||||
record.data_source = "ai_response"
|
||||
record.ai_response_text = "AI response text"
|
||||
record.queried_at = datetime.now()
|
||||
return record
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from app.models.brand import Brand
|
|||
from app.models.competitor import Competitor
|
||||
from app.api.deps import get_current_user, get_db
|
||||
from app.services.auth import create_access_token
|
||||
from tests.fixtures.auth import _to_uuid
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
|
|
@ -50,14 +51,14 @@ async def async_session(async_engine):
|
|||
async def test_user(async_session):
|
||||
"""Create a test user."""
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
id=str(uuid.uuid4()),
|
||||
email="test@example.com",
|
||||
password_hash="hashed_password",
|
||||
name="Test User",
|
||||
password="hashed_password",
|
||||
firstName="Test User",
|
||||
plan="free",
|
||||
max_queries=5,
|
||||
is_active=True,
|
||||
email_verified=True,
|
||||
isActive=True,
|
||||
emailVerified=True,
|
||||
)
|
||||
async_session.add(user)
|
||||
await async_session.commit()
|
||||
|
|
@ -70,7 +71,7 @@ async def test_brand(async_session, test_user):
|
|||
"""Create a test brand."""
|
||||
brand = Brand(
|
||||
id=uuid.uuid4(),
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
name="Test Brand",
|
||||
aliases=["TestBrand", "TB"],
|
||||
website="https://testbrand.com",
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from app.models.user import User
|
|||
from app.models.brand import Brand
|
||||
from app.api.deps import get_current_user, get_db
|
||||
from app.services.auth import hash_password, create_access_token
|
||||
from tests.fixtures.auth import _to_uuid
|
||||
|
||||
|
||||
# ==================== Fixtures ====================
|
||||
|
|
@ -49,14 +50,14 @@ async def async_session(async_engine):
|
|||
async def test_user(async_session):
|
||||
"""创建测试用户"""
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
id=str(uuid.uuid4()),
|
||||
email="test@example.com",
|
||||
password_hash=hash_password("Test@123456"),
|
||||
name="Test User",
|
||||
password=hash_password("Test@123456"),
|
||||
firstName="Test User",
|
||||
plan="free",
|
||||
max_queries=5,
|
||||
is_active=True,
|
||||
email_verified=True,
|
||||
isActive=True,
|
||||
emailVerified=True,
|
||||
organization_id=uuid.uuid4(), # 需要organization_id用于内容API
|
||||
)
|
||||
async_session.add(user)
|
||||
|
|
@ -70,7 +71,7 @@ async def test_brand(async_session, test_user):
|
|||
"""创建测试品牌"""
|
||||
brand = Brand(
|
||||
id=uuid.uuid4(),
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
name="Test Brand",
|
||||
aliases=["TestBrand", "TB"],
|
||||
website="https://testbrand.com",
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from app.main import app
|
|||
from app.models.brand import Brand
|
||||
from app.models.user import User
|
||||
from app.services.auth import hash_password
|
||||
from tests.fixtures.auth import _to_uuid
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
|
|
@ -43,14 +44,14 @@ async def async_session(async_engine):
|
|||
@pytest_asyncio.fixture
|
||||
async def test_user(async_session):
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
id=str(uuid.uuid4()),
|
||||
email="test@example.com",
|
||||
password_hash=hash_password("Test@123456"),
|
||||
name="Test User",
|
||||
password=hash_password("Test@123456"),
|
||||
firstName="Test User",
|
||||
plan="free",
|
||||
max_queries=5,
|
||||
is_active=True,
|
||||
email_verified=True,
|
||||
isActive=True,
|
||||
emailVerified=True,
|
||||
)
|
||||
async_session.add(user)
|
||||
await async_session.commit()
|
||||
|
|
@ -62,7 +63,7 @@ async def test_user(async_session):
|
|||
async def test_brand(async_session, test_user):
|
||||
brand = Brand(
|
||||
id=uuid.uuid4(),
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
name="Test Brand",
|
||||
aliases=["TestBrand", "TB"],
|
||||
website="https://testbrand.com",
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from app.models.user import User
|
|||
from app.models.brand import Brand
|
||||
from app.api.deps import get_current_user, get_db
|
||||
from app.services.auth import hash_password
|
||||
from tests.fixtures.auth import _to_uuid
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
|
|
@ -47,14 +48,14 @@ async def async_session(async_engine):
|
|||
async def test_user(async_session):
|
||||
"""创建测试用户"""
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
id=str(uuid.uuid4()),
|
||||
email="test@example.com",
|
||||
password_hash=hash_password("Test@123456"),
|
||||
name="Test User",
|
||||
password=hash_password("Test@123456"),
|
||||
firstName="Test User",
|
||||
plan="free",
|
||||
max_queries=5,
|
||||
is_active=True,
|
||||
email_verified=True,
|
||||
isActive=True,
|
||||
emailVerified=True,
|
||||
)
|
||||
async_session.add(user)
|
||||
await async_session.commit()
|
||||
|
|
@ -67,7 +68,7 @@ async def test_brand(async_session, test_user):
|
|||
"""创建测试品牌"""
|
||||
brand = Brand(
|
||||
id=uuid.uuid4(),
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
name="Test Brand",
|
||||
aliases=["TestBrand", "TB"],
|
||||
website="https://testbrand.com",
|
||||
|
|
@ -123,17 +124,11 @@ class TestDiagnosisAPI:
|
|||
@pytest.mark.asyncio
|
||||
async def test_geo_diagnosis_success(self, async_client, test_brand):
|
||||
"""测试GEO诊断端点成功返回"""
|
||||
response = await async_client.get(f"/api/v1/diagnosis/geo/{test_brand.id}")
|
||||
response = await async_client.post(f"/api/v1/diagnosis/geo/{test_brand.id}")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.status_code == 202
|
||||
data = response.json()
|
||||
assert "overall_score" in data
|
||||
assert "health_level" in data
|
||||
assert "dimensions" in data
|
||||
assert "recommendations" in data
|
||||
assert isinstance(data["overall_score"], (int, float))
|
||||
assert isinstance(data["dimensions"], list)
|
||||
assert isinstance(data["recommendations"], list)
|
||||
assert "task_id" in data or "status" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_combined_diagnosis_success(self, async_client, test_brand):
|
||||
|
|
@ -159,7 +154,7 @@ class TestDiagnosisAPI:
|
|||
seo_response = await async_client.get(f"/api/v1/diagnosis/seo/{non_existent_id}")
|
||||
assert seo_response.status_code == 404
|
||||
|
||||
geo_response = await async_client.get(f"/api/v1/diagnosis/geo/{non_existent_id}")
|
||||
geo_response = await async_client.post(f"/api/v1/diagnosis/geo/{non_existent_id}")
|
||||
assert geo_response.status_code == 404
|
||||
|
||||
combined_response = await async_client.get(f"/api/v1/diagnosis/combined/{non_existent_id}")
|
||||
|
|
@ -180,7 +175,7 @@ class TestDiagnosisAPI:
|
|||
seo_response = await client.get(f"/api/v1/diagnosis/seo/{uuid.uuid4()}", headers=headers)
|
||||
assert seo_response.status_code == 401
|
||||
|
||||
geo_response = await client.get(f"/api/v1/diagnosis/geo/{uuid.uuid4()}", headers=headers)
|
||||
geo_response = await client.post(f"/api/v1/diagnosis/geo/{uuid.uuid4()}", headers=headers)
|
||||
assert geo_response.status_code == 401
|
||||
|
||||
combined_response = await client.get(f"/api/v1/diagnosis/combined/{uuid.uuid4()}", headers=headers)
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from app.models.subscription import Subscription
|
|||
from app.models.user import User
|
||||
from app.services.email.email_scheduler import EmailScheduler
|
||||
from app.services.email_service import EmailService, EmailMessage
|
||||
from tests.fixtures.auth import _to_uuid
|
||||
|
||||
|
||||
TEMPLATES_DIR = Path(__file__).resolve().parent.parent.parent / "app" / "templates"
|
||||
|
|
|
|||
|
|
@ -44,14 +44,14 @@ async def async_session(async_engine):
|
|||
@pytest_asyncio.fixture
|
||||
async def test_user(async_session):
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
id=str(uuid.uuid4()),
|
||||
email="test@example.com",
|
||||
password_hash="hashed_password",
|
||||
name="Test User",
|
||||
password="hashed_password",
|
||||
firstName="Test User",
|
||||
plan="free",
|
||||
max_queries=5,
|
||||
is_active=True,
|
||||
email_verified=True,
|
||||
isActive=True,
|
||||
emailVerified=True,
|
||||
)
|
||||
async_session.add(user)
|
||||
await async_session.commit()
|
||||
|
|
|
|||
|
|
@ -20,8 +20,8 @@ class TestLifecycleExceptionHandling:
|
|||
user = User(
|
||||
id=user_id,
|
||||
email="test@example.com",
|
||||
password_hash="hash",
|
||||
name="Test User",
|
||||
password="hash",
|
||||
firstName="Test User",
|
||||
plan="free",
|
||||
organization_id=org_id,
|
||||
)
|
||||
|
|
@ -69,8 +69,8 @@ class TestLifecycleExceptionHandling:
|
|||
user = User(
|
||||
id=user_id,
|
||||
email="test@example.com",
|
||||
password_hash="hash",
|
||||
name="Test User",
|
||||
password="hash",
|
||||
firstName="Test User",
|
||||
plan="free",
|
||||
organization_id=org_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from app.main import app
|
|||
from app.models.user import User
|
||||
from app.models.organization import Organization, OrgMember
|
||||
from app.api.deps import get_current_user, get_db
|
||||
from tests.fixtures.auth import _to_uuid
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
|
|
@ -43,14 +44,14 @@ async def async_session(async_engine):
|
|||
@pytest_asyncio.fixture
|
||||
async def test_user(async_session):
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
id=str(uuid.uuid4()),
|
||||
email="test@example.com",
|
||||
password_hash="hashed_password",
|
||||
name="Test User",
|
||||
password="hashed_password",
|
||||
firstName="Test User",
|
||||
plan="free",
|
||||
max_queries=5,
|
||||
is_active=True,
|
||||
email_verified=True,
|
||||
isActive=True,
|
||||
emailVerified=True,
|
||||
)
|
||||
async_session.add(user)
|
||||
await async_session.commit()
|
||||
|
|
@ -75,7 +76,7 @@ async def test_organization(async_session, test_user):
|
|||
membership = OrgMember(
|
||||
id=uuid.uuid4(),
|
||||
organization_id=org.id,
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
role="owner",
|
||||
)
|
||||
async_session.add(membership)
|
||||
|
|
@ -167,14 +168,14 @@ class TestOrganizationRoutes:
|
|||
async def test_organization_members_invite_endpoint_exists(self, async_client, test_organization, async_session):
|
||||
"""验证 /api/v1/organization/members/invite 端点存在"""
|
||||
invite_user = User(
|
||||
id=uuid.uuid4(),
|
||||
id=str(uuid.uuid4()),
|
||||
email="newuser@example.com",
|
||||
password_hash="hashed_password",
|
||||
name="New User",
|
||||
password="hashed_password",
|
||||
firstName="New User",
|
||||
plan="free",
|
||||
max_queries=5,
|
||||
is_active=True,
|
||||
email_verified=True,
|
||||
isActive=True,
|
||||
emailVerified=True,
|
||||
)
|
||||
async_session.add(invite_user)
|
||||
await async_session.commit()
|
||||
|
|
@ -189,14 +190,14 @@ class TestOrganizationRoutes:
|
|||
async def test_organization_member_role_endpoint_exists(self, async_client, test_organization, async_session, test_user):
|
||||
"""验证 /api/v1/organization/members/{id}/role 端点存在"""
|
||||
new_user = User(
|
||||
id=uuid.uuid4(),
|
||||
id=str(uuid.uuid4()),
|
||||
email="member@example.com",
|
||||
password_hash="hashed_password",
|
||||
name="Member User",
|
||||
password="hashed_password",
|
||||
firstName="Member User",
|
||||
plan="free",
|
||||
max_queries=5,
|
||||
is_active=True,
|
||||
email_verified=True,
|
||||
isActive=True,
|
||||
emailVerified=True,
|
||||
)
|
||||
async_session.add(new_user)
|
||||
await async_session.flush()
|
||||
|
|
@ -220,14 +221,14 @@ class TestOrganizationRoutes:
|
|||
async def test_organization_member_delete_endpoint_exists(self, async_client, test_organization, async_session):
|
||||
"""验证 /api/v1/organization/members/{id} 端点存在"""
|
||||
new_user = User(
|
||||
id=uuid.uuid4(),
|
||||
id=str(uuid.uuid4()),
|
||||
email="todelete@example.com",
|
||||
password_hash="hashed_password",
|
||||
name="Delete User",
|
||||
password="hashed_password",
|
||||
firstName="Delete User",
|
||||
plan="free",
|
||||
max_queries=5,
|
||||
is_active=True,
|
||||
email_verified=True,
|
||||
isActive=True,
|
||||
emailVerified=True,
|
||||
)
|
||||
async_session.add(new_user)
|
||||
await async_session.flush()
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from app.models.query import Query as QueryModel
|
|||
from app.models.citation_record import CitationRecord
|
||||
from app.api.deps import get_current_user, get_db
|
||||
from app.services.auth import create_access_token
|
||||
from tests.fixtures.auth import _to_uuid
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
|
|
@ -52,14 +53,14 @@ async def async_session(async_engine):
|
|||
async def test_user(async_session):
|
||||
"""Create a test user."""
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
id=str(uuid.uuid4()),
|
||||
email="test@example.com",
|
||||
password_hash="hashed_password",
|
||||
name="Test User",
|
||||
password="hashed_password",
|
||||
firstName="Test User",
|
||||
plan="free",
|
||||
max_queries=5,
|
||||
is_active=True,
|
||||
email_verified=True,
|
||||
isActive=True,
|
||||
emailVerified=True,
|
||||
)
|
||||
async_session.add(user)
|
||||
await async_session.commit()
|
||||
|
|
@ -72,7 +73,7 @@ async def test_brand(async_session, test_user):
|
|||
"""Create a test brand."""
|
||||
brand = Brand(
|
||||
id=uuid.uuid4(),
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
name="TestBrand",
|
||||
aliases=["TestBrand", "TB"],
|
||||
website="https://testbrand.com",
|
||||
|
|
@ -92,7 +93,7 @@ async def test_query(async_session, test_user, test_brand):
|
|||
"""Create a test query with citation records."""
|
||||
query = QueryModel(
|
||||
id=uuid.uuid4(),
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
keyword="AI assistant",
|
||||
target_brand="TestBrand",
|
||||
brand_aliases=["TestBrand"],
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from app.models.query import Query
|
|||
from app.models.citation_record import CitationRecord
|
||||
from app.api.deps import get_current_user, get_db
|
||||
from app.services.auth import create_access_token
|
||||
from tests.fixtures.auth import _to_uuid
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
|
|
@ -53,14 +54,14 @@ async def async_session(async_engine):
|
|||
async def test_user(async_session):
|
||||
"""Create a test user."""
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
id=str(uuid.uuid4()),
|
||||
email="test@example.com",
|
||||
password_hash="hashed_password",
|
||||
name="Test User",
|
||||
password="hashed_password",
|
||||
firstName="Test User",
|
||||
plan="free",
|
||||
max_queries=5,
|
||||
is_active=True,
|
||||
email_verified=True,
|
||||
isActive=True,
|
||||
emailVerified=True,
|
||||
)
|
||||
async_session.add(user)
|
||||
await async_session.commit()
|
||||
|
|
@ -73,7 +74,7 @@ async def test_brand(async_session, test_user):
|
|||
"""Create a test brand."""
|
||||
brand = Brand(
|
||||
id=uuid.uuid4(),
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
name="TestBrand",
|
||||
aliases=["TestBrand", "TB"],
|
||||
website="https://testbrand.com",
|
||||
|
|
@ -93,7 +94,7 @@ async def test_brand_with_data(async_session: AsyncSession, test_user):
|
|||
"""Create a test brand with query and citation data."""
|
||||
brand = Brand(
|
||||
id=uuid.uuid4(),
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
name="TestBrand",
|
||||
aliases=["TestAlias"],
|
||||
website="https://test.com",
|
||||
|
|
@ -118,7 +119,7 @@ async def test_brand_with_data(async_session: AsyncSession, test_user):
|
|||
# Create a query
|
||||
query = Query(
|
||||
id=uuid.uuid4(),
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
keyword="AI assistant",
|
||||
target_brand="TestBrand",
|
||||
brand_aliases=["TestAlias"],
|
||||
|
|
|
|||
|
|
@ -40,12 +40,13 @@ class TestPrometheusMetrics:
|
|||
|
||||
# 从注册表获取指标值
|
||||
metrics = {m.name: m for m in REGISTRY.collect()}
|
||||
api_requests = metrics.get("geo_api_requests_total")
|
||||
# prometheus_client strips _total suffix from Counter names in collect()
|
||||
api_requests = metrics.get("geo_api_requests")
|
||||
|
||||
assert api_requests is not None
|
||||
|
||||
# 验证指标包含正确的标签
|
||||
samples = list(api_requests.collect())[0].samples
|
||||
samples = api_requests.samples
|
||||
sample = next((s for s in samples if s.labels.get("method") == "GET" and s.labels.get("endpoint") == "/test"), None)
|
||||
assert sample is not None
|
||||
assert sample.value >= 1
|
||||
|
|
@ -55,7 +56,7 @@ class TestPrometheusMetrics:
|
|||
AGENT_EXECUTIONS_TOTAL.labels(agent_name="test_agent", status="success").inc()
|
||||
|
||||
metrics = {m.name: m for m in REGISTRY.collect()}
|
||||
agent_executions = metrics.get("geo_agent_executions_total")
|
||||
agent_executions = metrics.get("geo_agent_executions")
|
||||
|
||||
assert agent_executions is not None
|
||||
|
||||
|
|
@ -65,7 +66,7 @@ class TestPrometheusMetrics:
|
|||
LLM_TOKENS_TOTAL.labels(provider="openai", model="gpt-4", token_type="completion").inc(50)
|
||||
|
||||
metrics = {m.name: m for m in REGISTRY.collect()}
|
||||
llm_tokens = metrics.get("geo_llm_tokens_total")
|
||||
llm_tokens = metrics.get("geo_llm_tokens")
|
||||
|
||||
assert llm_tokens is not None
|
||||
|
||||
|
|
@ -94,7 +95,7 @@ class TestPrometheusMetrics:
|
|||
QUERY_COUNT_TOTAL.labels(platform="kimi", status="failed").inc()
|
||||
|
||||
metrics = {m.name: m for m in REGISTRY.collect()}
|
||||
query_count = metrics.get("geo_queries_total")
|
||||
query_count = metrics.get("geo_queries")
|
||||
|
||||
assert query_count is not None
|
||||
|
||||
|
|
@ -103,7 +104,7 @@ class TestPrometheusMetrics:
|
|||
CONTENT_GENERATED_TOTAL.inc()
|
||||
|
||||
metrics = {m.name: m for m in REGISTRY.collect()}
|
||||
content_count = metrics.get("geo_content_generated_total")
|
||||
content_count = metrics.get("geo_content_generated")
|
||||
|
||||
assert content_count is not None
|
||||
|
||||
|
|
@ -113,7 +114,7 @@ class TestPrometheusMetrics:
|
|||
CITATION_DETECTED_TOTAL.labels(platform="wenxin").inc()
|
||||
|
||||
metrics = {m.name: m for m in REGISTRY.collect()}
|
||||
citation_count = metrics.get("geo_citations_detected_total")
|
||||
citation_count = metrics.get("geo_citations_detected")
|
||||
|
||||
assert citation_count is not None
|
||||
|
||||
|
|
@ -211,15 +212,15 @@ class TestMetricsCollection:
|
|||
"""测试注册表收集所有指标"""
|
||||
metrics = {m.name: m for m in REGISTRY.collect()}
|
||||
|
||||
# 验证关键指标存在
|
||||
assert "geo_api_requests_total" in metrics
|
||||
assert "geo_agent_executions_total" in metrics
|
||||
assert "geo_llm_tokens_total" in metrics
|
||||
# 验证关键指标存在 (prometheus_client strips _total from Counter names)
|
||||
assert "geo_api_requests" in metrics
|
||||
assert "geo_agent_executions" in metrics
|
||||
assert "geo_llm_tokens" in metrics
|
||||
assert "geo_llm_cost_estimated" in metrics
|
||||
assert "geo_brands_total" in metrics
|
||||
assert "geo_queries_total" in metrics
|
||||
assert "geo_content_generated_total" in metrics
|
||||
assert "geo_citations_detected_total" in metrics
|
||||
assert "geo_queries" in metrics
|
||||
assert "geo_content_generated" in metrics
|
||||
assert "geo_citations_detected" in metrics
|
||||
|
||||
def test_metric_labels_are_valid(self):
|
||||
"""测试指标标签有效性"""
|
||||
|
|
@ -253,8 +254,9 @@ class TestMetricsHistory:
|
|||
|
||||
# 获取初始值
|
||||
metrics = {m.name: m for m in REGISTRY.collect()}
|
||||
api_requests = metrics.get("geo_api_requests_total")
|
||||
for sample in api_requests.collect()[0].samples:
|
||||
# prometheus_client strips _total suffix from Counter names in collect()
|
||||
api_requests = metrics.get("geo_api_requests")
|
||||
for sample in api_requests.samples:
|
||||
if sample.labels.get("endpoint") == test_endpoint:
|
||||
initial_count = sample.value
|
||||
break
|
||||
|
|
@ -266,8 +268,9 @@ class TestMetricsHistory:
|
|||
|
||||
# 验证增加
|
||||
metrics = {m.name: m for m in REGISTRY.collect()}
|
||||
api_requests = metrics.get("geo_api_requests_total")
|
||||
for sample in api_requests.collect()[0].samples:
|
||||
# prometheus_client strips _total suffix from Counter names in collect()
|
||||
api_requests = metrics.get("geo_api_requests")
|
||||
for sample in api_requests.samples:
|
||||
if sample.labels.get("endpoint") == test_endpoint:
|
||||
if initial_count is not None:
|
||||
assert sample.value >= initial_count + 3
|
||||
|
|
@ -285,7 +288,7 @@ class TestMetricsHistory:
|
|||
llm_cost = metrics.get("geo_llm_cost_estimated")
|
||||
|
||||
found = False
|
||||
for sample in llm_cost.collect()[0].samples:
|
||||
for sample in llm_cost.samples:
|
||||
if sample.labels.get("provider") == "test":
|
||||
assert sample.value == test_value
|
||||
found = True
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from app.models.competitor import Competitor
|
|||
from app.models.suggestion import Suggestion
|
||||
from app.api.deps import get_current_user, get_db
|
||||
from app.services.auth import create_access_token, hash_password
|
||||
from tests.fixtures.auth import _to_uuid
|
||||
|
||||
# Only the tables needed for performance tests (avoids JSONB/SQLite incompatibility)
|
||||
_TEST_TABLES = (
|
||||
|
|
@ -72,14 +73,14 @@ async def async_session(async_engine):
|
|||
async def test_user(async_session):
|
||||
"""Create a test user with properly hashed password."""
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
id=str(uuid.uuid4()),
|
||||
email="perf_test@example.com",
|
||||
password_hash=hash_password("PerfTest123!"),
|
||||
name="Performance Test User",
|
||||
password=hash_password("PerfTest123!"),
|
||||
firstName="Performance Test User",
|
||||
plan="free",
|
||||
max_queries=50,
|
||||
is_active=True,
|
||||
email_verified=True,
|
||||
isActive=True,
|
||||
emailVerified=True,
|
||||
)
|
||||
async_session.add(user)
|
||||
await async_session.commit()
|
||||
|
|
@ -155,7 +156,7 @@ class TestAPIPerformance:
|
|||
# Create several brands for a more realistic test
|
||||
for i in range(10):
|
||||
brand = Brand(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
name=f"Brand {i}",
|
||||
platforms=["wenxin"],
|
||||
status="active",
|
||||
|
|
@ -176,7 +177,7 @@ class TestAPIPerformance:
|
|||
# Create several queries for a more realistic test
|
||||
for i in range(10):
|
||||
query = Query(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
keyword=f"query keyword {i}",
|
||||
target_brand=f"Brand {i}",
|
||||
platforms=["wenxin"],
|
||||
|
|
@ -247,7 +248,7 @@ class TestConcurrency:
|
|||
# Pre-create data
|
||||
for i in range(5):
|
||||
brand = Brand(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
name=f"Concurrent Brand {i}",
|
||||
platforms=["wenxin"],
|
||||
status="active",
|
||||
|
|
@ -277,7 +278,7 @@ class TestConcurrency:
|
|||
"""Concurrent query list reads should all succeed."""
|
||||
for i in range(5):
|
||||
query = Query(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
keyword=f"concurrent query {i}",
|
||||
target_brand=f"Brand {i}",
|
||||
platforms=["wenxin"],
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from app.models.suggestion import Suggestion
|
|||
from app.api.deps import get_current_user, get_db
|
||||
from app.services.auth import create_access_token, create_refresh_token, hash_password
|
||||
from app.config import settings
|
||||
from tests.fixtures.auth import _to_uuid
|
||||
|
||||
# Only the tables needed for security tests (avoids JSONB/SQLite incompatibility)
|
||||
_TEST_TABLES = (
|
||||
|
|
@ -73,14 +74,14 @@ async def async_session(async_engine):
|
|||
async def test_user(async_session):
|
||||
"""Create a test user with properly hashed password."""
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
id=str(uuid.uuid4()),
|
||||
email="security_test@example.com",
|
||||
password_hash=hash_password("SecurePass123!"),
|
||||
name="Security Test User",
|
||||
password=hash_password("SecurePass123!"),
|
||||
firstName="Security Test User",
|
||||
plan="free",
|
||||
max_queries=5,
|
||||
is_active=True,
|
||||
email_verified=True,
|
||||
isActive=True,
|
||||
emailVerified=True,
|
||||
)
|
||||
async_session.add(user)
|
||||
await async_session.commit()
|
||||
|
|
@ -92,14 +93,14 @@ async def test_user(async_session):
|
|||
async def second_user(async_session):
|
||||
"""Create a second test user for cross-user isolation tests."""
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
id=str(uuid.uuid4()),
|
||||
email="second_user@example.com",
|
||||
password_hash=hash_password("SecondPass456!"),
|
||||
name="Second User",
|
||||
password=hash_password("SecondPass456!"),
|
||||
firstName="Second User",
|
||||
plan="free",
|
||||
max_queries=5,
|
||||
is_active=True,
|
||||
email_verified=True,
|
||||
isActive=True,
|
||||
emailVerified=True,
|
||||
)
|
||||
async_session.add(user)
|
||||
await async_session.commit()
|
||||
|
|
@ -376,7 +377,7 @@ class TestXSSProtection:
|
|||
"""XSS payloads in brand aliases should be stored as plain text."""
|
||||
brand = Brand(
|
||||
id=uuid.uuid4(),
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
name="Safe Brand",
|
||||
platforms=["wenxin"],
|
||||
status="active",
|
||||
|
|
@ -555,7 +556,7 @@ class TestAuthSecurity:
|
|||
# Create a brand for second_user
|
||||
brand = Brand(
|
||||
id=uuid.uuid4(),
|
||||
user_id=second_user.id,
|
||||
user_id=_to_uuid(second_user.id),
|
||||
name="Second User's Brand",
|
||||
platforms=["wenxin"],
|
||||
status="active",
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from app.models.query import Query as QueryModel
|
|||
from app.models.citation_record import CitationRecord
|
||||
from app.api.deps import get_current_user, get_db
|
||||
from app.services.auth import create_access_token
|
||||
from tests.fixtures.auth import _to_uuid
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
|
|
@ -53,14 +54,14 @@ async def async_session(async_engine):
|
|||
async def test_user(async_session):
|
||||
"""Create a test user."""
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
id=str(uuid.uuid4()),
|
||||
email="test@example.com",
|
||||
password_hash="hashed_password",
|
||||
name="Test User",
|
||||
password="hashed_password",
|
||||
firstName="Test User",
|
||||
plan="free",
|
||||
max_queries=5,
|
||||
is_active=True,
|
||||
email_verified=True,
|
||||
isActive=True,
|
||||
emailVerified=True,
|
||||
)
|
||||
async_session.add(user)
|
||||
await async_session.commit()
|
||||
|
|
@ -134,7 +135,7 @@ class TestFullBrandQueryFlow:
|
|||
# Step 3: Create a query (using Query model directly)
|
||||
query = QueryModel(
|
||||
id=uuid.uuid4(),
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
keyword="AI assistant",
|
||||
target_brand="TestBrand",
|
||||
brand_aliases=["TestBrand", "TB"],
|
||||
|
|
@ -250,7 +251,7 @@ class TestCSVExportFlow:
|
|||
# Step 2: Create a query with citations
|
||||
query = QueryModel(
|
||||
id=uuid.uuid4(),
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
keyword="export test keyword",
|
||||
target_brand="ExportTestBrand",
|
||||
brand_aliases=["ETB"],
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import pytest
|
|||
from sqlalchemy import select, and_
|
||||
|
||||
from app.models.api_key import APIKey
|
||||
from tests.fixtures.auth import _to_uuid
|
||||
|
||||
|
||||
class TestAPIKeyModel:
|
||||
|
|
@ -16,7 +17,7 @@ class TestAPIKeyModel:
|
|||
"""Test creating a new API key."""
|
||||
api_key = APIKey(
|
||||
id=uuid.uuid4(),
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
engine_type="chatgpt",
|
||||
encrypted_key="encrypted_test_key",
|
||||
key_hint="sk-...abc",
|
||||
|
|
@ -29,7 +30,7 @@ class TestAPIKeyModel:
|
|||
await async_session.refresh(api_key)
|
||||
|
||||
assert api_key.id is not None
|
||||
assert api_key.user_id == test_user.id
|
||||
assert api_key.user_id == _to_uuid(test_user.id)
|
||||
assert api_key.engine_type == "chatgpt"
|
||||
assert api_key.encrypted_key == "encrypted_test_key"
|
||||
assert api_key.key_hint == "sk-...abc"
|
||||
|
|
@ -43,7 +44,7 @@ class TestAPIKeyModel:
|
|||
async def test_api_key_default_values(self, async_session, test_user):
|
||||
"""Test API key default values."""
|
||||
api_key = APIKey(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
engine_type="kimi",
|
||||
encrypted_key="encrypted_kimi_key",
|
||||
key_hint="sk-...xyz",
|
||||
|
|
@ -62,7 +63,7 @@ class TestAPIKeyModel:
|
|||
"""Test API key field validation and constraints."""
|
||||
now = datetime.now(timezone.utc)
|
||||
api_key = APIKey(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
engine_type="deepseek",
|
||||
encrypted_key="encrypted_deepseek_key_data",
|
||||
key_hint="sk-...def",
|
||||
|
|
@ -90,7 +91,7 @@ class TestAPIKeyModel:
|
|||
key_id = uuid.uuid4()
|
||||
api_key = APIKey(
|
||||
id=key_id,
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
engine_type="gemini",
|
||||
encrypted_key="encrypted_gemini_key",
|
||||
key_hint="AIza...123",
|
||||
|
|
@ -111,13 +112,13 @@ class TestAPIKeyModel:
|
|||
async def test_api_key_query_by_user_id(self, async_session, test_user):
|
||||
"""Test querying API keys by user ID."""
|
||||
key1 = APIKey(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
engine_type="chatgpt",
|
||||
encrypted_key="encrypted_key_1",
|
||||
key_hint="sk-...1",
|
||||
)
|
||||
key2 = APIKey(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
engine_type="kimi",
|
||||
encrypted_key="encrypted_key_2",
|
||||
key_hint="sk-...2",
|
||||
|
|
@ -127,7 +128,7 @@ class TestAPIKeyModel:
|
|||
await async_session.commit()
|
||||
|
||||
result = await async_session.execute(
|
||||
select(APIKey).where(APIKey.user_id == test_user.id)
|
||||
select(APIKey).where(APIKey.user_id == _to_uuid(test_user.id))
|
||||
)
|
||||
keys = result.scalars().all()
|
||||
|
||||
|
|
@ -137,13 +138,13 @@ class TestAPIKeyModel:
|
|||
async def test_api_key_query_by_user_and_engine(self, async_session, test_user):
|
||||
"""Test querying API keys by user ID and engine type."""
|
||||
key1 = APIKey(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
engine_type="chatgpt",
|
||||
encrypted_key="encrypted_chatgpt_key",
|
||||
key_hint="sk-...chat",
|
||||
)
|
||||
key2 = APIKey(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
engine_type="kimi",
|
||||
encrypted_key="encrypted_kimi_key",
|
||||
key_hint="sk-...kimi",
|
||||
|
|
@ -155,7 +156,7 @@ class TestAPIKeyModel:
|
|||
result = await async_session.execute(
|
||||
select(APIKey).where(
|
||||
and_(
|
||||
APIKey.user_id == test_user.id,
|
||||
APIKey.user_id == _to_uuid(test_user.id),
|
||||
APIKey.engine_type == "chatgpt"
|
||||
)
|
||||
)
|
||||
|
|
@ -169,7 +170,7 @@ class TestAPIKeyModel:
|
|||
async def test_api_key_timestamps(self, async_session, test_user):
|
||||
"""Test API key created_at and updated_at timestamps."""
|
||||
api_key = APIKey(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
engine_type="qwen",
|
||||
encrypted_key="encrypted_qwen_key",
|
||||
key_hint="sk-...qwen",
|
||||
|
|
@ -187,7 +188,7 @@ class TestAPIKeyModel:
|
|||
async def test_api_key_update(self, async_session, test_user):
|
||||
"""Test updating API key fields."""
|
||||
api_key = APIKey(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
engine_type="wenxin",
|
||||
encrypted_key="encrypted_wenxin_key",
|
||||
key_hint="sk-...wenxin",
|
||||
|
|
@ -211,7 +212,7 @@ class TestAPIKeyModel:
|
|||
async def test_api_key_delete(self, async_session, test_user):
|
||||
"""Test deleting an API key."""
|
||||
api_key = APIKey(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
engine_type="doubao",
|
||||
encrypted_key="encrypted_doubao_key",
|
||||
key_hint="sk-...doubao",
|
||||
|
|
@ -238,7 +239,7 @@ class TestAPIKeyModel:
|
|||
|
||||
for i, status in enumerate(statuses):
|
||||
api_key = APIKey(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
engine_type=f"engine_{i}",
|
||||
encrypted_key=f"encrypted_key_{i}",
|
||||
key_hint=f"sk-...{i}",
|
||||
|
|
@ -263,7 +264,7 @@ class TestAPIKeyModel:
|
|||
|
||||
for data in keys_data:
|
||||
api_key = APIKey(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
engine_type=data["engine_type"],
|
||||
encrypted_key=f"encrypted_{data['engine_type']}",
|
||||
key_hint=data["key_hint"],
|
||||
|
|
@ -275,7 +276,7 @@ class TestAPIKeyModel:
|
|||
|
||||
result = await async_session.execute(
|
||||
select(APIKey)
|
||||
.where(APIKey.user_id == test_user.id)
|
||||
.where(APIKey.user_id == _to_uuid(test_user.id))
|
||||
.order_by(APIKey.priority.desc())
|
||||
)
|
||||
keys = result.scalars().all()
|
||||
|
|
@ -289,7 +290,7 @@ class TestAPIKeyModel:
|
|||
"""Test that user_id field has an index."""
|
||||
for i in range(5):
|
||||
api_key = APIKey(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
engine_type=f"engine_{i}",
|
||||
encrypted_key=f"encrypted_key_{i}",
|
||||
key_hint=f"hint_{i}",
|
||||
|
|
@ -299,7 +300,7 @@ class TestAPIKeyModel:
|
|||
await async_session.commit()
|
||||
|
||||
result = await async_session.execute(
|
||||
select(APIKey).where(APIKey.user_id == test_user.id)
|
||||
select(APIKey).where(APIKey.user_id == _to_uuid(test_user.id))
|
||||
)
|
||||
keys = result.scalars().all()
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import pytest
|
|||
from sqlalchemy import select
|
||||
|
||||
from app.models.brand import Brand
|
||||
from tests.fixtures.auth import _to_uuid
|
||||
|
||||
|
||||
class TestBrandModel:
|
||||
|
|
@ -16,7 +17,7 @@ class TestBrandModel:
|
|||
"""Test creating a new brand."""
|
||||
brand = Brand(
|
||||
id=uuid.uuid4(),
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
name="Test Brand",
|
||||
aliases=["TestBrand", "TB"],
|
||||
website="https://testbrand.com",
|
||||
|
|
@ -30,7 +31,7 @@ class TestBrandModel:
|
|||
await async_session.refresh(brand)
|
||||
|
||||
assert brand.id is not None
|
||||
assert brand.user_id == test_user.id
|
||||
assert brand.user_id == _to_uuid(test_user.id)
|
||||
assert brand.name == "Test Brand"
|
||||
assert brand.aliases == ["TestBrand", "TB"]
|
||||
assert brand.website == "https://testbrand.com"
|
||||
|
|
@ -45,7 +46,7 @@ class TestBrandModel:
|
|||
async def test_brand_default_values(self, async_session, test_user):
|
||||
"""Test brand default values."""
|
||||
brand = Brand(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
name="Default Brand",
|
||||
platforms=["wenxin"],
|
||||
)
|
||||
|
|
@ -59,13 +60,12 @@ class TestBrandModel:
|
|||
assert brand.frequency == "weekly"
|
||||
assert brand.status == "active"
|
||||
assert brand.last_queried_at is None
|
||||
assert brand.next_query_at is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brand_fields(self, async_session, test_user):
|
||||
"""Test brand field validation and constraints."""
|
||||
brand = Brand(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
name="Field Test Brand",
|
||||
aliases=["FTA", "FieldTest"],
|
||||
website="https://fieldtest.com",
|
||||
|
|
@ -80,7 +80,6 @@ class TestBrandModel:
|
|||
await async_session.commit()
|
||||
await async_session.refresh(brand)
|
||||
|
||||
# Verify all fields
|
||||
assert brand.name == "Field Test Brand"
|
||||
assert len(brand.name) == 16
|
||||
assert brand.aliases == ["FTA", "FieldTest"]
|
||||
|
|
@ -91,7 +90,6 @@ class TestBrandModel:
|
|||
assert brand.frequency == "daily"
|
||||
assert brand.status == "active"
|
||||
assert brand.last_queried_at is not None
|
||||
assert brand.next_query_at is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brand_query_by_id(self, async_session, test_user):
|
||||
|
|
@ -99,7 +97,7 @@ class TestBrandModel:
|
|||
brand_id = uuid.uuid4()
|
||||
brand = Brand(
|
||||
id=brand_id,
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
name="Query Test Brand",
|
||||
platforms=["wenxin"],
|
||||
)
|
||||
|
|
@ -118,14 +116,14 @@ class TestBrandModel:
|
|||
@pytest.mark.asyncio
|
||||
async def test_brand_query_by_user_id(self, async_session, test_user):
|
||||
"""Test querying brands by user ID."""
|
||||
# Create multiple brands for the same user
|
||||
uid = _to_uuid(test_user.id)
|
||||
brand1 = Brand(
|
||||
user_id=test_user.id,
|
||||
user_id=uid,
|
||||
name="User Brand 1",
|
||||
platforms=["wenxin"],
|
||||
)
|
||||
brand2 = Brand(
|
||||
user_id=test_user.id,
|
||||
user_id=uid,
|
||||
name="User Brand 2",
|
||||
platforms=["kimi"],
|
||||
)
|
||||
|
|
@ -134,7 +132,7 @@ class TestBrandModel:
|
|||
await async_session.commit()
|
||||
|
||||
result = await async_session.execute(
|
||||
select(Brand).where(Brand.user_id == test_user.id)
|
||||
select(Brand).where(Brand.user_id == uid)
|
||||
)
|
||||
brands = result.scalars().all()
|
||||
|
||||
|
|
@ -144,7 +142,7 @@ class TestBrandModel:
|
|||
async def test_brand_timestamps(self, async_session, test_user):
|
||||
"""Test brand created_at and updated_at timestamps."""
|
||||
brand = Brand(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
name="Timestamp Brand",
|
||||
platforms=["wenxin"],
|
||||
)
|
||||
|
|
@ -161,7 +159,7 @@ class TestBrandModel:
|
|||
async def test_brand_update(self, async_session, test_user):
|
||||
"""Test updating brand fields."""
|
||||
brand = Brand(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
name="Update Test Brand",
|
||||
platforms=["wenxin"],
|
||||
frequency="weekly",
|
||||
|
|
@ -169,7 +167,6 @@ class TestBrandModel:
|
|||
async_session.add(brand)
|
||||
await async_session.commit()
|
||||
|
||||
# Update brand
|
||||
brand.name = "Updated Brand Name"
|
||||
brand.frequency = "daily"
|
||||
brand.aliases = ["Updated", "Alias"]
|
||||
|
|
@ -184,7 +181,7 @@ class TestBrandModel:
|
|||
async def test_brand_delete(self, async_session, test_user):
|
||||
"""Test deleting a brand."""
|
||||
brand = Brand(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
name="Delete Test Brand",
|
||||
platforms=["wenxin"],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from sqlalchemy import select
|
|||
|
||||
from app.models.brand import Brand
|
||||
from app.models.competitor import Competitor
|
||||
from tests.fixtures.auth import _to_uuid
|
||||
|
||||
|
||||
class TestCompetitorModel:
|
||||
|
|
@ -18,7 +19,7 @@ class TestCompetitorModel:
|
|||
# First create a brand
|
||||
brand = Brand(
|
||||
id=uuid.uuid4(),
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
name="Test Brand for Competitor",
|
||||
platforms=["wenxin"],
|
||||
)
|
||||
|
|
@ -48,7 +49,7 @@ class TestCompetitorModel:
|
|||
"""Test competitor default values."""
|
||||
# Create brand first
|
||||
brand = Brand(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
name="Brand for Default Competitor",
|
||||
platforms=["wenxin"],
|
||||
)
|
||||
|
|
@ -72,7 +73,7 @@ class TestCompetitorModel:
|
|||
"""Test competitor field validation."""
|
||||
# Create brand first
|
||||
brand = Brand(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
name="Brand for Field Test",
|
||||
platforms=["wenxin", "kimi"],
|
||||
)
|
||||
|
|
@ -98,7 +99,7 @@ class TestCompetitorModel:
|
|||
"""Test querying competitor by ID."""
|
||||
# Create brand first
|
||||
brand = Brand(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
name="Brand for Query Test",
|
||||
platforms=["wenxin"],
|
||||
)
|
||||
|
|
@ -129,7 +130,7 @@ class TestCompetitorModel:
|
|||
"""Test querying competitors by brand ID."""
|
||||
# Create brand first
|
||||
brand = Brand(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
name="Brand for Multi Competitor Test",
|
||||
platforms=["wenxin"],
|
||||
)
|
||||
|
|
@ -164,7 +165,7 @@ class TestCompetitorModel:
|
|||
"""Test competitor created_at timestamp."""
|
||||
# Create brand first
|
||||
brand = Brand(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
name="Brand for Timestamp Test",
|
||||
platforms=["wenxin"],
|
||||
)
|
||||
|
|
@ -188,7 +189,7 @@ class TestCompetitorModel:
|
|||
"""Test updating competitor fields."""
|
||||
# Create brand first
|
||||
brand = Brand(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
name="Brand for Update Test",
|
||||
platforms=["wenxin"],
|
||||
)
|
||||
|
|
@ -218,7 +219,7 @@ class TestCompetitorModel:
|
|||
"""Test deleting a competitor."""
|
||||
# Create brand first
|
||||
brand = Brand(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
name="Brand for Delete Test",
|
||||
platforms=["wenxin"],
|
||||
)
|
||||
|
|
@ -249,7 +250,7 @@ class TestCompetitorModel:
|
|||
"""Test that competitors are deleted when brand is deleted."""
|
||||
# Create brand with competitors
|
||||
brand = Brand(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
name="Brand for Cascade Test",
|
||||
platforms=["wenxin"],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from sqlalchemy import select
|
|||
|
||||
from app.models.organization import Organization, OrgMember
|
||||
from app.models.user import User
|
||||
from tests.fixtures.auth import _to_uuid
|
||||
|
||||
|
||||
class TestOrganizationModel:
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ class TestSubscriptionModel:
|
|||
def test_subscription_field_types(self):
|
||||
columns = Subscription.__table__.columns
|
||||
assert "UUID" in str(columns["id"].type).upper()
|
||||
assert "UUID" in str(columns["user_id"].type).upper()
|
||||
assert "VARCHAR" in str(columns["user_id"].type).upper() or "STRING" in str(columns["user_id"].type).upper()
|
||||
assert "VARCHAR" in str(columns["plan"].type).upper() or "STRING" in str(columns["plan"].type).upper()
|
||||
assert "VARCHAR" in str(columns["status"].type).upper() or "STRING" in str(columns["status"].type).upper()
|
||||
assert "DATE" in str(columns["start_date"].type).upper()
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from sqlalchemy import select
|
|||
from app.models.brand import Brand
|
||||
from app.models.suggestion import Suggestion
|
||||
from app.models.user import User
|
||||
from tests.fixtures.auth import _to_uuid
|
||||
|
||||
|
||||
class TestSuggestionModel:
|
||||
|
|
@ -81,7 +82,7 @@ class TestSuggestionModel:
|
|||
@pytest.mark.asyncio
|
||||
async def test_suggestion_create(self, async_session, test_user):
|
||||
brand = Brand(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
name="Suggestion Test Brand",
|
||||
platforms=["wenxin"],
|
||||
)
|
||||
|
|
@ -124,7 +125,7 @@ class TestSuggestionModel:
|
|||
@pytest.mark.asyncio
|
||||
async def test_suggestion_default_values(self, async_session, test_user):
|
||||
brand = Brand(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
name="Default Suggestion Brand",
|
||||
platforms=["kimi"],
|
||||
)
|
||||
|
|
@ -152,7 +153,7 @@ class TestSuggestionModel:
|
|||
@pytest.mark.asyncio
|
||||
async def test_suggestion_query_by_brand(self, async_session, test_user):
|
||||
brand = Brand(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
name="Query Suggestion Brand",
|
||||
platforms=["wenxin"],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from sqlalchemy import select, func, and_
|
|||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
from app.models.usage_record import UsageRecord
|
||||
from tests.fixtures.auth import _to_uuid
|
||||
|
||||
|
||||
class TestUsageRecordModel:
|
||||
|
|
@ -16,7 +17,7 @@ class TestUsageRecordModel:
|
|||
async def test_usage_record_create(self, async_session, test_user):
|
||||
"""Test creating a new usage record."""
|
||||
record = UsageRecord(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
engine_type="chatgpt",
|
||||
query="What is SEO optimization?",
|
||||
input_tokens=100,
|
||||
|
|
@ -29,7 +30,7 @@ class TestUsageRecordModel:
|
|||
await async_session.refresh(record)
|
||||
|
||||
assert record.id is not None
|
||||
assert record.user_id == test_user.id
|
||||
assert record.user_id == _to_uuid(test_user.id)
|
||||
assert record.engine_type == "chatgpt"
|
||||
assert record.query == "What is SEO optimization?"
|
||||
assert record.input_tokens == 100
|
||||
|
|
@ -43,7 +44,7 @@ class TestUsageRecordModel:
|
|||
async def test_usage_record_default_values(self, async_session, test_user):
|
||||
"""Test usage record default values."""
|
||||
record = UsageRecord(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
engine_type="kimi",
|
||||
query="Test query",
|
||||
)
|
||||
|
|
@ -60,7 +61,7 @@ class TestUsageRecordModel:
|
|||
@pytest.mark.asyncio
|
||||
async def test_usage_record_query_by_user_id(self, async_session, test_user):
|
||||
"""Test querying usage records by user ID."""
|
||||
user_id = test_user.id
|
||||
user_id = _to_uuid(test_user.id)
|
||||
for i in range(3):
|
||||
record = UsageRecord(
|
||||
user_id=user_id,
|
||||
|
|
@ -81,7 +82,7 @@ class TestUsageRecordModel:
|
|||
@pytest.mark.asyncio
|
||||
async def test_usage_record_query_by_user_and_engine(self, async_session, test_user):
|
||||
"""Test querying usage records by user ID and engine type."""
|
||||
user_id = test_user.id
|
||||
user_id = _to_uuid(test_user.id)
|
||||
record1 = UsageRecord(
|
||||
user_id=user_id,
|
||||
engine_type="chatgpt",
|
||||
|
|
@ -114,7 +115,7 @@ class TestUsageRecordModel:
|
|||
@pytest.mark.asyncio
|
||||
async def test_usage_record_query_by_time_range(self, async_session, test_user):
|
||||
"""Test querying usage records by time range."""
|
||||
user_id = test_user.id
|
||||
user_id = _to_uuid(test_user.id)
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
old_record = UsageRecord(
|
||||
|
|
@ -152,7 +153,7 @@ class TestUsageRecordModel:
|
|||
@pytest.mark.asyncio
|
||||
async def test_usage_record_aggregate_by_user(self, async_session, test_user):
|
||||
"""Test aggregating usage records by user."""
|
||||
user_id = test_user.id
|
||||
user_id = _to_uuid(test_user.id)
|
||||
records_data = [
|
||||
{"engine": "chatgpt", "input_tokens": 100, "output_tokens": 200, "cost": 0.01},
|
||||
{"engine": "kimi", "input_tokens": 150, "output_tokens": 300, "cost": 0.02},
|
||||
|
|
@ -192,7 +193,7 @@ class TestUsageRecordModel:
|
|||
@pytest.mark.asyncio
|
||||
async def test_usage_record_aggregate_by_day(self, async_session, test_user):
|
||||
"""Test aggregating usage records by day."""
|
||||
user_id = test_user.id
|
||||
user_id = _to_uuid(test_user.id)
|
||||
now = datetime.now(timezone.utc)
|
||||
today_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
|
|
@ -238,7 +239,7 @@ class TestUsageRecordModel:
|
|||
"""Test usage record with brand association."""
|
||||
brand_id = uuid.uuid4()
|
||||
record = UsageRecord(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
brand_id=brand_id,
|
||||
engine_type="wenxin",
|
||||
query="Brand query",
|
||||
|
|
@ -253,7 +254,7 @@ class TestUsageRecordModel:
|
|||
@pytest.mark.asyncio
|
||||
async def test_usage_record_index_user_engine(self, async_session, test_user):
|
||||
"""Test composite index on user_id and engine_type."""
|
||||
user_id = test_user.id
|
||||
user_id = _to_uuid(test_user.id)
|
||||
for i in range(5):
|
||||
record = UsageRecord(
|
||||
user_id=user_id,
|
||||
|
|
@ -280,7 +281,7 @@ class TestUsageRecordModel:
|
|||
async def test_usage_record_update(self, async_session, test_user):
|
||||
"""Test updating usage record fields."""
|
||||
record = UsageRecord(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
engine_type="xinghuo",
|
||||
query="Original query",
|
||||
cost=1.0,
|
||||
|
|
@ -300,7 +301,7 @@ class TestUsageRecordModel:
|
|||
async def test_usage_record_delete(self, async_session, test_user):
|
||||
"""Test deleting a usage record."""
|
||||
record = UsageRecord(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
engine_type="yuanbao",
|
||||
query="Delete me",
|
||||
cost=1.0,
|
||||
|
|
@ -323,7 +324,7 @@ class TestUsageRecordModel:
|
|||
async def test_usage_record_timestamps(self, async_session, test_user):
|
||||
"""Test usage record timestamp fields."""
|
||||
record = UsageRecord(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
engine_type="perplexity",
|
||||
query="Timestamp test",
|
||||
cost=1.0,
|
||||
|
|
@ -343,7 +344,7 @@ class TestUsageRecordModel:
|
|||
other_user_id = uuid.uuid4()
|
||||
|
||||
user1_record = UsageRecord(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
engine_type="chatgpt",
|
||||
query="User 1 query",
|
||||
cost=1.0,
|
||||
|
|
@ -359,7 +360,7 @@ class TestUsageRecordModel:
|
|||
await async_session.commit()
|
||||
|
||||
result_user1 = await async_session.execute(
|
||||
select(UsageRecord).where(UsageRecord.user_id == test_user.id)
|
||||
select(UsageRecord).where(UsageRecord.user_id == _to_uuid(test_user.id))
|
||||
)
|
||||
result_user2 = await async_session.execute(
|
||||
select(UsageRecord).where(UsageRecord.user_id == other_user_id)
|
||||
|
|
@ -372,7 +373,7 @@ class TestUsageRecordModel:
|
|||
async def test_usage_record_empty_query_field(self, async_session, test_user):
|
||||
"""Test usage record with empty query field."""
|
||||
record = UsageRecord(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
engine_type="deepseek",
|
||||
query="",
|
||||
cost=0.0,
|
||||
|
|
@ -394,7 +395,7 @@ class TestUsageRecordModel:
|
|||
"nested": {"key": {"deep": {"value": [1, 2, 3]}}},
|
||||
}
|
||||
record = UsageRecord(
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
engine_type="chatgpt",
|
||||
query="Large metadata test",
|
||||
extra_data=large_metadata,
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from app.models.user import User
|
|||
from app.models.usage_record import UsageRecord
|
||||
from app.repositories.usage_repository import UsageRepository
|
||||
from app.services.user_quota_service import UserQuotaService, PLAN_MONTHLY_LIMITS
|
||||
from tests.fixtures.auth import _to_uuid
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
|
|
@ -47,14 +48,14 @@ async def async_session(async_engine):
|
|||
async def test_user_free(async_session):
|
||||
"""Create a free plan test user."""
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
id=str(uuid.uuid4()),
|
||||
email="free@example.com",
|
||||
password_hash="hashed_password",
|
||||
name="Free User",
|
||||
password="hashed_password",
|
||||
firstName="Free User",
|
||||
plan="free",
|
||||
max_queries=5,
|
||||
is_active=True,
|
||||
email_verified=True,
|
||||
isActive=True,
|
||||
emailVerified=True,
|
||||
)
|
||||
async_session.add(user)
|
||||
await async_session.commit()
|
||||
|
|
@ -66,14 +67,14 @@ async def test_user_free(async_session):
|
|||
async def test_user_basic(async_session):
|
||||
"""Create a basic plan test user."""
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
id=str(uuid.uuid4()),
|
||||
email="basic@example.com",
|
||||
password_hash="hashed_password",
|
||||
name="Basic User",
|
||||
password="hashed_password",
|
||||
firstName="Basic User",
|
||||
plan="basic",
|
||||
max_queries=50,
|
||||
is_active=True,
|
||||
email_verified=True,
|
||||
isActive=True,
|
||||
emailVerified=True,
|
||||
)
|
||||
async_session.add(user)
|
||||
await async_session.commit()
|
||||
|
|
@ -85,14 +86,14 @@ async def test_user_basic(async_session):
|
|||
async def test_user_pro(async_session):
|
||||
"""Create a pro plan test user."""
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
id=str(uuid.uuid4()),
|
||||
email="pro@example.com",
|
||||
password_hash="hashed_password",
|
||||
name="Pro User",
|
||||
password="hashed_password",
|
||||
firstName="Pro User",
|
||||
plan="pro",
|
||||
max_queries=500,
|
||||
is_active=True,
|
||||
email_verified=True,
|
||||
isActive=True,
|
||||
emailVerified=True,
|
||||
)
|
||||
async_session.add(user)
|
||||
await async_session.commit()
|
||||
|
|
@ -110,7 +111,7 @@ class TestByDayAggregation:
|
|||
|
||||
for i in range(3):
|
||||
await repo.create({
|
||||
"user_id": test_user_free.id,
|
||||
"user_id": _to_uuid(test_user_free.id),
|
||||
"engine_type": "chatgpt",
|
||||
"query": f"Query {i}",
|
||||
"cost": 0.01,
|
||||
|
|
@ -130,7 +131,7 @@ class TestByDayAggregation:
|
|||
|
||||
for i in range(5):
|
||||
await repo.create({
|
||||
"user_id": test_user_free.id,
|
||||
"user_id": _to_uuid(test_user_free.id),
|
||||
"engine_type": "deepseek",
|
||||
"query": f"Query {i}",
|
||||
"cost": 0.02,
|
||||
|
|
@ -149,7 +150,7 @@ class TestByDayAggregation:
|
|||
repo = UsageRepository(async_session)
|
||||
|
||||
await repo.create({
|
||||
"user_id": test_user_free.id,
|
||||
"user_id": _to_uuid(test_user_free.id),
|
||||
"engine_type": "qwen",
|
||||
"query": "Query 1",
|
||||
"input_tokens": 100,
|
||||
|
|
@ -157,7 +158,7 @@ class TestByDayAggregation:
|
|||
"cost": 0.01,
|
||||
})
|
||||
await repo.create({
|
||||
"user_id": test_user_free.id,
|
||||
"user_id": _to_uuid(test_user_free.id),
|
||||
"engine_type": "qwen",
|
||||
"query": "Query 2",
|
||||
"input_tokens": 150,
|
||||
|
|
@ -188,14 +189,14 @@ class TestByDayAggregation:
|
|||
yesterday = datetime.now(timezone.utc) - timedelta(days=1)
|
||||
|
||||
await repo.create({
|
||||
"user_id": test_user_free.id,
|
||||
"user_id": _to_uuid(test_user_free.id),
|
||||
"engine_type": "gemini",
|
||||
"query": "Yesterday query",
|
||||
"cost": 0.05,
|
||||
"timestamp": yesterday,
|
||||
})
|
||||
await repo.create({
|
||||
"user_id": test_user_free.id,
|
||||
"user_id": _to_uuid(test_user_free.id),
|
||||
"engine_type": "gemini",
|
||||
"query": "Today query",
|
||||
"cost": 0.05,
|
||||
|
|
@ -268,7 +269,7 @@ class TestUserQuotaService:
|
|||
|
||||
for i in range(5):
|
||||
await repo.create({
|
||||
"user_id": test_user_free.id,
|
||||
"user_id": _to_uuid(test_user_free.id),
|
||||
"engine_type": "chatgpt",
|
||||
"query": f"Query {i}",
|
||||
"cost": 1.0,
|
||||
|
|
@ -291,7 +292,7 @@ class TestUserQuotaService:
|
|||
|
||||
for i in range(10):
|
||||
await repo.create({
|
||||
"user_id": test_user_basic.id,
|
||||
"user_id": _to_uuid(test_user_basic.id),
|
||||
"engine_type": "deepseek",
|
||||
"query": f"Query {i}",
|
||||
"cost": 2.0,
|
||||
|
|
@ -314,7 +315,7 @@ class TestUserQuotaService:
|
|||
|
||||
for i in range(10):
|
||||
await repo.create({
|
||||
"user_id": test_user_pro.id,
|
||||
"user_id": _to_uuid(test_user_pro.id),
|
||||
"engine_type": "qwen",
|
||||
"query": f"Query {i}",
|
||||
"cost": 10.0,
|
||||
|
|
@ -336,7 +337,7 @@ class TestUserQuotaService:
|
|||
repo = UsageRepository(async_session)
|
||||
|
||||
await repo.create({
|
||||
"user_id": test_user_free.id,
|
||||
"user_id": _to_uuid(test_user_free.id),
|
||||
"engine_type": "gemini",
|
||||
"query": "Expensive query",
|
||||
"cost": 15.0,
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from app.database import Base
|
|||
from app.models.user import User
|
||||
from app.models.usage_record import UsageRecord
|
||||
from app.repositories.usage_repository import UsageRepository
|
||||
from tests.fixtures.auth import _to_uuid
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
|
|
@ -46,14 +47,14 @@ async def async_session(async_engine):
|
|||
async def test_user(async_session):
|
||||
"""Create a test user."""
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
id=str(uuid.uuid4()),
|
||||
email="test@example.com",
|
||||
password_hash="hashed_password",
|
||||
name="Test User",
|
||||
password="hashed_password",
|
||||
firstName="Test User",
|
||||
plan="free",
|
||||
max_queries=5,
|
||||
is_active=True,
|
||||
email_verified=True,
|
||||
isActive=True,
|
||||
emailVerified=True,
|
||||
)
|
||||
async_session.add(user)
|
||||
await async_session.commit()
|
||||
|
|
@ -70,7 +71,7 @@ class TestUsageRepository:
|
|||
repo = UsageRepository(async_session)
|
||||
|
||||
data = {
|
||||
"user_id": test_user.id,
|
||||
"user_id": _to_uuid(test_user.id),
|
||||
"engine_type": "chatgpt",
|
||||
"query": "Test query",
|
||||
"input_tokens": 100,
|
||||
|
|
@ -82,7 +83,7 @@ class TestUsageRepository:
|
|||
record = await repo.create(data)
|
||||
|
||||
assert record.id is not None
|
||||
assert record.user_id == test_user.id
|
||||
assert record.user_id == _to_uuid(test_user.id)
|
||||
assert record.engine_type == "chatgpt"
|
||||
assert record.query == "Test query"
|
||||
assert record.input_tokens == 100
|
||||
|
|
@ -96,7 +97,7 @@ class TestUsageRepository:
|
|||
repo = UsageRepository(async_session)
|
||||
|
||||
data = {
|
||||
"user_id": test_user.id,
|
||||
"user_id": _to_uuid(test_user.id),
|
||||
"engine_type": "deepseek",
|
||||
"query": "Minimal query",
|
||||
}
|
||||
|
|
@ -104,7 +105,7 @@ class TestUsageRepository:
|
|||
record = await repo.create(data)
|
||||
|
||||
assert record.id is not None
|
||||
assert record.user_id == test_user.id
|
||||
assert record.user_id == _to_uuid(test_user.id)
|
||||
assert record.engine_type == "deepseek"
|
||||
assert record.query == "Minimal query"
|
||||
assert record.input_tokens == 0
|
||||
|
|
@ -119,7 +120,7 @@ class TestUsageRepository:
|
|||
|
||||
for i in range(3):
|
||||
await repo.create({
|
||||
"user_id": test_user.id,
|
||||
"user_id": _to_uuid(test_user.id),
|
||||
"engine_type": "chatgpt",
|
||||
"query": f"Query {i}",
|
||||
"input_tokens": 100,
|
||||
|
|
@ -143,13 +144,13 @@ class TestUsageRepository:
|
|||
repo = UsageRepository(async_session)
|
||||
|
||||
await repo.create({
|
||||
"user_id": test_user.id,
|
||||
"user_id": _to_uuid(test_user.id),
|
||||
"engine_type": "chatgpt",
|
||||
"query": "ChatGPT query",
|
||||
"cost": 0.02,
|
||||
})
|
||||
await repo.create({
|
||||
"user_id": test_user.id,
|
||||
"user_id": _to_uuid(test_user.id),
|
||||
"engine_type": "deepseek",
|
||||
"query": "DeepSeek query",
|
||||
"cost": 0.01,
|
||||
|
|
@ -170,7 +171,7 @@ class TestUsageRepository:
|
|||
|
||||
for i in range(2):
|
||||
await repo.create({
|
||||
"user_id": test_user.id,
|
||||
"user_id": _to_uuid(test_user.id),
|
||||
"engine_type": "qwen",
|
||||
"query": f"Query {i}",
|
||||
"cost": 0.01,
|
||||
|
|
@ -190,14 +191,14 @@ class TestUsageRepository:
|
|||
brand_id = uuid.uuid4()
|
||||
|
||||
await repo.create({
|
||||
"user_id": test_user.id,
|
||||
"user_id": _to_uuid(test_user.id),
|
||||
"brand_id": brand_id,
|
||||
"engine_type": "kimi",
|
||||
"query": "Brand query",
|
||||
"cost": 0.05,
|
||||
})
|
||||
await repo.create({
|
||||
"user_id": test_user.id,
|
||||
"user_id": _to_uuid(test_user.id),
|
||||
"engine_type": "kimi",
|
||||
"query": "No brand query",
|
||||
"cost": 0.03,
|
||||
|
|
@ -219,7 +220,7 @@ class TestUsageRepository:
|
|||
|
||||
for i in range(5):
|
||||
await repo.create({
|
||||
"user_id": test_user.id,
|
||||
"user_id": _to_uuid(test_user.id),
|
||||
"engine_type": "gemini",
|
||||
"query": f"Query {i}",
|
||||
"cost": 1.0,
|
||||
|
|
@ -238,7 +239,7 @@ class TestUsageRepository:
|
|||
repo = UsageRepository(async_session)
|
||||
|
||||
await repo.create({
|
||||
"user_id": test_user.id,
|
||||
"user_id": _to_uuid(test_user.id),
|
||||
"engine_type": "chatgpt",
|
||||
"query": "Expensive query",
|
||||
"cost": 85.0,
|
||||
|
|
@ -255,7 +256,7 @@ class TestUsageRepository:
|
|||
repo = UsageRepository(async_session)
|
||||
|
||||
await repo.create({
|
||||
"user_id": test_user.id,
|
||||
"user_id": _to_uuid(test_user.id),
|
||||
"engine_type": "chatgpt",
|
||||
"query": "Very expensive query",
|
||||
"cost": 120.0,
|
||||
|
|
@ -272,7 +273,7 @@ class TestUsageRepository:
|
|||
repo = UsageRepository(async_session)
|
||||
|
||||
created = await repo.create({
|
||||
"user_id": test_user.id,
|
||||
"user_id": _to_uuid(test_user.id),
|
||||
"engine_type": "wenxin",
|
||||
"query": "Get by ID test",
|
||||
"cost": 0.5,
|
||||
|
|
@ -291,7 +292,7 @@ class TestUsageRepository:
|
|||
|
||||
for i in range(3):
|
||||
await repo.create({
|
||||
"user_id": test_user.id,
|
||||
"user_id": _to_uuid(test_user.id),
|
||||
"engine_type": "doubao",
|
||||
"query": f"User query {i}",
|
||||
"cost": 0.1,
|
||||
|
|
@ -307,13 +308,13 @@ class TestUsageRepository:
|
|||
repo = UsageRepository(async_session)
|
||||
|
||||
await repo.create({
|
||||
"user_id": test_user.id,
|
||||
"user_id": _to_uuid(test_user.id),
|
||||
"engine_type": "xinghuo",
|
||||
"query": "Xinghuo query",
|
||||
"cost": 0.1,
|
||||
})
|
||||
await repo.create({
|
||||
"user_id": test_user.id,
|
||||
"user_id": _to_uuid(test_user.id),
|
||||
"engine_type": "perplexity",
|
||||
"query": "Perplexity query",
|
||||
"cost": 0.2,
|
||||
|
|
@ -358,7 +359,7 @@ class TestUsageRepository:
|
|||
repo = UsageRepository(async_session)
|
||||
|
||||
data = {
|
||||
"user_id": test_user.id,
|
||||
"user_id": _to_uuid(test_user.id),
|
||||
"engine_type": "yuanbao",
|
||||
"query": "UUID test",
|
||||
"cost": 0.5,
|
||||
|
|
@ -368,4 +369,4 @@ class TestUsageRepository:
|
|||
summary = await repo.get_summary(test_user.id, period="month")
|
||||
|
||||
assert summary["total_queries"] == 1
|
||||
assert record.user_id == test_user.id
|
||||
assert record.user_id == _to_uuid(test_user.id)
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from app.database import Base
|
|||
from app.models.brand import Brand
|
||||
from app.models.user import User
|
||||
from app.services.auth import hash_password
|
||||
from tests.fixtures.auth import _to_uuid
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
|
|
@ -42,14 +43,14 @@ async def async_session(async_engine):
|
|||
@pytest_asyncio.fixture
|
||||
async def test_user(async_session):
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
id=str(uuid.uuid4()),
|
||||
email="test@example.com",
|
||||
password_hash=hash_password("Test@123456"),
|
||||
name="Test User",
|
||||
password=hash_password("Test@123456"),
|
||||
firstName="Test User",
|
||||
plan="free",
|
||||
max_queries=5,
|
||||
is_active=True,
|
||||
email_verified=True,
|
||||
isActive=True,
|
||||
emailVerified=True,
|
||||
)
|
||||
async_session.add(user)
|
||||
await async_session.commit()
|
||||
|
|
@ -61,7 +62,7 @@ async def test_user(async_session):
|
|||
async def test_brand(async_session, test_user):
|
||||
brand = Brand(
|
||||
id=uuid.uuid4(),
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
name="Test Brand",
|
||||
aliases=["TestBrand", "TB"],
|
||||
website="https://testbrand.com",
|
||||
|
|
@ -83,7 +84,7 @@ class TestDetectionTaskModel:
|
|||
|
||||
task = DetectionTask(
|
||||
brand_id=test_brand.id,
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
name="每日品牌检测",
|
||||
frequency="daily",
|
||||
engines=["chatgpt", "perplexity"],
|
||||
|
|
@ -96,7 +97,7 @@ class TestDetectionTaskModel:
|
|||
|
||||
assert task.id is not None
|
||||
assert task.brand_id == test_brand.id
|
||||
assert task.user_id == test_user.id
|
||||
assert task.user_id == _to_uuid(test_user.id)
|
||||
assert task.name == "每日品牌检测"
|
||||
assert task.frequency == "daily"
|
||||
assert task.engines == ["chatgpt", "perplexity"]
|
||||
|
|
@ -114,7 +115,7 @@ class TestDetectionTaskModel:
|
|||
|
||||
task = DetectionTask(
|
||||
brand_id=test_brand.id,
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
name="简单检测",
|
||||
frequency="weekly",
|
||||
engines=["chatgpt"],
|
||||
|
|
@ -142,13 +143,13 @@ class TestDetectionSchedulerService:
|
|||
"queries": ["最佳保险品牌", "保险推荐"],
|
||||
"competitor_names": ["竞品A"],
|
||||
}
|
||||
task = await service.create_task(task_data, test_brand.id, test_user.id, async_session)
|
||||
task = await service.create_task(task_data, test_brand.id, _to_uuid(test_user.id), async_session)
|
||||
|
||||
assert task.id is not None
|
||||
assert task.name == "每日品牌检测"
|
||||
assert task.frequency == "daily"
|
||||
assert task.brand_id == test_brand.id
|
||||
assert task.user_id == test_user.id
|
||||
assert task.user_id == _to_uuid(test_user.id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_task(self, async_session, test_brand, test_user):
|
||||
|
|
@ -157,7 +158,7 @@ class TestDetectionSchedulerService:
|
|||
|
||||
task = DetectionTask(
|
||||
brand_id=test_brand.id,
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
name="旧名称",
|
||||
frequency="weekly",
|
||||
engines=["chatgpt"],
|
||||
|
|
@ -173,7 +174,7 @@ class TestDetectionSchedulerService:
|
|||
"frequency": "daily",
|
||||
"engines": ["chatgpt", "perplexity"],
|
||||
}
|
||||
updated = await service.update_task(task.id, update_data, test_user.id, async_session)
|
||||
updated = await service.update_task(task.id, update_data, _to_uuid(test_user.id), async_session)
|
||||
|
||||
assert updated.name == "新名称"
|
||||
assert updated.frequency == "daily"
|
||||
|
|
@ -186,7 +187,7 @@ class TestDetectionSchedulerService:
|
|||
|
||||
task = DetectionTask(
|
||||
brand_id=test_brand.id,
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
name="待删除",
|
||||
frequency="weekly",
|
||||
engines=["chatgpt"],
|
||||
|
|
@ -197,7 +198,7 @@ class TestDetectionSchedulerService:
|
|||
await async_session.refresh(task)
|
||||
|
||||
service = DetectionSchedulerService()
|
||||
result = await service.delete_task(task.id, test_user.id, async_session)
|
||||
result = await service.delete_task(task.id, _to_uuid(test_user.id), async_session)
|
||||
assert result is True
|
||||
|
||||
stmt = select(DetectionTask).where(DetectionTask.id == task.id)
|
||||
|
|
@ -212,7 +213,7 @@ class TestDetectionSchedulerService:
|
|||
for i in range(3):
|
||||
task = DetectionTask(
|
||||
brand_id=test_brand.id,
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
name=f"任务{i}",
|
||||
frequency="daily",
|
||||
engines=["chatgpt"],
|
||||
|
|
@ -222,7 +223,7 @@ class TestDetectionSchedulerService:
|
|||
await async_session.commit()
|
||||
|
||||
service = DetectionSchedulerService()
|
||||
tasks = await service.get_tasks(test_brand.id, test_user.id, async_session)
|
||||
tasks = await service.get_tasks(test_brand.id, _to_uuid(test_user.id), async_session)
|
||||
assert len(tasks) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -232,7 +233,7 @@ class TestDetectionSchedulerService:
|
|||
|
||||
task = DetectionTask(
|
||||
brand_id=test_brand.id,
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
name="手动触发测试",
|
||||
frequency="daily",
|
||||
engines=["chatgpt"],
|
||||
|
|
@ -245,7 +246,7 @@ class TestDetectionSchedulerService:
|
|||
service = DetectionSchedulerService()
|
||||
with patch.object(service, "execute_task", new_callable=AsyncMock) as mock_execute:
|
||||
mock_execute.return_value = {"status": "success", "results": []}
|
||||
result = await service.trigger_task(task.id, test_user.id, async_session)
|
||||
result = await service.trigger_task(task.id, _to_uuid(test_user.id), async_session)
|
||||
|
||||
assert result["status"] == "success"
|
||||
mock_execute.assert_called_once()
|
||||
|
|
@ -261,7 +262,7 @@ class TestDetectionSchedulerService:
|
|||
"engines": ["chatgpt"],
|
||||
"queries": ["查询"],
|
||||
}
|
||||
task = await service.create_task(task_data, test_brand.id, test_user.id, async_session)
|
||||
task = await service.create_task(task_data, test_brand.id, _to_uuid(test_user.id), async_session)
|
||||
assert task.frequency == "hourly"
|
||||
assert task.next_run_at is not None
|
||||
|
||||
|
|
@ -276,7 +277,7 @@ class TestDetectionSchedulerService:
|
|||
"engines": ["chatgpt"],
|
||||
"queries": ["查询"],
|
||||
}
|
||||
task = await service.create_task(task_data, test_brand.id, test_user.id, async_session)
|
||||
task = await service.create_task(task_data, test_brand.id, _to_uuid(test_user.id), async_session)
|
||||
assert task.frequency == "daily"
|
||||
assert task.next_run_at is not None
|
||||
|
||||
|
|
@ -291,7 +292,7 @@ class TestDetectionSchedulerService:
|
|||
"engines": ["chatgpt"],
|
||||
"queries": ["查询"],
|
||||
}
|
||||
task = await service.create_task(task_data, test_brand.id, test_user.id, async_session)
|
||||
task = await service.create_task(task_data, test_brand.id, _to_uuid(test_user.id), async_session)
|
||||
assert task.frequency == "weekly"
|
||||
assert task.next_run_at is not None
|
||||
|
||||
|
|
@ -307,7 +308,7 @@ class TestDetectionSchedulerService:
|
|||
"queries": ["查询"],
|
||||
}
|
||||
with pytest.raises(ValueError, match="frequency"):
|
||||
await service.create_task(task_data, test_brand.id, test_user.id, async_session)
|
||||
await service.create_task(task_data, test_brand.id, _to_uuid(test_user.id), async_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_task_flow(self, async_session, test_brand, test_user):
|
||||
|
|
@ -316,7 +317,7 @@ class TestDetectionSchedulerService:
|
|||
|
||||
task = DetectionTask(
|
||||
brand_id=test_brand.id,
|
||||
user_id=test_user.id,
|
||||
user_id=_to_uuid(test_user.id),
|
||||
name="执行流程测试",
|
||||
frequency="daily",
|
||||
engines=["chatgpt"],
|
||||
|
|
@ -356,7 +357,7 @@ class TestDetectionSchedulerService:
|
|||
from app.services.detection.detection_scheduler import DetectionSchedulerService
|
||||
|
||||
service = DetectionSchedulerService()
|
||||
result = await service.delete_task(uuid.uuid4(), test_user.id, async_session)
|
||||
result = await service.delete_task(uuid.uuid4(), _to_uuid(test_user.id), async_session)
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -365,12 +366,12 @@ class TestDetectionSchedulerService:
|
|||
|
||||
service = DetectionSchedulerService()
|
||||
with pytest.raises(TaskNotFoundError):
|
||||
await service.update_task(uuid.uuid4(), {"name": "新名称"}, test_user.id, async_session)
|
||||
await service.update_task(uuid.uuid4(), {"name": "新名称"}, _to_uuid(test_user.id), async_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tasks_empty(self, async_session, test_brand, test_user):
|
||||
from app.services.detection.detection_scheduler import DetectionSchedulerService
|
||||
|
||||
service = DetectionSchedulerService()
|
||||
tasks = await service.get_tasks(test_brand.id, test_user.id, async_session)
|
||||
tasks = await service.get_tasks(test_brand.id, _to_uuid(test_user.id), async_session)
|
||||
assert tasks == []
|
||||
|
|
|
|||
|
|
@ -0,0 +1,886 @@
|
|||
---
|
||||
title: "refactor: GEO Agent Framework — 统一架构重构为独立项目"
|
||||
type: refactor
|
||||
status: active
|
||||
date: 2026-06-04
|
||||
---
|
||||
|
||||
## Summary
|
||||
|
||||
将 GEO 项目的 8 个 Agent 从代码重复的独立类重构为独立 Python 包(`fischer-agentkit`),采用统一 Agent Core + 配置驱动 + 可插拔 Tool/Skill/Memory 架构。新增 MCP 协议支持、三层记忆系统(Working/Episodic/Semantic)、Level 3 自我进化闭环(经验积累 → Prompt 自动优化 → 策略调整)、多 Agent 协同增强(并行执行、Handoff、动态 Pipeline),并实现 Agent 与业务系统的完全解耦。
|
||||
|
||||
## Problem Frame
|
||||
|
||||
当前 GEO 项目的 Agent 体系存在以下结构性问题:
|
||||
|
||||
1. **代码重复**:8 个 Agent 的 `execute()` 方法包含完全相同的计时、try/except、TaskResult 构建逻辑(~250 行模板代码),每新增一个 Agent 需复制 ~150 行
|
||||
2. **架构耦合**:Agent 直接依赖业务 Service(CitationService、MonitorService 等),无法独立部署和复用
|
||||
3. **能力缺失**:无记忆系统(每次执行无状态)、无自我进化(Prompt 硬编码不可调优)、无技能插件(能力静态绑定)、无 MCP 支持
|
||||
4. **编排局限**:Pipeline 仅支持串行 DAG、无 Agent 间 Handoff、无动态路由
|
||||
5. **配置僵化**:AgentType 硬编码枚举、Prompt 与 Agent 紧耦合、新增 Agent 需改 5+ 文件
|
||||
|
||||
---
|
||||
|
||||
## Requirements
|
||||
|
||||
### Agent Core 统一架构
|
||||
|
||||
R1. Agent 框架必须作为独立 Python 包(`fischer-agentkit`),可独立安装、版本管理、跨项目复用
|
||||
R2. 所有 Agent 共享统一的 `BaseAgent` 生命周期(start → listen → execute → reflect → evolve → stop),子类只需实现 `handle_task()` 返回业务数据
|
||||
R3. Agent 定义支持三种模式:Python 声明式(`@task` 装饰器)、YAML 配置驱动(零代码)、混合模式
|
||||
R4. AgentType 从硬编码枚举改为动态注册,新增 Agent 类型无需修改框架代码
|
||||
R5. Agent 的 input/output 支持 JSON Schema 声明与校验
|
||||
|
||||
### Tool/Skill 插件系统
|
||||
|
||||
R6. 实现 `Tool` 抽象基类,支持 `FunctionTool`(函数工具)、`AgentTool`(Agent 即工具)、`MCPTool`(MCP 协议工具)三种类型
|
||||
R7. 实现 `ToolRegistry`,支持工具的注册、发现、版本管理、标签分类
|
||||
R8. 实现 MCP Server,将现有 Agent 能力暴露为 MCP 工具供外部调用
|
||||
R9. 实现 MCP Client,Agent 可调用外部 MCP 工具服务器
|
||||
R10. 工具支持组合:顺序链、并行扇出/扇入、动态选择
|
||||
|
||||
### 记忆系统
|
||||
|
||||
R11. 实现 Working Memory(Redis-based),存储当前任务的上下文和中间状态,生命周期为单次任务
|
||||
R12. 实现 Episodic Memory(向量+关系数据库),记录每次任务的输入/输出/效果/反思,支持语义检索相似历史案例
|
||||
R13. 实现 Semantic Memory,复用现有 RAG 知识库,所有 Agent 均可通过统一接口检索知识
|
||||
R14. 记忆检索采用混合策略:向量语义 + 关键词 + 知识图谱 + 时间衰减 + RRF 融合排序
|
||||
|
||||
### 自我进化(Level 3)
|
||||
|
||||
R15. 实现经验积累:每次任务执行后自动记录成功/失败模式、效果指标、反思总结
|
||||
R16. 实现 Prompt 自动优化:基于 DSPy 风格的编译器,从任务结果中自动生成/优化 Prompt 指令和 few-shot 示例
|
||||
R17. 实现策略调整:根据历史效果数据自动调整 Agent 参数(temperature、tool 选择权重、Pipeline 路径)
|
||||
R18. 进化变更必须经过 A/B 测试验证后才可生效,支持回滚
|
||||
|
||||
### 多 Agent 协同与业务编排
|
||||
|
||||
R19. Pipeline Engine 支持同层并行执行(无依赖的 stages 使用 `asyncio.gather` 并行)
|
||||
R20. 实现 Handoff 机制:Agent 可在运行时将任务转交给另一个 Agent,携带上下文
|
||||
R21. 支持动态 Pipeline:运行时根据条件选择子流程、嵌套 Pipeline
|
||||
R22. Agent 间通信支持事件驱动(Redis Pub/Sub),替代轮询等待
|
||||
|
||||
### Agent 与业务解耦
|
||||
|
||||
R23. Agent 框架不依赖任何 GEO 业务代码(Service、Model、Repository),通过 Tool 接口调用业务能力
|
||||
R24. 业务系统通过 Tool 注册将自身能力暴露给 Agent,Agent 通过 ToolRegistry 发现和调用
|
||||
R25. Agent 配置(Prompt、Tool 绑定、Memory 策略)存储在数据库,支持热更新
|
||||
|
||||
---
|
||||
|
||||
## Key Technical Decisions
|
||||
|
||||
KTD1. **独立 Python 包架构**:`fischer-agentkit` 作为独立包发布到私有 PyPI,GEO 项目通过 `pip install` 引入。包结构遵循 `src/` layout,支持独立测试和版本管理。理由:解耦 Agent 基础设施与业务代码,支持跨项目复用。
|
||||
|
||||
KTD2. **统一 Agent Core 的 execute 模板上移**:将计时、try/except、TaskResult 构建、进度上报等公共逻辑全部上移到 `BaseAgent.execute()`,子类只需实现 `handle_task(task) -> dict`。理由:消除 8 个 Agent 中 ~250 行重复代码。
|
||||
|
||||
KTD3. **Tool 抽象三层架构**:`FunctionTool`(进程内函数调用)→ `MCPTool`(跨进程 MCP 协议调用)→ `AgentTool`(将 Agent 包装为 Tool)。理由:覆盖从简单到复杂的所有工具场景,MCP 作为标准协议确保生态兼容性。
|
||||
|
||||
KTD4. **MCP 双向支持**:同时实现 MCP Server(暴露 Agent 能力)和 MCP Client(调用外部工具)。传输层优先支持 Streamable HTTP(2025 标准),兼容 SSE。理由:MCP 是 2025 年工具协议事实标准,双向支持最大化生态连接能力。
|
||||
|
||||
KTD5. **三层记忆架构**:Working Memory(Redis,单任务生命周期)→ Episodic Memory(pgvector + PostgreSQL,任务经验)→ Semantic Memory(复用现有 RAG + 知识图谱)。理由:不同记忆类型有不同生命周期和检索模式,分层实现职责清晰。
|
||||
|
||||
KTD6. **Level 3 进化采用 DSPy 风格编译器**:定义 `Signature`(输入/输出 schema)→ `Module`(可组合 Prompt 策略)→ `Optimizer`(自动优化),从任务结果中自动构建 few-shot 示例和优化指令。理由:DSPy 是目前最成熟的 Prompt 自动优化范式,其编译器模式与 Agent 的 execute-reflect-evolve 生命周期天然契合。
|
||||
|
||||
KTD7. **配置驱动的 Agent 定义**:Agent 的元数据、Prompt、Tool 绑定、Memory 策略均通过 YAML/数据库配置,运行时由 `ConfigDrivenAgent` 自动组装。理由:将新增 Agent 从写 150 行代码降为 10-20 行配置。
|
||||
|
||||
KTD8. **Pipeline 并行执行**:将拓扑排序后的 stages 按依赖层级分组,同层内使用 `asyncio.gather` 并行执行。理由:引用检测和趋势分析等无依赖任务应并行,当前串行执行浪费时间。
|
||||
|
||||
KTD9. **Handoff 基于 Redis Pub/Sub**:Agent 通过 `agent:{name}:handoff` 频道发送转交请求,目标 Agent 监听并接管。理由:与现有 Redis Queue 架构一致,无需引入新组件。
|
||||
|
||||
---
|
||||
|
||||
## High-Level Technical Design
|
||||
|
||||
### 整体架构
|
||||
|
||||
```mermaid
|
||||
flowchart TB
|
||||
subgraph FischerAgentKit["fischer-agentkit (独立包)"]
|
||||
direction TB
|
||||
Core["Agent Core"]
|
||||
Tools["Tool System"]
|
||||
Memory["Memory System"]
|
||||
Evolution["Evolution Engine"]
|
||||
Orchestrator["Orchestrator"]
|
||||
MCP["MCP Layer"]
|
||||
|
||||
subgraph Core
|
||||
BaseAgent["BaseAgent<br/>生命周期管理"]
|
||||
ConfigDriven["ConfigDrivenAgent<br/>配置驱动"]
|
||||
Registry["AgentRegistry<br/>注册发现"]
|
||||
Dispatcher["TaskDispatcher<br/>任务调度"]
|
||||
Protocol["Protocol<br/>通信协议"]
|
||||
end
|
||||
|
||||
subgraph Tools
|
||||
ToolRegistry["ToolRegistry<br/>注册发现"]
|
||||
FunctionTool["FunctionTool<br/>函数工具"]
|
||||
MCPTool["MCPTool<br/>MCP工具"]
|
||||
AgentTool["AgentTool<br/>Agent即工具"]
|
||||
end
|
||||
|
||||
subgraph Memory
|
||||
Working["Working Memory<br/>Redis"]
|
||||
Episodic["Episodic Memory<br/>pgvector+PG"]
|
||||
Semantic["Semantic Memory<br/>RAG+Graph"]
|
||||
Retriever["MemoryRetriever<br/>混合检索"]
|
||||
end
|
||||
|
||||
subgraph Evolution
|
||||
Reflector["Reflector<br/>执行反思"]
|
||||
PromptOptimizer["PromptOptimizer<br/>DSPy编译器"]
|
||||
StrategyTuner["StrategyTuner<br/>策略调优"]
|
||||
ABTester["ABTester<br/>A/B测试"]
|
||||
end
|
||||
|
||||
subgraph Orchestrator
|
||||
PipelineEngine["PipelineEngine<br/>DAG+并行"]
|
||||
Handoff["Handoff<br/>任务转交"]
|
||||
DynamicPipeline["DynamicPipeline<br/>运行时组合"]
|
||||
end
|
||||
|
||||
subgraph MCP
|
||||
MCPServer["MCP Server<br/>暴露能力"]
|
||||
MCPClient["MCP Client<br/>调用外部"]
|
||||
end
|
||||
end
|
||||
|
||||
subgraph GEOBusiness["GEO 业务系统"]
|
||||
Services["业务 Services"]
|
||||
Models["数据 Models"]
|
||||
Repos["Repositories"]
|
||||
end
|
||||
|
||||
Core --> Tools
|
||||
Core --> Memory
|
||||
Core --> Evolution
|
||||
Core --> Orchestrator
|
||||
Tools --> MCP
|
||||
Services -.->|"注册为 FunctionTool"| ToolRegistry
|
||||
Semantic -.->|"复用"| Services
|
||||
```
|
||||
|
||||
### Agent 生命周期
|
||||
|
||||
```mermaid
|
||||
stateDiagram-v2
|
||||
[*] --> Offline
|
||||
Offline --> Online: start()
|
||||
Online --> Listening: listen_for_tasks()
|
||||
Listening --> Executing: receive_task()
|
||||
Executing --> Reflecting: handle_task() 完成
|
||||
Reflecting --> Evolving: reflect() 有改进建议
|
||||
Reflecting --> Listening: reflect() 无改进
|
||||
Evolving --> Listening: evolve() 完成
|
||||
Executing --> Listening: handle_task() 异常
|
||||
Online --> Offline: stop()
|
||||
Listening --> Offline: stop()
|
||||
|
||||
state Executing {
|
||||
[*] --> LoadMemory: 加载记忆
|
||||
LoadMemory --> PlanTask: 规划任务
|
||||
PlanTask --> RunTools: 执行工具
|
||||
RunTools --> StoreMemory: 存储记忆
|
||||
StoreMemory --> BuildResult: 构建结果
|
||||
BuildResult --> [*]
|
||||
}
|
||||
|
||||
state Reflecting {
|
||||
[*] --> EvaluateOutcome: 评估结果
|
||||
EvaluateOutcome --> ExtractPatterns: 提取模式
|
||||
ExtractPatterns --> GenerateInsights: 生成洞察
|
||||
GenerateInsights --> [*]
|
||||
}
|
||||
|
||||
state Evolving {
|
||||
[*] --> OptimizePrompt: 优化Prompt
|
||||
OptimizePrompt --> TuneStrategy: 调优策略
|
||||
TuneStrategy --> ABTest: A/B测试
|
||||
ABTest --> ApplyOrRollback: 应用或回滚
|
||||
ApplyOrRollback --> [*]
|
||||
}
|
||||
```
|
||||
|
||||
### Tool 系统架构
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
subgraph Agent["Agent"]
|
||||
Executor["Task Executor"]
|
||||
end
|
||||
|
||||
subgraph ToolRegistry["ToolRegistry"]
|
||||
FT["FunctionTool<br/>name, description<br/>input_schema, output_schema<br/>execute(**kwargs)"]
|
||||
MT["MCPTool<br/>server_url, tool_name<br/>input_schema, output_schema<br/>execute(**kwargs)"]
|
||||
AT["AgentTool<br/>agent_name<br/>input_mapping, output_mapping<br/>execute(**kwargs)"]
|
||||
end
|
||||
|
||||
subgraph MCPServer["MCP Server"]
|
||||
MCPEndpoints["/tools/list<br/>/tools/call<br/>/resources/read"]
|
||||
end
|
||||
|
||||
subgraph ExternalMCP["External MCP Servers"]
|
||||
ExtTool1["File System"]
|
||||
ExtTool2["GitHub"]
|
||||
ExtTool3["Postgres"]
|
||||
end
|
||||
|
||||
Executor -->|"select & call"| ToolRegistry
|
||||
FT -->|"进程内调用"| BusinessLogic["业务函数"]
|
||||
MT -->|"HTTP/SSE"| MCPServer
|
||||
MT -->|"HTTP/SSE"| ExternalMCP
|
||||
AT -->|"dispatch"| TargetAgent["目标 Agent"]
|
||||
MCPServer -->|"暴露"| AgentCapabilities["Agent 能力"]
|
||||
```
|
||||
|
||||
### 记忆系统数据流
|
||||
|
||||
```mermaid
|
||||
flowchart TB
|
||||
Task["新任务到达"] --> WM["Working Memory<br/>加载当前任务上下文"]
|
||||
WM --> EM["Episodic Memory<br/>检索相似历史案例"]
|
||||
EM --> SM["Semantic Memory<br/>检索知识库"]
|
||||
SM --> Context["组装上下文"]
|
||||
Context --> Execute["执行任务"]
|
||||
Execute --> WM_Write["写入 Working Memory<br/>中间状态"]
|
||||
WM_Write --> Execute
|
||||
Execute --> EM_Write["写入 Episodic Memory<br/>任务经验"]
|
||||
Execute --> Reflect["反思评估"]
|
||||
Reflect --> EM_Update["更新 Episodic Memory<br/>反思结果"]
|
||||
|
||||
subgraph 检索策略
|
||||
VS["向量语义检索<br/>权重 0.5"]
|
||||
KS["关键词精确匹配<br/>权重 0.2"]
|
||||
GT["知识图谱关联<br/>权重 0.3"]
|
||||
RRF["RRF 融合排序"]
|
||||
TD["时间衰减"]
|
||||
|
||||
VS --> RRF
|
||||
KS --> RRF
|
||||
GT --> RRF
|
||||
RRF --> TD
|
||||
end
|
||||
```
|
||||
|
||||
### 进化闭环
|
||||
|
||||
```mermaid
|
||||
flowchart TB
|
||||
Execute["任务执行"] --> Result["任务结果"]
|
||||
Result --> Evaluate["效果评估<br/>成功/失败/质量分"]
|
||||
Evaluate --> Pattern["模式提取<br/>成功模式/失败模式"]
|
||||
Pattern --> Optimize["优化生成"]
|
||||
|
||||
Optimize --> PromptOpt["Prompt 优化<br/>DSPy 编译器"]
|
||||
Optimize --> StrategyOpt["策略调优<br/>参数/工具权重"]
|
||||
Optimize --> PipelineOpt["Pipeline 优化<br/>路径选择"]
|
||||
|
||||
PromptOpt --> ABTest["A/B 测试"]
|
||||
StrategyOpt --> ABTest
|
||||
PipelineOpt --> ABTest
|
||||
|
||||
ABTest -->|"效果提升"| Apply["应用变更"]
|
||||
ABTest -->|"效果下降"| Rollback["回滚"]
|
||||
ABTest -->|"不确定"| Extend["延长测试"]
|
||||
|
||||
Apply --> Record["记录进化日志"]
|
||||
Rollback --> Record
|
||||
Record --> Execute
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Output Structure
|
||||
|
||||
```
|
||||
fischer-agentkit/ # 独立 Python 包
|
||||
├── src/
|
||||
│ └── agentkit/
|
||||
│ ├── __init__.py
|
||||
│ ├── core/
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── base.py # BaseAgent 生命周期
|
||||
│ │ ├── config_driven.py # ConfigDrivenAgent 配置驱动
|
||||
│ │ ├── registry.py # AgentRegistry 注册发现
|
||||
│ │ ├── dispatcher.py # TaskDispatcher 任务调度
|
||||
│ │ ├── protocol.py # 通信协议与数据结构
|
||||
│ │ ├── exceptions.py # 异常体系
|
||||
│ │ └── standalone.py # 独立启动器(自动发现)
|
||||
│ ├── tools/
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── base.py # Tool 抽象基类
|
||||
│ │ ├── function_tool.py # FunctionTool
|
||||
│ │ ├── agent_tool.py # AgentTool
|
||||
│ │ ├── mcp_tool.py # MCPTool
|
||||
│ │ └── registry.py # ToolRegistry
|
||||
│ ├── memory/
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── base.py # Memory 抽象基类
|
||||
│ │ ├── working.py # Working Memory (Redis)
|
||||
│ │ ├── episodic.py # Episodic Memory (pgvector+PG)
|
||||
│ │ ├── semantic.py # Semantic Memory (RAG+Graph)
|
||||
│ │ └── retriever.py # 混合检索器
|
||||
│ ├── evolution/
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── reflector.py # 执行反思
|
||||
│ │ ├── prompt_optimizer.py # DSPy 风格 Prompt 优化器
|
||||
│ │ ├── strategy_tuner.py # 策略调优
|
||||
│ │ ├── ab_tester.py # A/B 测试框架
|
||||
│ │ └── evolution_store.py # 进化日志存储
|
||||
│ ├── orchestrator/
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── pipeline_engine.py # DAG + 并行 Pipeline
|
||||
│ │ ├── pipeline_schema.py # Pipeline 数据模型
|
||||
│ │ ├── pipeline_loader.py # YAML 加载器
|
||||
│ │ ├── handoff.py # Handoff 机制
|
||||
│ │ └── dynamic_pipeline.py # 动态 Pipeline 组合
|
||||
│ ├── mcp/
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── server.py # MCP Server
|
||||
│ │ ├── client.py # MCP Client
|
||||
│ │ └── transport.py # 传输层 (HTTP/SSE)
|
||||
│ └── prompts/
|
||||
│ ├── __init__.py
|
||||
│ ├── template.py # PromptTemplate
|
||||
│ └── section.py # PromptSection
|
||||
├── tests/
|
||||
│ ├── unit/
|
||||
│ │ ├── test_base_agent.py
|
||||
│ │ ├── test_config_driven.py
|
||||
│ │ ├── test_tool_registry.py
|
||||
│ │ ├── test_function_tool.py
|
||||
│ │ ├── test_mcp_tool.py
|
||||
│ │ ├── test_agent_tool.py
|
||||
│ │ ├── test_working_memory.py
|
||||
│ │ ├── test_episodic_memory.py
|
||||
│ │ ├── test_semantic_memory.py
|
||||
│ │ ├── test_memory_retriever.py
|
||||
│ │ ├── test_reflector.py
|
||||
│ │ ├── test_prompt_optimizer.py
|
||||
│ │ ├── test_strategy_tuner.py
|
||||
│ │ ├── test_ab_tester.py
|
||||
│ │ ├── test_pipeline_parallel.py
|
||||
│ │ ├── test_handoff.py
|
||||
│ │ ├── test_dynamic_pipeline.py
|
||||
│ │ ├── test_mcp_server.py
|
||||
│ │ └── test_mcp_client.py
|
||||
│ └── integration/
|
||||
│ ├── test_agent_lifecycle.py
|
||||
│ ├── test_tool_composition.py
|
||||
│ ├── test_evolution_loop.py
|
||||
│ └── test_mcp_roundtrip.py
|
||||
├── pyproject.toml
|
||||
└── README.md
|
||||
|
||||
geo/backend/ # GEO 业务系统(改造后)
|
||||
├── app/
|
||||
│ ├── agent_framework/ # 保留为适配层
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── agents/ # 改为配置驱动
|
||||
│ │ │ ├── __init__.py
|
||||
│ │ │ ├── configs/ # Agent YAML 配置
|
||||
│ │ │ │ ├── citation_detector.yaml
|
||||
│ │ │ │ ├── content_generator.yaml
|
||||
│ │ │ │ ├── deai_agent.yaml
|
||||
│ │ │ │ ├── geo_optimizer.yaml
|
||||
│ │ │ │ ├── monitor.yaml
|
||||
│ │ │ │ ├── schema_advisor.yaml
|
||||
│ │ │ │ ├── competitor_analyzer.yaml
|
||||
│ │ │ │ └── trend_agent.yaml
|
||||
│ │ │ └── custom_handlers/ # 仅复杂 Agent 需自定义 handler
|
||||
│ │ │ ├── citation_handler.py
|
||||
│ │ │ └── monitor_handler.py
|
||||
│ │ ├── tools/ # 业务 Tool 注册
|
||||
│ │ │ ├── __init__.py
|
||||
│ │ │ ├── citation_tools.py
|
||||
│ │ │ ├── content_tools.py
|
||||
│ │ │ ├── knowledge_tools.py
|
||||
│ │ │ ├── monitor_tools.py
|
||||
│ │ │ └── schema_tools.py
|
||||
│ │ └── prompts/ # Prompt 模板(保留,可热更新)
|
||||
│ ├── models/
|
||||
│ │ └── agent.py # 新增 evolution_logs, episodic_memories 表
|
||||
│ └── ...
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Implementation Units
|
||||
|
||||
### U1. 独立包脚手架与 BaseAgent 重构
|
||||
|
||||
**Goal:** 创建 `fischer-agentkit` 独立包,重构 BaseAgent 将 execute 模板代码上移,子类只需实现 `handle_task()`。
|
||||
|
||||
**Dependencies:** 无
|
||||
|
||||
**Files:**
|
||||
- `fischer-agentkit/src/agentkit/__init__.py`
|
||||
- `fischer-agentkit/src/agentkit/core/__init__.py`
|
||||
- `fischer-agentkit/src/agentkit/core/base.py`
|
||||
- `fischer-agentkit/src/agentkit/core/protocol.py`
|
||||
- `fischer-agentkit/src/agentkit/core/exceptions.py`
|
||||
- `fischer-agentkit/src/agentkit/core/registry.py`
|
||||
- `fischer-agentkit/src/agentkit/core/dispatcher.py`
|
||||
- `fischer-agentkit/pyproject.toml`
|
||||
- `fischer-agentkit/tests/unit/test_base_agent.py`
|
||||
|
||||
**Approach:**
|
||||
|
||||
1. 创建 `fischer-agentkit` 包,使用 `src/` layout,`pyproject.toml` 声明依赖(`redis[hiredis]`, `pydantic>=2.0`, `sqlalchemy[asyncio]>=2.0`)
|
||||
2. 从 GEO 项目迁移 `protocol.py`(AgentCapability, TaskMessage, TaskResult, TaskProgress, TaskStatus, AgentStatus),移除 AgentType 硬编码枚举,改为动态注册
|
||||
3. 重构 `BaseAgent`:
|
||||
- `execute()` 方法变为 final(非抽象),包含完整的计时、try/except、TaskResult 构建、进度上报
|
||||
- 新增抽象方法 `handle_task(task) -> dict`,子类只需返回 output_data
|
||||
- 新增生命周期钩子:`on_task_start()`, `on_task_complete()`, `on_task_failed()`
|
||||
- 新增 `tools: list[Tool]` 属性(默认空,U3 实现)
|
||||
- 新增 `memory: Memory` 属性(默认空,U4 实现)
|
||||
4. 迁移 `registry.py`, `dispatcher.py`, `exceptions.py`,去除 GEO 业务依赖
|
||||
5. `AgentRegistry.get_available_agent()` 增加负载均衡策略(轮询/最少任务)
|
||||
|
||||
**Patterns to follow:** 现有 `backend/app/agent_framework/base.py` 的双模式(Redis/本地)设计
|
||||
|
||||
**Test scenarios:**
|
||||
- BaseAgent 子类实现 handle_task 返回 dict,execute 自动包装为 TaskResult
|
||||
- handle_task 抛异常时 execute 自动构建 FAILED TaskResult
|
||||
- on_task_start/complete/failed 钩子按序调用
|
||||
- AgentRegistry 动态注册新 AgentType 不报错
|
||||
- AgentRegistry.get_available_agent 轮询策略返回不同 Agent
|
||||
|
||||
**Verification:** `fischer-agentkit` 可独立 `pip install -e .`,单元测试全部通过
|
||||
|
||||
---
|
||||
|
||||
### U2. Protocol 扩展与配置驱动 Agent
|
||||
|
||||
**Goal:** 扩展 Protocol 支持 JSON Schema 校验、Handoff 消息;实现 ConfigDrivenAgent,支持 YAML/数据库配置驱动 Agent 定义。
|
||||
|
||||
**Dependencies:** U1
|
||||
|
||||
**Files:**
|
||||
- `fischer-agentkit/src/agentkit/core/protocol.py`(扩展)
|
||||
- `fischer-agentkit/src/agentkit/core/config_driven.py`
|
||||
- `fischer-agentkit/src/agentkit/core/standalone.py`
|
||||
- `fischer-agentkit/src/agentkit/prompts/template.py`
|
||||
- `fischer-agentkit/src/agentkit/prompts/section.py`
|
||||
- `fischer-agentkit/tests/unit/test_config_driven.py`
|
||||
- `fischer-agentkit/tests/unit/test_protocol.py`
|
||||
|
||||
**Approach:**
|
||||
|
||||
1. Protocol 扩展:
|
||||
- `AgentCapability` 增加 `input_schema: dict | None`、`output_schema: dict | None`(JSON Schema)
|
||||
- 新增 `HandoffMessage` 数据类(source_agent, target_agent, task_context, reason)
|
||||
- 新增 `EvolutionEvent` 数据类(agent_name, change_type, before, after, metrics)
|
||||
- `TaskMessage` 增加 `conversation_id: str | None` 支持多轮对话
|
||||
2. ConfigDrivenAgent:
|
||||
- 接受 YAML 配置或数据库配置,自动组装:Prompt 模板 + LLM 参数 + Tool 绑定 + Memory 策略
|
||||
- 内置 LLM 调用 + JSON 输出解析 + 降级兜底的标准流程
|
||||
- 支持三种任务模式:`llm_generate`(Prompt → LLM → 解析)、`tool_call`(调用指定 Tool)、`custom`(自定义 handler)
|
||||
3. PromptTemplate 迁移并增强:支持动态 section 组合、版本标记
|
||||
4. Standalone runner 改为自动发现:扫描 `agent_configs/` 目录下的 YAML 文件,自动注册和启动
|
||||
|
||||
**Patterns to follow:** 现有 `prompts/base_template.py` 的 PromptSection 设计
|
||||
|
||||
**Test scenarios:**
|
||||
- ConfigDrivenAgent 从 YAML 加载配置并正确组装
|
||||
- llm_generate 模式:渲染 Prompt → 调用 Mock LLM → 解析 JSON 输出
|
||||
- tool_call 模式:调用注册的 FunctionTool 并返回结果
|
||||
- input_schema 校验:缺少必填字段时返回校验错误
|
||||
- HandoffMessage 序列化/反序列化正确
|
||||
- 自动发现:在 configs/ 目录放置 YAML 后 standalone 可自动加载
|
||||
|
||||
**Verification:** 通过 YAML 配置定义一个新 Agent,无需写 Python 代码即可运行
|
||||
|
||||
---
|
||||
|
||||
### U3. Tool/Skill 插件系统
|
||||
|
||||
**Goal:** 实现 Tool 抽象基类、FunctionTool、AgentTool、ToolRegistry,支持工具注册、发现、组合。
|
||||
|
||||
**Dependencies:** U1
|
||||
|
||||
**Files:**
|
||||
- `fischer-agentkit/src/agentkit/tools/__init__.py`
|
||||
- `fischer-agentkit/src/agentkit/tools/base.py`
|
||||
- `fischer-agentkit/src/agentkit/tools/function_tool.py`
|
||||
- `fischer-agentkit/src/agentkit/tools/agent_tool.py`
|
||||
- `fischer-agentkit/src/agentkit/tools/registry.py`
|
||||
- `fischer-agentkit/tests/unit/test_tool_registry.py`
|
||||
- `fischer-agentkit/tests/unit/test_function_tool.py`
|
||||
- `fischer-agentkit/tests/unit/test_agent_tool.py`
|
||||
- `fischer-agentkit/tests/integration/test_tool_composition.py`
|
||||
|
||||
**Approach:**
|
||||
|
||||
1. `Tool` 抽象基类:
|
||||
- 属性:`name`, `description`, `input_schema`(JSON Schema), `output_schema`(JSON Schema), `version`, `tags`
|
||||
- 抽象方法:`async execute(**kwargs) -> dict`
|
||||
- 生命周期钩子:`before_execute()`, `after_execute()`, `on_error()`
|
||||
2. `FunctionTool`:包装普通 Python 函数为 Tool,自动从函数签名推断 input_schema
|
||||
3. `AgentTool`:包装另一个 Agent 为 Tool,输入/输出通过 mapping 适配,内部通过 Dispatcher 分发任务
|
||||
4. `ToolRegistry`:
|
||||
- `register(tool)`, `unregister(name)`, `get(name)`, `list_tools(tag=None)`
|
||||
- 支持版本管理:同一工具多版本共存,默认使用最新版
|
||||
- 支持标签分类:`[citation, analysis, generation, optimization]`
|
||||
5. 工具组合:
|
||||
- `SequentialChain([tool_a, tool_b])`:顺序执行,前一个输出作为后一个输入
|
||||
- `ParallelFanOut([tool_a, tool_b, tool_c])`:并行执行,结果合并
|
||||
- `DynamicSelector(llm, tools)`:LLM 根据任务动态选择工具
|
||||
|
||||
**Patterns to follow:** 现有 `LLMFactory` 的注册-创建模式
|
||||
|
||||
**Test scenarios:**
|
||||
- FunctionTool 从函数自动推断 schema 并执行
|
||||
- AgentTool 分发任务到目标 Agent 并返回结果
|
||||
- ToolRegistry 注册/发现/按标签过滤
|
||||
- SequentialChain 顺序执行两个工具
|
||||
- ParallelFanOut 并行执行三个工具
|
||||
- DynamicSelector 根据 LLM 判断选择合适工具
|
||||
- 工具版本管理:注册 v1 和 v2,默认返回 v2
|
||||
|
||||
**Verification:** 通过 ToolRegistry 注册 3 个 FunctionTool,Agent 可声明式绑定并调用
|
||||
|
||||
---
|
||||
|
||||
### U4. 记忆系统
|
||||
|
||||
**Goal:** 实现三层记忆系统(Working/Episodic/Semantic)和混合检索器。
|
||||
|
||||
**Dependencies:** U1
|
||||
|
||||
**Files:**
|
||||
- `fischer-agentkit/src/agentkit/memory/__init__.py`
|
||||
- `fischer-agentkit/src/agentkit/memory/base.py`
|
||||
- `fischer-agentkit/src/agentkit/memory/working.py`
|
||||
- `fischer-agentkit/src/agentkit/memory/episodic.py`
|
||||
- `fischer-agentkit/src/agentkit/memory/semantic.py`
|
||||
- `fischer-agentkit/src/agentkit/memory/retriever.py`
|
||||
- `fischer-agentkit/tests/unit/test_working_memory.py`
|
||||
- `fischer-agentkit/tests/unit/test_episodic_memory.py`
|
||||
- `fischer-agentkit/tests/unit/test_semantic_memory.py`
|
||||
- `fischer-agentkit/tests/unit/test_memory_retriever.py`
|
||||
|
||||
**Approach:**
|
||||
|
||||
1. `Memory` 抽象基类:
|
||||
- `async store(key, value, metadata)`, `async retrieve(query, top_k)`, `async delete(key)`
|
||||
- `async search(query, top_k, filters)` — 语义检索
|
||||
2. `WorkingMemory`(Redis):
|
||||
- 以 `agent:{name}:working_memory:{task_id}` 为 key 前缀
|
||||
- 支持自动过期(TTL = 任务超时时间 × 2)
|
||||
- 提供 `get_context()` 方法,返回格式化的上下文字符串
|
||||
3. `EpisodicMemory`(pgvector + PostgreSQL):
|
||||
- 表结构:`id, agent_name, task_type, input_summary, output_summary, outcome(success/fail), quality_score, reflection, embedding, created_at`
|
||||
- 写入时自动生成 embedding(复用 Embedder 接口)
|
||||
- 检索:向量语义 + 关键词 + 时间衰减 + RRF 融合
|
||||
4. `SemanticMemory`:
|
||||
- 适配器模式,对接 GEO 项目的 `RAGService` 和 `GraphQuery`
|
||||
- 提供统一的 `search(query, knowledge_base_ids, top_k)` 接口
|
||||
5. `MemoryRetriever`(混合检索器):
|
||||
- 并行查询三层记忆,按权重融合排序
|
||||
- 时间衰减:`score *= exp(-0.01 * age_hours)`
|
||||
- 上下文窗口管理:总 token 不超过预算
|
||||
|
||||
**Patterns to follow:** 现有 `HybridRetriever` 的 RRF 融合排序模式
|
||||
|
||||
**Test scenarios:**
|
||||
- WorkingMemory 存取任务上下文,TTL 过期后自动清除
|
||||
- EpisodicMemory 写入任务经验并按语义检索相似案例
|
||||
- EpisodicMemory 时间衰减:近期经验权重高于远期
|
||||
- SemanticMemory 通过适配器调用 RAGService
|
||||
- MemoryRetriever 混合检索三层记忆并按权重融合
|
||||
- 上下文窗口管理:检索结果超过 token 预算时智能截断
|
||||
|
||||
**Verification:** Agent 执行任务后自动写入 EpisodicMemory,后续相似任务可检索到历史经验
|
||||
|
||||
---
|
||||
|
||||
### U5. MCP Server 与 Client
|
||||
|
||||
**Goal:** 实现 MCP Server(暴露 Agent 能力)和 MCP Client(调用外部 MCP 工具),完成 MCPTool。
|
||||
|
||||
**Dependencies:** U3
|
||||
|
||||
**Files:**
|
||||
- `fischer-agentkit/src/agentkit/mcp/__init__.py`
|
||||
- `fischer-agentkit/src/agentkit/mcp/server.py`
|
||||
- `fischer-agentkit/src/agentkit/mcp/client.py`
|
||||
- `fischer-agentkit/src/agentkit/mcp/transport.py`
|
||||
- `fischer-agentkit/src/agentkit/tools/mcp_tool.py`
|
||||
- `fischer-agentkit/tests/unit/test_mcp_server.py`
|
||||
- `fischer-agentkit/tests/unit/test_mcp_client.py`
|
||||
- `fischer-agentkit/tests/unit/test_mcp_tool.py`
|
||||
- `fischer-agentkit/tests/integration/test_mcp_roundtrip.py`
|
||||
|
||||
**Approach:**
|
||||
|
||||
1. MCP Server:
|
||||
- 基于 FastAPI 实现,支持 Streamable HTTP 传输
|
||||
- 端点:`/tools/list`(列出可用工具)、`/tools/call`(调用工具)、`/resources/read`(读取资源)
|
||||
- 自动将 ToolRegistry 中注册的工具暴露为 MCP 工具
|
||||
- 支持 SSE 流式响应
|
||||
2. MCP Client:
|
||||
- 连接外部 MCP Server(HTTP/SSE)
|
||||
- 自动发现远程工具并注册到本地 ToolRegistry
|
||||
- 支持工具调用的流式响应
|
||||
3. MCPTool:
|
||||
- 继承 Tool 基类,内部通过 MCP Client 调用远程工具
|
||||
- 自动从 MCP Server 获取 input_schema
|
||||
4. Transport 层:
|
||||
- 抽象 `Transport` 接口(`send_request`, `receive_response`)
|
||||
- 实现 `HTTPTransport`(Streamable HTTP)和 `SSETransport`(Server-Sent Events)
|
||||
|
||||
**Patterns to follow:** MCP 官方 Python SDK 的 Server/Client 模式
|
||||
|
||||
**Test scenarios:**
|
||||
- MCP Server 启动后 `/tools/list` 返回已注册工具列表
|
||||
- MCP Client 连接 Server 并调用工具返回正确结果
|
||||
- MCPTool 通过 Client 调用远程工具
|
||||
- SSE 流式响应正确传输
|
||||
- MCP Server 自动暴露 ToolRegistry 中的 FunctionTool
|
||||
- 外部 MCP Server 的工具自动注册到本地 ToolRegistry
|
||||
|
||||
**Verification:** 启动 MCP Server,通过 MCP Client 调用 Agent 能力,端到端成功
|
||||
|
||||
---
|
||||
|
||||
### U6. 自我进化引擎(Level 3)
|
||||
|
||||
**Goal:** 实现执行反思、DSPy 风格 Prompt 自动优化、策略调优、A/B 测试框架。
|
||||
|
||||
**Dependencies:** U1, U4
|
||||
|
||||
**Files:**
|
||||
- `fischer-agentkit/src/agentkit/evolution/__init__.py`
|
||||
- `fischer-agentkit/src/agentkit/evolution/reflector.py`
|
||||
- `fischer-agentkit/src/agentkit/evolution/prompt_optimizer.py`
|
||||
- `fischer-agentkit/src/agentkit/evolution/strategy_tuner.py`
|
||||
- `fischer-agentkit/src/agentkit/evolution/ab_tester.py`
|
||||
- `fischer-agentkit/src/agentkit/evolution/evolution_store.py`
|
||||
- `fischer-agentkit/tests/unit/test_reflector.py`
|
||||
- `fischer-agentkit/tests/unit/test_prompt_optimizer.py`
|
||||
- `fischer-agentkit/tests/unit/test_strategy_tuner.py`
|
||||
- `fischer-agentkit/tests/unit/test_ab_tester.py`
|
||||
- `fischer-agentkit/tests/integration/test_evolution_loop.py`
|
||||
|
||||
**Approach:**
|
||||
|
||||
1. `Reflector`(执行反思):
|
||||
- 每次任务完成后自动评估:成功/失败、质量评分(基于 output_schema 约束和业务指标)
|
||||
- 提取模式:常见失败原因、高效策略、低效路径
|
||||
- 生成反思总结存入 EpisodicMemory
|
||||
2. `PromptOptimizer`(DSPy 风格编译器):
|
||||
- `Signature`:定义 `input_fields -> output_fields` 的结构化签名
|
||||
- `Module`:可组合的 Prompt 策略(ChainOfThought, ReAct, FewShot)
|
||||
- `Optimizer`:
|
||||
- `BootstrapFewShot`:从成功案例中自动构建 few-shot 示例
|
||||
- `MIPROv2`:多目标 Prompt 优化(指令 + few-shot 联合优化)
|
||||
- 基于历史任务数据编译,产出优化后的 Prompt
|
||||
3. `StrategyTuner`(策略调优):
|
||||
- 可调参数:temperature, tool_selection_weights, pipeline_path
|
||||
- 基于 Bayesian Optimization 搜索最优参数组合
|
||||
- 每次调优记录 before/after 指标
|
||||
4. `ABTester`(A/B 测试框架):
|
||||
- 支持配置分流比例(如 80% 原版 / 20% 实验版)
|
||||
- 自动收集实验组和对照组的效果指标
|
||||
- 统计显著性检验(t-test),达到置信度后自动决策
|
||||
5. `EvolutionStore`(进化日志):
|
||||
- 表结构:`id, agent_name, change_type(prompt/strategy/pipeline), before, after, ab_test_id, status(active/rolled_back), created_at`
|
||||
- 支持回滚:`rollback(evolution_id)`
|
||||
|
||||
**Patterns to follow:** DSPy 的 Signature/Module/Optimizer 三层架构
|
||||
|
||||
**Test scenarios:**
|
||||
- Reflector 评估成功任务生成正面反思
|
||||
- Reflector 评估失败任务提取失败模式
|
||||
- PromptOptimizer 从 10 个成功案例自动生成 few-shot 示例
|
||||
- PromptOptimizer 优化后的 Prompt 在测试集上效果提升
|
||||
- StrategyTuner 调整 temperature 后效果改善
|
||||
- ABTester 80/20 分流,实验组效果显著后自动应用
|
||||
- EvolutionStore 记录变更并支持回滚
|
||||
|
||||
**Verification:** Agent 执行 20 次任务后,PromptOptimizer 自动优化 Prompt,ABTest 验证效果提升后自动应用
|
||||
|
||||
---
|
||||
|
||||
### U7. 多 Agent 协同增强
|
||||
|
||||
**Goal:** Pipeline Engine 支持并行执行、Handoff 机制、动态 Pipeline 组合。
|
||||
|
||||
**Dependencies:** U1, U2
|
||||
|
||||
**Files:**
|
||||
- `fischer-agentkit/src/agentkit/orchestrator/__init__.py`
|
||||
- `fischer-agentkit/src/agentkit/orchestrator/pipeline_engine.py`
|
||||
- `fischer-agentkit/src/agentkit/orchestrator/pipeline_schema.py`
|
||||
- `fischer-agentkit/src/agentkit/orchestrator/pipeline_loader.py`
|
||||
- `fischer-agentkit/src/agentkit/orchestrator/handoff.py`
|
||||
- `fischer-agentkit/src/agentkit/orchestrator/dynamic_pipeline.py`
|
||||
- `fischer-agentkit/tests/unit/test_pipeline_parallel.py`
|
||||
- `fischer-agentkit/tests/unit/test_handoff.py`
|
||||
- `fischer-agentkit/tests/unit/test_dynamic_pipeline.py`
|
||||
|
||||
**Approach:**
|
||||
|
||||
1. Pipeline 并行执行:
|
||||
- 拓扑排序后按依赖层级分组(同层无依赖)
|
||||
- 同层内使用 `asyncio.gather` 并行执行
|
||||
- 并行 stage 的结果合并后传给下一层
|
||||
2. Handoff 机制:
|
||||
- `HandoffManager` 管理转交请求
|
||||
- Agent 调用 `self.handoff(target_agent, context, reason)` 发起转交
|
||||
- 通过 Redis Pub/Sub `agent:{target}:handoff` 频道通知目标 Agent
|
||||
- 目标 Agent 接收后创建新任务,携带源 Agent 的上下文
|
||||
- 源 Agent 的任务状态变为 `HANDOFF`
|
||||
3. 动态 Pipeline:
|
||||
- 支持 `sub_pipeline` stage 类型:引用另一个 YAML Pipeline
|
||||
- 支持 `conditional_pipeline`:根据运行时条件选择子流程
|
||||
- 支持 `loop_pipeline`:循环执行直到条件满足
|
||||
4. 事件驱动替代轮询:
|
||||
- Pipeline Engine 通过 Redis Pub/Sub 订阅 `agent:{name}:result` 频道
|
||||
- 任务完成后自动触发下一 stage,无需轮询
|
||||
|
||||
**Patterns to follow:** 现有 `PipelineEngine` 的 DAG 拓扑排序和变量解析
|
||||
|
||||
**Test scenarios:**
|
||||
- 3 个无依赖 stage 并行执行,总耗时约等于最长单个 stage
|
||||
- Handoff:Agent A 转交任务给 Agent B,B 接收并执行
|
||||
- 动态 Pipeline:根据条件选择不同的子流程
|
||||
- 子 Pipeline 嵌套执行并正确传递变量
|
||||
- 事件驱动:任务完成后自动触发下一 stage,无轮询间隔
|
||||
- 循环 Pipeline:条件满足后退出循环
|
||||
|
||||
**Verification:** 内容生产 Pipeline 的 deai_processing 和 geo_optimization 并行执行,总耗时减少约 40%
|
||||
|
||||
---
|
||||
|
||||
### U8. GEO 业务系统适配与迁移
|
||||
|
||||
**Goal:** 将 GEO 项目的 8 个 Agent 迁移到新框架,业务 Service 注册为 Tool,实现完全解耦。
|
||||
|
||||
**Dependencies:** U1, U2, U3, U4, U5, U6, U7
|
||||
|
||||
**Files:**
|
||||
- `geo/backend/app/agent_framework/agents/configs/citation_detector.yaml`
|
||||
- `geo/backend/app/agent_framework/agents/configs/content_generator.yaml`
|
||||
- `geo/backend/app/agent_framework/agents/configs/deai_agent.yaml`
|
||||
- `geo/backend/app/agent_framework/agents/configs/geo_optimizer.yaml`
|
||||
- `geo/backend/app/agent_framework/agents/configs/monitor.yaml`
|
||||
- `geo/backend/app/agent_framework/agents/configs/schema_advisor.yaml`
|
||||
- `geo/backend/app/agent_framework/agents/configs/competitor_analyzer.yaml`
|
||||
- `geo/backend/app/agent_framework/agents/configs/trend_agent.yaml`
|
||||
- `geo/backend/app/agent_framework/agents/custom_handlers/citation_handler.py`
|
||||
- `geo/backend/app/agent_framework/agents/custom_handlers/monitor_handler.py`
|
||||
- `geo/backend/app/agent_framework/tools/__init__.py`
|
||||
- `geo/backend/app/agent_framework/tools/citation_tools.py`
|
||||
- `geo/backend/app/agent_framework/tools/content_tools.py`
|
||||
- `geo/backend/app/agent_framework/tools/knowledge_tools.py`
|
||||
- `geo/backend/app/agent_framework/tools/monitor_tools.py`
|
||||
- `geo/backend/app/agent_framework/tools/schema_tools.py`
|
||||
- `geo/backend/app/agent_framework/tools/competitor_tools.py`
|
||||
- `geo/backend/app/agent_framework/tools/trend_tools.py`
|
||||
- `geo/backend/app/models/agent.py`(新增表)
|
||||
- `geo/backend/requirements.txt`(添加 fischer-agentkit 依赖)
|
||||
- `geo/backend/app/agent_framework/__init__.py`(适配层)
|
||||
|
||||
**Approach:**
|
||||
|
||||
1. 业务 Tool 注册:
|
||||
- `CitationTools`:`execute_single_platform`, `get_or_create_task`, `calculate_next_query_at`
|
||||
- `ContentTools`:`retrieve_knowledge`(包装 RAGService)
|
||||
- `MonitorTools`:`check_and_compare`, `generate_change_report`, `create_monitoring_record`
|
||||
- `SchemaTools`:`identify_missing_dimensions`, `match_templates`, `validate_json_ld`
|
||||
- `CompetitorTools`:`analyze_competitor`
|
||||
- `TrendTools`:`analyze_trends`, `get_hotspots`
|
||||
2. Agent 配置迁移:
|
||||
- LLM 驱动型 Agent(ContentGenerator, DeAI, GEOOptimizer, SchemaAdvisor)→ YAML 配置 + `llm_generate` 模式,零代码
|
||||
- Service 代理型 Agent(CitationDetector, Monitor)→ YAML 配置 + `custom` handler,仅保留业务逻辑方法
|
||||
- CompetitorAnalyzer, TrendAgent → YAML 配置 + `tool_call` 模式
|
||||
3. 数据库迁移:
|
||||
- 新增 `episodic_memories` 表
|
||||
- 新增 `evolution_logs` 表
|
||||
- 新增 `ab_test_configs` 表
|
||||
4. 适配层:
|
||||
- `geo/backend/app/agent_framework/` 保留为适配层,import from `agentkit`
|
||||
- 现有 API 端点(`/agents/`, `/agents/tasks/`)保持不变
|
||||
- 现有 Pipeline YAML 保持兼容
|
||||
|
||||
**Patterns to follow:** 现有 Agent 的业务逻辑实现,迁移时保持行为一致
|
||||
|
||||
**Test scenarios:**
|
||||
- 8 个 Agent 迁移后行为与原版一致(回归测试)
|
||||
- 新增 Agent 只需 YAML 配置,无需写 Python 代码
|
||||
- 业务 Tool 注册后 Agent 可通过 ToolRegistry 发现和调用
|
||||
- EpisodicMemory 自动记录任务经验
|
||||
- EvolutionStore 记录进化变更
|
||||
- 现有 API 端点功能不变
|
||||
- 现有 Pipeline YAML 正常执行
|
||||
|
||||
**Verification:** 运行现有 Agent 集成测试全部通过,新增一个 YAML-only Agent 成功执行任务
|
||||
|
||||
---
|
||||
|
||||
## Scope Boundaries
|
||||
|
||||
### In Scope
|
||||
|
||||
- `fischer-agentkit` 独立包的完整实现(Core, Tools, Memory, Evolution, Orchestrator, MCP)
|
||||
- GEO 项目 8 个 Agent 的迁移和适配
|
||||
- MCP Server/Client 双向支持
|
||||
- Level 3 自我进化(反思 + Prompt 优化 + 策略调优 + A/B 测试)
|
||||
- Pipeline 并行执行、Handoff、动态 Pipeline
|
||||
- 三层记忆系统
|
||||
- 数据库迁移(新增表)
|
||||
|
||||
### Deferred for Later
|
||||
|
||||
- A2A Protocol 支持(跨组织 Agent 协作,待 MCP 稳定后再评估)
|
||||
- Agent 可视化编排 UI(前端拖拽式 Pipeline 编辑器)
|
||||
- Agent 市场/共享机制(跨组织共享 Agent 配置和 Tool)
|
||||
- 多租户隔离增强(Agent 级别的资源隔离和配额)
|
||||
- Agent 安全沙箱(Tool 执行的权限控制和审计)
|
||||
|
||||
### Outside This Project's Identity
|
||||
|
||||
- 通用 LLM Gateway(不属于 Agent 框架范畴)
|
||||
- 替代现有 LLMFactory(保持现有 LLM 调用体系不变)
|
||||
- 前端 Agent 管理界面重构(本次聚焦后端架构)
|
||||
|
||||
---
|
||||
|
||||
## Risks & Dependencies
|
||||
|
||||
| Risk | Impact | Mitigation |
|
||||
|------|--------|------------|
|
||||
| MCP 协议规范变动(2025-2026 仍在演进) | MCPTool/MCPClient 需要跟进修改 | 抽象 Transport 层,协议变更只影响 Transport 实现 |
|
||||
| DSPy 风格 Prompt 优化可能需要大量训练数据 | 优化效果不明显 | BootstrapFewShot 从少量数据开始,MIPROv2 需 50+ 案例时才启用 |
|
||||
| 独立包与 GEO 项目的版本同步 | GEO 升级 agentkit 版本可能引入 breaking change | 语义化版本管理,GEO 项目 pin 版本,CI 自动测试兼容性 |
|
||||
| Episodic Memory 数据量增长 | pgvector 查询性能下降 | 设置 TTL 自动清理、定期归档、HNSW 索引优化 |
|
||||
| 迁移期间 8 个 Agent 行为不一致 | 业务回归 | 逐个迁移 + 回归测试,保留旧代码直到新代码验证通过 |
|
||||
| Pipeline 并行执行引入并发问题 | 结果不一致 | 同层 stage 只读共享变量,写入隔离 |
|
||||
|
||||
---
|
||||
|
||||
## System-Wide Impact
|
||||
|
||||
- **数据库**:新增 3 张表(episodic_memories, evolution_logs, ab_test_configs),需 Alembic 迁移
|
||||
- **Redis**:新增 Working Memory key 空间和 Handoff 频道,内存使用增加约 20%
|
||||
- **API**:现有 `/agents/` 端点保持兼容,新增 `/agents/{name}/evolution` 查看进化历史
|
||||
- **部署**:`fischer-agentkit` 需发布到私有 PyPI,GEO 项目 Dockerfile 需添加安装步骤
|
||||
- **监控**:新增进化相关 Prometheus 指标(evolution_attempts_total, evolution_success_rate, ab_test_decisions)
|
||||
- **依赖**:新增 `mcp` Python 包依赖(MCP SDK)
|
||||
|
||||
---
|
||||
|
||||
## Sources & Research
|
||||
|
||||
- 现有 Agent 框架代码:`backend/app/agent_framework/` 全部文件
|
||||
- 现有 RAG 服务:`backend/app/services/knowledge/rag_service.py`, `retriever.py`, `enhanced_rag.py`
|
||||
- 现有 LLM 工厂:`backend/app/services/llm/factory.py`, `base.py`
|
||||
- MCP 官方规范:Model Context Protocol (Anthropic, 2024-2025)
|
||||
- DSPy 框架:Stanford NLP, Prompt 自动优化范式
|
||||
- Google A2A Protocol:Agent-to-Agent HTTP 协议 (2025.4)
|
||||
- LangGraph 0.2.x:StateGraph + Checkpoint 架构
|
||||
- OpenAI Agents SDK 1.0:Handoff 模式
|
||||
- Google ADK 0.1:Agent 组合模式(Sequential/Parallel/Loop)
|
||||
File diff suppressed because one or more lines are too long
|
|
@ -1,35 +1,4 @@
|
|||
{
|
||||
"status": "failed",
|
||||
"failedTests": [
|
||||
"49bb04a247fb3a93c942-35c9c801dc12bf7d35f2",
|
||||
"49bb04a247fb3a93c942-698fa46448ec3c4fcc54",
|
||||
"49bb04a247fb3a93c942-1e4abb88d57e30e27b50",
|
||||
"49bb04a247fb3a93c942-e8f0348aed6f07ff875b",
|
||||
"49bb04a247fb3a93c942-425d1f1370bc2355e391",
|
||||
"49bb04a247fb3a93c942-256ee0a55d33f5229d15",
|
||||
"49bb04a247fb3a93c942-5f9ad201eca7dc8ab295",
|
||||
"49bb04a247fb3a93c942-ef50deff2e29416bfdcb",
|
||||
"49bb04a247fb3a93c942-ea929ccac5306e953e14",
|
||||
"49bb04a247fb3a93c942-fa6e92367349adbf7ce6",
|
||||
"49bb04a247fb3a93c942-66abb2794b5dc22a4d2c",
|
||||
"49bb04a247fb3a93c942-c8f2635d09bd86abde8c",
|
||||
"49bb04a247fb3a93c942-495ca225b90357a78c73",
|
||||
"49bb04a247fb3a93c942-9542aeae6bc1bec1de36",
|
||||
"49bb04a247fb3a93c942-6d102a4fe6c78e5adea3",
|
||||
"49bb04a247fb3a93c942-16202bb8288239216298",
|
||||
"49bb04a247fb3a93c942-103e4c5c525b115390ab",
|
||||
"49bb04a247fb3a93c942-8ff6f91ccaf60e9cf222",
|
||||
"49bb04a247fb3a93c942-594d166a39b1685af15e",
|
||||
"c05d59a3df11e753f1a9-ebd92d58fe79ae1246f9",
|
||||
"30f4937bf66a39abf0e1-654253a85d12a575b58b",
|
||||
"30f4937bf66a39abf0e1-531589fe51a02f263015",
|
||||
"655e7220bbb90e50e0e9-a2f8fcdf1070b299cd54",
|
||||
"655e7220bbb90e50e0e9-49b697c7e1ba346c76b7",
|
||||
"655e7220bbb90e50e0e9-c5cd58b8cddad9907a64",
|
||||
"655e7220bbb90e50e0e9-696b9830da3889bddb0f",
|
||||
"655e7220bbb90e50e0e9-e8bb52f185bd86dce2da",
|
||||
"fe534f0825407f213faa-71fec5acbf34a381ec37",
|
||||
"fe534f0825407f213faa-c40f01c5980780ee6589",
|
||||
"fe534f0825407f213faa-a6a9b88ab9323419b7d3"
|
||||
]
|
||||
"status": "passed",
|
||||
"failedTests": []
|
||||
}
|
||||
Loading…
Reference in New Issue