142 lines
5.0 KiB
Python
142 lines
5.0 KiB
Python
"""Prompt模板引擎 - Context Engineering风格的模块化Prompt系统
|
||
|
||
核心设计:
|
||
- PromptSection: 模块化section定义(identity/context/instructions/constraints/output_format/examples)
|
||
- PromptTemplate: 支持${variable}变量注入、token预算管理、渲染为OpenAI messages格式
|
||
"""
|
||
|
||
import re
|
||
from dataclasses import dataclass
|
||
|
||
|
||
@dataclass
|
||
class PromptSection:
|
||
"""Prompt模块化section
|
||
|
||
每个section对应prompt中的一个语义模块,渲染时按规则组装为
|
||
system/user两个message。
|
||
"""
|
||
|
||
identity: str = "" # Agent身份定义
|
||
context: str = "" # 动态上下文(知识库、品牌信息等)
|
||
instructions: str = "" # 任务指令
|
||
constraints: str = "" # 约束条件
|
||
output_format: str = "" # 输出格式要求
|
||
examples: str = "" # 少样本示例(可选)
|
||
|
||
|
||
class PromptTemplate:
|
||
"""Context Engineering风格的Prompt模板
|
||
|
||
特性:
|
||
- 支持 ${variable} 和 ${nested.path} 变量注入
|
||
- 支持 token 预算管理(智能截断context section,避免Lost-in-the-Middle)
|
||
- 渲染为 OpenAI messages 格式
|
||
|
||
渲染规则:
|
||
- system message = identity + context(截断到budget) + constraints
|
||
- user message = instructions + output_format + examples
|
||
"""
|
||
|
||
def __init__(self, sections: PromptSection):
|
||
self.sections = sections
|
||
|
||
def render(
|
||
self, variables: dict | None = None, context_budget: int = 3000
|
||
) -> list[dict]:
|
||
"""渲染为messages列表
|
||
|
||
Args:
|
||
variables: 变量字典,支持嵌套路径如 {"brand": {"name": "xxx"}}
|
||
context_budget: context section的token预算上限
|
||
|
||
Returns:
|
||
[{"role": "system", "content": "..."}, {"role": "user", "content": "..."}]
|
||
"""
|
||
variables = variables or {}
|
||
|
||
# 1. 注入变量到所有section
|
||
identity = self._inject(self.sections.identity, variables)
|
||
context = self._inject(self.sections.context, variables)
|
||
instructions = self._inject(self.sections.instructions, variables)
|
||
constraints = self._inject(self.sections.constraints, variables)
|
||
output_format = self._inject(self.sections.output_format, variables)
|
||
examples = self._inject(self.sections.examples, variables)
|
||
|
||
# 2. 截断context到token预算
|
||
context = self._truncate_smart(context, context_budget)
|
||
|
||
# 3. 组装messages
|
||
system_parts = [identity, context, constraints]
|
||
system_content = "\n\n".join(p for p in system_parts if p.strip())
|
||
|
||
user_parts = [instructions, output_format, examples]
|
||
user_content = "\n\n".join(p for p in user_parts if p.strip())
|
||
|
||
messages = []
|
||
if system_content:
|
||
messages.append({"role": "system", "content": system_content})
|
||
if user_content:
|
||
messages.append({"role": "user", "content": user_content})
|
||
|
||
return messages
|
||
|
||
def _inject(self, text: str, variables: dict) -> str:
|
||
"""替换 ${var} 和 ${var.nested.path} 占位符
|
||
|
||
支持多层嵌套字典路径解析,如 ${brand.name} 会从
|
||
{"brand": {"name": "xxx"}} 中取出 "xxx"。
|
||
未匹配的变量保持原样。
|
||
"""
|
||
if not text:
|
||
return text
|
||
|
||
def replacer(match):
|
||
key = match.group(1)
|
||
parts = key.split(".")
|
||
value = variables
|
||
for part in parts:
|
||
if isinstance(value, dict):
|
||
value = value.get(part, f"${{{key}}}")
|
||
else:
|
||
return f"${{{key}}}"
|
||
return str(value) if not isinstance(value, dict) else f"${{{key}}}"
|
||
|
||
return re.sub(r"\$\{([^}]+)\}", replacer, text)
|
||
|
||
def _truncate_smart(self, text: str, max_tokens: int) -> str:
|
||
"""智能截断:保留开头和结尾,中间省略
|
||
|
||
策略:保留开头50%和结尾30%的字符预算,中间部分省略。
|
||
这种策略基于Lost-in-the-Middle研究,LLM对首尾信息的
|
||
注意力更强,中间部分容易被忽略。
|
||
"""
|
||
estimated_tokens = self._estimate_tokens(text)
|
||
if estimated_tokens <= max_tokens:
|
||
return text
|
||
|
||
if not text:
|
||
return text
|
||
|
||
# 按字符近似截断(1 token ≈ 1.5 chars for中英混合)
|
||
char_budget = int(max_tokens * 1.5)
|
||
head_budget = int(char_budget * 0.5)
|
||
tail_budget = int(char_budget * 0.3)
|
||
|
||
head = text[:head_budget]
|
||
tail = text[-tail_budget:]
|
||
return f"{head}\n\n[... 中间内容已省略以控制上下文长度 ...]\n\n{tail}"
|
||
|
||
def _estimate_tokens(self, text: str) -> int:
|
||
"""粗略估算token数
|
||
|
||
中文字符约1 token/字,英文约4字符/token。
|
||
这是一个粗略估计,实际tokenizer会更精确,
|
||
但对于预算管理来说足够使用。
|
||
"""
|
||
if not text:
|
||
return 0
|
||
chinese_chars = sum(1 for c in text if "\u4e00" <= c <= "\u9fff")
|
||
other_chars = len(text) - chinese_chars
|
||
return chinese_chars + int(other_chars / 4)
|