geo/backend/app/workers/citation_extractor.py

255 lines
7.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
引用源分析引擎 - 从AI回答中提取引用的URL和来源信息
功能:
1. 提取文本中的URL链接
2. 提取Markdown格式的引用链接 [text](url)
3. 提取脚注引用标记 [1] 及其对应URL
4. 识别来源标注(如"来源xxx""据xxx报道"等)
5. 提取数据来源标记 [data_source: xxx]
"""
import logging
import re
from dataclasses import dataclass, field
logger = logging.getLogger(__name__)
@dataclass
class ExtractedCitation:
"""提取的引用源信息"""
source_url: str | None = None
source_title: str | None = None
citation_context: str | None = None # 引用出现的上下文片段
@dataclass
class CitationAnalysisResult:
"""引用源分析结果"""
data_source: str = "unknown" # "ai_platform" 或 "search_engine"
citations: list[ExtractedCitation] = field(default_factory=list)
clean_response: str = "" # 去掉 data_source 标记后的纯文本
# URL正则表达式匹配 http/https 链接)
_URL_PATTERN = re.compile(
r'https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+'
r'[/\w\-._~:/?#\[\]@!$&\'()*+,;=%]*',
re.IGNORECASE,
)
# Markdown引用链接 [text](url)
_MD_LINK_PATTERN = re.compile(
r'\[([^\]]+)\]\((https?://[^\s\)]+)\)',
re.IGNORECASE,
)
# 脚注引用 [1], [2] 等
_FOOTNOTE_REF_PATTERN = re.compile(r'\[(\d+)\]')
# 脚注定义 [1]: url 或 [1]: text url
_FOOTNOTE_DEF_PATTERN = re.compile(
r'\[(\d+)\]:\s*(?:([^\n]+?))?\s*(https?://\S+)',
re.MULTILINE,
)
# 来源标注模式
_SOURCE_ANNOTATION_PATTERNS = [
re.compile(r'来源[:]\s*([^\n,,。;;]+)', re.IGNORECASE),
re.compile(r'据([^\n,,。;;]{2,20}?)(?:报道|消息|透露|表示)', re.IGNORECASE),
re.compile(r'参考[:]\s*([^\n,,。;;]+)', re.IGNORECASE),
re.compile(r'引用[:]\s*([^\n,,。;;]+)', re.IGNORECASE),
re.compile(r'出处[:]\s*([^\n,,。;;]+)', re.IGNORECASE),
]
# data_source 标记
_DATA_SOURCE_PATTERN = re.compile(r'^\[data_source:\s*(\w+)\]\s*\n?', re.MULTILINE)
def extract_data_source(text: str) -> tuple[str, str]:
"""
从文本中提取 data_source 标记,返回 (data_source, clean_text)
"""
match = _DATA_SOURCE_PATTERN.search(text)
if match:
source = match.group(1)
clean_text = _DATA_SOURCE_PATTERN.sub("", text)
return source, clean_text
return "unknown", text
def extract_urls_with_context(text: str) -> list[ExtractedCitation]:
"""提取文本中的裸URL及其上下文"""
citations = []
seen_urls = set()
for match in _URL_PATTERN.finditer(text):
url = match.group(0)
# 清理URL末尾的标点
url = url.rstrip('.,;:!?),。;:!?)')
if url in seen_urls:
continue
seen_urls.add(url)
# 提取上下文URL前后各100字符
start = max(0, match.start() - 100)
end = min(len(text), match.end() + 50)
context = text[start:end].strip()
citations.append(ExtractedCitation(
source_url=url,
source_title=None,
citation_context=context,
))
return citations
def extract_markdown_links(text: str) -> list[ExtractedCitation]:
"""提取Markdown格式的引用链接 [text](url)"""
citations = []
seen_urls = set()
for match in _MD_LINK_PATTERN.finditer(text):
title = match.group(1).strip()
url = match.group(2).strip().rstrip('.,;:!?),。;:!?)')
if url in seen_urls:
continue
seen_urls.add(url)
# 提取上下文
start = max(0, match.start() - 80)
end = min(len(text), match.end() + 50)
context = text[start:end].strip()
citations.append(ExtractedCitation(
source_url=url,
source_title=title,
citation_context=context,
))
return citations
def extract_footnotes(text: str) -> list[ExtractedCitation]:
"""提取脚注引用及其定义"""
citations = []
seen_urls = set()
# 先收集脚注定义
footnote_defs: dict[str, tuple[str | None, str]] = {}
for match in _FOOTNOTE_DEF_PATTERN.finditer(text):
num = match.group(1)
title = match.group(2)
url = match.group(3).strip().rstrip('.,;:!?),。;:!?)')
if title:
title = title.strip().rstrip('.,;:,。;:')
footnote_defs[num] = (title, url)
# 再匹配脚注引用,提取上下文
for match in _FOOTNOTE_REF_PATTERN.finditer(text):
num = match.group(1)
if num in footnote_defs:
title, url = footnote_defs[num]
if url in seen_urls:
continue
seen_urls.add(url)
start = max(0, match.start() - 80)
end = min(len(text), match.end() + 50)
context = text[start:end].strip()
citations.append(ExtractedCitation(
source_url=url,
source_title=title,
citation_context=context,
))
return citations
def extract_source_annotations(text: str) -> list[ExtractedCitation]:
"""提取来源标注(如"来源xxx""据xxx报道"等)"""
citations = []
seen_titles = set()
for pattern in _SOURCE_ANNOTATION_PATTERNS:
for match in pattern.finditer(text):
title = match.group(1).strip()
if len(title) < 2 or title in seen_titles:
continue
seen_titles.add(title)
start = max(0, match.start() - 50)
end = min(len(text), match.end() + 50)
context = text[start:end].strip()
citations.append(ExtractedCitation(
source_url=None,
source_title=title,
citation_context=context,
))
return citations
def analyze_citations(raw_response: str) -> CitationAnalysisResult:
"""
分析AI回答中的引用源信息
Args:
raw_response: 平台适配器返回的原始响应文本
Returns:
CitationAnalysisResult: 包含数据来源标记、引用源列表和清理后的文本
"""
if not raw_response:
return CitationAnalysisResult()
# 1. 提取数据来源标记
data_source, clean_text = extract_data_source(raw_response)
# 2. 提取各类引用
all_citations: list[ExtractedCitation] = []
seen_urls = set()
# Markdown链接优先级最高有标题
for c in extract_markdown_links(clean_text):
if c.source_url not in seen_urls:
all_citations.append(c)
if c.source_url:
seen_urls.add(c.source_url)
# 脚注引用
for c in extract_footnotes(clean_text):
if c.source_url and c.source_url not in seen_urls:
all_citations.append(c)
seen_urls.add(c.source_url)
# 裸URL
for c in extract_urls_with_context(clean_text):
if c.source_url not in seen_urls:
all_citations.append(c)
seen_urls.add(c.source_url)
# 来源标注
for c in extract_source_annotations(clean_text):
all_citations.append(c)
# 限制最多20个引用
all_citations = all_citations[:20]
logger.info(
f"引用源分析完成: data_source={data_source}, "
f"提取到 {len(all_citations)} 个引用源"
)
return CitationAnalysisResult(
data_source=data_source,
citations=all_citations,
clean_response=clean_text,
)