386 lines
12 KiB
Python
386 lines
12 KiB
Python
"""Tests for PostProcessingExtractor (U4)."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
from agentkit.calendar.db import init_calendar_db
|
|
from agentkit.calendar.extraction import PostProcessingExtractor
|
|
from agentkit.calendar.service import CalendarService
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixtures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.fixture
|
|
def calendar_db_path(tmp_path: Path) -> Path:
|
|
path = tmp_path / "test_calendar.db"
|
|
asyncio.run(init_calendar_db(path))
|
|
return path
|
|
|
|
|
|
@pytest.fixture
|
|
def service(calendar_db_path: Path) -> CalendarService:
|
|
return CalendarService(db_path=calendar_db_path)
|
|
|
|
|
|
@pytest.fixture
|
|
def extractor(service: CalendarService) -> PostProcessingExtractor:
|
|
return PostProcessingExtractor(calendar_service=service)
|
|
|
|
|
|
class MockLLMGateway:
|
|
"""Minimal async mock for the LLM gateway."""
|
|
|
|
def __init__(self, response: str) -> None:
|
|
self.response = response
|
|
self.called = False
|
|
self.call_count = 0
|
|
|
|
async def acomplete(self, messages, temperature: float = 0.1) -> str:
|
|
self.called = True
|
|
self.call_count += 1
|
|
return self.response
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Keyword regex gate
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_keyword_regex_matches_chinese_time_words(extractor: PostProcessingExtractor) -> None:
|
|
"""Chinese time words trigger the keyword gate."""
|
|
assert extractor._KEYWORD_RE.search("明天下午3点开会") is not None
|
|
assert extractor._KEYWORD_RE.search("后天截止") is not None
|
|
assert extractor._KEYWORD_RE.search("下周安排一下") is not None
|
|
# No time words — should not match
|
|
assert extractor._KEYWORD_RE.search("继续优化吧") is None
|
|
assert extractor._KEYWORD_RE.search("好的,没问题") is None
|
|
|
|
|
|
def test_keyword_regex_matches_english_time_words(extractor: PostProcessingExtractor) -> None:
|
|
"""English time words trigger the keyword gate (case-insensitive)."""
|
|
assert extractor._KEYWORD_RE.search("deadline tomorrow") is not None
|
|
assert extractor._KEYWORD_RE.search("Schedule a meeting") is not None
|
|
assert extractor._KEYWORD_RE.search("set a reminder") is not None
|
|
# No time words — should not match
|
|
assert extractor._KEYWORD_RE.search("hello world") is None
|
|
assert extractor._KEYWORD_RE.search("how are you") is None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Keyword gate skips LLM
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def test_no_keyword_skips_llm_call(
|
|
extractor: PostProcessingExtractor, service: CalendarService
|
|
) -> None:
|
|
"""No keyword in text → LLM gateway never called, returns []."""
|
|
gateway = MockLLMGateway(response="[]")
|
|
extractor.llm_gateway = gateway
|
|
|
|
result = await extractor.extract(
|
|
conversation_text="好的,我们继续优化代码吧",
|
|
conversation_id="conv-1",
|
|
user_id="user-1",
|
|
)
|
|
|
|
assert result == []
|
|
assert gateway.called is False
|
|
assert gateway.call_count == 0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Keyword hit triggers LLM extraction
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def test_keyword_hit_triggers_llm_extraction(
|
|
extractor: PostProcessingExtractor,
|
|
) -> None:
|
|
"""Keyword present → LLM called → event created with source='post_extract'."""
|
|
llm_response = json.dumps(
|
|
[
|
|
{
|
|
"title": "团队会议",
|
|
"start_time": "2026-07-01T10:00:00+00:00",
|
|
"end_time": "2026-07-01T11:00:00+00:00",
|
|
"description": "周会",
|
|
}
|
|
]
|
|
)
|
|
gateway = MockLLMGateway(response=llm_response)
|
|
extractor.llm_gateway = gateway
|
|
|
|
result = await extractor.extract(
|
|
conversation_text="明天下午3点开个会",
|
|
conversation_id="conv-42",
|
|
user_id="user-1",
|
|
)
|
|
|
|
assert gateway.called is True
|
|
assert gateway.call_count == 1
|
|
assert len(result) == 1
|
|
event = result[0]
|
|
assert event["title"] == "团队会议"
|
|
assert event["source"] == "post_extract"
|
|
assert event["start_time"] == "2026-07-01T10:00:00+00:00"
|
|
assert event["end_time"] == "2026-07-01T11:00:00+00:00"
|
|
assert event["description"] == "周会"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# LLM returns empty array
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def test_llm_returns_empty_array_creates_nothing(
|
|
extractor: PostProcessingExtractor,
|
|
) -> None:
|
|
"""LLM returns [] → no events created."""
|
|
gateway = MockLLMGateway(response="[]")
|
|
extractor.llm_gateway = gateway
|
|
|
|
result = await extractor.extract(
|
|
conversation_text="明天有个安排",
|
|
conversation_id="conv-1",
|
|
user_id="user-1",
|
|
)
|
|
|
|
assert result == []
|
|
assert gateway.called is True
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Malformed LLM response
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def test_malformed_llm_response_handled_gracefully(
|
|
extractor: PostProcessingExtractor,
|
|
) -> None:
|
|
"""Invalid JSON response → no crash, returns []."""
|
|
gateway = MockLLMGateway(response="this is not json at all")
|
|
extractor.llm_gateway = gateway
|
|
|
|
result = await extractor.extract(
|
|
conversation_text="明天开会",
|
|
conversation_id="conv-1",
|
|
user_id="user-1",
|
|
)
|
|
|
|
assert result == []
|
|
|
|
|
|
async def test_malformed_llm_response_json_object_not_array(
|
|
extractor: PostProcessingExtractor,
|
|
) -> None:
|
|
"""JSON object (not array) → treated as no events."""
|
|
gateway = MockLLMGateway(response='{"title": "会议"}')
|
|
extractor.llm_gateway = gateway
|
|
|
|
result = await extractor.extract(
|
|
conversation_text="明天开会",
|
|
conversation_id="conv-1",
|
|
user_id="user-1",
|
|
)
|
|
|
|
assert result == []
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# conversation_id traceability
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def test_extracted_events_have_conversation_id(
|
|
extractor: PostProcessingExtractor,
|
|
) -> None:
|
|
"""Extracted events carry the conversation_id for traceability."""
|
|
llm_response = json.dumps(
|
|
[
|
|
{
|
|
"title": "评审会",
|
|
"start_time": "2026-07-01T14:00:00+00:00",
|
|
"end_time": "2026-07-01T15:00:00+00:00",
|
|
"description": "",
|
|
}
|
|
]
|
|
)
|
|
gateway = MockLLMGateway(response=llm_response)
|
|
extractor.llm_gateway = gateway
|
|
|
|
result = await extractor.extract(
|
|
conversation_text="后天下午2点评审会",
|
|
conversation_id="conv-trace-99",
|
|
user_id="user-7",
|
|
)
|
|
|
|
assert len(result) == 1
|
|
assert result[0]["conversation_id"] == "conv-trace-99"
|
|
assert result[0]["user_id"] == "user-7"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Async / non-blocking
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def test_extraction_does_not_block_chat_response(
|
|
extractor: PostProcessingExtractor,
|
|
) -> None:
|
|
"""extract() is awaitable and returns a list (inherent async guarantee)."""
|
|
gateway = MockLLMGateway(response="[]")
|
|
extractor.llm_gateway = gateway
|
|
|
|
# Awaiting must yield a list, not a coroutine or other object.
|
|
result = await extractor.extract(
|
|
conversation_text="明天deadline",
|
|
conversation_id="conv-1",
|
|
user_id="user-1",
|
|
)
|
|
assert isinstance(result, list)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# No LLM gateway configured
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def test_no_llm_gateway_returns_empty(
|
|
extractor: PostProcessingExtractor,
|
|
) -> None:
|
|
"""llm_gateway=None + keyword hit → returns [] without error."""
|
|
assert extractor.llm_gateway is None
|
|
|
|
result = await extractor.extract(
|
|
conversation_text="明天开会",
|
|
conversation_id="conv-1",
|
|
user_id="user-1",
|
|
)
|
|
|
|
assert result == []
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Code-fenced LLM response
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def test_llm_response_with_code_fences_parsed(
|
|
extractor: PostProcessingExtractor,
|
|
) -> None:
|
|
"""LLM wraps JSON in ```json ... ``` fences → parsed correctly."""
|
|
payload = json.dumps(
|
|
[
|
|
{
|
|
"title": "站会",
|
|
"start_time": "2026-07-01T09:00:00+00:00",
|
|
"end_time": "2026-07-01T09:15:00+00:00",
|
|
"description": "每日站会",
|
|
}
|
|
]
|
|
)
|
|
fenced = f"```json\n{payload}\n```"
|
|
gateway = MockLLMGateway(response=fenced)
|
|
extractor.llm_gateway = gateway
|
|
|
|
result = await extractor.extract(
|
|
conversation_text="明天上午开站会",
|
|
conversation_id="conv-1",
|
|
user_id="user-1",
|
|
)
|
|
|
|
assert len(result) == 1
|
|
assert result[0]["title"] == "站会"
|
|
assert result[0]["description"] == "每日站会"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Multiple events
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def test_multiple_events_extracted(
|
|
extractor: PostProcessingExtractor,
|
|
) -> None:
|
|
"""LLM returns 3 events → 3 events created."""
|
|
llm_response = json.dumps(
|
|
[
|
|
{
|
|
"title": "会议A",
|
|
"start_time": "2026-07-01T09:00:00+00:00",
|
|
"end_time": "2026-07-01T10:00:00+00:00",
|
|
"description": "",
|
|
},
|
|
{
|
|
"title": "会议B",
|
|
"start_time": "2026-07-02T14:00:00+00:00",
|
|
"end_time": "2026-07-02T15:00:00+00:00",
|
|
"description": "",
|
|
},
|
|
{
|
|
"title": "截止日期",
|
|
"start_time": "2026-07-05T23:59:00+00:00",
|
|
"end_time": "2026-07-05T23:59:00+00:00",
|
|
"description": "提交报告",
|
|
},
|
|
]
|
|
)
|
|
gateway = MockLLMGateway(response=llm_response)
|
|
extractor.llm_gateway = gateway
|
|
|
|
result = await extractor.extract(
|
|
conversation_text="本周有几个安排和截止",
|
|
conversation_id="conv-multi",
|
|
user_id="user-1",
|
|
)
|
|
|
|
assert len(result) == 3
|
|
titles = {e["title"] for e in result}
|
|
assert titles == {"会议A", "会议B", "截止日期"}
|
|
for event in result:
|
|
assert event["source"] == "post_extract"
|
|
assert event["conversation_id"] == "conv-multi"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Items without 'title' key are filtered out
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def test_items_without_title_filtered(
|
|
extractor: PostProcessingExtractor,
|
|
) -> None:
|
|
"""Dict items missing 'title' are dropped by the parser."""
|
|
llm_response = json.dumps(
|
|
[
|
|
{
|
|
"title": "有效会议",
|
|
"start_time": "2026-07-01T09:00:00+00:00",
|
|
"end_time": "2026-07-01T10:00:00+00:00",
|
|
"description": "",
|
|
},
|
|
{"start_time": "2026-07-02T09:00:00+00:00", "end_time": "2026-07-02T10:00:00+00:00"},
|
|
"not-a-dict",
|
|
]
|
|
)
|
|
gateway = MockLLMGateway(response=llm_response)
|
|
extractor.llm_gateway = gateway
|
|
|
|
result = await extractor.extract(
|
|
conversation_text="明天开会",
|
|
conversation_id="conv-1",
|
|
user_id="user-1",
|
|
)
|
|
|
|
assert len(result) == 1
|
|
assert result[0]["title"] == "有效会议"
|