feat(calendar): U4 post-processing extractor with keyword gating
Adds PostProcessingExtractor — a zero-LLM keyword gate (Chinese + English time words) followed by LLM extraction for ambiguous cases. Events created from extraction carry source="post_extract" so the UI can style them distinctly (R33). LLM gateway is optional to keep the constructor testable without a live provider. - src/agentkit/calendar/extraction.py — PostProcessingExtractor - tests/unit/calendar/test_extraction.py — 13 tests with MockLLMGateway
This commit is contained in:
parent
42fe7bcbc9
commit
ddcedb57b2
|
|
@ -0,0 +1,129 @@
|
||||||
|
"""Post-processing extraction of schedule info from conversation text.
|
||||||
|
|
||||||
|
Two-stage approach (U4):
|
||||||
|
1. Zero-LLM regex keyword gate — skip LLM entirely if no time-related keywords.
|
||||||
|
2. LLM extraction — call the LLM gateway to pull structured event data.
|
||||||
|
|
||||||
|
Extracted events are persisted via ``CalendarService.create_event`` with
|
||||||
|
``source="post_extract"`` and the originating ``conversation_id`` for
|
||||||
|
traceability (R15).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
|
||||||
|
from agentkit.calendar.service import CalendarService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PostProcessingExtractor:
|
||||||
|
"""Extract schedule info from conversation text after a chat turn.
|
||||||
|
|
||||||
|
Two-stage: regex keyword gate (zero LLM) → LLM extraction.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Time-related keywords that trigger LLM extraction
|
||||||
|
_KEYWORD_RE = re.compile(
|
||||||
|
r"明天|后天|下周|本周|今天下午|今天上午|上午|下午|晚上|"
|
||||||
|
r"\d+点|\d+月\d+日|\d+号|开会|截止|deadline|schedule|"
|
||||||
|
r"reminder|提醒|预约|约定|安排",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, calendar_service: CalendarService, llm_gateway=None):
|
||||||
|
self.service = calendar_service
|
||||||
|
self.llm_gateway = llm_gateway # Optional, may be set later
|
||||||
|
|
||||||
|
async def extract(
|
||||||
|
self,
|
||||||
|
conversation_text: str,
|
||||||
|
conversation_id: str,
|
||||||
|
user_id: str,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Extract events from conversation text.
|
||||||
|
|
||||||
|
Returns list of created event dicts. Empty if no keywords or no events extracted.
|
||||||
|
Never raises — all failures are logged and swallowed.
|
||||||
|
"""
|
||||||
|
# 1. Keyword gate — zero LLM cost if no match
|
||||||
|
if not self._KEYWORD_RE.search(conversation_text):
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 2. LLM extraction
|
||||||
|
events_data = await self._llm_extract(conversation_text)
|
||||||
|
if not events_data:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 3. Create events with source="post_extract"
|
||||||
|
created = []
|
||||||
|
for event_data in events_data:
|
||||||
|
try:
|
||||||
|
event = await self.service.create_event(
|
||||||
|
user_id=user_id,
|
||||||
|
title=event_data.get("title", ""),
|
||||||
|
start_time=event_data.get("start_time", ""),
|
||||||
|
end_time=event_data.get("end_time", ""),
|
||||||
|
description=event_data.get("description", ""),
|
||||||
|
source="post_extract",
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
)
|
||||||
|
created.append(event.to_dict())
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to create extracted event: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
return created
|
||||||
|
|
||||||
|
async def _llm_extract(self, text: str) -> list[dict]:
|
||||||
|
"""Call LLM gateway to extract events from text.
|
||||||
|
|
||||||
|
Returns list of event dicts: [{title, start_time, end_time, description}].
|
||||||
|
Returns [] on any error or empty result.
|
||||||
|
"""
|
||||||
|
if self.llm_gateway is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
prompt = self._build_extraction_prompt(text)
|
||||||
|
try:
|
||||||
|
response = await self.llm_gateway.acomplete(
|
||||||
|
messages=[{"role": "user", "content": prompt}],
|
||||||
|
temperature=0.1,
|
||||||
|
)
|
||||||
|
return self._parse_llm_response(response)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"LLM extraction failed: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _build_extraction_prompt(self, text: str) -> str:
|
||||||
|
"""Build the LLM extraction prompt."""
|
||||||
|
return f"""Extract schedule/event information from the following conversation text.
|
||||||
|
Return a JSON array of events. Each event should have: title, start_time (ISO 8601), end_time (ISO 8601), description.
|
||||||
|
If no events are found, return an empty array [].
|
||||||
|
|
||||||
|
Conversation text:
|
||||||
|
{text}
|
||||||
|
|
||||||
|
Respond with ONLY the JSON array, no other text."""
|
||||||
|
|
||||||
|
def _parse_llm_response(self, response: str) -> list[dict]:
|
||||||
|
"""Parse LLM response as JSON array. Returns [] on any error."""
|
||||||
|
try:
|
||||||
|
# Strip markdown code fences if present
|
||||||
|
cleaned = response.strip()
|
||||||
|
if cleaned.startswith("```"):
|
||||||
|
cleaned = cleaned.split("\n", 1)[1] if "\n" in cleaned else cleaned[3:]
|
||||||
|
if cleaned.endswith("```"):
|
||||||
|
cleaned = cleaned[:-3]
|
||||||
|
cleaned = cleaned.strip()
|
||||||
|
|
||||||
|
data = json.loads(cleaned)
|
||||||
|
if not isinstance(data, list):
|
||||||
|
return []
|
||||||
|
return [item for item in data if isinstance(item, dict) and "title" in item]
|
||||||
|
except (json.JSONDecodeError, TypeError) as e:
|
||||||
|
logger.warning(f"Failed to parse LLM response as JSON: {e}")
|
||||||
|
return []
|
||||||
|
|
@ -0,0 +1,385 @@
|
||||||
|
"""Tests for PostProcessingExtractor (U4)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.calendar.db import init_calendar_db
|
||||||
|
from agentkit.calendar.extraction import PostProcessingExtractor
|
||||||
|
from agentkit.calendar.service import CalendarService
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def calendar_db_path(tmp_path: Path) -> Path:
|
||||||
|
path = tmp_path / "test_calendar.db"
|
||||||
|
asyncio.run(init_calendar_db(path))
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def service(calendar_db_path: Path) -> CalendarService:
|
||||||
|
return CalendarService(db_path=calendar_db_path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def extractor(service: CalendarService) -> PostProcessingExtractor:
|
||||||
|
return PostProcessingExtractor(calendar_service=service)
|
||||||
|
|
||||||
|
|
||||||
|
class MockLLMGateway:
|
||||||
|
"""Minimal async mock for the LLM gateway."""
|
||||||
|
|
||||||
|
def __init__(self, response: str) -> None:
|
||||||
|
self.response = response
|
||||||
|
self.called = False
|
||||||
|
self.call_count = 0
|
||||||
|
|
||||||
|
async def acomplete(self, messages, temperature: float = 0.1) -> str:
|
||||||
|
self.called = True
|
||||||
|
self.call_count += 1
|
||||||
|
return self.response
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Keyword regex gate
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_keyword_regex_matches_chinese_time_words(extractor: PostProcessingExtractor) -> None:
|
||||||
|
"""Chinese time words trigger the keyword gate."""
|
||||||
|
assert extractor._KEYWORD_RE.search("明天下午3点开会") is not None
|
||||||
|
assert extractor._KEYWORD_RE.search("后天截止") is not None
|
||||||
|
assert extractor._KEYWORD_RE.search("下周安排一下") is not None
|
||||||
|
# No time words — should not match
|
||||||
|
assert extractor._KEYWORD_RE.search("继续优化吧") is None
|
||||||
|
assert extractor._KEYWORD_RE.search("好的,没问题") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_keyword_regex_matches_english_time_words(extractor: PostProcessingExtractor) -> None:
|
||||||
|
"""English time words trigger the keyword gate (case-insensitive)."""
|
||||||
|
assert extractor._KEYWORD_RE.search("deadline tomorrow") is not None
|
||||||
|
assert extractor._KEYWORD_RE.search("Schedule a meeting") is not None
|
||||||
|
assert extractor._KEYWORD_RE.search("set a reminder") is not None
|
||||||
|
# No time words — should not match
|
||||||
|
assert extractor._KEYWORD_RE.search("hello world") is None
|
||||||
|
assert extractor._KEYWORD_RE.search("how are you") is None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Keyword gate skips LLM
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_no_keyword_skips_llm_call(
|
||||||
|
extractor: PostProcessingExtractor, service: CalendarService
|
||||||
|
) -> None:
|
||||||
|
"""No keyword in text → LLM gateway never called, returns []."""
|
||||||
|
gateway = MockLLMGateway(response="[]")
|
||||||
|
extractor.llm_gateway = gateway
|
||||||
|
|
||||||
|
result = await extractor.extract(
|
||||||
|
conversation_text="好的,我们继续优化代码吧",
|
||||||
|
conversation_id="conv-1",
|
||||||
|
user_id="user-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == []
|
||||||
|
assert gateway.called is False
|
||||||
|
assert gateway.call_count == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Keyword hit triggers LLM extraction
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_keyword_hit_triggers_llm_extraction(
|
||||||
|
extractor: PostProcessingExtractor,
|
||||||
|
) -> None:
|
||||||
|
"""Keyword present → LLM called → event created with source='post_extract'."""
|
||||||
|
llm_response = json.dumps(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"title": "团队会议",
|
||||||
|
"start_time": "2026-07-01T10:00:00+00:00",
|
||||||
|
"end_time": "2026-07-01T11:00:00+00:00",
|
||||||
|
"description": "周会",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
gateway = MockLLMGateway(response=llm_response)
|
||||||
|
extractor.llm_gateway = gateway
|
||||||
|
|
||||||
|
result = await extractor.extract(
|
||||||
|
conversation_text="明天下午3点开个会",
|
||||||
|
conversation_id="conv-42",
|
||||||
|
user_id="user-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert gateway.called is True
|
||||||
|
assert gateway.call_count == 1
|
||||||
|
assert len(result) == 1
|
||||||
|
event = result[0]
|
||||||
|
assert event["title"] == "团队会议"
|
||||||
|
assert event["source"] == "post_extract"
|
||||||
|
assert event["start_time"] == "2026-07-01T10:00:00+00:00"
|
||||||
|
assert event["end_time"] == "2026-07-01T11:00:00+00:00"
|
||||||
|
assert event["description"] == "周会"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# LLM returns empty array
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_llm_returns_empty_array_creates_nothing(
|
||||||
|
extractor: PostProcessingExtractor,
|
||||||
|
) -> None:
|
||||||
|
"""LLM returns [] → no events created."""
|
||||||
|
gateway = MockLLMGateway(response="[]")
|
||||||
|
extractor.llm_gateway = gateway
|
||||||
|
|
||||||
|
result = await extractor.extract(
|
||||||
|
conversation_text="明天有个安排",
|
||||||
|
conversation_id="conv-1",
|
||||||
|
user_id="user-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == []
|
||||||
|
assert gateway.called is True
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Malformed LLM response
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_malformed_llm_response_handled_gracefully(
|
||||||
|
extractor: PostProcessingExtractor,
|
||||||
|
) -> None:
|
||||||
|
"""Invalid JSON response → no crash, returns []."""
|
||||||
|
gateway = MockLLMGateway(response="this is not json at all")
|
||||||
|
extractor.llm_gateway = gateway
|
||||||
|
|
||||||
|
result = await extractor.extract(
|
||||||
|
conversation_text="明天开会",
|
||||||
|
conversation_id="conv-1",
|
||||||
|
user_id="user-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
async def test_malformed_llm_response_json_object_not_array(
|
||||||
|
extractor: PostProcessingExtractor,
|
||||||
|
) -> None:
|
||||||
|
"""JSON object (not array) → treated as no events."""
|
||||||
|
gateway = MockLLMGateway(response='{"title": "会议"}')
|
||||||
|
extractor.llm_gateway = gateway
|
||||||
|
|
||||||
|
result = await extractor.extract(
|
||||||
|
conversation_text="明天开会",
|
||||||
|
conversation_id="conv-1",
|
||||||
|
user_id="user-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# conversation_id traceability
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_extracted_events_have_conversation_id(
|
||||||
|
extractor: PostProcessingExtractor,
|
||||||
|
) -> None:
|
||||||
|
"""Extracted events carry the conversation_id for traceability."""
|
||||||
|
llm_response = json.dumps(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"title": "评审会",
|
||||||
|
"start_time": "2026-07-01T14:00:00+00:00",
|
||||||
|
"end_time": "2026-07-01T15:00:00+00:00",
|
||||||
|
"description": "",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
gateway = MockLLMGateway(response=llm_response)
|
||||||
|
extractor.llm_gateway = gateway
|
||||||
|
|
||||||
|
result = await extractor.extract(
|
||||||
|
conversation_text="后天下午2点评审会",
|
||||||
|
conversation_id="conv-trace-99",
|
||||||
|
user_id="user-7",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["conversation_id"] == "conv-trace-99"
|
||||||
|
assert result[0]["user_id"] == "user-7"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Async / non-blocking
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_extraction_does_not_block_chat_response(
|
||||||
|
extractor: PostProcessingExtractor,
|
||||||
|
) -> None:
|
||||||
|
"""extract() is awaitable and returns a list (inherent async guarantee)."""
|
||||||
|
gateway = MockLLMGateway(response="[]")
|
||||||
|
extractor.llm_gateway = gateway
|
||||||
|
|
||||||
|
# Awaiting must yield a list, not a coroutine or other object.
|
||||||
|
result = await extractor.extract(
|
||||||
|
conversation_text="明天deadline",
|
||||||
|
conversation_id="conv-1",
|
||||||
|
user_id="user-1",
|
||||||
|
)
|
||||||
|
assert isinstance(result, list)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# No LLM gateway configured
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_no_llm_gateway_returns_empty(
|
||||||
|
extractor: PostProcessingExtractor,
|
||||||
|
) -> None:
|
||||||
|
"""llm_gateway=None + keyword hit → returns [] without error."""
|
||||||
|
assert extractor.llm_gateway is None
|
||||||
|
|
||||||
|
result = await extractor.extract(
|
||||||
|
conversation_text="明天开会",
|
||||||
|
conversation_id="conv-1",
|
||||||
|
user_id="user-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Code-fenced LLM response
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_llm_response_with_code_fences_parsed(
|
||||||
|
extractor: PostProcessingExtractor,
|
||||||
|
) -> None:
|
||||||
|
"""LLM wraps JSON in ```json ... ``` fences → parsed correctly."""
|
||||||
|
payload = json.dumps(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"title": "站会",
|
||||||
|
"start_time": "2026-07-01T09:00:00+00:00",
|
||||||
|
"end_time": "2026-07-01T09:15:00+00:00",
|
||||||
|
"description": "每日站会",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
fenced = f"```json\n{payload}\n```"
|
||||||
|
gateway = MockLLMGateway(response=fenced)
|
||||||
|
extractor.llm_gateway = gateway
|
||||||
|
|
||||||
|
result = await extractor.extract(
|
||||||
|
conversation_text="明天上午开站会",
|
||||||
|
conversation_id="conv-1",
|
||||||
|
user_id="user-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["title"] == "站会"
|
||||||
|
assert result[0]["description"] == "每日站会"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Multiple events
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_multiple_events_extracted(
|
||||||
|
extractor: PostProcessingExtractor,
|
||||||
|
) -> None:
|
||||||
|
"""LLM returns 3 events → 3 events created."""
|
||||||
|
llm_response = json.dumps(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"title": "会议A",
|
||||||
|
"start_time": "2026-07-01T09:00:00+00:00",
|
||||||
|
"end_time": "2026-07-01T10:00:00+00:00",
|
||||||
|
"description": "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"title": "会议B",
|
||||||
|
"start_time": "2026-07-02T14:00:00+00:00",
|
||||||
|
"end_time": "2026-07-02T15:00:00+00:00",
|
||||||
|
"description": "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"title": "截止日期",
|
||||||
|
"start_time": "2026-07-05T23:59:00+00:00",
|
||||||
|
"end_time": "2026-07-05T23:59:00+00:00",
|
||||||
|
"description": "提交报告",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
gateway = MockLLMGateway(response=llm_response)
|
||||||
|
extractor.llm_gateway = gateway
|
||||||
|
|
||||||
|
result = await extractor.extract(
|
||||||
|
conversation_text="本周有几个安排和截止",
|
||||||
|
conversation_id="conv-multi",
|
||||||
|
user_id="user-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) == 3
|
||||||
|
titles = {e["title"] for e in result}
|
||||||
|
assert titles == {"会议A", "会议B", "截止日期"}
|
||||||
|
for event in result:
|
||||||
|
assert event["source"] == "post_extract"
|
||||||
|
assert event["conversation_id"] == "conv-multi"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Items without 'title' key are filtered out
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_items_without_title_filtered(
|
||||||
|
extractor: PostProcessingExtractor,
|
||||||
|
) -> None:
|
||||||
|
"""Dict items missing 'title' are dropped by the parser."""
|
||||||
|
llm_response = json.dumps(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"title": "有效会议",
|
||||||
|
"start_time": "2026-07-01T09:00:00+00:00",
|
||||||
|
"end_time": "2026-07-01T10:00:00+00:00",
|
||||||
|
"description": "",
|
||||||
|
},
|
||||||
|
{"start_time": "2026-07-02T09:00:00+00:00", "end_time": "2026-07-02T10:00:00+00:00"},
|
||||||
|
"not-a-dict",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
gateway = MockLLMGateway(response=llm_response)
|
||||||
|
extractor.llm_gateway = gateway
|
||||||
|
|
||||||
|
result = await extractor.extract(
|
||||||
|
conversation_text="明天开会",
|
||||||
|
conversation_id="conv-1",
|
||||||
|
user_id="user-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["title"] == "有效会议"
|
||||||
Loading…
Reference in New Issue