92 lines
2.0 KiB
Python
92 lines
2.0 KiB
Python
from abc import ABC, abstractmethod
|
|
from typing import AsyncIterator
|
|
from dataclasses import dataclass, field
|
|
|
|
|
|
@dataclass
|
|
class LLMResponse:
|
|
"""LLM响应数据类"""
|
|
|
|
content: str
|
|
model: str
|
|
usage: dict = field(
|
|
default_factory=lambda: {
|
|
"prompt_tokens": 0,
|
|
"completion_tokens": 0,
|
|
"total_tokens": 0,
|
|
}
|
|
)
|
|
|
|
|
|
class LLMError(Exception):
|
|
"""LLM调用异常"""
|
|
|
|
def __init__(self, message: str, provider: str, status_code: int | None = None):
|
|
self.provider = provider
|
|
self.status_code = status_code
|
|
super().__init__(f"[{provider}] {message}")
|
|
|
|
|
|
class LLMProvider(ABC):
|
|
"""LLM服务提供商抽象基类"""
|
|
|
|
@property
|
|
@abstractmethod
|
|
def provider_name(self) -> str:
|
|
"""提供商名称"""
|
|
...
|
|
|
|
@property
|
|
@abstractmethod
|
|
def model_name(self) -> str:
|
|
"""模型名称"""
|
|
...
|
|
|
|
@property
|
|
@abstractmethod
|
|
def max_context_length(self) -> int:
|
|
"""最大上下文长度"""
|
|
...
|
|
|
|
@abstractmethod
|
|
async def chat(
|
|
self,
|
|
messages: list[dict],
|
|
temperature: float = 0.7,
|
|
max_tokens: int = 4096,
|
|
stop: list[str] | None = None,
|
|
) -> LLMResponse:
|
|
"""
|
|
同步聊天接口(非流式)
|
|
|
|
Args:
|
|
messages: 消息列表,格式 [{"role": "user"|"assistant"|"system", "content": "..."}]
|
|
temperature: 采样温度
|
|
max_tokens: 最大生成token数
|
|
stop: 停止词列表
|
|
|
|
Returns:
|
|
LLMResponse
|
|
"""
|
|
...
|
|
|
|
@abstractmethod
|
|
async def chat_stream(
|
|
self,
|
|
messages: list[dict],
|
|
temperature: float = 0.7,
|
|
max_tokens: int = 4096,
|
|
) -> AsyncIterator[str]:
|
|
"""
|
|
流式聊天接口
|
|
|
|
Args:
|
|
messages: 消息列表
|
|
temperature: 采样温度
|
|
max_tokens: 最大生成token数
|
|
|
|
Yields:
|
|
逐个token的文本片段
|
|
"""
|
|
...
|