geo/tests/test_business_flow.py

441 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import uuid
from datetime import datetime, timezone, timedelta
from unittest.mock import AsyncMock, patch
import pytest
import pytest_asyncio
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.main import app
from app.api.deps import get_current_user
from app.database import get_db
from app.models.citation_record import CitationRecord
from app.models.query import Query
from app.models.query_task import QueryTask
from app.models.user import User
from app.services.auth import hash_password
# ---------------------------------------------------------------------------
# Helper fixtures / functions
# ---------------------------------------------------------------------------
@pytest_asyncio.fixture
async def user_a(test_session: AsyncSession):
"""Create a real user in the test DB."""
user = User(
email="user_a@example.com",
password_hash=hash_password("password123"),
name="User A",
plan="free",
max_queries=5,
)
test_session.add(user)
await test_session.commit()
await test_session.refresh(user)
return user
@pytest_asyncio.fixture
async def user_b(test_session: AsyncSession):
"""Create another real user in the test DB."""
user = User(
email="user_b@example.com",
password_hash=hash_password("password456"),
name="User B",
plan="free",
max_queries=5,
)
test_session.add(user)
await test_session.commit()
await test_session.refresh(user)
return user
@pytest_asyncio.fixture
async def auth_client_a(async_client, user_a):
"""Login user_a and return client with auth headers."""
response = await async_client.post(
"/api/v1/auth/login",
json={"email": "user_a@example.com", "password": "password123"},
)
assert response.status_code == 200
token = response.json()["access_token"]
# Return a small helper or just the headers
return {"Authorization": f"Bearer {token}"}
@pytest_asyncio.fixture
async def auth_client_b(async_client, user_b):
"""Login user_b and return auth headers."""
response = await async_client.post(
"/api/v1/auth/login",
json={"email": "user_b@example.com", "password": "password456"},
)
assert response.status_code == 200
token = response.json()["access_token"]
return {"Authorization": f"Bearer {token}"}
# ---------------------------------------------------------------------------
# 1. 完整流程:注册 -> 登录 -> 创建查询词 -> 查看列表
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_full_user_flow(async_client, override_get_db):
# Register
reg_resp = await async_client.post(
"/api/v1/auth/register",
json={"email": "flow@example.com", "password": "flowpass", "name": "Flow User"},
)
assert reg_resp.status_code == 201
reg_data = reg_resp.json()
assert reg_data["email"] == "flow@example.com"
# Login
login_resp = await async_client.post(
"/api/v1/auth/login",
json={"email": "flow@example.com", "password": "flowpass"},
)
assert login_resp.status_code == 200
token = login_resp.json()["access_token"]
headers = {"Authorization": f"Bearer {token}"}
# Create query
create_resp = await async_client.post(
"/api/v1/queries/",
headers=headers,
json={
"keyword": "flow keyword",
"target_brand": "FlowBrand",
"platforms": ["wenxin"],
"frequency": "daily",
},
)
assert create_resp.status_code == 201
query_data = create_resp.json()
assert query_data["keyword"] == "flow keyword"
assert query_data["target_brand"] == "FlowBrand"
# List queries
list_resp = await async_client.get("/api/v1/queries/", headers=headers)
assert list_resp.status_code == 200
list_data = list_resp.json()
assert list_data["total"] == 1
assert len(list_data["items"]) == 1
assert list_data["items"][0]["keyword"] == "flow keyword"
# ---------------------------------------------------------------------------
# 2. 查询词生命周期:创建 -> 更新 -> 暂停 -> 恢复 -> 删除
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_query_lifecycle(async_client, override_get_db, auth_client_a):
headers = auth_client_a
# Create
create_resp = await async_client.post(
"/api/v1/queries/",
headers=headers,
json={
"keyword": "lifecycle keyword",
"target_brand": "LifecycleBrand",
"platforms": ["wenxin"],
"frequency": "weekly",
},
)
assert create_resp.status_code == 201
query_id = create_resp.json()["id"]
# Update
update_resp = await async_client.put(
f"/api/v1/queries/{query_id}",
headers=headers,
json={"keyword": "updated keyword", "frequency": "daily"},
)
assert update_resp.status_code == 200
assert update_resp.json()["keyword"] == "updated keyword"
assert update_resp.json()["frequency"] == "daily"
# Pause (update status to paused)
pause_resp = await async_client.put(
f"/api/v1/queries/{query_id}",
headers=headers,
json={"status": "paused"},
)
assert pause_resp.status_code == 200
assert pause_resp.json()["status"] == "paused"
# Resume (update status back to active)
resume_resp = await async_client.put(
f"/api/v1/queries/{query_id}",
headers=headers,
json={"status": "active"},
)
assert resume_resp.status_code == 200
assert resume_resp.json()["status"] == "active"
# Delete
delete_resp = await async_client.delete(
f"/api/v1/queries/{query_id}",
headers=headers,
)
assert delete_resp.status_code == 204
# Verify deletion
get_resp = await async_client.get(f"/api/v1/queries/{query_id}", headers=headers)
assert get_resp.status_code == 404
# ---------------------------------------------------------------------------
# 3. 查询数量限制free 用户最多 5 个
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_query_limit_free_user(async_client, override_get_db, auth_client_a):
headers = auth_client_a
# Create 5 queries (limit for free plan)
for i in range(5):
resp = await async_client.post(
"/api/v1/queries/",
headers=headers,
json={
"keyword": f"keyword {i}",
"target_brand": f"Brand{i}",
"platforms": ["wenxin"],
"frequency": "daily",
},
)
assert resp.status_code == 201, f"Failed to create query {i}"
# 6th query should be rejected
resp = await async_client.post(
"/api/v1/queries/",
headers=headers,
json={
"keyword": "over limit",
"target_brand": "OverBrand",
"platforms": ["wenxin"],
"frequency": "daily",
},
)
assert resp.status_code == 403
assert "Query limit exceeded" in resp.json()["detail"]
# ---------------------------------------------------------------------------
# 4. 引用统计数据正确性
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_citation_stats_correctness(
async_client, override_get_db, auth_client_a, test_session
):
headers = auth_client_a
# Create a query
create_resp = await async_client.post(
"/api/v1/queries/",
headers=headers,
json={
"keyword": "stats keyword",
"target_brand": "StatsBrand",
"platforms": ["wenxin", "kimi"],
"frequency": "weekly",
},
)
assert create_resp.status_code == 201
query_id = uuid.UUID(create_resp.json()["id"])
# Insert citation records directly
records = [
CitationRecord(
query_id=query_id,
platform="wenxin",
cited=True,
citation_position=1,
citation_text="mention 1",
competitor_brands=[],
),
CitationRecord(
query_id=query_id,
platform="wenxin",
cited=False,
citation_position=None,
citation_text=None,
competitor_brands=[],
),
CitationRecord(
query_id=query_id,
platform="kimi",
cited=True,
citation_position=2,
citation_text="mention 2",
competitor_brands=[],
),
]
for r in records:
test_session.add(r)
await test_session.commit()
# Call stats API
stats_resp = await async_client.get("/api/v1/citations/stats", headers=headers)
assert stats_resp.status_code == 200
stats = stats_resp.json()
assert stats["total_queries"] == 3
assert stats["total_citations"] == 2
assert stats["citation_rate"] == pytest.approx(0.67, abs=0.01)
assert stats["avg_position"] == pytest.approx(1.5, abs=0.1)
by_platform = stats["by_platform"]
assert "wenxin" in by_platform
assert "kimi" in by_platform
assert by_platform["wenxin"]["queries"] == 2
assert by_platform["wenxin"]["citations"] == 1
assert by_platform["kimi"]["queries"] == 1
assert by_platform["kimi"]["citations"] == 1
# ---------------------------------------------------------------------------
# 5. CSV 导出功能
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_export_csv(
async_client, override_get_db, auth_client_a, test_session
):
headers = auth_client_a
# Create query
create_resp = await async_client.post(
"/api/v1/queries/",
headers=headers,
json={
"keyword": "csv keyword",
"target_brand": "CSVBrand",
"platforms": ["wenxin"],
"frequency": "weekly",
},
)
assert create_resp.status_code == 201
query_id = create_resp.json()["id"]
# Insert a citation record
record = CitationRecord(
query_id=uuid.UUID(query_id),
platform="wenxin",
cited=True,
citation_position=1,
citation_text="CSV test text",
competitor_brands=[],
)
test_session.add(record)
await test_session.commit()
# Export CSV
export_resp = await async_client.get(
f"/api/v1/reports/export/csv?query_id={query_id}",
headers=headers,
)
assert export_resp.status_code == 200
assert export_resp.headers["content-type"].startswith("text/csv")
body = export_resp.text
assert "CSV test text" in body
assert "wenxin" in body
# ---------------------------------------------------------------------------
# 6. 手动触发查询POST run-now 应创建 QueryTask 记录
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_run_now_creates_query_task(
async_client, override_get_db, auth_client_a, test_session
):
headers = auth_client_a
# Create an active query
create_resp = await async_client.post(
"/api/v1/queries/",
headers=headers,
json={
"keyword": "run now keyword",
"target_brand": "RunNowBrand",
"platforms": ["wenxin", "kimi"],
"frequency": "weekly",
},
)
assert create_resp.status_code == 201
query_id = create_resp.json()["id"]
# Trigger run-now
run_resp = await async_client.post(
f"/api/v1/queries/{query_id}/run-now",
headers=headers,
)
assert run_resp.status_code == 202
data = run_resp.json()
assert data["status"] == "pending"
assert "task_id" in data
# Verify QueryTask records in DB
stmt = select(QueryTask).where(QueryTask.query_id == uuid.UUID(query_id))
result = await test_session.execute(stmt)
tasks = result.scalars().all()
assert len(tasks) == 2
platforms = {t.platform for t in tasks}
assert platforms == {"wenxin", "kimi"}
for t in tasks:
assert t.status == "pending"
# ---------------------------------------------------------------------------
# 7. 权限隔离:用户 A 无法访问用户 B 的数据
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_permission_isolation(
async_client, override_get_db, auth_client_a, auth_client_b
):
headers_a = auth_client_a
headers_b = auth_client_b
# User A creates a query
create_resp = await async_client.post(
"/api/v1/queries/",
headers=headers_a,
json={
"keyword": "private keyword",
"target_brand": "PrivateBrand",
"platforms": ["wenxin"],
"frequency": "weekly",
},
)
assert create_resp.status_code == 201
query_id = create_resp.json()["id"]
# User B tries to access User A's query
get_resp = await async_client.get(
f"/api/v1/queries/{query_id}",
headers=headers_b,
)
assert get_resp.status_code == 404
# User B tries to update User A's query
put_resp = await async_client.put(
f"/api/v1/queries/{query_id}",
headers=headers_b,
json={"keyword": "hacked"},
)
assert put_resp.status_code == 404
# User B tries to delete User A's query
del_resp = await async_client.delete(
f"/api/v1/queries/{query_id}",
headers=headers_b,
)
assert del_resp.status_code == 404
# User B tries to run-now User A's query
run_resp = await async_client.post(
f"/api/v1/queries/{query_id}/run-now",
headers=headers_b,
)
assert run_resp.status_code == 404