527 lines
21 KiB
Python
527 lines
21 KiB
Python
"""U14 — Skill/专家团队 MCP 发布的单元测试。
|
||
|
||
覆盖场景:
|
||
- 管理员发布 Skill 成功 / 非管理员拒绝 / 无认证拒绝 / Skill 不存在
|
||
- 危险 Skill 默认拒绝 / 显式 opt-in 放行 / 重复发布 409
|
||
- 已发布 Skill 在 /tools/list 可见 / 外部系统调用已发布 Skill
|
||
- DELETE 取消发布 / DELETE 不存在 404 / GET 列出已发布
|
||
- 团队发布 / 团队在 /tools/list 可见
|
||
- PublisherRegistry 单元 CRUD
|
||
- SkillMCPAdapter.execute 无 executor / executor 异常
|
||
- TeamMCPAdapter 名称格式
|
||
- check_dangerous_publish 单元测试
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
from typing import Any
|
||
from unittest.mock import AsyncMock, MagicMock
|
||
|
||
import httpx
|
||
import jwt
|
||
import pytest
|
||
from fastapi import FastAPI
|
||
|
||
from agentkit.mcp.publisher import (
|
||
PublisherRegistry,
|
||
SkillMCPAdapter,
|
||
TeamMCPAdapter,
|
||
_DANGEROUS_TOOL_NAMES,
|
||
check_dangerous_publish,
|
||
)
|
||
from agentkit.mcp.server import create_mcp_router
|
||
from agentkit.server.auth.middleware import AuthMiddleware
|
||
from agentkit.server.routes import mcp_publish as mcp_publish_routes
|
||
|
||
# 测试用的固定凭据 — 仅限单元测试,不可用于生产。
|
||
JWT_SECRET = "u14-test-jwt-secret-xxxxxxxxxxxxx"
|
||
API_KEY = "u14-test-api-key-zzz"
|
||
|
||
|
||
# ── 测试辅助函数 ──────────────────────────────────────────
|
||
|
||
|
||
def _make_jwt(role: str = "member", user_id: str = "u1", username: str = "alice") -> str:
|
||
"""签发一个测试用 access JWT(HS256)。"""
|
||
payload = {
|
||
"sub": user_id,
|
||
"username": username,
|
||
"role": role,
|
||
"type": "access",
|
||
"iat": 1700000000,
|
||
"exp": 9999999999,
|
||
}
|
||
token = jwt.encode(payload, JWT_SECRET, algorithm="HS256")
|
||
return token.decode("utf-8") if isinstance(token, bytes) else token
|
||
|
||
|
||
def _make_mock_skill(
|
||
name: str = "my_skill",
|
||
description: str = "Test skill",
|
||
tool_names: list[str] | None = None,
|
||
) -> MagicMock:
|
||
"""构造一个 mock Skill,模拟 Skill 接口(name/config/tools)。"""
|
||
skill = MagicMock()
|
||
skill.name = name
|
||
skill.config.name = name
|
||
skill.config.description = description
|
||
skill_tools = []
|
||
if tool_names:
|
||
for tn in tool_names:
|
||
t = MagicMock()
|
||
t.name = tn
|
||
skill_tools.append(t)
|
||
skill.tools = skill_tools
|
||
return skill
|
||
|
||
|
||
def _make_mock_skill_registry(skills: list) -> MagicMock:
|
||
"""构造一个 mock SkillRegistry,支持 get(name) 抛 KeyError 表示未找到。"""
|
||
registry = MagicMock()
|
||
|
||
def _get(name: str):
|
||
for s in skills:
|
||
if s.name == name:
|
||
return s
|
||
raise KeyError(name)
|
||
|
||
registry.get = _get
|
||
registry.list_skills.return_value = skills
|
||
return registry
|
||
|
||
|
||
def _make_app(
|
||
skill_registry: Any = None,
|
||
agent_pool: Any = None,
|
||
) -> FastAPI:
|
||
"""构造测试用 FastAPI app:MCP router + 发布路由 + AuthMiddleware。"""
|
||
app = FastAPI()
|
||
app.state.tool_registry = None
|
||
app.state.skill_registry = skill_registry
|
||
app.state.agent_pool = agent_pool
|
||
app.state.mcp_publisher_registry = PublisherRegistry()
|
||
|
||
app.add_middleware(AuthMiddleware, jwt_secret=JWT_SECRET, api_key=API_KEY)
|
||
|
||
mcp_router = create_mcp_router(
|
||
tool_registry=None,
|
||
published_tools_getter=lambda: app.state.mcp_publisher_registry.list_published(),
|
||
)
|
||
app.include_router(mcp_router, prefix="/api/v1/mcp")
|
||
app.include_router(mcp_publish_routes.router, prefix="/api/v1")
|
||
return app
|
||
|
||
|
||
def _admin_headers() -> dict[str, str]:
|
||
return {"Authorization": f"Bearer {_make_jwt(role='admin')}"}
|
||
|
||
|
||
def _member_headers() -> dict[str, str]:
|
||
return {"Authorization": f"Bearer {_make_jwt(role='member')}"}
|
||
|
||
|
||
# ── 1-3: 认证与权限 ──────────────────────────────────────
|
||
|
||
|
||
class TestPublishAuth:
|
||
"""发布端点的认证与权限校验。"""
|
||
|
||
async def test_admin_publish_skill_success(self):
|
||
"""场景 1: 管理员发布 Skill → 200 + published 名称。"""
|
||
skill = _make_mock_skill(name="my_skill")
|
||
registry = _make_mock_skill_registry([skill])
|
||
app = _make_app(skill_registry=registry)
|
||
transport = httpx.ASGITransport(app=app)
|
||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||
resp = await client.post(
|
||
"/api/v1/mcp/publish/skill/my_skill",
|
||
json={},
|
||
headers=_admin_headers(),
|
||
)
|
||
assert resp.status_code == 200
|
||
assert resp.json() == {"published": "skill_my_skill", "type": "skill"}
|
||
|
||
async def test_member_publish_skill_rejected_403(self):
|
||
"""场景 2: member 角色发布 → 403(需 SYSTEM_CONFIG 权限)。"""
|
||
skill = _make_mock_skill(name="my_skill")
|
||
registry = _make_mock_skill_registry([skill])
|
||
app = _make_app(skill_registry=registry)
|
||
transport = httpx.ASGITransport(app=app)
|
||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||
resp = await client.post(
|
||
"/api/v1/mcp/publish/skill/my_skill",
|
||
json={},
|
||
headers=_member_headers(),
|
||
)
|
||
assert resp.status_code == 403
|
||
|
||
async def test_no_auth_publish_rejected_401(self):
|
||
"""场景 3: 无认证发布 → 401。"""
|
||
skill = _make_mock_skill(name="my_skill")
|
||
registry = _make_mock_skill_registry([skill])
|
||
app = _make_app(skill_registry=registry)
|
||
transport = httpx.ASGITransport(app=app)
|
||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||
resp = await client.post(
|
||
"/api/v1/mcp/publish/skill/my_skill",
|
||
json={},
|
||
)
|
||
assert resp.status_code == 401
|
||
|
||
|
||
# ── 4-7: Skill 发布业务逻辑 ──────────────────────────────
|
||
|
||
|
||
class TestPublishSkillLogic:
|
||
"""Skill 发布的业务逻辑:404 / 危险工具 / 重复发布。"""
|
||
|
||
async def test_skill_not_found_404(self):
|
||
"""场景 4: 发布不存在的 Skill → 404。"""
|
||
registry = _make_mock_skill_registry([])
|
||
app = _make_app(skill_registry=registry)
|
||
transport = httpx.ASGITransport(app=app)
|
||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||
resp = await client.post(
|
||
"/api/v1/mcp/publish/skill/nonexistent",
|
||
json={},
|
||
headers=_admin_headers(),
|
||
)
|
||
assert resp.status_code == 404
|
||
|
||
async def test_dangerous_skill_default_rejected_403(self):
|
||
"""场景 5: Skill 含 terminal 工具,默认(allow_dangerous=false)→ 403。"""
|
||
skill = _make_mock_skill(name="dangerous_skill", tool_names=["terminal"])
|
||
registry = _make_mock_skill_registry([skill])
|
||
app = _make_app(skill_registry=registry)
|
||
transport = httpx.ASGITransport(app=app)
|
||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||
resp = await client.post(
|
||
"/api/v1/mcp/publish/skill/dangerous_skill",
|
||
json={"allow_dangerous": False},
|
||
headers=_admin_headers(),
|
||
)
|
||
assert resp.status_code == 403
|
||
|
||
async def test_dangerous_skill_explicit_optin_success(self):
|
||
"""场景 6: 同样危险 Skill,显式 allow_dangerous=true → 200。"""
|
||
skill = _make_mock_skill(name="dangerous_skill", tool_names=["terminal"])
|
||
registry = _make_mock_skill_registry([skill])
|
||
app = _make_app(skill_registry=registry)
|
||
transport = httpx.ASGITransport(app=app)
|
||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||
resp = await client.post(
|
||
"/api/v1/mcp/publish/skill/dangerous_skill",
|
||
json={"allow_dangerous": True},
|
||
headers=_admin_headers(),
|
||
)
|
||
assert resp.status_code == 200
|
||
assert resp.json() == {"published": "skill_dangerous_skill", "type": "skill"}
|
||
|
||
async def test_duplicate_publish_returns_409(self):
|
||
"""场景 7: 重复发布同一 Skill → 第二次 409。"""
|
||
skill = _make_mock_skill(name="my_skill")
|
||
registry = _make_mock_skill_registry([skill])
|
||
app = _make_app(skill_registry=registry)
|
||
transport = httpx.ASGITransport(app=app)
|
||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||
first = await client.post(
|
||
"/api/v1/mcp/publish/skill/my_skill",
|
||
json={},
|
||
headers=_admin_headers(),
|
||
)
|
||
second = await client.post(
|
||
"/api/v1/mcp/publish/skill/my_skill",
|
||
json={},
|
||
headers=_admin_headers(),
|
||
)
|
||
assert first.status_code == 200
|
||
assert second.status_code == 409
|
||
|
||
|
||
# ── 8-9: 已发布工具的可见性与调用 ─────────────────────────
|
||
|
||
|
||
class TestPublishedToolVisibility:
|
||
"""已发布工具在 MCP 端点的可见性与调用。"""
|
||
|
||
async def test_published_skill_visible_in_tools_list(self):
|
||
"""场景 8: 发布后 GET /tools/list(member)包含 skill_my_skill。"""
|
||
skill = _make_mock_skill(name="my_skill", description="desc")
|
||
registry = _make_mock_skill_registry([skill])
|
||
app = _make_app(skill_registry=registry)
|
||
transport = httpx.ASGITransport(app=app)
|
||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||
await client.post(
|
||
"/api/v1/mcp/publish/skill/my_skill",
|
||
json={},
|
||
headers=_admin_headers(),
|
||
)
|
||
resp = await client.get(
|
||
"/api/v1/mcp/tools/list",
|
||
headers=_member_headers(),
|
||
)
|
||
assert resp.status_code == 200
|
||
names = {t["name"] for t in resp.json()["tools"]}
|
||
assert "skill_my_skill" in names
|
||
|
||
async def test_call_published_skill(self):
|
||
"""场景 9: 外部系统调用已发布 Skill → 200 + executor 返回结果。"""
|
||
skill = _make_mock_skill(name="my_skill")
|
||
registry = _make_mock_skill_registry([skill])
|
||
# mock agent_pool.run_skill 返回固定结果
|
||
pool = MagicMock()
|
||
pool.run_skill = AsyncMock(return_value={"result": "processed: hello"})
|
||
app = _make_app(skill_registry=registry, agent_pool=pool)
|
||
transport = httpx.ASGITransport(app=app)
|
||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||
await client.post(
|
||
"/api/v1/mcp/publish/skill/my_skill",
|
||
json={},
|
||
headers=_admin_headers(),
|
||
)
|
||
resp = await client.post(
|
||
"/api/v1/mcp/tools/call",
|
||
json={"name": "skill_my_skill", "arguments": {"input": "hello"}},
|
||
headers=_member_headers(),
|
||
)
|
||
assert resp.status_code == 200
|
||
body = resp.json()
|
||
assert "processed: hello" in body["content"][0]["text"]
|
||
pool.run_skill.assert_awaited_once_with("my_skill", "hello")
|
||
|
||
|
||
# ── 10-12: 取消发布与列表 ────────────────────────────────
|
||
|
||
|
||
class TestUnpublishAndList:
|
||
"""DELETE 取消发布 / GET 列出已发布。"""
|
||
|
||
async def test_unpublish_success(self):
|
||
"""场景 10: 发布后 DELETE → 200;随后 /tools/list 不再包含。"""
|
||
skill = _make_mock_skill(name="my_skill")
|
||
registry = _make_mock_skill_registry([skill])
|
||
app = _make_app(skill_registry=registry)
|
||
transport = httpx.ASGITransport(app=app)
|
||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||
await client.post(
|
||
"/api/v1/mcp/publish/skill/my_skill",
|
||
json={},
|
||
headers=_admin_headers(),
|
||
)
|
||
resp = await client.delete(
|
||
"/api/v1/mcp/publish/skill_my_skill",
|
||
headers=_admin_headers(),
|
||
)
|
||
assert resp.status_code == 200
|
||
assert resp.json() == {"unpublished": "skill_my_skill"}
|
||
# 验证 /tools/list 不再包含
|
||
list_resp = await client.get(
|
||
"/api/v1/mcp/tools/list",
|
||
headers=_member_headers(),
|
||
)
|
||
names = {t["name"] for t in list_resp.json()["tools"]}
|
||
assert "skill_my_skill" not in names
|
||
|
||
async def test_unpublish_unknown_404(self):
|
||
"""场景 11: DELETE 不存在的名称 → 404。"""
|
||
app = _make_app(skill_registry=_make_mock_skill_registry([]))
|
||
transport = httpx.ASGITransport(app=app)
|
||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||
resp = await client.delete(
|
||
"/api/v1/mcp/publish/skill_unknown",
|
||
headers=_admin_headers(),
|
||
)
|
||
assert resp.status_code == 404
|
||
|
||
async def test_list_published(self):
|
||
"""场景 12: 发布 2 个 Skill 后 GET /mcp/publish → 200 + 2 条目。"""
|
||
skill_a = _make_mock_skill(name="skill_a")
|
||
skill_b = _make_mock_skill(name="skill_b")
|
||
registry = _make_mock_skill_registry([skill_a, skill_b])
|
||
app = _make_app(skill_registry=registry)
|
||
transport = httpx.ASGITransport(app=app)
|
||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||
await client.post(
|
||
"/api/v1/mcp/publish/skill/skill_a",
|
||
json={},
|
||
headers=_admin_headers(),
|
||
)
|
||
await client.post(
|
||
"/api/v1/mcp/publish/skill/skill_b",
|
||
json={},
|
||
headers=_admin_headers(),
|
||
)
|
||
resp = await client.get(
|
||
"/api/v1/mcp/publish",
|
||
headers=_admin_headers(),
|
||
)
|
||
assert resp.status_code == 200
|
||
names = {item["name"] for item in resp.json()["published"]}
|
||
assert names == {"skill_skill_a", "skill_skill_b"}
|
||
|
||
|
||
# ── 13-14: 团队发布 ──────────────────────────────────────
|
||
|
||
|
||
class TestPublishTeam:
|
||
"""专家团队发布为 MCP 工具。"""
|
||
|
||
async def test_publish_team_success(self):
|
||
"""场景 13: 管理员发布团队 → 200 + team_dev_team。"""
|
||
app = _make_app(skill_registry=_make_mock_skill_registry([]))
|
||
transport = httpx.ASGITransport(app=app)
|
||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||
resp = await client.post(
|
||
"/api/v1/mcp/publish/team/dev_team",
|
||
json={},
|
||
headers=_admin_headers(),
|
||
)
|
||
assert resp.status_code == 200
|
||
assert resp.json() == {"published": "team_dev_team", "type": "team"}
|
||
|
||
async def test_team_visible_in_tools_list(self):
|
||
"""场景 14: 团队发布后 GET /tools/list 包含 team_dev_team。"""
|
||
app = _make_app(skill_registry=_make_mock_skill_registry([]))
|
||
transport = httpx.ASGITransport(app=app)
|
||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||
await client.post(
|
||
"/api/v1/mcp/publish/team/dev_team",
|
||
json={},
|
||
headers=_admin_headers(),
|
||
)
|
||
resp = await client.get(
|
||
"/api/v1/mcp/tools/list",
|
||
headers=_member_headers(),
|
||
)
|
||
names = {t["name"] for t in resp.json()["tools"]}
|
||
assert "team_dev_team" in names
|
||
|
||
|
||
# ── 15-19: PublisherRegistry 与 Adapter 单元测试 ─────────
|
||
|
||
|
||
class TestPublisherRegistry:
|
||
"""PublisherRegistry 基础 CRUD。"""
|
||
|
||
def test_register_unregister_list_get(self):
|
||
"""场景 15: register / unregister / list_published / get 全流程。"""
|
||
reg = PublisherRegistry()
|
||
tool = _make_mock_skill_adapter("a")
|
||
# 适配器名称为 skill_a(带 skill_ 前缀)
|
||
assert tool.name == "skill_a"
|
||
assert reg.list_published() == []
|
||
assert reg.get("skill_a") is None
|
||
|
||
reg.register(tool)
|
||
assert reg.get("skill_a") is tool
|
||
assert [t.name for t in reg.list_published()] == ["skill_a"]
|
||
|
||
assert reg.unregister("skill_a") is True
|
||
assert reg.get("skill_a") is None
|
||
assert reg.list_published() == []
|
||
# 重复注销返回 False
|
||
assert reg.unregister("skill_a") is False
|
||
|
||
def test_register_duplicate_raises(self):
|
||
"""重复 register 同名工具应抛 ValueError。"""
|
||
reg = PublisherRegistry()
|
||
reg.register(_make_mock_skill_adapter("a"))
|
||
with pytest.raises(ValueError):
|
||
reg.register(_make_mock_skill_adapter("a"))
|
||
|
||
|
||
class TestSkillMCPAdapter:
|
||
"""SkillMCPAdapter.execute 行为。"""
|
||
|
||
async def test_execute_without_executor_returns_error(self):
|
||
"""场景 16: executor=None → execute 返回错误 dict。"""
|
||
skill = _make_mock_skill(name="s1")
|
||
adapter = SkillMCPAdapter(skill, executor=None)
|
||
result = await adapter.execute(input="hi")
|
||
assert "error" in result
|
||
assert "executor not configured" in result["error"]
|
||
|
||
async def test_execute_with_raising_executor_returns_error(self):
|
||
"""场景 17: executor 抛异常 → execute 捕获并返回错误 dict。"""
|
||
skill = _make_mock_skill(name="s1")
|
||
|
||
async def _boom(_name: str, _input: str) -> dict[str, Any]:
|
||
raise RuntimeError("boom")
|
||
|
||
adapter = SkillMCPAdapter(skill, executor=_boom)
|
||
result = await adapter.execute(input="hi")
|
||
assert "error" in result
|
||
assert "skill execution failed" in result["error"]
|
||
assert "boom" in result["error"]
|
||
|
||
async def test_execute_with_executor_returns_result(self):
|
||
"""executor 正常返回 → execute 透传结果。"""
|
||
skill = _make_mock_skill(name="s1")
|
||
|
||
async def _ok(_name: str, _input: str) -> dict[str, Any]:
|
||
return {"output": f"done:{_input}"}
|
||
|
||
adapter = SkillMCPAdapter(skill, executor=_ok)
|
||
result = await adapter.execute(input="hi")
|
||
assert result == {"output": "done:hi"}
|
||
|
||
def test_skill_adapter_name_format(self):
|
||
"""适配器名称格式为 skill_{name}。"""
|
||
skill = _make_mock_skill(name="my_skill", description="d")
|
||
adapter = SkillMCPAdapter(skill)
|
||
assert adapter.name == "skill_my_skill"
|
||
assert adapter.description == "d"
|
||
assert adapter.input_schema["required"] == ["input"]
|
||
|
||
|
||
class TestTeamMCPAdapter:
|
||
"""TeamMCPAdapter 名称与执行。"""
|
||
|
||
def test_team_adapter_name_format(self):
|
||
"""场景 18: 团队适配器名称格式为 team_{name}。"""
|
||
adapter = TeamMCPAdapter("dev_team")
|
||
assert adapter.name == "team_dev_team"
|
||
assert "team" in adapter.tags
|
||
|
||
async def test_team_execute_without_executor(self):
|
||
"""团队 executor=None → 返回错误 dict。"""
|
||
adapter = TeamMCPAdapter("dev_team", executor=None)
|
||
result = await adapter.execute(input="x")
|
||
assert "error" in result
|
||
assert "executor not configured" in result["error"]
|
||
|
||
|
||
class TestCheckDangerousPublish:
|
||
"""check_dangerous_publish 校验逻辑。"""
|
||
|
||
def test_safe_skill_no_raise(self):
|
||
"""场景 19: 无危险工具的 Skill 不抛异常。"""
|
||
skill = _make_mock_skill(name="s", tool_names=["web_search", "rag"])
|
||
# 不应抛异常
|
||
check_dangerous_publish(skill, allow_dangerous=False)
|
||
|
||
def test_dangerous_skill_without_optin_raises(self):
|
||
"""含 shell 工具且 allow_dangerous=False → 抛 ValueError。"""
|
||
skill = _make_mock_skill(name="s", tool_names=["shell"])
|
||
with pytest.raises(ValueError):
|
||
check_dangerous_publish(skill, allow_dangerous=False)
|
||
|
||
def test_dangerous_skill_with_optin_no_raise(self):
|
||
"""含 shell 工具且 allow_dangerous=True → 不抛异常。"""
|
||
skill = _make_mock_skill(name="s", tool_names=["shell", "file_write"])
|
||
check_dangerous_publish(skill, allow_dangerous=True)
|
||
|
||
def test_dangerous_tool_names_contains_expected(self):
|
||
"""_DANGEROUS_TOOL_NAMES 包含预期高危工具名。"""
|
||
for name in {"terminal", "shell", "file_write", "file_read", "file_delete"}:
|
||
assert name in _DANGEROUS_TOOL_NAMES
|
||
|
||
|
||
# ── 辅助:构造 SkillMCPAdapter 用于 registry 测试 ─────────
|
||
|
||
|
||
def _make_mock_skill_adapter(name: str) -> SkillMCPAdapter:
|
||
"""构造一个最小 SkillMCPAdapter 用于 registry 测试。"""
|
||
skill = _make_mock_skill(name=name)
|
||
return SkillMCPAdapter(skill, executor=None)
|