fischer-agentkit/tests/unit/rag_platform/test_settings.py

330 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""U6 测试 — KB 设置模型与存储。
测试场景:
1. KBSettings 默认值
2. KBSettingsUpdate 部分更新语义
3. KBSettingsStore CRUDget / get_or_create / update
4. owner 校验is_owner / set_owner
5. KB 设置默认模式生效(与 HitProcessor 集成)
"""
from __future__ import annotations
from agentkit.rag_platform.hit_processing import (
HIT_PROCESSING_DIRECT,
HIT_PROCESSING_MODEL_OPT,
)
from agentkit.rag_platform.models import QueryMode
from agentkit.rag_platform.settings import (
KBSettings,
KBSettingsStore,
KBSettingsUpdate,
)
# ---------------------------------------------------------------------------
# KBSettings 模型测试
# ---------------------------------------------------------------------------
class TestKBSettings:
"""KBSettings 模型测试。"""
def test_defaults(self):
"""默认值正确。"""
settings = KBSettings(kb_id="kb1")
assert settings.kb_id == "kb1"
assert settings.owner is None
assert settings.default_query_mode == QueryMode.blend
assert settings.default_hit_processing == HIT_PROCESSING_MODEL_OPT
assert settings.caching_disabled is False
assert settings.rerank_enabled is True
assert settings.rerank_provider == "cohere"
assert settings.rerank_api_key is None
assert settings.rerank_base_url is None
assert settings.data_export_warning is False
def test_custom_values(self):
"""自定义值正确。"""
settings = KBSettings(
kb_id="kb1",
owner="user1",
default_query_mode=QueryMode.keywords,
default_hit_processing=HIT_PROCESSING_DIRECT,
caching_disabled=True,
rerank_enabled=False,
rerank_provider="bge",
rerank_api_key="key-123",
rerank_base_url="http://localhost:9997",
data_export_warning=True,
)
assert settings.owner == "user1"
assert settings.default_query_mode == QueryMode.keywords
assert settings.default_hit_processing == HIT_PROCESSING_DIRECT
assert settings.caching_disabled is True
assert settings.rerank_enabled is False
assert settings.rerank_provider == "bge"
assert settings.rerank_api_key == "key-123"
assert settings.rerank_base_url == "http://localhost:9997"
assert settings.data_export_warning is True
# ---------------------------------------------------------------------------
# KBSettingsUpdate 模型测试
# ---------------------------------------------------------------------------
class TestKBSettingsUpdate:
"""KBSettingsUpdate 模型测试。"""
def test_all_none_by_default(self):
"""默认所有字段为 None部分更新语义"""
update = KBSettingsUpdate()
assert update.default_query_mode is None
assert update.default_hit_processing is None
assert update.caching_disabled is None
assert update.rerank_enabled is None
assert update.rerank_provider is None
assert update.rerank_api_key is None
assert update.rerank_base_url is None
def test_partial_update(self):
"""部分字段赋值。"""
update = KBSettingsUpdate(default_hit_processing=HIT_PROCESSING_DIRECT)
assert update.default_hit_processing == HIT_PROCESSING_DIRECT
assert update.default_query_mode is None
def test_exclude_none_dumps(self):
"""model_dump(exclude_none=True) 仅包含已设置字段。"""
update = KBSettingsUpdate(caching_disabled=True)
data = update.model_dump(exclude_none=True)
assert data == {"caching_disabled": True}
def test_full_update(self):
"""所有字段同时赋值。"""
update = KBSettingsUpdate(
default_query_mode=QueryMode.embedding,
default_hit_processing=HIT_PROCESSING_DIRECT,
caching_disabled=True,
rerank_enabled=False,
rerank_provider="none",
rerank_api_key="key",
rerank_base_url="http://x",
)
data = update.model_dump(exclude_none=True)
assert len(data) == 7
# ---------------------------------------------------------------------------
# KBSettingsStore CRUD 测试
# ---------------------------------------------------------------------------
class TestKBSettingsStoreCRUD:
"""KBSettingsStore 存储测试。"""
async def test_get_settings_not_exist(self):
"""不存在的 KB 返回 None。"""
store = KBSettingsStore()
assert await store.get_settings("kb1") is None
async def test_get_or_create_creates_defaults(self):
"""get_or_create 创建默认设置。"""
store = KBSettingsStore()
settings = await store.get_or_create("kb1", owner="user1")
assert settings.kb_id == "kb1"
assert settings.owner == "user1"
assert settings.default_hit_processing == HIT_PROCESSING_MODEL_OPT
async def test_get_or_create_idempotent(self):
"""get_or_create 幂等 — 已存在时不覆盖 owner。"""
store = KBSettingsStore()
await store.get_or_create("kb1", owner="user1")
second = await store.get_or_create("kb1", owner="user2")
assert second.owner == "user1" # 不覆盖已有 owner
async def test_update_settings_creates_if_not_exist(self):
"""update_settings 在设置不存在时创建。"""
store = KBSettingsStore()
update = KBSettingsUpdate(caching_disabled=True)
result = await store.update_settings("kb1", update, owner="user1")
assert result.caching_disabled is True
assert result.owner == "user1"
assert result.kb_id == "kb1"
async def test_update_settings_partial(self):
"""update_settings 仅更新提供的字段,其他字段保持不变。"""
store = KBSettingsStore()
# 先创建默认设置
await store.get_or_create("kb1", owner="user1")
# 部分更新
update = KBSettingsUpdate(default_hit_processing=HIT_PROCESSING_DIRECT)
result = await store.update_settings("kb1", update)
assert result.default_hit_processing == HIT_PROCESSING_DIRECT
# 其他字段保持默认
assert result.default_query_mode == QueryMode.blend
assert result.caching_disabled is False
assert result.owner == "user1"
async def test_update_settings_multiple_fields(self):
"""update_settings 同时更新多个字段。"""
store = KBSettingsStore()
await store.get_or_create("kb1", owner="user1")
update = KBSettingsUpdate(
caching_disabled=True,
rerank_provider="bge",
rerank_enabled=False,
)
result = await store.update_settings("kb1", update)
assert result.caching_disabled is True
assert result.rerank_provider == "bge"
assert result.rerank_enabled is False
async def test_update_settings_persists(self):
"""update_settings 后 get_settings 返回更新后的值。"""
store = KBSettingsStore()
await store.update_settings("kb1", KBSettingsUpdate(caching_disabled=True), owner="user1")
retrieved = await store.get_settings("kb1")
assert retrieved is not None
assert retrieved.caching_disabled is True
async def test_update_settings_none_field_ignored(self):
"""update_settings 中 None 字段被忽略。"""
store = KBSettingsStore()
await store.update_settings(
"kb1",
KBSettingsUpdate(caching_disabled=True, default_hit_processing=HIT_PROCESSING_DIRECT),
owner="user1",
)
# 第二次更新 — 只改 caching_disableddefault_hit_processing 应保持
await store.update_settings(
"kb1",
KBSettingsUpdate(caching_disabled=False),
)
result = await store.get_settings("kb1")
assert result is not None
assert result.caching_disabled is False
assert result.default_hit_processing == HIT_PROCESSING_DIRECT
# ---------------------------------------------------------------------------
# owner 校验测试
# ---------------------------------------------------------------------------
class TestOwnerCheck:
"""owner 校验测试。"""
async def test_is_owner_true(self):
"""owner 匹配返回 True。"""
store = KBSettingsStore()
await store.get_or_create("kb1", owner="user1")
assert store.is_owner("kb1", "user1") is True
async def test_is_owner_false_wrong_user(self):
"""非 owner 用户返回 False。"""
store = KBSettingsStore()
await store.get_or_create("kb1", owner="user1")
assert store.is_owner("kb1", "user2") is False
async def test_is_owner_false_none_user(self):
"""None 用户返回 False。"""
store = KBSettingsStore()
await store.get_or_create("kb1", owner="user1")
assert store.is_owner("kb1", None) is False
async def test_is_owner_false_nonexistent_kb(self):
"""不存在的 KB 返回 False。"""
store = KBSettingsStore()
assert store.is_owner("nonexistent", "user1") is False
async def test_is_owner_false_no_owner_set(self):
"""owner 未设置时返回 False。"""
store = KBSettingsStore()
await store.get_or_create("kb1") # 不传 owner
assert store.is_owner("kb1", "user1") is False
async def test_set_owner(self):
"""set_owner 设置所有者。"""
store = KBSettingsStore()
await store.get_or_create("kb1", owner="user1")
assert store.set_owner("kb1", "user2") is True
assert store.is_owner("kb1", "user2") is True
assert store.is_owner("kb1", "user1") is False
async def test_set_owner_nonexistent(self):
"""set_owner 对不存在的 KB 返回 False。"""
store = KBSettingsStore()
assert store.set_owner("nonexistent", "user1") is False
# ---------------------------------------------------------------------------
# KB 默认模式生效测试(与 HitProcessor 集成)
# ---------------------------------------------------------------------------
class TestKBSettingsDefaultModeIntegration:
"""KB 设置默认模式与 HitProcessor 集成测试。"""
async def test_default_model_opt_flows_to_processor(self):
"""KB 默认 model_opt 模式传递给 HitProcessor。"""
from unittest.mock import AsyncMock, MagicMock
from agentkit.rag_platform.hit_processing import HitProcessor
from agentkit.rag_platform.models import QueryResult
mock_llm = MagicMock()
mock_resp = MagicMock()
mock_resp.content = "LLM 回答"
mock_llm.chat = AsyncMock(return_value=mock_resp)
store = KBSettingsStore()
settings = await store.get_or_create("kb1", owner="user1")
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
results = [
QueryResult(
chunk_id="c1",
content="内容",
score=0.9,
metadata={},
document_id="d1",
kb_id="kb1",
)
]
result = await processor.process("query", results, mode=settings.default_hit_processing)
assert result.mode == HIT_PROCESSING_MODEL_OPT
mock_llm.chat.assert_awaited_once()
async def test_default_direct_flows_to_processor(self):
"""KB 默认 direct 模式传递给 HitProcessor。"""
from agentkit.rag_platform.hit_processing import HitProcessor
from agentkit.rag_platform.models import QueryResult
store = KBSettingsStore()
await store.update_settings(
"kb1",
KBSettingsUpdate(default_hit_processing=HIT_PROCESSING_DIRECT),
owner="user1",
)
settings = await store.get_settings("kb1")
assert settings is not None
processor = HitProcessor(llm_gateway=None, cache_enabled=False)
results = [
QueryResult(
chunk_id="c1",
content="直接段落",
score=0.9,
metadata={},
document_id="d1",
kb_id="kb1",
)
]
result = await processor.process("query", results, mode=settings.default_hit_processing)
assert result.mode == HIT_PROCESSING_DIRECT
assert "直接段落" in result.answer