geo/backend/app/agent_framework/agents/schema_advisor.py

399 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import copy
import json
import logging
import time
from datetime import datetime, timezone
from app.agent_framework.base import BaseAgent
from app.agent_framework.prompts.schema_advisor import SCHEMA_ADVISOR_TEMPLATE
from app.agent_framework.protocol import (
AgentCapability,
AgentType,
TaskMessage,
TaskResult,
TaskStatus,
)
from app.services.llm import LLMFactory, LLMError
from app.utils.json_extractor import extract_json
logger = logging.getLogger(__name__)
SCHEMA_TEMPLATES = {
"Organization": {
"@context": "https://schema.org",
"@type": "Organization",
"name": "",
"description": "",
"url": "",
"logo": "",
"sameAs": [],
"contactPoint": {
"@type": "ContactPoint",
"contactType": "customer service",
"telephone": "",
},
},
"Product": {
"@context": "https://schema.org",
"@type": "Product",
"name": "",
"description": "",
"brand": {"@type": "Brand", "name": ""},
"offers": {
"@type": "Offer",
"priceCurrency": "CNY",
"availability": "https://schema.org/InStock",
},
},
"FAQPage": {
"@context": "https://schema.org",
"@type": "FAQPage",
"mainEntity": [
{
"@type": "Question",
"name": "",
"acceptedAnswer": {
"@type": "Answer",
"text": "",
},
}
],
},
"Article": {
"@context": "https://schema.org",
"@type": "Article",
"headline": "",
"description": "",
"author": {"@type": "Organization", "name": ""},
"datePublished": "",
"image": "",
},
"LocalBusiness": {
"@context": "https://schema.org",
"@type": "LocalBusiness",
"name": "",
"address": {
"@type": "PostalAddress",
"streetAddress": "",
"addressLocality": "",
"addressRegion": "",
"postalCode": "",
"addressCountry": "CN",
},
"geo": {
"@type": "GeoCoordinates",
"latitude": "",
"longitude": "",
},
"telephone": "",
"openingHours": "",
},
}
DIMENSION_SCHEMA_MAP = {
"schema_marketing": ["Organization", "LocalBusiness"],
"entity_clarity": ["Organization", "Product"],
"citation_readiness": ["FAQPage", "Article"],
"brand_visibility": ["Organization", "Product"],
"local_seo": ["LocalBusiness"],
}
PRIORITY_THRESHOLD = {
"high": 30.0,
"medium": 60.0,
}
DIFFICULTY_MAP = {
"Organization": "easy",
"Product": "medium",
"FAQPage": "medium",
"Article": "easy",
"LocalBusiness": "hard",
}
class SchemaAdvisorAgent(BaseAgent):
def __init__(self):
super().__init__(
name="schema_advisor",
agent_type=AgentType.SCHEMA_ADVISOR,
version="1.0.0",
)
def get_capabilities(self) -> AgentCapability:
return AgentCapability(
agent_name=self.name,
agent_type=self.agent_type,
version=self.version,
supported_tasks=["schema_advise"],
max_concurrency=2,
description="Schema优化建议Agent识别Schema缺失维度生成JSON-LD结构化数据建议",
)
async def execute(self, task: TaskMessage) -> TaskResult:
started_at = datetime.now(timezone.utc)
start_time = time.monotonic()
try:
output = await self._advise(task)
elapsed = time.monotonic() - start_time
return TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.COMPLETED,
output_data=output,
error_message=None,
started_at=started_at,
completed_at=datetime.now(timezone.utc),
metrics={
"elapsed_seconds": round(elapsed, 2),
"task_type": task.task_type,
},
)
except LLMError as e:
elapsed = time.monotonic() - start_time
logger.error(f"SchemaAdvisor LLM error on task {task.task_id}: {e}")
return TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.FAILED,
output_data=None,
error_message=f"LLM调用失败: {e}",
started_at=started_at,
completed_at=datetime.now(timezone.utc),
metrics={
"elapsed_seconds": round(elapsed, 2),
"task_type": task.task_type,
},
)
except Exception as e:
elapsed = time.monotonic() - start_time
logger.error(f"SchemaAdvisor task {task.task_id} failed: {e}")
return TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.FAILED,
output_data=None,
error_message=str(e),
started_at=started_at,
completed_at=datetime.now(timezone.utc),
metrics={
"elapsed_seconds": round(elapsed, 2),
"task_type": task.task_type,
},
)
async def _advise(self, task: TaskMessage) -> dict:
input_data = task.input_data
brand_id = input_data.get("brand_id")
diagnosis_data = input_data.get("diagnosis_data", {})
brand_info = input_data.get("brand_info", {})
focus_dimensions = input_data.get("focus_dimensions")
if not brand_id:
raise ValueError("input_data必须包含'brand_id'字段")
await self.report_progress(
task_id=task.task_id,
progress=0.1,
message="开始Schema建议分析...",
)
missing_dimensions = self._identify_missing_dimensions(diagnosis_data, focus_dimensions)
await self.report_progress(
task_id=task.task_id,
progress=0.3,
message=f"识别到{len(missing_dimensions)}个Schema缺失维度...",
)
matched = self._match_templates(missing_dimensions)
await self.report_progress(
task_id=task.task_id,
progress=0.5,
message="匹配预定义模板完成开始LLM填充...",
)
filled = await self._fill_with_llm(matched, brand_info)
await self.report_progress(
task_id=task.task_id,
progress=0.8,
message="LLM填充完成验证JSON-LD格式...",
)
validated = self._validate_and_sort(filled)
await self.report_progress(
task_id=task.task_id,
progress=1.0,
message="Schema建议生成完成",
)
return {
"brand_id": brand_id,
"suggestions": validated,
"total": len(validated),
}
def _identify_missing_dimensions(
self,
diagnosis_data: dict,
focus_dimensions: list[str] | None = None,
) -> list[dict]:
dimensions = []
dimension_scores = diagnosis_data.get("dimensions", {})
for dim_name, dim_info in dimension_scores.items():
if dim_name not in DIMENSION_SCHEMA_MAP:
continue
if focus_dimensions and dim_name not in focus_dimensions:
continue
score = dim_info.get("score", 0) if isinstance(dim_info, dict) else dim_info
max_score = dim_info.get("max_score", 100) if isinstance(dim_info, dict) else 100
percentage = (score / max_score * 100) if max_score > 0 else 0
if percentage < 80:
dimensions.append({
"dimension": dim_name,
"current_score": round(score, 2),
"max_score": max_score,
"percentage": round(percentage, 2),
})
if not dimensions and diagnosis_data:
overall = diagnosis_data.get("overall_score", 0)
if overall < 80:
for dim_name in DIMENSION_SCHEMA_MAP:
if focus_dimensions and dim_name not in focus_dimensions:
continue
dimensions.append({
"dimension": dim_name,
"current_score": 0,
"max_score": 100,
"percentage": 0,
})
return dimensions
def _match_templates(self, missing_dimensions: list[dict]) -> list[dict]:
matched = []
seen_types = set()
for dim in missing_dimensions:
schema_types = DIMENSION_SCHEMA_MAP.get(dim["dimension"], [])
for schema_type in schema_types:
if schema_type in seen_types:
continue
seen_types.add(schema_type)
template = SCHEMA_TEMPLATES.get(schema_type)
if template:
percentage = dim["percentage"]
if percentage < PRIORITY_THRESHOLD["high"]:
priority = "high"
elif percentage < PRIORITY_THRESHOLD["medium"]:
priority = "medium"
else:
priority = "low"
matched.append({
"schema_type": schema_type,
"priority": priority,
"diagnosis_dimensions": {
"dimension": dim["dimension"],
"current_score": dim["current_score"],
"max_score": dim["max_score"],
"percentage": dim["percentage"],
},
"json_ld_template": copy.deepcopy(template),
"implementation_difficulty": DIFFICULTY_MAP.get(schema_type, "medium"),
})
return matched
async def _fill_with_llm(self, matched: list[dict], brand_info: dict) -> list[dict]:
provider = LLMFactory.get_default()
results = []
for item in matched:
schema_type = item["schema_type"]
try:
variables = {
"brand_name": brand_info.get("name", ""),
"brand_website": brand_info.get("website", ""),
"brand_industry": brand_info.get("industry", ""),
"schema_type": schema_type,
"diagnosis_data": json.dumps(item.get("diagnosis_dimensions", {}), ensure_ascii=False),
"existing_schemas": "",
}
messages = SCHEMA_ADVISOR_TEMPLATE.render(variables)
response = await provider.chat(
messages,
temperature=0.3,
max_tokens=2048,
)
filled = json.loads(extract_json(response.content))
item["json_ld_filled"] = filled
item["estimated_impact"] = self._generate_impact_description(
schema_type, item.get("diagnosis_dimensions", {}).get("dimension", "")
)
except (json.JSONDecodeError, LLMError, ValueError) as e:
logger.warning(f"LLM填充Schema {schema_type} 失败: {e}")
item["json_ld_filled"] = None
item["estimated_impact"] = self._generate_impact_description(
schema_type, item.get("diagnosis_dimensions", {}).get("dimension", "")
)
results.append(item)
return results
def _validate_json_ld(self, json_ld: dict) -> dict:
errors = []
warnings = []
if not json_ld:
return {"is_valid": False, "errors": ["JSON-LD为空"], "warnings": []}
if "@context" not in json_ld:
errors.append("缺少@context字段")
if "@type" not in json_ld:
errors.append("缺少@type字段")
if "@context" in json_ld and json_ld["@context"] != "https://schema.org":
warnings.append(f"@context值非标准: {json_ld.get('@context')}")
if "@type" in json_ld and json_ld["@type"] not in SCHEMA_TEMPLATES:
warnings.append(f"@type非推荐类型: {json_ld.get('@type')}")
try:
json.dumps(json_ld)
except (json.JSONDecodeError, TypeError) as e:
errors.append(f"JSON序列化失败: {e}")
return {
"is_valid": len(errors) == 0,
"errors": errors,
"warnings": warnings,
}
def _validate_and_sort(self, items: list[dict]) -> list[dict]:
validated = []
for item in items:
json_ld_filled = item.get("json_ld_filled")
if json_ld_filled:
validation = self._validate_json_ld(json_ld_filled)
item["validation_errors"] = None if validation["is_valid"] else {"errors": validation["errors"], "warnings": validation["warnings"]}
else:
item["validation_errors"] = {"errors": ["JSON-LD填充失败"], "warnings": []}
validated.append(item)
priority_order = {"high": 0, "medium": 1, "low": 2}
validated.sort(key=lambda x: priority_order.get(x.get("priority", "medium"), 1))
return validated
def _generate_impact_description(self, schema_type: str, dimension: str) -> str:
impacts = {
"Organization": "增强品牌实体识别提升AI搜索引擎对品牌的理解和引用概率",
"Product": "提升产品在搜索结果中的富摘要展示,增加点击率和引用率",
"FAQPage": "增加FAQ富摘要展示机会提升在AI回答中的直接引用概率",
"Article": "优化文章内容的结构化表达提升AI搜索引擎的内容理解和引用",
"LocalBusiness": "增强本地搜索可见性,提升地理位置相关查询的引用率",
}
return impacts.get(schema_type, f"提升{dimension}维度的得分和AI引用率")