geo/backend/app/services/image_generator.py

312 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""阿里云百炼图片生成服务"""
import os
from dataclasses import dataclass
from typing import Optional
import httpx
# 平台尺寸适配
PLATFORM_IMAGE_SPECS = {
"zhihu": {
"cover": {"width": 690, "height": 280, "ratio": "2.5:1"},
"inline": {"width": 500, "height": 375, "ratio": "4:3"},
},
"wechat": {
"cover": {"width": 900, "height": 383, "ratio": "2.35:1"},
"inline": {"width": 800, "height": 600, "ratio": "4:3"},
},
"xiaohongshu": {
"cover": {"width": 1080, "height": 1080, "ratio": "1:1"}, # 方版
"inline": {"width": 1242, "height": 1660, "ratio": "3:4"}, # 竖版
},
"toutiao": {
"cover": {"width": 1024, "height": 678, "ratio": "1.5:1"},
},
"baijiahao": {
"cover": {"width": 600, "height": 400, "ratio": "3:2"},
},
"weibo": {
"cover": {"width": 980, "height": 560, "ratio": "1.75:1"},
},
"bilibili": {
"cover": {"width": 1920, "height": 1080, "ratio": "16:9"},
},
"jianshu": {
"cover": {"width": 800, "height": 600, "ratio": "4:3"},
},
"juejin": {
"cover": {"width": 1024, "height": 768, "ratio": "4:3"},
},
"douyin": {
"cover": {"width": 1080, "height": 1920, "ratio": "9:16"}, # 竖版短视频封面
},
}
# 风格选项
IMAGE_STYLES = {
"modern": {"name": "现代简约", "prompt": "modern minimalist style, clean design, professional"},
"tech": {"name": "科技感", "prompt": "tech style, futuristic, digital, blue tones"},
"elegant": {"name": "优雅商务", "prompt": "elegant business style, sophisticated, premium"},
"creative": {"name": "创意活力", "prompt": "creative vibrant style, colorful, dynamic"},
"minimal": {"name": "极简主义", "prompt": "ultra minimal, white space, typography focus"},
}
# 排版选项
LAYOUT_OPTIONS = {
"centered": {"name": "居中排版", "prompt": "centered composition, text in middle"},
"left_text": {"name": "左文右图", "prompt": "left side text, right side visual"},
"top_text": {"name": "上文下图", "prompt": "text on top, visual below"},
"text_overlay": {"name": "文字叠加", "prompt": "text overlay on background image"},
}
@dataclass
class ImageResult:
"""图片生成结果"""
url: str
width: int
height: int
prompt: str
platform: str
task_id: str
class ImageGenerationError(Exception):
"""图片生成异常"""
pass
class ImageGenerator:
"""阿里云百炼图片生成服务(万相-文生图V1"""
def __init__(self):
self.api_key = os.getenv("ALIYUN_DASHSCOPE_API_KEY")
self.base_url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text2image/image-synthesis"
self.timeout = 120.0 # 异步任务等待超时时间(秒)
async def generate_cover(
self,
title: str,
platform: str,
image_type: str = "cover",
style: str = "modern",
layout: str = "centered",
custom_prompt: str = None,
) -> ImageResult:
"""生成封面图
Args:
title: 文章标题
platform: 目标平台
image_type: 图片类型 (cover/inline)
style: 风格选项
layout: 排版选项
custom_prompt: 自定义提示词(可选)
Returns:
ImageResult: 包含生成结果的 dataclass
"""
# 1. 获取平台尺寸
specs = PLATFORM_IMAGE_SPECS.get(platform, PLATFORM_IMAGE_SPECS["zhihu"])
size_spec = specs.get(image_type, specs["cover"])
# 2. 构建提示词
if custom_prompt:
prompt = custom_prompt
else:
prompt = self._build_prompt(title, platform, style, layout)
# 3. 调用百炼API异步
task_id = await self._create_task(prompt, size_spec)
# 4. 轮询结果
result = await self._wait_for_result(task_id)
return ImageResult(
url=result["image_url"],
width=size_spec["width"],
height=size_spec["height"],
prompt=prompt,
platform=platform,
task_id=task_id,
)
def _build_prompt(self, title: str, platform: str, style: str, layout: str) -> str:
"""构建AI提示词
Args:
title: 文章标题
platform: 目标平台
style: 风格选项
layout: 排版选项
Returns:
str: 构造的英文提示词
"""
style_prompt = IMAGE_STYLES.get(style, IMAGE_STYLES["modern"])["prompt"]
layout_prompt = LAYOUT_OPTIONS.get(layout, LAYOUT_OPTIONS["centered"])["prompt"]
# 平台特定要求
platform_notes = {
"xiaohongshu": "warm tones, lifestyle, lifestyle photography",
"wechat": "professional, clean, suitable for WeChat article cover",
"zhihu": "intellectual, professional, suitable for long-form content",
"toutiao": "eye-catching, clear hierarchy, news style",
"baijiahao": "clear, professional, suitable for news media",
"weibo": "social media friendly, engaging",
"bilibili": "anime style friendly, vibrant, suitable for video platform",
"jianshu": "literary, elegant, clean layout",
"juejin": "tech blog style, developer friendly",
"douyin": "vertical video style, eye-catching, short form content",
}
platform_note = platform_notes.get(platform, "")
return f"{title}, {style_prompt}, {layout_prompt}, {platform_note}, high quality, 4K"
async def _create_task(self, prompt: str, size_spec: dict) -> str:
"""创建异步任务
Args:
prompt: 英文提示词
size_spec: 尺寸规格
Returns:
str: 任务ID
Raises:
ImageGenerationError: API调用失败时
"""
if not self.api_key:
raise ImageGenerationError("ALIYUN_DASHSCOPE_API_KEY 未设置")
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
payload = {
"model": "wanx-v1", # 万相-文生图V1
"input": {
"prompt": prompt,
},
"parameters": {
"size": f"{size_spec['width']}x{size_spec['height']}",
"n": 1,
},
}
async with httpx.AsyncClient(timeout=30.0) as client:
try:
response = await client.post(
self.base_url,
headers=headers,
json=payload,
)
response.raise_for_status()
data = response.json()
if data.get("code"):
raise ImageGenerationError(f"API错误: {data.get('message', data.get('code'))}")
# 返回任务ID
task_id = data.get("output", {}).get("task_id")
if not task_id:
raise ImageGenerationError("未获取到任务ID")
return task_id
except httpx.HTTPStatusError as e:
raise ImageGenerationError(f"HTTP错误: {e.response.status_code}")
except httpx.RequestError as e:
raise ImageGenerationError(f"请求错误: {e}")
async def _wait_for_result(self, task_id: str) -> dict:
"""轮询等待结果
Args:
task_id: 任务ID
Returns:
dict: 包含 image_url 的结果
Raises:
ImageGenerationError: 任务失败或超时
"""
status_url = f"{self.base_url}/task/{task_id}"
headers = {
"Authorization": f"Bearer {self.api_key}",
}
async with httpx.AsyncClient(timeout=self.timeout) as client:
import asyncio
start_time = asyncio.get_event_loop().time()
poll_interval = 2.0 # 轮询间隔(秒)
while True:
try:
response = await client.get(status_url, headers=headers)
response.raise_for_status()
data = response.json()
status = data.get("output", {}).get("task_status")
if status == "succeeded":
# 任务成功
images = data.get("output", {}).get("results", [])
if images:
return {"image_url": images[0].get("url")}
raise ImageGenerationError("未获取到生成图片URL")
elif status == "failed":
error_msg = data.get("output", {}).get("message", "任务失败")
raise ImageGenerationError(f"图片生成失败: {error_msg}")
elif status == "pending":
# 等待中,继续轮询
pass
else:
# 未知状态,继续等待
pass
# 检查超时
elapsed = asyncio.get_event_loop().time() - start_time
if elapsed > self.timeout:
raise ImageGenerationError("图片生成超时")
# 等待后继续轮询
await asyncio.sleep(poll_interval)
except httpx.HTTPStatusError as e:
raise ImageGenerationError(f"HTTP错误: {e.response.status_code}")
except httpx.RequestError as e:
raise ImageGenerationError(f"请求错误: {e}")
@staticmethod
def get_platform_specs(platform: str) -> Optional[dict]:
"""获取平台的图片规格
Args:
platform: 平台标识
Returns:
Optional[dict]: 平台图片规格如果没有则返回None
"""
return PLATFORM_IMAGE_SPECS.get(platform)
@staticmethod
def get_supported_platforms() -> list[str]:
"""获取支持的平台列表"""
return list(PLATFORM_IMAGE_SPECS.keys())
@staticmethod
def get_styles() -> dict:
"""获取所有风格选项"""
return IMAGE_STYLES
@staticmethod
def get_layouts() -> dict:
"""获取所有排版选项"""
return LAYOUT_OPTIONS