330 lines
12 KiB
Python
330 lines
12 KiB
Python
"""U6 测试 — KB 设置模型与存储。
|
||
|
||
测试场景:
|
||
1. KBSettings 默认值
|
||
2. KBSettingsUpdate 部分更新语义
|
||
3. KBSettingsStore CRUD(get / get_or_create / update)
|
||
4. owner 校验(is_owner / set_owner)
|
||
5. KB 设置默认模式生效(与 HitProcessor 集成)
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
from agentkit.rag_platform.hit_processing import (
|
||
HIT_PROCESSING_DIRECT,
|
||
HIT_PROCESSING_MODEL_OPT,
|
||
)
|
||
from agentkit.rag_platform.models import QueryMode
|
||
from agentkit.rag_platform.settings import (
|
||
KBSettings,
|
||
KBSettingsStore,
|
||
KBSettingsUpdate,
|
||
)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# KBSettings 模型测试
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestKBSettings:
|
||
"""KBSettings 模型测试。"""
|
||
|
||
def test_defaults(self):
|
||
"""默认值正确。"""
|
||
settings = KBSettings(kb_id="kb1")
|
||
assert settings.kb_id == "kb1"
|
||
assert settings.owner is None
|
||
assert settings.default_query_mode == QueryMode.blend
|
||
assert settings.default_hit_processing == HIT_PROCESSING_MODEL_OPT
|
||
assert settings.caching_disabled is False
|
||
assert settings.rerank_enabled is True
|
||
assert settings.rerank_provider == "cohere"
|
||
assert settings.rerank_api_key is None
|
||
assert settings.rerank_base_url is None
|
||
assert settings.data_export_warning is False
|
||
|
||
def test_custom_values(self):
|
||
"""自定义值正确。"""
|
||
settings = KBSettings(
|
||
kb_id="kb1",
|
||
owner="user1",
|
||
default_query_mode=QueryMode.keywords,
|
||
default_hit_processing=HIT_PROCESSING_DIRECT,
|
||
caching_disabled=True,
|
||
rerank_enabled=False,
|
||
rerank_provider="bge",
|
||
rerank_api_key="key-123",
|
||
rerank_base_url="http://localhost:9997",
|
||
data_export_warning=True,
|
||
)
|
||
assert settings.owner == "user1"
|
||
assert settings.default_query_mode == QueryMode.keywords
|
||
assert settings.default_hit_processing == HIT_PROCESSING_DIRECT
|
||
assert settings.caching_disabled is True
|
||
assert settings.rerank_enabled is False
|
||
assert settings.rerank_provider == "bge"
|
||
assert settings.rerank_api_key == "key-123"
|
||
assert settings.rerank_base_url == "http://localhost:9997"
|
||
assert settings.data_export_warning is True
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# KBSettingsUpdate 模型测试
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestKBSettingsUpdate:
|
||
"""KBSettingsUpdate 模型测试。"""
|
||
|
||
def test_all_none_by_default(self):
|
||
"""默认所有字段为 None(部分更新语义)。"""
|
||
update = KBSettingsUpdate()
|
||
assert update.default_query_mode is None
|
||
assert update.default_hit_processing is None
|
||
assert update.caching_disabled is None
|
||
assert update.rerank_enabled is None
|
||
assert update.rerank_provider is None
|
||
assert update.rerank_api_key is None
|
||
assert update.rerank_base_url is None
|
||
|
||
def test_partial_update(self):
|
||
"""部分字段赋值。"""
|
||
update = KBSettingsUpdate(default_hit_processing=HIT_PROCESSING_DIRECT)
|
||
assert update.default_hit_processing == HIT_PROCESSING_DIRECT
|
||
assert update.default_query_mode is None
|
||
|
||
def test_exclude_none_dumps(self):
|
||
"""model_dump(exclude_none=True) 仅包含已设置字段。"""
|
||
update = KBSettingsUpdate(caching_disabled=True)
|
||
data = update.model_dump(exclude_none=True)
|
||
assert data == {"caching_disabled": True}
|
||
|
||
def test_full_update(self):
|
||
"""所有字段同时赋值。"""
|
||
update = KBSettingsUpdate(
|
||
default_query_mode=QueryMode.embedding,
|
||
default_hit_processing=HIT_PROCESSING_DIRECT,
|
||
caching_disabled=True,
|
||
rerank_enabled=False,
|
||
rerank_provider="none",
|
||
rerank_api_key="key",
|
||
rerank_base_url="http://x",
|
||
)
|
||
data = update.model_dump(exclude_none=True)
|
||
assert len(data) == 7
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# KBSettingsStore CRUD 测试
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestKBSettingsStoreCRUD:
|
||
"""KBSettingsStore 存储测试。"""
|
||
|
||
async def test_get_settings_not_exist(self):
|
||
"""不存在的 KB 返回 None。"""
|
||
store = KBSettingsStore()
|
||
assert await store.get_settings("kb1") is None
|
||
|
||
async def test_get_or_create_creates_defaults(self):
|
||
"""get_or_create 创建默认设置。"""
|
||
store = KBSettingsStore()
|
||
settings = await store.get_or_create("kb1", owner="user1")
|
||
assert settings.kb_id == "kb1"
|
||
assert settings.owner == "user1"
|
||
assert settings.default_hit_processing == HIT_PROCESSING_MODEL_OPT
|
||
|
||
async def test_get_or_create_idempotent(self):
|
||
"""get_or_create 幂等 — 已存在时不覆盖 owner。"""
|
||
store = KBSettingsStore()
|
||
await store.get_or_create("kb1", owner="user1")
|
||
second = await store.get_or_create("kb1", owner="user2")
|
||
assert second.owner == "user1" # 不覆盖已有 owner
|
||
|
||
async def test_update_settings_creates_if_not_exist(self):
|
||
"""update_settings 在设置不存在时创建。"""
|
||
store = KBSettingsStore()
|
||
update = KBSettingsUpdate(caching_disabled=True)
|
||
result = await store.update_settings("kb1", update, owner="user1")
|
||
|
||
assert result.caching_disabled is True
|
||
assert result.owner == "user1"
|
||
assert result.kb_id == "kb1"
|
||
|
||
async def test_update_settings_partial(self):
|
||
"""update_settings 仅更新提供的字段,其他字段保持不变。"""
|
||
store = KBSettingsStore()
|
||
# 先创建默认设置
|
||
await store.get_or_create("kb1", owner="user1")
|
||
# 部分更新
|
||
update = KBSettingsUpdate(default_hit_processing=HIT_PROCESSING_DIRECT)
|
||
result = await store.update_settings("kb1", update)
|
||
|
||
assert result.default_hit_processing == HIT_PROCESSING_DIRECT
|
||
# 其他字段保持默认
|
||
assert result.default_query_mode == QueryMode.blend
|
||
assert result.caching_disabled is False
|
||
assert result.owner == "user1"
|
||
|
||
async def test_update_settings_multiple_fields(self):
|
||
"""update_settings 同时更新多个字段。"""
|
||
store = KBSettingsStore()
|
||
await store.get_or_create("kb1", owner="user1")
|
||
update = KBSettingsUpdate(
|
||
caching_disabled=True,
|
||
rerank_provider="bge",
|
||
rerank_enabled=False,
|
||
)
|
||
result = await store.update_settings("kb1", update)
|
||
|
||
assert result.caching_disabled is True
|
||
assert result.rerank_provider == "bge"
|
||
assert result.rerank_enabled is False
|
||
|
||
async def test_update_settings_persists(self):
|
||
"""update_settings 后 get_settings 返回更新后的值。"""
|
||
store = KBSettingsStore()
|
||
await store.update_settings("kb1", KBSettingsUpdate(caching_disabled=True), owner="user1")
|
||
retrieved = await store.get_settings("kb1")
|
||
assert retrieved is not None
|
||
assert retrieved.caching_disabled is True
|
||
|
||
async def test_update_settings_none_field_ignored(self):
|
||
"""update_settings 中 None 字段被忽略。"""
|
||
store = KBSettingsStore()
|
||
await store.update_settings(
|
||
"kb1",
|
||
KBSettingsUpdate(caching_disabled=True, default_hit_processing=HIT_PROCESSING_DIRECT),
|
||
owner="user1",
|
||
)
|
||
# 第二次更新 — 只改 caching_disabled,default_hit_processing 应保持
|
||
await store.update_settings(
|
||
"kb1",
|
||
KBSettingsUpdate(caching_disabled=False),
|
||
)
|
||
result = await store.get_settings("kb1")
|
||
assert result is not None
|
||
assert result.caching_disabled is False
|
||
assert result.default_hit_processing == HIT_PROCESSING_DIRECT
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# owner 校验测试
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestOwnerCheck:
|
||
"""owner 校验测试。"""
|
||
|
||
async def test_is_owner_true(self):
|
||
"""owner 匹配返回 True。"""
|
||
store = KBSettingsStore()
|
||
await store.get_or_create("kb1", owner="user1")
|
||
assert store.is_owner("kb1", "user1") is True
|
||
|
||
async def test_is_owner_false_wrong_user(self):
|
||
"""非 owner 用户返回 False。"""
|
||
store = KBSettingsStore()
|
||
await store.get_or_create("kb1", owner="user1")
|
||
assert store.is_owner("kb1", "user2") is False
|
||
|
||
async def test_is_owner_false_none_user(self):
|
||
"""None 用户返回 False。"""
|
||
store = KBSettingsStore()
|
||
await store.get_or_create("kb1", owner="user1")
|
||
assert store.is_owner("kb1", None) is False
|
||
|
||
async def test_is_owner_false_nonexistent_kb(self):
|
||
"""不存在的 KB 返回 False。"""
|
||
store = KBSettingsStore()
|
||
assert store.is_owner("nonexistent", "user1") is False
|
||
|
||
async def test_is_owner_false_no_owner_set(self):
|
||
"""owner 未设置时返回 False。"""
|
||
store = KBSettingsStore()
|
||
await store.get_or_create("kb1") # 不传 owner
|
||
assert store.is_owner("kb1", "user1") is False
|
||
|
||
async def test_set_owner(self):
|
||
"""set_owner 设置所有者。"""
|
||
store = KBSettingsStore()
|
||
await store.get_or_create("kb1", owner="user1")
|
||
assert store.set_owner("kb1", "user2") is True
|
||
assert store.is_owner("kb1", "user2") is True
|
||
assert store.is_owner("kb1", "user1") is False
|
||
|
||
async def test_set_owner_nonexistent(self):
|
||
"""set_owner 对不存在的 KB 返回 False。"""
|
||
store = KBSettingsStore()
|
||
assert store.set_owner("nonexistent", "user1") is False
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# KB 默认模式生效测试(与 HitProcessor 集成)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestKBSettingsDefaultModeIntegration:
|
||
"""KB 设置默认模式与 HitProcessor 集成测试。"""
|
||
|
||
async def test_default_model_opt_flows_to_processor(self):
|
||
"""KB 默认 model_opt 模式传递给 HitProcessor。"""
|
||
from unittest.mock import AsyncMock, MagicMock
|
||
|
||
from agentkit.rag_platform.hit_processing import HitProcessor
|
||
from agentkit.rag_platform.models import QueryResult
|
||
|
||
mock_llm = MagicMock()
|
||
mock_resp = MagicMock()
|
||
mock_resp.content = "LLM 回答"
|
||
mock_llm.chat = AsyncMock(return_value=mock_resp)
|
||
|
||
store = KBSettingsStore()
|
||
settings = await store.get_or_create("kb1", owner="user1")
|
||
|
||
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
|
||
results = [
|
||
QueryResult(
|
||
chunk_id="c1",
|
||
content="内容",
|
||
score=0.9,
|
||
metadata={},
|
||
document_id="d1",
|
||
kb_id="kb1",
|
||
)
|
||
]
|
||
result = await processor.process("query", results, mode=settings.default_hit_processing)
|
||
assert result.mode == HIT_PROCESSING_MODEL_OPT
|
||
mock_llm.chat.assert_awaited_once()
|
||
|
||
async def test_default_direct_flows_to_processor(self):
|
||
"""KB 默认 direct 模式传递给 HitProcessor。"""
|
||
from agentkit.rag_platform.hit_processing import HitProcessor
|
||
from agentkit.rag_platform.models import QueryResult
|
||
|
||
store = KBSettingsStore()
|
||
await store.update_settings(
|
||
"kb1",
|
||
KBSettingsUpdate(default_hit_processing=HIT_PROCESSING_DIRECT),
|
||
owner="user1",
|
||
)
|
||
settings = await store.get_settings("kb1")
|
||
assert settings is not None
|
||
|
||
processor = HitProcessor(llm_gateway=None, cache_enabled=False)
|
||
results = [
|
||
QueryResult(
|
||
chunk_id="c1",
|
||
content="直接段落",
|
||
score=0.9,
|
||
metadata={},
|
||
document_id="d1",
|
||
kb_id="kb1",
|
||
)
|
||
]
|
||
result = await processor.process("query", results, mode=settings.default_hit_processing)
|
||
assert result.mode == HIT_PROCESSING_DIRECT
|
||
assert "直接段落" in result.answer
|