chore: geo production readiness improvements

This commit is contained in:
chiguyong 2026-06-04 22:08:06 +08:00
parent 435fec2b00
commit 79139bc504
56 changed files with 1368 additions and 423 deletions

View File

@ -29,6 +29,10 @@ from app.services.alert.alert_engine import AlertEngine
logger = logging.getLogger(__name__)
router = APIRouter()
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(str(value))
async def verify_brand_ownership(
@ -40,7 +44,7 @@ async def verify_brand_ownership(
stmt = select(Brand).where(
and_(
Brand.id == brand_id,
Brand.user_id == current_user.id,
Brand.user_id == _to_uuid(current_user.id),
)
)
result = await db.execute(stmt)
@ -76,7 +80,7 @@ async def get_alerts(
支持按类型严重程度已读状态和品牌筛选按创建时间倒序排列
"""
# 构建查询条件
conditions = [Alert.user_id == current_user.id]
conditions = [Alert.user_id == _to_uuid(current_user.id)]
if alert_type:
conditions.append(Alert.alert_type == alert_type)
if severity:
@ -113,7 +117,7 @@ async def get_unread_count(
"""获取当前用户的未读告警数量"""
stmt = select(func.count()).select_from(Alert).where(
and_(
Alert.user_id == current_user.id,
Alert.user_id == _to_uuid(current_user.id),
Alert.is_read == False,
)
)
@ -131,7 +135,7 @@ async def mark_all_read(
"""将当前用户的所有告警标记为已读"""
stmt = select(Alert).where(
and_(
Alert.user_id == current_user.id,
Alert.user_id == _to_uuid(current_user.id),
Alert.is_read == False,
)
)
@ -158,7 +162,7 @@ async def mark_read(
stmt = select(Alert).where(
and_(
Alert.id == alert_id,
Alert.user_id == current_user.id,
Alert.user_id == _to_uuid(current_user.id),
)
)
result = await db.execute(stmt)
@ -193,14 +197,14 @@ async def get_alert_settings(
如果指定 brand_id则返回该品牌的告警设置
否则返回所有品牌的告警设置
"""
conditions = [AlertSetting.user_id == current_user.id]
conditions = [AlertSetting.user_id == _to_uuid(current_user.id)]
if brand_id:
conditions.append(AlertSetting.brand_id == brand_id)
# 如果指定了品牌但该品牌没有设置,则自动初始化默认设置
if brand_id:
engine = AlertEngine(db)
await engine.ensure_default_settings(brand_id, current_user.id)
await engine.ensure_default_settings(brand_id, _to_uuid(current_user.id))
count_stmt = select(func.count()).select_from(AlertSetting).where(and_(*conditions))
count_result = await db.execute(count_stmt)
@ -257,7 +261,7 @@ async def update_alert_settings(
# 创建
setting = AlertSetting(
brand_id=item.brand_id,
user_id=current_user.id,
user_id=_to_uuid(current_user.id),
alert_type=item.alert_type,
enabled=item.enabled,
threshold=item.threshold,
@ -286,7 +290,7 @@ async def update_single_setting(
stmt = select(AlertSetting).where(
and_(
AlertSetting.id == setting_id,
AlertSetting.user_id == current_user.id,
AlertSetting.user_id == _to_uuid(current_user.id),
)
)
result = await db.execute(stmt)
@ -338,7 +342,7 @@ async def create_alert_setting(
# 创建新设置
setting = AlertSetting(
brand_id=data.brand_id,
user_id=current_user.id,
user_id=_to_uuid(current_user.id),
alert_type=data.alert_type,
enabled=data.enabled,
threshold=data.threshold,
@ -361,7 +365,7 @@ async def delete_alert_setting(
stmt = select(AlertSetting).where(
and_(
AlertSetting.id == setting_id,
AlertSetting.user_id == current_user.id,
AlertSetting.user_id == _to_uuid(current_user.id),
)
)
result = await db.execute(stmt)

View File

@ -42,7 +42,7 @@ async def add_key(
source = KeySource.USER if body.source == "user" else KeySource.SYSTEM
config = get_key_manager().add_key(
engine_type=body.engine_type,
api_key=body.api_key,
credentials=body.api_key,
source=source,
user_id=str(current_user.id),
)

View File

@ -23,6 +23,12 @@ logger = logging.getLogger(__name__)
router = APIRouter()
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(str(value))
# Include competitors router under brands
router.include_router(competitors_router)
@ -47,7 +53,7 @@ async def get_brands(
# 修复 N+1一次性加载 competitors 和 suggestions
stmt = (
select(Brand)
.where(Brand.user_id == current_user.id)
.where(Brand.user_id == _to_uuid(current_user.id))
.options(
selectinload(Brand.competitors),
selectinload(Brand.suggestions),
@ -73,7 +79,7 @@ async def create_brand(
):
"""Create a new brand."""
brand = Brand(
user_id=current_user.id,
user_id=_to_uuid(current_user.id),
name=brand_data.name,
aliases=brand_data.aliases,
website=brand_data.website,
@ -99,7 +105,7 @@ async def get_brand(
db: AsyncSession = Depends(get_db),
):
"""Get a specific brand by ID."""
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id)
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == _to_uuid(current_user.id))
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
@ -119,7 +125,7 @@ async def update_brand(
db: AsyncSession = Depends(get_db),
):
"""Update a brand."""
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id)
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == _to_uuid(current_user.id))
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
@ -173,7 +179,7 @@ async def delete_brand(
db: AsyncSession = Depends(get_db),
):
"""Delete a brand."""
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id)
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == _to_uuid(current_user.id))
result = await db.execute(stmt)
brand = result.scalar_one_or_none()

View File

@ -19,6 +19,10 @@ from app.services.citation.citation import (
)
router = APIRouter()
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(str(value))
@router.get("/", response_model=CitationListResponse)
@ -34,7 +38,7 @@ async def list_citations(
):
items, total = await get_citations(
db,
current_user.id,
_to_uuid(current_user.id),
query_id=query_id,
platform=platform,
start_date=start_date,
@ -56,7 +60,7 @@ async def citation_stats(
if brand_id is not None:
brand_stmt = select(Brand).where(
Brand.id == brand_id,
Brand.user_id == current_user.id,
Brand.user_id == _to_uuid(current_user.id),
)
brand_result = await db.execute(brand_stmt)
brand = brand_result.scalar_one_or_none()
@ -68,7 +72,7 @@ async def citation_stats(
)
stats = await get_citation_stats(
db, current_user.id, query_id=query_id, brand_id=brand_id
db, _to_uuid(current_user.id), query_id=query_id, brand_id=brand_id
)
return stats

View File

@ -23,12 +23,18 @@ logger = logging.getLogger(__name__)
router = APIRouter()
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(str(value))
async def _get_brand_if_owned(
brand_id: uuid.UUID,
current_user: User,
db: AsyncSession,
) -> Brand:
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id)
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == _to_uuid(current_user.id))
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
if not brand:

View File

@ -28,13 +28,19 @@ logger = logging.getLogger(__name__)
router = APIRouter()
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(str(value))
async def get_brand_if_owned(
brand_id: uuid.UUID,
current_user: User,
db: AsyncSession,
) -> Brand:
"""Helper to get brand if owned by current user."""
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id)
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == _to_uuid(current_user.id))
result = await db.execute(stmt)
brand = result.scalar_one_or_none()

View File

@ -28,6 +28,12 @@ from app.services.cache import get_cache_service, TTL_DASHBOARD
router = APIRouter()
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(str(value))
@router.get("/stats", response_model=DashboardStatsResponse)
async def get_dashboard_stats(
brand_id: uuid.UUID | None = Query(None),
@ -47,7 +53,7 @@ async def get_dashboard_stats(
"""
cache = get_cache_service()
if brand_id is None:
brand_stmt = select(Brand).where(Brand.user_id == current_user.id).limit(1)
brand_stmt = select(Brand).where(Brand.user_id == _to_uuid(current_user.id)).limit(1)
brand_result = await db.execute(brand_stmt)
brand = brand_result.scalar_one_or_none()
@ -80,7 +86,7 @@ async def get_dashboard_stats(
scoring_data_service = get_brand_scoring_data_service()
scoring_data = await scoring_data_service.get_brand_scoring_data(
db, current_user.id, brand
db, _to_uuid(current_user.id), brand
)
overall_score = scoring_data.v2_result.overall_score
@ -125,7 +131,7 @@ async def get_dashboard_stats(
platform_scores_dict = scoring_data.platform_scores
competitor_scores_dict = await scoring_data_service.get_competitor_platform_scores(
db, current_user.id, brand_id
db, _to_uuid(current_user.id), brand_id
)
competitor_stmt = select(Competitor).where(

View File

@ -23,6 +23,12 @@ logger = logging.getLogger(__name__)
router = APIRouter()
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(str(value))
async def verify_brand_ownership(
brand_id: uuid.UUID,
current_user: User,
@ -31,7 +37,7 @@ async def verify_brand_ownership(
stmt = select(Brand).where(
and_(
Brand.id == brand_id,
Brand.user_id == current_user.id,
Brand.user_id == _to_uuid(current_user.id),
)
)
result = await db.execute(stmt)
@ -59,7 +65,7 @@ async def create_detection_task(
task = await service.create_task(
task_data=data.model_dump(exclude={"brand_id"}),
brand_id=data.brand_id,
user_id=current_user.id,
user_id=_to_uuid(current_user.id),
db=db,
)
except ValueError as e:
@ -82,7 +88,7 @@ async def get_detection_tasks(
await verify_brand_ownership(brand_id, current_user, db)
service = DetectionSchedulerService()
tasks = await service.get_tasks(brand_id, current_user.id, db)
tasks = await service.get_tasks(brand_id, _to_uuid(current_user.id), db)
return {"items": tasks[skip : skip + limit], "total": len(tasks)}
@ -99,7 +105,7 @@ async def update_detection_task(
task = await service.update_task(
task_id=task_id,
task_data=data.model_dump(exclude_unset=True),
user_id=current_user.id,
user_id=_to_uuid(current_user.id),
db=db,
)
except TaskNotFoundError:
@ -123,7 +129,7 @@ async def delete_detection_task(
db: AsyncSession = Depends(get_db),
):
service = DetectionSchedulerService()
result = await service.delete_task(task_id, current_user.id, db)
result = await service.delete_task(task_id, _to_uuid(current_user.id), db)
if not result:
raise HTTPException(
@ -141,7 +147,7 @@ async def trigger_detection_task(
db: AsyncSession = Depends(get_db),
):
service = DetectionSchedulerService()
result = await service.trigger_task(task_id, current_user.id, db)
result = await service.trigger_task(task_id, _to_uuid(current_user.id), db)
if result.get("status") == "error":
raise HTTPException(

View File

@ -19,6 +19,10 @@ from app.schemas.monitoring import (
from app.services.monitoring.monitor_service import MonitorService
router = APIRouter()
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(str(value))
async def _get_brand_with_access(
@ -28,7 +32,7 @@ async def _get_brand_with_access(
) -> Brand:
stmt = select(Brand).where(
Brand.id == brand_id,
Brand.user_id == current_user.id,
Brand.user_id == _to_uuid(current_user.id),
)
result = await db.execute(stmt)
brand = result.scalar_one_or_none()

View File

@ -22,6 +22,12 @@ logger = logging.getLogger(__name__)
router = APIRouter()
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(str(value))
async def _compute_v2_scores(
db: AsyncSession,
user_id: uuid.UUID,
@ -98,10 +104,10 @@ async def export_report(
try:
v2_result = None
if brand_id is not None:
v2_result = await _compute_v2_scores(db, current_user.id, brand_id)
v2_result = await _compute_v2_scores(db, _to_uuid(current_user.id), brand_id)
csv_content = await export_citations_csv(
db, current_user.id, query_id, v2_result=v2_result
db, _to_uuid(current_user.id), query_id, v2_result=v2_result
)
except ValueError as e:
raise HTTPException(
@ -131,10 +137,10 @@ async def export_pdf(
try:
v2_result = None
if brand_id is not None:
v2_result = await _compute_v2_scores(db, current_user.id, brand_id)
v2_result = await _compute_v2_scores(db, _to_uuid(current_user.id), brand_id)
pdf_bytes = await export_citations_pdf(
db, current_user.id, query_id, v2_result=v2_result
db, _to_uuid(current_user.id), query_id, v2_result=v2_result
)
except ValueError as e:
raise HTTPException(

View File

@ -20,6 +20,10 @@ from app.services.schema.schema_advisor_service import SchemaAdvisorService
from app.services.scoring.scoring_service import ScoringService
router = APIRouter()
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(str(value))
async def _get_brand_with_access(
@ -29,7 +33,7 @@ async def _get_brand_with_access(
) -> Brand:
stmt = select(Brand).where(
Brand.id == brand_id,
Brand.user_id == current_user.id,
Brand.user_id == _to_uuid(current_user.id),
)
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
@ -152,7 +156,7 @@ async def generate_schema_advise(
):
brand = await _get_brand_with_access(request.brand_id, db, current_user)
diagnosis_data = await _get_brand_diagnosis_data(db, current_user.id, brand)
diagnosis_data = await _get_brand_diagnosis_data(db, _to_uuid(current_user.id), brand)
brand_info = {
"name": brand.name,

View File

@ -33,6 +33,10 @@ from app.services.analysis.sentiment_service import get_sentiment_service
logger = logging.getLogger(__name__)
router = APIRouter()
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(str(value))
async def _get_brand_with_access(
@ -43,7 +47,7 @@ async def _get_brand_with_access(
"""Verify brand exists and user has access."""
stmt = select(Brand).where(
Brand.id == brand_id,
Brand.user_id == current_user.id,
Brand.user_id == _to_uuid(current_user.id),
)
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
@ -398,7 +402,7 @@ async def get_brand_score(
# Get citations data
total_queries, brand_citations, competitor_citations, competitor_mentions = (
await _get_citations_for_brand(db, current_user.id, brand_id)
await _get_citations_for_brand(db, _to_uuid(current_user.id), brand_id)
)
# 情感分析
@ -471,7 +475,7 @@ async def get_brand_score(
await alert_engine.detect_after_scoring(
brand_id=brand_id,
brand_name=brand.name,
user_id=current_user.id,
user_id=_to_uuid(current_user.id),
current_score=v2_result.overall_score,
sentiment_counts=sentiment_counts,
brand_mentions=len(brand_citations),
@ -539,7 +543,7 @@ async def get_brand_score_v1(
# Get citations data
total_queries, brand_citations, competitor_citations, _ = (
await _get_citations_for_brand(db, current_user.id, brand_id)
await _get_citations_for_brand(db, _to_uuid(current_user.id), brand_id)
)
# Calculate scores using scoring service
@ -682,7 +686,7 @@ async def get_brand_comparison(
# Get brand's own score
total_queries, brand_citations, competitor_citations, competitor_mentions = (
await _get_citations_for_brand(db, current_user.id, brand_id)
await _get_citations_for_brand(db, _to_uuid(current_user.id), brand_id)
)
scoring_service = ScoringService()

View File

@ -24,6 +24,10 @@ from app.services.strategy.geo_plan_generator import generate_geo_plan
from app.services.content.content_generation_service import ContentGenerationService
router = APIRouter()
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(str(value))
async def _get_brand_with_access(
@ -33,7 +37,7 @@ async def _get_brand_with_access(
) -> Brand:
stmt = select(Brand).where(
Brand.id == brand_id,
Brand.user_id == current_user.id,
Brand.user_id == _to_uuid(current_user.id),
)
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
@ -117,7 +121,7 @@ async def generate_geo_plan_endpoint(
platform_scores,
total_queries,
mentioned_count,
) = await _get_brand_scoring_data(db, current_user.id, brand)
) = await _get_brand_scoring_data(db, _to_uuid(current_user.id), brand)
target_score = request.target_score or 75

View File

@ -30,6 +30,10 @@ from app.services.advisor.optimization_advisor import (
)
router = APIRouter()
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(str(value))
async def _get_brand_with_access(
@ -40,7 +44,7 @@ async def _get_brand_with_access(
"""验证品牌存在且用户有访问权限"""
stmt = select(Brand).where(
Brand.id == brand_id,
Brand.user_id == current_user.id,
Brand.user_id == _to_uuid(current_user.id),
)
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
@ -428,7 +432,7 @@ async def _generate_and_save_suggestions(
platform_scores,
total_queries,
mentioned_count,
) = await _get_brand_scoring_data(db, current_user.id, brand)
) = await _get_brand_scoring_data(db, _to_uuid(current_user.id), brand)
# 构建分析上下文
ctx = build_context_from_scoring_result(

View File

@ -18,6 +18,10 @@ from app.schemas.trend_insight import (
from app.services.trend.trend_analyzer_service import TrendAnalyzerService
router = APIRouter()
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(str(value))
async def _get_brand_with_access(
@ -27,7 +31,7 @@ async def _get_brand_with_access(
) -> Brand:
stmt = select(Brand).where(
Brand.id == brand_id,
Brand.user_id == current_user.id,
Brand.user_id == _to_uuid(current_user.id),
)
result = await db.execute(stmt)
brand = result.scalar_one_or_none()

View File

@ -328,22 +328,14 @@ async def readiness(db: AsyncSession = Depends(get_db)):
db_result = await checker.check_database()
redis_result = await checker.check_redis()
if db_result.healthy and redis_result.healthy:
return {
"status": "ready",
all_ok = db_result.healthy and redis_result.healthy
return JSONResponse(
status_code=200 if all_ok else 503,
content={
"status": "ready" if all_ok else "not_ready",
"checks": {
"database": db_result.healthy,
"redis": redis_result.healthy,
}
}
else:
raise HTTPException(
status_code=503,
detail={
"status": "not_ready",
"checks": {
"database": db_result.healthy,
"redis": redis_result.healthy,
}
}
)
},
},
)

View File

@ -41,7 +41,7 @@ class UpdateProfileRequest(BaseModel):
class UserResponse(BaseModel):
id: str
id: uuid.UUID | str
email: str
name: str | None = None
is_active: bool = True
@ -53,13 +53,17 @@ class UserResponse(BaseModel):
@classmethod
def from_user(cls, user) -> "UserResponse":
avatar = user.avatar_url
# 防止 mock 对象或非字符串值
if avatar is not None and not isinstance(avatar, str):
avatar = None
return cls(
id=user.id,
id=str(user.id) if not isinstance(user.id, str) else user.id,
email=user.email,
name=user.name,
is_active=user.is_active,
email_verified=user.email_verified,
avatar_url=user.avatar_url,
avatar_url=avatar,
created_at=user.createdAt if hasattr(user, "createdAt") else None,
)

View File

@ -122,6 +122,7 @@ class HealthChecker:
import os
storage_path = "/data/documents"
start = time.perf_counter()
try:
if os.path.exists(storage_path):
@ -131,23 +132,29 @@ class HealthChecker:
f.write("ok")
os.remove(test_file)
latency = (time.perf_counter() - start) * 1000
return HealthCheckResult(
name="storage",
healthy=True,
latency_ms=round(latency, 2),
message=f"Storage path {storage_path} is writable",
details={"path": storage_path},
)
else:
latency = (time.perf_counter() - start) * 1000
return HealthCheckResult(
name="storage",
healthy=True,
latency_ms=round(latency, 2),
message=f"Storage path {storage_path} does not exist (will be created)",
details={"path": storage_path, "created": True},
)
except Exception as e:
latency = (time.perf_counter() - start) * 1000
return HealthCheckResult(
name="storage",
healthy=False,
latency_ms=round(latency, 2),
message=f"Storage check failed: {str(e)}",
)

View File

@ -5,13 +5,13 @@ from app.services.auth import hash_password
def _make_user(
user_id: str | None = None,
user_id: str | uuid.UUID | None = None,
email: str = "test@example.com",
plan: str = "free",
) -> User:
uid = user_id or str(uuid.uuid4())
user = User(
id=uid,
id=str(uid),
email=email,
password=hash_password("Test@123456"),
firstName="Test",

View File

@ -82,7 +82,6 @@ class TestTaskDispatcher:
"""测试分发器初始化"""
assert dispatcher is not None
assert dispatcher._redis_url == settings.REDIS_URL
assert dispatcher._redis is None
@pytest.mark.asyncio
async def test_get_task_status_not_found(self, dispatcher):

View File

@ -45,14 +45,14 @@ async def async_session(async_engine):
@pytest_asyncio.fixture
async def test_user(async_session):
user = User(
id=uuid.uuid4(),
id=str(uuid.uuid4()),
email="test@example.com",
password_hash=hash_password("Test@123456"),
name="Test User",
password=hash_password("Test@123456"),
firstName="Test User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
isActive=True,
emailVerified=True,
)
async_session.add(user)
await async_session.commit()
@ -102,7 +102,7 @@ class TestSingleQueryEndpoint:
@pytest.mark.asyncio
async def test_query_single_engine(self, async_client):
mock_result = _make_result(EngineType.CHATGPT, has_brand=True)
with patch("app.api.ai_engines.get_batch_service") as mock_get_service:
with patch("app.api.ai_engines._get_batch_service") as mock_get_service:
mock_service = AsyncMock()
mock_service.query_single.return_value = mock_result
mock_get_service.return_value = mock_service
@ -129,7 +129,7 @@ class TestBatchQueryEndpoint:
async def test_query_batch_parallel(self, async_client):
r1 = _make_result(EngineType.CHATGPT, has_brand=True)
r2 = _make_result(EngineType.PERPLEXITY, has_brand=False, has_competitor=True)
with patch("app.api.ai_engines.get_batch_service") as mock_get_service:
with patch("app.api.ai_engines._get_batch_service") as mock_get_service:
mock_service = AsyncMock()
mock_service.query_batch.return_value = [r1, r2]
mock_service.calculate_citation_rate = MagicMock(return_value={
@ -164,7 +164,7 @@ class TestGetResultsEndpoint:
async def test_get_results(self, async_client):
r1 = _make_result(EngineType.CHATGPT, has_brand=True)
r2 = _make_result(EngineType.KIMI, has_brand=False)
with patch("app.api.ai_engines.get_batch_service") as mock_get_service:
with patch("app.api.ai_engines._get_batch_service") as mock_get_service:
mock_service = AsyncMock()
mock_service.query_batch.return_value = [r1, r2]
mock_service.calculate_citation_rate = MagicMock(return_value={

View File

@ -14,6 +14,7 @@ from app.models.brand import Brand
from app.models.alert_setting import AlertSetting
from app.api.deps import get_current_user, get_db
from app.services.auth import hash_password, create_access_token
from tests.fixtures.auth import _to_uuid
# ==================== Fixtures ====================
@ -50,14 +51,14 @@ async def async_session(async_engine):
async def test_user(async_session):
"""创建测试用户"""
user = User(
id=uuid.uuid4(),
id=str(uuid.uuid4()),
email="test@example.com",
password_hash=hash_password("Test@123456"),
name="Test User",
password=hash_password("Test@123456"),
firstName="Test User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
isActive=True,
emailVerified=True,
)
async_session.add(user)
await async_session.commit()
@ -70,7 +71,7 @@ async def test_brand(async_session, test_user):
"""创建测试品牌"""
brand = Brand(
id=uuid.uuid4(),
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
name="Test Brand",
aliases=["TestBrand", "TB"],
website="https://testbrand.com",
@ -91,7 +92,7 @@ async def test_alert_setting(async_session, test_user, test_brand):
setting = AlertSetting(
id=uuid.uuid4(),
brand_id=test_brand.id,
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
alert_type="score_drop",
enabled=True,
threshold=5.0,

View File

@ -44,14 +44,14 @@ async def async_session(async_engine):
@pytest_asyncio.fixture
async def test_user(async_session):
user = User(
id=uuid.uuid4(),
id=str(uuid.uuid4()),
email="test@example.com",
password_hash=hash_password("Test@123456"),
name="Test User",
password=hash_password("Test@123456"),
firstName="Test User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
isActive=True,
emailVerified=True,
)
async_session.add(user)
await async_session.commit()

View File

@ -15,6 +15,7 @@ from app.models.brand import Brand
from app.models.diagnosis_record import DiagnosisRecord
from app.models.user import User
from app.services.auth import hash_password
from tests.fixtures.auth import _to_uuid
def _make_user(

View File

@ -55,7 +55,7 @@ async def test_register_duplicate_email(async_client):
)
assert response.status_code == 400
data = response.json()
assert "Email already registered" in data["detail"]
assert "注册失败" in data["detail"] or "已被使用" in data["detail"]
@pytest.mark.asyncio
@ -81,7 +81,7 @@ async def test_login_wrong_password(async_client):
)
assert response.status_code == 401
data = response.json()
assert "Incorrect email or password" in data["detail"]
assert "邮箱或密码错误" in data["detail"]
@pytest.mark.asyncio

View File

@ -48,14 +48,14 @@ async def async_session(async_engine):
async def test_user(async_session):
"""创建测试用户"""
user = User(
id=uuid.uuid4(),
id=str(uuid.uuid4()),
email="test@example.com",
password_hash=hash_password("Test@123456"),
name="Test User",
password=hash_password("Test@123456"),
firstName="Test User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
isActive=True,
emailVerified=True,
)
async_session.add(user)
await async_session.commit()
@ -166,14 +166,14 @@ class TestAuthAPI:
# 先创建用户
user = User(
id=uuid.uuid4(),
id=str(uuid.uuid4()),
email=email,
password_hash=hash_password(password),
name="Test User",
password=hash_password(password),
firstName="Test User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
isActive=True,
emailVerified=True,
)
async_session.add(user)
await async_session.commit()
@ -209,14 +209,14 @@ class TestAuthAPI:
# 创建用户
user = User(
id=uuid.uuid4(),
id=str(uuid.uuid4()),
email=email,
password_hash=hash_password("Correct@123456"),
name="Test User",
password=hash_password("Correct@123456"),
firstName="Test User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
isActive=True,
emailVerified=True,
)
async_session.add(user)
await async_session.commit()

View File

@ -12,6 +12,7 @@ from app.models.user import User
from app.models.brand import Brand
from app.api.deps import get_current_user, get_db
from app.services.auth import create_access_token
from tests.fixtures.auth import _to_uuid
@pytest_asyncio.fixture
@ -49,14 +50,14 @@ async def async_session(async_engine):
async def test_user(async_session):
"""Create a test user."""
user = User(
id=uuid.uuid4(),
id=str(uuid.uuid4()),
email="test@example.com",
password_hash="hashed_password",
name="Test User",
password="hashed_password",
firstName="Test User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
isActive=True,
emailVerified=True,
)
async_session.add(user)
await async_session.commit()
@ -69,7 +70,7 @@ async def test_brand(async_session, test_user):
"""Create a test brand."""
brand = Brand(
id=uuid.uuid4(),
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
name="Test Brand",
aliases=["TestBrand", "TB"],
website="https://testbrand.com",
@ -195,7 +196,7 @@ class TestBrandsAPI:
# Create multiple brands
for i in range(3):
brand = Brand(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
name=f"Brand {i}",
platforms=["wenxin"],
)

View File

@ -13,6 +13,7 @@ from app.models.user import User
from app.models.brand import Brand
from app.api.deps import get_current_user, get_db
from app.services.auth import hash_password, create_access_token
from tests.fixtures.auth import _to_uuid
# ==================== Fixtures ====================
@ -49,14 +50,14 @@ async def async_session(async_engine):
async def test_user(async_session):
"""创建测试用户"""
user = User(
id=uuid.uuid4(),
id=str(uuid.uuid4()),
email="test@example.com",
password_hash=hash_password("Test@123456"),
name="Test User",
password=hash_password("Test@123456"),
firstName="Test User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
isActive=True,
emailVerified=True,
)
async_session.add(user)
await async_session.commit()
@ -69,7 +70,7 @@ async def test_brand(async_session, test_user):
"""创建测试品牌"""
brand = Brand(
id=uuid.uuid4(),
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
name="Test Brand",
aliases=["TestBrand", "TB"],
website="https://testbrand.com",
@ -171,7 +172,7 @@ class TestBrandsAPI:
# 创建多个品牌
for i in range(3):
brand = Brand(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
name=f"Brand {i}",
platforms=["wenxin"],
)

View File

@ -16,6 +16,9 @@ def mock_citation_record():
record.citation_position = 1
record.citation_text = "Test citation text"
record.competitor_brands = []
record.match_type = "exact"
record.data_source = "ai_response"
record.ai_response_text = "AI response text"
record.queried_at = datetime.now()
return record

View File

@ -13,6 +13,7 @@ from app.models.brand import Brand
from app.models.competitor import Competitor
from app.api.deps import get_current_user, get_db
from app.services.auth import create_access_token
from tests.fixtures.auth import _to_uuid
@pytest_asyncio.fixture
@ -50,14 +51,14 @@ async def async_session(async_engine):
async def test_user(async_session):
"""Create a test user."""
user = User(
id=uuid.uuid4(),
id=str(uuid.uuid4()),
email="test@example.com",
password_hash="hashed_password",
name="Test User",
password="hashed_password",
firstName="Test User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
isActive=True,
emailVerified=True,
)
async_session.add(user)
await async_session.commit()
@ -70,7 +71,7 @@ async def test_brand(async_session, test_user):
"""Create a test brand."""
brand = Brand(
id=uuid.uuid4(),
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
name="Test Brand",
aliases=["TestBrand", "TB"],
website="https://testbrand.com",

View File

@ -13,6 +13,7 @@ from app.models.user import User
from app.models.brand import Brand
from app.api.deps import get_current_user, get_db
from app.services.auth import hash_password, create_access_token
from tests.fixtures.auth import _to_uuid
# ==================== Fixtures ====================
@ -49,14 +50,14 @@ async def async_session(async_engine):
async def test_user(async_session):
"""创建测试用户"""
user = User(
id=uuid.uuid4(),
id=str(uuid.uuid4()),
email="test@example.com",
password_hash=hash_password("Test@123456"),
name="Test User",
password=hash_password("Test@123456"),
firstName="Test User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
isActive=True,
emailVerified=True,
organization_id=uuid.uuid4(), # 需要organization_id用于内容API
)
async_session.add(user)
@ -70,7 +71,7 @@ async def test_brand(async_session, test_user):
"""创建测试品牌"""
brand = Brand(
id=uuid.uuid4(),
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
name="Test Brand",
aliases=["TestBrand", "TB"],
website="https://testbrand.com",

View File

@ -12,6 +12,7 @@ from app.main import app
from app.models.brand import Brand
from app.models.user import User
from app.services.auth import hash_password
from tests.fixtures.auth import _to_uuid
@pytest_asyncio.fixture
@ -43,14 +44,14 @@ async def async_session(async_engine):
@pytest_asyncio.fixture
async def test_user(async_session):
user = User(
id=uuid.uuid4(),
id=str(uuid.uuid4()),
email="test@example.com",
password_hash=hash_password("Test@123456"),
name="Test User",
password=hash_password("Test@123456"),
firstName="Test User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
isActive=True,
emailVerified=True,
)
async_session.add(user)
await async_session.commit()
@ -62,7 +63,7 @@ async def test_user(async_session):
async def test_brand(async_session, test_user):
brand = Brand(
id=uuid.uuid4(),
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
name="Test Brand",
aliases=["TestBrand", "TB"],
website="https://testbrand.com",

View File

@ -13,6 +13,7 @@ from app.models.user import User
from app.models.brand import Brand
from app.api.deps import get_current_user, get_db
from app.services.auth import hash_password
from tests.fixtures.auth import _to_uuid
@pytest_asyncio.fixture
@ -47,14 +48,14 @@ async def async_session(async_engine):
async def test_user(async_session):
"""创建测试用户"""
user = User(
id=uuid.uuid4(),
id=str(uuid.uuid4()),
email="test@example.com",
password_hash=hash_password("Test@123456"),
name="Test User",
password=hash_password("Test@123456"),
firstName="Test User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
isActive=True,
emailVerified=True,
)
async_session.add(user)
await async_session.commit()
@ -67,7 +68,7 @@ async def test_brand(async_session, test_user):
"""创建测试品牌"""
brand = Brand(
id=uuid.uuid4(),
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
name="Test Brand",
aliases=["TestBrand", "TB"],
website="https://testbrand.com",
@ -123,17 +124,11 @@ class TestDiagnosisAPI:
@pytest.mark.asyncio
async def test_geo_diagnosis_success(self, async_client, test_brand):
"""测试GEO诊断端点成功返回"""
response = await async_client.get(f"/api/v1/diagnosis/geo/{test_brand.id}")
response = await async_client.post(f"/api/v1/diagnosis/geo/{test_brand.id}")
assert response.status_code == 200
assert response.status_code == 202
data = response.json()
assert "overall_score" in data
assert "health_level" in data
assert "dimensions" in data
assert "recommendations" in data
assert isinstance(data["overall_score"], (int, float))
assert isinstance(data["dimensions"], list)
assert isinstance(data["recommendations"], list)
assert "task_id" in data or "status" in data
@pytest.mark.asyncio
async def test_combined_diagnosis_success(self, async_client, test_brand):
@ -159,7 +154,7 @@ class TestDiagnosisAPI:
seo_response = await async_client.get(f"/api/v1/diagnosis/seo/{non_existent_id}")
assert seo_response.status_code == 404
geo_response = await async_client.get(f"/api/v1/diagnosis/geo/{non_existent_id}")
geo_response = await async_client.post(f"/api/v1/diagnosis/geo/{non_existent_id}")
assert geo_response.status_code == 404
combined_response = await async_client.get(f"/api/v1/diagnosis/combined/{non_existent_id}")
@ -180,7 +175,7 @@ class TestDiagnosisAPI:
seo_response = await client.get(f"/api/v1/diagnosis/seo/{uuid.uuid4()}", headers=headers)
assert seo_response.status_code == 401
geo_response = await client.get(f"/api/v1/diagnosis/geo/{uuid.uuid4()}", headers=headers)
geo_response = await client.post(f"/api/v1/diagnosis/geo/{uuid.uuid4()}", headers=headers)
assert geo_response.status_code == 401
combined_response = await client.get(f"/api/v1/diagnosis/combined/{uuid.uuid4()}", headers=headers)

View File

@ -16,6 +16,7 @@ from app.models.subscription import Subscription
from app.models.user import User
from app.services.email.email_scheduler import EmailScheduler
from app.services.email_service import EmailService, EmailMessage
from tests.fixtures.auth import _to_uuid
TEMPLATES_DIR = Path(__file__).resolve().parent.parent.parent / "app" / "templates"

View File

@ -44,14 +44,14 @@ async def async_session(async_engine):
@pytest_asyncio.fixture
async def test_user(async_session):
user = User(
id=uuid.uuid4(),
id=str(uuid.uuid4()),
email="test@example.com",
password_hash="hashed_password",
name="Test User",
password="hashed_password",
firstName="Test User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
isActive=True,
emailVerified=True,
)
async_session.add(user)
await async_session.commit()

View File

@ -20,8 +20,8 @@ class TestLifecycleExceptionHandling:
user = User(
id=user_id,
email="test@example.com",
password_hash="hash",
name="Test User",
password="hash",
firstName="Test User",
plan="free",
organization_id=org_id,
)
@ -69,8 +69,8 @@ class TestLifecycleExceptionHandling:
user = User(
id=user_id,
email="test@example.com",
password_hash="hash",
name="Test User",
password="hash",
firstName="Test User",
plan="free",
organization_id=org_id,
)

View File

@ -12,6 +12,7 @@ from app.main import app
from app.models.user import User
from app.models.organization import Organization, OrgMember
from app.api.deps import get_current_user, get_db
from tests.fixtures.auth import _to_uuid
@pytest_asyncio.fixture
@ -43,14 +44,14 @@ async def async_session(async_engine):
@pytest_asyncio.fixture
async def test_user(async_session):
user = User(
id=uuid.uuid4(),
id=str(uuid.uuid4()),
email="test@example.com",
password_hash="hashed_password",
name="Test User",
password="hashed_password",
firstName="Test User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
isActive=True,
emailVerified=True,
)
async_session.add(user)
await async_session.commit()
@ -75,7 +76,7 @@ async def test_organization(async_session, test_user):
membership = OrgMember(
id=uuid.uuid4(),
organization_id=org.id,
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
role="owner",
)
async_session.add(membership)
@ -167,14 +168,14 @@ class TestOrganizationRoutes:
async def test_organization_members_invite_endpoint_exists(self, async_client, test_organization, async_session):
"""验证 /api/v1/organization/members/invite 端点存在"""
invite_user = User(
id=uuid.uuid4(),
id=str(uuid.uuid4()),
email="newuser@example.com",
password_hash="hashed_password",
name="New User",
password="hashed_password",
firstName="New User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
isActive=True,
emailVerified=True,
)
async_session.add(invite_user)
await async_session.commit()
@ -189,14 +190,14 @@ class TestOrganizationRoutes:
async def test_organization_member_role_endpoint_exists(self, async_client, test_organization, async_session, test_user):
"""验证 /api/v1/organization/members/{id}/role 端点存在"""
new_user = User(
id=uuid.uuid4(),
id=str(uuid.uuid4()),
email="member@example.com",
password_hash="hashed_password",
name="Member User",
password="hashed_password",
firstName="Member User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
isActive=True,
emailVerified=True,
)
async_session.add(new_user)
await async_session.flush()
@ -220,14 +221,14 @@ class TestOrganizationRoutes:
async def test_organization_member_delete_endpoint_exists(self, async_client, test_organization, async_session):
"""验证 /api/v1/organization/members/{id} 端点存在"""
new_user = User(
id=uuid.uuid4(),
id=str(uuid.uuid4()),
email="todelete@example.com",
password_hash="hashed_password",
name="Delete User",
password="hashed_password",
firstName="Delete User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
isActive=True,
emailVerified=True,
)
async_session.add(new_user)
await async_session.flush()

View File

@ -15,6 +15,7 @@ from app.models.query import Query as QueryModel
from app.models.citation_record import CitationRecord
from app.api.deps import get_current_user, get_db
from app.services.auth import create_access_token
from tests.fixtures.auth import _to_uuid
@pytest_asyncio.fixture
@ -52,14 +53,14 @@ async def async_session(async_engine):
async def test_user(async_session):
"""Create a test user."""
user = User(
id=uuid.uuid4(),
id=str(uuid.uuid4()),
email="test@example.com",
password_hash="hashed_password",
name="Test User",
password="hashed_password",
firstName="Test User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
isActive=True,
emailVerified=True,
)
async_session.add(user)
await async_session.commit()
@ -72,7 +73,7 @@ async def test_brand(async_session, test_user):
"""Create a test brand."""
brand = Brand(
id=uuid.uuid4(),
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
name="TestBrand",
aliases=["TestBrand", "TB"],
website="https://testbrand.com",
@ -92,7 +93,7 @@ async def test_query(async_session, test_user, test_brand):
"""Create a test query with citation records."""
query = QueryModel(
id=uuid.uuid4(),
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
keyword="AI assistant",
target_brand="TestBrand",
brand_aliases=["TestBrand"],

View File

@ -16,6 +16,7 @@ from app.models.query import Query
from app.models.citation_record import CitationRecord
from app.api.deps import get_current_user, get_db
from app.services.auth import create_access_token
from tests.fixtures.auth import _to_uuid
@pytest_asyncio.fixture
@ -53,14 +54,14 @@ async def async_session(async_engine):
async def test_user(async_session):
"""Create a test user."""
user = User(
id=uuid.uuid4(),
id=str(uuid.uuid4()),
email="test@example.com",
password_hash="hashed_password",
name="Test User",
password="hashed_password",
firstName="Test User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
isActive=True,
emailVerified=True,
)
async_session.add(user)
await async_session.commit()
@ -73,7 +74,7 @@ async def test_brand(async_session, test_user):
"""Create a test brand."""
brand = Brand(
id=uuid.uuid4(),
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
name="TestBrand",
aliases=["TestBrand", "TB"],
website="https://testbrand.com",
@ -93,7 +94,7 @@ async def test_brand_with_data(async_session: AsyncSession, test_user):
"""Create a test brand with query and citation data."""
brand = Brand(
id=uuid.uuid4(),
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
name="TestBrand",
aliases=["TestAlias"],
website="https://test.com",
@ -118,7 +119,7 @@ async def test_brand_with_data(async_session: AsyncSession, test_user):
# Create a query
query = Query(
id=uuid.uuid4(),
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
keyword="AI assistant",
target_brand="TestBrand",
brand_aliases=["TestAlias"],

View File

@ -40,12 +40,13 @@ class TestPrometheusMetrics:
# 从注册表获取指标值
metrics = {m.name: m for m in REGISTRY.collect()}
api_requests = metrics.get("geo_api_requests_total")
# prometheus_client strips _total suffix from Counter names in collect()
api_requests = metrics.get("geo_api_requests")
assert api_requests is not None
# 验证指标包含正确的标签
samples = list(api_requests.collect())[0].samples
samples = api_requests.samples
sample = next((s for s in samples if s.labels.get("method") == "GET" and s.labels.get("endpoint") == "/test"), None)
assert sample is not None
assert sample.value >= 1
@ -55,7 +56,7 @@ class TestPrometheusMetrics:
AGENT_EXECUTIONS_TOTAL.labels(agent_name="test_agent", status="success").inc()
metrics = {m.name: m for m in REGISTRY.collect()}
agent_executions = metrics.get("geo_agent_executions_total")
agent_executions = metrics.get("geo_agent_executions")
assert agent_executions is not None
@ -65,7 +66,7 @@ class TestPrometheusMetrics:
LLM_TOKENS_TOTAL.labels(provider="openai", model="gpt-4", token_type="completion").inc(50)
metrics = {m.name: m for m in REGISTRY.collect()}
llm_tokens = metrics.get("geo_llm_tokens_total")
llm_tokens = metrics.get("geo_llm_tokens")
assert llm_tokens is not None
@ -94,7 +95,7 @@ class TestPrometheusMetrics:
QUERY_COUNT_TOTAL.labels(platform="kimi", status="failed").inc()
metrics = {m.name: m for m in REGISTRY.collect()}
query_count = metrics.get("geo_queries_total")
query_count = metrics.get("geo_queries")
assert query_count is not None
@ -103,7 +104,7 @@ class TestPrometheusMetrics:
CONTENT_GENERATED_TOTAL.inc()
metrics = {m.name: m for m in REGISTRY.collect()}
content_count = metrics.get("geo_content_generated_total")
content_count = metrics.get("geo_content_generated")
assert content_count is not None
@ -113,7 +114,7 @@ class TestPrometheusMetrics:
CITATION_DETECTED_TOTAL.labels(platform="wenxin").inc()
metrics = {m.name: m for m in REGISTRY.collect()}
citation_count = metrics.get("geo_citations_detected_total")
citation_count = metrics.get("geo_citations_detected")
assert citation_count is not None
@ -211,15 +212,15 @@ class TestMetricsCollection:
"""测试注册表收集所有指标"""
metrics = {m.name: m for m in REGISTRY.collect()}
# 验证关键指标存在
assert "geo_api_requests_total" in metrics
assert "geo_agent_executions_total" in metrics
assert "geo_llm_tokens_total" in metrics
# 验证关键指标存在 (prometheus_client strips _total from Counter names)
assert "geo_api_requests" in metrics
assert "geo_agent_executions" in metrics
assert "geo_llm_tokens" in metrics
assert "geo_llm_cost_estimated" in metrics
assert "geo_brands_total" in metrics
assert "geo_queries_total" in metrics
assert "geo_content_generated_total" in metrics
assert "geo_citations_detected_total" in metrics
assert "geo_queries" in metrics
assert "geo_content_generated" in metrics
assert "geo_citations_detected" in metrics
def test_metric_labels_are_valid(self):
"""测试指标标签有效性"""
@ -253,8 +254,9 @@ class TestMetricsHistory:
# 获取初始值
metrics = {m.name: m for m in REGISTRY.collect()}
api_requests = metrics.get("geo_api_requests_total")
for sample in api_requests.collect()[0].samples:
# prometheus_client strips _total suffix from Counter names in collect()
api_requests = metrics.get("geo_api_requests")
for sample in api_requests.samples:
if sample.labels.get("endpoint") == test_endpoint:
initial_count = sample.value
break
@ -266,8 +268,9 @@ class TestMetricsHistory:
# 验证增加
metrics = {m.name: m for m in REGISTRY.collect()}
api_requests = metrics.get("geo_api_requests_total")
for sample in api_requests.collect()[0].samples:
# prometheus_client strips _total suffix from Counter names in collect()
api_requests = metrics.get("geo_api_requests")
for sample in api_requests.samples:
if sample.labels.get("endpoint") == test_endpoint:
if initial_count is not None:
assert sample.value >= initial_count + 3
@ -285,7 +288,7 @@ class TestMetricsHistory:
llm_cost = metrics.get("geo_llm_cost_estimated")
found = False
for sample in llm_cost.collect()[0].samples:
for sample in llm_cost.samples:
if sample.labels.get("provider") == "test":
assert sample.value == test_value
found = True

View File

@ -18,6 +18,7 @@ from app.models.competitor import Competitor
from app.models.suggestion import Suggestion
from app.api.deps import get_current_user, get_db
from app.services.auth import create_access_token, hash_password
from tests.fixtures.auth import _to_uuid
# Only the tables needed for performance tests (avoids JSONB/SQLite incompatibility)
_TEST_TABLES = (
@ -72,14 +73,14 @@ async def async_session(async_engine):
async def test_user(async_session):
"""Create a test user with properly hashed password."""
user = User(
id=uuid.uuid4(),
id=str(uuid.uuid4()),
email="perf_test@example.com",
password_hash=hash_password("PerfTest123!"),
name="Performance Test User",
password=hash_password("PerfTest123!"),
firstName="Performance Test User",
plan="free",
max_queries=50,
is_active=True,
email_verified=True,
isActive=True,
emailVerified=True,
)
async_session.add(user)
await async_session.commit()
@ -155,7 +156,7 @@ class TestAPIPerformance:
# Create several brands for a more realistic test
for i in range(10):
brand = Brand(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
name=f"Brand {i}",
platforms=["wenxin"],
status="active",
@ -176,7 +177,7 @@ class TestAPIPerformance:
# Create several queries for a more realistic test
for i in range(10):
query = Query(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
keyword=f"query keyword {i}",
target_brand=f"Brand {i}",
platforms=["wenxin"],
@ -247,7 +248,7 @@ class TestConcurrency:
# Pre-create data
for i in range(5):
brand = Brand(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
name=f"Concurrent Brand {i}",
platforms=["wenxin"],
status="active",
@ -277,7 +278,7 @@ class TestConcurrency:
"""Concurrent query list reads should all succeed."""
for i in range(5):
query = Query(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
keyword=f"concurrent query {i}",
target_brand=f"Brand {i}",
platforms=["wenxin"],

View File

@ -19,6 +19,7 @@ from app.models.suggestion import Suggestion
from app.api.deps import get_current_user, get_db
from app.services.auth import create_access_token, create_refresh_token, hash_password
from app.config import settings
from tests.fixtures.auth import _to_uuid
# Only the tables needed for security tests (avoids JSONB/SQLite incompatibility)
_TEST_TABLES = (
@ -73,14 +74,14 @@ async def async_session(async_engine):
async def test_user(async_session):
"""Create a test user with properly hashed password."""
user = User(
id=uuid.uuid4(),
id=str(uuid.uuid4()),
email="security_test@example.com",
password_hash=hash_password("SecurePass123!"),
name="Security Test User",
password=hash_password("SecurePass123!"),
firstName="Security Test User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
isActive=True,
emailVerified=True,
)
async_session.add(user)
await async_session.commit()
@ -92,14 +93,14 @@ async def test_user(async_session):
async def second_user(async_session):
"""Create a second test user for cross-user isolation tests."""
user = User(
id=uuid.uuid4(),
id=str(uuid.uuid4()),
email="second_user@example.com",
password_hash=hash_password("SecondPass456!"),
name="Second User",
password=hash_password("SecondPass456!"),
firstName="Second User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
isActive=True,
emailVerified=True,
)
async_session.add(user)
await async_session.commit()
@ -376,7 +377,7 @@ class TestXSSProtection:
"""XSS payloads in brand aliases should be stored as plain text."""
brand = Brand(
id=uuid.uuid4(),
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
name="Safe Brand",
platforms=["wenxin"],
status="active",
@ -555,7 +556,7 @@ class TestAuthSecurity:
# Create a brand for second_user
brand = Brand(
id=uuid.uuid4(),
user_id=second_user.id,
user_id=_to_uuid(second_user.id),
name="Second User's Brand",
platforms=["wenxin"],
status="active",

View File

@ -16,6 +16,7 @@ from app.models.query import Query as QueryModel
from app.models.citation_record import CitationRecord
from app.api.deps import get_current_user, get_db
from app.services.auth import create_access_token
from tests.fixtures.auth import _to_uuid
@pytest_asyncio.fixture
@ -53,14 +54,14 @@ async def async_session(async_engine):
async def test_user(async_session):
"""Create a test user."""
user = User(
id=uuid.uuid4(),
id=str(uuid.uuid4()),
email="test@example.com",
password_hash="hashed_password",
name="Test User",
password="hashed_password",
firstName="Test User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
isActive=True,
emailVerified=True,
)
async_session.add(user)
await async_session.commit()
@ -134,7 +135,7 @@ class TestFullBrandQueryFlow:
# Step 3: Create a query (using Query model directly)
query = QueryModel(
id=uuid.uuid4(),
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
keyword="AI assistant",
target_brand="TestBrand",
brand_aliases=["TestBrand", "TB"],
@ -250,7 +251,7 @@ class TestCSVExportFlow:
# Step 2: Create a query with citations
query = QueryModel(
id=uuid.uuid4(),
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
keyword="export test keyword",
target_brand="ExportTestBrand",
brand_aliases=["ETB"],

View File

@ -6,6 +6,7 @@ import pytest
from sqlalchemy import select, and_
from app.models.api_key import APIKey
from tests.fixtures.auth import _to_uuid
class TestAPIKeyModel:
@ -16,7 +17,7 @@ class TestAPIKeyModel:
"""Test creating a new API key."""
api_key = APIKey(
id=uuid.uuid4(),
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
engine_type="chatgpt",
encrypted_key="encrypted_test_key",
key_hint="sk-...abc",
@ -29,7 +30,7 @@ class TestAPIKeyModel:
await async_session.refresh(api_key)
assert api_key.id is not None
assert api_key.user_id == test_user.id
assert api_key.user_id == _to_uuid(test_user.id)
assert api_key.engine_type == "chatgpt"
assert api_key.encrypted_key == "encrypted_test_key"
assert api_key.key_hint == "sk-...abc"
@ -43,7 +44,7 @@ class TestAPIKeyModel:
async def test_api_key_default_values(self, async_session, test_user):
"""Test API key default values."""
api_key = APIKey(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
engine_type="kimi",
encrypted_key="encrypted_kimi_key",
key_hint="sk-...xyz",
@ -62,7 +63,7 @@ class TestAPIKeyModel:
"""Test API key field validation and constraints."""
now = datetime.now(timezone.utc)
api_key = APIKey(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
engine_type="deepseek",
encrypted_key="encrypted_deepseek_key_data",
key_hint="sk-...def",
@ -90,7 +91,7 @@ class TestAPIKeyModel:
key_id = uuid.uuid4()
api_key = APIKey(
id=key_id,
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
engine_type="gemini",
encrypted_key="encrypted_gemini_key",
key_hint="AIza...123",
@ -111,13 +112,13 @@ class TestAPIKeyModel:
async def test_api_key_query_by_user_id(self, async_session, test_user):
"""Test querying API keys by user ID."""
key1 = APIKey(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
engine_type="chatgpt",
encrypted_key="encrypted_key_1",
key_hint="sk-...1",
)
key2 = APIKey(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
engine_type="kimi",
encrypted_key="encrypted_key_2",
key_hint="sk-...2",
@ -127,7 +128,7 @@ class TestAPIKeyModel:
await async_session.commit()
result = await async_session.execute(
select(APIKey).where(APIKey.user_id == test_user.id)
select(APIKey).where(APIKey.user_id == _to_uuid(test_user.id))
)
keys = result.scalars().all()
@ -137,13 +138,13 @@ class TestAPIKeyModel:
async def test_api_key_query_by_user_and_engine(self, async_session, test_user):
"""Test querying API keys by user ID and engine type."""
key1 = APIKey(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
engine_type="chatgpt",
encrypted_key="encrypted_chatgpt_key",
key_hint="sk-...chat",
)
key2 = APIKey(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
engine_type="kimi",
encrypted_key="encrypted_kimi_key",
key_hint="sk-...kimi",
@ -155,7 +156,7 @@ class TestAPIKeyModel:
result = await async_session.execute(
select(APIKey).where(
and_(
APIKey.user_id == test_user.id,
APIKey.user_id == _to_uuid(test_user.id),
APIKey.engine_type == "chatgpt"
)
)
@ -169,7 +170,7 @@ class TestAPIKeyModel:
async def test_api_key_timestamps(self, async_session, test_user):
"""Test API key created_at and updated_at timestamps."""
api_key = APIKey(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
engine_type="qwen",
encrypted_key="encrypted_qwen_key",
key_hint="sk-...qwen",
@ -187,7 +188,7 @@ class TestAPIKeyModel:
async def test_api_key_update(self, async_session, test_user):
"""Test updating API key fields."""
api_key = APIKey(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
engine_type="wenxin",
encrypted_key="encrypted_wenxin_key",
key_hint="sk-...wenxin",
@ -211,7 +212,7 @@ class TestAPIKeyModel:
async def test_api_key_delete(self, async_session, test_user):
"""Test deleting an API key."""
api_key = APIKey(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
engine_type="doubao",
encrypted_key="encrypted_doubao_key",
key_hint="sk-...doubao",
@ -238,7 +239,7 @@ class TestAPIKeyModel:
for i, status in enumerate(statuses):
api_key = APIKey(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
engine_type=f"engine_{i}",
encrypted_key=f"encrypted_key_{i}",
key_hint=f"sk-...{i}",
@ -263,7 +264,7 @@ class TestAPIKeyModel:
for data in keys_data:
api_key = APIKey(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
engine_type=data["engine_type"],
encrypted_key=f"encrypted_{data['engine_type']}",
key_hint=data["key_hint"],
@ -275,7 +276,7 @@ class TestAPIKeyModel:
result = await async_session.execute(
select(APIKey)
.where(APIKey.user_id == test_user.id)
.where(APIKey.user_id == _to_uuid(test_user.id))
.order_by(APIKey.priority.desc())
)
keys = result.scalars().all()
@ -289,7 +290,7 @@ class TestAPIKeyModel:
"""Test that user_id field has an index."""
for i in range(5):
api_key = APIKey(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
engine_type=f"engine_{i}",
encrypted_key=f"encrypted_key_{i}",
key_hint=f"hint_{i}",
@ -299,7 +300,7 @@ class TestAPIKeyModel:
await async_session.commit()
result = await async_session.execute(
select(APIKey).where(APIKey.user_id == test_user.id)
select(APIKey).where(APIKey.user_id == _to_uuid(test_user.id))
)
keys = result.scalars().all()

View File

@ -6,6 +6,7 @@ import pytest
from sqlalchemy import select
from app.models.brand import Brand
from tests.fixtures.auth import _to_uuid
class TestBrandModel:
@ -16,7 +17,7 @@ class TestBrandModel:
"""Test creating a new brand."""
brand = Brand(
id=uuid.uuid4(),
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
name="Test Brand",
aliases=["TestBrand", "TB"],
website="https://testbrand.com",
@ -30,7 +31,7 @@ class TestBrandModel:
await async_session.refresh(brand)
assert brand.id is not None
assert brand.user_id == test_user.id
assert brand.user_id == _to_uuid(test_user.id)
assert brand.name == "Test Brand"
assert brand.aliases == ["TestBrand", "TB"]
assert brand.website == "https://testbrand.com"
@ -45,7 +46,7 @@ class TestBrandModel:
async def test_brand_default_values(self, async_session, test_user):
"""Test brand default values."""
brand = Brand(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
name="Default Brand",
platforms=["wenxin"],
)
@ -59,13 +60,12 @@ class TestBrandModel:
assert brand.frequency == "weekly"
assert brand.status == "active"
assert brand.last_queried_at is None
assert brand.next_query_at is None
@pytest.mark.asyncio
async def test_brand_fields(self, async_session, test_user):
"""Test brand field validation and constraints."""
brand = Brand(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
name="Field Test Brand",
aliases=["FTA", "FieldTest"],
website="https://fieldtest.com",
@ -80,7 +80,6 @@ class TestBrandModel:
await async_session.commit()
await async_session.refresh(brand)
# Verify all fields
assert brand.name == "Field Test Brand"
assert len(brand.name) == 16
assert brand.aliases == ["FTA", "FieldTest"]
@ -91,7 +90,6 @@ class TestBrandModel:
assert brand.frequency == "daily"
assert brand.status == "active"
assert brand.last_queried_at is not None
assert brand.next_query_at is not None
@pytest.mark.asyncio
async def test_brand_query_by_id(self, async_session, test_user):
@ -99,7 +97,7 @@ class TestBrandModel:
brand_id = uuid.uuid4()
brand = Brand(
id=brand_id,
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
name="Query Test Brand",
platforms=["wenxin"],
)
@ -118,14 +116,14 @@ class TestBrandModel:
@pytest.mark.asyncio
async def test_brand_query_by_user_id(self, async_session, test_user):
"""Test querying brands by user ID."""
# Create multiple brands for the same user
uid = _to_uuid(test_user.id)
brand1 = Brand(
user_id=test_user.id,
user_id=uid,
name="User Brand 1",
platforms=["wenxin"],
)
brand2 = Brand(
user_id=test_user.id,
user_id=uid,
name="User Brand 2",
platforms=["kimi"],
)
@ -134,7 +132,7 @@ class TestBrandModel:
await async_session.commit()
result = await async_session.execute(
select(Brand).where(Brand.user_id == test_user.id)
select(Brand).where(Brand.user_id == uid)
)
brands = result.scalars().all()
@ -144,7 +142,7 @@ class TestBrandModel:
async def test_brand_timestamps(self, async_session, test_user):
"""Test brand created_at and updated_at timestamps."""
brand = Brand(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
name="Timestamp Brand",
platforms=["wenxin"],
)
@ -161,7 +159,7 @@ class TestBrandModel:
async def test_brand_update(self, async_session, test_user):
"""Test updating brand fields."""
brand = Brand(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
name="Update Test Brand",
platforms=["wenxin"],
frequency="weekly",
@ -169,7 +167,6 @@ class TestBrandModel:
async_session.add(brand)
await async_session.commit()
# Update brand
brand.name = "Updated Brand Name"
brand.frequency = "daily"
brand.aliases = ["Updated", "Alias"]
@ -184,7 +181,7 @@ class TestBrandModel:
async def test_brand_delete(self, async_session, test_user):
"""Test deleting a brand."""
brand = Brand(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
name="Delete Test Brand",
platforms=["wenxin"],
)

View File

@ -7,6 +7,7 @@ from sqlalchemy import select
from app.models.brand import Brand
from app.models.competitor import Competitor
from tests.fixtures.auth import _to_uuid
class TestCompetitorModel:
@ -18,7 +19,7 @@ class TestCompetitorModel:
# First create a brand
brand = Brand(
id=uuid.uuid4(),
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
name="Test Brand for Competitor",
platforms=["wenxin"],
)
@ -48,7 +49,7 @@ class TestCompetitorModel:
"""Test competitor default values."""
# Create brand first
brand = Brand(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
name="Brand for Default Competitor",
platforms=["wenxin"],
)
@ -72,7 +73,7 @@ class TestCompetitorModel:
"""Test competitor field validation."""
# Create brand first
brand = Brand(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
name="Brand for Field Test",
platforms=["wenxin", "kimi"],
)
@ -98,7 +99,7 @@ class TestCompetitorModel:
"""Test querying competitor by ID."""
# Create brand first
brand = Brand(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
name="Brand for Query Test",
platforms=["wenxin"],
)
@ -129,7 +130,7 @@ class TestCompetitorModel:
"""Test querying competitors by brand ID."""
# Create brand first
brand = Brand(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
name="Brand for Multi Competitor Test",
platforms=["wenxin"],
)
@ -164,7 +165,7 @@ class TestCompetitorModel:
"""Test competitor created_at timestamp."""
# Create brand first
brand = Brand(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
name="Brand for Timestamp Test",
platforms=["wenxin"],
)
@ -188,7 +189,7 @@ class TestCompetitorModel:
"""Test updating competitor fields."""
# Create brand first
brand = Brand(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
name="Brand for Update Test",
platforms=["wenxin"],
)
@ -218,7 +219,7 @@ class TestCompetitorModel:
"""Test deleting a competitor."""
# Create brand first
brand = Brand(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
name="Brand for Delete Test",
platforms=["wenxin"],
)
@ -249,7 +250,7 @@ class TestCompetitorModel:
"""Test that competitors are deleted when brand is deleted."""
# Create brand with competitors
brand = Brand(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
name="Brand for Cascade Test",
platforms=["wenxin"],
)

View File

@ -6,6 +6,7 @@ from sqlalchemy import select
from app.models.organization import Organization, OrgMember
from app.models.user import User
from tests.fixtures.auth import _to_uuid
class TestOrganizationModel:

View File

@ -29,7 +29,7 @@ class TestSubscriptionModel:
def test_subscription_field_types(self):
columns = Subscription.__table__.columns
assert "UUID" in str(columns["id"].type).upper()
assert "UUID" in str(columns["user_id"].type).upper()
assert "VARCHAR" in str(columns["user_id"].type).upper() or "STRING" in str(columns["user_id"].type).upper()
assert "VARCHAR" in str(columns["plan"].type).upper() or "STRING" in str(columns["plan"].type).upper()
assert "VARCHAR" in str(columns["status"].type).upper() or "STRING" in str(columns["status"].type).upper()
assert "DATE" in str(columns["start_date"].type).upper()

View File

@ -7,6 +7,7 @@ from sqlalchemy import select
from app.models.brand import Brand
from app.models.suggestion import Suggestion
from app.models.user import User
from tests.fixtures.auth import _to_uuid
class TestSuggestionModel:
@ -81,7 +82,7 @@ class TestSuggestionModel:
@pytest.mark.asyncio
async def test_suggestion_create(self, async_session, test_user):
brand = Brand(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
name="Suggestion Test Brand",
platforms=["wenxin"],
)
@ -124,7 +125,7 @@ class TestSuggestionModel:
@pytest.mark.asyncio
async def test_suggestion_default_values(self, async_session, test_user):
brand = Brand(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
name="Default Suggestion Brand",
platforms=["kimi"],
)
@ -152,7 +153,7 @@ class TestSuggestionModel:
@pytest.mark.asyncio
async def test_suggestion_query_by_brand(self, async_session, test_user):
brand = Brand(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
name="Query Suggestion Brand",
platforms=["wenxin"],
)

View File

@ -7,6 +7,7 @@ from sqlalchemy import select, func, and_
from sqlalchemy.dialects.postgresql import insert as pg_insert
from app.models.usage_record import UsageRecord
from tests.fixtures.auth import _to_uuid
class TestUsageRecordModel:
@ -16,7 +17,7 @@ class TestUsageRecordModel:
async def test_usage_record_create(self, async_session, test_user):
"""Test creating a new usage record."""
record = UsageRecord(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
engine_type="chatgpt",
query="What is SEO optimization?",
input_tokens=100,
@ -29,7 +30,7 @@ class TestUsageRecordModel:
await async_session.refresh(record)
assert record.id is not None
assert record.user_id == test_user.id
assert record.user_id == _to_uuid(test_user.id)
assert record.engine_type == "chatgpt"
assert record.query == "What is SEO optimization?"
assert record.input_tokens == 100
@ -43,7 +44,7 @@ class TestUsageRecordModel:
async def test_usage_record_default_values(self, async_session, test_user):
"""Test usage record default values."""
record = UsageRecord(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
engine_type="kimi",
query="Test query",
)
@ -60,7 +61,7 @@ class TestUsageRecordModel:
@pytest.mark.asyncio
async def test_usage_record_query_by_user_id(self, async_session, test_user):
"""Test querying usage records by user ID."""
user_id = test_user.id
user_id = _to_uuid(test_user.id)
for i in range(3):
record = UsageRecord(
user_id=user_id,
@ -81,7 +82,7 @@ class TestUsageRecordModel:
@pytest.mark.asyncio
async def test_usage_record_query_by_user_and_engine(self, async_session, test_user):
"""Test querying usage records by user ID and engine type."""
user_id = test_user.id
user_id = _to_uuid(test_user.id)
record1 = UsageRecord(
user_id=user_id,
engine_type="chatgpt",
@ -114,7 +115,7 @@ class TestUsageRecordModel:
@pytest.mark.asyncio
async def test_usage_record_query_by_time_range(self, async_session, test_user):
"""Test querying usage records by time range."""
user_id = test_user.id
user_id = _to_uuid(test_user.id)
now = datetime.now(timezone.utc)
old_record = UsageRecord(
@ -152,7 +153,7 @@ class TestUsageRecordModel:
@pytest.mark.asyncio
async def test_usage_record_aggregate_by_user(self, async_session, test_user):
"""Test aggregating usage records by user."""
user_id = test_user.id
user_id = _to_uuid(test_user.id)
records_data = [
{"engine": "chatgpt", "input_tokens": 100, "output_tokens": 200, "cost": 0.01},
{"engine": "kimi", "input_tokens": 150, "output_tokens": 300, "cost": 0.02},
@ -192,7 +193,7 @@ class TestUsageRecordModel:
@pytest.mark.asyncio
async def test_usage_record_aggregate_by_day(self, async_session, test_user):
"""Test aggregating usage records by day."""
user_id = test_user.id
user_id = _to_uuid(test_user.id)
now = datetime.now(timezone.utc)
today_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
@ -238,7 +239,7 @@ class TestUsageRecordModel:
"""Test usage record with brand association."""
brand_id = uuid.uuid4()
record = UsageRecord(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
brand_id=brand_id,
engine_type="wenxin",
query="Brand query",
@ -253,7 +254,7 @@ class TestUsageRecordModel:
@pytest.mark.asyncio
async def test_usage_record_index_user_engine(self, async_session, test_user):
"""Test composite index on user_id and engine_type."""
user_id = test_user.id
user_id = _to_uuid(test_user.id)
for i in range(5):
record = UsageRecord(
user_id=user_id,
@ -280,7 +281,7 @@ class TestUsageRecordModel:
async def test_usage_record_update(self, async_session, test_user):
"""Test updating usage record fields."""
record = UsageRecord(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
engine_type="xinghuo",
query="Original query",
cost=1.0,
@ -300,7 +301,7 @@ class TestUsageRecordModel:
async def test_usage_record_delete(self, async_session, test_user):
"""Test deleting a usage record."""
record = UsageRecord(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
engine_type="yuanbao",
query="Delete me",
cost=1.0,
@ -323,7 +324,7 @@ class TestUsageRecordModel:
async def test_usage_record_timestamps(self, async_session, test_user):
"""Test usage record timestamp fields."""
record = UsageRecord(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
engine_type="perplexity",
query="Timestamp test",
cost=1.0,
@ -343,7 +344,7 @@ class TestUsageRecordModel:
other_user_id = uuid.uuid4()
user1_record = UsageRecord(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
engine_type="chatgpt",
query="User 1 query",
cost=1.0,
@ -359,7 +360,7 @@ class TestUsageRecordModel:
await async_session.commit()
result_user1 = await async_session.execute(
select(UsageRecord).where(UsageRecord.user_id == test_user.id)
select(UsageRecord).where(UsageRecord.user_id == _to_uuid(test_user.id))
)
result_user2 = await async_session.execute(
select(UsageRecord).where(UsageRecord.user_id == other_user_id)
@ -372,7 +373,7 @@ class TestUsageRecordModel:
async def test_usage_record_empty_query_field(self, async_session, test_user):
"""Test usage record with empty query field."""
record = UsageRecord(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
engine_type="deepseek",
query="",
cost=0.0,
@ -394,7 +395,7 @@ class TestUsageRecordModel:
"nested": {"key": {"deep": {"value": [1, 2, 3]}}},
}
record = UsageRecord(
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
engine_type="chatgpt",
query="Large metadata test",
extra_data=large_metadata,

View File

@ -12,6 +12,7 @@ from app.models.user import User
from app.models.usage_record import UsageRecord
from app.repositories.usage_repository import UsageRepository
from app.services.user_quota_service import UserQuotaService, PLAN_MONTHLY_LIMITS
from tests.fixtures.auth import _to_uuid
@pytest_asyncio.fixture
@ -47,14 +48,14 @@ async def async_session(async_engine):
async def test_user_free(async_session):
"""Create a free plan test user."""
user = User(
id=uuid.uuid4(),
id=str(uuid.uuid4()),
email="free@example.com",
password_hash="hashed_password",
name="Free User",
password="hashed_password",
firstName="Free User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
isActive=True,
emailVerified=True,
)
async_session.add(user)
await async_session.commit()
@ -66,14 +67,14 @@ async def test_user_free(async_session):
async def test_user_basic(async_session):
"""Create a basic plan test user."""
user = User(
id=uuid.uuid4(),
id=str(uuid.uuid4()),
email="basic@example.com",
password_hash="hashed_password",
name="Basic User",
password="hashed_password",
firstName="Basic User",
plan="basic",
max_queries=50,
is_active=True,
email_verified=True,
isActive=True,
emailVerified=True,
)
async_session.add(user)
await async_session.commit()
@ -85,14 +86,14 @@ async def test_user_basic(async_session):
async def test_user_pro(async_session):
"""Create a pro plan test user."""
user = User(
id=uuid.uuid4(),
id=str(uuid.uuid4()),
email="pro@example.com",
password_hash="hashed_password",
name="Pro User",
password="hashed_password",
firstName="Pro User",
plan="pro",
max_queries=500,
is_active=True,
email_verified=True,
isActive=True,
emailVerified=True,
)
async_session.add(user)
await async_session.commit()
@ -110,7 +111,7 @@ class TestByDayAggregation:
for i in range(3):
await repo.create({
"user_id": test_user_free.id,
"user_id": _to_uuid(test_user_free.id),
"engine_type": "chatgpt",
"query": f"Query {i}",
"cost": 0.01,
@ -130,7 +131,7 @@ class TestByDayAggregation:
for i in range(5):
await repo.create({
"user_id": test_user_free.id,
"user_id": _to_uuid(test_user_free.id),
"engine_type": "deepseek",
"query": f"Query {i}",
"cost": 0.02,
@ -149,7 +150,7 @@ class TestByDayAggregation:
repo = UsageRepository(async_session)
await repo.create({
"user_id": test_user_free.id,
"user_id": _to_uuid(test_user_free.id),
"engine_type": "qwen",
"query": "Query 1",
"input_tokens": 100,
@ -157,7 +158,7 @@ class TestByDayAggregation:
"cost": 0.01,
})
await repo.create({
"user_id": test_user_free.id,
"user_id": _to_uuid(test_user_free.id),
"engine_type": "qwen",
"query": "Query 2",
"input_tokens": 150,
@ -188,14 +189,14 @@ class TestByDayAggregation:
yesterday = datetime.now(timezone.utc) - timedelta(days=1)
await repo.create({
"user_id": test_user_free.id,
"user_id": _to_uuid(test_user_free.id),
"engine_type": "gemini",
"query": "Yesterday query",
"cost": 0.05,
"timestamp": yesterday,
})
await repo.create({
"user_id": test_user_free.id,
"user_id": _to_uuid(test_user_free.id),
"engine_type": "gemini",
"query": "Today query",
"cost": 0.05,
@ -268,7 +269,7 @@ class TestUserQuotaService:
for i in range(5):
await repo.create({
"user_id": test_user_free.id,
"user_id": _to_uuid(test_user_free.id),
"engine_type": "chatgpt",
"query": f"Query {i}",
"cost": 1.0,
@ -291,7 +292,7 @@ class TestUserQuotaService:
for i in range(10):
await repo.create({
"user_id": test_user_basic.id,
"user_id": _to_uuid(test_user_basic.id),
"engine_type": "deepseek",
"query": f"Query {i}",
"cost": 2.0,
@ -314,7 +315,7 @@ class TestUserQuotaService:
for i in range(10):
await repo.create({
"user_id": test_user_pro.id,
"user_id": _to_uuid(test_user_pro.id),
"engine_type": "qwen",
"query": f"Query {i}",
"cost": 10.0,
@ -336,7 +337,7 @@ class TestUserQuotaService:
repo = UsageRepository(async_session)
await repo.create({
"user_id": test_user_free.id,
"user_id": _to_uuid(test_user_free.id),
"engine_type": "gemini",
"query": "Expensive query",
"cost": 15.0,

View File

@ -11,6 +11,7 @@ from app.database import Base
from app.models.user import User
from app.models.usage_record import UsageRecord
from app.repositories.usage_repository import UsageRepository
from tests.fixtures.auth import _to_uuid
@pytest_asyncio.fixture
@ -46,14 +47,14 @@ async def async_session(async_engine):
async def test_user(async_session):
"""Create a test user."""
user = User(
id=uuid.uuid4(),
id=str(uuid.uuid4()),
email="test@example.com",
password_hash="hashed_password",
name="Test User",
password="hashed_password",
firstName="Test User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
isActive=True,
emailVerified=True,
)
async_session.add(user)
await async_session.commit()
@ -70,7 +71,7 @@ class TestUsageRepository:
repo = UsageRepository(async_session)
data = {
"user_id": test_user.id,
"user_id": _to_uuid(test_user.id),
"engine_type": "chatgpt",
"query": "Test query",
"input_tokens": 100,
@ -82,7 +83,7 @@ class TestUsageRepository:
record = await repo.create(data)
assert record.id is not None
assert record.user_id == test_user.id
assert record.user_id == _to_uuid(test_user.id)
assert record.engine_type == "chatgpt"
assert record.query == "Test query"
assert record.input_tokens == 100
@ -96,7 +97,7 @@ class TestUsageRepository:
repo = UsageRepository(async_session)
data = {
"user_id": test_user.id,
"user_id": _to_uuid(test_user.id),
"engine_type": "deepseek",
"query": "Minimal query",
}
@ -104,7 +105,7 @@ class TestUsageRepository:
record = await repo.create(data)
assert record.id is not None
assert record.user_id == test_user.id
assert record.user_id == _to_uuid(test_user.id)
assert record.engine_type == "deepseek"
assert record.query == "Minimal query"
assert record.input_tokens == 0
@ -119,7 +120,7 @@ class TestUsageRepository:
for i in range(3):
await repo.create({
"user_id": test_user.id,
"user_id": _to_uuid(test_user.id),
"engine_type": "chatgpt",
"query": f"Query {i}",
"input_tokens": 100,
@ -143,13 +144,13 @@ class TestUsageRepository:
repo = UsageRepository(async_session)
await repo.create({
"user_id": test_user.id,
"user_id": _to_uuid(test_user.id),
"engine_type": "chatgpt",
"query": "ChatGPT query",
"cost": 0.02,
})
await repo.create({
"user_id": test_user.id,
"user_id": _to_uuid(test_user.id),
"engine_type": "deepseek",
"query": "DeepSeek query",
"cost": 0.01,
@ -170,7 +171,7 @@ class TestUsageRepository:
for i in range(2):
await repo.create({
"user_id": test_user.id,
"user_id": _to_uuid(test_user.id),
"engine_type": "qwen",
"query": f"Query {i}",
"cost": 0.01,
@ -190,14 +191,14 @@ class TestUsageRepository:
brand_id = uuid.uuid4()
await repo.create({
"user_id": test_user.id,
"user_id": _to_uuid(test_user.id),
"brand_id": brand_id,
"engine_type": "kimi",
"query": "Brand query",
"cost": 0.05,
})
await repo.create({
"user_id": test_user.id,
"user_id": _to_uuid(test_user.id),
"engine_type": "kimi",
"query": "No brand query",
"cost": 0.03,
@ -219,7 +220,7 @@ class TestUsageRepository:
for i in range(5):
await repo.create({
"user_id": test_user.id,
"user_id": _to_uuid(test_user.id),
"engine_type": "gemini",
"query": f"Query {i}",
"cost": 1.0,
@ -238,7 +239,7 @@ class TestUsageRepository:
repo = UsageRepository(async_session)
await repo.create({
"user_id": test_user.id,
"user_id": _to_uuid(test_user.id),
"engine_type": "chatgpt",
"query": "Expensive query",
"cost": 85.0,
@ -255,7 +256,7 @@ class TestUsageRepository:
repo = UsageRepository(async_session)
await repo.create({
"user_id": test_user.id,
"user_id": _to_uuid(test_user.id),
"engine_type": "chatgpt",
"query": "Very expensive query",
"cost": 120.0,
@ -272,7 +273,7 @@ class TestUsageRepository:
repo = UsageRepository(async_session)
created = await repo.create({
"user_id": test_user.id,
"user_id": _to_uuid(test_user.id),
"engine_type": "wenxin",
"query": "Get by ID test",
"cost": 0.5,
@ -291,7 +292,7 @@ class TestUsageRepository:
for i in range(3):
await repo.create({
"user_id": test_user.id,
"user_id": _to_uuid(test_user.id),
"engine_type": "doubao",
"query": f"User query {i}",
"cost": 0.1,
@ -307,13 +308,13 @@ class TestUsageRepository:
repo = UsageRepository(async_session)
await repo.create({
"user_id": test_user.id,
"user_id": _to_uuid(test_user.id),
"engine_type": "xinghuo",
"query": "Xinghuo query",
"cost": 0.1,
})
await repo.create({
"user_id": test_user.id,
"user_id": _to_uuid(test_user.id),
"engine_type": "perplexity",
"query": "Perplexity query",
"cost": 0.2,
@ -358,7 +359,7 @@ class TestUsageRepository:
repo = UsageRepository(async_session)
data = {
"user_id": test_user.id,
"user_id": _to_uuid(test_user.id),
"engine_type": "yuanbao",
"query": "UUID test",
"cost": 0.5,
@ -368,4 +369,4 @@ class TestUsageRepository:
summary = await repo.get_summary(test_user.id, period="month")
assert summary["total_queries"] == 1
assert record.user_id == test_user.id
assert record.user_id == _to_uuid(test_user.id)

View File

@ -11,6 +11,7 @@ from app.database import Base
from app.models.brand import Brand
from app.models.user import User
from app.services.auth import hash_password
from tests.fixtures.auth import _to_uuid
@pytest_asyncio.fixture
@ -42,14 +43,14 @@ async def async_session(async_engine):
@pytest_asyncio.fixture
async def test_user(async_session):
user = User(
id=uuid.uuid4(),
id=str(uuid.uuid4()),
email="test@example.com",
password_hash=hash_password("Test@123456"),
name="Test User",
password=hash_password("Test@123456"),
firstName="Test User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
isActive=True,
emailVerified=True,
)
async_session.add(user)
await async_session.commit()
@ -61,7 +62,7 @@ async def test_user(async_session):
async def test_brand(async_session, test_user):
brand = Brand(
id=uuid.uuid4(),
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
name="Test Brand",
aliases=["TestBrand", "TB"],
website="https://testbrand.com",
@ -83,7 +84,7 @@ class TestDetectionTaskModel:
task = DetectionTask(
brand_id=test_brand.id,
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
name="每日品牌检测",
frequency="daily",
engines=["chatgpt", "perplexity"],
@ -96,7 +97,7 @@ class TestDetectionTaskModel:
assert task.id is not None
assert task.brand_id == test_brand.id
assert task.user_id == test_user.id
assert task.user_id == _to_uuid(test_user.id)
assert task.name == "每日品牌检测"
assert task.frequency == "daily"
assert task.engines == ["chatgpt", "perplexity"]
@ -114,7 +115,7 @@ class TestDetectionTaskModel:
task = DetectionTask(
brand_id=test_brand.id,
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
name="简单检测",
frequency="weekly",
engines=["chatgpt"],
@ -142,13 +143,13 @@ class TestDetectionSchedulerService:
"queries": ["最佳保险品牌", "保险推荐"],
"competitor_names": ["竞品A"],
}
task = await service.create_task(task_data, test_brand.id, test_user.id, async_session)
task = await service.create_task(task_data, test_brand.id, _to_uuid(test_user.id), async_session)
assert task.id is not None
assert task.name == "每日品牌检测"
assert task.frequency == "daily"
assert task.brand_id == test_brand.id
assert task.user_id == test_user.id
assert task.user_id == _to_uuid(test_user.id)
@pytest.mark.asyncio
async def test_update_task(self, async_session, test_brand, test_user):
@ -157,7 +158,7 @@ class TestDetectionSchedulerService:
task = DetectionTask(
brand_id=test_brand.id,
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
name="旧名称",
frequency="weekly",
engines=["chatgpt"],
@ -173,7 +174,7 @@ class TestDetectionSchedulerService:
"frequency": "daily",
"engines": ["chatgpt", "perplexity"],
}
updated = await service.update_task(task.id, update_data, test_user.id, async_session)
updated = await service.update_task(task.id, update_data, _to_uuid(test_user.id), async_session)
assert updated.name == "新名称"
assert updated.frequency == "daily"
@ -186,7 +187,7 @@ class TestDetectionSchedulerService:
task = DetectionTask(
brand_id=test_brand.id,
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
name="待删除",
frequency="weekly",
engines=["chatgpt"],
@ -197,7 +198,7 @@ class TestDetectionSchedulerService:
await async_session.refresh(task)
service = DetectionSchedulerService()
result = await service.delete_task(task.id, test_user.id, async_session)
result = await service.delete_task(task.id, _to_uuid(test_user.id), async_session)
assert result is True
stmt = select(DetectionTask).where(DetectionTask.id == task.id)
@ -212,7 +213,7 @@ class TestDetectionSchedulerService:
for i in range(3):
task = DetectionTask(
brand_id=test_brand.id,
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
name=f"任务{i}",
frequency="daily",
engines=["chatgpt"],
@ -222,7 +223,7 @@ class TestDetectionSchedulerService:
await async_session.commit()
service = DetectionSchedulerService()
tasks = await service.get_tasks(test_brand.id, test_user.id, async_session)
tasks = await service.get_tasks(test_brand.id, _to_uuid(test_user.id), async_session)
assert len(tasks) == 3
@pytest.mark.asyncio
@ -232,7 +233,7 @@ class TestDetectionSchedulerService:
task = DetectionTask(
brand_id=test_brand.id,
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
name="手动触发测试",
frequency="daily",
engines=["chatgpt"],
@ -245,7 +246,7 @@ class TestDetectionSchedulerService:
service = DetectionSchedulerService()
with patch.object(service, "execute_task", new_callable=AsyncMock) as mock_execute:
mock_execute.return_value = {"status": "success", "results": []}
result = await service.trigger_task(task.id, test_user.id, async_session)
result = await service.trigger_task(task.id, _to_uuid(test_user.id), async_session)
assert result["status"] == "success"
mock_execute.assert_called_once()
@ -261,7 +262,7 @@ class TestDetectionSchedulerService:
"engines": ["chatgpt"],
"queries": ["查询"],
}
task = await service.create_task(task_data, test_brand.id, test_user.id, async_session)
task = await service.create_task(task_data, test_brand.id, _to_uuid(test_user.id), async_session)
assert task.frequency == "hourly"
assert task.next_run_at is not None
@ -276,7 +277,7 @@ class TestDetectionSchedulerService:
"engines": ["chatgpt"],
"queries": ["查询"],
}
task = await service.create_task(task_data, test_brand.id, test_user.id, async_session)
task = await service.create_task(task_data, test_brand.id, _to_uuid(test_user.id), async_session)
assert task.frequency == "daily"
assert task.next_run_at is not None
@ -291,7 +292,7 @@ class TestDetectionSchedulerService:
"engines": ["chatgpt"],
"queries": ["查询"],
}
task = await service.create_task(task_data, test_brand.id, test_user.id, async_session)
task = await service.create_task(task_data, test_brand.id, _to_uuid(test_user.id), async_session)
assert task.frequency == "weekly"
assert task.next_run_at is not None
@ -307,7 +308,7 @@ class TestDetectionSchedulerService:
"queries": ["查询"],
}
with pytest.raises(ValueError, match="frequency"):
await service.create_task(task_data, test_brand.id, test_user.id, async_session)
await service.create_task(task_data, test_brand.id, _to_uuid(test_user.id), async_session)
@pytest.mark.asyncio
async def test_execute_task_flow(self, async_session, test_brand, test_user):
@ -316,7 +317,7 @@ class TestDetectionSchedulerService:
task = DetectionTask(
brand_id=test_brand.id,
user_id=test_user.id,
user_id=_to_uuid(test_user.id),
name="执行流程测试",
frequency="daily",
engines=["chatgpt"],
@ -356,7 +357,7 @@ class TestDetectionSchedulerService:
from app.services.detection.detection_scheduler import DetectionSchedulerService
service = DetectionSchedulerService()
result = await service.delete_task(uuid.uuid4(), test_user.id, async_session)
result = await service.delete_task(uuid.uuid4(), _to_uuid(test_user.id), async_session)
assert result is False
@pytest.mark.asyncio
@ -365,12 +366,12 @@ class TestDetectionSchedulerService:
service = DetectionSchedulerService()
with pytest.raises(TaskNotFoundError):
await service.update_task(uuid.uuid4(), {"name": "新名称"}, test_user.id, async_session)
await service.update_task(uuid.uuid4(), {"name": "新名称"}, _to_uuid(test_user.id), async_session)
@pytest.mark.asyncio
async def test_get_tasks_empty(self, async_session, test_brand, test_user):
from app.services.detection.detection_scheduler import DetectionSchedulerService
service = DetectionSchedulerService()
tasks = await service.get_tasks(test_brand.id, test_user.id, async_session)
tasks = await service.get_tasks(test_brand.id, _to_uuid(test_user.id), async_session)
assert tasks == []

View File

@ -0,0 +1,886 @@
---
title: "refactor: GEO Agent Framework — 统一架构重构为独立项目"
type: refactor
status: active
date: 2026-06-04
---
## Summary
将 GEO 项目的 8 个 Agent 从代码重复的独立类重构为独立 Python 包(`fischer-agentkit`),采用统一 Agent Core + 配置驱动 + 可插拔 Tool/Skill/Memory 架构。新增 MCP 协议支持、三层记忆系统Working/Episodic/Semantic、Level 3 自我进化闭环(经验积累 → Prompt 自动优化 → 策略调整)、多 Agent 协同增强并行执行、Handoff、动态 Pipeline并实现 Agent 与业务系统的完全解耦。
## Problem Frame
当前 GEO 项目的 Agent 体系存在以下结构性问题:
1. **代码重复**8 个 Agent 的 `execute()` 方法包含完全相同的计时、try/except、TaskResult 构建逻辑(~250 行模板代码),每新增一个 Agent 需复制 ~150 行
2. **架构耦合**Agent 直接依赖业务 ServiceCitationService、MonitorService 等),无法独立部署和复用
3. **能力缺失**无记忆系统每次执行无状态、无自我进化Prompt 硬编码不可调优)、无技能插件(能力静态绑定)、无 MCP 支持
4. **编排局限**Pipeline 仅支持串行 DAG、无 Agent 间 Handoff、无动态路由
5. **配置僵化**AgentType 硬编码枚举、Prompt 与 Agent 紧耦合、新增 Agent 需改 5+ 文件
---
## Requirements
### Agent Core 统一架构
R1. Agent 框架必须作为独立 Python 包(`fischer-agentkit`),可独立安装、版本管理、跨项目复用
R2. 所有 Agent 共享统一的 `BaseAgent` 生命周期start → listen → execute → reflect → evolve → stop子类只需实现 `handle_task()` 返回业务数据
R3. Agent 定义支持三种模式Python 声明式(`@task` 装饰器、YAML 配置驱动(零代码)、混合模式
R4. AgentType 从硬编码枚举改为动态注册,新增 Agent 类型无需修改框架代码
R5. Agent 的 input/output 支持 JSON Schema 声明与校验
### Tool/Skill 插件系统
R6. 实现 `Tool` 抽象基类,支持 `FunctionTool`(函数工具)、`AgentTool`Agent 即工具)、`MCPTool`MCP 协议工具)三种类型
R7. 实现 `ToolRegistry`,支持工具的注册、发现、版本管理、标签分类
R8. 实现 MCP Server将现有 Agent 能力暴露为 MCP 工具供外部调用
R9. 实现 MCP ClientAgent 可调用外部 MCP 工具服务器
R10. 工具支持组合:顺序链、并行扇出/扇入、动态选择
### 记忆系统
R11. 实现 Working MemoryRedis-based存储当前任务的上下文和中间状态生命周期为单次任务
R12. 实现 Episodic Memory向量+关系数据库),记录每次任务的输入/输出/效果/反思,支持语义检索相似历史案例
R13. 实现 Semantic Memory复用现有 RAG 知识库,所有 Agent 均可通过统一接口检索知识
R14. 记忆检索采用混合策略:向量语义 + 关键词 + 知识图谱 + 时间衰减 + RRF 融合排序
### 自我进化Level 3
R15. 实现经验积累:每次任务执行后自动记录成功/失败模式、效果指标、反思总结
R16. 实现 Prompt 自动优化:基于 DSPy 风格的编译器,从任务结果中自动生成/优化 Prompt 指令和 few-shot 示例
R17. 实现策略调整:根据历史效果数据自动调整 Agent 参数temperature、tool 选择权重、Pipeline 路径)
R18. 进化变更必须经过 A/B 测试验证后才可生效,支持回滚
### 多 Agent 协同与业务编排
R19. Pipeline Engine 支持同层并行执行(无依赖的 stages 使用 `asyncio.gather` 并行)
R20. 实现 Handoff 机制Agent 可在运行时将任务转交给另一个 Agent携带上下文
R21. 支持动态 Pipeline运行时根据条件选择子流程、嵌套 Pipeline
R22. Agent 间通信支持事件驱动Redis Pub/Sub替代轮询等待
### Agent 与业务解耦
R23. Agent 框架不依赖任何 GEO 业务代码Service、Model、Repository通过 Tool 接口调用业务能力
R24. 业务系统通过 Tool 注册将自身能力暴露给 AgentAgent 通过 ToolRegistry 发现和调用
R25. Agent 配置Prompt、Tool 绑定、Memory 策略)存储在数据库,支持热更新
---
## Key Technical Decisions
KTD1. **独立 Python 包架构**`fischer-agentkit` 作为独立包发布到私有 PyPIGEO 项目通过 `pip install` 引入。包结构遵循 `src/` layout支持独立测试和版本管理。理由解耦 Agent 基础设施与业务代码,支持跨项目复用。
KTD2. **统一 Agent Core 的 execute 模板上移**将计时、try/except、TaskResult 构建、进度上报等公共逻辑全部上移到 `BaseAgent.execute()`,子类只需实现 `handle_task(task) -> dict`。理由:消除 8 个 Agent 中 ~250 行重复代码。
KTD3. **Tool 抽象三层架构**`FunctionTool`(进程内函数调用)→ `MCPTool`(跨进程 MCP 协议调用)→ `AgentTool`(将 Agent 包装为 Tool。理由覆盖从简单到复杂的所有工具场景MCP 作为标准协议确保生态兼容性。
KTD4. **MCP 双向支持**:同时实现 MCP Server暴露 Agent 能力)和 MCP Client调用外部工具。传输层优先支持 Streamable HTTP2025 标准),兼容 SSE。理由MCP 是 2025 年工具协议事实标准,双向支持最大化生态连接能力。
KTD5. **三层记忆架构**Working MemoryRedis单任务生命周期→ Episodic Memorypgvector + PostgreSQL任务经验→ Semantic Memory复用现有 RAG + 知识图谱)。理由:不同记忆类型有不同生命周期和检索模式,分层实现职责清晰。
KTD6. **Level 3 进化采用 DSPy 风格编译器**:定义 `Signature`(输入/输出 schema`Module`(可组合 Prompt 策略)→ `Optimizer`(自动优化),从任务结果中自动构建 few-shot 示例和优化指令。理由DSPy 是目前最成熟的 Prompt 自动优化范式,其编译器模式与 Agent 的 execute-reflect-evolve 生命周期天然契合。
KTD7. **配置驱动的 Agent 定义**Agent 的元数据、Prompt、Tool 绑定、Memory 策略均通过 YAML/数据库配置,运行时由 `ConfigDrivenAgent` 自动组装。理由:将新增 Agent 从写 150 行代码降为 10-20 行配置。
KTD8. **Pipeline 并行执行**:将拓扑排序后的 stages 按依赖层级分组,同层内使用 `asyncio.gather` 并行执行。理由:引用检测和趋势分析等无依赖任务应并行,当前串行执行浪费时间。
KTD9. **Handoff 基于 Redis Pub/Sub**Agent 通过 `agent:{name}:handoff` 频道发送转交请求,目标 Agent 监听并接管。理由:与现有 Redis Queue 架构一致,无需引入新组件。
---
## High-Level Technical Design
### 整体架构
```mermaid
flowchart TB
subgraph FischerAgentKit["fischer-agentkit (独立包)"]
direction TB
Core["Agent Core"]
Tools["Tool System"]
Memory["Memory System"]
Evolution["Evolution Engine"]
Orchestrator["Orchestrator"]
MCP["MCP Layer"]
subgraph Core
BaseAgent["BaseAgent<br/>生命周期管理"]
ConfigDriven["ConfigDrivenAgent<br/>配置驱动"]
Registry["AgentRegistry<br/>注册发现"]
Dispatcher["TaskDispatcher<br/>任务调度"]
Protocol["Protocol<br/>通信协议"]
end
subgraph Tools
ToolRegistry["ToolRegistry<br/>注册发现"]
FunctionTool["FunctionTool<br/>函数工具"]
MCPTool["MCPTool<br/>MCP工具"]
AgentTool["AgentTool<br/>Agent即工具"]
end
subgraph Memory
Working["Working Memory<br/>Redis"]
Episodic["Episodic Memory<br/>pgvector+PG"]
Semantic["Semantic Memory<br/>RAG+Graph"]
Retriever["MemoryRetriever<br/>混合检索"]
end
subgraph Evolution
Reflector["Reflector<br/>执行反思"]
PromptOptimizer["PromptOptimizer<br/>DSPy编译器"]
StrategyTuner["StrategyTuner<br/>策略调优"]
ABTester["ABTester<br/>A/B测试"]
end
subgraph Orchestrator
PipelineEngine["PipelineEngine<br/>DAG+并行"]
Handoff["Handoff<br/>任务转交"]
DynamicPipeline["DynamicPipeline<br/>运行时组合"]
end
subgraph MCP
MCPServer["MCP Server<br/>暴露能力"]
MCPClient["MCP Client<br/>调用外部"]
end
end
subgraph GEOBusiness["GEO 业务系统"]
Services["业务 Services"]
Models["数据 Models"]
Repos["Repositories"]
end
Core --> Tools
Core --> Memory
Core --> Evolution
Core --> Orchestrator
Tools --> MCP
Services -.->|"注册为 FunctionTool"| ToolRegistry
Semantic -.->|"复用"| Services
```
### Agent 生命周期
```mermaid
stateDiagram-v2
[*] --> Offline
Offline --> Online: start()
Online --> Listening: listen_for_tasks()
Listening --> Executing: receive_task()
Executing --> Reflecting: handle_task() 完成
Reflecting --> Evolving: reflect() 有改进建议
Reflecting --> Listening: reflect() 无改进
Evolving --> Listening: evolve() 完成
Executing --> Listening: handle_task() 异常
Online --> Offline: stop()
Listening --> Offline: stop()
state Executing {
[*] --> LoadMemory: 加载记忆
LoadMemory --> PlanTask: 规划任务
PlanTask --> RunTools: 执行工具
RunTools --> StoreMemory: 存储记忆
StoreMemory --> BuildResult: 构建结果
BuildResult --> [*]
}
state Reflecting {
[*] --> EvaluateOutcome: 评估结果
EvaluateOutcome --> ExtractPatterns: 提取模式
ExtractPatterns --> GenerateInsights: 生成洞察
GenerateInsights --> [*]
}
state Evolving {
[*] --> OptimizePrompt: 优化Prompt
OptimizePrompt --> TuneStrategy: 调优策略
TuneStrategy --> ABTest: A/B测试
ABTest --> ApplyOrRollback: 应用或回滚
ApplyOrRollback --> [*]
}
```
### Tool 系统架构
```mermaid
flowchart LR
subgraph Agent["Agent"]
Executor["Task Executor"]
end
subgraph ToolRegistry["ToolRegistry"]
FT["FunctionTool<br/>name, description<br/>input_schema, output_schema<br/>execute(**kwargs)"]
MT["MCPTool<br/>server_url, tool_name<br/>input_schema, output_schema<br/>execute(**kwargs)"]
AT["AgentTool<br/>agent_name<br/>input_mapping, output_mapping<br/>execute(**kwargs)"]
end
subgraph MCPServer["MCP Server"]
MCPEndpoints["/tools/list<br/>/tools/call<br/>/resources/read"]
end
subgraph ExternalMCP["External MCP Servers"]
ExtTool1["File System"]
ExtTool2["GitHub"]
ExtTool3["Postgres"]
end
Executor -->|"select & call"| ToolRegistry
FT -->|"进程内调用"| BusinessLogic["业务函数"]
MT -->|"HTTP/SSE"| MCPServer
MT -->|"HTTP/SSE"| ExternalMCP
AT -->|"dispatch"| TargetAgent["目标 Agent"]
MCPServer -->|"暴露"| AgentCapabilities["Agent 能力"]
```
### 记忆系统数据流
```mermaid
flowchart TB
Task["新任务到达"] --> WM["Working Memory<br/>加载当前任务上下文"]
WM --> EM["Episodic Memory<br/>检索相似历史案例"]
EM --> SM["Semantic Memory<br/>检索知识库"]
SM --> Context["组装上下文"]
Context --> Execute["执行任务"]
Execute --> WM_Write["写入 Working Memory<br/>中间状态"]
WM_Write --> Execute
Execute --> EM_Write["写入 Episodic Memory<br/>任务经验"]
Execute --> Reflect["反思评估"]
Reflect --> EM_Update["更新 Episodic Memory<br/>反思结果"]
subgraph 检索策略
VS["向量语义检索<br/>权重 0.5"]
KS["关键词精确匹配<br/>权重 0.2"]
GT["知识图谱关联<br/>权重 0.3"]
RRF["RRF 融合排序"]
TD["时间衰减"]
VS --> RRF
KS --> RRF
GT --> RRF
RRF --> TD
end
```
### 进化闭环
```mermaid
flowchart TB
Execute["任务执行"] --> Result["任务结果"]
Result --> Evaluate["效果评估<br/>成功/失败/质量分"]
Evaluate --> Pattern["模式提取<br/>成功模式/失败模式"]
Pattern --> Optimize["优化生成"]
Optimize --> PromptOpt["Prompt 优化<br/>DSPy 编译器"]
Optimize --> StrategyOpt["策略调优<br/>参数/工具权重"]
Optimize --> PipelineOpt["Pipeline 优化<br/>路径选择"]
PromptOpt --> ABTest["A/B 测试"]
StrategyOpt --> ABTest
PipelineOpt --> ABTest
ABTest -->|"效果提升"| Apply["应用变更"]
ABTest -->|"效果下降"| Rollback["回滚"]
ABTest -->|"不确定"| Extend["延长测试"]
Apply --> Record["记录进化日志"]
Rollback --> Record
Record --> Execute
```
---
## Output Structure
```
fischer-agentkit/ # 独立 Python 包
├── src/
│ └── agentkit/
│ ├── __init__.py
│ ├── core/
│ │ ├── __init__.py
│ │ ├── base.py # BaseAgent 生命周期
│ │ ├── config_driven.py # ConfigDrivenAgent 配置驱动
│ │ ├── registry.py # AgentRegistry 注册发现
│ │ ├── dispatcher.py # TaskDispatcher 任务调度
│ │ ├── protocol.py # 通信协议与数据结构
│ │ ├── exceptions.py # 异常体系
│ │ └── standalone.py # 独立启动器(自动发现)
│ ├── tools/
│ │ ├── __init__.py
│ │ ├── base.py # Tool 抽象基类
│ │ ├── function_tool.py # FunctionTool
│ │ ├── agent_tool.py # AgentTool
│ │ ├── mcp_tool.py # MCPTool
│ │ └── registry.py # ToolRegistry
│ ├── memory/
│ │ ├── __init__.py
│ │ ├── base.py # Memory 抽象基类
│ │ ├── working.py # Working Memory (Redis)
│ │ ├── episodic.py # Episodic Memory (pgvector+PG)
│ │ ├── semantic.py # Semantic Memory (RAG+Graph)
│ │ └── retriever.py # 混合检索器
│ ├── evolution/
│ │ ├── __init__.py
│ │ ├── reflector.py # 执行反思
│ │ ├── prompt_optimizer.py # DSPy 风格 Prompt 优化器
│ │ ├── strategy_tuner.py # 策略调优
│ │ ├── ab_tester.py # A/B 测试框架
│ │ └── evolution_store.py # 进化日志存储
│ ├── orchestrator/
│ │ ├── __init__.py
│ │ ├── pipeline_engine.py # DAG + 并行 Pipeline
│ │ ├── pipeline_schema.py # Pipeline 数据模型
│ │ ├── pipeline_loader.py # YAML 加载器
│ │ ├── handoff.py # Handoff 机制
│ │ └── dynamic_pipeline.py # 动态 Pipeline 组合
│ ├── mcp/
│ │ ├── __init__.py
│ │ ├── server.py # MCP Server
│ │ ├── client.py # MCP Client
│ │ └── transport.py # 传输层 (HTTP/SSE)
│ └── prompts/
│ ├── __init__.py
│ ├── template.py # PromptTemplate
│ └── section.py # PromptSection
├── tests/
│ ├── unit/
│ │ ├── test_base_agent.py
│ │ ├── test_config_driven.py
│ │ ├── test_tool_registry.py
│ │ ├── test_function_tool.py
│ │ ├── test_mcp_tool.py
│ │ ├── test_agent_tool.py
│ │ ├── test_working_memory.py
│ │ ├── test_episodic_memory.py
│ │ ├── test_semantic_memory.py
│ │ ├── test_memory_retriever.py
│ │ ├── test_reflector.py
│ │ ├── test_prompt_optimizer.py
│ │ ├── test_strategy_tuner.py
│ │ ├── test_ab_tester.py
│ │ ├── test_pipeline_parallel.py
│ │ ├── test_handoff.py
│ │ ├── test_dynamic_pipeline.py
│ │ ├── test_mcp_server.py
│ │ └── test_mcp_client.py
│ └── integration/
│ ├── test_agent_lifecycle.py
│ ├── test_tool_composition.py
│ ├── test_evolution_loop.py
│ └── test_mcp_roundtrip.py
├── pyproject.toml
└── README.md
geo/backend/ # GEO 业务系统(改造后)
├── app/
│ ├── agent_framework/ # 保留为适配层
│ │ ├── __init__.py
│ │ ├── agents/ # 改为配置驱动
│ │ │ ├── __init__.py
│ │ │ ├── configs/ # Agent YAML 配置
│ │ │ │ ├── citation_detector.yaml
│ │ │ │ ├── content_generator.yaml
│ │ │ │ ├── deai_agent.yaml
│ │ │ │ ├── geo_optimizer.yaml
│ │ │ │ ├── monitor.yaml
│ │ │ │ ├── schema_advisor.yaml
│ │ │ │ ├── competitor_analyzer.yaml
│ │ │ │ └── trend_agent.yaml
│ │ │ └── custom_handlers/ # 仅复杂 Agent 需自定义 handler
│ │ │ ├── citation_handler.py
│ │ │ └── monitor_handler.py
│ │ ├── tools/ # 业务 Tool 注册
│ │ │ ├── __init__.py
│ │ │ ├── citation_tools.py
│ │ │ ├── content_tools.py
│ │ │ ├── knowledge_tools.py
│ │ │ ├── monitor_tools.py
│ │ │ └── schema_tools.py
│ │ └── prompts/ # Prompt 模板(保留,可热更新)
│ ├── models/
│ │ └── agent.py # 新增 evolution_logs, episodic_memories 表
│ └── ...
```
---
## Implementation Units
### U1. 独立包脚手架与 BaseAgent 重构
**Goal:** 创建 `fischer-agentkit` 独立包,重构 BaseAgent 将 execute 模板代码上移,子类只需实现 `handle_task()`
**Dependencies:** 无
**Files:**
- `fischer-agentkit/src/agentkit/__init__.py`
- `fischer-agentkit/src/agentkit/core/__init__.py`
- `fischer-agentkit/src/agentkit/core/base.py`
- `fischer-agentkit/src/agentkit/core/protocol.py`
- `fischer-agentkit/src/agentkit/core/exceptions.py`
- `fischer-agentkit/src/agentkit/core/registry.py`
- `fischer-agentkit/src/agentkit/core/dispatcher.py`
- `fischer-agentkit/pyproject.toml`
- `fischer-agentkit/tests/unit/test_base_agent.py`
**Approach:**
1. 创建 `fischer-agentkit` 包,使用 `src/` layout`pyproject.toml` 声明依赖(`redis[hiredis]`, `pydantic>=2.0`, `sqlalchemy[asyncio]>=2.0`
2. 从 GEO 项目迁移 `protocol.py`AgentCapability, TaskMessage, TaskResult, TaskProgress, TaskStatus, AgentStatus移除 AgentType 硬编码枚举,改为动态注册
3. 重构 `BaseAgent`
- `execute()` 方法变为 final非抽象包含完整的计时、try/except、TaskResult 构建、进度上报
- 新增抽象方法 `handle_task(task) -> dict`,子类只需返回 output_data
- 新增生命周期钩子:`on_task_start()`, `on_task_complete()`, `on_task_failed()`
- 新增 `tools: list[Tool]` 属性默认空U3 实现)
- 新增 `memory: Memory` 属性默认空U4 实现)
4. 迁移 `registry.py`, `dispatcher.py`, `exceptions.py`,去除 GEO 业务依赖
5. `AgentRegistry.get_available_agent()` 增加负载均衡策略(轮询/最少任务)
**Patterns to follow:** 现有 `backend/app/agent_framework/base.py` 的双模式Redis/本地)设计
**Test scenarios:**
- BaseAgent 子类实现 handle_task 返回 dictexecute 自动包装为 TaskResult
- handle_task 抛异常时 execute 自动构建 FAILED TaskResult
- on_task_start/complete/failed 钩子按序调用
- AgentRegistry 动态注册新 AgentType 不报错
- AgentRegistry.get_available_agent 轮询策略返回不同 Agent
**Verification:** `fischer-agentkit` 可独立 `pip install -e .`,单元测试全部通过
---
### U2. Protocol 扩展与配置驱动 Agent
**Goal:** 扩展 Protocol 支持 JSON Schema 校验、Handoff 消息;实现 ConfigDrivenAgent支持 YAML/数据库配置驱动 Agent 定义。
**Dependencies:** U1
**Files:**
- `fischer-agentkit/src/agentkit/core/protocol.py`(扩展)
- `fischer-agentkit/src/agentkit/core/config_driven.py`
- `fischer-agentkit/src/agentkit/core/standalone.py`
- `fischer-agentkit/src/agentkit/prompts/template.py`
- `fischer-agentkit/src/agentkit/prompts/section.py`
- `fischer-agentkit/tests/unit/test_config_driven.py`
- `fischer-agentkit/tests/unit/test_protocol.py`
**Approach:**
1. Protocol 扩展:
- `AgentCapability` 增加 `input_schema: dict | None`、`output_schema: dict | None`JSON Schema
- 新增 `HandoffMessage` 数据类source_agent, target_agent, task_context, reason
- 新增 `EvolutionEvent` 数据类agent_name, change_type, before, after, metrics
- `TaskMessage` 增加 `conversation_id: str | None` 支持多轮对话
2. ConfigDrivenAgent
- 接受 YAML 配置或数据库配置自动组装Prompt 模板 + LLM 参数 + Tool 绑定 + Memory 策略
- 内置 LLM 调用 + JSON 输出解析 + 降级兜底的标准流程
- 支持三种任务模式:`llm_generate`Prompt → LLM → 解析)、`tool_call`(调用指定 Tool、`custom`(自定义 handler
3. PromptTemplate 迁移并增强:支持动态 section 组合、版本标记
4. Standalone runner 改为自动发现:扫描 `agent_configs/` 目录下的 YAML 文件,自动注册和启动
**Patterns to follow:** 现有 `prompts/base_template.py` 的 PromptSection 设计
**Test scenarios:**
- ConfigDrivenAgent 从 YAML 加载配置并正确组装
- llm_generate 模式:渲染 Prompt → 调用 Mock LLM → 解析 JSON 输出
- tool_call 模式:调用注册的 FunctionTool 并返回结果
- input_schema 校验:缺少必填字段时返回校验错误
- HandoffMessage 序列化/反序列化正确
- 自动发现:在 configs/ 目录放置 YAML 后 standalone 可自动加载
**Verification:** 通过 YAML 配置定义一个新 Agent无需写 Python 代码即可运行
---
### U3. Tool/Skill 插件系统
**Goal:** 实现 Tool 抽象基类、FunctionTool、AgentTool、ToolRegistry支持工具注册、发现、组合。
**Dependencies:** U1
**Files:**
- `fischer-agentkit/src/agentkit/tools/__init__.py`
- `fischer-agentkit/src/agentkit/tools/base.py`
- `fischer-agentkit/src/agentkit/tools/function_tool.py`
- `fischer-agentkit/src/agentkit/tools/agent_tool.py`
- `fischer-agentkit/src/agentkit/tools/registry.py`
- `fischer-agentkit/tests/unit/test_tool_registry.py`
- `fischer-agentkit/tests/unit/test_function_tool.py`
- `fischer-agentkit/tests/unit/test_agent_tool.py`
- `fischer-agentkit/tests/integration/test_tool_composition.py`
**Approach:**
1. `Tool` 抽象基类:
- 属性:`name`, `description`, `input_schema`JSON Schema, `output_schema`JSON Schema, `version`, `tags`
- 抽象方法:`async execute(**kwargs) -> dict`
- 生命周期钩子:`before_execute()`, `after_execute()`, `on_error()`
2. `FunctionTool`:包装普通 Python 函数为 Tool自动从函数签名推断 input_schema
3. `AgentTool`:包装另一个 Agent 为 Tool输入/输出通过 mapping 适配,内部通过 Dispatcher 分发任务
4. `ToolRegistry`
- `register(tool)`, `unregister(name)`, `get(name)`, `list_tools(tag=None)`
- 支持版本管理:同一工具多版本共存,默认使用最新版
- 支持标签分类:`[citation, analysis, generation, optimization]`
5. 工具组合:
- `SequentialChain([tool_a, tool_b])`:顺序执行,前一个输出作为后一个输入
- `ParallelFanOut([tool_a, tool_b, tool_c])`:并行执行,结果合并
- `DynamicSelector(llm, tools)`LLM 根据任务动态选择工具
**Patterns to follow:** 现有 `LLMFactory` 的注册-创建模式
**Test scenarios:**
- FunctionTool 从函数自动推断 schema 并执行
- AgentTool 分发任务到目标 Agent 并返回结果
- ToolRegistry 注册/发现/按标签过滤
- SequentialChain 顺序执行两个工具
- ParallelFanOut 并行执行三个工具
- DynamicSelector 根据 LLM 判断选择合适工具
- 工具版本管理:注册 v1 和 v2默认返回 v2
**Verification:** 通过 ToolRegistry 注册 3 个 FunctionToolAgent 可声明式绑定并调用
---
### U4. 记忆系统
**Goal:** 实现三层记忆系统Working/Episodic/Semantic和混合检索器。
**Dependencies:** U1
**Files:**
- `fischer-agentkit/src/agentkit/memory/__init__.py`
- `fischer-agentkit/src/agentkit/memory/base.py`
- `fischer-agentkit/src/agentkit/memory/working.py`
- `fischer-agentkit/src/agentkit/memory/episodic.py`
- `fischer-agentkit/src/agentkit/memory/semantic.py`
- `fischer-agentkit/src/agentkit/memory/retriever.py`
- `fischer-agentkit/tests/unit/test_working_memory.py`
- `fischer-agentkit/tests/unit/test_episodic_memory.py`
- `fischer-agentkit/tests/unit/test_semantic_memory.py`
- `fischer-agentkit/tests/unit/test_memory_retriever.py`
**Approach:**
1. `Memory` 抽象基类:
- `async store(key, value, metadata)`, `async retrieve(query, top_k)`, `async delete(key)`
- `async search(query, top_k, filters)` — 语义检索
2. `WorkingMemory`Redis
- 以 `agent:{name}:working_memory:{task_id}` 为 key 前缀
- 支持自动过期TTL = 任务超时时间 × 2
- 提供 `get_context()` 方法,返回格式化的上下文字符串
3. `EpisodicMemory`pgvector + PostgreSQL
- 表结构:`id, agent_name, task_type, input_summary, output_summary, outcome(success/fail), quality_score, reflection, embedding, created_at`
- 写入时自动生成 embedding复用 Embedder 接口)
- 检索:向量语义 + 关键词 + 时间衰减 + RRF 融合
4. `SemanticMemory`
- 适配器模式,对接 GEO 项目的 `RAGService``GraphQuery`
- 提供统一的 `search(query, knowledge_base_ids, top_k)` 接口
5. `MemoryRetriever`(混合检索器):
- 并行查询三层记忆,按权重融合排序
- 时间衰减:`score *= exp(-0.01 * age_hours)`
- 上下文窗口管理:总 token 不超过预算
**Patterns to follow:** 现有 `HybridRetriever` 的 RRF 融合排序模式
**Test scenarios:**
- WorkingMemory 存取任务上下文TTL 过期后自动清除
- EpisodicMemory 写入任务经验并按语义检索相似案例
- EpisodicMemory 时间衰减:近期经验权重高于远期
- SemanticMemory 通过适配器调用 RAGService
- MemoryRetriever 混合检索三层记忆并按权重融合
- 上下文窗口管理:检索结果超过 token 预算时智能截断
**Verification:** Agent 执行任务后自动写入 EpisodicMemory后续相似任务可检索到历史经验
---
### U5. MCP Server 与 Client
**Goal:** 实现 MCP Server暴露 Agent 能力)和 MCP Client调用外部 MCP 工具),完成 MCPTool。
**Dependencies:** U3
**Files:**
- `fischer-agentkit/src/agentkit/mcp/__init__.py`
- `fischer-agentkit/src/agentkit/mcp/server.py`
- `fischer-agentkit/src/agentkit/mcp/client.py`
- `fischer-agentkit/src/agentkit/mcp/transport.py`
- `fischer-agentkit/src/agentkit/tools/mcp_tool.py`
- `fischer-agentkit/tests/unit/test_mcp_server.py`
- `fischer-agentkit/tests/unit/test_mcp_client.py`
- `fischer-agentkit/tests/unit/test_mcp_tool.py`
- `fischer-agentkit/tests/integration/test_mcp_roundtrip.py`
**Approach:**
1. MCP Server
- 基于 FastAPI 实现,支持 Streamable HTTP 传输
- 端点:`/tools/list`(列出可用工具)、`/tools/call`(调用工具)、`/resources/read`(读取资源)
- 自动将 ToolRegistry 中注册的工具暴露为 MCP 工具
- 支持 SSE 流式响应
2. MCP Client
- 连接外部 MCP ServerHTTP/SSE
- 自动发现远程工具并注册到本地 ToolRegistry
- 支持工具调用的流式响应
3. MCPTool
- 继承 Tool 基类,内部通过 MCP Client 调用远程工具
- 自动从 MCP Server 获取 input_schema
4. Transport 层:
- 抽象 `Transport` 接口(`send_request`, `receive_response`
- 实现 `HTTPTransport`Streamable HTTP`SSETransport`Server-Sent Events
**Patterns to follow:** MCP 官方 Python SDK 的 Server/Client 模式
**Test scenarios:**
- MCP Server 启动后 `/tools/list` 返回已注册工具列表
- MCP Client 连接 Server 并调用工具返回正确结果
- MCPTool 通过 Client 调用远程工具
- SSE 流式响应正确传输
- MCP Server 自动暴露 ToolRegistry 中的 FunctionTool
- 外部 MCP Server 的工具自动注册到本地 ToolRegistry
**Verification:** 启动 MCP Server通过 MCP Client 调用 Agent 能力,端到端成功
---
### U6. 自我进化引擎Level 3
**Goal:** 实现执行反思、DSPy 风格 Prompt 自动优化、策略调优、A/B 测试框架。
**Dependencies:** U1, U4
**Files:**
- `fischer-agentkit/src/agentkit/evolution/__init__.py`
- `fischer-agentkit/src/agentkit/evolution/reflector.py`
- `fischer-agentkit/src/agentkit/evolution/prompt_optimizer.py`
- `fischer-agentkit/src/agentkit/evolution/strategy_tuner.py`
- `fischer-agentkit/src/agentkit/evolution/ab_tester.py`
- `fischer-agentkit/src/agentkit/evolution/evolution_store.py`
- `fischer-agentkit/tests/unit/test_reflector.py`
- `fischer-agentkit/tests/unit/test_prompt_optimizer.py`
- `fischer-agentkit/tests/unit/test_strategy_tuner.py`
- `fischer-agentkit/tests/unit/test_ab_tester.py`
- `fischer-agentkit/tests/integration/test_evolution_loop.py`
**Approach:**
1. `Reflector`(执行反思):
- 每次任务完成后自动评估:成功/失败、质量评分(基于 output_schema 约束和业务指标)
- 提取模式:常见失败原因、高效策略、低效路径
- 生成反思总结存入 EpisodicMemory
2. `PromptOptimizer`DSPy 风格编译器):
- `Signature`:定义 `input_fields -> output_fields` 的结构化签名
- `Module`:可组合的 Prompt 策略ChainOfThought, ReAct, FewShot
- `Optimizer`
- `BootstrapFewShot`:从成功案例中自动构建 few-shot 示例
- `MIPROv2`:多目标 Prompt 优化(指令 + few-shot 联合优化)
- 基于历史任务数据编译,产出优化后的 Prompt
3. `StrategyTuner`(策略调优):
- 可调参数temperature, tool_selection_weights, pipeline_path
- 基于 Bayesian Optimization 搜索最优参数组合
- 每次调优记录 before/after 指标
4. `ABTester`A/B 测试框架):
- 支持配置分流比例(如 80% 原版 / 20% 实验版)
- 自动收集实验组和对照组的效果指标
- 统计显著性检验t-test达到置信度后自动决策
5. `EvolutionStore`(进化日志):
- 表结构:`id, agent_name, change_type(prompt/strategy/pipeline), before, after, ab_test_id, status(active/rolled_back), created_at`
- 支持回滚:`rollback(evolution_id)`
**Patterns to follow:** DSPy 的 Signature/Module/Optimizer 三层架构
**Test scenarios:**
- Reflector 评估成功任务生成正面反思
- Reflector 评估失败任务提取失败模式
- PromptOptimizer 从 10 个成功案例自动生成 few-shot 示例
- PromptOptimizer 优化后的 Prompt 在测试集上效果提升
- StrategyTuner 调整 temperature 后效果改善
- ABTester 80/20 分流,实验组效果显著后自动应用
- EvolutionStore 记录变更并支持回滚
**Verification:** Agent 执行 20 次任务后PromptOptimizer 自动优化 PromptABTest 验证效果提升后自动应用
---
### U7. 多 Agent 协同增强
**Goal:** Pipeline Engine 支持并行执行、Handoff 机制、动态 Pipeline 组合。
**Dependencies:** U1, U2
**Files:**
- `fischer-agentkit/src/agentkit/orchestrator/__init__.py`
- `fischer-agentkit/src/agentkit/orchestrator/pipeline_engine.py`
- `fischer-agentkit/src/agentkit/orchestrator/pipeline_schema.py`
- `fischer-agentkit/src/agentkit/orchestrator/pipeline_loader.py`
- `fischer-agentkit/src/agentkit/orchestrator/handoff.py`
- `fischer-agentkit/src/agentkit/orchestrator/dynamic_pipeline.py`
- `fischer-agentkit/tests/unit/test_pipeline_parallel.py`
- `fischer-agentkit/tests/unit/test_handoff.py`
- `fischer-agentkit/tests/unit/test_dynamic_pipeline.py`
**Approach:**
1. Pipeline 并行执行:
- 拓扑排序后按依赖层级分组(同层无依赖)
- 同层内使用 `asyncio.gather` 并行执行
- 并行 stage 的结果合并后传给下一层
2. Handoff 机制:
- `HandoffManager` 管理转交请求
- Agent 调用 `self.handoff(target_agent, context, reason)` 发起转交
- 通过 Redis Pub/Sub `agent:{target}:handoff` 频道通知目标 Agent
- 目标 Agent 接收后创建新任务,携带源 Agent 的上下文
- 源 Agent 的任务状态变为 `HANDOFF`
3. 动态 Pipeline
- 支持 `sub_pipeline` stage 类型:引用另一个 YAML Pipeline
- 支持 `conditional_pipeline`:根据运行时条件选择子流程
- 支持 `loop_pipeline`:循环执行直到条件满足
4. 事件驱动替代轮询:
- Pipeline Engine 通过 Redis Pub/Sub 订阅 `agent:{name}:result` 频道
- 任务完成后自动触发下一 stage无需轮询
**Patterns to follow:** 现有 `PipelineEngine` 的 DAG 拓扑排序和变量解析
**Test scenarios:**
- 3 个无依赖 stage 并行执行,总耗时约等于最长单个 stage
- HandoffAgent A 转交任务给 Agent BB 接收并执行
- 动态 Pipeline根据条件选择不同的子流程
- 子 Pipeline 嵌套执行并正确传递变量
- 事件驱动:任务完成后自动触发下一 stage无轮询间隔
- 循环 Pipeline条件满足后退出循环
**Verification:** 内容生产 Pipeline 的 deai_processing 和 geo_optimization 并行执行,总耗时减少约 40%
---
### U8. GEO 业务系统适配与迁移
**Goal:** 将 GEO 项目的 8 个 Agent 迁移到新框架,业务 Service 注册为 Tool实现完全解耦。
**Dependencies:** U1, U2, U3, U4, U5, U6, U7
**Files:**
- `geo/backend/app/agent_framework/agents/configs/citation_detector.yaml`
- `geo/backend/app/agent_framework/agents/configs/content_generator.yaml`
- `geo/backend/app/agent_framework/agents/configs/deai_agent.yaml`
- `geo/backend/app/agent_framework/agents/configs/geo_optimizer.yaml`
- `geo/backend/app/agent_framework/agents/configs/monitor.yaml`
- `geo/backend/app/agent_framework/agents/configs/schema_advisor.yaml`
- `geo/backend/app/agent_framework/agents/configs/competitor_analyzer.yaml`
- `geo/backend/app/agent_framework/agents/configs/trend_agent.yaml`
- `geo/backend/app/agent_framework/agents/custom_handlers/citation_handler.py`
- `geo/backend/app/agent_framework/agents/custom_handlers/monitor_handler.py`
- `geo/backend/app/agent_framework/tools/__init__.py`
- `geo/backend/app/agent_framework/tools/citation_tools.py`
- `geo/backend/app/agent_framework/tools/content_tools.py`
- `geo/backend/app/agent_framework/tools/knowledge_tools.py`
- `geo/backend/app/agent_framework/tools/monitor_tools.py`
- `geo/backend/app/agent_framework/tools/schema_tools.py`
- `geo/backend/app/agent_framework/tools/competitor_tools.py`
- `geo/backend/app/agent_framework/tools/trend_tools.py`
- `geo/backend/app/models/agent.py`(新增表)
- `geo/backend/requirements.txt`(添加 fischer-agentkit 依赖)
- `geo/backend/app/agent_framework/__init__.py`(适配层)
**Approach:**
1. 业务 Tool 注册:
- `CitationTools``execute_single_platform`, `get_or_create_task`, `calculate_next_query_at`
- `ContentTools``retrieve_knowledge`(包装 RAGService
- `MonitorTools``check_and_compare`, `generate_change_report`, `create_monitoring_record`
- `SchemaTools``identify_missing_dimensions`, `match_templates`, `validate_json_ld`
- `CompetitorTools``analyze_competitor`
- `TrendTools``analyze_trends`, `get_hotspots`
2. Agent 配置迁移:
- LLM 驱动型 AgentContentGenerator, DeAI, GEOOptimizer, SchemaAdvisor→ YAML 配置 + `llm_generate` 模式,零代码
- Service 代理型 AgentCitationDetector, Monitor→ YAML 配置 + `custom` handler仅保留业务逻辑方法
- CompetitorAnalyzer, TrendAgent → YAML 配置 + `tool_call` 模式
3. 数据库迁移:
- 新增 `episodic_memories`
- 新增 `evolution_logs`
- 新增 `ab_test_configs`
4. 适配层:
- `geo/backend/app/agent_framework/` 保留为适配层import from `agentkit`
- 现有 API 端点(`/agents/`, `/agents/tasks/`)保持不变
- 现有 Pipeline YAML 保持兼容
**Patterns to follow:** 现有 Agent 的业务逻辑实现,迁移时保持行为一致
**Test scenarios:**
- 8 个 Agent 迁移后行为与原版一致(回归测试)
- 新增 Agent 只需 YAML 配置,无需写 Python 代码
- 业务 Tool 注册后 Agent 可通过 ToolRegistry 发现和调用
- EpisodicMemory 自动记录任务经验
- EvolutionStore 记录进化变更
- 现有 API 端点功能不变
- 现有 Pipeline YAML 正常执行
**Verification:** 运行现有 Agent 集成测试全部通过,新增一个 YAML-only Agent 成功执行任务
---
## Scope Boundaries
### In Scope
- `fischer-agentkit` 独立包的完整实现Core, Tools, Memory, Evolution, Orchestrator, MCP
- GEO 项目 8 个 Agent 的迁移和适配
- MCP Server/Client 双向支持
- Level 3 自我进化(反思 + Prompt 优化 + 策略调优 + A/B 测试)
- Pipeline 并行执行、Handoff、动态 Pipeline
- 三层记忆系统
- 数据库迁移(新增表)
### Deferred for Later
- A2A Protocol 支持(跨组织 Agent 协作,待 MCP 稳定后再评估)
- Agent 可视化编排 UI前端拖拽式 Pipeline 编辑器)
- Agent 市场/共享机制(跨组织共享 Agent 配置和 Tool
- 多租户隔离增强Agent 级别的资源隔离和配额)
- Agent 安全沙箱Tool 执行的权限控制和审计)
### Outside This Project's Identity
- 通用 LLM Gateway不属于 Agent 框架范畴)
- 替代现有 LLMFactory保持现有 LLM 调用体系不变)
- 前端 Agent 管理界面重构(本次聚焦后端架构)
---
## Risks & Dependencies
| Risk | Impact | Mitigation |
|------|--------|------------|
| MCP 协议规范变动2025-2026 仍在演进) | MCPTool/MCPClient 需要跟进修改 | 抽象 Transport 层,协议变更只影响 Transport 实现 |
| DSPy 风格 Prompt 优化可能需要大量训练数据 | 优化效果不明显 | BootstrapFewShot 从少量数据开始MIPROv2 需 50+ 案例时才启用 |
| 独立包与 GEO 项目的版本同步 | GEO 升级 agentkit 版本可能引入 breaking change | 语义化版本管理GEO 项目 pin 版本CI 自动测试兼容性 |
| Episodic Memory 数据量增长 | pgvector 查询性能下降 | 设置 TTL 自动清理、定期归档、HNSW 索引优化 |
| 迁移期间 8 个 Agent 行为不一致 | 业务回归 | 逐个迁移 + 回归测试,保留旧代码直到新代码验证通过 |
| Pipeline 并行执行引入并发问题 | 结果不一致 | 同层 stage 只读共享变量,写入隔离 |
---
## System-Wide Impact
- **数据库**:新增 3 张表episodic_memories, evolution_logs, ab_test_configs需 Alembic 迁移
- **Redis**:新增 Working Memory key 空间和 Handoff 频道,内存使用增加约 20%
- **API**:现有 `/agents/` 端点保持兼容,新增 `/agents/{name}/evolution` 查看进化历史
- **部署**`fischer-agentkit` 需发布到私有 PyPIGEO 项目 Dockerfile 需添加安装步骤
- **监控**:新增进化相关 Prometheus 指标evolution_attempts_total, evolution_success_rate, ab_test_decisions
- **依赖**:新增 `mcp` Python 包依赖MCP SDK
---
## Sources & Research
- 现有 Agent 框架代码:`backend/app/agent_framework/` 全部文件
- 现有 RAG 服务:`backend/app/services/knowledge/rag_service.py`, `retriever.py`, `enhanced_rag.py`
- 现有 LLM 工厂:`backend/app/services/llm/factory.py`, `base.py`
- MCP 官方规范Model Context Protocol (Anthropic, 2024-2025)
- DSPy 框架Stanford NLP, Prompt 自动优化范式
- Google A2A ProtocolAgent-to-Agent HTTP 协议 (2025.4)
- LangGraph 0.2.xStateGraph + Checkpoint 架构
- OpenAI Agents SDK 1.0Handoff 模式
- Google ADK 0.1Agent 组合模式Sequential/Parallel/Loop

File diff suppressed because one or more lines are too long

View File

@ -1,35 +1,4 @@
{
"status": "failed",
"failedTests": [
"49bb04a247fb3a93c942-35c9c801dc12bf7d35f2",
"49bb04a247fb3a93c942-698fa46448ec3c4fcc54",
"49bb04a247fb3a93c942-1e4abb88d57e30e27b50",
"49bb04a247fb3a93c942-e8f0348aed6f07ff875b",
"49bb04a247fb3a93c942-425d1f1370bc2355e391",
"49bb04a247fb3a93c942-256ee0a55d33f5229d15",
"49bb04a247fb3a93c942-5f9ad201eca7dc8ab295",
"49bb04a247fb3a93c942-ef50deff2e29416bfdcb",
"49bb04a247fb3a93c942-ea929ccac5306e953e14",
"49bb04a247fb3a93c942-fa6e92367349adbf7ce6",
"49bb04a247fb3a93c942-66abb2794b5dc22a4d2c",
"49bb04a247fb3a93c942-c8f2635d09bd86abde8c",
"49bb04a247fb3a93c942-495ca225b90357a78c73",
"49bb04a247fb3a93c942-9542aeae6bc1bec1de36",
"49bb04a247fb3a93c942-6d102a4fe6c78e5adea3",
"49bb04a247fb3a93c942-16202bb8288239216298",
"49bb04a247fb3a93c942-103e4c5c525b115390ab",
"49bb04a247fb3a93c942-8ff6f91ccaf60e9cf222",
"49bb04a247fb3a93c942-594d166a39b1685af15e",
"c05d59a3df11e753f1a9-ebd92d58fe79ae1246f9",
"30f4937bf66a39abf0e1-654253a85d12a575b58b",
"30f4937bf66a39abf0e1-531589fe51a02f263015",
"655e7220bbb90e50e0e9-a2f8fcdf1070b299cd54",
"655e7220bbb90e50e0e9-49b697c7e1ba346c76b7",
"655e7220bbb90e50e0e9-c5cd58b8cddad9907a64",
"655e7220bbb90e50e0e9-696b9830da3889bddb0f",
"655e7220bbb90e50e0e9-e8bb52f185bd86dce2da",
"fe534f0825407f213faa-71fec5acbf34a381ec37",
"fe534f0825407f213faa-c40f01c5980780ee6589",
"fe534f0825407f213faa-a6a9b88ab9323419b7d3"
]
"status": "passed",
"failedTests": []
}