641 lines
24 KiB
Python
641 lines
24 KiB
Python
"""Real LLM E2E tests — tests against a live server with real LLM providers.
|
|
|
|
These tests start a real AgentKit server using the project's ``agentkit.yaml``
|
|
configuration and make actual LLM API calls to Bailian (DashScope).
|
|
|
|
Requirements:
|
|
- ``DASHSCOPE_API_KEY`` environment variable (loaded from ``.env``)
|
|
- Network access to ``https://coding.dashscope.aliyuncs.com/v1``
|
|
|
|
Run with::
|
|
|
|
.venv/bin/python -m pytest tests/e2e/test_real_llm_e2e.py -v --timeout=180
|
|
|
|
All tests are marked with ``@pytest.mark.integration`` so they are excluded
|
|
from the default unit-test run (``pytest -m "not integration"``).
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Any, Generator
|
|
|
|
import aiosqlite
|
|
import httpx
|
|
import pytest
|
|
|
|
# Disable HTTP proxies for localhost requests (Clash/V2Ray intercepts localhost).
|
|
os.environ["NO_PROXY"] = "127.0.0.1,localhost"
|
|
os.environ["no_proxy"] = "127.0.0.1,localhost"
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Constants
|
|
# ---------------------------------------------------------------------------
|
|
|
|
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
|
|
|
REAL_LLM_HOST = "127.0.0.1"
|
|
REAL_LLM_PORT = 18766 # dedicated port to avoid conflict with mock E2E (18765)
|
|
REAL_LLM_BASE_URL = f"http://{REAL_LLM_HOST}:{REAL_LLM_PORT}"
|
|
REAL_LLM_WS_URL = f"ws://{REAL_LLM_HOST}:{REAL_LLM_PORT}"
|
|
|
|
# Fixed JWT secret so tokens are deterministic across the session.
|
|
TEST_JWT_SECRET = "test-jwt-secret-for-real-llm-e2e-fixed-do-not-use-in-prod"
|
|
|
|
# Test user credentials (created directly in the auth DB).
|
|
TEST_USERNAME = "real_llm_e2e_user"
|
|
TEST_PASSWORD = "TestPassword123!@#"
|
|
TEST_EMAIL = "real_llm_e2e@example.com"
|
|
|
|
# Model alias from agentkit.yaml (resolves to bailian-coding/qwen3.7-plus).
|
|
TEST_MODEL = "default"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# .env loading
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _load_dotenv_vars(dotenv_path: Path) -> dict[str, str]:
|
|
"""Load env vars from a .env file into a dict (does not touch os.environ)."""
|
|
env_vars: dict[str, str] = {}
|
|
if not dotenv_path.exists():
|
|
return env_vars
|
|
with open(dotenv_path, encoding="utf-8") as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if not line or line.startswith("#"):
|
|
continue
|
|
if "=" not in line:
|
|
continue
|
|
key, _, value = line.partition("=")
|
|
key = key.strip()
|
|
value = value.strip().strip("\"'")
|
|
if key:
|
|
env_vars[key] = value
|
|
return env_vars
|
|
|
|
|
|
def _has_dashscope_key() -> bool:
|
|
"""Return True if DASHSCOPE_API_KEY is available (env or .env file)."""
|
|
if os.environ.get("DASHSCOPE_API_KEY"):
|
|
return True
|
|
dotenv_vars = _load_dotenv_vars(PROJECT_ROOT / ".env")
|
|
return bool(dotenv_vars.get("DASHSCOPE_API_KEY"))
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Test user creation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _create_test_user(auth_db_path: Path) -> None:
|
|
"""Create the test user directly in the SQLite auth DB.
|
|
|
|
Uses bcrypt hashing (rounds=12) via the project's password utility so the
|
|
``/auth/login`` route can verify the password.
|
|
"""
|
|
from agentkit.server.auth.models import init_auth_db
|
|
from agentkit.server.auth.password import hash_password
|
|
|
|
# Ensure the schema exists.
|
|
asyncio.run(init_auth_db(auth_db_path))
|
|
|
|
user_id = str(uuid.uuid4())
|
|
password_hash = hash_password(TEST_PASSWORD)
|
|
now_iso = datetime.now(timezone.utc).isoformat()
|
|
|
|
async def _insert() -> None:
|
|
async with aiosqlite.connect(str(auth_db_path)) as db:
|
|
# Remove any stale row from a previous run.
|
|
await db.execute("DELETE FROM users WHERE username = ?", (TEST_USERNAME,))
|
|
await db.execute(
|
|
"INSERT INTO users "
|
|
"(id, username, email, password_hash, role, is_active, "
|
|
" is_terminal_authorized, is_server_terminal_authorized, "
|
|
" created_at, updated_at) "
|
|
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
|
(
|
|
user_id,
|
|
TEST_USERNAME,
|
|
TEST_EMAIL,
|
|
password_hash,
|
|
"admin", # admin role → full access for tests
|
|
1,
|
|
1,
|
|
1,
|
|
now_iso,
|
|
now_iso,
|
|
),
|
|
)
|
|
await db.commit()
|
|
|
|
asyncio.run(_insert())
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Session-scoped server fixture
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def real_llm_server(
|
|
tmp_path_factory: pytest.TempPathFactory,
|
|
) -> Generator[tuple[str, Path], None, None]:
|
|
"""Start a real AgentKit server with actual LLM providers.
|
|
|
|
Yields ``(base_url, auth_db_path)``. The server uses the project root's
|
|
``agentkit.yaml`` (Bailian coding plan) — no mock providers.
|
|
|
|
Skips the entire session if ``DASHSCOPE_API_KEY`` is not available.
|
|
"""
|
|
if not _has_dashscope_key():
|
|
pytest.skip("DASHSCOPE_API_KEY not set — skipping real LLM E2E tests")
|
|
|
|
tmp_path = tmp_path_factory.mktemp("real_llm_server")
|
|
auth_db_path = tmp_path / "auth.db"
|
|
|
|
# Build subprocess environment.
|
|
env = os.environ.copy()
|
|
|
|
# Disable HTTP proxies so localhost requests don't go through Clash/V2Ray.
|
|
for proxy_var in ("HTTP_PROXY", "HTTPS_PROXY", "http_proxy", "https_proxy", "ALL_PROXY", "all_proxy"):
|
|
env.pop(proxy_var, None)
|
|
env["NO_PROXY"] = "127.0.0.1,localhost"
|
|
env["no_proxy"] = "127.0.0.1,localhost"
|
|
|
|
# Ensure API keys from .env are available to the subprocess.
|
|
dotenv_vars = _load_dotenv_vars(PROJECT_ROOT / ".env")
|
|
for key, value in dotenv_vars.items():
|
|
if not env.get(key):
|
|
env[key] = value
|
|
|
|
# Auth configuration.
|
|
env["AGENTKIT_JWT_SECRET"] = TEST_JWT_SECRET
|
|
env["AGENTKIT_AUTH_DB"] = str(auth_db_path)
|
|
|
|
# GUI mode creates a default chat agent (needed for chat / WebSocket tests).
|
|
env["AGENTKIT_GUI_MODE"] = "1"
|
|
|
|
# Explicit config path (also auto-discovered via CWD, but set explicitly).
|
|
config_path = PROJECT_ROOT / "agentkit.yaml"
|
|
env["AGENTKIT_CONFIG_PATH"] = str(config_path)
|
|
|
|
# Start the server via uvicorn directly (agentkit serve has interactive
|
|
# prompts that fail in non-tty subprocess environments).
|
|
# Redirect stderr to a file so we can read server logs on test failures.
|
|
stderr_log = tmp_path / "server_stderr.log"
|
|
stderr_fh = open(stderr_log, "w", encoding="utf-8")
|
|
try:
|
|
proc = subprocess.Popen(
|
|
[
|
|
sys.executable,
|
|
"-c",
|
|
"import uvicorn; uvicorn.run("
|
|
"'agentkit.server.app:create_app', "
|
|
f"host='{REAL_LLM_HOST}', port={REAL_LLM_PORT}, factory=True)",
|
|
],
|
|
env=env,
|
|
stdout=subprocess.PIPE,
|
|
stderr=stderr_fh,
|
|
cwd=str(PROJECT_ROOT),
|
|
)
|
|
|
|
# Wait for the server to become healthy (max 60s — real LLM server
|
|
# initialization is slower than the mock E2E server).
|
|
base_url = REAL_LLM_BASE_URL
|
|
deadline = time.monotonic() + 60
|
|
ready = False
|
|
while time.monotonic() < deadline:
|
|
if proc.poll() is not None:
|
|
# Process exited early — capture output for diagnostics.
|
|
stdout, stderr = proc.communicate(timeout=5)
|
|
pytest.fail(
|
|
"Real LLM server exited early.\n"
|
|
f"stdout: {stdout.decode()[:2000] if stdout else ''}\n"
|
|
f"stderr: {stderr.decode()[:2000] if stderr else ''}"
|
|
)
|
|
try:
|
|
resp = httpx.get(f"{base_url}/api/v1/health", timeout=2)
|
|
if resp.status_code == 200:
|
|
ready = True
|
|
break
|
|
except httpx.ConnectError:
|
|
pass
|
|
time.sleep(0.5)
|
|
|
|
if not ready:
|
|
proc.terminate()
|
|
try:
|
|
stdout, stderr = proc.communicate(timeout=5)
|
|
except subprocess.TimeoutExpired:
|
|
proc.kill()
|
|
stdout, stderr = proc.communicate()
|
|
pytest.fail(
|
|
"Real LLM server failed to start within 60s.\n"
|
|
f"stdout: {stdout.decode()[:2000] if stdout else ''}\n"
|
|
f"stderr: {stderr.decode()[:2000] if stderr else ''}"
|
|
)
|
|
|
|
# Create the test user now that the server (and auth DB schema) is up.
|
|
_create_test_user(auth_db_path)
|
|
|
|
yield base_url, auth_db_path
|
|
|
|
# Teardown — terminate the server process.
|
|
proc.terminate()
|
|
try:
|
|
proc.wait(timeout=10)
|
|
except subprocess.TimeoutExpired:
|
|
proc.kill()
|
|
proc.wait()
|
|
finally:
|
|
stderr_fh.close()
|
|
|
|
# If the server logged any errors, print them for debugging.
|
|
if stderr_log.exists():
|
|
log_content = stderr_log.read_text(encoding="utf-8", errors="replace")
|
|
if "Error" in log_content or "Traceback" in log_content:
|
|
print(f"\n--- Server stderr log ---\n{log_content[-3000:]}\n--- End server log ---")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Convenience fixtures derived from real_llm_server
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def base_url(real_llm_server: tuple[str, Path]) -> str:
|
|
return real_llm_server[0]
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def auth_db_path(real_llm_server: tuple[str, Path]) -> Path:
|
|
return real_llm_server[1]
|
|
|
|
|
|
def _login_with_retry(
|
|
base_url: str, max_retries: int = 3, delay: float = 1.0
|
|
) -> httpx.Response:
|
|
"""Login with retry on 500 (transient SQLite write-lock contention)."""
|
|
if max_retries <= 0:
|
|
raise ValueError("max_retries must be > 0")
|
|
with httpx.Client(base_url=base_url, timeout=30) as client:
|
|
for attempt in range(max_retries):
|
|
resp = client.post(
|
|
"/api/v1/auth/login",
|
|
json={"username": TEST_USERNAME, "password": TEST_PASSWORD},
|
|
)
|
|
if resp.status_code == 200:
|
|
return resp
|
|
if resp.status_code == 500 and attempt < max_retries - 1:
|
|
time.sleep(delay)
|
|
continue
|
|
return resp
|
|
raise RuntimeError("unreachable: loop should have returned")
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def auth_token(base_url: str) -> str:
|
|
"""Log in once per session and return the access token."""
|
|
resp = _login_with_retry(base_url)
|
|
assert resp.status_code == 200, (
|
|
f"Login failed: {resp.status_code} {resp.text[:1000]}"
|
|
)
|
|
data = resp.json()
|
|
assert "access_token" in data
|
|
return data["access_token"]
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def refresh_token(base_url: str) -> str:
|
|
"""Log in once per session and return the refresh token."""
|
|
resp = _login_with_retry(base_url)
|
|
assert resp.status_code == 200, (
|
|
f"Login failed: {resp.status_code} {resp.text[:1000]}"
|
|
)
|
|
return resp.json()["refresh_token"]
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def auth_headers(auth_token: str) -> dict[str, str]:
|
|
"""Default headers with a Bearer JWT for authenticated requests."""
|
|
return {"Authorization": f"Bearer {auth_token}", "Content-Type": "application/json"}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# 1. Authentication Flow Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.timeout(30)
|
|
class TestAuthFlow:
|
|
"""Verify the JWT authentication flow against the live server."""
|
|
|
|
def test_login_success(self, base_url: str):
|
|
"""POST /auth/login with correct credentials returns a JWT pair."""
|
|
with httpx.Client(base_url=base_url, timeout=30) as client:
|
|
resp = client.post(
|
|
"/api/v1/auth/login",
|
|
json={"username": TEST_USERNAME, "password": TEST_PASSWORD},
|
|
)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert "access_token" in data
|
|
assert "refresh_token" in data
|
|
assert data["token_type"] == "bearer"
|
|
assert data["user"]["username"] == TEST_USERNAME
|
|
assert data["user"]["role"] == "admin"
|
|
|
|
def test_login_wrong_password(self, base_url: str):
|
|
"""POST /auth/login with wrong password returns 401."""
|
|
with httpx.Client(base_url=base_url, timeout=30) as client:
|
|
resp = client.post(
|
|
"/api/v1/auth/login",
|
|
json={"username": TEST_USERNAME, "password": "definitely-wrong"},
|
|
)
|
|
assert resp.status_code == 401
|
|
|
|
def test_me_with_valid_token(self, base_url: str, auth_headers: dict[str, str]):
|
|
"""GET /auth/me with a valid JWT returns the user profile."""
|
|
with httpx.Client(base_url=base_url, timeout=30) as client:
|
|
resp = client.get("/api/v1/auth/me", headers=auth_headers)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["username"] == TEST_USERNAME
|
|
assert data["email"] == TEST_EMAIL
|
|
assert data["role"] == "admin"
|
|
assert data["is_active"] is True
|
|
|
|
def test_me_without_token_returns_401(self, base_url: str):
|
|
"""GET /auth/me without a token returns 401."""
|
|
with httpx.Client(base_url=base_url, timeout=10) as client:
|
|
resp = client.get("/api/v1/auth/me")
|
|
assert resp.status_code == 401
|
|
|
|
def test_refresh_token(self, base_url: str, refresh_token: str):
|
|
"""POST /auth/refresh exchanges a refresh token for a new access token."""
|
|
with httpx.Client(base_url=base_url, timeout=30) as client:
|
|
resp = client.post(
|
|
"/api/v1/auth/refresh",
|
|
json={"refresh_token": refresh_token},
|
|
)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert "access_token" in data
|
|
assert data["user"]["username"] == TEST_USERNAME
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# 2. LLM Gateway Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.timeout(120)
|
|
class TestLLMGateway:
|
|
"""Verify the LLM gateway proxy returns real LLM responses."""
|
|
|
|
def test_chat_non_streaming(self, base_url: str, auth_headers: dict[str, str]):
|
|
"""POST /llm/chat returns a non-empty real LLM response."""
|
|
with httpx.Client(base_url=base_url, timeout=90) as client:
|
|
resp = client.post(
|
|
"/api/v1/llm/chat",
|
|
headers=auth_headers,
|
|
json={
|
|
"messages": [{"role": "user", "content": "你好,请用一句话介绍自己"}],
|
|
"model": TEST_MODEL,
|
|
"temperature": 0.7,
|
|
"max_tokens": 200,
|
|
},
|
|
)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert "content" in data
|
|
content: str = data["content"]
|
|
assert len(content) > 0
|
|
# Real LLM response should contain Chinese characters.
|
|
assert any("\u4e00" <= ch <= "\u9fff" for ch in content)
|
|
assert "model" in data
|
|
assert "usage" in data
|
|
|
|
def test_chat_streaming_sse(self, base_url: str, auth_headers: dict[str, str]):
|
|
"""POST /llm/chat/stream returns SSE chunks with real content."""
|
|
chunks: list[dict[str, Any]] = []
|
|
with httpx.Client(base_url=base_url, timeout=90) as client:
|
|
with client.stream(
|
|
"POST",
|
|
"/api/v1/llm/chat/stream",
|
|
headers=auth_headers,
|
|
json={
|
|
"messages": [{"role": "user", "content": "用一句话说明什么是人工智能"}],
|
|
"model": TEST_MODEL,
|
|
"temperature": 0.7,
|
|
"max_tokens": 200,
|
|
},
|
|
) as resp:
|
|
assert resp.status_code == 200
|
|
for line in resp.iter_lines():
|
|
if not line.startswith("data: "):
|
|
continue
|
|
payload = line[6:]
|
|
if payload == "[DONE]":
|
|
break
|
|
chunks.append(json.loads(payload))
|
|
|
|
assert len(chunks) > 0
|
|
full_content = "".join(c.get("content", "") for c in chunks)
|
|
assert len(full_content) > 0
|
|
assert any("\u4e00" <= ch <= "\u9fff" for ch in full_content)
|
|
|
|
def test_chat_invalid_model_returns_error(self, base_url: str, auth_headers: dict[str, str]):
|
|
"""POST /llm/chat with an unknown model returns 404 or 502."""
|
|
with httpx.Client(base_url=base_url, timeout=30) as client:
|
|
resp = client.post(
|
|
"/api/v1/llm/chat",
|
|
headers=auth_headers,
|
|
json={
|
|
"messages": [{"role": "user", "content": "test"}],
|
|
"model": "nonexistent-model-xyz-12345",
|
|
},
|
|
)
|
|
assert resp.status_code in (404, 502)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# 3. Chat REST API Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.fixture(scope="class")
|
|
def chat_session_id(base_url: str, auth_headers: dict[str, str]) -> str:
|
|
"""Create a chat session bound to the default agent (created in GUI mode)."""
|
|
with httpx.Client(base_url=base_url, timeout=30) as client:
|
|
resp = client.post(
|
|
"/api/v1/chat/sessions",
|
|
headers=auth_headers,
|
|
json={"agent_name": "default"},
|
|
)
|
|
assert resp.status_code in (200, 201), f"Failed to create chat session: {resp.text}"
|
|
return resp.json()["session_id"]
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.timeout(120)
|
|
class TestChatAPI:
|
|
"""Verify the chat REST API returns real LLM responses."""
|
|
|
|
def test_create_session(self, chat_session_id: str):
|
|
"""A chat session is created with a non-empty ID."""
|
|
assert chat_session_id
|
|
assert len(chat_session_id) > 0
|
|
|
|
def test_send_message_and_get_real_response(
|
|
self, base_url: str, auth_headers: dict[str, str], chat_session_id: str
|
|
):
|
|
"""POST /chat/sessions/{id}/messages returns a real LLM reply."""
|
|
with httpx.Client(base_url=base_url, timeout=90) as client:
|
|
resp = client.post(
|
|
f"/api/v1/chat/sessions/{chat_session_id}/messages",
|
|
headers=auth_headers,
|
|
json={"content": "你好,请用一句话介绍自己"},
|
|
)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["role"] == "assistant"
|
|
content: str = data["content"]
|
|
assert len(content) > 0
|
|
# Must not be a mock response.
|
|
assert "mock" not in content.lower()
|
|
# Real LLM response should contain Chinese characters.
|
|
assert any("\u4e00" <= ch <= "\u9fff" for ch in content)
|
|
|
|
def test_message_history_after_conversation(
|
|
self, base_url: str, auth_headers: dict[str, str], chat_session_id: str
|
|
):
|
|
"""GET /chat/sessions/{id}/messages returns user + assistant messages."""
|
|
with httpx.Client(base_url=base_url, timeout=30) as client:
|
|
resp = client.get(
|
|
f"/api/v1/chat/sessions/{chat_session_id}/messages",
|
|
headers=auth_headers,
|
|
)
|
|
assert resp.status_code == 200
|
|
messages = resp.json()
|
|
assert isinstance(messages, list)
|
|
assert len(messages) >= 2 # at least one user + one assistant
|
|
roles = [m["role"] for m in messages]
|
|
assert "user" in roles
|
|
assert "assistant" in roles
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# 4. WebSocket Chat Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.timeout(120)
|
|
class TestWebSocketChat:
|
|
"""Verify the WebSocket chat protocol with real LLM streaming."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_websocket_full_chat_flow(self, base_url: str, auth_token: str):
|
|
"""Connect → send message → receive final_answer with real LLM content."""
|
|
try:
|
|
import websockets
|
|
except ImportError:
|
|
pytest.skip("websockets package not installed")
|
|
|
|
# Create a chat session via REST.
|
|
with httpx.Client(base_url=base_url, timeout=30) as client:
|
|
resp = client.post(
|
|
"/api/v1/chat/sessions",
|
|
headers={
|
|
"Authorization": f"Bearer {auth_token}",
|
|
"Content-Type": "application/json",
|
|
},
|
|
json={"agent_name": "default"},
|
|
)
|
|
assert resp.status_code in (200, 201)
|
|
session_id = resp.json()["session_id"]
|
|
|
|
# Connect to the WebSocket (JWT passed via ?token= query param).
|
|
ws_url = f"{REAL_LLM_WS_URL}/api/v1/chat/ws/{session_id}?token={auth_token}"
|
|
received: list[dict[str, Any]] = []
|
|
|
|
async with websockets.connect(ws_url) as ws: # type: ignore[name-defined]
|
|
# 1. Expect a connected event.
|
|
raw = await asyncio.wait_for(ws.recv(), timeout=10)
|
|
data = json.loads(raw)
|
|
received.append(data)
|
|
assert data["type"] == "connected"
|
|
|
|
# 2. Send a user message.
|
|
await ws.send(json.dumps({"type": "message", "content": "你好,请用一句话介绍自己"}))
|
|
|
|
# 3. Collect events until final_answer / error / timeout.
|
|
deadline = time.monotonic() + 90
|
|
while time.monotonic() < deadline:
|
|
try:
|
|
raw = await asyncio.wait_for(ws.recv(), timeout=90)
|
|
except asyncio.TimeoutError:
|
|
received.append({"type": "timeout"})
|
|
break
|
|
msg = json.loads(raw)
|
|
received.append(msg)
|
|
if msg.get("type") in ("final_answer", "error"):
|
|
break
|
|
|
|
# 4. Assert we got a final_answer (not an error).
|
|
types = [m.get("type") for m in received]
|
|
assert "connected" in types
|
|
final_msgs = [m for m in received if m.get("type") == "final_answer"]
|
|
assert final_msgs, f"Expected final_answer, got event types: {types}"
|
|
|
|
final_content: str = final_msgs[0].get("content", "")
|
|
assert len(final_content) > 0
|
|
# Must not be a mock response.
|
|
assert "mock" not in final_content.lower()
|
|
# Real LLM response should contain Chinese characters.
|
|
assert any("\u4e00" <= ch <= "\u9fff" for ch in final_content)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_websocket_ping_pong(self, base_url: str, auth_token: str):
|
|
"""WebSocket ping/pong heartbeat works alongside the chat session."""
|
|
try:
|
|
import websockets
|
|
except ImportError:
|
|
pytest.skip("websockets package not installed")
|
|
|
|
with httpx.Client(base_url=base_url, timeout=30) as client:
|
|
resp = client.post(
|
|
"/api/v1/chat/sessions",
|
|
headers={
|
|
"Authorization": f"Bearer {auth_token}",
|
|
"Content-Type": "application/json",
|
|
},
|
|
json={"agent_name": "default"},
|
|
)
|
|
assert resp.status_code in (200, 201)
|
|
session_id = resp.json()["session_id"]
|
|
|
|
ws_url = f"{REAL_LLM_WS_URL}/api/v1/chat/ws/{session_id}?token={auth_token}"
|
|
async with websockets.connect(ws_url) as ws: # type: ignore[name-defined]
|
|
# Wait for connected.
|
|
await asyncio.wait_for(ws.recv(), timeout=10)
|
|
|
|
# Send ping → expect pong.
|
|
await ws.send(json.dumps({"type": "ping"}))
|
|
raw = await asyncio.wait_for(ws.recv(), timeout=10)
|
|
msg = json.loads(raw)
|
|
assert msg["type"] == "pong"
|