327 lines
11 KiB
Python
327 lines
11 KiB
Python
"""Tests for Knowledge Base Management API routes"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import io
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
from fastapi.testclient import TestClient
|
|
|
|
from agentkit.llm.gateway import LLMGateway
|
|
from agentkit.server.app import create_app
|
|
from agentkit.server.routes.kb_management import KnowledgeSourceStore, KnowledgeSource
|
|
from agentkit.skills.base import Skill, SkillConfig
|
|
from agentkit.skills.registry import SkillRegistry
|
|
from agentkit.tools.registry import ToolRegistry
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixtures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_llm_gateway():
|
|
gateway = LLMGateway()
|
|
return gateway
|
|
|
|
|
|
@pytest.fixture
|
|
def skill_registry():
|
|
return SkillRegistry()
|
|
|
|
|
|
@pytest.fixture
|
|
def tool_registry():
|
|
return ToolRegistry()
|
|
|
|
|
|
@pytest.fixture
|
|
def app(mock_llm_gateway, skill_registry, tool_registry):
|
|
return create_app(
|
|
llm_gateway=mock_llm_gateway,
|
|
skill_registry=skill_registry,
|
|
tool_registry=tool_registry,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def client(app):
|
|
return TestClient(app)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# KnowledgeSourceStore unit tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestKnowledgeSourceStore:
|
|
def test_add_source(self):
|
|
store = KnowledgeSourceStore()
|
|
source = store.add_source("测试知识库", "local", {})
|
|
assert source.id is not None
|
|
assert source.name == "测试知识库"
|
|
assert source.type == "local"
|
|
assert source.status == "active"
|
|
|
|
def test_get_source(self):
|
|
store = KnowledgeSourceStore()
|
|
source = store.add_source("飞书知识库", "feishu", {"app_id": "test"})
|
|
retrieved = store.get_source(source.id)
|
|
assert retrieved is not None
|
|
assert retrieved.name == "飞书知识库"
|
|
|
|
def test_get_source_not_found(self):
|
|
store = KnowledgeSourceStore()
|
|
assert store.get_source("nonexistent") is None
|
|
|
|
def test_remove_source(self):
|
|
store = KnowledgeSourceStore()
|
|
source = store.add_source("待删除", "local", {})
|
|
assert store.remove_source(source.id) is True
|
|
assert store.get_source(source.id) is None
|
|
|
|
def test_remove_source_not_found(self):
|
|
store = KnowledgeSourceStore()
|
|
assert store.remove_source("nonexistent") is False
|
|
|
|
def test_list_sources(self):
|
|
store = KnowledgeSourceStore()
|
|
store.add_source("源1", "local", {})
|
|
store.add_source("源2", "feishu", {})
|
|
sources = store.list_sources()
|
|
assert len(sources) == 2
|
|
|
|
def test_list_sources_empty(self):
|
|
store = KnowledgeSourceStore()
|
|
assert store.list_sources() == []
|
|
|
|
def test_add_document_updates_source(self):
|
|
store = KnowledgeSourceStore()
|
|
source = store.add_source("本地文档", "local", {})
|
|
from agentkit.server.routes.kb_management import UploadedDocument
|
|
doc = UploadedDocument(
|
|
document_id="doc-1",
|
|
filename="test.pdf",
|
|
source_id=source.id,
|
|
chunks=5,
|
|
status="indexed",
|
|
)
|
|
store.add_document(doc)
|
|
updated_source = store.get_source(source.id)
|
|
assert updated_source.document_count == 1
|
|
assert updated_source.last_synced is not None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# GET /kb-management/sources
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestListSources:
|
|
def test_list_sources_empty(self, client):
|
|
response = client.get("/api/v1/kb-management/sources")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "sources" in data
|
|
assert isinstance(data["sources"], list)
|
|
|
|
def test_list_sources_after_add(self, client):
|
|
# Add a source first
|
|
client.post(
|
|
"/api/v1/kb-management/sources",
|
|
json={"name": "测试源", "type": "local", "config": {}},
|
|
)
|
|
response = client.get("/api/v1/kb-management/sources")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert len(data["sources"]) >= 1
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# POST /kb-management/sources
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestAddSource:
|
|
def test_add_local_source(self, client):
|
|
response = client.post(
|
|
"/api/v1/kb-management/sources",
|
|
json={"name": "本地文档", "type": "local", "config": {}},
|
|
)
|
|
assert response.status_code == 201
|
|
data = response.json()
|
|
assert data["name"] == "本地文档"
|
|
assert data["type"] == "local"
|
|
assert data["status"] == "active"
|
|
|
|
def test_add_feishu_source(self, client):
|
|
response = client.post(
|
|
"/api/v1/kb-management/sources",
|
|
json={
|
|
"name": "飞书知识库",
|
|
"type": "feishu",
|
|
"config": {"app_id": "test", "app_secret": "secret"},
|
|
},
|
|
)
|
|
assert response.status_code == 201
|
|
data = response.json()
|
|
assert data["type"] == "feishu"
|
|
|
|
def test_add_confluence_source(self, client):
|
|
response = client.post(
|
|
"/api/v1/kb-management/sources",
|
|
json={
|
|
"name": "Confluence",
|
|
"type": "confluence",
|
|
"config": {"base_url": "https://wiki.example.com"},
|
|
},
|
|
)
|
|
assert response.status_code == 201
|
|
|
|
def test_add_http_source(self, client):
|
|
response = client.post(
|
|
"/api/v1/kb-management/sources",
|
|
json={
|
|
"name": "HTTP API",
|
|
"type": "http",
|
|
"config": {"url": "https://api.example.com/kb"},
|
|
},
|
|
)
|
|
assert response.status_code == 201
|
|
|
|
def test_add_source_invalid_type(self, client):
|
|
response = client.post(
|
|
"/api/v1/kb-management/sources",
|
|
json={"name": "无效类型", "type": "invalid", "config": {}},
|
|
)
|
|
assert response.status_code == 422
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# DELETE /kb-management/sources/{source_id}
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestRemoveSource:
|
|
def test_remove_source(self, client):
|
|
add_resp = client.post(
|
|
"/api/v1/kb-management/sources",
|
|
json={"name": "待删除", "type": "local", "config": {}},
|
|
)
|
|
source_id = add_resp.json()["id"]
|
|
|
|
response = client.delete(f"/api/v1/kb-management/sources/{source_id}")
|
|
assert response.status_code == 200
|
|
assert response.json()["status"] == "removed"
|
|
|
|
def test_remove_source_not_found(self, client):
|
|
response = client.delete("/api/v1/kb-management/sources/nonexistent")
|
|
assert response.status_code == 404
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# POST /kb-management/documents/upload
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestUploadDocument:
|
|
def test_upload_document(self, client):
|
|
file_content = b"This is a test document content for upload."
|
|
response = client.post(
|
|
"/api/v1/kb-management/documents/upload",
|
|
files={"file": ("test.txt", io.BytesIO(file_content), "text/plain")},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["filename"] == "test.txt"
|
|
assert data["document_id"] is not None
|
|
assert data["status"] == "indexed"
|
|
assert data["chunks"] >= 1
|
|
|
|
def test_upload_document_with_source_id(self, client):
|
|
# Create a source first
|
|
add_resp = client.post(
|
|
"/api/v1/kb-management/sources",
|
|
json={"name": "上传源", "type": "local", "config": {}},
|
|
)
|
|
source_id = add_resp.json()["id"]
|
|
|
|
file_content = b"Test content with source ID."
|
|
response = client.post(
|
|
"/api/v1/kb-management/documents/upload",
|
|
files={"file": ("doc.txt", io.BytesIO(file_content), "text/plain")},
|
|
data={"source_id": source_id},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["filename"] == "doc.txt"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# POST /kb-management/search
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSearchKnowledge:
|
|
def test_search_returns_results(self, client):
|
|
response = client.post(
|
|
"/api/v1/kb-management/search",
|
|
json={"query": "测试查询", "top_k": 5},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "results" in data
|
|
assert isinstance(data["results"], list)
|
|
|
|
def test_search_with_sources_filter(self, client):
|
|
response = client.post(
|
|
"/api/v1/kb-management/search",
|
|
json={"query": "测试", "sources": ["local"], "top_k": 3},
|
|
)
|
|
assert response.status_code == 200
|
|
|
|
def test_search_default_top_k(self, client):
|
|
response = client.post(
|
|
"/api/v1/kb-management/search",
|
|
json={"query": "默认参数"},
|
|
)
|
|
assert response.status_code == 200
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# GET /kb-management/sources/{source_id}/health
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSourceHealth:
|
|
def test_health_local_source(self, client):
|
|
add_resp = client.post(
|
|
"/api/v1/kb-management/sources",
|
|
json={"name": "本地", "type": "local", "config": {}},
|
|
)
|
|
source_id = add_resp.json()["id"]
|
|
|
|
response = client.get(f"/api/v1/kb-management/sources/{source_id}/health")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["status"] == "healthy"
|
|
assert data["source_id"] == source_id
|
|
|
|
def test_health_external_source(self, client):
|
|
add_resp = client.post(
|
|
"/api/v1/kb-management/sources",
|
|
json={"name": "外部", "type": "feishu", "config": {}},
|
|
)
|
|
source_id = add_resp.json()["id"]
|
|
|
|
response = client.get(f"/api/v1/kb-management/sources/{source_id}/health")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["status"] == "unknown"
|
|
|
|
def test_health_source_not_found(self, client):
|
|
response = client.get("/api/v1/kb-management/sources/nonexistent/health")
|
|
assert response.status_code == 404
|