674 lines
31 KiB
Python
674 lines
31 KiB
Python
"""FastAPI Application Factory"""
|
||
|
||
import asyncio
|
||
import logging
|
||
import os
|
||
from contextlib import asynccontextmanager
|
||
|
||
from fastapi import FastAPI, Request
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
|
||
from agentkit.core.agent_pool import AgentPool
|
||
from agentkit.llm.gateway import LLMGateway
|
||
from agentkit.llm.providers.anthropic import AnthropicProvider
|
||
from agentkit.llm.providers.gemini import GeminiProvider
|
||
from agentkit.llm.providers.openai import OpenAICompatibleProvider
|
||
from agentkit.mcp.manager import MCPManager
|
||
from agentkit.quality.gate import QualityGate
|
||
from agentkit.quality.output import OutputStandardizer
|
||
from agentkit.router.intent import IntentRouter
|
||
from agentkit.skills.base import Skill, SkillConfig
|
||
from agentkit.skills.registry import SkillRegistry
|
||
from agentkit.tools.registry import ToolRegistry
|
||
from agentkit.server.config import ServerConfig
|
||
from agentkit.server.routes import agents, tasks, skills, llm, health, metrics, ws, evolution, memory, portal, evolution_dashboard, kb_management, skill_management, workflows, chat
|
||
from agentkit.server.middleware import APIKeyAuthMiddleware, RateLimitMiddleware
|
||
from agentkit.server.task_store import create_task_store
|
||
from agentkit.server.runner import BackgroundRunner
|
||
from agentkit.core.logging import setup_structured_logging
|
||
from agentkit.telemetry.setup import setup_telemetry
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
_ALLOWED_ENV_PREFIXES = (
|
||
'AGENTKIT_', 'DASHSCOPE_', 'OPENAI_', 'ANTHROPIC_', 'GEMINI_',
|
||
'TAVILY_', 'SERPER_', 'DEEPSEEK_',
|
||
)
|
||
_ALLOWED_ENV_EXACT = {'DATABASE_URL', 'REDIS_URL'}
|
||
|
||
|
||
def _build_llm_gateway(config: ServerConfig) -> LLMGateway:
|
||
"""Build LLMGateway from ServerConfig, registering all providers."""
|
||
gateway = LLMGateway(config=config.llm_config)
|
||
|
||
for name, pconf in config.llm_config.providers.items():
|
||
if not pconf.api_key:
|
||
continue # Skip providers without API keys
|
||
try:
|
||
if pconf.type == "anthropic":
|
||
provider = AnthropicProvider(
|
||
api_key=pconf.api_key,
|
||
model=list(pconf.models.keys())[0] if pconf.models else "claude-sonnet-4-20250514",
|
||
max_tokens=pconf.max_tokens,
|
||
base_url=pconf.base_url or "https://api.anthropic.com",
|
||
timeout=pconf.timeout,
|
||
max_connections=pconf.max_connections,
|
||
max_keepalive_connections=pconf.max_keepalive_connections,
|
||
keepalive_expiry=pconf.keepalive_expiry,
|
||
)
|
||
elif pconf.type == "gemini":
|
||
provider = GeminiProvider(
|
||
api_key=pconf.api_key,
|
||
model=list(pconf.models.keys())[0] if pconf.models else "gemini-2.0-flash",
|
||
max_output_tokens=pconf.max_tokens,
|
||
base_url=pconf.base_url or "https://generativelanguage.googleapis.com",
|
||
timeout=pconf.timeout,
|
||
max_connections=pconf.max_connections,
|
||
max_keepalive_connections=pconf.max_keepalive_connections,
|
||
keepalive_expiry=pconf.keepalive_expiry,
|
||
)
|
||
else:
|
||
provider = OpenAICompatibleProvider(
|
||
api_key=pconf.api_key,
|
||
base_url=pconf.base_url,
|
||
max_connections=pconf.max_connections,
|
||
max_keepalive_connections=pconf.max_keepalive_connections,
|
||
keepalive_expiry=pconf.keepalive_expiry,
|
||
)
|
||
gateway.register_provider(name, provider)
|
||
except Exception as e:
|
||
import logging
|
||
logging.getLogger(__name__).warning(f"Failed to register LLM provider '{name}': {e}")
|
||
|
||
return gateway
|
||
|
||
|
||
def _build_skill_registry(config: ServerConfig) -> SkillRegistry:
|
||
"""Build SkillRegistry from ServerConfig, loading all skill configs."""
|
||
registry = SkillRegistry()
|
||
skill_configs = config.load_skill_configs()
|
||
for skill_config in skill_configs:
|
||
skill = Skill(config=skill_config)
|
||
registry.register(skill)
|
||
return registry
|
||
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
# Startup
|
||
task_store = app.state.task_store
|
||
await task_store.start_cleanup()
|
||
|
||
# Start config watcher if server_config is available
|
||
server_config = getattr(app.state, "server_config", None)
|
||
if server_config is not None and server_config._config_path:
|
||
server_config.on_change = lambda cfg: _on_config_change(app, cfg)
|
||
server_config.watch_config()
|
||
logger.info("Config hot-reload enabled")
|
||
|
||
# Start MCP servers if configured
|
||
mcp_manager = getattr(app.state, "mcp_manager", None)
|
||
if mcp_manager is not None:
|
||
await mcp_manager.start_all()
|
||
|
||
# In GUI mode, ensure a default chat agent exists with memory + tools
|
||
gui_mode = os.environ.get("AGENTKIT_GUI_MODE")
|
||
if gui_mode and not app.state.agent_pool.list_agents():
|
||
from agentkit.core.config_driven import AgentConfig
|
||
from agentkit.memory.profile import MemoryStore
|
||
from agentkit.tools.memory_tool import MemoryTool
|
||
from agentkit.tools.shell import ShellTool
|
||
from agentkit.tools.web_search import WebSearchTool
|
||
from agentkit.tools.web_crawl import WebCrawlTool
|
||
from agentkit.tools.baidu_search import BaiduSearchTool
|
||
|
||
# Initialize memory store and build system prompt
|
||
memory_store = MemoryStore()
|
||
memory_store.ensure_defaults()
|
||
memory_snapshot = memory_store.load_all()
|
||
base_prompt = (
|
||
"你是一个有帮助的AI助手。请记住我们对话的上下文,并在后续对话中引用之前的内容。回答要清晰简洁,请使用中文回复。\n\n"
|
||
"重要提示:当你不确定事实信息、时事新闻或任何你不确信的话题时,"
|
||
"你必须先使用搜索工具查找准确和最新的信息,然后再回答。"
|
||
"中文内容优先使用 baidu_search 工具,英文/国际内容使用 web_search。"
|
||
"在能够搜索到真相的情况下,绝不猜测或编造答案。"
|
||
"始终优先搜索而不是给出可能不正确的信息。"
|
||
)
|
||
effective_system_prompt = memory_store.build_system_prompt(memory_snapshot, base_prompt)
|
||
|
||
# Store memory_store on app.state for chat routes to use
|
||
app.state.memory_store = memory_store
|
||
|
||
default_config = AgentConfig(
|
||
name="default",
|
||
agent_type="chat",
|
||
task_mode="llm_generate",
|
||
description="Default chat agent for GUI",
|
||
prompt={"system": effective_system_prompt},
|
||
)
|
||
try:
|
||
agent = await app.state.agent_pool.create_agent(default_config)
|
||
|
||
# Register tools into the agent's tool registry
|
||
search_api_keys = {
|
||
"tavily_api_key": os.environ.get("TAVILY_API_KEY"),
|
||
"serper_api_key": os.environ.get("SERPER_API_KEY"),
|
||
}
|
||
agent._tool_registry.register(MemoryTool(memory_store=memory_store))
|
||
agent._tool_registry.register(ShellTool())
|
||
agent._tool_registry.register(BaiduSearchTool())
|
||
agent._tool_registry.register(WebSearchTool(**search_api_keys))
|
||
agent._tool_registry.register(WebCrawlTool())
|
||
|
||
# Override system prompt with memory-injected version
|
||
agent._system_prompt = effective_system_prompt
|
||
|
||
logger.info("GUI mode: created default chat agent with memory + tools")
|
||
except Exception as e:
|
||
logger.warning(f"GUI mode: failed to create default agent: {e}")
|
||
|
||
# Load skills from config and register into SkillRegistry
|
||
try:
|
||
from agentkit.skills.loader import SkillLoader
|
||
skill_registry = app.state.skill_registry
|
||
tool_registry = app.state.tool_registry
|
||
|
||
# Register GUI tools into the shared tool registry so skills can bind them
|
||
for tool in agent._tool_registry.list_tools():
|
||
try:
|
||
tool_registry.register(tool)
|
||
except Exception:
|
||
pass # Already registered
|
||
|
||
# Load skills from configured paths
|
||
server_config = getattr(app.state, "server_config", None)
|
||
if server_config and server_config.skill_paths:
|
||
loader = SkillLoader(
|
||
skill_registry=skill_registry,
|
||
tool_registry=tool_registry,
|
||
)
|
||
for skill_path in server_config.skill_paths:
|
||
from pathlib import Path as _P
|
||
p = _P(skill_path)
|
||
if p.is_dir():
|
||
loaded = loader.load_from_directory(str(p))
|
||
logger.info(f"GUI mode: loaded {len(loaded)} skills from {p}")
|
||
elif p.is_file() and p.suffix in (".yaml", ".yml"):
|
||
try:
|
||
loader.load_from_file(str(p))
|
||
logger.info(f"GUI mode: loaded skill from {p}")
|
||
except Exception as se:
|
||
logger.warning(f"GUI mode: failed to load skill from {p}: {se}")
|
||
|
||
logger.info(f"GUI mode: {len(skill_registry.list_skills())} skills registered")
|
||
except Exception as e:
|
||
logger.warning(f"GUI mode: failed to load skills: {e}")
|
||
elif gui_mode:
|
||
# Agent already exists (e.g. from config), still ensure memory store is available
|
||
if not hasattr(app.state, "memory_store") or app.state.memory_store is None:
|
||
from agentkit.memory.profile import MemoryStore
|
||
memory_store = MemoryStore()
|
||
memory_store.ensure_defaults()
|
||
app.state.memory_store = memory_store
|
||
|
||
yield
|
||
|
||
# Shutdown
|
||
# Stop MCP servers
|
||
if mcp_manager is not None:
|
||
await mcp_manager.stop_all()
|
||
|
||
# Close Redis client for working memory
|
||
working_redis = getattr(app.state, "working_redis_client", None)
|
||
if working_redis is not None:
|
||
await working_redis.aclose()
|
||
|
||
if server_config is not None:
|
||
server_config.stop_watching()
|
||
|
||
await task_store.stop_cleanup()
|
||
|
||
|
||
def _on_config_change(app: FastAPI, config: ServerConfig) -> None:
|
||
"""Handle config change by reloading affected components.
|
||
|
||
Implements graceful rolling update:
|
||
- New tasks use the new configuration
|
||
- In-progress tasks continue with their original configuration
|
||
- Config version is incremented for audit tracking
|
||
|
||
Uses a lock to prevent concurrent config reloads from racing.
|
||
Thread-safe: uses threading.Event for cross-thread signaling.
|
||
"""
|
||
import threading
|
||
|
||
lock: asyncio.Lock = app.state._config_reload_lock
|
||
|
||
# Thread-safe: set pending flag via threading.Event or call_soon_threadsafe
|
||
if not hasattr(app.state, "_config_reload_event"):
|
||
app.state._config_reload_event = threading.Event()
|
||
|
||
app.state._config_reload_event.set()
|
||
|
||
async def _reload():
|
||
if lock.locked():
|
||
return # Another reload running; it will check pending flag
|
||
async with lock:
|
||
while app.state._config_reload_event.is_set():
|
||
app.state._config_reload_event.clear()
|
||
# Increment config version for audit
|
||
current_version = getattr(app.state, "config_version", 0) + 1
|
||
app.state.config_version = current_version
|
||
logger.info(f"Config change detected (v{current_version}), reloading...")
|
||
|
||
# Rebuild LLMGateway if llm config changed
|
||
try:
|
||
new_gateway = _build_llm_gateway(config)
|
||
app.state.llm_gateway = new_gateway
|
||
# Also update the agent pool's gateway reference
|
||
if hasattr(app.state, "agent_pool") and app.state.agent_pool is not None:
|
||
app.state.agent_pool._llm_gateway = new_gateway
|
||
if hasattr(app.state, "intent_router") and app.state.intent_router is not None:
|
||
app.state.intent_router._llm_gateway = new_gateway
|
||
logger.info(f"LLM Gateway reloaded (config v{current_version})")
|
||
except Exception as e:
|
||
logger.error(f"Failed to reload LLM Gateway: {e}")
|
||
|
||
# Reload skills if skill paths changed
|
||
try:
|
||
new_skill_registry = _build_skill_registry(config)
|
||
# Re-bind tools from the shared tool_registry so skills don't lose their bindings
|
||
tool_registry = getattr(app.state, "tool_registry", None)
|
||
if tool_registry:
|
||
from agentkit.skills.loader import SkillLoader
|
||
loader = SkillLoader(
|
||
skill_registry=new_skill_registry,
|
||
tool_registry=tool_registry,
|
||
)
|
||
for skill_path in (config.skill_paths or []):
|
||
from pathlib import Path as _P
|
||
p = _P(skill_path)
|
||
if p.is_dir():
|
||
loader.load_from_directory(str(p))
|
||
elif p.is_file() and p.suffix in (".yaml", ".yml"):
|
||
try:
|
||
loader.load_from_file(str(p))
|
||
except Exception:
|
||
pass
|
||
app.state.skill_registry = new_skill_registry
|
||
if hasattr(app.state, "agent_pool") and app.state.agent_pool is not None:
|
||
app.state.agent_pool._skill_registry = new_skill_registry
|
||
logger.info(f"Skills reloaded (config v{current_version})")
|
||
except Exception as e:
|
||
logger.error(f"Failed to reload skills: {e}")
|
||
|
||
# Update config version on all agents
|
||
if hasattr(app.state, "agent_pool") and app.state.agent_pool is not None:
|
||
for agent in app.state.agent_pool._agents.values():
|
||
if hasattr(agent, "_config_version"):
|
||
agent._config_version = current_version
|
||
|
||
logger.info(f"Config reload complete (v{current_version})")
|
||
|
||
# Schedule the reload as a task (non-blocking for the watcher thread)
|
||
try:
|
||
loop = asyncio.get_running_loop()
|
||
loop.create_task(_reload())
|
||
except RuntimeError:
|
||
logger.warning("No running event loop, config reload deferred")
|
||
|
||
|
||
def create_app(
|
||
llm_gateway: LLMGateway | None = None,
|
||
skill_registry: SkillRegistry | None = None,
|
||
tool_registry: ToolRegistry | None = None,
|
||
api_key: str | None = None,
|
||
rate_limit: int | None = None,
|
||
server_config: ServerConfig | None = None,
|
||
) -> FastAPI:
|
||
"""Create and configure the FastAPI application
|
||
|
||
When called by uvicorn (factory=True), automatically loads ServerConfig
|
||
from AGENTKIT_CONFIG_PATH env var if server_config is not provided.
|
||
"""
|
||
# Auto-load config from env var if not provided (uvicorn factory mode)
|
||
if server_config is None:
|
||
config_path = os.environ.get("AGENTKIT_CONFIG_PATH")
|
||
if config_path and os.path.exists(config_path):
|
||
# Load .env before parsing config (so ${ENV_VAR} substitutions work)
|
||
from pathlib import Path as _P
|
||
_dotenv = _P(config_path).parent / ".env"
|
||
if _dotenv.exists():
|
||
with open(_dotenv, encoding="utf-8") as _f:
|
||
for _line in _f:
|
||
_line = _line.strip()
|
||
if not _line or _line.startswith("#") or "=" not in _line:
|
||
continue
|
||
_key, _, _val = _line.partition("=")
|
||
_key = _key.strip()
|
||
_val = _val.strip().strip("\"'")
|
||
if _key and _key not in os.environ:
|
||
allowed = any(_key.startswith(p) for p in _ALLOWED_ENV_PREFIXES) or _key in _ALLOWED_ENV_EXACT
|
||
if not allowed:
|
||
logger.warning(f"Skipping .env variable '{_key}' (not in allowed prefixes)")
|
||
continue
|
||
os.environ[_key] = _val
|
||
server_config = ServerConfig.from_yaml(config_path)
|
||
app = FastAPI(title="AgentKit Server", version="2.0.0", lifespan=lifespan)
|
||
|
||
# Initialize structured logging
|
||
setup_structured_logging()
|
||
|
||
# Initialize OpenTelemetry (no-op if not installed or not configured)
|
||
if server_config:
|
||
setup_telemetry(app, server_config.telemetry)
|
||
|
||
# Resolve effective API key and rate limit
|
||
effective_api_key = api_key
|
||
effective_rate_limit = rate_limit
|
||
if server_config:
|
||
if effective_api_key is None:
|
||
effective_api_key = server_config.api_key
|
||
if effective_rate_limit is None:
|
||
effective_rate_limit = server_config.rate_limit
|
||
|
||
# CORS 配置
|
||
cors_origins = ["*"]
|
||
if server_config:
|
||
cors_origins = server_config.cors_origins
|
||
if cors_origins == ["*"]:
|
||
import logging
|
||
logging.getLogger(__name__).warning(
|
||
"CORS allows all origins (allow_origins=['*']). "
|
||
"Set server.cors_origins in agentkit.yaml for production."
|
||
)
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=cors_origins,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
# Auth middleware
|
||
app.add_middleware(APIKeyAuthMiddleware, api_key=effective_api_key)
|
||
|
||
# Rate limiting middleware
|
||
if effective_rate_limit is not None:
|
||
os.environ["AGENTKIT_RATE_LIMIT_PER_MINUTE"] = str(effective_rate_limit)
|
||
app.add_middleware(RateLimitMiddleware)
|
||
|
||
# Build LLM Gateway from config if not provided
|
||
if llm_gateway is None and server_config:
|
||
llm_gateway = _build_llm_gateway(server_config)
|
||
|
||
# Build Skill Registry from config if not provided
|
||
if skill_registry is None and server_config:
|
||
skill_registry = _build_skill_registry(server_config)
|
||
|
||
# Initialize shared state
|
||
app.state.llm_gateway = llm_gateway or LLMGateway()
|
||
app.state.skill_registry = skill_registry or SkillRegistry()
|
||
app.state.tool_registry = tool_registry or ToolRegistry()
|
||
# Initialize MCPManager if MCP servers are configured
|
||
if server_config and server_config.mcp_servers:
|
||
mcp_manager = MCPManager(
|
||
configs=server_config.mcp_servers,
|
||
tool_registry=app.state.tool_registry,
|
||
)
|
||
app.state.mcp_manager = mcp_manager
|
||
else:
|
||
app.state.mcp_manager = None
|
||
# Initialize compressor if compression is configured
|
||
from agentkit.core.compressor import create_compressor
|
||
compressor = create_compressor(server_config.compression) if server_config else None
|
||
app.state.compressor = compressor
|
||
# Register headroom_retrieve tool if HeadroomCompressor is active
|
||
if compressor is not None:
|
||
try:
|
||
from agentkit.core.headroom_compressor import HeadroomCompressor
|
||
if isinstance(compressor, HeadroomCompressor) and compressor.is_available():
|
||
from agentkit.tools.headroom_retrieve import HeadroomRetrieveTool
|
||
retrieve_tool = HeadroomRetrieveTool(compressor=compressor)
|
||
app.state.tool_registry.register(retrieve_tool)
|
||
logger.info("HeadroomRetrieveTool registered (CCR retrieval enabled)")
|
||
except ImportError:
|
||
pass
|
||
# Initialize MessageBus for inter-agent communication
|
||
from agentkit.bus.redis_bus import create_message_bus
|
||
bus_config = {}
|
||
if server_config and hasattr(server_config, "bus") and server_config.bus:
|
||
bus_config = server_config.bus
|
||
message_bus = create_message_bus(
|
||
backend=bus_config.get("backend", "memory"),
|
||
redis_url=bus_config.get("redis_url", "redis://localhost:6379/0"),
|
||
)
|
||
app.state.message_bus = message_bus
|
||
|
||
app.state.agent_pool = AgentPool(
|
||
llm_gateway=app.state.llm_gateway,
|
||
skill_registry=app.state.skill_registry,
|
||
tool_registry=app.state.tool_registry,
|
||
compressor=compressor,
|
||
message_bus=message_bus,
|
||
)
|
||
app.state.intent_router = IntentRouter(llm_gateway=app.state.llm_gateway)
|
||
app.state.quality_gate = QualityGate()
|
||
app.state.output_standardizer = OutputStandardizer()
|
||
|
||
# Initialize OrganizationContext from AgentPool + SkillRegistry
|
||
from agentkit.org.context import OrganizationContext
|
||
org_context = OrganizationContext.from_agent_pool(
|
||
agent_pool=app.state.agent_pool,
|
||
skill_registry=app.state.skill_registry,
|
||
)
|
||
app.state.org_context = org_context
|
||
|
||
# Initialize AlignmentGuard from config
|
||
from agentkit.quality.alignment import AlignmentGuard, AlignmentConfig
|
||
alignment_config_data = {}
|
||
if server_config and hasattr(server_config, "alignment") and server_config.alignment:
|
||
alignment_config_data = server_config.alignment
|
||
alignment_config = AlignmentConfig.from_dict(alignment_config_data)
|
||
alignment_guard = AlignmentGuard(config=alignment_config, llm_gateway=app.state.llm_gateway)
|
||
app.state.alignment_guard = alignment_guard
|
||
|
||
# Initialize CostAwareRouter
|
||
from agentkit.chat.skill_routing import CostAwareRouter
|
||
auction_enabled = False
|
||
if server_config and hasattr(server_config, "marketplace") and server_config.marketplace:
|
||
auction_enabled = server_config.marketplace.get("auction_enabled", False)
|
||
cost_aware_router = CostAwareRouter(
|
||
llm_gateway=app.state.llm_gateway,
|
||
org_context=org_context,
|
||
auction_enabled=auction_enabled,
|
||
classifier=server_config.router.get("classifier", "heuristic") if server_config and server_config.router else "heuristic",
|
||
)
|
||
app.state.cost_aware_router = cost_aware_router
|
||
# Initialize task store from config
|
||
ts_config = server_config.task_store if server_config else {}
|
||
# Merge CLI overrides from AGENTKIT_TASK_STORE env var
|
||
ts_env = os.environ.get("AGENTKIT_TASK_STORE")
|
||
if ts_env:
|
||
import json as _json
|
||
try:
|
||
ts_config = {**ts_config, **_json.loads(ts_env)}
|
||
except Exception:
|
||
pass
|
||
task_store = create_task_store(
|
||
backend=ts_config.get("backend", "memory"),
|
||
redis_url=ts_config.get("redis_url", "redis://localhost:6379/0"),
|
||
ttl_seconds=ts_config.get("ttl_seconds", 3600),
|
||
max_records=ts_config.get("max_records", 10000),
|
||
)
|
||
app.state.task_store = task_store
|
||
app.state.runner = BackgroundRunner(task_store=app.state.task_store)
|
||
app.state.server_config = server_config
|
||
app.state.api_key = effective_api_key
|
||
app.state._config_reload_lock = asyncio.Lock()
|
||
|
||
# Initialize session manager for Chat mode
|
||
from agentkit.session.manager import SessionManager
|
||
from agentkit.session.store import create_session_store
|
||
session_config = {}
|
||
if server_config and hasattr(server_config, "session") and server_config.session:
|
||
session_config = server_config.session
|
||
# GUI mode defaults to file-backed sessions for persistence
|
||
session_backend = session_config.get("backend", "file" if os.environ.get("AGENTKIT_GUI_MODE") else "memory")
|
||
session_store = create_session_store(
|
||
backend=session_backend,
|
||
redis_url=session_config.get("redis_url", "redis://localhost:6379/0"),
|
||
ttl_seconds=session_config.get("ttl_seconds", 86400),
|
||
)
|
||
app.state.session_manager = SessionManager(store=session_store)
|
||
|
||
# Initialize evolution store if configured
|
||
if server_config and hasattr(server_config, 'evolution') and server_config.evolution:
|
||
try:
|
||
from agentkit.evolution.evolution_store import create_evolution_store
|
||
evo_conf = server_config.evolution
|
||
app.state.evolution_store = create_evolution_store(
|
||
backend=evo_conf.get("backend", "memory"),
|
||
db_path=evo_conf.get("db_path", "~/.agentkit/evolution.db"),
|
||
)
|
||
except Exception as e:
|
||
import logging
|
||
logging.getLogger(__name__).warning(f"Failed to initialize evolution store: {e}")
|
||
app.state.evolution_store = None
|
||
else:
|
||
app.state.evolution_store = None
|
||
|
||
# Initialize memory components if configured
|
||
if server_config and hasattr(server_config, 'memory') and server_config.memory:
|
||
try:
|
||
from agentkit.memory.retriever import MemoryRetriever
|
||
from agentkit.memory.working import WorkingMemory
|
||
from agentkit.memory.semantic import SemanticMemory
|
||
from agentkit.memory.http_rag import HttpRAGService
|
||
|
||
working = None
|
||
episodic = None
|
||
semantic = None
|
||
|
||
if server_config.memory.get("working", {}).get("enabled"):
|
||
import redis.asyncio as aioredis
|
||
redis_url = server_config.memory["working"].get("redis_url", "redis://localhost:6379")
|
||
redis_client = aioredis.from_url(redis_url, decode_responses=True)
|
||
working = WorkingMemory(redis=redis_client)
|
||
app.state.working_redis_client = redis_client
|
||
|
||
if server_config.memory.get("semantic", {}).get("enabled"):
|
||
sem_conf = server_config.memory["semantic"]
|
||
rag_service = HttpRAGService(
|
||
base_url=sem_conf["base_url"],
|
||
api_key=sem_conf.get("api_key"),
|
||
knowledge_base_ids=sem_conf.get("knowledge_base_ids", []),
|
||
timeout=sem_conf.get("timeout", 30),
|
||
)
|
||
semantic = SemanticMemory(
|
||
rag_service=rag_service,
|
||
knowledge_base_ids=sem_conf.get("knowledge_base_ids", []),
|
||
search_mode=sem_conf.get("search_mode", "standard"),
|
||
use_rerank=sem_conf.get("use_rerank", True),
|
||
use_compression=sem_conf.get("use_compression", False),
|
||
kb_weights=sem_conf.get("kb_weights"),
|
||
)
|
||
|
||
if server_config.memory.get("episodic", {}).get("enabled"):
|
||
try:
|
||
from agentkit.memory.episodic import EpisodicMemory
|
||
from agentkit.memory.embedder import OpenAIEmbedder, EmbeddingCache
|
||
from agentkit.memory.models import EpisodeModel, create_episodic_session_factory
|
||
|
||
epi_conf = server_config.memory["episodic"]
|
||
embedder = None
|
||
if epi_conf.get("embedder_api_key") or os.environ.get("OPENAI_API_KEY"):
|
||
cache = EmbeddingCache(
|
||
max_size=epi_conf.get("cache_max_size", 1000),
|
||
ttl=epi_conf.get("cache_ttl", 3600),
|
||
)
|
||
embedder = OpenAIEmbedder(
|
||
api_key=epi_conf.get("embedder_api_key"),
|
||
model=epi_conf.get("embedder_model", "text-embedding-3-small"),
|
||
base_url=epi_conf.get("embedder_base_url"),
|
||
cache=cache,
|
||
)
|
||
# Resolve session_factory and model from database_url if configured
|
||
epi_session_factory = None
|
||
epi_model = None
|
||
database_url = epi_conf.get("database_url") or os.environ.get("DATABASE_URL")
|
||
if database_url:
|
||
try:
|
||
epi_session_factory = create_episodic_session_factory(database_url)
|
||
epi_model = EpisodeModel
|
||
except Exception as db_err:
|
||
import logging as _log
|
||
_log.getLogger(__name__).warning(
|
||
f"Failed to create episodic DB session: {db_err}"
|
||
)
|
||
|
||
episodic = EpisodicMemory(
|
||
session_factory=epi_session_factory,
|
||
episodic_model=epi_model,
|
||
embedder=embedder,
|
||
decay_rate=epi_conf.get("decay_rate", 0.01),
|
||
alpha=epi_conf.get("alpha", 0.7),
|
||
retrieve_limit=epi_conf.get("retrieve_limit", 200),
|
||
pgvector_enabled=epi_conf.get("pgvector_enabled", True),
|
||
table_name=epi_conf.get("table_name", "episodic_memories"),
|
||
)
|
||
except Exception as e:
|
||
import logging
|
||
logging.getLogger(__name__).warning(f"Failed to initialize episodic memory: {e}")
|
||
|
||
memory_retriever = MemoryRetriever(
|
||
working_memory=working,
|
||
episodic_memory=episodic,
|
||
semantic_memory=semantic,
|
||
)
|
||
app.state.memory_retriever = memory_retriever
|
||
|
||
# Auto-register retrieve_knowledge tool if semantic memory is configured
|
||
if memory_retriever:
|
||
retrieve_tool = memory_retriever.create_retrieve_tool()
|
||
if retrieve_tool:
|
||
app.state.retrieve_knowledge_tool = retrieve_tool
|
||
except Exception as e:
|
||
import logging
|
||
logging.getLogger(__name__).warning(f"Failed to initialize memory components: {e}")
|
||
app.state.memory_retriever = None
|
||
|
||
# Include routes
|
||
app.include_router(agents.router, prefix="/api/v1")
|
||
app.include_router(tasks.router, prefix="/api/v1")
|
||
app.include_router(skills.router, prefix="/api/v1")
|
||
app.include_router(llm.router, prefix="/api/v1")
|
||
app.include_router(health.router, prefix="/api/v1")
|
||
app.include_router(metrics.router, prefix="/api/v1")
|
||
app.include_router(ws.router, prefix="/api/v1")
|
||
app.include_router(evolution.router, prefix="/api/v1")
|
||
app.include_router(memory.router, prefix="/api/v1")
|
||
app.include_router(portal.router, prefix="/api/v1")
|
||
app.include_router(evolution_dashboard.router, prefix="/api/v1")
|
||
app.include_router(kb_management.router, prefix="/api/v1")
|
||
app.include_router(skill_management.router, prefix="/api/v1")
|
||
app.include_router(workflows.router, prefix="/api/v1")
|
||
app.include_router(chat.router, prefix="/api/v1")
|
||
|
||
# Serve GUI when in GUI mode
|
||
gui_mode = os.environ.get("AGENTKIT_GUI_MODE")
|
||
if gui_mode:
|
||
from pathlib import Path as _Path
|
||
from fastapi.responses import HTMLResponse, FileResponse
|
||
|
||
_static_dir = _Path(__file__).parent / "static"
|
||
|
||
@app.get("/", response_class=HTMLResponse, include_in_schema=False)
|
||
async def gui_index():
|
||
"""Serve the GUI index page."""
|
||
index_path = _static_dir / "index.html"
|
||
if index_path.exists():
|
||
return FileResponse(str(index_path))
|
||
return HTMLResponse("<h1>AgentKit GUI not found</h1>", status_code=404)
|
||
|
||
return app
|