feat(rag_platform): U6 — hit processing mode + KB settings

Add hit_processing.py: HitProcessor with model_opt (LLM-generated) and direct (concatenated chunks) modes, with in-process cache
Add settings.py: KBSettings/KBSettingsUpdate models, KBSettingsStore with async CRUD
Add KB settings endpoints to kb_management.py: GET/PUT /kb-management/kbs/{kb_id}/settings with owner-only modification

Tests: 43 new tests (25 hit_processing + 18 settings), 293 total passing
This commit is contained in:
chiguyong 2026-06-25 12:44:47 +08:00
parent 5c562dbff3
commit d026a91f43
5 changed files with 1037 additions and 0 deletions

View File

@ -0,0 +1,181 @@
"""命中处理 — 模型优化模式LLM 生成回答vs 直接回答模式(返回匹配段落)。
- model_opt: 调用 LLM Gateway基于检索结果生成自然语言回答
- direct: 直接拼接匹配段落返回 LLM 调用零额外延迟
模式选择由 KB 设置 ``default_hit_processing`` 决定Agent 运行时可通过
``RetrievalRequest.hit_processing_mode`` 覆盖
"""
from __future__ import annotations
import hashlib
import logging
from typing import Any
from pydantic import BaseModel, ConfigDict
from agentkit.rag_platform.models import QueryResult
logger = logging.getLogger(__name__)
HIT_PROCESSING_MODEL_OPT = "model_opt"
HIT_PROCESSING_DIRECT = "direct"
# ponytail: 硬编码默认模型 — 升级路径为从 KB 设置 / agentkit.yaml 读取
_DEFAULT_LLM_MODEL = "gpt-4o-mini"
# RAG prompt 模板 — 指示 LLM 基于参考资料回答
_RAG_PROMPT_TEMPLATE = (
"你是一个知识库问答助手。请根据以下检索到的参考资料回答用户问题。\n\n"
"参考资料:\n{context}\n\n"
"用户问题:{query}\n\n"
"请基于参考资料回答,如果资料中没有相关信息请说明。"
)
class HitProcessingResult(BaseModel):
"""命中处理结果。"""
model_config = ConfigDict()
mode: str # "model_opt" | "direct"
answer: str # LLM 生成的回答或直接拼接的段落
sources: list[QueryResult] # 检索到的来源
cached: bool = False # 是否命中缓存
class HitProcessor:
"""命中处理器 — 根据 mode 处理检索结果。
Args:
llm_gateway: LLM 网关实例model_opt 模式使用需支持 ``async chat(messages, model)`` 接口
cache_enabled: 是否启用命中处理缓存默认 True
model: LLM 模型名默认 gpt-4o-mini
"""
def __init__(
self,
llm_gateway: Any = None,
cache_enabled: bool = True,
model: str = _DEFAULT_LLM_MODEL,
) -> None:
self._llm = llm_gateway
self._cache_enabled = cache_enabled
self._model = model
# ponytail: 进程内 dict 缓存 — 无 TTL无容量上限
# 升级路径:迁移到 Redis带 TTL + LRU 淘汰
self._cache: dict[str, HitProcessingResult] = {}
async def process(
self,
query: str,
results: list[QueryResult],
mode: str = HIT_PROCESSING_MODEL_OPT,
) -> HitProcessingResult:
"""处理检索结果。
Args:
query: 用户查询文本
results: 检索结果列表
mode: 命中处理模式 "model_opt"LLM 生成 "direct"直接返回段落
Returns:
HitProcessingResult 包含回答来源和缓存标记
"""
if not results:
return HitProcessingResult(
mode=mode, answer="未找到相关内容。", sources=[], cached=False
)
# 缓存检查
cache_key: str | None = None
if self._cache_enabled:
cache_key = self._cache_key(query, results, mode)
cached = self._cache.get(cache_key)
if cached is not None:
return cached.model_copy(update={"cached": True})
if mode == HIT_PROCESSING_DIRECT:
result = self._process_direct(results)
else:
result = await self._process_model_opt(query, results)
# 写入缓存
if cache_key is not None:
self._cache[cache_key] = result
return result
def _process_direct(self, results: list[QueryResult]) -> HitProcessingResult:
"""直接回答模式 — 拼接匹配段落,按分数降序排列。"""
sorted_results = sorted(results, key=lambda r: r.score, reverse=True)
parts: list[str] = []
for i, r in enumerate(sorted_results, start=1):
parts.append(f"[{i}] {r.content}")
answer = "\n\n".join(parts)
return HitProcessingResult(
mode=HIT_PROCESSING_DIRECT,
answer=answer,
sources=sorted_results,
cached=False,
)
async def _process_model_opt(
self,
query: str,
results: list[QueryResult],
) -> HitProcessingResult:
"""模型优化模式 — LLM 基于检索结果生成回答。
LLM 网关或调用失败时降级为 direct 模式不丢失检索结果
"""
if self._llm is None:
logger.warning("No LLM gateway configured, falling back to direct mode")
return self._process_direct(results)
# 构建 context — 按分数降序
sorted_results = sorted(results, key=lambda r: r.score, reverse=True)
context_parts: list[str] = []
for i, r in enumerate(sorted_results, start=1):
context_parts.append(f"[{i}] {r.content}")
context = "\n\n".join(context_parts)
prompt = _RAG_PROMPT_TEMPLATE.format(context=context, query=query)
messages = [
{"role": "system", "content": "你是一个知识库问答助手,请基于参考资料回答问题。"},
{"role": "user", "content": prompt},
]
try:
response = await self._llm.chat(messages=messages, model=self._model)
answer = response.content
except Exception as e:
logger.warning("LLM call failed, falling back to direct mode: %s", e)
return self._process_direct(results)
return HitProcessingResult(
mode=HIT_PROCESSING_MODEL_OPT,
answer=answer,
sources=sorted_results,
cached=False,
)
@staticmethod
def _cache_key(
query: str,
results: list[QueryResult],
mode: str,
) -> str:
"""生成缓存键 — 基于 mode + query + chunk_ids 的 SHA256 哈希。"""
chunk_ids = ",".join(sorted(r.chunk_id for r in results))
raw = f"{mode}:{query}:{chunk_ids}"
return hashlib.sha256(raw.encode()).hexdigest()
__all__ = [
"HIT_PROCESSING_DIRECT",
"HIT_PROCESSING_MODEL_OPT",
"HitProcessingResult",
"HitProcessor",
]

