169 lines
4.6 KiB
Python
169 lines
4.6 KiB
Python
"""实体和关系抽取服务"""
|
||
import re
|
||
import json
|
||
from typing import Optional
|
||
from dataclasses import dataclass, field
|
||
|
||
from app.services.llm.factory import LLMFactory
|
||
|
||
|
||
@dataclass
|
||
class ExtractedEntity:
|
||
"""抽取的实体"""
|
||
name: str
|
||
entity_type: str
|
||
description: Optional[str] = None
|
||
properties: dict = field(default_factory=dict)
|
||
|
||
|
||
@dataclass
|
||
class ExtractedRelation:
|
||
"""抽取的关系"""
|
||
source_entity: str
|
||
target_entity: str
|
||
relation_type: str
|
||
properties: dict = field(default_factory=dict)
|
||
|
||
|
||
@dataclass
|
||
class ExtractionResult:
|
||
"""抽取结果"""
|
||
entities: list[ExtractedEntity] = field(default_factory=list)
|
||
relations: list[ExtractedRelation] = field(default_factory=list)
|
||
|
||
|
||
class EntityExtractor:
|
||
"""实体和关系抽取服务"""
|
||
|
||
# 实体类型映射
|
||
ENTITY_TYPES = [
|
||
"ORGANIZATION", # 公司/组织
|
||
"PRODUCT", # 产品
|
||
"PERSON", # 人物
|
||
"LOCATION", # 地点
|
||
"TECHNOLOGY", # 技术
|
||
"BRAND", # 品牌
|
||
"EVENT", # 事件
|
||
"CONCEPT", # 概念
|
||
]
|
||
|
||
# 关系类型映射
|
||
RELATION_TYPES = [
|
||
"COMPETES_WITH", # 竞争对手
|
||
"PARTNERS_WITH", # 合作伙伴
|
||
"PRODUCES", # 生产
|
||
"USES_TECHNOLOGY", # 使用技术
|
||
"LOCATED_IN", # 位于
|
||
"FOUNDED_IN", # 成立于
|
||
"CEO_OF", # CEO
|
||
"FOUNDER_OF", # 创始人
|
||
"RELATED_TO", # 相关
|
||
"PART_OF", # 属于
|
||
]
|
||
|
||
def __init__(self):
|
||
self.llm = LLMFactory.create()
|
||
|
||
async def extract(self, text: str, context: Optional[str] = None) -> ExtractionResult:
|
||
"""
|
||
从文本中抽取实体和关系
|
||
|
||
Args:
|
||
text: 待处理的文本
|
||
context: 可选的上下文信息(如品牌名、行业等)
|
||
|
||
Returns:
|
||
ExtractionResult: 包含实体和关系的抽取结果
|
||
"""
|
||
# 构建抽取Prompt
|
||
prompt = self._build_extraction_prompt(text, context)
|
||
|
||
# 调用LLM
|
||
response = await self.llm.generate(prompt)
|
||
|
||
# 解析结果
|
||
return self._parse_response(response)
|
||
|
||
def _build_extraction_prompt(self, text: str, context: Optional[str] = None) -> str:
|
||
"""构建抽取Prompt"""
|
||
entity_types = "\n".join([f"- {t}" for t in self.ENTITY_TYPES])
|
||
relation_types = "\n".join([f"- {t}" for t in self.RELATION_TYPES])
|
||
|
||
context_hint = f"\n\n附加上下文:{context}" if context else ""
|
||
|
||
return f"""从以下文本中抽取知识图谱的实体和关系。
|
||
|
||
要求:
|
||
1. 实体必须从文本中明确提及,不能臆造
|
||
2. 关系必须有文本依据,不能臆造
|
||
3. 每个实体和关系都要有置信度说明(high/medium/low)
|
||
|
||
实体类型:
|
||
{entity_types}
|
||
|
||
关系类型:
|
||
{relation_types}
|
||
|
||
{context_hint}
|
||
|
||
文本内容:
|
||
{text}
|
||
|
||
请以JSON格式返回结果:
|
||
{{
|
||
"entities": [
|
||
{{
|
||
"name": "实体名称",
|
||
"entity_type": "实体类型",
|
||
"description": "实体描述(可选)",
|
||
"confidence": "high/medium/low"
|
||
}}
|
||
],
|
||
"relations": [
|
||
{{
|
||
"source_entity": "源实体名称",
|
||
"target_entity": "目标实体名称",
|
||
"relation_type": "关系类型",
|
||
"confidence": "high/medium/low"
|
||
}}
|
||
]
|
||
}}
|
||
|
||
只返回JSON,不要有其他内容:"""
|
||
|
||
def _parse_response(self, response: str) -> ExtractionResult:
|
||
"""解析LLM返回的结果"""
|
||
# 提取JSON
|
||
json_match = re.search(r'\{[\s\S]*\}', response)
|
||
if not json_match:
|
||
return ExtractionResult(entities=[], relations=[])
|
||
|
||
try:
|
||
data = json.loads(json_match.group())
|
||
except json.JSONDecodeError:
|
||
return ExtractionResult(entities=[], relations=[])
|
||
|
||
# 解析实体
|
||
entities = [
|
||
ExtractedEntity(
|
||
name=e["name"],
|
||
entity_type=e["entity_type"],
|
||
description=e.get("description"),
|
||
properties={"confidence": e.get("confidence", "medium")},
|
||
)
|
||
for e in data.get("entities", [])
|
||
]
|
||
|
||
# 解析关系
|
||
relations = [
|
||
ExtractedRelation(
|
||
source_entity=r["source_entity"],
|
||
target_entity=r["target_entity"],
|
||
relation_type=r["relation_type"],
|
||
properties={"confidence": r.get("confidence", "medium")},
|
||
)
|
||
for r in data.get("relations", [])
|
||
]
|
||
|
||
return ExtractionResult(entities=entities, relations=relations)
|