feat(calendar): U4 post-processing extractor with keyword gating
Adds PostProcessingExtractor — a zero-LLM keyword gate (Chinese + English time words) followed by LLM extraction for ambiguous cases. Events created from extraction carry source="post_extract" so the UI can style them distinctly (R33). LLM gateway is optional to keep the constructor testable without a live provider. - src/agentkit/calendar/extraction.py — PostProcessingExtractor - tests/unit/calendar/test_extraction.py — 13 tests with MockLLMGateway
This commit is contained in:
parent
42fe7bcbc9
commit
ddcedb57b2
|
|
@ -0,0 +1,129 @@
|
|||
"""Post-processing extraction of schedule info from conversation text.
|
||||
|
||||
Two-stage approach (U4):
|
||||
1. Zero-LLM regex keyword gate — skip LLM entirely if no time-related keywords.
|
||||
2. LLM extraction — call the LLM gateway to pull structured event data.
|
||||
|
||||
Extracted events are persisted via ``CalendarService.create_event`` with
|
||||
``source="post_extract"`` and the originating ``conversation_id`` for
|
||||
traceability (R15).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
|
||||
from agentkit.calendar.service import CalendarService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PostProcessingExtractor:
|
||||
"""Extract schedule info from conversation text after a chat turn.
|
||||
|
||||
Two-stage: regex keyword gate (zero LLM) → LLM extraction.
|
||||
"""
|
||||
|
||||
# Time-related keywords that trigger LLM extraction
|
||||
_KEYWORD_RE = re.compile(
|
||||
r"明天|后天|下周|本周|今天下午|今天上午|上午|下午|晚上|"
|
||||
r"\d+点|\d+月\d+日|\d+号|开会|截止|deadline|schedule|"
|
||||
r"reminder|提醒|预约|约定|安排",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
def __init__(self, calendar_service: CalendarService, llm_gateway=None):
|
||||
self.service = calendar_service
|
||||
self.llm_gateway = llm_gateway # Optional, may be set later
|
||||
|
||||
async def extract(
|
||||
self,
|
||||
conversation_text: str,
|
||||
conversation_id: str,
|
||||
user_id: str,
|
||||
) -> list[dict]:
|
||||
"""Extract events from conversation text.
|
||||
|
||||
Returns list of created event dicts. Empty if no keywords or no events extracted.
|
||||
Never raises — all failures are logged and swallowed.
|
||||
"""
|
||||
# 1. Keyword gate — zero LLM cost if no match
|
||||
if not self._KEYWORD_RE.search(conversation_text):
|
||||
return []
|
||||
|
||||
# 2. LLM extraction
|
||||
events_data = await self._llm_extract(conversation_text)
|
||||
if not events_data:
|
||||
return []
|
||||
|
||||
# 3. Create events with source="post_extract"
|
||||
created = []
|
||||
for event_data in events_data:
|
||||
try:
|
||||
event = await self.service.create_event(
|
||||
user_id=user_id,
|
||||
title=event_data.get("title", ""),
|
||||
start_time=event_data.get("start_time", ""),
|
||||
end_time=event_data.get("end_time", ""),
|
||||
description=event_data.get("description", ""),
|
||||
source="post_extract",
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
created.append(event.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create extracted event: {e}")
|
||||
continue
|
||||
|
||||
return created
|
||||
|
||||
async def _llm_extract(self, text: str) -> list[dict]:
|
||||
"""Call LLM gateway to extract events from text.
|
||||
|
||||
Returns list of event dicts: [{title, start_time, end_time, description}].
|
||||
Returns [] on any error or empty result.
|
||||
"""
|
||||
if self.llm_gateway is None:
|
||||
return []
|
||||
|
||||
prompt = self._build_extraction_prompt(text)
|
||||
try:
|
||||
response = await self.llm_gateway.acomplete(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0.1,
|
||||
)
|
||||
return self._parse_llm_response(response)
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM extraction failed: {e}")
|
||||
return []
|
||||
|
||||
def _build_extraction_prompt(self, text: str) -> str:
|
||||
"""Build the LLM extraction prompt."""
|
||||
return f"""Extract schedule/event information from the following conversation text.
|
||||
Return a JSON array of events. Each event should have: title, start_time (ISO 8601), end_time (ISO 8601), description.
|
||||
If no events are found, return an empty array [].
|
||||
|
||||
Conversation text:
|
||||
{text}
|
||||
|
||||
Respond with ONLY the JSON array, no other text."""
|
||||
|
||||
def _parse_llm_response(self, response: str) -> list[dict]:
|
||||
"""Parse LLM response as JSON array. Returns [] on any error."""
|
||||
try:
|
||||
# Strip markdown code fences if present
|
||||
cleaned = response.strip()
|
||||
if cleaned.startswith("```"):
|
||||
cleaned = cleaned.split("\n", 1)[1] if "\n" in cleaned else cleaned[3:]
|
||||
if cleaned.endswith("```"):
|
||||
cleaned = cleaned[:-3]
|
||||
cleaned = cleaned.strip()
|
||||
|
||||
data = json.loads(cleaned)
|
||||
if not isinstance(data, list):
|
||||
return []
|
||||
return [item for item in data if isinstance(item, dict) and "title" in item]
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
logger.warning(f"Failed to parse LLM response as JSON: {e}")
|
||||
return []
|
||||
|
|
@ -0,0 +1,385 @@
|
|||
"""Tests for PostProcessingExtractor (U4)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.calendar.db import init_calendar_db
|
||||
from agentkit.calendar.extraction import PostProcessingExtractor
|
||||
from agentkit.calendar.service import CalendarService
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def calendar_db_path(tmp_path: Path) -> Path:
|
||||
path = tmp_path / "test_calendar.db"
|
||||
asyncio.run(init_calendar_db(path))
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service(calendar_db_path: Path) -> CalendarService:
|
||||
return CalendarService(db_path=calendar_db_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def extractor(service: CalendarService) -> PostProcessingExtractor:
|
||||
return PostProcessingExtractor(calendar_service=service)
|
||||
|
||||
|
||||
class MockLLMGateway:
|
||||
"""Minimal async mock for the LLM gateway."""
|
||||
|
||||
def __init__(self, response: str) -> None:
|
||||
self.response = response
|
||||
self.called = False
|
||||
self.call_count = 0
|
||||
|
||||
async def acomplete(self, messages, temperature: float = 0.1) -> str:
|
||||
self.called = True
|
||||
self.call_count += 1
|
||||
return self.response
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Keyword regex gate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_keyword_regex_matches_chinese_time_words(extractor: PostProcessingExtractor) -> None:
|
||||
"""Chinese time words trigger the keyword gate."""
|
||||
assert extractor._KEYWORD_RE.search("明天下午3点开会") is not None
|
||||
assert extractor._KEYWORD_RE.search("后天截止") is not None
|
||||
assert extractor._KEYWORD_RE.search("下周安排一下") is not None
|
||||
# No time words — should not match
|
||||
assert extractor._KEYWORD_RE.search("继续优化吧") is None
|
||||
assert extractor._KEYWORD_RE.search("好的,没问题") is None
|
||||
|
||||
|
||||
def test_keyword_regex_matches_english_time_words(extractor: PostProcessingExtractor) -> None:
|
||||
"""English time words trigger the keyword gate (case-insensitive)."""
|
||||
assert extractor._KEYWORD_RE.search("deadline tomorrow") is not None
|
||||
assert extractor._KEYWORD_RE.search("Schedule a meeting") is not None
|
||||
assert extractor._KEYWORD_RE.search("set a reminder") is not None
|
||||
# No time words — should not match
|
||||
assert extractor._KEYWORD_RE.search("hello world") is None
|
||||
assert extractor._KEYWORD_RE.search("how are you") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Keyword gate skips LLM
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_no_keyword_skips_llm_call(
|
||||
extractor: PostProcessingExtractor, service: CalendarService
|
||||
) -> None:
|
||||
"""No keyword in text → LLM gateway never called, returns []."""
|
||||
gateway = MockLLMGateway(response="[]")
|
||||
extractor.llm_gateway = gateway
|
||||
|
||||
result = await extractor.extract(
|
||||
conversation_text="好的,我们继续优化代码吧",
|
||||
conversation_id="conv-1",
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
assert result == []
|
||||
assert gateway.called is False
|
||||
assert gateway.call_count == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Keyword hit triggers LLM extraction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_keyword_hit_triggers_llm_extraction(
|
||||
extractor: PostProcessingExtractor,
|
||||
) -> None:
|
||||
"""Keyword present → LLM called → event created with source='post_extract'."""
|
||||
llm_response = json.dumps(
|
||||
[
|
||||
{
|
||||
"title": "团队会议",
|
||||
"start_time": "2026-07-01T10:00:00+00:00",
|
||||
"end_time": "2026-07-01T11:00:00+00:00",
|
||||
"description": "周会",
|
||||
}
|
||||
]
|
||||
)
|
||||
gateway = MockLLMGateway(response=llm_response)
|
||||
extractor.llm_gateway = gateway
|
||||
|
||||
result = await extractor.extract(
|
||||
conversation_text="明天下午3点开个会",
|
||||
conversation_id="conv-42",
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
assert gateway.called is True
|
||||
assert gateway.call_count == 1
|
||||
assert len(result) == 1
|
||||
event = result[0]
|
||||
assert event["title"] == "团队会议"
|
||||
assert event["source"] == "post_extract"
|
||||
assert event["start_time"] == "2026-07-01T10:00:00+00:00"
|
||||
assert event["end_time"] == "2026-07-01T11:00:00+00:00"
|
||||
assert event["description"] == "周会"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLM returns empty array
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_llm_returns_empty_array_creates_nothing(
|
||||
extractor: PostProcessingExtractor,
|
||||
) -> None:
|
||||
"""LLM returns [] → no events created."""
|
||||
gateway = MockLLMGateway(response="[]")
|
||||
extractor.llm_gateway = gateway
|
||||
|
||||
result = await extractor.extract(
|
||||
conversation_text="明天有个安排",
|
||||
conversation_id="conv-1",
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
assert result == []
|
||||
assert gateway.called is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Malformed LLM response
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_malformed_llm_response_handled_gracefully(
|
||||
extractor: PostProcessingExtractor,
|
||||
) -> None:
|
||||
"""Invalid JSON response → no crash, returns []."""
|
||||
gateway = MockLLMGateway(response="this is not json at all")
|
||||
extractor.llm_gateway = gateway
|
||||
|
||||
result = await extractor.extract(
|
||||
conversation_text="明天开会",
|
||||
conversation_id="conv-1",
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
async def test_malformed_llm_response_json_object_not_array(
|
||||
extractor: PostProcessingExtractor,
|
||||
) -> None:
|
||||
"""JSON object (not array) → treated as no events."""
|
||||
gateway = MockLLMGateway(response='{"title": "会议"}')
|
||||
extractor.llm_gateway = gateway
|
||||
|
||||
result = await extractor.extract(
|
||||
conversation_text="明天开会",
|
||||
conversation_id="conv-1",
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# conversation_id traceability
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_extracted_events_have_conversation_id(
|
||||
extractor: PostProcessingExtractor,
|
||||
) -> None:
|
||||
"""Extracted events carry the conversation_id for traceability."""
|
||||
llm_response = json.dumps(
|
||||
[
|
||||
{
|
||||
"title": "评审会",
|
||||
"start_time": "2026-07-01T14:00:00+00:00",
|
||||
"end_time": "2026-07-01T15:00:00+00:00",
|
||||
"description": "",
|
||||
}
|
||||
]
|
||||
)
|
||||
gateway = MockLLMGateway(response=llm_response)
|
||||
extractor.llm_gateway = gateway
|
||||
|
||||
result = await extractor.extract(
|
||||
conversation_text="后天下午2点评审会",
|
||||
conversation_id="conv-trace-99",
|
||||
user_id="user-7",
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["conversation_id"] == "conv-trace-99"
|
||||
assert result[0]["user_id"] == "user-7"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Async / non-blocking
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_extraction_does_not_block_chat_response(
|
||||
extractor: PostProcessingExtractor,
|
||||
) -> None:
|
||||
"""extract() is awaitable and returns a list (inherent async guarantee)."""
|
||||
gateway = MockLLMGateway(response="[]")
|
||||
extractor.llm_gateway = gateway
|
||||
|
||||
# Awaiting must yield a list, not a coroutine or other object.
|
||||
result = await extractor.extract(
|
||||
conversation_text="明天deadline",
|
||||
conversation_id="conv-1",
|
||||
user_id="user-1",
|
||||
)
|
||||
assert isinstance(result, list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# No LLM gateway configured
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_no_llm_gateway_returns_empty(
|
||||
extractor: PostProcessingExtractor,
|
||||
) -> None:
|
||||
"""llm_gateway=None + keyword hit → returns [] without error."""
|
||||
assert extractor.llm_gateway is None
|
||||
|
||||
result = await extractor.extract(
|
||||
conversation_text="明天开会",
|
||||
conversation_id="conv-1",
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Code-fenced LLM response
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_llm_response_with_code_fences_parsed(
|
||||
extractor: PostProcessingExtractor,
|
||||
) -> None:
|
||||
"""LLM wraps JSON in ```json ... ``` fences → parsed correctly."""
|
||||
payload = json.dumps(
|
||||
[
|
||||
{
|
||||
"title": "站会",
|
||||
"start_time": "2026-07-01T09:00:00+00:00",
|
||||
"end_time": "2026-07-01T09:15:00+00:00",
|
||||
"description": "每日站会",
|
||||
}
|
||||
]
|
||||
)
|
||||
fenced = f"```json\n{payload}\n```"
|
||||
gateway = MockLLMGateway(response=fenced)
|
||||
extractor.llm_gateway = gateway
|
||||
|
||||
result = await extractor.extract(
|
||||
conversation_text="明天上午开站会",
|
||||
conversation_id="conv-1",
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["title"] == "站会"
|
||||
assert result[0]["description"] == "每日站会"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Multiple events
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_multiple_events_extracted(
|
||||
extractor: PostProcessingExtractor,
|
||||
) -> None:
|
||||
"""LLM returns 3 events → 3 events created."""
|
||||
llm_response = json.dumps(
|
||||
[
|
||||
{
|
||||
"title": "会议A",
|
||||
"start_time": "2026-07-01T09:00:00+00:00",
|
||||
"end_time": "2026-07-01T10:00:00+00:00",
|
||||
"description": "",
|
||||
},
|
||||
{
|
||||
"title": "会议B",
|
||||
"start_time": "2026-07-02T14:00:00+00:00",
|
||||
"end_time": "2026-07-02T15:00:00+00:00",
|
||||
"description": "",
|
||||
},
|
||||
{
|
||||
"title": "截止日期",
|
||||
"start_time": "2026-07-05T23:59:00+00:00",
|
||||
"end_time": "2026-07-05T23:59:00+00:00",
|
||||
"description": "提交报告",
|
||||
},
|
||||
]
|
||||
)
|
||||
gateway = MockLLMGateway(response=llm_response)
|
||||
extractor.llm_gateway = gateway
|
||||
|
||||
result = await extractor.extract(
|
||||
conversation_text="本周有几个安排和截止",
|
||||
conversation_id="conv-multi",
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
assert len(result) == 3
|
||||
titles = {e["title"] for e in result}
|
||||
assert titles == {"会议A", "会议B", "截止日期"}
|
||||
for event in result:
|
||||
assert event["source"] == "post_extract"
|
||||
assert event["conversation_id"] == "conv-multi"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Items without 'title' key are filtered out
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_items_without_title_filtered(
|
||||
extractor: PostProcessingExtractor,
|
||||
) -> None:
|
||||
"""Dict items missing 'title' are dropped by the parser."""
|
||||
llm_response = json.dumps(
|
||||
[
|
||||
{
|
||||
"title": "有效会议",
|
||||
"start_time": "2026-07-01T09:00:00+00:00",
|
||||
"end_time": "2026-07-01T10:00:00+00:00",
|
||||
"description": "",
|
||||
},
|
||||
{"start_time": "2026-07-02T09:00:00+00:00", "end_time": "2026-07-02T10:00:00+00:00"},
|
||||
"not-a-dict",
|
||||
]
|
||||
)
|
||||
gateway = MockLLMGateway(response=llm_response)
|
||||
extractor.llm_gateway = gateway
|
||||
|
||||
result = await extractor.extract(
|
||||
conversation_text="明天开会",
|
||||
conversation_id="conv-1",
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["title"] == "有效会议"
|
||||
Loading…
Reference in New Issue