View File

@ -0,0 +1,144 @@
"""KB 设置模型 — 检索模式默认/命中处理默认/授权用户/caching/rerank。
KB 级别设置控制检索与命中处理的默认行为Agent 运行时可通过
``RetrievalRequest`` 字段覆盖
当前为进程内 dict 存储重启丢失
ponytail: 升级路径 迁移到 PGKBModel 已有 default_query_mode /
default_hit_processing / caching_disabled rerank 相关设置需新增
JSON 列或独立表
"""
from __future__ import annotations
import logging
from pydantic import BaseModel, ConfigDict
from agentkit.rag_platform.models import QueryMode
logger = logging.getLogger(__name__)
HIT_PROCESSING_MODEL_OPT = "model_opt"
HIT_PROCESSING_DIRECT = "direct"
class KBSettings(BaseModel):
"""KB 级别设置 — 可通过 API 修改。
owner 字段用于 ACL 校验 owner admin可修改设置viewer 只读
"""
model_config = ConfigDict()
kb_id: str
owner: str | None = None # KB 所有者(仅 owner 可修改设置)
default_query_mode: QueryMode = QueryMode.blend
default_hit_processing: str = HIT_PROCESSING_MODEL_OPT # model_opt | direct
caching_disabled: bool = False
rerank_enabled: bool = True
rerank_provider: str = "cohere" # cohere | bge | none
rerank_api_key: str | None = None
rerank_base_url: str | None = None
data_export_warning: bool = False # True = 使用云端 rerank 处理敏感数据
class KBSettingsUpdate(BaseModel):
"""KB 设置更新请求 — 仅 owner 可修改。
所有字段可选部分更新语义None 表示不更新该字段
"""
model_config = ConfigDict()
default_query_mode: QueryMode | None = None
default_hit_processing: str | None = None
caching_disabled: bool | None = None
rerank_enabled: bool | None = None
rerank_provider: str | None = None
rerank_api_key: str | None = None
rerank_base_url: str | None = None
class KBSettingsStore:
"""KB 设置存储 — 读写 KB 设置。
当前为进程内 dict 存储CRUD 方法为 async 以保持与未来 PG 迁移的接口兼容
"""
def __init__(self) -> None:
self._settings: dict[str, KBSettings] = {}
async def get_settings(self, kb_id: str) -> KBSettings | None:
"""读取 KB 设置。不存在时返回 None。"""
return self._settings.get(kb_id)
async def get_or_create(self, kb_id: str, owner: str | None = None) -> KBSettings:
"""读取 KB 设置,不存在时创建默认设置。
已存在时不覆盖 owner
"""
if kb_id not in self._settings:
self._settings[kb_id] = KBSettings(kb_id=kb_id, owner=owner)
return self._settings[kb_id]
async def update_settings(
self,
kb_id: str,
update: KBSettingsUpdate,
owner: str | None = None,
) -> KBSettings:
"""更新 KB 设置 — 仅更新非 None 字段。
首次更新时若提供 owner 则设置所有者owner 校验由路由层完成
"""
existing = self._settings.get(kb_id)
if existing is None:
existing = KBSettings(kb_id=kb_id, owner=owner)
self._settings[kb_id] = existing
update_data = update.model_dump(exclude_none=True)
updated = existing.model_copy(update=update_data)
self._settings[kb_id] = updated
return updated
def is_owner(self, kb_id: str, user_id: str | None) -> bool:
"""检查用户是否为 KB 所有者。KB 不存在或 owner 未设置时返回 False。"""
settings = self._settings.get(kb_id)
if settings is None or settings.owner is None:
return False
return settings.owner == user_id
def set_owner(self, kb_id: str, owner: str) -> bool:
"""设置 KB 所有者。返回 True 如果 KB 存在。"""
settings = self._settings.get(kb_id)
if settings is None:
return False
self._settings[kb_id] = settings.model_copy(update={"owner": owner})
return True
# 模块级单例 — 类似 KnowledgeSourceStore 模式
_settings_store = KBSettingsStore()
def get_settings_store() -> KBSettingsStore:
"""返回进程级 KBSettingsStore 单例。"""
return _settings_store
def set_settings_store(store: KBSettingsStore) -> None:
"""注入自定义 KBSettingsStore测试用"""
global _settings_store
_settings_store = store
__all__ = [
"HIT_PROCESSING_DIRECT",
"HIT_PROCESSING_MODEL_OPT",
"KBSettings",
"KBSettingsStore",
"KBSettingsUpdate",
"get_settings_store",
"set_settings_store",
]

