geo/backend/app/services/llm/base.py

92 lines
2.0 KiB
Python

from abc import ABC, abstractmethod
from typing import AsyncIterator
from dataclasses import dataclass, field
@dataclass
class LLMResponse:
"""LLM响应数据类"""
content: str
model: str
usage: dict = field(
default_factory=lambda: {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
}
)
class LLMError(Exception):
"""LLM调用异常"""
def __init__(self, message: str, provider: str, status_code: int | None = None):
self.provider = provider
self.status_code = status_code
super().__init__(f"[{provider}] {message}")
class LLMProvider(ABC):
"""LLM服务提供商抽象基类"""
@property
@abstractmethod
def provider_name(self) -> str:
"""提供商名称"""
...
@property
@abstractmethod
def model_name(self) -> str:
"""模型名称"""
...
@property
@abstractmethod
def max_context_length(self) -> int:
"""最大上下文长度"""
...
@abstractmethod
async def chat(
self,
messages: list[dict],
temperature: float = 0.7,
max_tokens: int = 4096,
stop: list[str] | None = None,
) -> LLMResponse:
"""
同步聊天接口(非流式)
Args:
messages: 消息列表,格式 [{"role": "user"|"assistant"|"system", "content": "..."}]
temperature: 采样温度
max_tokens: 最大生成token数
stop: 停止词列表
Returns:
LLMResponse
"""
...
@abstractmethod
async def chat_stream(
self,
messages: list[dict],
temperature: float = 0.7,
max_tokens: int = 4096,
) -> AsyncIterator[str]:
"""
流式聊天接口
Args:
messages: 消息列表
temperature: 采样温度
max_tokens: 最大生成token数
Yields:
逐个token的文本片段
"""
...