geo/backend/app/agent_framework/agents/deai_agent.py

222 lines
7.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""去AI化Agent - 消除AI生成痕迹"""
import logging
import time
from datetime import datetime, timezone
from typing import Optional
from app.agent_framework.base import BaseAgent
from app.agent_framework.prompts import DEAI_TEMPLATE
from app.agent_framework.protocol import (
AgentCapability,
AgentType,
TaskMessage,
TaskResult,
TaskStatus,
)
from app.services.llm import LLMFactory, LLMError
from app.services.distribution.platform_rules import PLATFORM_RULES, rule_engine
from app.services.distribution.rule_service import platform_rule_service
logger = logging.getLogger(__name__)
class DeAIAgent(BaseAgent):
"""内容去AI化处理消除AI生成特征
支持的任务类型:
- deai_process: 对内容进行去AI化处理
input_data 字段:
- content: str (必填,待处理的文章内容)
- platform: str (可选目标平台ID如 zhihu, wechat 等)
- style: str (可选,目标风格)
- preserve_structure: bool (可选,是否保留原有结构)
"""
def __init__(self):
super().__init__(
name="deai_agent",
agent_type=AgentType.DEAI_AGENT,
version="1.1.0",
)
def get_capabilities(self) -> AgentCapability:
return AgentCapability(
agent_name=self.name,
agent_type=self.agent_type,
version=self.version,
supported_tasks=["deai_process"],
max_concurrency=2,
description="内容去AI化Agent消除AI生成特征使文章更自然流畅",
)
async def execute(self, task: TaskMessage) -> TaskResult:
"""执行去AI化任务"""
started_at = datetime.now(timezone.utc)
start_time = time.monotonic()
try:
output = await self._process(task)
elapsed = time.monotonic() - start_time
return TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.COMPLETED,
output_data=output,
error_message=None,
started_at=started_at,
completed_at=datetime.now(timezone.utc),
metrics={
"elapsed_seconds": round(elapsed, 2),
"task_type": task.task_type,
},
)
except LLMError as e:
elapsed = time.monotonic() - start_time
logger.error(f"DeAIAgent LLM error on task {task.task_id}: {e}")
return TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.FAILED,
output_data=None,
error_message=f"LLM调用失败: {e}",
started_at=started_at,
completed_at=datetime.now(timezone.utc),
metrics={
"elapsed_seconds": round(elapsed, 2),
"task_type": task.task_type,
},
)
except Exception as e:
elapsed = time.monotonic() - start_time
logger.error(f"DeAIAgent task {task.task_id} failed: {e}")
return TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.FAILED,
output_data=None,
error_message=str(e),
started_at=started_at,
completed_at=datetime.now(timezone.utc),
metrics={
"elapsed_seconds": round(elapsed, 2),
"task_type": task.task_type,
},
)
async def _process(self, task: TaskMessage) -> dict:
"""执行去AI化处理
input_data 字段:
- content: str (必填,待处理的文章内容)
- platform: str (可选目标平台ID)
- style: str (可选,目标风格)
- preserve_structure: bool (可选,是否保留原有结构)
"""
input_data = task.input_data
content = input_data.get("content", "")
if not content:
raise ValueError("input_data必须包含非空的'content'字段")
platform_id = input_data.get("platform", "")
# 上报进度:开始
await self.report_progress(
task_id=task.task_id,
progress=0.1,
message="开始去AI化处理...",
)
# 获取平台特定配置
platform_config = self._get_platform_config(platform_id)
# 构建变量
variables = {
"original_content": content,
"target_style": input_data.get("style", "自然流畅"),
"preserve_structure": "" if input_data.get("preserve_structure", True) else "",
"platform_info": platform_config.get("platform_info", "通用"),
"ai_sensitivity": platform_config.get("ai_sensitivity", ""),
"banned_patterns": platform_config.get("banned_patterns", ""),
"safe_patterns": platform_config.get("safe_patterns", ""),
}
messages = DEAI_TEMPLATE.render(variables)
# 上报进度调用LLM
await self.report_progress(
task_id=task.task_id,
progress=0.3,
message="正在调用LLM进行去AI化改写...",
)
provider = LLMFactory.get_default()
response = await provider.chat(
messages,
temperature=0.9,
max_tokens=len(content) * 3,
)
# 检测处理后的AI模式
detected_patterns = []
if platform_id:
detected_patterns = platform_rule_service.detect_ai_patterns(
response.content, platform_id
)
# 上报进度:完成
await self.report_progress(
task_id=task.task_id,
progress=1.0,
message=f"去AI化处理完成原文{len(content)}字 -> 处理后{len(response.content)}",
)
return {
"content": response.content,
"original_word_count": len(content),
"processed_word_count": len(response.content),
"usage": response.usage,
"platform_id": platform_id,
"detected_ai_patterns": detected_patterns,
}
def _get_platform_config(self, platform_id: str) -> dict:
"""获取平台特定配置
Args:
platform_id: 平台标识
Returns:
包含平台配置的字典
"""
if not platform_id or platform_id not in PLATFORM_RULES:
return {
"platform_info": "通用",
"ai_sensitivity": "",
"banned_patterns": "总之、综上所述、首先其次最后等",
"safe_patterns": "根据研究表明、事实上、说实话",
}
rules = PLATFORM_RULES[platform_id]
ai_config = rules.get("ai_sensitivity", {})
platform_name = rules.get("name", platform_id)
detection_level = ai_config.get("detection_level", "medium")
banned = ai_config.get("banned_patterns", [])
safe = ai_config.get("safe_patterns", [])
return {
"platform_info": f"{platform_name} (检测级别: {detection_level})",
"ai_sensitivity": detection_level,
"banned_patterns": "".join(banned[:10]) if banned else "",
"safe_patterns": "".join(safe[:5]) if safe else "",
}
# 导出单例
deai_agent = DeAIAgent()