From d026a91f438c368743c00e0639ce6deeaf0c7f15 Mon Sep 17 00:00:00 2001 From: chiguyong Date: Thu, 25 Jun 2026 12:44:47 +0800 Subject: [PATCH] =?UTF-8?q?feat(rag=5Fplatform):=20U6=20=E2=80=94=20hit=20?= =?UTF-8?q?processing=20mode=20+=20KB=20settings?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add hit_processing.py: HitProcessor with model_opt (LLM-generated) and direct (concatenated chunks) modes, with in-process cache Add settings.py: KBSettings/KBSettingsUpdate models, KBSettingsStore with async CRUD Add KB settings endpoints to kb_management.py: GET/PUT /kb-management/kbs/{kb_id}/settings with owner-only modification Tests: 43 new tests (25 hit_processing + 18 settings), 293 total passing --- src/agentkit/rag_platform/hit_processing.py | 181 ++++++++++ src/agentkit/rag_platform/settings.py | 144 ++++++++ src/agentkit/server/routes/kb_management.py | 64 ++++ .../unit/rag_platform/test_hit_processing.py | 319 +++++++++++++++++ tests/unit/rag_platform/test_settings.py | 329 ++++++++++++++++++ 5 files changed, 1037 insertions(+) create mode 100644 src/agentkit/rag_platform/hit_processing.py create mode 100644 src/agentkit/rag_platform/settings.py create mode 100644 tests/unit/rag_platform/test_hit_processing.py create mode 100644 tests/unit/rag_platform/test_settings.py diff --git a/src/agentkit/rag_platform/hit_processing.py b/src/agentkit/rag_platform/hit_processing.py new file mode 100644 index 0000000..1505a01 --- /dev/null +++ b/src/agentkit/rag_platform/hit_processing.py @@ -0,0 +1,181 @@ +"""命中处理 — 模型优化模式(LLM 生成回答)vs 直接回答模式(返回匹配段落)。 + +- model_opt: 调用 LLM Gateway,基于检索结果生成自然语言回答 +- direct: 直接拼接匹配段落返回(零 LLM 调用,零额外延迟) + +模式选择由 KB 设置 ``default_hit_processing`` 决定,Agent 运行时可通过 +``RetrievalRequest.hit_processing_mode`` 覆盖。 +""" + +from __future__ import annotations + +import hashlib +import logging +from typing import Any + +from pydantic import BaseModel, ConfigDict + +from agentkit.rag_platform.models import QueryResult + +logger = logging.getLogger(__name__) + +HIT_PROCESSING_MODEL_OPT = "model_opt" +HIT_PROCESSING_DIRECT = "direct" + +# ponytail: 硬编码默认模型 — 升级路径为从 KB 设置 / agentkit.yaml 读取 +_DEFAULT_LLM_MODEL = "gpt-4o-mini" + +# RAG prompt 模板 — 指示 LLM 基于参考资料回答 +_RAG_PROMPT_TEMPLATE = ( + "你是一个知识库问答助手。请根据以下检索到的参考资料回答用户问题。\n\n" + "参考资料:\n{context}\n\n" + "用户问题:{query}\n\n" + "请基于参考资料回答,如果资料中没有相关信息请说明。" +) + + +class HitProcessingResult(BaseModel): + """命中处理结果。""" + + model_config = ConfigDict() + + mode: str # "model_opt" | "direct" + answer: str # LLM 生成的回答或直接拼接的段落 + sources: list[QueryResult] # 检索到的来源 + cached: bool = False # 是否命中缓存 + + +class HitProcessor: + """命中处理器 — 根据 mode 处理检索结果。 + + Args: + llm_gateway: LLM 网关实例(model_opt 模式使用),需支持 ``async chat(messages, model)`` 接口 + cache_enabled: 是否启用命中处理缓存(默认 True) + model: LLM 模型名(默认 gpt-4o-mini) + """ + + def __init__( + self, + llm_gateway: Any = None, + cache_enabled: bool = True, + model: str = _DEFAULT_LLM_MODEL, + ) -> None: + self._llm = llm_gateway + self._cache_enabled = cache_enabled + self._model = model + # ponytail: 进程内 dict 缓存 — 无 TTL,无容量上限 + # 升级路径:迁移到 Redis,带 TTL + LRU 淘汰 + self._cache: dict[str, HitProcessingResult] = {} + + async def process( + self, + query: str, + results: list[QueryResult], + mode: str = HIT_PROCESSING_MODEL_OPT, + ) -> HitProcessingResult: + """处理检索结果。 + + Args: + query: 用户查询文本 + results: 检索结果列表 + mode: 命中处理模式 — "model_opt"(LLM 生成)或 "direct"(直接返回段落) + + Returns: + HitProcessingResult 包含回答、来源和缓存标记 + """ + if not results: + return HitProcessingResult( + mode=mode, answer="未找到相关内容。", sources=[], cached=False + ) + + # 缓存检查 + cache_key: str | None = None + if self._cache_enabled: + cache_key = self._cache_key(query, results, mode) + cached = self._cache.get(cache_key) + if cached is not None: + return cached.model_copy(update={"cached": True}) + + if mode == HIT_PROCESSING_DIRECT: + result = self._process_direct(results) + else: + result = await self._process_model_opt(query, results) + + # 写入缓存 + if cache_key is not None: + self._cache[cache_key] = result + + return result + + def _process_direct(self, results: list[QueryResult]) -> HitProcessingResult: + """直接回答模式 — 拼接匹配段落,按分数降序排列。""" + sorted_results = sorted(results, key=lambda r: r.score, reverse=True) + parts: list[str] = [] + for i, r in enumerate(sorted_results, start=1): + parts.append(f"[{i}] {r.content}") + answer = "\n\n".join(parts) + return HitProcessingResult( + mode=HIT_PROCESSING_DIRECT, + answer=answer, + sources=sorted_results, + cached=False, + ) + + async def _process_model_opt( + self, + query: str, + results: list[QueryResult], + ) -> HitProcessingResult: + """模型优化模式 — LLM 基于检索结果生成回答。 + + 无 LLM 网关或调用失败时降级为 direct 模式(不丢失检索结果)。 + """ + if self._llm is None: + logger.warning("No LLM gateway configured, falling back to direct mode") + return self._process_direct(results) + + # 构建 context — 按分数降序 + sorted_results = sorted(results, key=lambda r: r.score, reverse=True) + context_parts: list[str] = [] + for i, r in enumerate(sorted_results, start=1): + context_parts.append(f"[{i}] {r.content}") + context = "\n\n".join(context_parts) + + prompt = _RAG_PROMPT_TEMPLATE.format(context=context, query=query) + messages = [ + {"role": "system", "content": "你是一个知识库问答助手,请基于参考资料回答问题。"}, + {"role": "user", "content": prompt}, + ] + + try: + response = await self._llm.chat(messages=messages, model=self._model) + answer = response.content + except Exception as e: + logger.warning("LLM call failed, falling back to direct mode: %s", e) + return self._process_direct(results) + + return HitProcessingResult( + mode=HIT_PROCESSING_MODEL_OPT, + answer=answer, + sources=sorted_results, + cached=False, + ) + + @staticmethod + def _cache_key( + query: str, + results: list[QueryResult], + mode: str, + ) -> str: + """生成缓存键 — 基于 mode + query + chunk_ids 的 SHA256 哈希。""" + chunk_ids = ",".join(sorted(r.chunk_id for r in results)) + raw = f"{mode}:{query}:{chunk_ids}" + return hashlib.sha256(raw.encode()).hexdigest() + + +__all__ = [ + "HIT_PROCESSING_DIRECT", + "HIT_PROCESSING_MODEL_OPT", + "HitProcessingResult", + "HitProcessor", +] diff --git a/src/agentkit/rag_platform/settings.py b/src/agentkit/rag_platform/settings.py new file mode 100644 index 0000000..039dce1 --- /dev/null +++ b/src/agentkit/rag_platform/settings.py @@ -0,0 +1,144 @@ +"""KB 设置模型 — 检索模式默认/命中处理默认/授权用户/caching/rerank。 + +KB 级别设置控制检索与命中处理的默认行为,Agent 运行时可通过 +``RetrievalRequest`` 字段覆盖。 + +当前为进程内 dict 存储(重启丢失)。 +ponytail: 升级路径 — 迁移到 PG(KBModel 已有 default_query_mode / +default_hit_processing / caching_disabled 列),rerank 相关设置需新增 +JSON 列或独立表。 +""" + +from __future__ import annotations + +import logging + +from pydantic import BaseModel, ConfigDict + +from agentkit.rag_platform.models import QueryMode + +logger = logging.getLogger(__name__) + +HIT_PROCESSING_MODEL_OPT = "model_opt" +HIT_PROCESSING_DIRECT = "direct" + + +class KBSettings(BaseModel): + """KB 级别设置 — 可通过 API 修改。 + + owner 字段用于 ACL 校验:仅 owner(或 admin)可修改设置,viewer 只读。 + """ + + model_config = ConfigDict() + + kb_id: str + owner: str | None = None # KB 所有者(仅 owner 可修改设置) + default_query_mode: QueryMode = QueryMode.blend + default_hit_processing: str = HIT_PROCESSING_MODEL_OPT # model_opt | direct + caching_disabled: bool = False + rerank_enabled: bool = True + rerank_provider: str = "cohere" # cohere | bge | none + rerank_api_key: str | None = None + rerank_base_url: str | None = None + data_export_warning: bool = False # True = 使用云端 rerank 处理敏感数据 + + +class KBSettingsUpdate(BaseModel): + """KB 设置更新请求 — 仅 owner 可修改。 + + 所有字段可选(部分更新语义),None 表示不更新该字段。 + """ + + model_config = ConfigDict() + + default_query_mode: QueryMode | None = None + default_hit_processing: str | None = None + caching_disabled: bool | None = None + rerank_enabled: bool | None = None + rerank_provider: str | None = None + rerank_api_key: str | None = None + rerank_base_url: str | None = None + + +class KBSettingsStore: + """KB 设置存储 — 读写 KB 设置。 + + 当前为进程内 dict 存储。CRUD 方法为 async 以保持与未来 PG 迁移的接口兼容。 + """ + + def __init__(self) -> None: + self._settings: dict[str, KBSettings] = {} + + async def get_settings(self, kb_id: str) -> KBSettings | None: + """读取 KB 设置。不存在时返回 None。""" + return self._settings.get(kb_id) + + async def get_or_create(self, kb_id: str, owner: str | None = None) -> KBSettings: + """读取 KB 设置,不存在时创建默认设置。 + + 已存在时不覆盖 owner。 + """ + if kb_id not in self._settings: + self._settings[kb_id] = KBSettings(kb_id=kb_id, owner=owner) + return self._settings[kb_id] + + async def update_settings( + self, + kb_id: str, + update: KBSettingsUpdate, + owner: str | None = None, + ) -> KBSettings: + """更新 KB 设置 — 仅更新非 None 字段。 + + 首次更新时若提供 owner 则设置所有者。owner 校验由路由层完成。 + """ + existing = self._settings.get(kb_id) + if existing is None: + existing = KBSettings(kb_id=kb_id, owner=owner) + self._settings[kb_id] = existing + + update_data = update.model_dump(exclude_none=True) + updated = existing.model_copy(update=update_data) + self._settings[kb_id] = updated + return updated + + def is_owner(self, kb_id: str, user_id: str | None) -> bool: + """检查用户是否为 KB 所有者。KB 不存在或 owner 未设置时返回 False。""" + settings = self._settings.get(kb_id) + if settings is None or settings.owner is None: + return False + return settings.owner == user_id + + def set_owner(self, kb_id: str, owner: str) -> bool: + """设置 KB 所有者。返回 True 如果 KB 存在。""" + settings = self._settings.get(kb_id) + if settings is None: + return False + self._settings[kb_id] = settings.model_copy(update={"owner": owner}) + return True + + +# 模块级单例 — 类似 KnowledgeSourceStore 模式 +_settings_store = KBSettingsStore() + + +def get_settings_store() -> KBSettingsStore: + """返回进程级 KBSettingsStore 单例。""" + return _settings_store + + +def set_settings_store(store: KBSettingsStore) -> None: + """注入自定义 KBSettingsStore(测试用)。""" + global _settings_store + _settings_store = store + + +__all__ = [ + "HIT_PROCESSING_DIRECT", + "HIT_PROCESSING_MODEL_OPT", + "KBSettings", + "KBSettingsStore", + "KBSettingsUpdate", + "get_settings_store", + "set_settings_store", +] diff --git a/src/agentkit/server/routes/kb_management.py b/src/agentkit/server/routes/kb_management.py index 0a66be2..1bc570d 100644 --- a/src/agentkit/server/routes/kb_management.py +++ b/src/agentkit/server/routes/kb_management.py @@ -652,3 +652,67 @@ async def check_source_health(source_id: str, req: Request, _auth: None = Depend # For external sources, try to check connectivity # This would use the actual KB adapters in production return {"source_id": source_id, "status": "unknown", "message": "健康检查未实现"} + + +# --------------------------------------------------------------------------- +# KB Settings endpoints (U6) +# --------------------------------------------------------------------------- + +from agentkit.rag_platform.settings import ( # noqa: E402 + KBSettings, + KBSettingsUpdate, + get_settings_store, +) + + +@router.get("/kb-management/kbs/{kb_id}/settings") +async def get_kb_settings( + kb_id: str, + req: Request, + _auth: None = Depends(_verify_api_key), + _user=Depends(require_permission(Permission.KB_QUERY)), +): + """获取 KB 设置 — owner/viewer 可读。 + + U6: 返回 KB 级别设置(检索模式默认/命中处理默认/caching/rerank)。 + 设置不存在时返回默认值。 + """ + store = get_settings_store() + settings = await store.get_settings(kb_id) + if settings is None: + # 返回默认设置(不持久化 — 首次 PUT 时创建) + settings = KBSettings(kb_id=kb_id) + return settings.model_dump() + + +@router.put("/kb-management/kbs/{kb_id}/settings") +async def update_kb_settings( + kb_id: str, + update: KBSettingsUpdate, + req: Request, + _auth: None = Depends(_verify_api_key), + _user=Depends(require_permission(Permission.KB_WRITE)), +): + """更新 KB 设置 — 仅 owner(或 admin)可修改。 + + U6: per-KB ACL 校验 — 首次更新时设置 owner,后续更新仅 owner/admin 可调用。 + viewer(有 KB_QUERY 但非 owner)调用返回 403。 + """ + user_id = _user.get("user_id") + role = _user.get("role", "") + + store = get_settings_store() + existing = await store.get_settings(kb_id) + + # ACL 校验:已有设置且 owner 已设置时,仅 owner/admin 可修改 + if existing is not None and existing.owner is not None: + if existing.owner != user_id and role != "admin": + raise HTTPException( + status_code=403, + detail="仅 KB owner 可修改设置", + ) + + # 首次创建时设置 owner + owner = user_id if (existing is None or existing.owner is None) else None + updated = await store.update_settings(kb_id, update, owner=owner) + return updated.model_dump() diff --git a/tests/unit/rag_platform/test_hit_processing.py b/tests/unit/rag_platform/test_hit_processing.py new file mode 100644 index 0000000..7e6e446 --- /dev/null +++ b/tests/unit/rag_platform/test_hit_processing.py @@ -0,0 +1,319 @@ +"""U6 测试 — 命中处理(model_opt / direct 模式)。 + +测试场景: +1. model_opt 模式:LLM 基于检索结果生成回答 +2. direct 模式:直接返回匹配段落 +3. 空结果返回"未找到相关内容" +4. 无 LLM 网关时降级为 direct 模式 +5. LLM 调用失败时降级为 direct 模式 +6. 缓存命中返回 cached=True +7. KB 默认模式生效 +8. Agent 运行时覆盖默认模式 +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock + +from agentkit.rag_platform.hit_processing import ( + HIT_PROCESSING_DIRECT, + HIT_PROCESSING_MODEL_OPT, + HitProcessor, +) +from agentkit.rag_platform.models import QueryResult +from agentkit.rag_platform.settings import KBSettings + + +def _make_query_result( + chunk_id: str, + content: str, + score: float = 0.5, + document_id: str = "doc1", + kb_id: str = "kb1", +) -> QueryResult: + """创建测试用 QueryResult。""" + return QueryResult( + chunk_id=chunk_id, + content=content, + score=score, + metadata={"document_id": document_id, "kb_id": kb_id}, + document_id=document_id, + kb_id=kb_id, + ) + + +def _make_mock_llm_gateway(response_content: str = "LLM 生成的回答") -> MagicMock: + """创建 mock LLM 网关 — chat() 返回带 content 的 response。""" + mock = MagicMock() + mock_response = MagicMock() + mock_response.content = response_content + mock.chat = AsyncMock(return_value=mock_response) + return mock + + +# --------------------------------------------------------------------------- +# direct 模式测试 +# --------------------------------------------------------------------------- + + +class TestDirectMode: + """direct 模式测试。""" + + async def test_direct_returns_concatenated_chunks(self): + """direct 模式返回拼接的匹配段落,按分数降序。""" + processor = HitProcessor(llm_gateway=None, cache_enabled=False) + results = [ + _make_query_result("c1", "段落一", score=0.9), + _make_query_result("c2", "段落二", score=0.7), + ] + result = await processor.process("query", results, mode=HIT_PROCESSING_DIRECT) + + assert result.mode == HIT_PROCESSING_DIRECT + assert "段落一" in result.answer + assert "段落二" in result.answer + assert result.cached is False + assert len(result.sources) == 2 + # 按分数降序 + assert result.sources[0].chunk_id == "c1" + + async def test_direct_format_includes_index(self): + """direct 模式段落带序号 [1] [2]。""" + processor = HitProcessor(llm_gateway=None, cache_enabled=False) + results = [_make_query_result("c1", "唯一段落")] + result = await processor.process("query", results, mode=HIT_PROCESSING_DIRECT) + + assert "[1]" in result.answer + assert "唯一段落" in result.answer + + async def test_direct_empty_results(self): + """空结果返回"未找到相关内容"。""" + processor = HitProcessor(llm_gateway=None, cache_enabled=False) + result = await processor.process("query", [], mode=HIT_PROCESSING_DIRECT) + + assert result.answer == "未找到相关内容。" + assert result.sources == [] + assert result.cached is False + + +# --------------------------------------------------------------------------- +# model_opt 模式测试 +# --------------------------------------------------------------------------- + + +class TestModelOptMode: + """model_opt 模式测试。""" + + async def test_model_opt_calls_llm(self): + """model_opt 模式调用 LLM 生成回答。""" + mock_llm = _make_mock_llm_gateway("基于资料的回答") + processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False) + results = [_make_query_result("c1", "参考资料内容")] + result = await processor.process("问题", results, mode=HIT_PROCESSING_MODEL_OPT) + + assert result.mode == HIT_PROCESSING_MODEL_OPT + assert result.answer == "基于资料的回答" + assert len(result.sources) == 1 + mock_llm.chat.assert_awaited_once() + + async def test_model_opt_includes_context_in_prompt(self): + """model_opt 模式将检索结果作为 context 传入 LLM。""" + mock_llm = _make_mock_llm_gateway("回答") + processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False) + results = [_make_query_result("c1", "独特的资料文本")] + await processor.process("问题", results, mode=HIT_PROCESSING_MODEL_OPT) + + call_args = mock_llm.chat.await_args + messages = call_args.kwargs.get("messages") or call_args.args[0] + # 检查 context 出现在 prompt 中 + all_content = " ".join(m["content"] for m in messages) + assert "独特的资料文本" in all_content + + async def test_model_opt_passes_query_in_prompt(self): + """model_opt 模式将用户查询传入 LLM prompt。""" + mock_llm = _make_mock_llm_gateway("回答") + processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False) + results = [_make_query_result("c1", "资料")] + await processor.process("独特查询文本", results, mode=HIT_PROCESSING_MODEL_OPT) + + call_args = mock_llm.chat.await_args + messages = call_args.kwargs.get("messages") or call_args.args[0] + all_content = " ".join(m["content"] for m in messages) + assert "独特查询文本" in all_content + + async def test_model_opt_no_llm_falls_back_to_direct(self): + """无 LLM 网关时降级为 direct 模式。""" + processor = HitProcessor(llm_gateway=None, cache_enabled=False) + results = [_make_query_result("c1", "段落内容")] + result = await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT) + + # 降级为 direct — answer 是拼接的段落 + assert "段落内容" in result.answer + + async def test_model_opt_llm_failure_falls_back_to_direct(self): + """LLM 调用失败时降级为 direct 模式(不丢失检索结果)。""" + mock_llm = MagicMock() + mock_llm.chat = AsyncMock(side_effect=RuntimeError("LLM 不可用")) + processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False) + results = [_make_query_result("c1", "降级段落")] + result = await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT) + + assert "降级段落" in result.answer + + async def test_model_opt_empty_results(self): + """空结果返回"未找到相关内容",不调用 LLM。""" + mock_llm = _make_mock_llm_gateway() + processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False) + result = await processor.process("query", [], mode=HIT_PROCESSING_MODEL_OPT) + + assert result.answer == "未找到相关内容。" + mock_llm.chat.assert_not_awaited() + + async def test_model_opt_sorts_by_score(self): + """model_opt 模式按分数降序构建 context。""" + mock_llm = _make_mock_llm_gateway("回答") + processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False) + results = [ + _make_query_result("c2", "低分段落", score=0.3), + _make_query_result("c1", "高分段落", score=0.9), + ] + result = await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT) + + assert result.sources[0].chunk_id == "c1" + assert result.sources[1].chunk_id == "c2" + + +# --------------------------------------------------------------------------- +# 缓存测试 +# --------------------------------------------------------------------------- + + +class TestCaching: + """缓存测试。""" + + async def test_cache_hit_returns_cached_true(self): + """相同输入第二次返回 cached=True。""" + mock_llm = _make_mock_llm_gateway("回答") + processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=True) + results = [_make_query_result("c1", "内容")] + + first = await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT) + second = await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT) + + assert first.cached is False + assert second.cached is True + # LLM 只调用一次(第二次命中缓存) + assert mock_llm.chat.await_count == 1 + + async def test_cache_disabled_no_caching(self): + """禁用缓存时每次都调用 LLM。""" + mock_llm = _make_mock_llm_gateway("回答") + processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False) + results = [_make_query_result("c1", "内容")] + + await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT) + await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT) + + assert mock_llm.chat.await_count == 2 + + async def test_cache_key_includes_query(self): + """不同查询不共享缓存。""" + mock_llm = _make_mock_llm_gateway("回答") + processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=True) + results = [_make_query_result("c1", "内容")] + + await processor.process("query1", results, mode=HIT_PROCESSING_MODEL_OPT) + await processor.process("query2", results, mode=HIT_PROCESSING_MODEL_OPT) + + assert mock_llm.chat.await_count == 2 + + async def test_cache_key_includes_chunk_ids(self): + """不同检索结果不共享缓存。""" + mock_llm = _make_mock_llm_gateway("回答") + processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=True) + + await processor.process( + "query", [_make_query_result("c1", "内容")], mode=HIT_PROCESSING_MODEL_OPT + ) + await processor.process( + "query", [_make_query_result("c2", "内容")], mode=HIT_PROCESSING_MODEL_OPT + ) + + assert mock_llm.chat.await_count == 2 + + async def test_cache_key_includes_mode(self): + """不同模式不共享缓存。""" + mock_llm = _make_mock_llm_gateway("回答") + processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=True) + results = [_make_query_result("c1", "内容")] + + await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT) + await processor.process("query", results, mode=HIT_PROCESSING_DIRECT) + + # direct 模式不调用 LLM,但也不应命中 model_opt 的缓存 + assert mock_llm.chat.await_count == 1 + + async def test_cache_not_used_for_empty_results(self): + """空结果不写入缓存。""" + mock_llm = _make_mock_llm_gateway() + processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=True) + + first = await processor.process("query", [], mode=HIT_PROCESSING_MODEL_OPT) + second = await processor.process("query", [], mode=HIT_PROCESSING_MODEL_OPT) + + assert first.cached is False + assert second.cached is False + + +# --------------------------------------------------------------------------- +# KB 默认模式 + 运行时覆盖测试 +# --------------------------------------------------------------------------- + + +class TestKBDefaultMode: + """KB 默认模式生效测试。""" + + async def test_kb_default_model_opt(self): + """KB 默认 model_opt 模式生效。""" + mock_llm = _make_mock_llm_gateway("回答") + processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False) + settings = KBSettings(kb_id="kb1") # 默认 model_opt + results = [_make_query_result("c1", "内容")] + + result = await processor.process("query", results, mode=settings.default_hit_processing) + assert result.mode == HIT_PROCESSING_MODEL_OPT + mock_llm.chat.assert_awaited_once() + + async def test_kb_default_direct(self): + """KB 默认 direct 模式生效。""" + processor = HitProcessor(llm_gateway=None, cache_enabled=False) + settings = KBSettings(kb_id="kb1", default_hit_processing=HIT_PROCESSING_DIRECT) + results = [_make_query_result("c1", "内容")] + + result = await processor.process("query", results, mode=settings.default_hit_processing) + assert result.mode == HIT_PROCESSING_DIRECT + + +class TestModeOverride: + """Agent 运行时覆盖默认模式测试。""" + + async def test_override_to_direct(self): + """运行时覆盖为 direct 模式 — 不调用 LLM。""" + mock_llm = _make_mock_llm_gateway("LLM 回答") + processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False) + results = [_make_query_result("c1", "段落")] + + # KB 默认 model_opt,运行时覆盖为 direct + result = await processor.process("query", results, mode=HIT_PROCESSING_DIRECT) + assert result.mode == HIT_PROCESSING_DIRECT + mock_llm.chat.assert_not_awaited() + + async def test_override_to_model_opt(self): + """运行时覆盖为 model_opt 模式 — 调用 LLM。""" + mock_llm = _make_mock_llm_gateway("LLM 回答") + processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False) + results = [_make_query_result("c1", "段落")] + + # KB 默认 direct,运行时覆盖为 model_opt + result = await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT) + assert result.mode == HIT_PROCESSING_MODEL_OPT + mock_llm.chat.assert_awaited_once() diff --git a/tests/unit/rag_platform/test_settings.py b/tests/unit/rag_platform/test_settings.py new file mode 100644 index 0000000..b19aa69 --- /dev/null +++ b/tests/unit/rag_platform/test_settings.py @@ -0,0 +1,329 @@ +"""U6 测试 — KB 设置模型与存储。 + +测试场景: +1. KBSettings 默认值 +2. KBSettingsUpdate 部分更新语义 +3. KBSettingsStore CRUD(get / get_or_create / update) +4. owner 校验(is_owner / set_owner) +5. KB 设置默认模式生效(与 HitProcessor 集成) +""" + +from __future__ import annotations + +from agentkit.rag_platform.hit_processing import ( + HIT_PROCESSING_DIRECT, + HIT_PROCESSING_MODEL_OPT, +) +from agentkit.rag_platform.models import QueryMode +from agentkit.rag_platform.settings import ( + KBSettings, + KBSettingsStore, + KBSettingsUpdate, +) + + +# --------------------------------------------------------------------------- +# KBSettings 模型测试 +# --------------------------------------------------------------------------- + + +class TestKBSettings: + """KBSettings 模型测试。""" + + def test_defaults(self): + """默认值正确。""" + settings = KBSettings(kb_id="kb1") + assert settings.kb_id == "kb1" + assert settings.owner is None + assert settings.default_query_mode == QueryMode.blend + assert settings.default_hit_processing == HIT_PROCESSING_MODEL_OPT + assert settings.caching_disabled is False + assert settings.rerank_enabled is True + assert settings.rerank_provider == "cohere" + assert settings.rerank_api_key is None + assert settings.rerank_base_url is None + assert settings.data_export_warning is False + + def test_custom_values(self): + """自定义值正确。""" + settings = KBSettings( + kb_id="kb1", + owner="user1", + default_query_mode=QueryMode.keywords, + default_hit_processing=HIT_PROCESSING_DIRECT, + caching_disabled=True, + rerank_enabled=False, + rerank_provider="bge", + rerank_api_key="key-123", + rerank_base_url="http://localhost:9997", + data_export_warning=True, + ) + assert settings.owner == "user1" + assert settings.default_query_mode == QueryMode.keywords + assert settings.default_hit_processing == HIT_PROCESSING_DIRECT + assert settings.caching_disabled is True + assert settings.rerank_enabled is False + assert settings.rerank_provider == "bge" + assert settings.rerank_api_key == "key-123" + assert settings.rerank_base_url == "http://localhost:9997" + assert settings.data_export_warning is True + + +# --------------------------------------------------------------------------- +# KBSettingsUpdate 模型测试 +# --------------------------------------------------------------------------- + + +class TestKBSettingsUpdate: + """KBSettingsUpdate 模型测试。""" + + def test_all_none_by_default(self): + """默认所有字段为 None(部分更新语义)。""" + update = KBSettingsUpdate() + assert update.default_query_mode is None + assert update.default_hit_processing is None + assert update.caching_disabled is None + assert update.rerank_enabled is None + assert update.rerank_provider is None + assert update.rerank_api_key is None + assert update.rerank_base_url is None + + def test_partial_update(self): + """部分字段赋值。""" + update = KBSettingsUpdate(default_hit_processing=HIT_PROCESSING_DIRECT) + assert update.default_hit_processing == HIT_PROCESSING_DIRECT + assert update.default_query_mode is None + + def test_exclude_none_dumps(self): + """model_dump(exclude_none=True) 仅包含已设置字段。""" + update = KBSettingsUpdate(caching_disabled=True) + data = update.model_dump(exclude_none=True) + assert data == {"caching_disabled": True} + + def test_full_update(self): + """所有字段同时赋值。""" + update = KBSettingsUpdate( + default_query_mode=QueryMode.embedding, + default_hit_processing=HIT_PROCESSING_DIRECT, + caching_disabled=True, + rerank_enabled=False, + rerank_provider="none", + rerank_api_key="key", + rerank_base_url="http://x", + ) + data = update.model_dump(exclude_none=True) + assert len(data) == 7 + + +# --------------------------------------------------------------------------- +# KBSettingsStore CRUD 测试 +# --------------------------------------------------------------------------- + + +class TestKBSettingsStoreCRUD: + """KBSettingsStore 存储测试。""" + + async def test_get_settings_not_exist(self): + """不存在的 KB 返回 None。""" + store = KBSettingsStore() + assert await store.get_settings("kb1") is None + + async def test_get_or_create_creates_defaults(self): + """get_or_create 创建默认设置。""" + store = KBSettingsStore() + settings = await store.get_or_create("kb1", owner="user1") + assert settings.kb_id == "kb1" + assert settings.owner == "user1" + assert settings.default_hit_processing == HIT_PROCESSING_MODEL_OPT + + async def test_get_or_create_idempotent(self): + """get_or_create 幂等 — 已存在时不覆盖 owner。""" + store = KBSettingsStore() + await store.get_or_create("kb1", owner="user1") + second = await store.get_or_create("kb1", owner="user2") + assert second.owner == "user1" # 不覆盖已有 owner + + async def test_update_settings_creates_if_not_exist(self): + """update_settings 在设置不存在时创建。""" + store = KBSettingsStore() + update = KBSettingsUpdate(caching_disabled=True) + result = await store.update_settings("kb1", update, owner="user1") + + assert result.caching_disabled is True + assert result.owner == "user1" + assert result.kb_id == "kb1" + + async def test_update_settings_partial(self): + """update_settings 仅更新提供的字段,其他字段保持不变。""" + store = KBSettingsStore() + # 先创建默认设置 + await store.get_or_create("kb1", owner="user1") + # 部分更新 + update = KBSettingsUpdate(default_hit_processing=HIT_PROCESSING_DIRECT) + result = await store.update_settings("kb1", update) + + assert result.default_hit_processing == HIT_PROCESSING_DIRECT + # 其他字段保持默认 + assert result.default_query_mode == QueryMode.blend + assert result.caching_disabled is False + assert result.owner == "user1" + + async def test_update_settings_multiple_fields(self): + """update_settings 同时更新多个字段。""" + store = KBSettingsStore() + await store.get_or_create("kb1", owner="user1") + update = KBSettingsUpdate( + caching_disabled=True, + rerank_provider="bge", + rerank_enabled=False, + ) + result = await store.update_settings("kb1", update) + + assert result.caching_disabled is True + assert result.rerank_provider == "bge" + assert result.rerank_enabled is False + + async def test_update_settings_persists(self): + """update_settings 后 get_settings 返回更新后的值。""" + store = KBSettingsStore() + await store.update_settings("kb1", KBSettingsUpdate(caching_disabled=True), owner="user1") + retrieved = await store.get_settings("kb1") + assert retrieved is not None + assert retrieved.caching_disabled is True + + async def test_update_settings_none_field_ignored(self): + """update_settings 中 None 字段被忽略。""" + store = KBSettingsStore() + await store.update_settings( + "kb1", + KBSettingsUpdate(caching_disabled=True, default_hit_processing=HIT_PROCESSING_DIRECT), + owner="user1", + ) + # 第二次更新 — 只改 caching_disabled,default_hit_processing 应保持 + await store.update_settings( + "kb1", + KBSettingsUpdate(caching_disabled=False), + ) + result = await store.get_settings("kb1") + assert result is not None + assert result.caching_disabled is False + assert result.default_hit_processing == HIT_PROCESSING_DIRECT + + +# --------------------------------------------------------------------------- +# owner 校验测试 +# --------------------------------------------------------------------------- + + +class TestOwnerCheck: + """owner 校验测试。""" + + async def test_is_owner_true(self): + """owner 匹配返回 True。""" + store = KBSettingsStore() + await store.get_or_create("kb1", owner="user1") + assert store.is_owner("kb1", "user1") is True + + async def test_is_owner_false_wrong_user(self): + """非 owner 用户返回 False。""" + store = KBSettingsStore() + await store.get_or_create("kb1", owner="user1") + assert store.is_owner("kb1", "user2") is False + + async def test_is_owner_false_none_user(self): + """None 用户返回 False。""" + store = KBSettingsStore() + await store.get_or_create("kb1", owner="user1") + assert store.is_owner("kb1", None) is False + + async def test_is_owner_false_nonexistent_kb(self): + """不存在的 KB 返回 False。""" + store = KBSettingsStore() + assert store.is_owner("nonexistent", "user1") is False + + async def test_is_owner_false_no_owner_set(self): + """owner 未设置时返回 False。""" + store = KBSettingsStore() + await store.get_or_create("kb1") # 不传 owner + assert store.is_owner("kb1", "user1") is False + + async def test_set_owner(self): + """set_owner 设置所有者。""" + store = KBSettingsStore() + await store.get_or_create("kb1", owner="user1") + assert store.set_owner("kb1", "user2") is True + assert store.is_owner("kb1", "user2") is True + assert store.is_owner("kb1", "user1") is False + + async def test_set_owner_nonexistent(self): + """set_owner 对不存在的 KB 返回 False。""" + store = KBSettingsStore() + assert store.set_owner("nonexistent", "user1") is False + + +# --------------------------------------------------------------------------- +# KB 默认模式生效测试(与 HitProcessor 集成) +# --------------------------------------------------------------------------- + + +class TestKBSettingsDefaultModeIntegration: + """KB 设置默认模式与 HitProcessor 集成测试。""" + + async def test_default_model_opt_flows_to_processor(self): + """KB 默认 model_opt 模式传递给 HitProcessor。""" + from unittest.mock import AsyncMock, MagicMock + + from agentkit.rag_platform.hit_processing import HitProcessor + from agentkit.rag_platform.models import QueryResult + + mock_llm = MagicMock() + mock_resp = MagicMock() + mock_resp.content = "LLM 回答" + mock_llm.chat = AsyncMock(return_value=mock_resp) + + store = KBSettingsStore() + settings = await store.get_or_create("kb1", owner="user1") + + processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False) + results = [ + QueryResult( + chunk_id="c1", + content="内容", + score=0.9, + metadata={}, + document_id="d1", + kb_id="kb1", + ) + ] + result = await processor.process("query", results, mode=settings.default_hit_processing) + assert result.mode == HIT_PROCESSING_MODEL_OPT + mock_llm.chat.assert_awaited_once() + + async def test_default_direct_flows_to_processor(self): + """KB 默认 direct 模式传递给 HitProcessor。""" + from agentkit.rag_platform.hit_processing import HitProcessor + from agentkit.rag_platform.models import QueryResult + + store = KBSettingsStore() + await store.update_settings( + "kb1", + KBSettingsUpdate(default_hit_processing=HIT_PROCESSING_DIRECT), + owner="user1", + ) + settings = await store.get_settings("kb1") + assert settings is not None + + processor = HitProcessor(llm_gateway=None, cache_enabled=False) + results = [ + QueryResult( + chunk_id="c1", + content="直接段落", + score=0.9, + metadata={}, + document_id="d1", + kb_id="kb1", + ) + ] + result = await processor.process("query", results, mode=settings.default_hit_processing) + assert result.mode == HIT_PROCESSING_DIRECT + assert "直接段落" in result.answer