geo/backend/app/services/knowledge/entity_extractor.py

169 lines
4.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""实体和关系抽取服务"""
import re
import json
from typing import Optional
from dataclasses import dataclass, field
from app.services.llm.factory import LLMFactory
@dataclass
class ExtractedEntity:
"""抽取的实体"""
name: str
entity_type: str
description: Optional[str] = None
properties: dict = field(default_factory=dict)
@dataclass
class ExtractedRelation:
"""抽取的关系"""
source_entity: str
target_entity: str
relation_type: str
properties: dict = field(default_factory=dict)
@dataclass
class ExtractionResult:
"""抽取结果"""
entities: list[ExtractedEntity] = field(default_factory=list)
relations: list[ExtractedRelation] = field(default_factory=list)
class EntityExtractor:
"""实体和关系抽取服务"""
# 实体类型映射
ENTITY_TYPES = [
"ORGANIZATION", # 公司/组织
"PRODUCT", # 产品
"PERSON", # 人物
"LOCATION", # 地点
"TECHNOLOGY", # 技术
"BRAND", # 品牌
"EVENT", # 事件
"CONCEPT", # 概念
]
# 关系类型映射
RELATION_TYPES = [
"COMPETES_WITH", # 竞争对手
"PARTNERS_WITH", # 合作伙伴
"PRODUCES", # 生产
"USES_TECHNOLOGY", # 使用技术
"LOCATED_IN", # 位于
"FOUNDED_IN", # 成立于
"CEO_OF", # CEO
"FOUNDER_OF", # 创始人
"RELATED_TO", # 相关
"PART_OF", # 属于
]
def __init__(self):
self.llm = LLMFactory.create()
async def extract(self, text: str, context: Optional[str] = None) -> ExtractionResult:
"""
从文本中抽取实体和关系
Args:
text: 待处理的文本
context: 可选的上下文信息(如品牌名、行业等)
Returns:
ExtractionResult: 包含实体和关系的抽取结果
"""
# 构建抽取Prompt
prompt = self._build_extraction_prompt(text, context)
# 调用LLM
response = await self.llm.generate(prompt)
# 解析结果
return self._parse_response(response)
def _build_extraction_prompt(self, text: str, context: Optional[str] = None) -> str:
"""构建抽取Prompt"""
entity_types = "\n".join([f"- {t}" for t in self.ENTITY_TYPES])
relation_types = "\n".join([f"- {t}" for t in self.RELATION_TYPES])
context_hint = f"\n\n附加上下文:{context}" if context else ""
return f"""从以下文本中抽取知识图谱的实体和关系。
要求:
1. 实体必须从文本中明确提及,不能臆造
2. 关系必须有文本依据,不能臆造
3. 每个实体和关系都要有置信度说明high/medium/low
实体类型:
{entity_types}
关系类型:
{relation_types}
{context_hint}
文本内容:
{text}
请以JSON格式返回结果
{{
"entities": [
{{
"name": "实体名称",
"entity_type": "实体类型",
"description": "实体描述(可选)",
"confidence": "high/medium/low"
}}
],
"relations": [
{{
"source_entity": "源实体名称",
"target_entity": "目标实体名称",
"relation_type": "关系类型",
"confidence": "high/medium/low"
}}
]
}}
只返回JSON不要有其他内容"""
def _parse_response(self, response: str) -> ExtractionResult:
"""解析LLM返回的结果"""
# 提取JSON
json_match = re.search(r'\{[\s\S]*\}', response)
if not json_match:
return ExtractionResult(entities=[], relations=[])
try:
data = json.loads(json_match.group())
except json.JSONDecodeError:
return ExtractionResult(entities=[], relations=[])
# 解析实体
entities = [
ExtractedEntity(
name=e["name"],
entity_type=e["entity_type"],
description=e.get("description"),
properties={"confidence": e.get("confidence", "medium")},
)
for e in data.get("entities", [])
]
# 解析关系
relations = [
ExtractedRelation(
source_entity=r["source_entity"],
target_entity=r["target_entity"],
relation_type=r["relation_type"],
properties={"confidence": r.get("confidence", "medium")},
)
for r in data.get("relations", [])
]
return ExtractionResult(entities=entities, relations=relations)