1247 lines
41 KiB
Python
1247 lines
41 KiB
Python
"""Tests for KnowledgeBase adapters — 飞书、Confluence、通用 HTTP 适配器"""
|
||
|
||
import pytest
|
||
import time
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
from agentkit.memory.knowledge_base import Document, QueryResult, SourceInfo, KnowledgeBase
|
||
from agentkit.memory.adapters.base import KBAdapter
|
||
from agentkit.memory.adapters.feishu import FeishuKBAdapter
|
||
from agentkit.memory.adapters.confluence import ConfluenceAdapter
|
||
from agentkit.memory.adapters.generic_http import GenericHTTPAdapter
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# KnowledgeBase Protocol tests
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestKnowledgeBaseProtocol:
|
||
"""KnowledgeBase 协议验证"""
|
||
|
||
def test_document_creation(self):
|
||
doc = Document(doc_id="d1", content="测试内容", title="测试文档")
|
||
assert doc.doc_id == "d1"
|
||
assert doc.content == "测试内容"
|
||
assert doc.title == "测试文档"
|
||
assert doc.metadata == {}
|
||
|
||
def test_query_result_creation(self):
|
||
result = QueryResult(
|
||
content="检索结果",
|
||
source_id="feishu-xxx",
|
||
source_name="飞书知识库",
|
||
score=0.92,
|
||
)
|
||
assert result.content == "检索结果"
|
||
assert result.score == 0.92
|
||
assert result.doc_id == ""
|
||
assert result.metadata == {}
|
||
|
||
def test_source_info_creation(self):
|
||
info = SourceInfo(
|
||
source_id="feishu-xxx",
|
||
source_name="飞书知识库",
|
||
source_type="feishu",
|
||
document_count=100,
|
||
)
|
||
assert info.source_id == "feishu-xxx"
|
||
assert info.source_type == "feishu"
|
||
assert info.document_count == 100
|
||
|
||
def test_knowledge_base_protocol_check(self):
|
||
"""验证适配器满足 KnowledgeBase 协议"""
|
||
adapter = FeishuKBAdapter(
|
||
app_id="cli_test",
|
||
app_secret="secret",
|
||
)
|
||
assert isinstance(adapter, KnowledgeBase)
|
||
|
||
adapter2 = ConfluenceAdapter(
|
||
base_url="https://test.atlassian.net/wiki",
|
||
username="user@test.com",
|
||
api_token="token",
|
||
)
|
||
assert isinstance(adapter2, KnowledgeBase)
|
||
|
||
adapter3 = GenericHTTPAdapter(
|
||
endpoint_url="https://example.com/api/kb",
|
||
)
|
||
assert isinstance(adapter3, KnowledgeBase)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# KBAdapter base class tests
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestKBAdapterBase:
|
||
"""KBAdapter 抽象基类测试"""
|
||
|
||
def _make_concrete_adapter(self) -> KBAdapter:
|
||
"""创建一个具体子类用于测试"""
|
||
|
||
class ConcreteAdapter(KBAdapter):
|
||
def _make_client(self):
|
||
return MagicMock()
|
||
|
||
async def search(self, query, top_k=5):
|
||
return []
|
||
|
||
async def health_check(self):
|
||
return True
|
||
|
||
return ConcreteAdapter(
|
||
source_id="test-adapter",
|
||
source_name="Test Adapter",
|
||
source_type="test",
|
||
)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_query_delegates_to_search(self):
|
||
adapter = self._make_concrete_adapter()
|
||
adapter.search = AsyncMock(return_value=[
|
||
QueryResult(content="result", source_id="test", source_name="test", score=0.9)
|
||
])
|
||
|
||
results = await adapter.query("test query", top_k=3)
|
||
adapter.search.assert_called_once_with("test query", top_k=3)
|
||
assert len(results) == 1
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_list_sources_default(self):
|
||
adapter = self._make_concrete_adapter()
|
||
sources = await adapter.list_sources()
|
||
assert len(sources) == 1
|
||
assert sources[0].source_id == "test-adapter"
|
||
assert sources[0].source_name == "Test Adapter"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_ingest_default_returns_empty(self):
|
||
adapter = self._make_concrete_adapter()
|
||
docs = [Document(doc_id="d1", content="test")]
|
||
ids = await adapter.ingest(docs)
|
||
assert ids == []
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_delete_by_id_default_returns_false(self):
|
||
adapter = self._make_concrete_adapter()
|
||
result = await adapter.delete_by_id("d1")
|
||
assert result is False
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_document_default_returns_none(self):
|
||
adapter = self._make_concrete_adapter()
|
||
result = await adapter.get_document("d1")
|
||
assert result is None
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_authenticate_delegates_to_health_check(self):
|
||
adapter = self._make_concrete_adapter()
|
||
adapter.health_check = AsyncMock(return_value=True)
|
||
result = await adapter.authenticate()
|
||
assert result is True
|
||
adapter.health_check.assert_called_once()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_authenticate_failure(self):
|
||
adapter = self._make_concrete_adapter()
|
||
adapter.health_check = AsyncMock(side_effect=Exception("connection error"))
|
||
result = await adapter.authenticate()
|
||
assert result is False
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_context_manager(self):
|
||
adapter = self._make_concrete_adapter()
|
||
adapter.close = AsyncMock()
|
||
async with adapter as a:
|
||
assert a is adapter
|
||
adapter.close.assert_called_once()
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# FeishuKBAdapter tests
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestFeishuKBAdapterInit:
|
||
"""FeishuKBAdapter 初始化"""
|
||
|
||
def test_basic_init(self):
|
||
adapter = FeishuKBAdapter(
|
||
app_id="cli_test1234",
|
||
app_secret="secret",
|
||
)
|
||
assert adapter._app_id == "cli_test1234"
|
||
assert adapter._app_secret == "secret"
|
||
assert adapter._base_url == "https://open.feishu.cn/open-apis"
|
||
assert adapter._space_ids == []
|
||
assert adapter._source_type == "feishu"
|
||
|
||
def test_init_with_custom_base_url(self):
|
||
adapter = FeishuKBAdapter(
|
||
app_id="cli_test1234",
|
||
app_secret="secret",
|
||
base_url="https://internal.feishu.cn/open-apis/",
|
||
space_ids=["space-1", "space-2"],
|
||
timeout=60,
|
||
)
|
||
assert adapter._base_url == "https://internal.feishu.cn/open-apis"
|
||
assert adapter._space_ids == ["space-1", "space-2"]
|
||
assert adapter._timeout == 60
|
||
|
||
|
||
class TestFeishuKBAdapterAuth:
|
||
"""FeishuKBAdapter 认证"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_authenticate_success(self):
|
||
adapter = FeishuKBAdapter(
|
||
app_id="cli_test",
|
||
app_secret="secret",
|
||
)
|
||
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 200
|
||
mock_resp.raise_for_status = MagicMock()
|
||
mock_resp.json.return_value = {
|
||
"code": 0,
|
||
"msg": "ok",
|
||
"tenant_access_token": "t-xxx",
|
||
"expire": 7200,
|
||
}
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||
adapter._make_client = MagicMock(return_value=mock_client)
|
||
|
||
result = await adapter.authenticate()
|
||
assert result is True
|
||
assert adapter._access_token == "t-xxx"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_authenticate_failure(self):
|
||
adapter = FeishuKBAdapter(
|
||
app_id="cli_test",
|
||
app_secret="wrong_secret",
|
||
)
|
||
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 200
|
||
mock_resp.raise_for_status = MagicMock()
|
||
mock_resp.json.return_value = {
|
||
"code": 10014,
|
||
"msg": "invalid app_id or app_secret",
|
||
}
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||
adapter._make_client = MagicMock(return_value=mock_client)
|
||
|
||
result = await adapter.authenticate()
|
||
assert result is False
|
||
|
||
|
||
class TestFeishuKBAdapterSearch:
|
||
"""FeishuKBAdapter 检索"""
|
||
|
||
@pytest.fixture
|
||
def adapter(self):
|
||
return FeishuKBAdapter(
|
||
app_id="cli_test",
|
||
app_secret="secret",
|
||
space_ids=["space-1"],
|
||
)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_search_success(self, adapter):
|
||
# Mock authentication
|
||
adapter._access_token = "t-xxx"
|
||
adapter._token_expiry = time.time() + 7200
|
||
adapter._token_expiry = time.time() + 7200
|
||
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 200
|
||
mock_resp.raise_for_status = MagicMock()
|
||
mock_resp.json.return_value = {
|
||
"code": 0,
|
||
"data": {
|
||
"items": [
|
||
{
|
||
"wiki_token": "wikcnxxx",
|
||
"title": "飞书知识库文档",
|
||
"content": "这是飞书知识库的内容",
|
||
"score": 0.92,
|
||
"space_id": "space-1",
|
||
},
|
||
{
|
||
"wiki_token": "wikcnyyy",
|
||
"title": "另一个文档",
|
||
"content": "另一个文档的内容",
|
||
"score": 0.85,
|
||
"space_id": "space-1",
|
||
},
|
||
],
|
||
},
|
||
}
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||
adapter._make_client = MagicMock(return_value=mock_client)
|
||
|
||
results = await adapter.search("飞书知识库", top_k=5)
|
||
|
||
assert len(results) == 2
|
||
assert results[0].content == "这是飞书知识库的内容"
|
||
assert results[0].score == 0.92
|
||
assert results[0].source_id.startswith("feishu-")
|
||
assert results[0].source_name == "飞书知识库"
|
||
assert results[0].doc_id == "wikcnxxx"
|
||
assert results[0].title == "飞书知识库文档"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_search_not_authenticated(self, adapter):
|
||
adapter._access_token = None
|
||
adapter._get_access_token = AsyncMock(return_value=None)
|
||
|
||
results = await adapter.search("test")
|
||
assert results == []
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_search_api_error(self, adapter):
|
||
adapter._access_token = "t-xxx"
|
||
adapter._token_expiry = time.time() + 7200
|
||
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 200
|
||
mock_resp.raise_for_status = MagicMock()
|
||
mock_resp.json.return_value = {
|
||
"code": 9999,
|
||
"msg": "internal error",
|
||
}
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||
adapter._make_client = MagicMock(return_value=mock_client)
|
||
|
||
results = await adapter.search("test")
|
||
assert results == []
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_search_http_error(self, adapter):
|
||
import httpx
|
||
|
||
adapter._access_token = "t-xxx"
|
||
adapter._token_expiry = time.time() + 7200
|
||
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 500
|
||
mock_resp.text = "Internal Server Error"
|
||
mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||
"500", request=MagicMock(), response=mock_resp
|
||
)
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||
adapter._make_client = MagicMock(return_value=mock_client)
|
||
|
||
results = await adapter.search("test")
|
||
assert results == []
|
||
|
||
|
||
class TestFeishuKBAdapterHealthCheck:
|
||
"""FeishuKBAdapter 健康检查"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_health_check_ok(self):
|
||
adapter = FeishuKBAdapter(app_id="cli_test", app_secret="secret")
|
||
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 200
|
||
mock_resp.raise_for_status = MagicMock()
|
||
mock_resp.json.return_value = {"code": 0, "tenant_access_token": "t-xxx"}
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||
adapter._make_client = MagicMock(return_value=mock_client)
|
||
|
||
assert await adapter.health_check() is True
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_health_check_failure(self):
|
||
adapter = FeishuKBAdapter(app_id="cli_test", app_secret="secret")
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.post = AsyncMock(side_effect=Exception("Connection refused"))
|
||
adapter._make_client = MagicMock(return_value=mock_client)
|
||
|
||
assert await adapter.health_check() is False
|
||
|
||
|
||
class TestFeishuKBAdapterListSources:
|
||
"""FeishuKBAdapter 列出信息源"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_list_sources_success(self):
|
||
adapter = FeishuKBAdapter(app_id="cli_test", app_secret="secret")
|
||
adapter._access_token = "t-xxx"
|
||
adapter._token_expiry = time.time() + 7200
|
||
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 200
|
||
mock_resp.raise_for_status = MagicMock()
|
||
mock_resp.json.return_value = {
|
||
"data": {
|
||
"items": [
|
||
{"space_id": "space-1", "name": "产品文档"},
|
||
{"space_id": "space-2", "name": "技术文档"},
|
||
]
|
||
}
|
||
}
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.get = AsyncMock(return_value=mock_resp)
|
||
adapter._make_client = MagicMock(return_value=mock_client)
|
||
|
||
sources = await adapter.list_sources()
|
||
assert len(sources) == 2
|
||
assert sources[0].source_name == "产品文档"
|
||
assert sources[1].source_name == "技术文档"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_list_sources_not_authenticated(self):
|
||
adapter = FeishuKBAdapter(app_id="cli_test", app_secret="secret")
|
||
adapter._access_token = None
|
||
adapter._get_access_token = AsyncMock(return_value=None)
|
||
|
||
sources = await adapter.list_sources()
|
||
assert len(sources) == 1
|
||
assert sources[0].source_id.startswith("feishu-")
|
||
|
||
|
||
class TestFeishuKBAdapterGetDocument:
|
||
"""FeishuKBAdapter 获取文档"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_document_success(self):
|
||
adapter = FeishuKBAdapter(app_id="cli_test", app_secret="secret")
|
||
adapter._access_token = "t-xxx"
|
||
adapter._token_expiry = time.time() + 7200
|
||
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 200
|
||
mock_resp.raise_for_status = MagicMock()
|
||
mock_resp.json.return_value = {
|
||
"code": 0,
|
||
"data": {
|
||
"node": {
|
||
"title": "测试文档",
|
||
"content": "文档内容",
|
||
"space_id": "space-1",
|
||
"obj_type": "doc",
|
||
}
|
||
},
|
||
}
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.get = AsyncMock(return_value=mock_resp)
|
||
adapter._make_client = MagicMock(return_value=mock_client)
|
||
|
||
doc = await adapter.get_document("wikcnxxx")
|
||
assert doc is not None
|
||
assert doc.doc_id == "wikcnxxx"
|
||
assert doc.title == "测试文档"
|
||
assert doc.content == "文档内容"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_document_not_found(self):
|
||
adapter = FeishuKBAdapter(app_id="cli_test", app_secret="secret")
|
||
adapter._access_token = "t-xxx"
|
||
adapter._token_expiry = time.time() + 7200
|
||
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 200
|
||
mock_resp.raise_for_status = MagicMock()
|
||
mock_resp.json.return_value = {"code": 10004, "msg": "node not found"}
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.get = AsyncMock(return_value=mock_resp)
|
||
adapter._make_client = MagicMock(return_value=mock_client)
|
||
|
||
doc = await adapter.get_document("nonexistent")
|
||
assert doc is None
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# ConfluenceAdapter tests
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestConfluenceAdapterInit:
|
||
"""ConfluenceAdapter 初始化"""
|
||
|
||
def test_basic_init(self):
|
||
adapter = ConfluenceAdapter(
|
||
base_url="https://test.atlassian.net/wiki",
|
||
username="user@test.com",
|
||
api_token="token",
|
||
)
|
||
assert adapter._base_url == "https://test.atlassian.net/wiki"
|
||
assert adapter._username == "user@test.com"
|
||
assert adapter._api_token == "token"
|
||
assert adapter._space_keys == []
|
||
assert adapter._source_type == "confluence"
|
||
|
||
def test_init_with_spaces(self):
|
||
adapter = ConfluenceAdapter(
|
||
base_url="https://test.atlassian.net/wiki/",
|
||
username="user@test.com",
|
||
api_token="token",
|
||
space_keys=["DEV", "DOC"],
|
||
timeout=60,
|
||
)
|
||
assert adapter._base_url == "https://test.atlassian.net/wiki"
|
||
assert adapter._space_keys == ["DEV", "DOC"]
|
||
|
||
|
||
class TestConfluenceAdapterAuth:
|
||
"""ConfluenceAdapter 认证"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_authenticate_success(self):
|
||
adapter = ConfluenceAdapter(
|
||
base_url="https://test.atlassian.net/wiki",
|
||
username="user@test.com",
|
||
api_token="token",
|
||
)
|
||
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 200
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.get = AsyncMock(return_value=mock_resp)
|
||
adapter._make_client = MagicMock(return_value=mock_client)
|
||
|
||
result = await adapter.authenticate()
|
||
assert result is True
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_authenticate_failure(self):
|
||
adapter = ConfluenceAdapter(
|
||
base_url="https://test.atlassian.net/wiki",
|
||
username="user@test.com",
|
||
api_token="wrong_token",
|
||
)
|
||
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 401
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.get = AsyncMock(return_value=mock_resp)
|
||
adapter._make_client = MagicMock(return_value=mock_client)
|
||
|
||
result = await adapter.authenticate()
|
||
assert result is False
|
||
|
||
|
||
class TestConfluenceAdapterSearch:
|
||
"""ConfluenceAdapter 检索"""
|
||
|
||
@pytest.fixture
|
||
def adapter(self):
|
||
return ConfluenceAdapter(
|
||
base_url="https://test.atlassian.net/wiki",
|
||
username="user@test.com",
|
||
api_token="token",
|
||
space_keys=["DEV"],
|
||
)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_search_success(self, adapter):
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 200
|
||
mock_resp.raise_for_status = MagicMock()
|
||
mock_resp.json.return_value = {
|
||
"results": [
|
||
{
|
||
"id": "12345",
|
||
"title": "Confluence 页面",
|
||
"type": "page",
|
||
"status": "current",
|
||
"space": {"key": "DEV"},
|
||
"body": {
|
||
"storage": {
|
||
"value": "<p>这是 Confluence 页面内容</p>"
|
||
}
|
||
},
|
||
},
|
||
],
|
||
}
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.get = AsyncMock(return_value=mock_resp)
|
||
adapter._make_client = MagicMock(return_value=mock_client)
|
||
|
||
results = await adapter.search("Confluence", top_k=5)
|
||
|
||
assert len(results) == 1
|
||
assert "Confluence 页面内容" in results[0].content
|
||
assert results[0].source_id.startswith("confluence-")
|
||
assert results[0].source_name == "Confluence"
|
||
assert results[0].doc_id == "12345"
|
||
assert results[0].title == "Confluence 页面"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_search_with_space_filter(self, adapter):
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 200
|
||
mock_resp.raise_for_status = MagicMock()
|
||
mock_resp.json.return_value = {"results": []}
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.get = AsyncMock(return_value=mock_resp)
|
||
adapter._make_client = MagicMock(return_value=mock_client)
|
||
|
||
await adapter.search("test")
|
||
|
||
# Verify CQL includes space filter
|
||
call_args = mock_client.get.call_args
|
||
params = call_args[1].get("params", call_args[0][1] if len(call_args[0]) > 1 else {})
|
||
cql = params.get("cql", "")
|
||
assert 'space = "DEV"' in cql
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_search_http_error(self, adapter):
|
||
import httpx
|
||
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 500
|
||
mock_resp.text = "Internal Server Error"
|
||
mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||
"500", request=MagicMock(), response=mock_resp
|
||
)
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.get = AsyncMock(return_value=mock_resp)
|
||
adapter._make_client = MagicMock(return_value=mock_client)
|
||
|
||
results = await adapter.search("test")
|
||
assert results == []
|
||
|
||
|
||
class TestConfluenceAdapterHealthCheck:
|
||
"""ConfluenceAdapter 健康检查"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_health_check_ok(self):
|
||
adapter = ConfluenceAdapter(
|
||
base_url="https://test.atlassian.net/wiki",
|
||
username="user@test.com",
|
||
api_token="token",
|
||
)
|
||
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 200
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.get = AsyncMock(return_value=mock_resp)
|
||
adapter._make_client = MagicMock(return_value=mock_client)
|
||
|
||
assert await adapter.health_check() is True
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_health_check_failure(self):
|
||
adapter = ConfluenceAdapter(
|
||
base_url="https://test.atlassian.net/wiki",
|
||
username="user@test.com",
|
||
api_token="token",
|
||
)
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.get = AsyncMock(side_effect=Exception("Connection refused"))
|
||
adapter._make_client = MagicMock(return_value=mock_client)
|
||
|
||
assert await adapter.health_check() is False
|
||
|
||
|
||
class TestConfluenceAdapterGetDocument:
|
||
"""ConfluenceAdapter 获取文档"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_document_success(self):
|
||
adapter = ConfluenceAdapter(
|
||
base_url="https://test.atlassian.net/wiki",
|
||
username="user@test.com",
|
||
api_token="token",
|
||
)
|
||
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 200
|
||
mock_resp.raise_for_status = MagicMock()
|
||
mock_resp.json.return_value = {
|
||
"id": "12345",
|
||
"title": "测试页面",
|
||
"type": "page",
|
||
"space": {"key": "DEV"},
|
||
"body": {
|
||
"storage": {
|
||
"value": "<p>页面内容</p>"
|
||
}
|
||
},
|
||
"version": {"number": 3},
|
||
}
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.get = AsyncMock(return_value=mock_resp)
|
||
adapter._make_client = MagicMock(return_value=mock_client)
|
||
|
||
doc = await adapter.get_document("12345")
|
||
assert doc is not None
|
||
assert doc.doc_id == "12345"
|
||
assert doc.title == "测试页面"
|
||
assert "页面内容" in doc.content
|
||
assert doc.metadata["version"] == 3
|
||
|
||
|
||
class TestConfluenceAdapterListSources:
|
||
"""ConfluenceAdapter 列出信息源"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_list_sources_success(self):
|
||
adapter = ConfluenceAdapter(
|
||
base_url="https://test.atlassian.net/wiki",
|
||
username="user@test.com",
|
||
api_token="token",
|
||
)
|
||
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 200
|
||
mock_resp.raise_for_status = MagicMock()
|
||
mock_resp.json.return_value = {
|
||
"results": [
|
||
{"key": "DEV", "name": "Development"},
|
||
{"key": "DOC", "name": "Documentation"},
|
||
]
|
||
}
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.get = AsyncMock(return_value=mock_resp)
|
||
adapter._make_client = MagicMock(return_value=mock_client)
|
||
|
||
sources = await adapter.list_sources()
|
||
assert len(sources) == 2
|
||
assert sources[0].source_name == "Development"
|
||
assert sources[1].source_name == "Documentation"
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# GenericHTTPAdapter tests
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestGenericHTTPAdapterInit:
|
||
"""GenericHTTPAdapter 初始化"""
|
||
|
||
def test_basic_init(self):
|
||
adapter = GenericHTTPAdapter(
|
||
endpoint_url="https://example.com/api/kb",
|
||
)
|
||
assert adapter._endpoint_url == "https://example.com/api/kb"
|
||
assert adapter._auth_config == {}
|
||
assert adapter._extra_headers == {}
|
||
assert adapter._source_type == "generic_http"
|
||
|
||
def test_init_with_auth_bearer(self):
|
||
adapter = GenericHTTPAdapter(
|
||
endpoint_url="https://example.com/api/kb/",
|
||
auth_config={"type": "bearer", "token": "sk-test"},
|
||
headers={"X-Custom": "value"},
|
||
source_id="my-kb",
|
||
source_name="My KB",
|
||
timeout=60,
|
||
)
|
||
assert adapter._endpoint_url == "https://example.com/api/kb"
|
||
assert adapter._auth_config["type"] == "bearer"
|
||
assert adapter._extra_headers == {"X-Custom": "value"}
|
||
assert adapter._source_id == "my-kb"
|
||
assert adapter._source_name == "My KB"
|
||
|
||
def test_client_bearer_auth_header(self):
|
||
adapter = GenericHTTPAdapter(
|
||
endpoint_url="https://example.com/api/kb",
|
||
auth_config={"type": "bearer", "token": "sk-test"},
|
||
)
|
||
client = adapter._make_client()
|
||
assert "Bearer sk-test" in str(client.headers.get("Authorization", ""))
|
||
|
||
def test_client_basic_auth_header(self):
|
||
adapter = GenericHTTPAdapter(
|
||
endpoint_url="https://example.com/api/kb",
|
||
auth_config={"type": "basic", "username": "user", "password": "pass"},
|
||
)
|
||
client = adapter._make_client()
|
||
auth_header = str(client.headers.get("Authorization", ""))
|
||
assert auth_header.startswith("Basic ")
|
||
|
||
def test_client_api_key_header(self):
|
||
adapter = GenericHTTPAdapter(
|
||
endpoint_url="https://example.com/api/kb",
|
||
auth_config={"type": "api_key", "header_name": "X-API-Key", "api_key": "key123"},
|
||
)
|
||
client = adapter._make_client()
|
||
assert client.headers.get("X-API-Key") == "key123"
|
||
|
||
|
||
class TestGenericHTTPAdapterSearch:
|
||
"""GenericHTTPAdapter 检索"""
|
||
|
||
@pytest.fixture
|
||
def adapter(self):
|
||
return GenericHTTPAdapter(
|
||
endpoint_url="https://example.com/api/kb",
|
||
auth_config={"type": "bearer", "token": "sk-test"},
|
||
)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_search_standard_response(self, adapter):
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 200
|
||
mock_resp.raise_for_status = MagicMock()
|
||
mock_resp.json.return_value = {
|
||
"results": [
|
||
{
|
||
"content": "HTTP 知识库内容",
|
||
"score": 0.92,
|
||
"doc_id": "d1",
|
||
"title": "文档1",
|
||
"metadata": {"page": 1},
|
||
},
|
||
{
|
||
"content": "另一条结果",
|
||
"score": 0.85,
|
||
"doc_id": "d2",
|
||
"title": "文档2",
|
||
},
|
||
]
|
||
}
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||
adapter._make_client = MagicMock(return_value=mock_client)
|
||
|
||
results = await adapter.search("知识查询", top_k=5)
|
||
|
||
assert len(results) == 2
|
||
assert results[0].content == "HTTP 知识库内容"
|
||
assert results[0].score == 0.92
|
||
assert results[0].doc_id == "d1"
|
||
assert results[0].title == "文档1"
|
||
|
||
# Verify payload
|
||
call_args = mock_client.post.call_args
|
||
assert call_args[0][0] == "/search"
|
||
payload = call_args[1]["json"]
|
||
assert payload["query"] == "知识查询"
|
||
assert payload["top_k"] == 5
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_search_list_response(self, adapter):
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 200
|
||
mock_resp.raise_for_status = MagicMock()
|
||
mock_resp.json.return_value = [
|
||
{"content": "直接列表结果", "score": 0.8, "id": "d1"},
|
||
]
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||
adapter._make_client = MagicMock(return_value=mock_client)
|
||
|
||
results = await adapter.search("test")
|
||
assert len(results) == 1
|
||
assert results[0].content == "直接列表结果"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_search_http_error(self, adapter):
|
||
import httpx
|
||
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 500
|
||
mock_resp.text = "Internal Server Error"
|
||
mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||
"500", request=MagicMock(), response=mock_resp
|
||
)
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||
adapter._make_client = MagicMock(return_value=mock_client)
|
||
|
||
results = await adapter.search("test")
|
||
assert results == []
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_search_unexpected_format(self, adapter):
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 200
|
||
mock_resp.raise_for_status = MagicMock()
|
||
mock_resp.json.return_value = {"error": "unexpected"}
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||
adapter._make_client = MagicMock(return_value=mock_client)
|
||
|
||
results = await adapter.search("test")
|
||
assert results == []
|
||
|
||
|
||
class TestGenericHTTPAdapterIngest:
|
||
"""GenericHTTPAdapter 文档写入"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_ingest_success(self):
|
||
adapter = GenericHTTPAdapter(
|
||
endpoint_url="https://example.com/api/kb",
|
||
)
|
||
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 200
|
||
mock_resp.raise_for_status = MagicMock()
|
||
mock_resp.json.return_value = {"ids": ["d1", "d2"]}
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||
adapter._make_client = MagicMock(return_value=mock_client)
|
||
|
||
docs = [
|
||
Document(doc_id="d1", content="内容1", title="文档1"),
|
||
Document(doc_id="d2", content="内容2", title="文档2"),
|
||
]
|
||
ids = await adapter.ingest(docs)
|
||
|
||
assert ids == ["d1", "d2"]
|
||
call_args = mock_client.post.call_args
|
||
assert call_args[0][0] == "/ingest"
|
||
payload = call_args[1]["json"]
|
||
assert len(payload["documents"]) == 2
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_ingest_http_error(self):
|
||
import httpx
|
||
|
||
adapter = GenericHTTPAdapter(
|
||
endpoint_url="https://example.com/api/kb",
|
||
)
|
||
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 500
|
||
mock_resp.text = "Error"
|
||
mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||
"500", request=MagicMock(), response=mock_resp
|
||
)
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||
adapter._make_client = MagicMock(return_value=mock_client)
|
||
|
||
docs = [Document(doc_id="d1", content="test")]
|
||
ids = await adapter.ingest(docs)
|
||
assert ids == []
|
||
|
||
|
||
class TestGenericHTTPAdapterDeleteById:
|
||
"""GenericHTTPAdapter 按 ID 删除"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_delete_success(self):
|
||
adapter = GenericHTTPAdapter(
|
||
endpoint_url="https://example.com/api/kb",
|
||
)
|
||
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 200
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.delete = AsyncMock(return_value=mock_resp)
|
||
adapter._make_client = MagicMock(return_value=mock_client)
|
||
|
||
result = await adapter.delete_by_id("d1")
|
||
assert result is True
|
||
mock_client.delete.assert_called_once_with("/documents/d1")
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_delete_not_found(self):
|
||
adapter = GenericHTTPAdapter(
|
||
endpoint_url="https://example.com/api/kb",
|
||
)
|
||
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 404
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.delete = AsyncMock(return_value=mock_resp)
|
||
adapter._make_client = MagicMock(return_value=mock_client)
|
||
|
||
result = await adapter.delete_by_id("nonexistent")
|
||
assert result is False
|
||
|
||
|
||
class TestGenericHTTPAdapterGetDocument:
|
||
"""GenericHTTPAdapter 获取文档"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_document_success(self):
|
||
adapter = GenericHTTPAdapter(
|
||
endpoint_url="https://example.com/api/kb",
|
||
)
|
||
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 200
|
||
mock_resp.raise_for_status = MagicMock()
|
||
mock_resp.json.return_value = {
|
||
"id": "d1",
|
||
"content": "文档内容",
|
||
"title": "测试文档",
|
||
"metadata": {"page": 1},
|
||
}
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.get = AsyncMock(return_value=mock_resp)
|
||
adapter._make_client = MagicMock(return_value=mock_client)
|
||
|
||
doc = await adapter.get_document("d1")
|
||
assert doc is not None
|
||
assert doc.doc_id == "d1"
|
||
assert doc.content == "文档内容"
|
||
assert doc.title == "测试文档"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_document_not_found(self):
|
||
import httpx
|
||
|
||
adapter = GenericHTTPAdapter(
|
||
endpoint_url="https://example.com/api/kb",
|
||
)
|
||
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 404
|
||
mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||
"404", request=MagicMock(), response=mock_resp
|
||
)
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.get = AsyncMock(return_value=mock_resp)
|
||
adapter._make_client = MagicMock(return_value=mock_client)
|
||
|
||
doc = await adapter.get_document("nonexistent")
|
||
assert doc is None
|
||
|
||
|
||
class TestGenericHTTPAdapterHealthCheck:
|
||
"""GenericHTTPAdapter 健康检查"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_health_check_ok(self):
|
||
adapter = GenericHTTPAdapter(
|
||
endpoint_url="https://example.com/api/kb",
|
||
)
|
||
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 200
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.get = AsyncMock(return_value=mock_resp)
|
||
adapter._make_client = MagicMock(return_value=mock_client)
|
||
|
||
assert await adapter.health_check() is True
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_health_check_fallback_to_root(self):
|
||
"""health endpoint 不存在时回退到根路径"""
|
||
adapter = GenericHTTPAdapter(
|
||
endpoint_url="https://example.com/api/kb",
|
||
)
|
||
|
||
import httpx
|
||
|
||
# /health returns 404
|
||
health_resp = MagicMock()
|
||
health_resp.status_code = 404
|
||
health_resp.text = "Not Found"
|
||
health_resp.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||
"404", request=MagicMock(), response=health_resp
|
||
)
|
||
|
||
# / returns 200
|
||
root_resp = MagicMock()
|
||
root_resp.status_code = 200
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.get = AsyncMock(side_effect=[health_resp, root_resp])
|
||
adapter._make_client = MagicMock(return_value=mock_client)
|
||
|
||
assert await adapter.health_check() is True
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_health_check_connection_error(self):
|
||
adapter = GenericHTTPAdapter(
|
||
endpoint_url="https://example.com/api/kb",
|
||
)
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.get = AsyncMock(side_effect=Exception("Connection refused"))
|
||
adapter._make_client = MagicMock(return_value=mock_client)
|
||
|
||
assert await adapter.health_check() is False
|
||
|
||
|
||
class TestGenericHTTPAdapterListSources:
|
||
"""GenericHTTPAdapter 列出信息源"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_list_sources_success(self):
|
||
adapter = GenericHTTPAdapter(
|
||
endpoint_url="https://example.com/api/kb",
|
||
)
|
||
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 200
|
||
mock_resp.raise_for_status = MagicMock()
|
||
mock_resp.json.return_value = [
|
||
{"source_id": "src-1", "source_name": "知识库1", "source_type": "custom"},
|
||
{"source_id": "src-2", "source_name": "知识库2", "document_count": 50},
|
||
]
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.get = AsyncMock(return_value=mock_resp)
|
||
adapter._make_client = MagicMock(return_value=mock_client)
|
||
|
||
sources = await adapter.list_sources()
|
||
assert len(sources) == 2
|
||
assert sources[0].source_name == "知识库1"
|
||
assert sources[1].document_count == 50
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_list_sources_endpoint_not_found(self):
|
||
"""sources endpoint 不存在时返回默认信息源"""
|
||
adapter = GenericHTTPAdapter(
|
||
endpoint_url="https://example.com/api/kb",
|
||
)
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.get = AsyncMock(side_effect=Exception("404"))
|
||
adapter._make_client = MagicMock(return_value=mock_client)
|
||
|
||
sources = await adapter.list_sources()
|
||
assert len(sources) == 1
|
||
assert sources[0].source_type == "generic_http"
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Cross-adapter integration tests
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestCrossAdapterIntegration:
|
||
"""跨适配器集成测试 — 验证统一 KnowledgeBase 接口"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_all_adapters_implement_knowledge_base_protocol(self):
|
||
"""所有适配器都实现 KnowledgeBase 协议"""
|
||
adapters = [
|
||
FeishuKBAdapter(app_id="cli_test", app_secret="secret"),
|
||
ConfluenceAdapter(
|
||
base_url="https://test.atlassian.net/wiki",
|
||
username="user@test.com",
|
||
api_token="token",
|
||
),
|
||
GenericHTTPAdapter(endpoint_url="https://example.com/api/kb"),
|
||
]
|
||
for adapter in adapters:
|
||
assert isinstance(adapter, KnowledgeBase)
|
||
assert hasattr(adapter, "ingest")
|
||
assert hasattr(adapter, "query")
|
||
assert hasattr(adapter, "delete_by_id")
|
||
assert hasattr(adapter, "list_sources")
|
||
assert hasattr(adapter, "health_check")
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_all_adapters_have_search(self):
|
||
"""所有适配器都有 search 方法(query 的别名)"""
|
||
adapters = [
|
||
FeishuKBAdapter(app_id="cli_test", app_secret="secret"),
|
||
ConfluenceAdapter(
|
||
base_url="https://test.atlassian.net/wiki",
|
||
username="user@test.com",
|
||
api_token="token",
|
||
),
|
||
GenericHTTPAdapter(endpoint_url="https://example.com/api/kb"),
|
||
]
|
||
for adapter in adapters:
|
||
assert hasattr(adapter, "search")
|
||
assert hasattr(adapter, "get_document")
|
||
assert hasattr(adapter, "authenticate")
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_unified_search_returns_query_result(self):
|
||
"""统一检索接口返回 QueryResult 类型"""
|
||
# Feishu
|
||
feishu = FeishuKBAdapter(app_id="cli_test", app_secret="secret")
|
||
feishu._access_token = "t-xxx"
|
||
feishu._token_expiry = time.time() + 7200
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 200
|
||
mock_resp.raise_for_status = MagicMock()
|
||
mock_resp.json.return_value = {
|
||
"code": 0,
|
||
"data": {
|
||
"items": [
|
||
{"wiki_token": "w1", "title": "飞书文档", "content": "内容", "score": 0.9}
|
||
]
|
||
},
|
||
}
|
||
mock_client = AsyncMock()
|
||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||
feishu._make_client = MagicMock(return_value=mock_client)
|
||
|
||
results = await feishu.search("test")
|
||
assert all(isinstance(r, QueryResult) for r in results)
|
||
|
||
# Confluence
|
||
confluence = ConfluenceAdapter(
|
||
base_url="https://test.atlassian.net/wiki",
|
||
username="user@test.com",
|
||
api_token="token",
|
||
)
|
||
mock_resp2 = MagicMock()
|
||
mock_resp2.status_code = 200
|
||
mock_resp2.raise_for_status = MagicMock()
|
||
mock_resp2.json.return_value = {
|
||
"results": [
|
||
{"id": "123", "title": "Confluence 页面", "type": "page",
|
||
"body": {"storage": {"value": "<p>内容</p>"}}, "space": {"key": "DEV"}}
|
||
]
|
||
}
|
||
mock_client2 = AsyncMock()
|
||
mock_client2.get = AsyncMock(return_value=mock_resp2)
|
||
confluence._make_client = MagicMock(return_value=mock_client2)
|
||
|
||
results = await confluence.search("test")
|
||
assert all(isinstance(r, QueryResult) for r in results)
|
||
|
||
# GenericHTTP
|
||
generic = GenericHTTPAdapter(endpoint_url="https://example.com/api/kb")
|
||
mock_resp3 = MagicMock()
|
||
mock_resp3.status_code = 200
|
||
mock_resp3.raise_for_status = MagicMock()
|
||
mock_resp3.json.return_value = {
|
||
"results": [
|
||
{"content": "HTTP 内容", "score": 0.8, "doc_id": "d1", "title": "文档"}
|
||
]
|
||
}
|
||
mock_client3 = AsyncMock()
|
||
mock_client3.post = AsyncMock(return_value=mock_resp3)
|
||
generic._make_client = MagicMock(return_value=mock_client3)
|
||
|
||
results = await generic.search("test")
|
||
assert all(isinstance(r, QueryResult) for r in results)
|