182 lines
5.5 KiB
Python
182 lines
5.5 KiB
Python
"""Agent 通信协议定义 - 统一消息格式"""
|
||
|
||
from dataclasses import dataclass, field
|
||
from datetime import datetime
|
||
from enum import Enum
|
||
|
||
|
||
class AgentType(str, Enum):
|
||
"""Agent 类型枚举"""
|
||
CITATION_DETECTOR = "citation_detector"
|
||
CONTENT_GENERATOR = "content_generator"
|
||
DEAI_AGENT = "deai_agent"
|
||
GEO_OPTIMIZER = "geo_optimizer"
|
||
RULE_CHECKER = "rule_checker"
|
||
COMPETITOR_ANALYZER = "competitor_analyzer"
|
||
PERFORMANCE_TRACKER = "performance_tracker"
|
||
|
||
|
||
class TaskStatus(str, Enum):
|
||
"""任务状态枚举"""
|
||
PENDING = "pending"
|
||
RUNNING = "running"
|
||
COMPLETED = "completed"
|
||
FAILED = "failed"
|
||
CANCELLED = "cancelled"
|
||
|
||
|
||
class AgentStatus(str, Enum):
|
||
"""Agent 状态枚举"""
|
||
ONLINE = "online"
|
||
OFFLINE = "offline"
|
||
BUSY = "busy"
|
||
|
||
|
||
@dataclass
|
||
class AgentCapability:
|
||
"""Agent 能力声明"""
|
||
agent_name: str
|
||
agent_type: str # AgentType value
|
||
version: str
|
||
supported_tasks: list[str]
|
||
max_concurrency: int
|
||
description: str
|
||
|
||
def to_dict(self) -> dict:
|
||
return {
|
||
"agent_name": self.agent_name,
|
||
"agent_type": self.agent_type,
|
||
"version": self.version,
|
||
"supported_tasks": self.supported_tasks,
|
||
"max_concurrency": self.max_concurrency,
|
||
"description": self.description,
|
||
}
|
||
|
||
@classmethod
|
||
def from_dict(cls, data: dict) -> "AgentCapability":
|
||
return cls(
|
||
agent_name=data["agent_name"],
|
||
agent_type=data["agent_type"],
|
||
version=data["version"],
|
||
supported_tasks=data["supported_tasks"],
|
||
max_concurrency=data["max_concurrency"],
|
||
description=data["description"],
|
||
)
|
||
|
||
|
||
@dataclass
|
||
class TaskMessage:
|
||
"""任务消息 - 从调度器发往 Agent"""
|
||
task_id: str # UUID
|
||
agent_name: str
|
||
task_type: str
|
||
priority: int
|
||
input_data: dict
|
||
callback_url: str | None
|
||
created_at: datetime
|
||
timeout_seconds: int = 300
|
||
|
||
def to_dict(self) -> dict:
|
||
return {
|
||
"task_id": self.task_id,
|
||
"agent_name": self.agent_name,
|
||
"task_type": self.task_type,
|
||
"priority": self.priority,
|
||
"input_data": self.input_data,
|
||
"callback_url": self.callback_url,
|
||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||
"timeout_seconds": self.timeout_seconds,
|
||
}
|
||
|
||
@classmethod
|
||
def from_dict(cls, data: dict) -> "TaskMessage":
|
||
created_at = data.get("created_at")
|
||
if isinstance(created_at, str):
|
||
created_at = datetime.fromisoformat(created_at)
|
||
return cls(
|
||
task_id=data["task_id"],
|
||
agent_name=data["agent_name"],
|
||
task_type=data["task_type"],
|
||
priority=data.get("priority", 0),
|
||
input_data=data.get("input_data", {}),
|
||
callback_url=data.get("callback_url"),
|
||
created_at=created_at or datetime.utcnow(),
|
||
timeout_seconds=data.get("timeout_seconds", 300),
|
||
)
|
||
|
||
|
||
@dataclass
|
||
class TaskResult:
|
||
"""任务结果 - 从 Agent 返回"""
|
||
task_id: str
|
||
agent_name: str
|
||
status: str # TaskStatus value: completed/failed/cancelled
|
||
output_data: dict | None
|
||
error_message: str | None
|
||
started_at: datetime
|
||
completed_at: datetime
|
||
metrics: dict | None # 执行指标(耗时、token消耗等)
|
||
|
||
def to_dict(self) -> dict:
|
||
return {
|
||
"task_id": self.task_id,
|
||
"agent_name": self.agent_name,
|
||
"status": self.status,
|
||
"output_data": self.output_data,
|
||
"error_message": self.error_message,
|
||
"started_at": self.started_at.isoformat() if self.started_at else None,
|
||
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
|
||
"metrics": self.metrics,
|
||
}
|
||
|
||
@classmethod
|
||
def from_dict(cls, data: dict) -> "TaskResult":
|
||
started_at = data.get("started_at")
|
||
if isinstance(started_at, str):
|
||
started_at = datetime.fromisoformat(started_at)
|
||
completed_at = data.get("completed_at")
|
||
if isinstance(completed_at, str):
|
||
completed_at = datetime.fromisoformat(completed_at)
|
||
return cls(
|
||
task_id=data["task_id"],
|
||
agent_name=data["agent_name"],
|
||
status=data["status"],
|
||
output_data=data.get("output_data"),
|
||
error_message=data.get("error_message"),
|
||
started_at=started_at or datetime.utcnow(),
|
||
completed_at=completed_at or datetime.utcnow(),
|
||
metrics=data.get("metrics"),
|
||
)
|
||
|
||
|
||
@dataclass
|
||
class TaskProgress:
|
||
"""进度上报 - Agent 执行过程中上报"""
|
||
task_id: str
|
||
agent_name: str
|
||
progress: float # 0.0 - 1.0
|
||
message: str
|
||
updated_at: datetime
|
||
|
||
def to_dict(self) -> dict:
|
||
return {
|
||
"task_id": self.task_id,
|
||
"agent_name": self.agent_name,
|
||
"progress": self.progress,
|
||
"message": self.message,
|
||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||
}
|
||
|
||
@classmethod
|
||
def from_dict(cls, data: dict) -> "TaskProgress":
|
||
updated_at = data.get("updated_at")
|
||
if isinstance(updated_at, str):
|
||
updated_at = datetime.fromisoformat(updated_at)
|
||
return cls(
|
||
task_id=data["task_id"],
|
||
agent_name=data["agent_name"],
|
||
progress=data.get("progress", 0.0),
|
||
message=data.get("message", ""),
|
||
updated_at=updated_at or datetime.utcnow(),
|
||
)
|