312 lines
10 KiB
Python
312 lines
10 KiB
Python
"""阿里云百炼图片生成服务"""
|
||
|
||
import os
|
||
from dataclasses import dataclass
|
||
from typing import Optional
|
||
|
||
import httpx
|
||
|
||
# 平台尺寸适配
|
||
PLATFORM_IMAGE_SPECS = {
|
||
"zhihu": {
|
||
"cover": {"width": 690, "height": 280, "ratio": "2.5:1"},
|
||
"inline": {"width": 500, "height": 375, "ratio": "4:3"},
|
||
},
|
||
"wechat": {
|
||
"cover": {"width": 900, "height": 383, "ratio": "2.35:1"},
|
||
"inline": {"width": 800, "height": 600, "ratio": "4:3"},
|
||
},
|
||
"xiaohongshu": {
|
||
"cover": {"width": 1080, "height": 1080, "ratio": "1:1"}, # 方版
|
||
"inline": {"width": 1242, "height": 1660, "ratio": "3:4"}, # 竖版
|
||
},
|
||
"toutiao": {
|
||
"cover": {"width": 1024, "height": 678, "ratio": "1.5:1"},
|
||
},
|
||
"baijiahao": {
|
||
"cover": {"width": 600, "height": 400, "ratio": "3:2"},
|
||
},
|
||
"weibo": {
|
||
"cover": {"width": 980, "height": 560, "ratio": "1.75:1"},
|
||
},
|
||
"bilibili": {
|
||
"cover": {"width": 1920, "height": 1080, "ratio": "16:9"},
|
||
},
|
||
"jianshu": {
|
||
"cover": {"width": 800, "height": 600, "ratio": "4:3"},
|
||
},
|
||
"juejin": {
|
||
"cover": {"width": 1024, "height": 768, "ratio": "4:3"},
|
||
},
|
||
"douyin": {
|
||
"cover": {"width": 1080, "height": 1920, "ratio": "9:16"}, # 竖版短视频封面
|
||
},
|
||
}
|
||
|
||
# 风格选项
|
||
IMAGE_STYLES = {
|
||
"modern": {"name": "现代简约", "prompt": "modern minimalist style, clean design, professional"},
|
||
"tech": {"name": "科技感", "prompt": "tech style, futuristic, digital, blue tones"},
|
||
"elegant": {"name": "优雅商务", "prompt": "elegant business style, sophisticated, premium"},
|
||
"creative": {"name": "创意活力", "prompt": "creative vibrant style, colorful, dynamic"},
|
||
"minimal": {"name": "极简主义", "prompt": "ultra minimal, white space, typography focus"},
|
||
}
|
||
|
||
# 排版选项
|
||
LAYOUT_OPTIONS = {
|
||
"centered": {"name": "居中排版", "prompt": "centered composition, text in middle"},
|
||
"left_text": {"name": "左文右图", "prompt": "left side text, right side visual"},
|
||
"top_text": {"name": "上文下图", "prompt": "text on top, visual below"},
|
||
"text_overlay": {"name": "文字叠加", "prompt": "text overlay on background image"},
|
||
}
|
||
|
||
|
||
@dataclass
|
||
class ImageResult:
|
||
"""图片生成结果"""
|
||
url: str
|
||
width: int
|
||
height: int
|
||
prompt: str
|
||
platform: str
|
||
task_id: str
|
||
|
||
|
||
class ImageGenerationError(Exception):
|
||
"""图片生成异常"""
|
||
pass
|
||
|
||
|
||
class ImageGenerator:
|
||
"""阿里云百炼图片生成服务(万相-文生图V1)"""
|
||
|
||
def __init__(self):
|
||
self.api_key = os.getenv("ALIYUN_DASHSCOPE_API_KEY")
|
||
self.base_url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text2image/image-synthesis"
|
||
self.timeout = 120.0 # 异步任务等待超时时间(秒)
|
||
|
||
async def generate_cover(
|
||
self,
|
||
title: str,
|
||
platform: str,
|
||
image_type: str = "cover",
|
||
style: str = "modern",
|
||
layout: str = "centered",
|
||
custom_prompt: str = None,
|
||
) -> ImageResult:
|
||
"""生成封面图
|
||
|
||
Args:
|
||
title: 文章标题
|
||
platform: 目标平台
|
||
image_type: 图片类型 (cover/inline)
|
||
style: 风格选项
|
||
layout: 排版选项
|
||
custom_prompt: 自定义提示词(可选)
|
||
|
||
Returns:
|
||
ImageResult: 包含生成结果的 dataclass
|
||
"""
|
||
# 1. 获取平台尺寸
|
||
specs = PLATFORM_IMAGE_SPECS.get(platform, PLATFORM_IMAGE_SPECS["zhihu"])
|
||
size_spec = specs.get(image_type, specs["cover"])
|
||
|
||
# 2. 构建提示词
|
||
if custom_prompt:
|
||
prompt = custom_prompt
|
||
else:
|
||
prompt = self._build_prompt(title, platform, style, layout)
|
||
|
||
# 3. 调用百炼API(异步)
|
||
task_id = await self._create_task(prompt, size_spec)
|
||
|
||
# 4. 轮询结果
|
||
result = await self._wait_for_result(task_id)
|
||
|
||
return ImageResult(
|
||
url=result["image_url"],
|
||
width=size_spec["width"],
|
||
height=size_spec["height"],
|
||
prompt=prompt,
|
||
platform=platform,
|
||
task_id=task_id,
|
||
)
|
||
|
||
def _build_prompt(self, title: str, platform: str, style: str, layout: str) -> str:
|
||
"""构建AI提示词
|
||
|
||
Args:
|
||
title: 文章标题
|
||
platform: 目标平台
|
||
style: 风格选项
|
||
layout: 排版选项
|
||
|
||
Returns:
|
||
str: 构造的英文提示词
|
||
"""
|
||
style_prompt = IMAGE_STYLES.get(style, IMAGE_STYLES["modern"])["prompt"]
|
||
layout_prompt = LAYOUT_OPTIONS.get(layout, LAYOUT_OPTIONS["centered"])["prompt"]
|
||
|
||
# 平台特定要求
|
||
platform_notes = {
|
||
"xiaohongshu": "warm tones, lifestyle, lifestyle photography",
|
||
"wechat": "professional, clean, suitable for WeChat article cover",
|
||
"zhihu": "intellectual, professional, suitable for long-form content",
|
||
"toutiao": "eye-catching, clear hierarchy, news style",
|
||
"baijiahao": "clear, professional, suitable for news media",
|
||
"weibo": "social media friendly, engaging",
|
||
"bilibili": "anime style friendly, vibrant, suitable for video platform",
|
||
"jianshu": "literary, elegant, clean layout",
|
||
"juejin": "tech blog style, developer friendly",
|
||
"douyin": "vertical video style, eye-catching, short form content",
|
||
}
|
||
platform_note = platform_notes.get(platform, "")
|
||
|
||
return f"{title}, {style_prompt}, {layout_prompt}, {platform_note}, high quality, 4K"
|
||
|
||
async def _create_task(self, prompt: str, size_spec: dict) -> str:
|
||
"""创建异步任务
|
||
|
||
Args:
|
||
prompt: 英文提示词
|
||
size_spec: 尺寸规格
|
||
|
||
Returns:
|
||
str: 任务ID
|
||
|
||
Raises:
|
||
ImageGenerationError: API调用失败时
|
||
"""
|
||
if not self.api_key:
|
||
raise ImageGenerationError("ALIYUN_DASHSCOPE_API_KEY 未设置")
|
||
|
||
headers = {
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
"Content-Type": "application/json",
|
||
}
|
||
|
||
payload = {
|
||
"model": "wanx-v1", # 万相-文生图V1
|
||
"input": {
|
||
"prompt": prompt,
|
||
},
|
||
"parameters": {
|
||
"size": f"{size_spec['width']}x{size_spec['height']}",
|
||
"n": 1,
|
||
},
|
||
}
|
||
|
||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||
try:
|
||
response = await client.post(
|
||
self.base_url,
|
||
headers=headers,
|
||
json=payload,
|
||
)
|
||
response.raise_for_status()
|
||
data = response.json()
|
||
|
||
if data.get("code"):
|
||
raise ImageGenerationError(f"API错误: {data.get('message', data.get('code'))}")
|
||
|
||
# 返回任务ID
|
||
task_id = data.get("output", {}).get("task_id")
|
||
if not task_id:
|
||
raise ImageGenerationError("未获取到任务ID")
|
||
|
||
return task_id
|
||
|
||
except httpx.HTTPStatusError as e:
|
||
raise ImageGenerationError(f"HTTP错误: {e.response.status_code}")
|
||
except httpx.RequestError as e:
|
||
raise ImageGenerationError(f"请求错误: {e}")
|
||
|
||
async def _wait_for_result(self, task_id: str) -> dict:
|
||
"""轮询等待结果
|
||
|
||
Args:
|
||
task_id: 任务ID
|
||
|
||
Returns:
|
||
dict: 包含 image_url 的结果
|
||
|
||
Raises:
|
||
ImageGenerationError: 任务失败或超时
|
||
"""
|
||
status_url = f"{self.base_url}/task/{task_id}"
|
||
|
||
headers = {
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
}
|
||
|
||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||
import asyncio
|
||
start_time = asyncio.get_event_loop().time()
|
||
poll_interval = 2.0 # 轮询间隔(秒)
|
||
|
||
while True:
|
||
try:
|
||
response = await client.get(status_url, headers=headers)
|
||
response.raise_for_status()
|
||
data = response.json()
|
||
|
||
status = data.get("output", {}).get("task_status")
|
||
|
||
if status == "succeeded":
|
||
# 任务成功
|
||
images = data.get("output", {}).get("results", [])
|
||
if images:
|
||
return {"image_url": images[0].get("url")}
|
||
raise ImageGenerationError("未获取到生成图片URL")
|
||
|
||
elif status == "failed":
|
||
error_msg = data.get("output", {}).get("message", "任务失败")
|
||
raise ImageGenerationError(f"图片生成失败: {error_msg}")
|
||
|
||
elif status == "pending":
|
||
# 等待中,继续轮询
|
||
pass
|
||
|
||
else:
|
||
# 未知状态,继续等待
|
||
pass
|
||
|
||
# 检查超时
|
||
elapsed = asyncio.get_event_loop().time() - start_time
|
||
if elapsed > self.timeout:
|
||
raise ImageGenerationError("图片生成超时")
|
||
|
||
# 等待后继续轮询
|
||
await asyncio.sleep(poll_interval)
|
||
|
||
except httpx.HTTPStatusError as e:
|
||
raise ImageGenerationError(f"HTTP错误: {e.response.status_code}")
|
||
except httpx.RequestError as e:
|
||
raise ImageGenerationError(f"请求错误: {e}")
|
||
|
||
@staticmethod
|
||
def get_platform_specs(platform: str) -> Optional[dict]:
|
||
"""获取平台的图片规格
|
||
|
||
Args:
|
||
platform: 平台标识
|
||
|
||
Returns:
|
||
Optional[dict]: 平台图片规格,如果没有则返回None
|
||
"""
|
||
return PLATFORM_IMAGE_SPECS.get(platform)
|
||
|
||
@staticmethod
|
||
def get_supported_platforms() -> list[str]:
|
||
"""获取支持的平台列表"""
|
||
return list(PLATFORM_IMAGE_SPECS.keys())
|
||
|
||
@staticmethod
|
||
def get_styles() -> dict:
|
||
"""获取所有风格选项"""
|
||
return IMAGE_STYLES
|
||
|
||
@staticmethod
|
||
def get_layouts() -> dict:
|
||
"""获取所有排版选项"""
|
||
return LAYOUT_OPTIONS |