geo/backend/app/workers/llm_adapter.py

296 lines
9.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
LLM适配器 - 使用DeepSeek LLM API检测品牌引用
"""
import asyncio
import json
import logging
import random
import re
from typing import Optional
from app.schemas.scoring import CitationResult
from app.config import settings
logger = logging.getLogger(__name__)
BRAND_CITATION_PROMPT = """分析以下AI搜索查询中是否提到了目标品牌。
查询关键词: {keyword}
目标品牌: {brand_name}
品牌别名: {brand_aliases}
返回JSON格式:
{{"cited": true/false, "position": 1, "citation_text": "...", "sentiment": "positive/neutral/negative", "confidence": 0.95}}
"""
class LLMAdapterError(Exception):
"""LLM适配器异常"""
pass
class LLMAdapter:
"""LLM适配器 - 使用 OpenAI 兼容协议检测品牌引用(支持百炼/DashScope/DeepSeek"""
def __init__(self, api_key: Optional[str] = None, max_retries: int = 3):
"""
初始化LLM适配器
Args:
api_key: API密钥默认优先使用 OPENAI_API_KEY百炼/DashScope其次 DEEPSEEK_API_KEY
max_retries: 最大重试次数
"""
self.api_key = (
api_key
or getattr(settings, 'OPENAI_API_KEY', None)
or getattr(settings, 'DEEPSEEK_API_KEY', None)
)
# base_url 优先 OPENAI_BASE_URL其次 DEEPSEEK_BASE_URL
self.base_url = (
getattr(settings, 'OPENAI_BASE_URL', None)
or getattr(settings, 'DEEPSEEK_BASE_URL', 'https://api.deepseek.com/v1')
)
# model 优先 OPENAI_MODEL其次 DEFAULT_LLM_MODEL
self.model = (
getattr(settings, 'OPENAI_MODEL', None)
or getattr(settings, 'DEFAULT_LLM_MODEL', 'qwen3-coder-plus')
or 'qwen3-coder-plus'
)
self.max_retries = max_retries
self._client = None
@property
def client(self):
"""延迟初始化 OpenAI 兼容客户端"""
if self._client is None:
try:
from openai import OpenAI
self._client = OpenAI(
api_key=self.api_key,
base_url=self.base_url,
)
except ImportError:
raise LLMAdapterError("请安装openai库: pip install openai")
return self._client
def _build_prompt(self, keyword: str, brand_name: str, brand_aliases: list[str]) -> str:
"""构建Prompt"""
aliases_str = ", ".join(brand_aliases) if brand_aliases else ""
return BRAND_CITATION_PROMPT.format(
keyword=keyword,
brand_name=brand_name,
brand_aliases=aliases_str
)
async def query_brand_citation(
self,
keyword: str,
brand_name: str,
brand_aliases: list[str]
) -> CitationResult:
"""
使用LLM检测品牌引用
Args:
keyword: 查询关键词
brand_name: 目标品牌名称
brand_aliases: 品牌别名列表
Returns:
CitationResult: 包含cited, position, citation_text, sentiment, confidence
Raises:
LLMAdapterError: API调用或解析失败
"""
if not settings.ENABLE_LLM:
logger.info("LLM调用已禁用 (ENABLE_LLM=False),返回模拟数据")
return self._get_mock_result(keyword, brand_name, brand_aliases)
prompt = self._build_prompt(keyword, brand_name, brand_aliases)
last_error = None
for attempt in range(self.max_retries):
try:
response = await self._call_deepseek(prompt)
return self._parse_response(response)
except Exception as e:
last_error = e
logger.warning(
f"LLM API调用失败 (尝试 {attempt + 1}/{self.max_retries}): {e}"
)
raise LLMAdapterError(f"LLM API调用失败已重试{self.max_retries}次: {last_error}")
def _get_mock_result(
self,
keyword: str,
brand_name: str,
brand_aliases: list[str]
) -> CitationResult:
"""
生成模拟结果当LLM禁用时使用
随机决定是否引用,模拟真实场景的数据分布
"""
cited = random.random() < 0.6
sentiment_options = ["positive", "neutral", "negative"]
sentiment = random.choice(sentiment_options)
if cited:
position = random.randint(1, 10)
citation_text = f'模拟引用:在搜索"{keyword}"时,提到了{brand_name}品牌及其相关产品。'
else:
position = None
citation_text = ""
return CitationResult(
cited=cited,
position=position,
citation_text=citation_text,
sentiment=sentiment,
confidence=round(random.uniform(0.7, 0.99), 2)
)
async def _call_deepseek(self, prompt: str) -> dict:
"""
调用DeepSeek API
Args:
prompt: 提示词
Returns:
API响应的JSON解析结果
Raises:
LLMAdapterError: API调用失败
"""
try:
# 在线程池中执行同步的API调用
response = await asyncio.to_thread(
self._sync_call_deepseek,
prompt
)
return response
except json.JSONDecodeError as e:
raise LLMAdapterError(f"JSON解析失败: {e}")
except Exception as e:
raise LLMAdapterError(f"API调用失败: {e}")
def _sync_call_deepseek(self, prompt: str) -> dict:
"""
同步调用DeepSeek API在线程池中执行
Args:
prompt: 提示词
Returns:
API响应的JSON解析结果
"""
response = self.client.chat.completions.create(
model=self.model,
messages=[
{
"role": "user",
"content": prompt
}
],
temperature=0.1,
max_tokens=500,
)
content = response.choices[0].message.content
if not content:
raise LLMAdapterError("API返回空响应")
# 提取JSON可能包裹在```json block中
json_str = self._extract_json(content)
return json.loads(json_str)
def _extract_json(self, text: str) -> str:
"""从文本中提取JSON"""
# 尝试直接解析
try:
json.loads(text)
return text
except json.JSONDecodeError:
pass
# 尝试从代码块中提取
json_pattern = r'```(?:json)?\s*([\s\S]*?)\s*```'
match = re.search(json_pattern, text)
if match:
return match.group(1).strip()
# 尝试找到第一个{到最后一个}之间的内容
first_brace = text.find('{')
last_brace = text.rfind('}')
if first_brace != -1 and last_brace != -1 and last_brace > first_brace:
return text[first_brace:last_brace + 1]
raise LLMAdapterError(f"无法从响应中提取JSON: {text[:200]}")
def _parse_response(self, response: dict) -> CitationResult:
"""
解析API响应
Args:
response: API返回的字典
Returns:
CitationResult对象
Raises:
LLMAdapterError: 解析失败
"""
try:
required_fields = ['cited', 'sentiment', 'confidence']
for field in required_fields:
if field not in response:
raise LLMAdapterError(f"响应缺少必需字段: {field}")
cited = bool(response['cited'])
sentiment = str(response.get('sentiment', 'neutral')).lower()
if sentiment not in ['positive', 'neutral', 'negative']:
sentiment = 'neutral'
# 验证position
position = response.get('position')
if position is not None:
position = int(position)
if position < 1:
position = None
# 验证confidence
confidence = float(response.get('confidence', 0.5))
confidence = max(0.0, min(1.0, confidence))
citation_text = response.get('citation_text')
if citation_text and len(citation_text) > 500:
citation_text = citation_text[:500]
return CitationResult(
cited=cited,
position=position,
citation_text=citation_text,
sentiment=sentiment,
confidence=confidence
)
except (ValueError, TypeError) as e:
raise LLMAdapterError(f"解析响应失败: {e}")
async def close(self):
"""关闭客户端连接"""
if self._client is not None:
try:
# OpenAI/DeepSeek客户端不需要显式关闭
pass
except Exception as e:
logger.warning(f"关闭LLM客户端时出错: {e}")
finally:
self._client = None