geo/tests/test_prompt_template.py

265 lines
9.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Prompt 模板单元测试"""
import pytest
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def simple_template():
"""构造一个简单的 PromptTemplate"""
from app.agent_framework.prompts.base_template import PromptSection, PromptTemplate
sections = PromptSection(
identity="你是一个AI助手",
context="品牌:${brand_name}",
instructions="请为 ${topic} 生成内容",
constraints="字数不超过500字",
output_format="输出 JSON 格式",
examples="示例:{ 'title': '...' }",
)
return PromptTemplate(sections=sections)
# ---------------------------------------------------------------------------
# PromptTemplate 基本测试
# ---------------------------------------------------------------------------
class TestPromptTemplate:
def test_render_returns_messages(self, simple_template):
"""render 输出 list[dict] 格式"""
messages = simple_template.render()
assert isinstance(messages, list)
assert len(messages) >= 1
for msg in messages:
assert "role" in msg
assert "content" in msg
assert msg["role"] in ("system", "user")
def test_system_user_message_split(self, simple_template):
"""system 含 identity+contextuser 含 instructions"""
messages = simple_template.render(variables={"brand_name": "TestBrand", "topic": "AI"})
roles = [m["role"] for m in messages]
assert "system" in roles
assert "user" in roles
system_msg = next(m for m in messages if m["role"] == "system")
user_msg = next(m for m in messages if m["role"] == "user")
# identity 和 context 在 system
assert "你是一个AI助手" in system_msg["content"]
# instructions 在 user
assert "生成内容" in user_msg["content"]
def test_variable_injection_simple(self, simple_template):
"""${var} 被替换"""
messages = simple_template.render(variables={"brand_name": "MyBrand", "topic": "AI"})
all_content = " ".join(m["content"] for m in messages)
assert "MyBrand" in all_content
assert "AI" in all_content
# 原始占位符不应存在
assert "${brand_name}" not in all_content
assert "${topic}" not in all_content
def test_variable_injection_nested(self):
"""${a.b.c} 嵌套路径解析"""
from app.agent_framework.prompts.base_template import PromptSection, PromptTemplate
sections = PromptSection(
instructions="平台:${platform.name},目标:${goal.type}",
)
template = PromptTemplate(sections=sections)
messages = template.render(variables={
"platform": {"name": "微信公众号"},
"goal": {"type": "品牌曝光"},
})
user_msg = next(m for m in messages if m["role"] == "user")
assert "微信公众号" in user_msg["content"]
assert "品牌曝光" in user_msg["content"]
def test_unresolved_variable_kept(self):
"""未注入的变量保持 ${var} 原样"""
from app.agent_framework.prompts.base_template import PromptSection, PromptTemplate
sections = PromptSection(
instructions="主题是 ${undefined_var}",
)
template = PromptTemplate(sections=sections)
messages = template.render(variables={})
user_msg = next(m for m in messages if m["role"] == "user")
assert "${undefined_var}" in user_msg["content"]
def test_truncation_within_budget(self):
"""短文本不被截断"""
from app.agent_framework.prompts.base_template import PromptSection, PromptTemplate
short_context = "简短的上下文内容"
sections = PromptSection(
identity="身份",
context=short_context,
)
template = PromptTemplate(sections=sections)
messages = template.render(context_budget=3000)
system_msg = next(m for m in messages if m["role"] == "system")
assert short_context in system_msg["content"]
assert "中间内容已省略" not in system_msg["content"]
def test_truncation_exceeds_budget(self):
"""超长文本被智能截断"""
from app.agent_framework.prompts.base_template import PromptSection, PromptTemplate
# 生成超过 budget 的文本budget=10 token
long_context = "这是一段非常长的上下文内容,包含大量文字。" * 200
sections = PromptSection(
identity="我是AI助手",
context=long_context,
)
template = PromptTemplate(sections=sections)
# 设置很小的 budget 强制触发截断
messages = template.render(context_budget=50)
system_msg = next(m for m in messages if m["role"] == "system")
# 截断标记应该出现
assert "中间内容已省略" in system_msg["content"]
def test_render_with_no_variables(self, simple_template):
"""不传变量也能正常 render未解析变量保持原样"""
messages = simple_template.render()
assert isinstance(messages, list)
assert len(messages) > 0
def test_render_empty_sections(self):
"""空 sections 不返回空角色 message"""
from app.agent_framework.prompts.base_template import PromptSection, PromptTemplate
sections = PromptSection(instructions="做些事情")
template = PromptTemplate(sections=sections)
messages = template.render()
# system 内容为空时不应出现 system message
roles = [m["role"] for m in messages]
assert "user" in roles
# 只有 user messageidentity/context/constraints 都为空)
system_msgs = [m for m in messages if m["role"] == "system"]
for sm in system_msgs:
assert sm["content"].strip() != ""
# ---------------------------------------------------------------------------
# 5个 Template 全部能正常 render
# ---------------------------------------------------------------------------
class TestAllTemplatesRender:
def test_topic_selector_template_renders(self):
"""TOPIC_SELECTOR_TEMPLATE 能正常 render"""
from app.agent_framework.prompts import TOPIC_SELECTOR_TEMPLATE
messages = TOPIC_SELECTOR_TEMPLATE.render(variables={
"target_keyword": "AI营销",
"brand_name": "示例品牌",
"target_platform": "微信公众号",
})
assert isinstance(messages, list)
assert len(messages) > 0
for m in messages:
assert "role" in m
assert "content" in m
def test_content_generator_template_renders(self):
"""CONTENT_GENERATOR_TEMPLATE 能正常 render"""
from app.agent_framework.prompts import CONTENT_GENERATOR_TEMPLATE
messages = CONTENT_GENERATOR_TEMPLATE.render(variables={
"topic": "AI发展趋势",
"platform": "知乎",
})
assert isinstance(messages, list)
assert len(messages) > 0
def test_deai_template_renders(self):
"""DEAI_TEMPLATE 能正常 render"""
from app.agent_framework.prompts import DEAI_TEMPLATE
messages = DEAI_TEMPLATE.render(variables={
"content": "测试AI生成内容",
})
assert isinstance(messages, list)
assert len(messages) > 0
def test_geo_optimizer_template_renders(self):
"""GEO_OPTIMIZER_TEMPLATE 能正常 render"""
from app.agent_framework.prompts import GEO_OPTIMIZER_TEMPLATE
messages = GEO_OPTIMIZER_TEMPLATE.render(variables={
"content": "待优化的内容",
"platform": "小红书",
})
assert isinstance(messages, list)
assert len(messages) > 0
def test_rule_checker_template_renders(self):
"""RULE_CHECKER_TEMPLATE 能正常 render"""
from app.agent_framework.prompts import RULE_CHECKER_TEMPLATE
messages = RULE_CHECKER_TEMPLATE.render(variables={
"content": "待检查的内容",
"platform": "微信公众号",
})
assert isinstance(messages, list)
assert len(messages) > 0
# ---------------------------------------------------------------------------
# PromptTemplate._inject 直接测试
# ---------------------------------------------------------------------------
class TestPromptTemplateInject:
@pytest.fixture
def template(self):
from app.agent_framework.prompts.base_template import PromptSection, PromptTemplate
return PromptTemplate(sections=PromptSection())
def test_inject_simple_var(self, template):
"""简单变量注入"""
result = template._inject("Hello ${name}", {"name": "World"})
assert result == "Hello World"
def test_inject_multiple_vars(self, template):
"""多变量注入"""
result = template._inject("${a} and ${b}", {"a": "foo", "b": "bar"})
assert result == "foo and bar"
def test_inject_nested_path(self, template):
"""嵌套路径注入"""
result = template._inject("${x.y}", {"x": {"y": "deep_value"}})
assert result == "deep_value"
def test_inject_missing_var_kept(self, template):
"""缺失变量保持原样"""
result = template._inject("${missing}", {})
assert result == "${missing}"
def test_inject_empty_text(self, template):
"""空文本原样返回"""
result = template._inject("", {"x": "val"})
assert result == ""
def test_estimate_tokens_chinese(self, template):
"""中文字符 token 估算:每字约 1 token"""
text = "你好世界" # 4 个中文字符
tokens = template._estimate_tokens(text)
assert tokens == 4
def test_estimate_tokens_empty(self, template):
"""空文本 token 为 0"""
assert template._estimate_tokens("") == 0