192 lines
7.3 KiB
Python
192 lines
7.3 KiB
Python
"""FastAPI Application Factory"""
|
|
|
|
import os
|
|
from contextlib import asynccontextmanager
|
|
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
from agentkit.core.agent_pool import AgentPool
|
|
from agentkit.llm.gateway import LLMGateway
|
|
from agentkit.llm.providers.openai import OpenAICompatibleProvider
|
|
from agentkit.quality.gate import QualityGate
|
|
from agentkit.quality.output import OutputStandardizer
|
|
from agentkit.router.intent import IntentRouter
|
|
from agentkit.skills.base import Skill, SkillConfig
|
|
from agentkit.skills.registry import SkillRegistry
|
|
from agentkit.tools.registry import ToolRegistry
|
|
from agentkit.server.config import ServerConfig
|
|
from agentkit.server.routes import agents, tasks, skills, llm, health, metrics
|
|
from agentkit.server.middleware import APIKeyAuthMiddleware, RateLimitMiddleware
|
|
from agentkit.server.task_store import create_task_store
|
|
from agentkit.server.runner import BackgroundRunner
|
|
from agentkit.core.logging import setup_structured_logging
|
|
|
|
|
|
def _build_llm_gateway(config: ServerConfig) -> LLMGateway:
|
|
"""Build LLMGateway from ServerConfig, registering all providers."""
|
|
gateway = LLMGateway(config=config.llm_config)
|
|
|
|
for name, pconf in config.llm_config.providers.items():
|
|
if not pconf.api_key:
|
|
continue # Skip providers without API keys
|
|
try:
|
|
provider = OpenAICompatibleProvider(
|
|
api_key=pconf.api_key,
|
|
base_url=pconf.base_url,
|
|
)
|
|
gateway.register_provider(name, provider)
|
|
except Exception as e:
|
|
import logging
|
|
logging.getLogger(__name__).warning(f"Failed to register LLM provider '{name}': {e}")
|
|
|
|
return gateway
|
|
|
|
|
|
def _build_skill_registry(config: ServerConfig) -> SkillRegistry:
|
|
"""Build SkillRegistry from ServerConfig, loading all skill configs."""
|
|
registry = SkillRegistry()
|
|
skill_configs = config.load_skill_configs()
|
|
for skill_config in skill_configs:
|
|
skill = Skill(config=skill_config)
|
|
registry.register(skill)
|
|
return registry
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
# Startup
|
|
task_store = app.state.task_store
|
|
await task_store.start_cleanup()
|
|
yield
|
|
# Shutdown
|
|
await task_store.stop_cleanup()
|
|
|
|
|
|
def create_app(
|
|
llm_gateway: LLMGateway | None = None,
|
|
skill_registry: SkillRegistry | None = None,
|
|
tool_registry: ToolRegistry | None = None,
|
|
api_key: str | None = None,
|
|
rate_limit: int | None = None,
|
|
server_config: ServerConfig | None = None,
|
|
) -> FastAPI:
|
|
"""Create and configure the FastAPI application
|
|
|
|
When called by uvicorn (factory=True), automatically loads ServerConfig
|
|
from AGENTKIT_CONFIG_PATH env var if server_config is not provided.
|
|
"""
|
|
# Auto-load config from env var if not provided (uvicorn factory mode)
|
|
if server_config is None:
|
|
config_path = os.environ.get("AGENTKIT_CONFIG_PATH")
|
|
if config_path and os.path.exists(config_path):
|
|
server_config = ServerConfig.from_yaml(config_path)
|
|
app = FastAPI(title="AgentKit Server", version="2.0.0", lifespan=lifespan)
|
|
|
|
# Initialize structured logging
|
|
setup_structured_logging()
|
|
|
|
# Resolve effective API key and rate limit
|
|
effective_api_key = api_key
|
|
effective_rate_limit = rate_limit
|
|
if server_config:
|
|
if effective_api_key is None:
|
|
effective_api_key = server_config.api_key
|
|
if effective_rate_limit is None:
|
|
effective_rate_limit = server_config.rate_limit
|
|
|
|
# CORS 配置
|
|
cors_origins = ["*"]
|
|
if server_config:
|
|
cors_origins = server_config.cors_origins
|
|
if cors_origins == ["*"]:
|
|
import logging
|
|
logging.getLogger(__name__).warning(
|
|
"CORS allows all origins (allow_origins=['*']). "
|
|
"Set server.cors_origins in agentkit.yaml for production."
|
|
)
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=cors_origins,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# Auth middleware
|
|
app.add_middleware(APIKeyAuthMiddleware, api_key=effective_api_key)
|
|
|
|
# Rate limiting middleware
|
|
if effective_rate_limit is not None:
|
|
os.environ["AGENTKIT_RATE_LIMIT_PER_MINUTE"] = str(effective_rate_limit)
|
|
app.add_middleware(RateLimitMiddleware)
|
|
|
|
# Build LLM Gateway from config if not provided
|
|
if llm_gateway is None and server_config:
|
|
llm_gateway = _build_llm_gateway(server_config)
|
|
|
|
# Build Skill Registry from config if not provided
|
|
if skill_registry is None and server_config:
|
|
skill_registry = _build_skill_registry(server_config)
|
|
|
|
# Initialize shared state
|
|
app.state.llm_gateway = llm_gateway or LLMGateway()
|
|
app.state.skill_registry = skill_registry or SkillRegistry()
|
|
app.state.tool_registry = tool_registry or ToolRegistry()
|
|
app.state.agent_pool = AgentPool(
|
|
llm_gateway=app.state.llm_gateway,
|
|
skill_registry=app.state.skill_registry,
|
|
tool_registry=app.state.tool_registry,
|
|
)
|
|
app.state.intent_router = IntentRouter(llm_gateway=app.state.llm_gateway)
|
|
app.state.quality_gate = QualityGate()
|
|
app.state.output_standardizer = OutputStandardizer()
|
|
# Initialize task store from config
|
|
ts_config = server_config.task_store if server_config else {}
|
|
# Merge CLI overrides from AGENTKIT_TASK_STORE env var
|
|
ts_env = os.environ.get("AGENTKIT_TASK_STORE")
|
|
if ts_env:
|
|
import json as _json
|
|
try:
|
|
ts_config = {**ts_config, **_json.loads(ts_env)}
|
|
except Exception:
|
|
pass
|
|
task_store = create_task_store(
|
|
backend=ts_config.get("backend", "memory"),
|
|
redis_url=ts_config.get("redis_url", "redis://localhost:6379/0"),
|
|
ttl_seconds=ts_config.get("ttl_seconds", 3600),
|
|
max_records=ts_config.get("max_records", 10000),
|
|
)
|
|
app.state.task_store = task_store
|
|
app.state.runner = BackgroundRunner(task_store=app.state.task_store)
|
|
app.state.server_config = server_config
|
|
|
|
# Initialize memory components if configured
|
|
if server_config and hasattr(server_config, 'memory') and server_config.memory:
|
|
try:
|
|
from agentkit.memory.retriever import MemoryRetriever
|
|
from agentkit.memory.working import WorkingMemory
|
|
|
|
working = None
|
|
if server_config.memory.get("working", {}).get("enabled"):
|
|
import redis.asyncio as aioredis
|
|
redis_url = server_config.memory["working"].get("redis_url", "redis://localhost:6379")
|
|
redis_client = aioredis.from_url(redis_url, decode_responses=True)
|
|
working = WorkingMemory(redis=redis_client)
|
|
|
|
memory_retriever = MemoryRetriever(working_memory=working)
|
|
app.state.memory_retriever = memory_retriever
|
|
except Exception as e:
|
|
import logging
|
|
logging.getLogger(__name__).warning(f"Failed to initialize memory components: {e}")
|
|
app.state.memory_retriever = None
|
|
|
|
# Include routes
|
|
app.include_router(agents.router, prefix="/api/v1")
|
|
app.include_router(tasks.router, prefix="/api/v1")
|
|
app.include_router(skills.router, prefix="/api/v1")
|
|
app.include_router(llm.router, prefix="/api/v1")
|
|
app.include_router(health.router, prefix="/api/v1")
|
|
app.include_router(metrics.router, prefix="/api/v1")
|
|
|
|
return app
|