265 lines
9.9 KiB
Python
265 lines
9.9 KiB
Python
"""Prompt 模板单元测试"""
|
||
import pytest
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Fixtures
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@pytest.fixture
|
||
def simple_template():
|
||
"""构造一个简单的 PromptTemplate"""
|
||
from app.agent_framework.prompts.base_template import PromptSection, PromptTemplate
|
||
|
||
sections = PromptSection(
|
||
identity="你是一个AI助手",
|
||
context="品牌:${brand_name}",
|
||
instructions="请为 ${topic} 生成内容",
|
||
constraints="字数不超过500字",
|
||
output_format="输出 JSON 格式",
|
||
examples="示例:{ 'title': '...' }",
|
||
)
|
||
return PromptTemplate(sections=sections)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# PromptTemplate 基本测试
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestPromptTemplate:
|
||
def test_render_returns_messages(self, simple_template):
|
||
"""render 输出 list[dict] 格式"""
|
||
messages = simple_template.render()
|
||
|
||
assert isinstance(messages, list)
|
||
assert len(messages) >= 1
|
||
for msg in messages:
|
||
assert "role" in msg
|
||
assert "content" in msg
|
||
assert msg["role"] in ("system", "user")
|
||
|
||
def test_system_user_message_split(self, simple_template):
|
||
"""system 含 identity+context,user 含 instructions"""
|
||
messages = simple_template.render(variables={"brand_name": "TestBrand", "topic": "AI"})
|
||
|
||
roles = [m["role"] for m in messages]
|
||
assert "system" in roles
|
||
assert "user" in roles
|
||
|
||
system_msg = next(m for m in messages if m["role"] == "system")
|
||
user_msg = next(m for m in messages if m["role"] == "user")
|
||
|
||
# identity 和 context 在 system
|
||
assert "你是一个AI助手" in system_msg["content"]
|
||
# instructions 在 user
|
||
assert "生成内容" in user_msg["content"]
|
||
|
||
def test_variable_injection_simple(self, simple_template):
|
||
"""${var} 被替换"""
|
||
messages = simple_template.render(variables={"brand_name": "MyBrand", "topic": "AI"})
|
||
|
||
all_content = " ".join(m["content"] for m in messages)
|
||
assert "MyBrand" in all_content
|
||
assert "AI" in all_content
|
||
# 原始占位符不应存在
|
||
assert "${brand_name}" not in all_content
|
||
assert "${topic}" not in all_content
|
||
|
||
def test_variable_injection_nested(self):
|
||
"""${a.b.c} 嵌套路径解析"""
|
||
from app.agent_framework.prompts.base_template import PromptSection, PromptTemplate
|
||
|
||
sections = PromptSection(
|
||
instructions="平台:${platform.name},目标:${goal.type}",
|
||
)
|
||
template = PromptTemplate(sections=sections)
|
||
|
||
messages = template.render(variables={
|
||
"platform": {"name": "微信公众号"},
|
||
"goal": {"type": "品牌曝光"},
|
||
})
|
||
|
||
user_msg = next(m for m in messages if m["role"] == "user")
|
||
assert "微信公众号" in user_msg["content"]
|
||
assert "品牌曝光" in user_msg["content"]
|
||
|
||
def test_unresolved_variable_kept(self):
|
||
"""未注入的变量保持 ${var} 原样"""
|
||
from app.agent_framework.prompts.base_template import PromptSection, PromptTemplate
|
||
|
||
sections = PromptSection(
|
||
instructions="主题是 ${undefined_var}",
|
||
)
|
||
template = PromptTemplate(sections=sections)
|
||
messages = template.render(variables={})
|
||
|
||
user_msg = next(m for m in messages if m["role"] == "user")
|
||
assert "${undefined_var}" in user_msg["content"]
|
||
|
||
def test_truncation_within_budget(self):
|
||
"""短文本不被截断"""
|
||
from app.agent_framework.prompts.base_template import PromptSection, PromptTemplate
|
||
|
||
short_context = "简短的上下文内容"
|
||
sections = PromptSection(
|
||
identity="身份",
|
||
context=short_context,
|
||
)
|
||
template = PromptTemplate(sections=sections)
|
||
messages = template.render(context_budget=3000)
|
||
|
||
system_msg = next(m for m in messages if m["role"] == "system")
|
||
assert short_context in system_msg["content"]
|
||
assert "中间内容已省略" not in system_msg["content"]
|
||
|
||
def test_truncation_exceeds_budget(self):
|
||
"""超长文本被智能截断"""
|
||
from app.agent_framework.prompts.base_template import PromptSection, PromptTemplate
|
||
|
||
# 生成超过 budget 的文本(budget=10 token)
|
||
long_context = "这是一段非常长的上下文内容,包含大量文字。" * 200
|
||
|
||
sections = PromptSection(
|
||
identity="我是AI助手",
|
||
context=long_context,
|
||
)
|
||
template = PromptTemplate(sections=sections)
|
||
# 设置很小的 budget 强制触发截断
|
||
messages = template.render(context_budget=50)
|
||
|
||
system_msg = next(m for m in messages if m["role"] == "system")
|
||
# 截断标记应该出现
|
||
assert "中间内容已省略" in system_msg["content"]
|
||
|
||
def test_render_with_no_variables(self, simple_template):
|
||
"""不传变量也能正常 render(未解析变量保持原样)"""
|
||
messages = simple_template.render()
|
||
assert isinstance(messages, list)
|
||
assert len(messages) > 0
|
||
|
||
def test_render_empty_sections(self):
|
||
"""空 sections 不返回空角色 message"""
|
||
from app.agent_framework.prompts.base_template import PromptSection, PromptTemplate
|
||
|
||
sections = PromptSection(instructions="做些事情")
|
||
template = PromptTemplate(sections=sections)
|
||
messages = template.render()
|
||
|
||
# system 内容为空时不应出现 system message
|
||
roles = [m["role"] for m in messages]
|
||
assert "user" in roles
|
||
# 只有 user message(identity/context/constraints 都为空)
|
||
system_msgs = [m for m in messages if m["role"] == "system"]
|
||
for sm in system_msgs:
|
||
assert sm["content"].strip() != ""
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 5个 Template 全部能正常 render
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestAllTemplatesRender:
|
||
def test_topic_selector_template_renders(self):
|
||
"""TOPIC_SELECTOR_TEMPLATE 能正常 render"""
|
||
from app.agent_framework.prompts import TOPIC_SELECTOR_TEMPLATE
|
||
|
||
messages = TOPIC_SELECTOR_TEMPLATE.render(variables={
|
||
"target_keyword": "AI营销",
|
||
"brand_name": "示例品牌",
|
||
"target_platform": "微信公众号",
|
||
})
|
||
assert isinstance(messages, list)
|
||
assert len(messages) > 0
|
||
for m in messages:
|
||
assert "role" in m
|
||
assert "content" in m
|
||
|
||
def test_content_generator_template_renders(self):
|
||
"""CONTENT_GENERATOR_TEMPLATE 能正常 render"""
|
||
from app.agent_framework.prompts import CONTENT_GENERATOR_TEMPLATE
|
||
|
||
messages = CONTENT_GENERATOR_TEMPLATE.render(variables={
|
||
"topic": "AI发展趋势",
|
||
"platform": "知乎",
|
||
})
|
||
assert isinstance(messages, list)
|
||
assert len(messages) > 0
|
||
|
||
def test_deai_template_renders(self):
|
||
"""DEAI_TEMPLATE 能正常 render"""
|
||
from app.agent_framework.prompts import DEAI_TEMPLATE
|
||
|
||
messages = DEAI_TEMPLATE.render(variables={
|
||
"content": "测试AI生成内容",
|
||
})
|
||
assert isinstance(messages, list)
|
||
assert len(messages) > 0
|
||
|
||
def test_geo_optimizer_template_renders(self):
|
||
"""GEO_OPTIMIZER_TEMPLATE 能正常 render"""
|
||
from app.agent_framework.prompts import GEO_OPTIMIZER_TEMPLATE
|
||
|
||
messages = GEO_OPTIMIZER_TEMPLATE.render(variables={
|
||
"content": "待优化的内容",
|
||
"platform": "小红书",
|
||
})
|
||
assert isinstance(messages, list)
|
||
assert len(messages) > 0
|
||
|
||
def test_rule_checker_template_renders(self):
|
||
"""RULE_CHECKER_TEMPLATE 能正常 render"""
|
||
from app.agent_framework.prompts import RULE_CHECKER_TEMPLATE
|
||
|
||
messages = RULE_CHECKER_TEMPLATE.render(variables={
|
||
"content": "待检查的内容",
|
||
"platform": "微信公众号",
|
||
})
|
||
assert isinstance(messages, list)
|
||
assert len(messages) > 0
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# PromptTemplate._inject 直接测试
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestPromptTemplateInject:
|
||
@pytest.fixture
|
||
def template(self):
|
||
from app.agent_framework.prompts.base_template import PromptSection, PromptTemplate
|
||
return PromptTemplate(sections=PromptSection())
|
||
|
||
def test_inject_simple_var(self, template):
|
||
"""简单变量注入"""
|
||
result = template._inject("Hello ${name}", {"name": "World"})
|
||
assert result == "Hello World"
|
||
|
||
def test_inject_multiple_vars(self, template):
|
||
"""多变量注入"""
|
||
result = template._inject("${a} and ${b}", {"a": "foo", "b": "bar"})
|
||
assert result == "foo and bar"
|
||
|
||
def test_inject_nested_path(self, template):
|
||
"""嵌套路径注入"""
|
||
result = template._inject("${x.y}", {"x": {"y": "deep_value"}})
|
||
assert result == "deep_value"
|
||
|
||
def test_inject_missing_var_kept(self, template):
|
||
"""缺失变量保持原样"""
|
||
result = template._inject("${missing}", {})
|
||
assert result == "${missing}"
|
||
|
||
def test_inject_empty_text(self, template):
|
||
"""空文本原样返回"""
|
||
result = template._inject("", {"x": "val"})
|
||
assert result == ""
|
||
|
||
def test_estimate_tokens_chinese(self, template):
|
||
"""中文字符 token 估算:每字约 1 token"""
|
||
text = "你好世界" # 4 个中文字符
|
||
tokens = template._estimate_tokens(text)
|
||
assert tokens == 4
|
||
|
||
def test_estimate_tokens_empty(self, template):
|
||
"""空文本 token 为 0"""
|
||
assert template._estimate_tokens("") == 0
|