geo/backend/tests/test_services/test_batch_query_service.py

219 lines
7.9 KiB
Python

import asyncio
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.services.ai_engine.base import AIEngineAdapter, AIQueryResult, EngineType
def _make_result(
engine_type: EngineType,
query: str = "test query",
has_brand: bool = False,
has_competitor: bool = False,
) -> AIQueryResult:
return AIQueryResult(
engine_type=engine_type,
query=query,
raw_response="some response",
citations=[],
has_brand_citation=has_brand,
has_competitor_citation=has_competitor,
brand_context="brand context" if has_brand else None,
competitor_contexts=["comp context"] if has_competitor else [],
response_time_ms=100,
timestamp=datetime.now(UTC),
)
class _StubAdapter(AIEngineAdapter):
def __init__(self, engine_type: EngineType, result: AIQueryResult | None = None, side_effect=None):
super().__init__(api_key="test-key")
self._engine_type = engine_type
self._result = result
self._side_effect = side_effect
async def query(self, query: str, brand_name: str, competitor_names: list[str] | None = None) -> AIQueryResult:
if self._side_effect:
raise self._side_effect
return self._result
def get_engine_type(self) -> EngineType:
return self._engine_type
def _get_env_key(self) -> str | None:
return ""
class TestBatchQueryServiceInit:
@pytest.mark.asyncio
async def test_init_with_adapters(self):
from app.services.ai_engine.batch_query import BatchQueryService
adapters = {
"chatgpt": _StubAdapter(EngineType.CHATGPT),
"perplexity": _StubAdapter(EngineType.PERPLEXITY),
}
service = BatchQueryService(adapters)
assert service.adapters is adapters
assert len(service.adapters) == 2
class TestBatchQuerySingleEngine:
@pytest.mark.asyncio
async def test_query_single_success(self):
from app.services.ai_engine.batch_query import BatchQueryService
expected = _make_result(EngineType.CHATGPT, has_brand=True)
adapters = {"chatgpt": _StubAdapter(EngineType.CHATGPT, result=expected)}
service = BatchQueryService(adapters)
result = await service.query_single(
EngineType.CHATGPT, "best insurance", "BrandX", ["CompY"]
)
assert result == expected
assert result.engine_type == EngineType.CHATGPT
assert result.has_brand_citation is True
@pytest.mark.asyncio
async def test_query_single_unknown_engine(self):
from app.services.ai_engine.batch_query import BatchQueryService
adapters = {"chatgpt": _StubAdapter(EngineType.CHATGPT)}
service = BatchQueryService(adapters)
with pytest.raises(ValueError, match="Unknown engine type"):
await service.query_single(EngineType.KIMI, "test", "BrandX")
class TestBatchQueryParallel:
@pytest.mark.asyncio
async def test_query_batch_multiple_engines(self):
from app.services.ai_engine.batch_query import BatchQueryService
r1 = _make_result(EngineType.CHATGPT, has_brand=True)
r2 = _make_result(EngineType.PERPLEXITY, has_brand=False)
adapters = {
"chatgpt": _StubAdapter(EngineType.CHATGPT, result=r1),
"perplexity": _StubAdapter(EngineType.PERPLEXITY, result=r2),
}
service = BatchQueryService(adapters)
results = await service.query_batch(
[EngineType.CHATGPT, EngineType.PERPLEXITY],
"best insurance",
"BrandX",
)
assert len(results) == 2
engine_types = {r.engine_type for r in results}
assert engine_types == {EngineType.CHATGPT, EngineType.PERPLEXITY}
@pytest.mark.asyncio
async def test_query_batch_partial_failure(self):
from app.services.ai_engine.batch_query import BatchQueryService
r1 = _make_result(EngineType.CHATGPT, has_brand=True)
adapters = {
"chatgpt": _StubAdapter(EngineType.CHATGPT, result=r1),
"perplexity": _StubAdapter(
EngineType.PERPLEXITY, side_effect=Exception("API error")
),
}
service = BatchQueryService(adapters)
results = await service.query_batch(
[EngineType.CHATGPT, EngineType.PERPLEXITY],
"best insurance",
"BrandX",
)
assert len(results) == 1
assert results[0].engine_type == EngineType.CHATGPT
@pytest.mark.asyncio
async def test_query_batch_all_fail(self):
from app.services.ai_engine.batch_query import BatchQueryService
adapters = {
"chatgpt": _StubAdapter(EngineType.CHATGPT, side_effect=Exception("err")),
"perplexity": _StubAdapter(EngineType.PERPLEXITY, side_effect=Exception("err")),
}
service = BatchQueryService(adapters)
results = await service.query_batch(
[EngineType.CHATGPT, EngineType.PERPLEXITY],
"test",
"BrandX",
)
assert results == []
class TestBatchQueryAggregation:
@pytest.mark.asyncio
async def test_results_aggregation(self):
from app.services.ai_engine.batch_query import BatchQueryService
r1 = _make_result(EngineType.CHATGPT, has_brand=True, has_competitor=False)
r2 = _make_result(EngineType.PERPLEXITY, has_brand=False, has_competitor=True)
r3 = _make_result(EngineType.KIMI, has_brand=True, has_competitor=True)
adapters = {
"chatgpt": _StubAdapter(EngineType.CHATGPT, result=r1),
"perplexity": _StubAdapter(EngineType.PERPLEXITY, result=r2),
"kimi": _StubAdapter(EngineType.KIMI, result=r3),
}
service = BatchQueryService(adapters)
results = await service.query_batch(
[EngineType.CHATGPT, EngineType.PERPLEXITY, EngineType.KIMI],
"test",
"BrandX",
["CompY"],
)
assert len(results) == 3
brand_cited = [r for r in results if r.has_brand_citation]
competitor_cited = [r for r in results if r.has_competitor_citation]
assert len(brand_cited) == 2
assert len(competitor_cited) == 2
class TestCitationRateCalculation:
def test_brand_citation_rate(self):
from app.services.ai_engine.batch_query import BatchQueryService
results = [
_make_result(EngineType.CHATGPT, has_brand=True),
_make_result(EngineType.PERPLEXITY, has_brand=True),
_make_result(EngineType.KIMI, has_brand=False),
]
service = BatchQueryService({})
rate = service.calculate_citation_rate(results)
assert rate["total_engines"] == 3
assert rate["brand_citation_count"] == 2
assert rate["brand_citation_rate"] == pytest.approx(2 / 3)
def test_competitor_citation_rate(self):
from app.services.ai_engine.batch_query import BatchQueryService
results = [
_make_result(EngineType.CHATGPT, has_competitor=True),
_make_result(EngineType.PERPLEXITY, has_competitor=False),
_make_result(EngineType.KIMI, has_competitor=True),
_make_result(EngineType.WENXIN, has_competitor=True),
]
service = BatchQueryService({})
rate = service.calculate_citation_rate(results)
assert rate["total_engines"] == 4
assert rate["competitor_citation_count"] == 3
assert rate["competitor_citation_rate"] == pytest.approx(3 / 4)
def test_empty_results(self):
from app.services.ai_engine.batch_query import BatchQueryService
service = BatchQueryService({})
rate = service.calculate_citation_rate([])
assert rate["total_engines"] == 0
assert rate["brand_citation_count"] == 0
assert rate["brand_citation_rate"] == 0
assert rate["competitor_citation_count"] == 0
assert rate["competitor_citation_rate"] == 0