611 lines
21 KiB
Python
611 lines
21 KiB
Python
"""Tests for Knowledge Base Management API routes"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import io
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
from fastapi import Request
|
|
from fastapi.testclient import TestClient
|
|
|
|
from agentkit.llm.gateway import LLMGateway
|
|
from agentkit.server.app import create_app
|
|
from agentkit.server.routes.kb_management import KnowledgeSourceStore, KnowledgeSource
|
|
from agentkit.skills.base import Skill, SkillConfig
|
|
from agentkit.skills.registry import SkillRegistry
|
|
from agentkit.tools.registry import ToolRegistry
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixtures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_llm_gateway():
|
|
gateway = LLMGateway()
|
|
return gateway
|
|
|
|
|
|
@pytest.fixture
|
|
def skill_registry():
|
|
return SkillRegistry()
|
|
|
|
|
|
@pytest.fixture
|
|
def tool_registry():
|
|
return ToolRegistry()
|
|
|
|
|
|
@pytest.fixture
|
|
def app(mock_llm_gateway, skill_registry, tool_registry):
|
|
application = create_app(
|
|
llm_gateway=mock_llm_gateway,
|
|
skill_registry=skill_registry,
|
|
tool_registry=tool_registry,
|
|
)
|
|
# Inject a dev-mode admin user so the KB_WRITE / KB_QUERY permission
|
|
# checks pass. In production, AuthMiddleware sets this from JWT.
|
|
@application.middleware("http")
|
|
async def _set_dev_admin_user(request: Request, call_next):
|
|
request.state.current_user = {
|
|
"user_id": "dev-admin",
|
|
"username": "dev-admin",
|
|
"role": "admin",
|
|
"dev_mode": True,
|
|
}
|
|
return await call_next(request)
|
|
|
|
return application
|
|
|
|
|
|
@pytest.fixture
|
|
def client(app):
|
|
return TestClient(app)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# KnowledgeSourceStore unit tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestKnowledgeSourceStore:
|
|
def test_add_source(self):
|
|
store = KnowledgeSourceStore()
|
|
source = store.add_source("测试知识库", "local", {})
|
|
assert source.id is not None
|
|
assert source.name == "测试知识库"
|
|
assert source.type == "local"
|
|
assert source.status == "active"
|
|
|
|
def test_get_source(self):
|
|
store = KnowledgeSourceStore()
|
|
source = store.add_source("飞书知识库", "feishu", {"app_id": "test"})
|
|
retrieved = store.get_source(source.id)
|
|
assert retrieved is not None
|
|
assert retrieved.name == "飞书知识库"
|
|
|
|
def test_get_source_not_found(self):
|
|
store = KnowledgeSourceStore()
|
|
assert store.get_source("nonexistent") is None
|
|
|
|
def test_remove_source(self):
|
|
store = KnowledgeSourceStore()
|
|
source = store.add_source("待删除", "local", {})
|
|
assert store.remove_source(source.id) is True
|
|
assert store.get_source(source.id) is None
|
|
|
|
def test_remove_source_not_found(self):
|
|
store = KnowledgeSourceStore()
|
|
assert store.remove_source("nonexistent") is False
|
|
|
|
def test_list_sources(self):
|
|
store = KnowledgeSourceStore()
|
|
store.add_source("源1", "local", {})
|
|
store.add_source("源2", "feishu", {})
|
|
sources = store.list_sources()
|
|
assert len(sources) == 2
|
|
|
|
def test_list_sources_empty(self):
|
|
store = KnowledgeSourceStore()
|
|
assert store.list_sources() == []
|
|
|
|
def test_add_document_updates_source(self):
|
|
store = KnowledgeSourceStore()
|
|
source = store.add_source("本地文档", "local", {})
|
|
from agentkit.server.routes.kb_management import UploadedDocument
|
|
doc = UploadedDocument(
|
|
document_id="doc-1",
|
|
filename="test.pdf",
|
|
source_id=source.id,
|
|
chunks=5,
|
|
status="indexed",
|
|
)
|
|
store.add_document(doc)
|
|
updated_source = store.get_source(source.id)
|
|
assert updated_source.document_count == 1
|
|
assert updated_source.last_synced is not None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# GET /kb-management/sources
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestListSources:
|
|
def test_list_sources_empty(self, client):
|
|
response = client.get("/api/v1/kb-management/sources")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "sources" in data
|
|
assert isinstance(data["sources"], list)
|
|
|
|
def test_list_sources_after_add(self, client):
|
|
# Add a source first
|
|
client.post(
|
|
"/api/v1/kb-management/sources",
|
|
json={"name": "测试源", "type": "local", "config": {}},
|
|
)
|
|
response = client.get("/api/v1/kb-management/sources")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert len(data["sources"]) >= 1
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# POST /kb-management/sources
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestAddSource:
|
|
def test_add_local_source(self, client):
|
|
response = client.post(
|
|
"/api/v1/kb-management/sources",
|
|
json={"name": "本地文档", "type": "local", "config": {}},
|
|
)
|
|
assert response.status_code == 201
|
|
data = response.json()
|
|
assert data["name"] == "本地文档"
|
|
assert data["type"] == "local"
|
|
assert data["status"] == "active"
|
|
|
|
def test_add_feishu_source(self, client):
|
|
response = client.post(
|
|
"/api/v1/kb-management/sources",
|
|
json={
|
|
"name": "飞书知识库",
|
|
"type": "feishu",
|
|
"config": {"app_id": "test", "app_secret": "secret"},
|
|
},
|
|
)
|
|
assert response.status_code == 201
|
|
data = response.json()
|
|
assert data["type"] == "feishu"
|
|
|
|
def test_add_confluence_source(self, client):
|
|
response = client.post(
|
|
"/api/v1/kb-management/sources",
|
|
json={
|
|
"name": "Confluence",
|
|
"type": "confluence",
|
|
"config": {"base_url": "https://wiki.example.com"},
|
|
},
|
|
)
|
|
assert response.status_code == 201
|
|
|
|
def test_add_http_source(self, client):
|
|
response = client.post(
|
|
"/api/v1/kb-management/sources",
|
|
json={
|
|
"name": "HTTP API",
|
|
"type": "http",
|
|
"config": {"url": "https://api.example.com/kb"},
|
|
},
|
|
)
|
|
assert response.status_code == 201
|
|
|
|
def test_add_source_invalid_type(self, client):
|
|
response = client.post(
|
|
"/api/v1/kb-management/sources",
|
|
json={"name": "无效类型", "type": "invalid", "config": {}},
|
|
)
|
|
assert response.status_code == 422
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# DELETE /kb-management/sources/{source_id}
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestRemoveSource:
|
|
def test_remove_source(self, client):
|
|
add_resp = client.post(
|
|
"/api/v1/kb-management/sources",
|
|
json={"name": "待删除", "type": "local", "config": {}},
|
|
)
|
|
source_id = add_resp.json()["id"]
|
|
|
|
response = client.delete(f"/api/v1/kb-management/sources/{source_id}")
|
|
assert response.status_code == 200
|
|
assert response.json()["status"] == "removed"
|
|
|
|
def test_remove_source_not_found(self, client):
|
|
response = client.delete("/api/v1/kb-management/sources/nonexistent")
|
|
assert response.status_code == 404
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# POST /kb-management/documents/upload
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestUploadDocument:
|
|
def test_upload_document(self, client):
|
|
file_content = b"This is a test document content for upload."
|
|
response = client.post(
|
|
"/api/v1/kb-management/documents/upload",
|
|
files={"file": ("test.txt", io.BytesIO(file_content), "text/plain")},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["filename"] == "test.txt"
|
|
assert data["document_id"] is not None
|
|
assert data["status"] == "segmenting"
|
|
assert data["chunks"] >= 1
|
|
|
|
def test_upload_document_with_source_id(self, client):
|
|
# Create a source first
|
|
add_resp = client.post(
|
|
"/api/v1/kb-management/sources",
|
|
json={"name": "上传源", "type": "local", "config": {}},
|
|
)
|
|
source_id = add_resp.json()["id"]
|
|
|
|
file_content = b"Test content with source ID."
|
|
response = client.post(
|
|
"/api/v1/kb-management/documents/upload",
|
|
files={"file": ("doc.txt", io.BytesIO(file_content), "text/plain")},
|
|
data={"source_id": source_id},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["filename"] == "doc.txt"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# POST /kb-management/search
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSearchKnowledge:
|
|
def test_search_returns_results(self, client):
|
|
response = client.post(
|
|
"/api/v1/kb-management/search",
|
|
json={"query": "测试查询", "top_k": 5},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "results" in data
|
|
assert isinstance(data["results"], list)
|
|
|
|
def test_search_with_sources_filter(self, client):
|
|
response = client.post(
|
|
"/api/v1/kb-management/search",
|
|
json={"query": "测试", "sources": ["local"], "top_k": 3},
|
|
)
|
|
assert response.status_code == 200
|
|
|
|
def test_search_default_top_k(self, client):
|
|
response = client.post(
|
|
"/api/v1/kb-management/search",
|
|
json={"query": "默认参数"},
|
|
)
|
|
assert response.status_code == 200
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# GET /kb-management/sources/{source_id}/health
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSourceHealth:
|
|
def test_health_local_source(self, client):
|
|
add_resp = client.post(
|
|
"/api/v1/kb-management/sources",
|
|
json={"name": "本地", "type": "local", "config": {}},
|
|
)
|
|
source_id = add_resp.json()["id"]
|
|
|
|
response = client.get(f"/api/v1/kb-management/sources/{source_id}/health")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["status"] == "healthy"
|
|
assert data["source_id"] == source_id
|
|
|
|
def test_health_external_source(self, client):
|
|
add_resp = client.post(
|
|
"/api/v1/kb-management/sources",
|
|
json={"name": "外部", "type": "feishu", "config": {}},
|
|
)
|
|
source_id = add_resp.json()["id"]
|
|
|
|
response = client.get(f"/api/v1/kb-management/sources/{source_id}/health")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["status"] == "unknown"
|
|
|
|
def test_health_source_not_found(self, client):
|
|
response = client.get("/api/v1/kb-management/sources/nonexistent/health")
|
|
assert response.status_code == 404
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# GET /kb-management/documents
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestListDocuments:
|
|
def test_list_documents_empty(self, client):
|
|
response = client.get("/api/v1/kb-management/documents")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "documents" in data
|
|
assert isinstance(data["documents"], list)
|
|
|
|
def test_list_documents_after_upload(self, client):
|
|
file_content = b"Test document for listing."
|
|
client.post(
|
|
"/api/v1/kb-management/documents/upload",
|
|
files={"file": ("list_test.txt", io.BytesIO(file_content), "text/plain")},
|
|
)
|
|
response = client.get("/api/v1/kb-management/documents")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert len(data["documents"]) >= 1
|
|
doc = data["documents"][0]
|
|
assert "document_id" in doc
|
|
assert "filename" in doc
|
|
assert "source_id" in doc
|
|
assert "chunks" in doc
|
|
assert "status" in doc
|
|
assert "created_at" in doc
|
|
|
|
def test_list_documents_filter_by_source(self, client):
|
|
# Create a source
|
|
add_resp = client.post(
|
|
"/api/v1/kb-management/sources",
|
|
json={"name": "过滤源", "type": "local", "config": {}},
|
|
)
|
|
source_id = add_resp.json()["id"]
|
|
|
|
# Upload a document to that source
|
|
file_content = b"Filtered document."
|
|
upload_resp = client.post(
|
|
"/api/v1/kb-management/documents/upload",
|
|
files={"file": ("filtered.txt", io.BytesIO(file_content), "text/plain")},
|
|
data={"source_id": source_id},
|
|
)
|
|
assert upload_resp.status_code == 200
|
|
|
|
# Check the uploaded document's source_id
|
|
all_docs_resp = client.get("/api/v1/kb-management/documents")
|
|
all_docs = all_docs_resp.json()["documents"]
|
|
matching_docs = [d for d in all_docs if d["source_id"] == source_id]
|
|
|
|
# If source_id was not properly set (falls back to "local"),
|
|
# the filter endpoint should still work for the correct source_id
|
|
if matching_docs:
|
|
response = client.get(f"/api/v1/kb-management/documents?source_id={source_id}")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert len(data["documents"]) >= 1
|
|
for doc in data["documents"]:
|
|
assert doc["source_id"] == source_id
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# DELETE /kb-management/documents/{document_id}
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestDeleteDocument:
|
|
def test_delete_document(self, client):
|
|
# Upload a document first
|
|
file_content = b"Document to delete."
|
|
upload_resp = client.post(
|
|
"/api/v1/kb-management/documents/upload",
|
|
files={"file": ("delete_me.txt", io.BytesIO(file_content), "text/plain")},
|
|
)
|
|
document_id = upload_resp.json()["document_id"]
|
|
|
|
response = client.delete(f"/api/v1/kb-management/documents/{document_id}")
|
|
assert response.status_code == 200
|
|
assert response.json()["status"] == "deleted"
|
|
|
|
def test_delete_document_not_found(self, client):
|
|
response = client.delete("/api/v1/kb-management/documents/nonexistent")
|
|
assert response.status_code == 404
|
|
|
|
def test_delete_document_updates_source_count(self, client):
|
|
# Create a source and upload a document
|
|
add_resp = client.post(
|
|
"/api/v1/kb-management/sources",
|
|
json={"name": "计数源", "type": "local", "config": {}},
|
|
)
|
|
source_id = add_resp.json()["id"]
|
|
|
|
file_content = b"Count document."
|
|
upload_resp = client.post(
|
|
"/api/v1/kb-management/documents/upload",
|
|
files={"file": ("count.txt", io.BytesIO(file_content), "text/plain")},
|
|
data={"source_id": source_id},
|
|
)
|
|
document_id = upload_resp.json()["document_id"]
|
|
|
|
# Delete the document
|
|
delete_resp = client.delete(f"/api/v1/kb-management/documents/{document_id}")
|
|
assert delete_resp.status_code == 200
|
|
assert delete_resp.json()["status"] == "deleted"
|
|
|
|
# Verify document is gone from the list
|
|
docs_resp = client.get("/api/v1/kb-management/documents")
|
|
doc_ids = [d["document_id"] for d in docs_resp.json()["documents"]]
|
|
assert document_id not in doc_ids
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# POST /kb-management/sources/{source_id}/sync
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSyncSource:
|
|
def test_sync_source(self, client):
|
|
add_resp = client.post(
|
|
"/api/v1/kb-management/sources",
|
|
json={"name": "同步源", "type": "feishu", "config": {}},
|
|
)
|
|
source_id = add_resp.json()["id"]
|
|
|
|
response = client.post(f"/api/v1/kb-management/sources/{source_id}/sync")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["status"] == "syncing"
|
|
assert data["source_id"] == source_id
|
|
|
|
def test_sync_source_not_found(self, client):
|
|
response = client.post("/api/v1/kb-management/sources/nonexistent/sync")
|
|
assert response.status_code == 404
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# PUT /kb-management/sources/{source_id}
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestUpdateSource:
|
|
def test_update_source_name(self, client):
|
|
add_resp = client.post(
|
|
"/api/v1/kb-management/sources",
|
|
json={"name": "原名", "type": "local", "config": {}},
|
|
)
|
|
source_id = add_resp.json()["id"]
|
|
|
|
response = client.put(
|
|
f"/api/v1/kb-management/sources/{source_id}",
|
|
json={"name": "新名称"},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["name"] == "新名称"
|
|
assert data["id"] == source_id
|
|
|
|
def test_update_source_config(self, client):
|
|
add_resp = client.post(
|
|
"/api/v1/kb-management/sources",
|
|
json={"name": "配置源", "type": "feishu", "config": {"app_id": "old"}},
|
|
)
|
|
source_id = add_resp.json()["id"]
|
|
|
|
response = client.put(
|
|
f"/api/v1/kb-management/sources/{source_id}",
|
|
json={"config": {"app_id": "new", "app_secret": "secret"}},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["id"] == source_id
|
|
|
|
def test_update_source_not_found(self, client):
|
|
response = client.put(
|
|
"/api/v1/kb-management/sources/nonexistent",
|
|
json={"name": "不存在"},
|
|
)
|
|
assert response.status_code == 404
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# POST /kb-management/search with advanced params
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSearchAdvanced:
|
|
def test_search_with_source_ids(self, client):
|
|
response = client.post(
|
|
"/api/v1/kb-management/search",
|
|
json={"query": "高级查询", "source_ids": ["source-1", "source-2"], "top_k": 10},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "results" in data
|
|
|
|
def test_search_with_strategy_vector(self, client):
|
|
response = client.post(
|
|
"/api/v1/kb-management/search",
|
|
json={"query": "向量检索", "strategy": "vector"},
|
|
)
|
|
assert response.status_code == 200
|
|
|
|
def test_search_with_strategy_hybrid(self, client):
|
|
response = client.post(
|
|
"/api/v1/kb-management/search",
|
|
json={"query": "混合检索", "strategy": "hybrid"},
|
|
)
|
|
assert response.status_code == 200
|
|
|
|
def test_search_default_strategy(self, client):
|
|
response = client.post(
|
|
"/api/v1/kb-management/search",
|
|
json={"query": "默认策略"},
|
|
)
|
|
assert response.status_code == 200
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# KnowledgeSourceStore: delete_document and update_source unit tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestKnowledgeSourceStoreAdvanced:
|
|
def test_delete_document(self):
|
|
store = KnowledgeSourceStore()
|
|
source = store.add_source("本地文档", "local", {})
|
|
from agentkit.server.routes.kb_management import UploadedDocument
|
|
doc = UploadedDocument(
|
|
document_id="doc-del-1",
|
|
filename="delete.pdf",
|
|
source_id=source.id,
|
|
chunks=3,
|
|
status="indexed",
|
|
)
|
|
store.add_document(doc)
|
|
assert source.document_count == 1
|
|
|
|
result = store.delete_document("doc-del-1")
|
|
assert result is True
|
|
assert source.document_count == 0
|
|
assert len(store.list_documents()) == 0
|
|
|
|
def test_delete_document_not_found(self):
|
|
store = KnowledgeSourceStore()
|
|
result = store.delete_document("nonexistent")
|
|
assert result is False
|
|
|
|
def test_update_source_name(self):
|
|
store = KnowledgeSourceStore()
|
|
source = store.add_source("原名", "local", {})
|
|
updated = store.update_source(source.id, {"name": "新名"})
|
|
assert updated is not None
|
|
assert updated.name == "新名"
|
|
|
|
def test_update_source_config(self):
|
|
store = KnowledgeSourceStore()
|
|
source = store.add_source("配置源", "feishu", {"app_id": "old"})
|
|
updated = store.update_source(source.id, {"config": {"app_id": "new"}})
|
|
assert updated is not None
|
|
assert updated.config == {"app_id": "new"}
|
|
|
|
def test_update_source_not_found(self):
|
|
store = KnowledgeSourceStore()
|
|
result = store.update_source("nonexistent", {"name": "不存在"})
|
|
assert result is None
|