172 lines
5.4 KiB
Python
172 lines
5.4 KiB
Python
"""Shared test fixtures for fischer-agentkit"""
|
|
|
|
import os
|
|
|
|
# Disable WS heartbeat timeout in test environment to prevent 120s hangs
|
|
# Must be set before importing portal module (which reads this at module level)
|
|
os.environ.setdefault("AGENTKIT_WS_TIMEOUT", "0")
|
|
|
|
import pytest
|
|
from datetime import datetime, timezone
|
|
|
|
from agentkit.core.protocol import AgentCapability, TaskMessage, TaskResult, TaskStatus
|
|
|
|
|
|
# ── Task/Result Factory Fixtures ──────────────────────────
|
|
|
|
|
|
@pytest.fixture
|
|
def make_task():
|
|
"""Factory fixture for creating TaskMessage instances."""
|
|
counter = [0]
|
|
|
|
def _make_task(
|
|
task_id: str | None = None,
|
|
agent_name: str = "test_agent",
|
|
task_type: str = "test_task",
|
|
priority: int = 1,
|
|
input_data: dict | None = None,
|
|
callback_url: str | None = None,
|
|
timeout_seconds: int = 300,
|
|
conversation_id: str | None = None,
|
|
) -> TaskMessage:
|
|
counter[0] += 1
|
|
return TaskMessage(
|
|
task_id=task_id or f"task-{counter[0]:03d}",
|
|
agent_name=agent_name,
|
|
task_type=task_type,
|
|
priority=priority,
|
|
input_data=input_data or {},
|
|
callback_url=callback_url,
|
|
created_at=datetime.now(timezone.utc),
|
|
timeout_seconds=timeout_seconds,
|
|
conversation_id=conversation_id,
|
|
)
|
|
|
|
return _make_task
|
|
|
|
|
|
@pytest.fixture
|
|
def make_result():
|
|
"""Factory fixture for creating TaskResult instances."""
|
|
counter = [0]
|
|
|
|
def _make_result(
|
|
task_id: str | None = None,
|
|
agent_name: str = "test_agent",
|
|
status: str = TaskStatus.COMPLETED,
|
|
output_data: dict | None = None,
|
|
error_message: str | None = None,
|
|
metrics: dict | None = None,
|
|
) -> TaskResult:
|
|
counter[0] += 1
|
|
now = datetime.now(timezone.utc)
|
|
return TaskResult(
|
|
task_id=task_id or f"task-{counter[0]:03d}",
|
|
agent_name=agent_name,
|
|
status=status,
|
|
output_data=output_data or {"result": "ok"},
|
|
error_message=error_message,
|
|
started_at=now,
|
|
completed_at=now,
|
|
metrics=metrics,
|
|
)
|
|
|
|
return _make_result
|
|
|
|
|
|
@pytest.fixture
|
|
def make_capability():
|
|
"""Factory fixture for creating AgentCapability instances."""
|
|
|
|
def _make_capability(
|
|
agent_name: str = "test_agent",
|
|
agent_type: str = "test",
|
|
version: str = "1.0.0",
|
|
supported_tasks: list[str] | None = None,
|
|
max_concurrency: int = 1,
|
|
description: str = "Test agent",
|
|
input_schema: dict | None = None,
|
|
output_schema: dict | None = None,
|
|
) -> AgentCapability:
|
|
return AgentCapability(
|
|
agent_name=agent_name,
|
|
agent_type=agent_type,
|
|
version=version,
|
|
supported_tasks=supported_tasks or ["test_task"],
|
|
max_concurrency=max_concurrency,
|
|
description=description,
|
|
input_schema=input_schema,
|
|
output_schema=output_schema,
|
|
)
|
|
|
|
return _make_capability
|
|
|
|
|
|
# ── Redis Fixtures (requires docker) ─────────────────────
|
|
|
|
|
|
@pytest.fixture
|
|
async def redis_client():
|
|
"""Provide a real Redis client for testing (requires docker-compose.test.yml)."""
|
|
import redis.asyncio as aioredis
|
|
|
|
url = os.environ.get("REDIS_URL", "redis://localhost:6381/0")
|
|
client = aioredis.from_url(url, decode_responses=True)
|
|
try:
|
|
yield client
|
|
finally:
|
|
await client.aclose()
|
|
|
|
|
|
@pytest.fixture
|
|
async def clean_redis(redis_client):
|
|
"""Clean Redis before each test."""
|
|
await redis_client.flushdb()
|
|
yield
|
|
await redis_client.flushdb()
|
|
|
|
|
|
# ── PostgreSQL Fixtures (requires docker) ─────────────────
|
|
|
|
|
|
@pytest.fixture
|
|
async def pg_session_factory():
|
|
"""Provide an async SQLAlchemy session factory for testing (requires docker-compose.test.yml)."""
|
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
|
url = os.environ.get("DATABASE_URL", "postgresql+asyncpg://agentkit_test:agentkit_test_pw@localhost:5434/agentkit_test")
|
|
engine = create_async_engine(url, echo=False)
|
|
factory = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
|
|
|
yield factory
|
|
|
|
await engine.dispose()
|
|
|
|
|
|
@pytest.fixture
|
|
async def clean_db(pg_session_factory):
|
|
"""Clean database tables before each test."""
|
|
yield
|
|
# Cleanup after test - truncate all tables
|
|
async with pg_session_factory() as session:
|
|
from sqlalchemy import text
|
|
# Get all table names and truncate
|
|
result = await session.execute(text(
|
|
"SELECT tablename FROM pg_tables WHERE schemaname = 'public'"
|
|
))
|
|
tables = [row[0] for row in result]
|
|
if tables:
|
|
await session.execute(text(f"TRUNCATE TABLE {', '.join(tables)} CASCADE"))
|
|
await session.commit()
|
|
|
|
|
|
# ── Pytest Markers ────────────────────────────────────────
|
|
|
|
|
|
def pytest_configure(config):
|
|
config.addinivalue_line("markers", "integration: mark test as integration test (requires docker)")
|
|
config.addinivalue_line("markers", "redis: mark test as requiring Redis")
|
|
config.addinivalue_line("markers", "postgres: mark test as requiring PostgreSQL")
|