View File

@ -652,3 +652,67 @@ async def check_source_health(source_id: str, req: Request, _auth: None = Depend
# For external sources, try to check connectivity
# This would use the actual KB adapters in production
return {"source_id": source_id, "status": "unknown", "message": "健康检查未实现"}
# ---------------------------------------------------------------------------
# KB Settings endpoints (U6)
# ---------------------------------------------------------------------------
from agentkit.rag_platform.settings import ( # noqa: E402
KBSettings,
KBSettingsUpdate,
get_settings_store,
)
@router.get("/kb-management/kbs/{kb_id}/settings")
async def get_kb_settings(
kb_id: str,
req: Request,
_auth: None = Depends(_verify_api_key),
_user=Depends(require_permission(Permission.KB_QUERY)),
):
"""获取 KB 设置 — owner/viewer 可读。
U6: 返回 KB 级别设置检索模式默认/命中处理默认/caching/rerank
设置不存在时返回默认值
"""
store = get_settings_store()
settings = await store.get_settings(kb_id)
if settings is None:
# 返回默认设置(不持久化 — 首次 PUT 时创建)
settings = KBSettings(kb_id=kb_id)
return settings.model_dump()
@router.put("/kb-management/kbs/{kb_id}/settings")
async def update_kb_settings(
kb_id: str,
update: KBSettingsUpdate,
req: Request,
_auth: None = Depends(_verify_api_key),
_user=Depends(require_permission(Permission.KB_WRITE)),
):
"""更新 KB 设置 — 仅 owner或 admin可修改。
U6: per-KB ACL 校验 首次更新时设置 owner后续更新仅 owner/admin 可调用
viewer KB_QUERY 但非 owner调用返回 403
"""
user_id = _user.get("user_id")
role = _user.get("role", "")
store = get_settings_store()
existing = await store.get_settings(kb_id)
# ACL 校验:已有设置且 owner 已设置时,仅 owner/admin 可修改
if existing is not None and existing.owner is not None:
if existing.owner != user_id and role != "admin":
raise HTTPException(
status_code=403,
detail="仅 KB owner 可修改设置",
)
# 首次创建时设置 owner
owner = user_id if (existing is None or existing.owner is None) else None
updated = await store.update_settings(kb_id, update, owner=owner)
return updated.model_dump()

View File

@ -0,0 +1,319 @@
"""U6 测试 — 命中处理model_opt / direct 模式)。
测试场景
1. model_opt 模式LLM 基于检索结果生成回答
2. direct 模式直接返回匹配段落
3. 空结果返回"未找到相关内容"
4. LLM 网关时降级为 direct 模式
5. LLM 调用失败时降级为 direct 模式
6. 缓存命中返回 cached=True
7. KB 默认模式生效
8. Agent 运行时覆盖默认模式
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
from agentkit.rag_platform.hit_processing import (
HIT_PROCESSING_DIRECT,
HIT_PROCESSING_MODEL_OPT,
HitProcessor,
)
from agentkit.rag_platform.models import QueryResult
from agentkit.rag_platform.settings import KBSettings
def _make_query_result(
chunk_id: str,
content: str,
score: float = 0.5,
document_id: str = "doc1",
kb_id: str = "kb1",
) -> QueryResult:
"""创建测试用 QueryResult。"""
return QueryResult(
chunk_id=chunk_id,
content=content,
score=score,
metadata={"document_id": document_id, "kb_id": kb_id},
document_id=document_id,
kb_id=kb_id,
)
def _make_mock_llm_gateway(response_content: str = "LLM 生成的回答") -> MagicMock:
"""创建 mock LLM 网关 — chat() 返回带 content 的 response。"""
mock = MagicMock()
mock_response = MagicMock()
mock_response.content = response_content
mock.chat = AsyncMock(return_value=mock_response)
return mock
# ---------------------------------------------------------------------------
# direct 模式测试
# ---------------------------------------------------------------------------
class TestDirectMode:
"""direct 模式测试。"""
async def test_direct_returns_concatenated_chunks(self):
"""direct 模式返回拼接的匹配段落,按分数降序。"""
processor = HitProcessor(llm_gateway=None, cache_enabled=False)
results = [
_make_query_result("c1", "段落一", score=0.9),
_make_query_result("c2", "段落二", score=0.7),
]
result = await processor.process("query", results, mode=HIT_PROCESSING_DIRECT)
assert result.mode == HIT_PROCESSING_DIRECT
assert "段落一" in result.answer
assert "段落二" in result.answer
assert result.cached is False
assert len(result.sources) == 2
# 按分数降序
assert result.sources[0].chunk_id == "c1"
async def test_direct_format_includes_index(self):
"""direct 模式段落带序号 [1] [2]。"""
processor = HitProcessor(llm_gateway=None, cache_enabled=False)
results = [_make_query_result("c1", "唯一段落")]
result = await processor.process("query", results, mode=HIT_PROCESSING_DIRECT)
assert "[1]" in result.answer
assert "唯一段落" in result.answer
async def test_direct_empty_results(self):
"""空结果返回"未找到相关内容""""
processor = HitProcessor(llm_gateway=None, cache_enabled=False)
result = await processor.process("query", [], mode=HIT_PROCESSING_DIRECT)
assert result.answer == "未找到相关内容。"
assert result.sources == []
assert result.cached is False
# ---------------------------------------------------------------------------
# model_opt 模式测试
# ---------------------------------------------------------------------------
class TestModelOptMode:
"""model_opt 模式测试。"""
async def test_model_opt_calls_llm(self):
"""model_opt 模式调用 LLM 生成回答。"""
mock_llm = _make_mock_llm_gateway("基于资料的回答")
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
results = [_make_query_result("c1", "参考资料内容")]
result = await processor.process("问题", results, mode=HIT_PROCESSING_MODEL_OPT)
assert result.mode == HIT_PROCESSING_MODEL_OPT
assert result.answer == "基于资料的回答"
assert len(result.sources) == 1
mock_llm.chat.assert_awaited_once()
async def test_model_opt_includes_context_in_prompt(self):
"""model_opt 模式将检索结果作为 context 传入 LLM。"""
mock_llm = _make_mock_llm_gateway("回答")
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
results = [_make_query_result("c1", "独特的资料文本")]
await processor.process("问题", results, mode=HIT_PROCESSING_MODEL_OPT)
call_args = mock_llm.chat.await_args
messages = call_args.kwargs.get("messages") or call_args.args[0]
# 检查 context 出现在 prompt 中
all_content = " ".join(m["content"] for m in messages)
assert "独特的资料文本" in all_content
async def test_model_opt_passes_query_in_prompt(self):
"""model_opt 模式将用户查询传入 LLM prompt。"""
mock_llm = _make_mock_llm_gateway("回答")
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
results = [_make_query_result("c1", "资料")]
await processor.process("独特查询文本", results, mode=HIT_PROCESSING_MODEL_OPT)
call_args = mock_llm.chat.await_args
messages = call_args.kwargs.get("messages") or call_args.args[0]
all_content = " ".join(m["content"] for m in messages)
assert "独特查询文本" in all_content
async def test_model_opt_no_llm_falls_back_to_direct(self):
"""无 LLM 网关时降级为 direct 模式。"""
processor = HitProcessor(llm_gateway=None, cache_enabled=False)
results = [_make_query_result("c1", "段落内容")]
result = await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT)
# 降级为 direct — answer 是拼接的段落
assert "段落内容" in result.answer
async def test_model_opt_llm_failure_falls_back_to_direct(self):
"""LLM 调用失败时降级为 direct 模式(不丢失检索结果)。"""
mock_llm = MagicMock()
mock_llm.chat = AsyncMock(side_effect=RuntimeError("LLM 不可用"))
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
results = [_make_query_result("c1", "降级段落")]
result = await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT)
assert "降级段落" in result.answer
async def test_model_opt_empty_results(self):
"""空结果返回"未找到相关内容",不调用 LLM。"""
mock_llm = _make_mock_llm_gateway()
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
result = await processor.process("query", [], mode=HIT_PROCESSING_MODEL_OPT)
assert result.answer == "未找到相关内容。"
mock_llm.chat.assert_not_awaited()
async def test_model_opt_sorts_by_score(self):
"""model_opt 模式按分数降序构建 context。"""
mock_llm = _make_mock_llm_gateway("回答")
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
results = [
_make_query_result("c2", "低分段落", score=0.3),
_make_query_result("c1", "高分段落", score=0.9),
]
result = await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT)
assert result.sources[0].chunk_id == "c1"
assert result.sources[1].chunk_id == "c2"
# ---------------------------------------------------------------------------
# 缓存测试
# ---------------------------------------------------------------------------
class TestCaching:
"""缓存测试。"""
async def test_cache_hit_returns_cached_true(self):
"""相同输入第二次返回 cached=True。"""
mock_llm = _make_mock_llm_gateway("回答")
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=True)
results = [_make_query_result("c1", "内容")]
first = await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT)
second = await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT)
assert first.cached is False
assert second.cached is True
# LLM 只调用一次(第二次命中缓存)
assert mock_llm.chat.await_count == 1
async def test_cache_disabled_no_caching(self):
"""禁用缓存时每次都调用 LLM。"""
mock_llm = _make_mock_llm_gateway("回答")
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
results = [_make_query_result("c1", "内容")]
await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT)
await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT)
assert mock_llm.chat.await_count == 2
async def test_cache_key_includes_query(self):
"""不同查询不共享缓存。"""
mock_llm = _make_mock_llm_gateway("回答")
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=True)
results = [_make_query_result("c1", "内容")]
await processor.process("query1", results, mode=HIT_PROCESSING_MODEL_OPT)
await processor.process("query2", results, mode=HIT_PROCESSING_MODEL_OPT)
assert mock_llm.chat.await_count == 2
async def test_cache_key_includes_chunk_ids(self):
"""不同检索结果不共享缓存。"""
mock_llm = _make_mock_llm_gateway("回答")
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=True)
await processor.process(
"query", [_make_query_result("c1", "内容")], mode=HIT_PROCESSING_MODEL_OPT
)
await processor.process(
"query", [_make_query_result("c2", "内容")], mode=HIT_PROCESSING_MODEL_OPT
)
assert mock_llm.chat.await_count == 2
async def test_cache_key_includes_mode(self):
"""不同模式不共享缓存。"""
mock_llm = _make_mock_llm_gateway("回答")
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=True)
results = [_make_query_result("c1", "内容")]
await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT)
await processor.process("query", results, mode=HIT_PROCESSING_DIRECT)
# direct 模式不调用 LLM但也不应命中 model_opt 的缓存
assert mock_llm.chat.await_count == 1
async def test_cache_not_used_for_empty_results(self):
"""空结果不写入缓存。"""
mock_llm = _make_mock_llm_gateway()
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=True)
first = await processor.process("query", [], mode=HIT_PROCESSING_MODEL_OPT)
second = await processor.process("query", [], mode=HIT_PROCESSING_MODEL_OPT)
assert first.cached is False
assert second.cached is False
# ---------------------------------------------------------------------------
# KB 默认模式 + 运行时覆盖测试
# ---------------------------------------------------------------------------
class TestKBDefaultMode:
"""KB 默认模式生效测试。"""
async def test_kb_default_model_opt(self):
"""KB 默认 model_opt 模式生效。"""
mock_llm = _make_mock_llm_gateway("回答")
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
settings = KBSettings(kb_id="kb1") # 默认 model_opt
results = [_make_query_result("c1", "内容")]
result = await processor.process("query", results, mode=settings.default_hit_processing)
assert result.mode == HIT_PROCESSING_MODEL_OPT
mock_llm.chat.assert_awaited_once()
async def test_kb_default_direct(self):
"""KB 默认 direct 模式生效。"""
processor = HitProcessor(llm_gateway=None, cache_enabled=False)
settings = KBSettings(kb_id="kb1", default_hit_processing=HIT_PROCESSING_DIRECT)
results = [_make_query_result("c1", "内容")]
result = await processor.process("query", results, mode=settings.default_hit_processing)
assert result.mode == HIT_PROCESSING_DIRECT
class TestModeOverride:
"""Agent 运行时覆盖默认模式测试。"""
async def test_override_to_direct(self):
"""运行时覆盖为 direct 模式 — 不调用 LLM。"""
mock_llm = _make_mock_llm_gateway("LLM 回答")
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
results = [_make_query_result("c1", "段落")]
# KB 默认 model_opt运行时覆盖为 direct
result = await processor.process("query", results, mode=HIT_PROCESSING_DIRECT)
assert result.mode == HIT_PROCESSING_DIRECT
mock_llm.chat.assert_not_awaited()
async def test_override_to_model_opt(self):
"""运行时覆盖为 model_opt 模式 — 调用 LLM。"""
mock_llm = _make_mock_llm_gateway("LLM 回答")
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
results = [_make_query_result("c1", "段落")]
# KB 默认 direct运行时覆盖为 model_opt
result = await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT)
assert result.mode == HIT_PROCESSING_MODEL_OPT
mock_llm.chat.assert_awaited_once()

View File

@ -0,0 +1,329 @@
"""U6 测试 — KB 设置模型与存储。
测试场景
1. KBSettings 默认值
2. KBSettingsUpdate 部分更新语义
3. KBSettingsStore CRUDget / get_or_create / update
4. owner 校验is_owner / set_owner
5. KB 设置默认模式生效 HitProcessor 集成
"""
from __future__ import annotations
from agentkit.rag_platform.hit_processing import (
HIT_PROCESSING_DIRECT,
HIT_PROCESSING_MODEL_OPT,
)
from agentkit.rag_platform.models import QueryMode
from agentkit.rag_platform.settings import (
KBSettings,
KBSettingsStore,
KBSettingsUpdate,
)
# ---------------------------------------------------------------------------
# KBSettings 模型测试
# ---------------------------------------------------------------------------
class TestKBSettings:
"""KBSettings 模型测试。"""
def test_defaults(self):
"""默认值正确。"""
settings = KBSettings(kb_id="kb1")
assert settings.kb_id == "kb1"
assert settings.owner is None
assert settings.default_query_mode == QueryMode.blend
assert settings.default_hit_processing == HIT_PROCESSING_MODEL_OPT
assert settings.caching_disabled is False
assert settings.rerank_enabled is True
assert settings.rerank_provider == "cohere"
assert settings.rerank_api_key is None
assert settings.rerank_base_url is None
assert settings.data_export_warning is False
def test_custom_values(self):
"""自定义值正确。"""
settings = KBSettings(
kb_id="kb1",
owner="user1",
default_query_mode=QueryMode.keywords,
default_hit_processing=HIT_PROCESSING_DIRECT,
caching_disabled=True,
rerank_enabled=False,
rerank_provider="bge",
rerank_api_key="key-123",
rerank_base_url="http://localhost:9997",
data_export_warning=True,
)
assert settings.owner == "user1"
assert settings.default_query_mode == QueryMode.keywords
assert settings.default_hit_processing == HIT_PROCESSING_DIRECT
assert settings.caching_disabled is True
assert settings.rerank_enabled is False
assert settings.rerank_provider == "bge"
assert settings.rerank_api_key == "key-123"
assert settings.rerank_base_url == "http://localhost:9997"
assert settings.data_export_warning is True
# ---------------------------------------------------------------------------
# KBSettingsUpdate 模型测试
# ---------------------------------------------------------------------------
class TestKBSettingsUpdate:
"""KBSettingsUpdate 模型测试。"""
def test_all_none_by_default(self):
"""默认所有字段为 None部分更新语义"""
update = KBSettingsUpdate()
assert update.default_query_mode is None
assert update.default_hit_processing is None
assert update.caching_disabled is None
assert update.rerank_enabled is None
assert update.rerank_provider is None
assert update.rerank_api_key is None
assert update.rerank_base_url is None
def test_partial_update(self):
"""部分字段赋值。"""
update = KBSettingsUpdate(default_hit_processing=HIT_PROCESSING_DIRECT)
assert update.default_hit_processing == HIT_PROCESSING_DIRECT
assert update.default_query_mode is None
def test_exclude_none_dumps(self):
"""model_dump(exclude_none=True) 仅包含已设置字段。"""
update = KBSettingsUpdate(caching_disabled=True)
data = update.model_dump(exclude_none=True)
assert data == {"caching_disabled": True}
def test_full_update(self):
"""所有字段同时赋值。"""
update = KBSettingsUpdate(
default_query_mode=QueryMode.embedding,
default_hit_processing=HIT_PROCESSING_DIRECT,
caching_disabled=True,
rerank_enabled=False,
rerank_provider="none",
rerank_api_key="key",
rerank_base_url="http://x",
)
data = update.model_dump(exclude_none=True)
assert len(data) == 7
# ---------------------------------------------------------------------------
# KBSettingsStore CRUD 测试
# ---------------------------------------------------------------------------
class TestKBSettingsStoreCRUD:
"""KBSettingsStore 存储测试。"""
async def test_get_settings_not_exist(self):
"""不存在的 KB 返回 None。"""
store = KBSettingsStore()
assert await store.get_settings("kb1") is None
async def test_get_or_create_creates_defaults(self):
"""get_or_create 创建默认设置。"""
store = KBSettingsStore()
settings = await store.get_or_create("kb1", owner="user1")
assert settings.kb_id == "kb1"
assert settings.owner == "user1"
assert settings.default_hit_processing == HIT_PROCESSING_MODEL_OPT
async def test_get_or_create_idempotent(self):
"""get_or_create 幂等 — 已存在时不覆盖 owner。"""
store = KBSettingsStore()
await store.get_or_create("kb1", owner="user1")
second = await store.get_or_create("kb1", owner="user2")
assert second.owner == "user1" # 不覆盖已有 owner
async def test_update_settings_creates_if_not_exist(self):
"""update_settings 在设置不存在时创建。"""
store = KBSettingsStore()
update = KBSettingsUpdate(caching_disabled=True)
result = await store.update_settings("kb1", update, owner="user1")
assert result.caching_disabled is True
assert result.owner == "user1"
assert result.kb_id == "kb1"
async def test_update_settings_partial(self):
"""update_settings 仅更新提供的字段,其他字段保持不变。"""
store = KBSettingsStore()
# 先创建默认设置
await store.get_or_create("kb1", owner="user1")
# 部分更新
update = KBSettingsUpdate(default_hit_processing=HIT_PROCESSING_DIRECT)
result = await store.update_settings("kb1", update)
assert result.default_hit_processing == HIT_PROCESSING_DIRECT
# 其他字段保持默认
assert result.default_query_mode == QueryMode.blend
assert result.caching_disabled is False
assert result.owner == "user1"
async def test_update_settings_multiple_fields(self):
"""update_settings 同时更新多个字段。"""
store = KBSettingsStore()
await store.get_or_create("kb1", owner="user1")
update = KBSettingsUpdate(
caching_disabled=True,
rerank_provider="bge",
rerank_enabled=False,
)
result = await store.update_settings("kb1", update)
assert result.caching_disabled is True
assert result.rerank_provider == "bge"
assert result.rerank_enabled is False
async def test_update_settings_persists(self):
"""update_settings 后 get_settings 返回更新后的值。"""
store = KBSettingsStore()
await store.update_settings("kb1", KBSettingsUpdate(caching_disabled=True), owner="user1")
retrieved = await store.get_settings("kb1")
assert retrieved is not None
assert retrieved.caching_disabled is True
async def test_update_settings_none_field_ignored(self):
"""update_settings 中 None 字段被忽略。"""
store = KBSettingsStore()
await store.update_settings(
"kb1",
KBSettingsUpdate(caching_disabled=True, default_hit_processing=HIT_PROCESSING_DIRECT),
owner="user1",
)
# 第二次更新 — 只改 caching_disableddefault_hit_processing 应保持
await store.update_settings(
"kb1",
KBSettingsUpdate(caching_disabled=False),
)
result = await store.get_settings("kb1")
assert result is not None
assert result.caching_disabled is False
assert result.default_hit_processing == HIT_PROCESSING_DIRECT
# ---------------------------------------------------------------------------
# owner 校验测试
# ---------------------------------------------------------------------------
class TestOwnerCheck:
"""owner 校验测试。"""
async def test_is_owner_true(self):
"""owner 匹配返回 True。"""
store = KBSettingsStore()
await store.get_or_create("kb1", owner="user1")
assert store.is_owner("kb1", "user1") is True
async def test_is_owner_false_wrong_user(self):
"""非 owner 用户返回 False。"""
store = KBSettingsStore()
await store.get_or_create("kb1", owner="user1")
assert store.is_owner("kb1", "user2") is False
async def test_is_owner_false_none_user(self):
"""None 用户返回 False。"""
store = KBSettingsStore()
await store.get_or_create("kb1", owner="user1")
assert store.is_owner("kb1", None) is False
async def test_is_owner_false_nonexistent_kb(self):
"""不存在的 KB 返回 False。"""
store = KBSettingsStore()
assert store.is_owner("nonexistent", "user1") is False
async def test_is_owner_false_no_owner_set(self):
"""owner 未设置时返回 False。"""
store = KBSettingsStore()
await store.get_or_create("kb1") # 不传 owner
assert store.is_owner("kb1", "user1") is False
async def test_set_owner(self):
"""set_owner 设置所有者。"""
store = KBSettingsStore()
await store.get_or_create("kb1", owner="user1")
assert store.set_owner("kb1", "user2") is True
assert store.is_owner("kb1", "user2") is True
assert store.is_owner("kb1", "user1") is False
async def test_set_owner_nonexistent(self):
"""set_owner 对不存在的 KB 返回 False。"""
store = KBSettingsStore()
assert store.set_owner("nonexistent", "user1") is False
# ---------------------------------------------------------------------------
# KB 默认模式生效测试(与 HitProcessor 集成)
# ---------------------------------------------------------------------------
class TestKBSettingsDefaultModeIntegration:
"""KB 设置默认模式与 HitProcessor 集成测试。"""
async def test_default_model_opt_flows_to_processor(self):
"""KB 默认 model_opt 模式传递给 HitProcessor。"""
from unittest.mock import AsyncMock, MagicMock
from agentkit.rag_platform.hit_processing import HitProcessor
from agentkit.rag_platform.models import QueryResult
mock_llm = MagicMock()
mock_resp = MagicMock()
mock_resp.content = "LLM 回答"
mock_llm.chat = AsyncMock(return_value=mock_resp)
store = KBSettingsStore()
settings = await store.get_or_create("kb1", owner="user1")
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
results = [
QueryResult(
chunk_id="c1",
content="内容",
score=0.9,
metadata={},
document_id="d1",
kb_id="kb1",
)
]
result = await processor.process("query", results, mode=settings.default_hit_processing)
assert result.mode == HIT_PROCESSING_MODEL_OPT
mock_llm.chat.assert_awaited_once()
async def test_default_direct_flows_to_processor(self):
"""KB 默认 direct 模式传递给 HitProcessor。"""
from agentkit.rag_platform.hit_processing import HitProcessor
from agentkit.rag_platform.models import QueryResult
store = KBSettingsStore()
await store.update_settings(
"kb1",
KBSettingsUpdate(default_hit_processing=HIT_PROCESSING_DIRECT),
owner="user1",
)
settings = await store.get_settings("kb1")
assert settings is not None
processor = HitProcessor(llm_gateway=None, cache_enabled=False)
results = [
QueryResult(
chunk_id="c1",
content="直接段落",
score=0.9,
metadata={},
document_id="d1",
kb_id="kb1",
)
]
result = await processor.process("query", results, mode=settings.default_hit_processing)
assert result.mode == HIT_PROCESSING_DIRECT
assert "直接段落" in result.answer