feat(rag_platform): U6 — hit processing mode + KB settings
Add hit_processing.py: HitProcessor with model_opt (LLM-generated) and direct (concatenated chunks) modes, with in-process cache
Add settings.py: KBSettings/KBSettingsUpdate models, KBSettingsStore with async CRUD
Add KB settings endpoints to kb_management.py: GET/PUT /kb-management/kbs/{kb_id}/settings with owner-only modification
Tests: 43 new tests (25 hit_processing + 18 settings), 293 total passing
This commit is contained in:
parent
5c562dbff3
commit
d026a91f43
|
|
@ -0,0 +1,181 @@
|
||||||
|
"""命中处理 — 模型优化模式(LLM 生成回答)vs 直接回答模式(返回匹配段落)。
|
||||||
|
|
||||||
|
- model_opt: 调用 LLM Gateway,基于检索结果生成自然语言回答
|
||||||
|
- direct: 直接拼接匹配段落返回(零 LLM 调用,零额外延迟)
|
||||||
|
|
||||||
|
模式选择由 KB 设置 ``default_hit_processing`` 决定,Agent 运行时可通过
|
||||||
|
``RetrievalRequest.hit_processing_mode`` 覆盖。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
from agentkit.rag_platform.models import QueryResult
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
HIT_PROCESSING_MODEL_OPT = "model_opt"
|
||||||
|
HIT_PROCESSING_DIRECT = "direct"
|
||||||
|
|
||||||
|
# ponytail: 硬编码默认模型 — 升级路径为从 KB 设置 / agentkit.yaml 读取
|
||||||
|
_DEFAULT_LLM_MODEL = "gpt-4o-mini"
|
||||||
|
|
||||||
|
# RAG prompt 模板 — 指示 LLM 基于参考资料回答
|
||||||
|
_RAG_PROMPT_TEMPLATE = (
|
||||||
|
"你是一个知识库问答助手。请根据以下检索到的参考资料回答用户问题。\n\n"
|
||||||
|
"参考资料:\n{context}\n\n"
|
||||||
|
"用户问题:{query}\n\n"
|
||||||
|
"请基于参考资料回答,如果资料中没有相关信息请说明。"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HitProcessingResult(BaseModel):
|
||||||
|
"""命中处理结果。"""
|
||||||
|
|
||||||
|
model_config = ConfigDict()
|
||||||
|
|
||||||
|
mode: str # "model_opt" | "direct"
|
||||||
|
answer: str # LLM 生成的回答或直接拼接的段落
|
||||||
|
sources: list[QueryResult] # 检索到的来源
|
||||||
|
cached: bool = False # 是否命中缓存
|
||||||
|
|
||||||
|
|
||||||
|
class HitProcessor:
|
||||||
|
"""命中处理器 — 根据 mode 处理检索结果。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm_gateway: LLM 网关实例(model_opt 模式使用),需支持 ``async chat(messages, model)`` 接口
|
||||||
|
cache_enabled: 是否启用命中处理缓存(默认 True)
|
||||||
|
model: LLM 模型名(默认 gpt-4o-mini)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
llm_gateway: Any = None,
|
||||||
|
cache_enabled: bool = True,
|
||||||
|
model: str = _DEFAULT_LLM_MODEL,
|
||||||
|
) -> None:
|
||||||
|
self._llm = llm_gateway
|
||||||
|
self._cache_enabled = cache_enabled
|
||||||
|
self._model = model
|
||||||
|
# ponytail: 进程内 dict 缓存 — 无 TTL,无容量上限
|
||||||
|
# 升级路径:迁移到 Redis,带 TTL + LRU 淘汰
|
||||||
|
self._cache: dict[str, HitProcessingResult] = {}
|
||||||
|
|
||||||
|
async def process(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
results: list[QueryResult],
|
||||||
|
mode: str = HIT_PROCESSING_MODEL_OPT,
|
||||||
|
) -> HitProcessingResult:
|
||||||
|
"""处理检索结果。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 用户查询文本
|
||||||
|
results: 检索结果列表
|
||||||
|
mode: 命中处理模式 — "model_opt"(LLM 生成)或 "direct"(直接返回段落)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HitProcessingResult 包含回答、来源和缓存标记
|
||||||
|
"""
|
||||||
|
if not results:
|
||||||
|
return HitProcessingResult(
|
||||||
|
mode=mode, answer="未找到相关内容。", sources=[], cached=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# 缓存检查
|
||||||
|
cache_key: str | None = None
|
||||||
|
if self._cache_enabled:
|
||||||
|
cache_key = self._cache_key(query, results, mode)
|
||||||
|
cached = self._cache.get(cache_key)
|
||||||
|
if cached is not None:
|
||||||
|
return cached.model_copy(update={"cached": True})
|
||||||
|
|
||||||
|
if mode == HIT_PROCESSING_DIRECT:
|
||||||
|
result = self._process_direct(results)
|
||||||
|
else:
|
||||||
|
result = await self._process_model_opt(query, results)
|
||||||
|
|
||||||
|
# 写入缓存
|
||||||
|
if cache_key is not None:
|
||||||
|
self._cache[cache_key] = result
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _process_direct(self, results: list[QueryResult]) -> HitProcessingResult:
|
||||||
|
"""直接回答模式 — 拼接匹配段落,按分数降序排列。"""
|
||||||
|
sorted_results = sorted(results, key=lambda r: r.score, reverse=True)
|
||||||
|
parts: list[str] = []
|
||||||
|
for i, r in enumerate(sorted_results, start=1):
|
||||||
|
parts.append(f"[{i}] {r.content}")
|
||||||
|
answer = "\n\n".join(parts)
|
||||||
|
return HitProcessingResult(
|
||||||
|
mode=HIT_PROCESSING_DIRECT,
|
||||||
|
answer=answer,
|
||||||
|
sources=sorted_results,
|
||||||
|
cached=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _process_model_opt(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
results: list[QueryResult],
|
||||||
|
) -> HitProcessingResult:
|
||||||
|
"""模型优化模式 — LLM 基于检索结果生成回答。
|
||||||
|
|
||||||
|
无 LLM 网关或调用失败时降级为 direct 模式(不丢失检索结果)。
|
||||||
|
"""
|
||||||
|
if self._llm is None:
|
||||||
|
logger.warning("No LLM gateway configured, falling back to direct mode")
|
||||||
|
return self._process_direct(results)
|
||||||
|
|
||||||
|
# 构建 context — 按分数降序
|
||||||
|
sorted_results = sorted(results, key=lambda r: r.score, reverse=True)
|
||||||
|
context_parts: list[str] = []
|
||||||
|
for i, r in enumerate(sorted_results, start=1):
|
||||||
|
context_parts.append(f"[{i}] {r.content}")
|
||||||
|
context = "\n\n".join(context_parts)
|
||||||
|
|
||||||
|
prompt = _RAG_PROMPT_TEMPLATE.format(context=context, query=query)
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "你是一个知识库问答助手,请基于参考资料回答问题。"},
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self._llm.chat(messages=messages, model=self._model)
|
||||||
|
answer = response.content
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("LLM call failed, falling back to direct mode: %s", e)
|
||||||
|
return self._process_direct(results)
|
||||||
|
|
||||||
|
return HitProcessingResult(
|
||||||
|
mode=HIT_PROCESSING_MODEL_OPT,
|
||||||
|
answer=answer,
|
||||||
|
sources=sorted_results,
|
||||||
|
cached=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _cache_key(
|
||||||
|
query: str,
|
||||||
|
results: list[QueryResult],
|
||||||
|
mode: str,
|
||||||
|
) -> str:
|
||||||
|
"""生成缓存键 — 基于 mode + query + chunk_ids 的 SHA256 哈希。"""
|
||||||
|
chunk_ids = ",".join(sorted(r.chunk_id for r in results))
|
||||||
|
raw = f"{mode}:{query}:{chunk_ids}"
|
||||||
|
return hashlib.sha256(raw.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"HIT_PROCESSING_DIRECT",
|
||||||
|
"HIT_PROCESSING_MODEL_OPT",
|
||||||
|
"HitProcessingResult",
|
||||||
|
"HitProcessor",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,144 @@
|
||||||
|
"""KB 设置模型 — 检索模式默认/命中处理默认/授权用户/caching/rerank。
|
||||||
|
|
||||||
|
KB 级别设置控制检索与命中处理的默认行为,Agent 运行时可通过
|
||||||
|
``RetrievalRequest`` 字段覆盖。
|
||||||
|
|
||||||
|
当前为进程内 dict 存储(重启丢失)。
|
||||||
|
ponytail: 升级路径 — 迁移到 PG(KBModel 已有 default_query_mode /
|
||||||
|
default_hit_processing / caching_disabled 列),rerank 相关设置需新增
|
||||||
|
JSON 列或独立表。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
from agentkit.rag_platform.models import QueryMode
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
HIT_PROCESSING_MODEL_OPT = "model_opt"
|
||||||
|
HIT_PROCESSING_DIRECT = "direct"
|
||||||
|
|
||||||
|
|
||||||
|
class KBSettings(BaseModel):
|
||||||
|
"""KB 级别设置 — 可通过 API 修改。
|
||||||
|
|
||||||
|
owner 字段用于 ACL 校验:仅 owner(或 admin)可修改设置,viewer 只读。
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict()
|
||||||
|
|
||||||
|
kb_id: str
|
||||||
|
owner: str | None = None # KB 所有者(仅 owner 可修改设置)
|
||||||
|
default_query_mode: QueryMode = QueryMode.blend
|
||||||
|
default_hit_processing: str = HIT_PROCESSING_MODEL_OPT # model_opt | direct
|
||||||
|
caching_disabled: bool = False
|
||||||
|
rerank_enabled: bool = True
|
||||||
|
rerank_provider: str = "cohere" # cohere | bge | none
|
||||||
|
rerank_api_key: str | None = None
|
||||||
|
rerank_base_url: str | None = None
|
||||||
|
data_export_warning: bool = False # True = 使用云端 rerank 处理敏感数据
|
||||||
|
|
||||||
|
|
||||||
|
class KBSettingsUpdate(BaseModel):
|
||||||
|
"""KB 设置更新请求 — 仅 owner 可修改。
|
||||||
|
|
||||||
|
所有字段可选(部分更新语义),None 表示不更新该字段。
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict()
|
||||||
|
|
||||||
|
default_query_mode: QueryMode | None = None
|
||||||
|
default_hit_processing: str | None = None
|
||||||
|
caching_disabled: bool | None = None
|
||||||
|
rerank_enabled: bool | None = None
|
||||||
|
rerank_provider: str | None = None
|
||||||
|
rerank_api_key: str | None = None
|
||||||
|
rerank_base_url: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class KBSettingsStore:
|
||||||
|
"""KB 设置存储 — 读写 KB 设置。
|
||||||
|
|
||||||
|
当前为进程内 dict 存储。CRUD 方法为 async 以保持与未来 PG 迁移的接口兼容。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._settings: dict[str, KBSettings] = {}
|
||||||
|
|
||||||
|
async def get_settings(self, kb_id: str) -> KBSettings | None:
|
||||||
|
"""读取 KB 设置。不存在时返回 None。"""
|
||||||
|
return self._settings.get(kb_id)
|
||||||
|
|
||||||
|
async def get_or_create(self, kb_id: str, owner: str | None = None) -> KBSettings:
|
||||||
|
"""读取 KB 设置,不存在时创建默认设置。
|
||||||
|
|
||||||
|
已存在时不覆盖 owner。
|
||||||
|
"""
|
||||||
|
if kb_id not in self._settings:
|
||||||
|
self._settings[kb_id] = KBSettings(kb_id=kb_id, owner=owner)
|
||||||
|
return self._settings[kb_id]
|
||||||
|
|
||||||
|
async def update_settings(
|
||||||
|
self,
|
||||||
|
kb_id: str,
|
||||||
|
update: KBSettingsUpdate,
|
||||||
|
owner: str | None = None,
|
||||||
|
) -> KBSettings:
|
||||||
|
"""更新 KB 设置 — 仅更新非 None 字段。
|
||||||
|
|
||||||
|
首次更新时若提供 owner 则设置所有者。owner 校验由路由层完成。
|
||||||
|
"""
|
||||||
|
existing = self._settings.get(kb_id)
|
||||||
|
if existing is None:
|
||||||
|
existing = KBSettings(kb_id=kb_id, owner=owner)
|
||||||
|
self._settings[kb_id] = existing
|
||||||
|
|
||||||
|
update_data = update.model_dump(exclude_none=True)
|
||||||
|
updated = existing.model_copy(update=update_data)
|
||||||
|
self._settings[kb_id] = updated
|
||||||
|
return updated
|
||||||
|
|
||||||
|
def is_owner(self, kb_id: str, user_id: str | None) -> bool:
|
||||||
|
"""检查用户是否为 KB 所有者。KB 不存在或 owner 未设置时返回 False。"""
|
||||||
|
settings = self._settings.get(kb_id)
|
||||||
|
if settings is None or settings.owner is None:
|
||||||
|
return False
|
||||||
|
return settings.owner == user_id
|
||||||
|
|
||||||
|
def set_owner(self, kb_id: str, owner: str) -> bool:
|
||||||
|
"""设置 KB 所有者。返回 True 如果 KB 存在。"""
|
||||||
|
settings = self._settings.get(kb_id)
|
||||||
|
if settings is None:
|
||||||
|
return False
|
||||||
|
self._settings[kb_id] = settings.model_copy(update={"owner": owner})
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
# 模块级单例 — 类似 KnowledgeSourceStore 模式
|
||||||
|
_settings_store = KBSettingsStore()
|
||||||
|
|
||||||
|
|
||||||
|
def get_settings_store() -> KBSettingsStore:
|
||||||
|
"""返回进程级 KBSettingsStore 单例。"""
|
||||||
|
return _settings_store
|
||||||
|
|
||||||
|
|
||||||
|
def set_settings_store(store: KBSettingsStore) -> None:
|
||||||
|
"""注入自定义 KBSettingsStore(测试用)。"""
|
||||||
|
global _settings_store
|
||||||
|
_settings_store = store
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"HIT_PROCESSING_DIRECT",
|
||||||
|
"HIT_PROCESSING_MODEL_OPT",
|
||||||
|
"KBSettings",
|
||||||
|
"KBSettingsStore",
|
||||||
|
"KBSettingsUpdate",
|
||||||
|
"get_settings_store",
|
||||||
|
"set_settings_store",
|
||||||
|
]
|
||||||
|
|
@ -652,3 +652,67 @@ async def check_source_health(source_id: str, req: Request, _auth: None = Depend
|
||||||
# For external sources, try to check connectivity
|
# For external sources, try to check connectivity
|
||||||
# This would use the actual KB adapters in production
|
# This would use the actual KB adapters in production
|
||||||
return {"source_id": source_id, "status": "unknown", "message": "健康检查未实现"}
|
return {"source_id": source_id, "status": "unknown", "message": "健康检查未实现"}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# KB Settings endpoints (U6)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
from agentkit.rag_platform.settings import ( # noqa: E402
|
||||||
|
KBSettings,
|
||||||
|
KBSettingsUpdate,
|
||||||
|
get_settings_store,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/kb-management/kbs/{kb_id}/settings")
|
||||||
|
async def get_kb_settings(
|
||||||
|
kb_id: str,
|
||||||
|
req: Request,
|
||||||
|
_auth: None = Depends(_verify_api_key),
|
||||||
|
_user=Depends(require_permission(Permission.KB_QUERY)),
|
||||||
|
):
|
||||||
|
"""获取 KB 设置 — owner/viewer 可读。
|
||||||
|
|
||||||
|
U6: 返回 KB 级别设置(检索模式默认/命中处理默认/caching/rerank)。
|
||||||
|
设置不存在时返回默认值。
|
||||||
|
"""
|
||||||
|
store = get_settings_store()
|
||||||
|
settings = await store.get_settings(kb_id)
|
||||||
|
if settings is None:
|
||||||
|
# 返回默认设置(不持久化 — 首次 PUT 时创建)
|
||||||
|
settings = KBSettings(kb_id=kb_id)
|
||||||
|
return settings.model_dump()
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/kb-management/kbs/{kb_id}/settings")
|
||||||
|
async def update_kb_settings(
|
||||||
|
kb_id: str,
|
||||||
|
update: KBSettingsUpdate,
|
||||||
|
req: Request,
|
||||||
|
_auth: None = Depends(_verify_api_key),
|
||||||
|
_user=Depends(require_permission(Permission.KB_WRITE)),
|
||||||
|
):
|
||||||
|
"""更新 KB 设置 — 仅 owner(或 admin)可修改。
|
||||||
|
|
||||||
|
U6: per-KB ACL 校验 — 首次更新时设置 owner,后续更新仅 owner/admin 可调用。
|
||||||
|
viewer(有 KB_QUERY 但非 owner)调用返回 403。
|
||||||
|
"""
|
||||||
|
user_id = _user.get("user_id")
|
||||||
|
role = _user.get("role", "")
|
||||||
|
|
||||||
|
store = get_settings_store()
|
||||||
|
existing = await store.get_settings(kb_id)
|
||||||
|
|
||||||
|
# ACL 校验:已有设置且 owner 已设置时,仅 owner/admin 可修改
|
||||||
|
if existing is not None and existing.owner is not None:
|
||||||
|
if existing.owner != user_id and role != "admin":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail="仅 KB owner 可修改设置",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 首次创建时设置 owner
|
||||||
|
owner = user_id if (existing is None or existing.owner is None) else None
|
||||||
|
updated = await store.update_settings(kb_id, update, owner=owner)
|
||||||
|
return updated.model_dump()
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,319 @@
|
||||||
|
"""U6 测试 — 命中处理(model_opt / direct 模式)。
|
||||||
|
|
||||||
|
测试场景:
|
||||||
|
1. model_opt 模式:LLM 基于检索结果生成回答
|
||||||
|
2. direct 模式:直接返回匹配段落
|
||||||
|
3. 空结果返回"未找到相关内容"
|
||||||
|
4. 无 LLM 网关时降级为 direct 模式
|
||||||
|
5. LLM 调用失败时降级为 direct 模式
|
||||||
|
6. 缓存命中返回 cached=True
|
||||||
|
7. KB 默认模式生效
|
||||||
|
8. Agent 运行时覆盖默认模式
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
from agentkit.rag_platform.hit_processing import (
|
||||||
|
HIT_PROCESSING_DIRECT,
|
||||||
|
HIT_PROCESSING_MODEL_OPT,
|
||||||
|
HitProcessor,
|
||||||
|
)
|
||||||
|
from agentkit.rag_platform.models import QueryResult
|
||||||
|
from agentkit.rag_platform.settings import KBSettings
|
||||||
|
|
||||||
|
|
||||||
|
def _make_query_result(
|
||||||
|
chunk_id: str,
|
||||||
|
content: str,
|
||||||
|
score: float = 0.5,
|
||||||
|
document_id: str = "doc1",
|
||||||
|
kb_id: str = "kb1",
|
||||||
|
) -> QueryResult:
|
||||||
|
"""创建测试用 QueryResult。"""
|
||||||
|
return QueryResult(
|
||||||
|
chunk_id=chunk_id,
|
||||||
|
content=content,
|
||||||
|
score=score,
|
||||||
|
metadata={"document_id": document_id, "kb_id": kb_id},
|
||||||
|
document_id=document_id,
|
||||||
|
kb_id=kb_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_llm_gateway(response_content: str = "LLM 生成的回答") -> MagicMock:
|
||||||
|
"""创建 mock LLM 网关 — chat() 返回带 content 的 response。"""
|
||||||
|
mock = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.content = response_content
|
||||||
|
mock.chat = AsyncMock(return_value=mock_response)
|
||||||
|
return mock
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# direct 模式测试
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestDirectMode:
|
||||||
|
"""direct 模式测试。"""
|
||||||
|
|
||||||
|
async def test_direct_returns_concatenated_chunks(self):
|
||||||
|
"""direct 模式返回拼接的匹配段落,按分数降序。"""
|
||||||
|
processor = HitProcessor(llm_gateway=None, cache_enabled=False)
|
||||||
|
results = [
|
||||||
|
_make_query_result("c1", "段落一", score=0.9),
|
||||||
|
_make_query_result("c2", "段落二", score=0.7),
|
||||||
|
]
|
||||||
|
result = await processor.process("query", results, mode=HIT_PROCESSING_DIRECT)
|
||||||
|
|
||||||
|
assert result.mode == HIT_PROCESSING_DIRECT
|
||||||
|
assert "段落一" in result.answer
|
||||||
|
assert "段落二" in result.answer
|
||||||
|
assert result.cached is False
|
||||||
|
assert len(result.sources) == 2
|
||||||
|
# 按分数降序
|
||||||
|
assert result.sources[0].chunk_id == "c1"
|
||||||
|
|
||||||
|
async def test_direct_format_includes_index(self):
|
||||||
|
"""direct 模式段落带序号 [1] [2]。"""
|
||||||
|
processor = HitProcessor(llm_gateway=None, cache_enabled=False)
|
||||||
|
results = [_make_query_result("c1", "唯一段落")]
|
||||||
|
result = await processor.process("query", results, mode=HIT_PROCESSING_DIRECT)
|
||||||
|
|
||||||
|
assert "[1]" in result.answer
|
||||||
|
assert "唯一段落" in result.answer
|
||||||
|
|
||||||
|
async def test_direct_empty_results(self):
|
||||||
|
"""空结果返回"未找到相关内容"。"""
|
||||||
|
processor = HitProcessor(llm_gateway=None, cache_enabled=False)
|
||||||
|
result = await processor.process("query", [], mode=HIT_PROCESSING_DIRECT)
|
||||||
|
|
||||||
|
assert result.answer == "未找到相关内容。"
|
||||||
|
assert result.sources == []
|
||||||
|
assert result.cached is False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# model_opt 模式测试
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestModelOptMode:
|
||||||
|
"""model_opt 模式测试。"""
|
||||||
|
|
||||||
|
async def test_model_opt_calls_llm(self):
|
||||||
|
"""model_opt 模式调用 LLM 生成回答。"""
|
||||||
|
mock_llm = _make_mock_llm_gateway("基于资料的回答")
|
||||||
|
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
|
||||||
|
results = [_make_query_result("c1", "参考资料内容")]
|
||||||
|
result = await processor.process("问题", results, mode=HIT_PROCESSING_MODEL_OPT)
|
||||||
|
|
||||||
|
assert result.mode == HIT_PROCESSING_MODEL_OPT
|
||||||
|
assert result.answer == "基于资料的回答"
|
||||||
|
assert len(result.sources) == 1
|
||||||
|
mock_llm.chat.assert_awaited_once()
|
||||||
|
|
||||||
|
async def test_model_opt_includes_context_in_prompt(self):
|
||||||
|
"""model_opt 模式将检索结果作为 context 传入 LLM。"""
|
||||||
|
mock_llm = _make_mock_llm_gateway("回答")
|
||||||
|
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
|
||||||
|
results = [_make_query_result("c1", "独特的资料文本")]
|
||||||
|
await processor.process("问题", results, mode=HIT_PROCESSING_MODEL_OPT)
|
||||||
|
|
||||||
|
call_args = mock_llm.chat.await_args
|
||||||
|
messages = call_args.kwargs.get("messages") or call_args.args[0]
|
||||||
|
# 检查 context 出现在 prompt 中
|
||||||
|
all_content = " ".join(m["content"] for m in messages)
|
||||||
|
assert "独特的资料文本" in all_content
|
||||||
|
|
||||||
|
async def test_model_opt_passes_query_in_prompt(self):
|
||||||
|
"""model_opt 模式将用户查询传入 LLM prompt。"""
|
||||||
|
mock_llm = _make_mock_llm_gateway("回答")
|
||||||
|
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
|
||||||
|
results = [_make_query_result("c1", "资料")]
|
||||||
|
await processor.process("独特查询文本", results, mode=HIT_PROCESSING_MODEL_OPT)
|
||||||
|
|
||||||
|
call_args = mock_llm.chat.await_args
|
||||||
|
messages = call_args.kwargs.get("messages") or call_args.args[0]
|
||||||
|
all_content = " ".join(m["content"] for m in messages)
|
||||||
|
assert "独特查询文本" in all_content
|
||||||
|
|
||||||
|
async def test_model_opt_no_llm_falls_back_to_direct(self):
|
||||||
|
"""无 LLM 网关时降级为 direct 模式。"""
|
||||||
|
processor = HitProcessor(llm_gateway=None, cache_enabled=False)
|
||||||
|
results = [_make_query_result("c1", "段落内容")]
|
||||||
|
result = await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT)
|
||||||
|
|
||||||
|
# 降级为 direct — answer 是拼接的段落
|
||||||
|
assert "段落内容" in result.answer
|
||||||
|
|
||||||
|
async def test_model_opt_llm_failure_falls_back_to_direct(self):
|
||||||
|
"""LLM 调用失败时降级为 direct 模式(不丢失检索结果)。"""
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.chat = AsyncMock(side_effect=RuntimeError("LLM 不可用"))
|
||||||
|
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
|
||||||
|
results = [_make_query_result("c1", "降级段落")]
|
||||||
|
result = await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT)
|
||||||
|
|
||||||
|
assert "降级段落" in result.answer
|
||||||
|
|
||||||
|
async def test_model_opt_empty_results(self):
|
||||||
|
"""空结果返回"未找到相关内容",不调用 LLM。"""
|
||||||
|
mock_llm = _make_mock_llm_gateway()
|
||||||
|
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
|
||||||
|
result = await processor.process("query", [], mode=HIT_PROCESSING_MODEL_OPT)
|
||||||
|
|
||||||
|
assert result.answer == "未找到相关内容。"
|
||||||
|
mock_llm.chat.assert_not_awaited()
|
||||||
|
|
||||||
|
async def test_model_opt_sorts_by_score(self):
|
||||||
|
"""model_opt 模式按分数降序构建 context。"""
|
||||||
|
mock_llm = _make_mock_llm_gateway("回答")
|
||||||
|
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
|
||||||
|
results = [
|
||||||
|
_make_query_result("c2", "低分段落", score=0.3),
|
||||||
|
_make_query_result("c1", "高分段落", score=0.9),
|
||||||
|
]
|
||||||
|
result = await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT)
|
||||||
|
|
||||||
|
assert result.sources[0].chunk_id == "c1"
|
||||||
|
assert result.sources[1].chunk_id == "c2"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 缓存测试
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCaching:
|
||||||
|
"""缓存测试。"""
|
||||||
|
|
||||||
|
async def test_cache_hit_returns_cached_true(self):
|
||||||
|
"""相同输入第二次返回 cached=True。"""
|
||||||
|
mock_llm = _make_mock_llm_gateway("回答")
|
||||||
|
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=True)
|
||||||
|
results = [_make_query_result("c1", "内容")]
|
||||||
|
|
||||||
|
first = await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT)
|
||||||
|
second = await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT)
|
||||||
|
|
||||||
|
assert first.cached is False
|
||||||
|
assert second.cached is True
|
||||||
|
# LLM 只调用一次(第二次命中缓存)
|
||||||
|
assert mock_llm.chat.await_count == 1
|
||||||
|
|
||||||
|
async def test_cache_disabled_no_caching(self):
|
||||||
|
"""禁用缓存时每次都调用 LLM。"""
|
||||||
|
mock_llm = _make_mock_llm_gateway("回答")
|
||||||
|
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
|
||||||
|
results = [_make_query_result("c1", "内容")]
|
||||||
|
|
||||||
|
await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT)
|
||||||
|
await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT)
|
||||||
|
|
||||||
|
assert mock_llm.chat.await_count == 2
|
||||||
|
|
||||||
|
async def test_cache_key_includes_query(self):
|
||||||
|
"""不同查询不共享缓存。"""
|
||||||
|
mock_llm = _make_mock_llm_gateway("回答")
|
||||||
|
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=True)
|
||||||
|
results = [_make_query_result("c1", "内容")]
|
||||||
|
|
||||||
|
await processor.process("query1", results, mode=HIT_PROCESSING_MODEL_OPT)
|
||||||
|
await processor.process("query2", results, mode=HIT_PROCESSING_MODEL_OPT)
|
||||||
|
|
||||||
|
assert mock_llm.chat.await_count == 2
|
||||||
|
|
||||||
|
async def test_cache_key_includes_chunk_ids(self):
|
||||||
|
"""不同检索结果不共享缓存。"""
|
||||||
|
mock_llm = _make_mock_llm_gateway("回答")
|
||||||
|
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=True)
|
||||||
|
|
||||||
|
await processor.process(
|
||||||
|
"query", [_make_query_result("c1", "内容")], mode=HIT_PROCESSING_MODEL_OPT
|
||||||
|
)
|
||||||
|
await processor.process(
|
||||||
|
"query", [_make_query_result("c2", "内容")], mode=HIT_PROCESSING_MODEL_OPT
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_llm.chat.await_count == 2
|
||||||
|
|
||||||
|
async def test_cache_key_includes_mode(self):
|
||||||
|
"""不同模式不共享缓存。"""
|
||||||
|
mock_llm = _make_mock_llm_gateway("回答")
|
||||||
|
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=True)
|
||||||
|
results = [_make_query_result("c1", "内容")]
|
||||||
|
|
||||||
|
await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT)
|
||||||
|
await processor.process("query", results, mode=HIT_PROCESSING_DIRECT)
|
||||||
|
|
||||||
|
# direct 模式不调用 LLM,但也不应命中 model_opt 的缓存
|
||||||
|
assert mock_llm.chat.await_count == 1
|
||||||
|
|
||||||
|
async def test_cache_not_used_for_empty_results(self):
|
||||||
|
"""空结果不写入缓存。"""
|
||||||
|
mock_llm = _make_mock_llm_gateway()
|
||||||
|
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=True)
|
||||||
|
|
||||||
|
first = await processor.process("query", [], mode=HIT_PROCESSING_MODEL_OPT)
|
||||||
|
second = await processor.process("query", [], mode=HIT_PROCESSING_MODEL_OPT)
|
||||||
|
|
||||||
|
assert first.cached is False
|
||||||
|
assert second.cached is False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# KB 默认模式 + 运行时覆盖测试
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestKBDefaultMode:
|
||||||
|
"""KB 默认模式生效测试。"""
|
||||||
|
|
||||||
|
async def test_kb_default_model_opt(self):
|
||||||
|
"""KB 默认 model_opt 模式生效。"""
|
||||||
|
mock_llm = _make_mock_llm_gateway("回答")
|
||||||
|
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
|
||||||
|
settings = KBSettings(kb_id="kb1") # 默认 model_opt
|
||||||
|
results = [_make_query_result("c1", "内容")]
|
||||||
|
|
||||||
|
result = await processor.process("query", results, mode=settings.default_hit_processing)
|
||||||
|
assert result.mode == HIT_PROCESSING_MODEL_OPT
|
||||||
|
mock_llm.chat.assert_awaited_once()
|
||||||
|
|
||||||
|
async def test_kb_default_direct(self):
|
||||||
|
"""KB 默认 direct 模式生效。"""
|
||||||
|
processor = HitProcessor(llm_gateway=None, cache_enabled=False)
|
||||||
|
settings = KBSettings(kb_id="kb1", default_hit_processing=HIT_PROCESSING_DIRECT)
|
||||||
|
results = [_make_query_result("c1", "内容")]
|
||||||
|
|
||||||
|
result = await processor.process("query", results, mode=settings.default_hit_processing)
|
||||||
|
assert result.mode == HIT_PROCESSING_DIRECT
|
||||||
|
|
||||||
|
|
||||||
|
class TestModeOverride:
|
||||||
|
"""Agent 运行时覆盖默认模式测试。"""
|
||||||
|
|
||||||
|
async def test_override_to_direct(self):
|
||||||
|
"""运行时覆盖为 direct 模式 — 不调用 LLM。"""
|
||||||
|
mock_llm = _make_mock_llm_gateway("LLM 回答")
|
||||||
|
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
|
||||||
|
results = [_make_query_result("c1", "段落")]
|
||||||
|
|
||||||
|
# KB 默认 model_opt,运行时覆盖为 direct
|
||||||
|
result = await processor.process("query", results, mode=HIT_PROCESSING_DIRECT)
|
||||||
|
assert result.mode == HIT_PROCESSING_DIRECT
|
||||||
|
mock_llm.chat.assert_not_awaited()
|
||||||
|
|
||||||
|
async def test_override_to_model_opt(self):
|
||||||
|
"""运行时覆盖为 model_opt 模式 — 调用 LLM。"""
|
||||||
|
mock_llm = _make_mock_llm_gateway("LLM 回答")
|
||||||
|
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
|
||||||
|
results = [_make_query_result("c1", "段落")]
|
||||||
|
|
||||||
|
# KB 默认 direct,运行时覆盖为 model_opt
|
||||||
|
result = await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT)
|
||||||
|
assert result.mode == HIT_PROCESSING_MODEL_OPT
|
||||||
|
mock_llm.chat.assert_awaited_once()
|
||||||
|
|
@ -0,0 +1,329 @@
|
||||||
|
"""U6 测试 — KB 设置模型与存储。
|
||||||
|
|
||||||
|
测试场景:
|
||||||
|
1. KBSettings 默认值
|
||||||
|
2. KBSettingsUpdate 部分更新语义
|
||||||
|
3. KBSettingsStore CRUD(get / get_or_create / update)
|
||||||
|
4. owner 校验(is_owner / set_owner)
|
||||||
|
5. KB 设置默认模式生效(与 HitProcessor 集成)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from agentkit.rag_platform.hit_processing import (
|
||||||
|
HIT_PROCESSING_DIRECT,
|
||||||
|
HIT_PROCESSING_MODEL_OPT,
|
||||||
|
)
|
||||||
|
from agentkit.rag_platform.models import QueryMode
|
||||||
|
from agentkit.rag_platform.settings import (
|
||||||
|
KBSettings,
|
||||||
|
KBSettingsStore,
|
||||||
|
KBSettingsUpdate,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# KBSettings 模型测试
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestKBSettings:
|
||||||
|
"""KBSettings 模型测试。"""
|
||||||
|
|
||||||
|
def test_defaults(self):
|
||||||
|
"""默认值正确。"""
|
||||||
|
settings = KBSettings(kb_id="kb1")
|
||||||
|
assert settings.kb_id == "kb1"
|
||||||
|
assert settings.owner is None
|
||||||
|
assert settings.default_query_mode == QueryMode.blend
|
||||||
|
assert settings.default_hit_processing == HIT_PROCESSING_MODEL_OPT
|
||||||
|
assert settings.caching_disabled is False
|
||||||
|
assert settings.rerank_enabled is True
|
||||||
|
assert settings.rerank_provider == "cohere"
|
||||||
|
assert settings.rerank_api_key is None
|
||||||
|
assert settings.rerank_base_url is None
|
||||||
|
assert settings.data_export_warning is False
|
||||||
|
|
||||||
|
def test_custom_values(self):
|
||||||
|
"""自定义值正确。"""
|
||||||
|
settings = KBSettings(
|
||||||
|
kb_id="kb1",
|
||||||
|
owner="user1",
|
||||||
|
default_query_mode=QueryMode.keywords,
|
||||||
|
default_hit_processing=HIT_PROCESSING_DIRECT,
|
||||||
|
caching_disabled=True,
|
||||||
|
rerank_enabled=False,
|
||||||
|
rerank_provider="bge",
|
||||||
|
rerank_api_key="key-123",
|
||||||
|
rerank_base_url="http://localhost:9997",
|
||||||
|
data_export_warning=True,
|
||||||
|
)
|
||||||
|
assert settings.owner == "user1"
|
||||||
|
assert settings.default_query_mode == QueryMode.keywords
|
||||||
|
assert settings.default_hit_processing == HIT_PROCESSING_DIRECT
|
||||||
|
assert settings.caching_disabled is True
|
||||||
|
assert settings.rerank_enabled is False
|
||||||
|
assert settings.rerank_provider == "bge"
|
||||||
|
assert settings.rerank_api_key == "key-123"
|
||||||
|
assert settings.rerank_base_url == "http://localhost:9997"
|
||||||
|
assert settings.data_export_warning is True
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# KBSettingsUpdate 模型测试
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestKBSettingsUpdate:
|
||||||
|
"""KBSettingsUpdate 模型测试。"""
|
||||||
|
|
||||||
|
def test_all_none_by_default(self):
|
||||||
|
"""默认所有字段为 None(部分更新语义)。"""
|
||||||
|
update = KBSettingsUpdate()
|
||||||
|
assert update.default_query_mode is None
|
||||||
|
assert update.default_hit_processing is None
|
||||||
|
assert update.caching_disabled is None
|
||||||
|
assert update.rerank_enabled is None
|
||||||
|
assert update.rerank_provider is None
|
||||||
|
assert update.rerank_api_key is None
|
||||||
|
assert update.rerank_base_url is None
|
||||||
|
|
||||||
|
def test_partial_update(self):
|
||||||
|
"""部分字段赋值。"""
|
||||||
|
update = KBSettingsUpdate(default_hit_processing=HIT_PROCESSING_DIRECT)
|
||||||
|
assert update.default_hit_processing == HIT_PROCESSING_DIRECT
|
||||||
|
assert update.default_query_mode is None
|
||||||
|
|
||||||
|
def test_exclude_none_dumps(self):
|
||||||
|
"""model_dump(exclude_none=True) 仅包含已设置字段。"""
|
||||||
|
update = KBSettingsUpdate(caching_disabled=True)
|
||||||
|
data = update.model_dump(exclude_none=True)
|
||||||
|
assert data == {"caching_disabled": True}
|
||||||
|
|
||||||
|
def test_full_update(self):
|
||||||
|
"""所有字段同时赋值。"""
|
||||||
|
update = KBSettingsUpdate(
|
||||||
|
default_query_mode=QueryMode.embedding,
|
||||||
|
default_hit_processing=HIT_PROCESSING_DIRECT,
|
||||||
|
caching_disabled=True,
|
||||||
|
rerank_enabled=False,
|
||||||
|
rerank_provider="none",
|
||||||
|
rerank_api_key="key",
|
||||||
|
rerank_base_url="http://x",
|
||||||
|
)
|
||||||
|
data = update.model_dump(exclude_none=True)
|
||||||
|
assert len(data) == 7
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# KBSettingsStore CRUD 测试
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestKBSettingsStoreCRUD:
|
||||||
|
"""KBSettingsStore 存储测试。"""
|
||||||
|
|
||||||
|
async def test_get_settings_not_exist(self):
|
||||||
|
"""不存在的 KB 返回 None。"""
|
||||||
|
store = KBSettingsStore()
|
||||||
|
assert await store.get_settings("kb1") is None
|
||||||
|
|
||||||
|
async def test_get_or_create_creates_defaults(self):
|
||||||
|
"""get_or_create 创建默认设置。"""
|
||||||
|
store = KBSettingsStore()
|
||||||
|
settings = await store.get_or_create("kb1", owner="user1")
|
||||||
|
assert settings.kb_id == "kb1"
|
||||||
|
assert settings.owner == "user1"
|
||||||
|
assert settings.default_hit_processing == HIT_PROCESSING_MODEL_OPT
|
||||||
|
|
||||||
|
async def test_get_or_create_idempotent(self):
|
||||||
|
"""get_or_create 幂等 — 已存在时不覆盖 owner。"""
|
||||||
|
store = KBSettingsStore()
|
||||||
|
await store.get_or_create("kb1", owner="user1")
|
||||||
|
second = await store.get_or_create("kb1", owner="user2")
|
||||||
|
assert second.owner == "user1" # 不覆盖已有 owner
|
||||||
|
|
||||||
|
async def test_update_settings_creates_if_not_exist(self):
|
||||||
|
"""update_settings 在设置不存在时创建。"""
|
||||||
|
store = KBSettingsStore()
|
||||||
|
update = KBSettingsUpdate(caching_disabled=True)
|
||||||
|
result = await store.update_settings("kb1", update, owner="user1")
|
||||||
|
|
||||||
|
assert result.caching_disabled is True
|
||||||
|
assert result.owner == "user1"
|
||||||
|
assert result.kb_id == "kb1"
|
||||||
|
|
||||||
|
async def test_update_settings_partial(self):
|
||||||
|
"""update_settings 仅更新提供的字段,其他字段保持不变。"""
|
||||||
|
store = KBSettingsStore()
|
||||||
|
# 先创建默认设置
|
||||||
|
await store.get_or_create("kb1", owner="user1")
|
||||||
|
# 部分更新
|
||||||
|
update = KBSettingsUpdate(default_hit_processing=HIT_PROCESSING_DIRECT)
|
||||||
|
result = await store.update_settings("kb1", update)
|
||||||
|
|
||||||
|
assert result.default_hit_processing == HIT_PROCESSING_DIRECT
|
||||||
|
# 其他字段保持默认
|
||||||
|
assert result.default_query_mode == QueryMode.blend
|
||||||
|
assert result.caching_disabled is False
|
||||||
|
assert result.owner == "user1"
|
||||||
|
|
||||||
|
async def test_update_settings_multiple_fields(self):
|
||||||
|
"""update_settings 同时更新多个字段。"""
|
||||||
|
store = KBSettingsStore()
|
||||||
|
await store.get_or_create("kb1", owner="user1")
|
||||||
|
update = KBSettingsUpdate(
|
||||||
|
caching_disabled=True,
|
||||||
|
rerank_provider="bge",
|
||||||
|
rerank_enabled=False,
|
||||||
|
)
|
||||||
|
result = await store.update_settings("kb1", update)
|
||||||
|
|
||||||
|
assert result.caching_disabled is True
|
||||||
|
assert result.rerank_provider == "bge"
|
||||||
|
assert result.rerank_enabled is False
|
||||||
|
|
||||||
|
async def test_update_settings_persists(self):
|
||||||
|
"""update_settings 后 get_settings 返回更新后的值。"""
|
||||||
|
store = KBSettingsStore()
|
||||||
|
await store.update_settings("kb1", KBSettingsUpdate(caching_disabled=True), owner="user1")
|
||||||
|
retrieved = await store.get_settings("kb1")
|
||||||
|
assert retrieved is not None
|
||||||
|
assert retrieved.caching_disabled is True
|
||||||
|
|
||||||
|
async def test_update_settings_none_field_ignored(self):
|
||||||
|
"""update_settings 中 None 字段被忽略。"""
|
||||||
|
store = KBSettingsStore()
|
||||||
|
await store.update_settings(
|
||||||
|
"kb1",
|
||||||
|
KBSettingsUpdate(caching_disabled=True, default_hit_processing=HIT_PROCESSING_DIRECT),
|
||||||
|
owner="user1",
|
||||||
|
)
|
||||||
|
# 第二次更新 — 只改 caching_disabled,default_hit_processing 应保持
|
||||||
|
await store.update_settings(
|
||||||
|
"kb1",
|
||||||
|
KBSettingsUpdate(caching_disabled=False),
|
||||||
|
)
|
||||||
|
result = await store.get_settings("kb1")
|
||||||
|
assert result is not None
|
||||||
|
assert result.caching_disabled is False
|
||||||
|
assert result.default_hit_processing == HIT_PROCESSING_DIRECT
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# owner 校验测试
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestOwnerCheck:
|
||||||
|
"""owner 校验测试。"""
|
||||||
|
|
||||||
|
async def test_is_owner_true(self):
|
||||||
|
"""owner 匹配返回 True。"""
|
||||||
|
store = KBSettingsStore()
|
||||||
|
await store.get_or_create("kb1", owner="user1")
|
||||||
|
assert store.is_owner("kb1", "user1") is True
|
||||||
|
|
||||||
|
async def test_is_owner_false_wrong_user(self):
|
||||||
|
"""非 owner 用户返回 False。"""
|
||||||
|
store = KBSettingsStore()
|
||||||
|
await store.get_or_create("kb1", owner="user1")
|
||||||
|
assert store.is_owner("kb1", "user2") is False
|
||||||
|
|
||||||
|
async def test_is_owner_false_none_user(self):
|
||||||
|
"""None 用户返回 False。"""
|
||||||
|
store = KBSettingsStore()
|
||||||
|
await store.get_or_create("kb1", owner="user1")
|
||||||
|
assert store.is_owner("kb1", None) is False
|
||||||
|
|
||||||
|
async def test_is_owner_false_nonexistent_kb(self):
|
||||||
|
"""不存在的 KB 返回 False。"""
|
||||||
|
store = KBSettingsStore()
|
||||||
|
assert store.is_owner("nonexistent", "user1") is False
|
||||||
|
|
||||||
|
async def test_is_owner_false_no_owner_set(self):
|
||||||
|
"""owner 未设置时返回 False。"""
|
||||||
|
store = KBSettingsStore()
|
||||||
|
await store.get_or_create("kb1") # 不传 owner
|
||||||
|
assert store.is_owner("kb1", "user1") is False
|
||||||
|
|
||||||
|
async def test_set_owner(self):
|
||||||
|
"""set_owner 设置所有者。"""
|
||||||
|
store = KBSettingsStore()
|
||||||
|
await store.get_or_create("kb1", owner="user1")
|
||||||
|
assert store.set_owner("kb1", "user2") is True
|
||||||
|
assert store.is_owner("kb1", "user2") is True
|
||||||
|
assert store.is_owner("kb1", "user1") is False
|
||||||
|
|
||||||
|
async def test_set_owner_nonexistent(self):
|
||||||
|
"""set_owner 对不存在的 KB 返回 False。"""
|
||||||
|
store = KBSettingsStore()
|
||||||
|
assert store.set_owner("nonexistent", "user1") is False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# KB 默认模式生效测试(与 HitProcessor 集成)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestKBSettingsDefaultModeIntegration:
|
||||||
|
"""KB 设置默认模式与 HitProcessor 集成测试。"""
|
||||||
|
|
||||||
|
async def test_default_model_opt_flows_to_processor(self):
|
||||||
|
"""KB 默认 model_opt 模式传递给 HitProcessor。"""
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
from agentkit.rag_platform.hit_processing import HitProcessor
|
||||||
|
from agentkit.rag_platform.models import QueryResult
|
||||||
|
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.content = "LLM 回答"
|
||||||
|
mock_llm.chat = AsyncMock(return_value=mock_resp)
|
||||||
|
|
||||||
|
store = KBSettingsStore()
|
||||||
|
settings = await store.get_or_create("kb1", owner="user1")
|
||||||
|
|
||||||
|
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
|
||||||
|
results = [
|
||||||
|
QueryResult(
|
||||||
|
chunk_id="c1",
|
||||||
|
content="内容",
|
||||||
|
score=0.9,
|
||||||
|
metadata={},
|
||||||
|
document_id="d1",
|
||||||
|
kb_id="kb1",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
result = await processor.process("query", results, mode=settings.default_hit_processing)
|
||||||
|
assert result.mode == HIT_PROCESSING_MODEL_OPT
|
||||||
|
mock_llm.chat.assert_awaited_once()
|
||||||
|
|
||||||
|
async def test_default_direct_flows_to_processor(self):
|
||||||
|
"""KB 默认 direct 模式传递给 HitProcessor。"""
|
||||||
|
from agentkit.rag_platform.hit_processing import HitProcessor
|
||||||
|
from agentkit.rag_platform.models import QueryResult
|
||||||
|
|
||||||
|
store = KBSettingsStore()
|
||||||
|
await store.update_settings(
|
||||||
|
"kb1",
|
||||||
|
KBSettingsUpdate(default_hit_processing=HIT_PROCESSING_DIRECT),
|
||||||
|
owner="user1",
|
||||||
|
)
|
||||||
|
settings = await store.get_settings("kb1")
|
||||||
|
assert settings is not None
|
||||||
|
|
||||||
|
processor = HitProcessor(llm_gateway=None, cache_enabled=False)
|
||||||
|
results = [
|
||||||
|
QueryResult(
|
||||||
|
chunk_id="c1",
|
||||||
|
content="直接段落",
|
||||||
|
score=0.9,
|
||||||
|
metadata={},
|
||||||
|
document_id="d1",
|
||||||
|
kb_id="kb1",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
result = await processor.process("query", results, mode=settings.default_hit_processing)
|
||||||
|
assert result.mode == HIT_PROCESSING_DIRECT
|
||||||
|
assert "直接段落" in result.answer
|
||||||
Loading…
Reference in New Issue