216 lines
7.8 KiB
Python
216 lines
7.8 KiB
Python
import asyncio
|
|
from datetime import UTC, datetime
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
|
|
from app.services.ai_engine.base import AIEngineAdapter, AIQueryResult, EngineType
|
|
|
|
|
|
def _make_result(
|
|
engine_type: EngineType,
|
|
query: str = "test query",
|
|
has_brand: bool = False,
|
|
has_competitor: bool = False,
|
|
) -> AIQueryResult:
|
|
return AIQueryResult(
|
|
engine_type=engine_type,
|
|
query=query,
|
|
raw_response="some response",
|
|
citations=[],
|
|
has_brand_citation=has_brand,
|
|
has_competitor_citation=has_competitor,
|
|
brand_context="brand context" if has_brand else None,
|
|
competitor_contexts=["comp context"] if has_competitor else [],
|
|
response_time_ms=100,
|
|
timestamp=datetime.now(UTC),
|
|
)
|
|
|
|
|
|
class _StubAdapter(AIEngineAdapter):
|
|
def __init__(self, engine_type: EngineType, result: AIQueryResult | None = None, side_effect=None):
|
|
super().__init__(api_key="test-key")
|
|
self._engine_type = engine_type
|
|
self._result = result
|
|
self._side_effect = side_effect
|
|
|
|
async def query(self, query: str, brand_name: str, competitor_names: list[str] | None = None) -> AIQueryResult:
|
|
if self._side_effect:
|
|
raise self._side_effect
|
|
return self._result
|
|
|
|
def get_engine_type(self) -> EngineType:
|
|
return self._engine_type
|
|
|
|
|
|
class TestBatchQueryServiceInit:
|
|
@pytest.mark.asyncio
|
|
async def test_init_with_adapters(self):
|
|
from app.services.ai_engine.batch_query import BatchQueryService
|
|
|
|
adapters = {
|
|
"chatgpt": _StubAdapter(EngineType.CHATGPT),
|
|
"perplexity": _StubAdapter(EngineType.PERPLEXITY),
|
|
}
|
|
service = BatchQueryService(adapters)
|
|
assert service.adapters is adapters
|
|
assert len(service.adapters) == 2
|
|
|
|
|
|
class TestBatchQuerySingleEngine:
|
|
@pytest.mark.asyncio
|
|
async def test_query_single_success(self):
|
|
from app.services.ai_engine.batch_query import BatchQueryService
|
|
|
|
expected = _make_result(EngineType.CHATGPT, has_brand=True)
|
|
adapters = {"chatgpt": _StubAdapter(EngineType.CHATGPT, result=expected)}
|
|
service = BatchQueryService(adapters)
|
|
|
|
result = await service.query_single(
|
|
EngineType.CHATGPT, "best insurance", "BrandX", ["CompY"]
|
|
)
|
|
assert result == expected
|
|
assert result.engine_type == EngineType.CHATGPT
|
|
assert result.has_brand_citation is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_query_single_unknown_engine(self):
|
|
from app.services.ai_engine.batch_query import BatchQueryService
|
|
|
|
adapters = {"chatgpt": _StubAdapter(EngineType.CHATGPT)}
|
|
service = BatchQueryService(adapters)
|
|
|
|
with pytest.raises(ValueError, match="Unknown engine type"):
|
|
await service.query_single(EngineType.KIMI, "test", "BrandX")
|
|
|
|
|
|
class TestBatchQueryParallel:
|
|
@pytest.mark.asyncio
|
|
async def test_query_batch_multiple_engines(self):
|
|
from app.services.ai_engine.batch_query import BatchQueryService
|
|
|
|
r1 = _make_result(EngineType.CHATGPT, has_brand=True)
|
|
r2 = _make_result(EngineType.PERPLEXITY, has_brand=False)
|
|
adapters = {
|
|
"chatgpt": _StubAdapter(EngineType.CHATGPT, result=r1),
|
|
"perplexity": _StubAdapter(EngineType.PERPLEXITY, result=r2),
|
|
}
|
|
service = BatchQueryService(adapters)
|
|
|
|
results = await service.query_batch(
|
|
[EngineType.CHATGPT, EngineType.PERPLEXITY],
|
|
"best insurance",
|
|
"BrandX",
|
|
)
|
|
assert len(results) == 2
|
|
engine_types = {r.engine_type for r in results}
|
|
assert engine_types == {EngineType.CHATGPT, EngineType.PERPLEXITY}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_query_batch_partial_failure(self):
|
|
from app.services.ai_engine.batch_query import BatchQueryService
|
|
|
|
r1 = _make_result(EngineType.CHATGPT, has_brand=True)
|
|
adapters = {
|
|
"chatgpt": _StubAdapter(EngineType.CHATGPT, result=r1),
|
|
"perplexity": _StubAdapter(
|
|
EngineType.PERPLEXITY, side_effect=Exception("API error")
|
|
),
|
|
}
|
|
service = BatchQueryService(adapters)
|
|
|
|
results = await service.query_batch(
|
|
[EngineType.CHATGPT, EngineType.PERPLEXITY],
|
|
"best insurance",
|
|
"BrandX",
|
|
)
|
|
assert len(results) == 1
|
|
assert results[0].engine_type == EngineType.CHATGPT
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_query_batch_all_fail(self):
|
|
from app.services.ai_engine.batch_query import BatchQueryService
|
|
|
|
adapters = {
|
|
"chatgpt": _StubAdapter(EngineType.CHATGPT, side_effect=Exception("err")),
|
|
"perplexity": _StubAdapter(EngineType.PERPLEXITY, side_effect=Exception("err")),
|
|
}
|
|
service = BatchQueryService(adapters)
|
|
|
|
results = await service.query_batch(
|
|
[EngineType.CHATGPT, EngineType.PERPLEXITY],
|
|
"test",
|
|
"BrandX",
|
|
)
|
|
assert results == []
|
|
|
|
|
|
class TestBatchQueryAggregation:
|
|
@pytest.mark.asyncio
|
|
async def test_results_aggregation(self):
|
|
from app.services.ai_engine.batch_query import BatchQueryService
|
|
|
|
r1 = _make_result(EngineType.CHATGPT, has_brand=True, has_competitor=False)
|
|
r2 = _make_result(EngineType.PERPLEXITY, has_brand=False, has_competitor=True)
|
|
r3 = _make_result(EngineType.KIMI, has_brand=True, has_competitor=True)
|
|
adapters = {
|
|
"chatgpt": _StubAdapter(EngineType.CHATGPT, result=r1),
|
|
"perplexity": _StubAdapter(EngineType.PERPLEXITY, result=r2),
|
|
"kimi": _StubAdapter(EngineType.KIMI, result=r3),
|
|
}
|
|
service = BatchQueryService(adapters)
|
|
|
|
results = await service.query_batch(
|
|
[EngineType.CHATGPT, EngineType.PERPLEXITY, EngineType.KIMI],
|
|
"test",
|
|
"BrandX",
|
|
["CompY"],
|
|
)
|
|
assert len(results) == 3
|
|
brand_cited = [r for r in results if r.has_brand_citation]
|
|
competitor_cited = [r for r in results if r.has_competitor_citation]
|
|
assert len(brand_cited) == 2
|
|
assert len(competitor_cited) == 2
|
|
|
|
|
|
class TestCitationRateCalculation:
|
|
def test_brand_citation_rate(self):
|
|
from app.services.ai_engine.batch_query import BatchQueryService
|
|
|
|
results = [
|
|
_make_result(EngineType.CHATGPT, has_brand=True),
|
|
_make_result(EngineType.PERPLEXITY, has_brand=True),
|
|
_make_result(EngineType.KIMI, has_brand=False),
|
|
]
|
|
service = BatchQueryService({})
|
|
rate = service.calculate_citation_rate(results)
|
|
assert rate["total_engines"] == 3
|
|
assert rate["brand_citation_count"] == 2
|
|
assert rate["brand_citation_rate"] == pytest.approx(2 / 3)
|
|
|
|
def test_competitor_citation_rate(self):
|
|
from app.services.ai_engine.batch_query import BatchQueryService
|
|
|
|
results = [
|
|
_make_result(EngineType.CHATGPT, has_competitor=True),
|
|
_make_result(EngineType.PERPLEXITY, has_competitor=False),
|
|
_make_result(EngineType.KIMI, has_competitor=True),
|
|
_make_result(EngineType.WENXIN, has_competitor=True),
|
|
]
|
|
service = BatchQueryService({})
|
|
rate = service.calculate_citation_rate(results)
|
|
assert rate["total_engines"] == 4
|
|
assert rate["competitor_citation_count"] == 3
|
|
assert rate["competitor_citation_rate"] == pytest.approx(3 / 4)
|
|
|
|
def test_empty_results(self):
|
|
from app.services.ai_engine.batch_query import BatchQueryService
|
|
|
|
service = BatchQueryService({})
|
|
rate = service.calculate_citation_rate([])
|
|
assert rate["total_engines"] == 0
|
|
assert rate["brand_citation_count"] == 0
|
|
assert rate["brand_citation_rate"] == 0
|
|
assert rate["competitor_citation_count"] == 0
|
|
assert rate["competitor_citation_rate"] == 0
|