271 lines
8.2 KiB
Python
271 lines
8.2 KiB
Python
"""
|
||
LLM适配器 - 使用DeepSeek LLM API检测品牌引用
|
||
"""
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
import re
|
||
from typing import Optional
|
||
|
||
from app.schemas.scoring import CitationResult
|
||
from app.config import settings
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
BRAND_CITATION_PROMPT = """分析以下AI搜索查询中是否提到了目标品牌。
|
||
|
||
查询关键词: {keyword}
|
||
目标品牌: {brand_name}
|
||
品牌别名: {brand_aliases}
|
||
|
||
返回JSON格式:
|
||
{{"cited": true/false, "position": 1, "citation_text": "...", "sentiment": "positive/neutral/negative", "confidence": 0.95}}
|
||
"""
|
||
|
||
|
||
class LLMAdapterError(Exception):
|
||
"""LLM适配器异常"""
|
||
pass
|
||
|
||
|
||
class LLMAdapter:
|
||
"""LLM适配器 - 使用 OpenAI 兼容协议检测品牌引用(支持百炼/DashScope/DeepSeek)"""
|
||
|
||
def __init__(self, api_key: Optional[str] = None, max_retries: int = 3):
|
||
"""
|
||
初始化LLM适配器
|
||
|
||
Args:
|
||
api_key: API密钥,默认优先使用 OPENAI_API_KEY(百炼/DashScope),其次 DEEPSEEK_API_KEY
|
||
max_retries: 最大重试次数
|
||
"""
|
||
self.api_key = (
|
||
api_key
|
||
or getattr(settings, 'OPENAI_API_KEY', None)
|
||
or getattr(settings, 'DEEPSEEK_API_KEY', None)
|
||
)
|
||
# base_url 优先 OPENAI_BASE_URL,其次 DEEPSEEK_BASE_URL
|
||
self.base_url = (
|
||
getattr(settings, 'OPENAI_BASE_URL', None)
|
||
or getattr(settings, 'DEEPSEEK_BASE_URL', 'https://api.deepseek.com/v1')
|
||
)
|
||
# model 优先 OPENAI_MODEL,其次 DEFAULT_LLM_MODEL
|
||
self.model = (
|
||
getattr(settings, 'OPENAI_MODEL', None)
|
||
or getattr(settings, 'DEFAULT_LLM_MODEL', 'qwen3-coder-plus')
|
||
or 'qwen3-coder-plus'
|
||
)
|
||
self.max_retries = max_retries
|
||
self._client = None
|
||
|
||
@property
|
||
def client(self):
|
||
"""延迟初始化 OpenAI 兼容客户端"""
|
||
if self._client is None:
|
||
try:
|
||
from openai import OpenAI
|
||
self._client = OpenAI(
|
||
api_key=self.api_key,
|
||
base_url=self.base_url,
|
||
)
|
||
except ImportError:
|
||
raise LLMAdapterError("请安装openai库: pip install openai")
|
||
return self._client
|
||
|
||
def _build_prompt(self, keyword: str, brand_name: str, brand_aliases: list[str]) -> str:
|
||
"""构建Prompt"""
|
||
aliases_str = ", ".join(brand_aliases) if brand_aliases else "无"
|
||
return BRAND_CITATION_PROMPT.format(
|
||
keyword=keyword,
|
||
brand_name=brand_name,
|
||
brand_aliases=aliases_str
|
||
)
|
||
|
||
async def query_brand_citation(
|
||
self,
|
||
keyword: str,
|
||
brand_name: str,
|
||
brand_aliases: list[str]
|
||
) -> CitationResult:
|
||
"""
|
||
使用LLM检测品牌引用
|
||
|
||
Args:
|
||
keyword: 查询关键词
|
||
brand_name: 目标品牌名称
|
||
brand_aliases: 品牌别名列表
|
||
|
||
Returns:
|
||
CitationResult: 包含cited, position, citation_text, sentiment, confidence
|
||
|
||
Raises:
|
||
LLMAdapterError: API调用或解析失败
|
||
"""
|
||
if not settings.ENABLE_LLM:
|
||
raise LLMAdapterError(
|
||
"LLM引用检测未启用。请在环境变量中设置 ENABLE_LLM=True 并配置 DEEPSEEK_API_KEY"
|
||
)
|
||
|
||
if not self.api_key:
|
||
raise LLMAdapterError(
|
||
"未配置DeepSeek API Key。请设置 DEEPSEEK_API_KEY 环境变量"
|
||
)
|
||
|
||
prompt = self._build_prompt(keyword, brand_name, brand_aliases)
|
||
|
||
last_error = None
|
||
for attempt in range(self.max_retries):
|
||
try:
|
||
response = await self._call_deepseek(prompt)
|
||
return self._parse_response(response)
|
||
|
||
except Exception as e:
|
||
last_error = e
|
||
logger.warning(
|
||
f"LLM API调用失败 (尝试 {attempt + 1}/{self.max_retries}): {e}"
|
||
)
|
||
|
||
raise LLMAdapterError(f"LLM API调用失败,已重试{self.max_retries}次: {last_error}")
|
||
|
||
async def _call_deepseek(self, prompt: str) -> dict:
|
||
"""
|
||
调用DeepSeek API
|
||
|
||
Args:
|
||
prompt: 提示词
|
||
|
||
Returns:
|
||
API响应的JSON解析结果
|
||
|
||
Raises:
|
||
LLMAdapterError: API调用失败
|
||
"""
|
||
try:
|
||
# 在线程池中执行同步的API调用
|
||
response = await asyncio.to_thread(
|
||
self._sync_call_deepseek,
|
||
prompt
|
||
)
|
||
return response
|
||
|
||
except json.JSONDecodeError as e:
|
||
raise LLMAdapterError(f"JSON解析失败: {e}")
|
||
except Exception as e:
|
||
raise LLMAdapterError(f"API调用失败: {e}")
|
||
|
||
def _sync_call_deepseek(self, prompt: str) -> dict:
|
||
"""
|
||
同步调用DeepSeek API(在线程池中执行)
|
||
|
||
Args:
|
||
prompt: 提示词
|
||
|
||
Returns:
|
||
API响应的JSON解析结果
|
||
"""
|
||
response = self.client.chat.completions.create(
|
||
model=self.model,
|
||
messages=[
|
||
{
|
||
"role": "user",
|
||
"content": prompt
|
||
}
|
||
],
|
||
temperature=0.1,
|
||
max_tokens=500,
|
||
)
|
||
|
||
content = response.choices[0].message.content
|
||
if not content:
|
||
raise LLMAdapterError("API返回空响应")
|
||
|
||
# 提取JSON(可能包裹在```json block中)
|
||
json_str = self._extract_json(content)
|
||
return json.loads(json_str)
|
||
|
||
def _extract_json(self, text: str) -> str:
|
||
"""从文本中提取JSON"""
|
||
# 尝试直接解析
|
||
try:
|
||
json.loads(text)
|
||
return text
|
||
except json.JSONDecodeError:
|
||
pass
|
||
|
||
# 尝试从代码块中提取
|
||
json_pattern = r'```(?:json)?\s*([\s\S]*?)\s*```'
|
||
match = re.search(json_pattern, text)
|
||
if match:
|
||
return match.group(1).strip()
|
||
|
||
# 尝试找到第一个{到最后一个}之间的内容
|
||
first_brace = text.find('{')
|
||
last_brace = text.rfind('}')
|
||
if first_brace != -1 and last_brace != -1 and last_brace > first_brace:
|
||
return text[first_brace:last_brace + 1]
|
||
|
||
raise LLMAdapterError(f"无法从响应中提取JSON: {text[:200]}")
|
||
|
||
def _parse_response(self, response: dict) -> CitationResult:
|
||
"""
|
||
解析API响应
|
||
|
||
Args:
|
||
response: API返回的字典
|
||
|
||
Returns:
|
||
CitationResult对象
|
||
|
||
Raises:
|
||
LLMAdapterError: 解析失败
|
||
"""
|
||
try:
|
||
required_fields = ['cited', 'sentiment', 'confidence']
|
||
for field in required_fields:
|
||
if field not in response:
|
||
raise LLMAdapterError(f"响应缺少必需字段: {field}")
|
||
|
||
cited = bool(response['cited'])
|
||
sentiment = str(response.get('sentiment', 'neutral')).lower()
|
||
|
||
if sentiment not in ['positive', 'neutral', 'negative']:
|
||
sentiment = 'neutral'
|
||
|
||
# 验证position
|
||
position = response.get('position')
|
||
if position is not None:
|
||
position = int(position)
|
||
if position < 1:
|
||
position = None
|
||
|
||
# 验证confidence
|
||
confidence = float(response.get('confidence', 0.5))
|
||
confidence = max(0.0, min(1.0, confidence))
|
||
|
||
citation_text = response.get('citation_text')
|
||
if citation_text and len(citation_text) > 500:
|
||
citation_text = citation_text[:500]
|
||
|
||
return CitationResult(
|
||
cited=cited,
|
||
position=position,
|
||
citation_text=citation_text,
|
||
sentiment=sentiment,
|
||
confidence=confidence
|
||
)
|
||
|
||
except (ValueError, TypeError) as e:
|
||
raise LLMAdapterError(f"解析响应失败: {e}")
|
||
|
||
async def close(self):
|
||
"""关闭客户端连接"""
|
||
if self._client is not None:
|
||
try:
|
||
# OpenAI/DeepSeek客户端不需要显式关闭
|
||
pass
|
||
except Exception as e:
|
||
logger.warning(f"关闭LLM客户端时出错: {e}")
|
||
finally:
|
||
self._client = None
|