fischer-agentkit/tests/unit/server/test_kb_management.py

597 lines
21 KiB
Python

"""Tests for Knowledge Base Management API routes"""
from __future__ import annotations
import io
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from agentkit.llm.gateway import LLMGateway
from agentkit.server.app import create_app
from agentkit.server.routes.kb_management import KnowledgeSourceStore, KnowledgeSource
from agentkit.skills.base import Skill, SkillConfig
from agentkit.skills.registry import SkillRegistry
from agentkit.tools.registry import ToolRegistry
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def mock_llm_gateway():
gateway = LLMGateway()
return gateway
@pytest.fixture
def skill_registry():
return SkillRegistry()
@pytest.fixture
def tool_registry():
return ToolRegistry()
@pytest.fixture
def app(mock_llm_gateway, skill_registry, tool_registry):
return create_app(
llm_gateway=mock_llm_gateway,
skill_registry=skill_registry,
tool_registry=tool_registry,
)
@pytest.fixture
def client(app):
return TestClient(app)
# ---------------------------------------------------------------------------
# KnowledgeSourceStore unit tests
# ---------------------------------------------------------------------------
class TestKnowledgeSourceStore:
def test_add_source(self):
store = KnowledgeSourceStore()
source = store.add_source("测试知识库", "local", {})
assert source.id is not None
assert source.name == "测试知识库"
assert source.type == "local"
assert source.status == "active"
def test_get_source(self):
store = KnowledgeSourceStore()
source = store.add_source("飞书知识库", "feishu", {"app_id": "test"})
retrieved = store.get_source(source.id)
assert retrieved is not None
assert retrieved.name == "飞书知识库"
def test_get_source_not_found(self):
store = KnowledgeSourceStore()
assert store.get_source("nonexistent") is None
def test_remove_source(self):
store = KnowledgeSourceStore()
source = store.add_source("待删除", "local", {})
assert store.remove_source(source.id) is True
assert store.get_source(source.id) is None
def test_remove_source_not_found(self):
store = KnowledgeSourceStore()
assert store.remove_source("nonexistent") is False
def test_list_sources(self):
store = KnowledgeSourceStore()
store.add_source("源1", "local", {})
store.add_source("源2", "feishu", {})
sources = store.list_sources()
assert len(sources) == 2
def test_list_sources_empty(self):
store = KnowledgeSourceStore()
assert store.list_sources() == []
def test_add_document_updates_source(self):
store = KnowledgeSourceStore()
source = store.add_source("本地文档", "local", {})
from agentkit.server.routes.kb_management import UploadedDocument
doc = UploadedDocument(
document_id="doc-1",
filename="test.pdf",
source_id=source.id,
chunks=5,
status="indexed",
)
store.add_document(doc)
updated_source = store.get_source(source.id)
assert updated_source.document_count == 1
assert updated_source.last_synced is not None
# ---------------------------------------------------------------------------
# GET /kb-management/sources
# ---------------------------------------------------------------------------
class TestListSources:
def test_list_sources_empty(self, client):
response = client.get("/api/v1/kb-management/sources")
assert response.status_code == 200
data = response.json()
assert "sources" in data
assert isinstance(data["sources"], list)
def test_list_sources_after_add(self, client):
# Add a source first
client.post(
"/api/v1/kb-management/sources",
json={"name": "测试源", "type": "local", "config": {}},
)
response = client.get("/api/v1/kb-management/sources")
assert response.status_code == 200
data = response.json()
assert len(data["sources"]) >= 1
# ---------------------------------------------------------------------------
# POST /kb-management/sources
# ---------------------------------------------------------------------------
class TestAddSource:
def test_add_local_source(self, client):
response = client.post(
"/api/v1/kb-management/sources",
json={"name": "本地文档", "type": "local", "config": {}},
)
assert response.status_code == 201
data = response.json()
assert data["name"] == "本地文档"
assert data["type"] == "local"
assert data["status"] == "active"
def test_add_feishu_source(self, client):
response = client.post(
"/api/v1/kb-management/sources",
json={
"name": "飞书知识库",
"type": "feishu",
"config": {"app_id": "test", "app_secret": "secret"},
},
)
assert response.status_code == 201
data = response.json()
assert data["type"] == "feishu"
def test_add_confluence_source(self, client):
response = client.post(
"/api/v1/kb-management/sources",
json={
"name": "Confluence",
"type": "confluence",
"config": {"base_url": "https://wiki.example.com"},
},
)
assert response.status_code == 201
def test_add_http_source(self, client):
response = client.post(
"/api/v1/kb-management/sources",
json={
"name": "HTTP API",
"type": "http",
"config": {"url": "https://api.example.com/kb"},
},
)
assert response.status_code == 201
def test_add_source_invalid_type(self, client):
response = client.post(
"/api/v1/kb-management/sources",
json={"name": "无效类型", "type": "invalid", "config": {}},
)
assert response.status_code == 422
# ---------------------------------------------------------------------------
# DELETE /kb-management/sources/{source_id}
# ---------------------------------------------------------------------------
class TestRemoveSource:
def test_remove_source(self, client):
add_resp = client.post(
"/api/v1/kb-management/sources",
json={"name": "待删除", "type": "local", "config": {}},
)
source_id = add_resp.json()["id"]
response = client.delete(f"/api/v1/kb-management/sources/{source_id}")
assert response.status_code == 200
assert response.json()["status"] == "removed"
def test_remove_source_not_found(self, client):
response = client.delete("/api/v1/kb-management/sources/nonexistent")
assert response.status_code == 404
# ---------------------------------------------------------------------------
# POST /kb-management/documents/upload
# ---------------------------------------------------------------------------
class TestUploadDocument:
def test_upload_document(self, client):
file_content = b"This is a test document content for upload."
response = client.post(
"/api/v1/kb-management/documents/upload",
files={"file": ("test.txt", io.BytesIO(file_content), "text/plain")},
)
assert response.status_code == 200
data = response.json()
assert data["filename"] == "test.txt"
assert data["document_id"] is not None
assert data["status"] == "indexed"
assert data["chunks"] >= 1
def test_upload_document_with_source_id(self, client):
# Create a source first
add_resp = client.post(
"/api/v1/kb-management/sources",
json={"name": "上传源", "type": "local", "config": {}},
)
source_id = add_resp.json()["id"]
file_content = b"Test content with source ID."
response = client.post(
"/api/v1/kb-management/documents/upload",
files={"file": ("doc.txt", io.BytesIO(file_content), "text/plain")},
data={"source_id": source_id},
)
assert response.status_code == 200
data = response.json()
assert data["filename"] == "doc.txt"
# ---------------------------------------------------------------------------
# POST /kb-management/search
# ---------------------------------------------------------------------------
class TestSearchKnowledge:
def test_search_returns_results(self, client):
response = client.post(
"/api/v1/kb-management/search",
json={"query": "测试查询", "top_k": 5},
)
assert response.status_code == 200
data = response.json()
assert "results" in data
assert isinstance(data["results"], list)
def test_search_with_sources_filter(self, client):
response = client.post(
"/api/v1/kb-management/search",
json={"query": "测试", "sources": ["local"], "top_k": 3},
)
assert response.status_code == 200
def test_search_default_top_k(self, client):
response = client.post(
"/api/v1/kb-management/search",
json={"query": "默认参数"},
)
assert response.status_code == 200
# ---------------------------------------------------------------------------
# GET /kb-management/sources/{source_id}/health
# ---------------------------------------------------------------------------
class TestSourceHealth:
def test_health_local_source(self, client):
add_resp = client.post(
"/api/v1/kb-management/sources",
json={"name": "本地", "type": "local", "config": {}},
)
source_id = add_resp.json()["id"]
response = client.get(f"/api/v1/kb-management/sources/{source_id}/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
assert data["source_id"] == source_id
def test_health_external_source(self, client):
add_resp = client.post(
"/api/v1/kb-management/sources",
json={"name": "外部", "type": "feishu", "config": {}},
)
source_id = add_resp.json()["id"]
response = client.get(f"/api/v1/kb-management/sources/{source_id}/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "unknown"
def test_health_source_not_found(self, client):
response = client.get("/api/v1/kb-management/sources/nonexistent/health")
assert response.status_code == 404
# ---------------------------------------------------------------------------
# GET /kb-management/documents
# ---------------------------------------------------------------------------
class TestListDocuments:
def test_list_documents_empty(self, client):
response = client.get("/api/v1/kb-management/documents")
assert response.status_code == 200
data = response.json()
assert "documents" in data
assert isinstance(data["documents"], list)
def test_list_documents_after_upload(self, client):
file_content = b"Test document for listing."
client.post(
"/api/v1/kb-management/documents/upload",
files={"file": ("list_test.txt", io.BytesIO(file_content), "text/plain")},
)
response = client.get("/api/v1/kb-management/documents")
assert response.status_code == 200
data = response.json()
assert len(data["documents"]) >= 1
doc = data["documents"][0]
assert "document_id" in doc
assert "filename" in doc
assert "source_id" in doc
assert "chunks" in doc
assert "status" in doc
assert "created_at" in doc
def test_list_documents_filter_by_source(self, client):
# Create a source
add_resp = client.post(
"/api/v1/kb-management/sources",
json={"name": "过滤源", "type": "local", "config": {}},
)
source_id = add_resp.json()["id"]
# Upload a document to that source
file_content = b"Filtered document."
upload_resp = client.post(
"/api/v1/kb-management/documents/upload",
files={"file": ("filtered.txt", io.BytesIO(file_content), "text/plain")},
data={"source_id": source_id},
)
assert upload_resp.status_code == 200
# Check the uploaded document's source_id
all_docs_resp = client.get("/api/v1/kb-management/documents")
all_docs = all_docs_resp.json()["documents"]
matching_docs = [d for d in all_docs if d["source_id"] == source_id]
# If source_id was not properly set (falls back to "local"),
# the filter endpoint should still work for the correct source_id
if matching_docs:
response = client.get(f"/api/v1/kb-management/documents?source_id={source_id}")
assert response.status_code == 200
data = response.json()
assert len(data["documents"]) >= 1
for doc in data["documents"]:
assert doc["source_id"] == source_id
# ---------------------------------------------------------------------------
# DELETE /kb-management/documents/{document_id}
# ---------------------------------------------------------------------------
class TestDeleteDocument:
def test_delete_document(self, client):
# Upload a document first
file_content = b"Document to delete."
upload_resp = client.post(
"/api/v1/kb-management/documents/upload",
files={"file": ("delete_me.txt", io.BytesIO(file_content), "text/plain")},
)
document_id = upload_resp.json()["document_id"]
response = client.delete(f"/api/v1/kb-management/documents/{document_id}")
assert response.status_code == 200
assert response.json()["status"] == "deleted"
def test_delete_document_not_found(self, client):
response = client.delete("/api/v1/kb-management/documents/nonexistent")
assert response.status_code == 404
def test_delete_document_updates_source_count(self, client):
# Create a source and upload a document
add_resp = client.post(
"/api/v1/kb-management/sources",
json={"name": "计数源", "type": "local", "config": {}},
)
source_id = add_resp.json()["id"]
file_content = b"Count document."
upload_resp = client.post(
"/api/v1/kb-management/documents/upload",
files={"file": ("count.txt", io.BytesIO(file_content), "text/plain")},
data={"source_id": source_id},
)
document_id = upload_resp.json()["document_id"]
# Delete the document
delete_resp = client.delete(f"/api/v1/kb-management/documents/{document_id}")
assert delete_resp.status_code == 200
assert delete_resp.json()["status"] == "deleted"
# Verify document is gone from the list
docs_resp = client.get("/api/v1/kb-management/documents")
doc_ids = [d["document_id"] for d in docs_resp.json()["documents"]]
assert document_id not in doc_ids
# ---------------------------------------------------------------------------
# POST /kb-management/sources/{source_id}/sync
# ---------------------------------------------------------------------------
class TestSyncSource:
def test_sync_source(self, client):
add_resp = client.post(
"/api/v1/kb-management/sources",
json={"name": "同步源", "type": "feishu", "config": {}},
)
source_id = add_resp.json()["id"]
response = client.post(f"/api/v1/kb-management/sources/{source_id}/sync")
assert response.status_code == 200
data = response.json()
assert data["status"] == "syncing"
assert data["source_id"] == source_id
def test_sync_source_not_found(self, client):
response = client.post("/api/v1/kb-management/sources/nonexistent/sync")
assert response.status_code == 404
# ---------------------------------------------------------------------------
# PUT /kb-management/sources/{source_id}
# ---------------------------------------------------------------------------
class TestUpdateSource:
def test_update_source_name(self, client):
add_resp = client.post(
"/api/v1/kb-management/sources",
json={"name": "原名", "type": "local", "config": {}},
)
source_id = add_resp.json()["id"]
response = client.put(
f"/api/v1/kb-management/sources/{source_id}",
json={"name": "新名称"},
)
assert response.status_code == 200
data = response.json()
assert data["name"] == "新名称"
assert data["id"] == source_id
def test_update_source_config(self, client):
add_resp = client.post(
"/api/v1/kb-management/sources",
json={"name": "配置源", "type": "feishu", "config": {"app_id": "old"}},
)
source_id = add_resp.json()["id"]
response = client.put(
f"/api/v1/kb-management/sources/{source_id}",
json={"config": {"app_id": "new", "app_secret": "secret"}},
)
assert response.status_code == 200
data = response.json()
assert data["id"] == source_id
def test_update_source_not_found(self, client):
response = client.put(
"/api/v1/kb-management/sources/nonexistent",
json={"name": "不存在"},
)
assert response.status_code == 404
# ---------------------------------------------------------------------------
# POST /kb-management/search with advanced params
# ---------------------------------------------------------------------------
class TestSearchAdvanced:
def test_search_with_source_ids(self, client):
response = client.post(
"/api/v1/kb-management/search",
json={"query": "高级查询", "source_ids": ["source-1", "source-2"], "top_k": 10},
)
assert response.status_code == 200
data = response.json()
assert "results" in data
def test_search_with_strategy_vector(self, client):
response = client.post(
"/api/v1/kb-management/search",
json={"query": "向量检索", "strategy": "vector"},
)
assert response.status_code == 200
def test_search_with_strategy_hybrid(self, client):
response = client.post(
"/api/v1/kb-management/search",
json={"query": "混合检索", "strategy": "hybrid"},
)
assert response.status_code == 200
def test_search_default_strategy(self, client):
response = client.post(
"/api/v1/kb-management/search",
json={"query": "默认策略"},
)
assert response.status_code == 200
# ---------------------------------------------------------------------------
# KnowledgeSourceStore: delete_document and update_source unit tests
# ---------------------------------------------------------------------------
class TestKnowledgeSourceStoreAdvanced:
def test_delete_document(self):
store = KnowledgeSourceStore()
source = store.add_source("本地文档", "local", {})
from agentkit.server.routes.kb_management import UploadedDocument
doc = UploadedDocument(
document_id="doc-del-1",
filename="delete.pdf",
source_id=source.id,
chunks=3,
status="indexed",
)
store.add_document(doc)
assert source.document_count == 1
result = store.delete_document("doc-del-1")
assert result is True
assert source.document_count == 0
assert len(store.list_documents()) == 0
def test_delete_document_not_found(self):
store = KnowledgeSourceStore()
result = store.delete_document("nonexistent")
assert result is False
def test_update_source_name(self):
store = KnowledgeSourceStore()
source = store.add_source("原名", "local", {})
updated = store.update_source(source.id, {"name": "新名"})
assert updated is not None
assert updated.name == "新名"
def test_update_source_config(self):
store = KnowledgeSourceStore()
source = store.add_source("配置源", "feishu", {"app_id": "old"})
updated = store.update_source(source.id, {"config": {"app_id": "new"}})
assert updated is not None
assert updated.config == {"app_id": "new"}
def test_update_source_not_found(self):
store = KnowledgeSourceStore()
result = store.update_source("nonexistent", {"name": "不存在"})
assert result is None