feat(memory): add HttpRAGService for config-driven knowledge base integration

This commit is contained in:
chiguyong 2026-06-06 18:36:05 +08:00
parent 0456429beb
commit cd5b39087e
6 changed files with 711 additions and 1 deletions

View File

@ -313,9 +313,12 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
try: try:
from agentkit.memory.retriever import MemoryRetriever from agentkit.memory.retriever import MemoryRetriever
from agentkit.memory.working import WorkingMemory from agentkit.memory.working import WorkingMemory
from agentkit.memory.semantic import SemanticMemory
from agentkit.memory.http_rag import HttpRAGService
working = None working = None
episodic = None episodic = None
semantic = None
if config.memory.get("working", {}).get("enabled"): if config.memory.get("working", {}).get("enabled"):
import redis.asyncio as aioredis import redis.asyncio as aioredis
@ -328,9 +331,23 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
# Will be initialized externally when DB is available # Will be initialized externally when DB is available
pass pass
if config.memory.get("semantic", {}).get("enabled"):
sem_conf = config.memory["semantic"]
rag_service = HttpRAGService(
base_url=sem_conf["base_url"],
api_key=sem_conf.get("api_key"),
knowledge_base_ids=sem_conf.get("knowledge_base_ids", []),
timeout=sem_conf.get("timeout", 30),
)
semantic = SemanticMemory(
rag_service=rag_service,
knowledge_base_ids=sem_conf.get("knowledge_base_ids", []),
)
self._memory_retriever = MemoryRetriever( self._memory_retriever = MemoryRetriever(
working_memory=working, working_memory=working,
episodic_memory=episodic, episodic_memory=episodic,
semantic_memory=semantic,
) )
# Inject into BaseAgent # Inject into BaseAgent

View File

@ -4,6 +4,7 @@ from agentkit.memory.base import Memory, MemoryItem
from agentkit.memory.working import WorkingMemory from agentkit.memory.working import WorkingMemory
from agentkit.memory.episodic import EpisodicMemory from agentkit.memory.episodic import EpisodicMemory
from agentkit.memory.semantic import SemanticMemory from agentkit.memory.semantic import SemanticMemory
from agentkit.memory.http_rag import HttpRAGService
from agentkit.memory.retriever import MemoryRetriever from agentkit.memory.retriever import MemoryRetriever
__all__ = [ __all__ = [
@ -12,5 +13,6 @@ __all__ = [
"WorkingMemory", "WorkingMemory",
"EpisodicMemory", "EpisodicMemory",
"SemanticMemory", "SemanticMemory",
"HttpRAGService",
"MemoryRetriever", "MemoryRetriever",
] ]

View File

@ -0,0 +1,193 @@
"""HTTP RAG Service - 通过 HTTP 调用业务系统知识库 API
配置驱动不直接依赖业务系统代码通过 base_url + api_key 连接
"""
import logging
from typing import Any
import httpx
logger = logging.getLogger(__name__)
class HttpRAGService:
"""HTTP 客户端,调用业务系统的知识库检索 API
适配任意提供以下接口的知识库服务
- POST {base_url}/search 语义检索
- POST {base_url}/ingest 文档写入可选
典型配置agentkit.yaml::
memory:
semantic:
enabled: true
base_url: "http://localhost:8000/api/knowledge"
api_key: "${GEO_API_KEY}"
knowledge_base_ids:
- "industry-kb-id"
- "enterprise-kb-id"
timeout: 30
"""
def __init__(
self,
base_url: str,
api_key: str | None = None,
knowledge_base_ids: list[str] | None = None,
timeout: int = 30,
):
"""
Args:
base_url: 知识库 API 基础地址 http://localhost:8000/api/knowledge
api_key: 认证 API Key放在 Authorization: Bearer
knowledge_base_ids: 默认检索的知识库 ID 列表
timeout: HTTP 请求超时秒数
"""
self._base_url = base_url.rstrip("/")
self._api_key = api_key
self._knowledge_base_ids = knowledge_base_ids or []
self._timeout = timeout
self._client: httpx.AsyncClient | None = None
def _get_client(self) -> httpx.AsyncClient:
"""懒初始化 httpx 客户端"""
if self._client is None or self._client.is_closed:
headers: dict[str, str] = {"Content-Type": "application/json"}
if self._api_key:
headers["Authorization"] = f"Bearer {self._api_key}"
self._client = httpx.AsyncClient(
base_url=self._base_url,
headers=headers,
timeout=self._timeout,
)
return self._client
async def search(
self,
query: str,
knowledge_base_ids: list[str] | None = None,
top_k: int = 5,
) -> list[dict[str, Any]]:
"""语义检索知识库
Args:
query: 检索查询
knowledge_base_ids: 知识库 ID 列表默认使用配置值
top_k: 返回结果数量
Returns:
检索结果列表每项包含 content/score/document_id 等字段
"""
kb_ids = knowledge_base_ids or self._knowledge_base_ids
payload = {
"query": query,
"knowledge_base_ids": kb_ids,
"top_k": top_k,
}
client = self._get_client()
try:
resp = await client.post("/search", json=payload)
resp.raise_for_status()
data = resp.json()
# 兼容两种响应格式:
# 1. {"results": [...]} — GEO 标准 SearchResponse
# 2. [...] — 直接返回列表
if isinstance(data, dict) and "results" in data:
results = data["results"]
elif isinstance(data, list):
results = data
else:
logger.warning(f"Unexpected search response format: {type(data)}")
return []
# 标准化为 SemanticMemory 期望的格式
normalized = []
for r in results:
if isinstance(r, dict):
normalized.append({
"id": r.get("chunk_id", r.get("id", "")),
"content": r.get("content", ""),
"score": float(r.get("score", 0.0)),
"source": r.get("source", "rag"),
"document_id": r.get("document_id", ""),
"document_title": r.get("document_title", ""),
"metadata": r.get("metadata", {}),
})
return normalized
except httpx.HTTPStatusError as e:
logger.error(f"RAG search HTTP error: {e.response.status_code}{e.response.text[:200]}")
return []
except httpx.RequestError as e:
logger.error(f"RAG search request error: {e}")
return []
except Exception as e:
logger.error(f"RAG search unexpected error: {e}")
return []
async def ingest(
self,
key: str,
value: Any,
metadata: dict[str, Any] | None = None,
) -> dict[str, Any] | None:
"""写入文档到知识库(可选操作)
Args:
key: 文档标题或标识
value: 文档内容
metadata: 额外元数据
Returns:
写入结果 None 表示写入不可用
"""
kb_ids = self._knowledge_base_ids
if not kb_ids:
logger.warning("HttpRAGService.ingest: no knowledge_base_ids configured")
return None
payload = {
"title": key,
"content": str(value),
"source_type": "text",
"metadata": metadata or {},
}
client = self._get_client()
try:
# 写入到第一个配置的知识库
kb_id = kb_ids[0]
resp = await client.post(f"/bases/{kb_id}/documents", json=payload)
resp.raise_for_status()
return resp.json()
except httpx.HTTPStatusError as e:
logger.error(f"RAG ingest HTTP error: {e.response.status_code}")
return None
except Exception as e:
logger.error(f"RAG ingest error: {e}")
return None
async def health_check(self) -> bool:
"""检查知识库服务是否可用"""
client = self._get_client()
try:
resp = await client.get("/bases")
return resp.status_code in (200, 401) # 401 = 服务在但需认证
except Exception:
return False
async def close(self) -> None:
"""关闭 HTTP 客户端"""
if self._client and not self._client.is_closed:
await self._client.aclose()
self._client = None
async def __aenter__(self) -> "HttpRAGService":
return self
async def __aexit__(self, *args: Any) -> None:
await self.close()

View File

@ -165,15 +165,37 @@ def create_app(
try: try:
from agentkit.memory.retriever import MemoryRetriever from agentkit.memory.retriever import MemoryRetriever
from agentkit.memory.working import WorkingMemory from agentkit.memory.working import WorkingMemory
from agentkit.memory.semantic import SemanticMemory
from agentkit.memory.http_rag import HttpRAGService
working = None working = None
episodic = None
semantic = None
if server_config.memory.get("working", {}).get("enabled"): if server_config.memory.get("working", {}).get("enabled"):
import redis.asyncio as aioredis import redis.asyncio as aioredis
redis_url = server_config.memory["working"].get("redis_url", "redis://localhost:6379") redis_url = server_config.memory["working"].get("redis_url", "redis://localhost:6379")
redis_client = aioredis.from_url(redis_url, decode_responses=True) redis_client = aioredis.from_url(redis_url, decode_responses=True)
working = WorkingMemory(redis=redis_client) working = WorkingMemory(redis=redis_client)
memory_retriever = MemoryRetriever(working_memory=working) if server_config.memory.get("semantic", {}).get("enabled"):
sem_conf = server_config.memory["semantic"]
rag_service = HttpRAGService(
base_url=sem_conf["base_url"],
api_key=sem_conf.get("api_key"),
knowledge_base_ids=sem_conf.get("knowledge_base_ids", []),
timeout=sem_conf.get("timeout", 30),
)
semantic = SemanticMemory(
rag_service=rag_service,
knowledge_base_ids=sem_conf.get("knowledge_base_ids", []),
)
memory_retriever = MemoryRetriever(
working_memory=working,
episodic_memory=episodic,
semantic_memory=semantic,
)
app.state.memory_retriever = memory_retriever app.state.memory_retriever = memory_retriever
except Exception as e: except Exception as e:
import logging import logging

View File

@ -62,6 +62,7 @@ class ServerConfig:
log_format: str = "text", log_format: str = "text",
task_store: dict[str, Any] | None = None, task_store: dict[str, Any] | None = None,
cors_origins: list[str] | None = None, cors_origins: list[str] | None = None,
memory: dict[str, Any] | None = None,
): ):
self.host = host self.host = host
self.port = port self.port = port
@ -75,6 +76,7 @@ class ServerConfig:
self.log_format = log_format self.log_format = log_format
self.task_store = task_store or {} self.task_store = task_store or {}
self.cors_origins = cors_origins or ["*"] self.cors_origins = cors_origins or ["*"]
self.memory = memory or {}
@classmethod @classmethod
def from_yaml(cls, path: str) -> "ServerConfig": def from_yaml(cls, path: str) -> "ServerConfig":
@ -95,6 +97,7 @@ class ServerConfig:
skills_data = data.get("skills", {}) skills_data = data.get("skills", {})
logging_data = data.get("logging", {}) logging_data = data.get("logging", {})
task_store_data = data.get("task_store", {}) task_store_data = data.get("task_store", {})
memory_data = data.get("memory", {})
# Build LLMConfig # Build LLMConfig
llm_config = cls._build_llm_config(llm_data) llm_config = cls._build_llm_config(llm_data)
@ -116,6 +119,7 @@ class ServerConfig:
log_format=logging_data.get("format", "text"), log_format=logging_data.get("format", "text"),
task_store=task_store_data, task_store=task_store_data,
cors_origins=server.get("cors_origins"), cors_origins=server.get("cors_origins"),
memory=memory_data,
) )
@staticmethod @staticmethod

View File

@ -0,0 +1,472 @@
"""Tests for HttpRAGService — HTTP 客户端调用业务系统知识库 API"""
import json
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from agentkit.memory.http_rag import HttpRAGService
from agentkit.memory.semantic import SemanticMemory
from agentkit.memory.retriever import MemoryRetriever
# ---------------------------------------------------------------------------
# HttpRAGService unit tests
# ---------------------------------------------------------------------------
class TestHttpRAGServiceInit:
"""HttpRAGService 初始化"""
def test_basic_init(self):
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
assert svc._base_url == "http://localhost:8000/api/knowledge"
assert svc._api_key is None
assert svc._knowledge_base_ids == []
assert svc._timeout == 30
def test_init_with_all_params(self):
svc = HttpRAGService(
base_url="http://geo:8000/api/knowledge/",
api_key="sk-test",
knowledge_base_ids=["kb-1", "kb-2"],
timeout=60,
)
assert svc._base_url == "http://geo:8000/api/knowledge" # trailing slash stripped
assert svc._api_key == "sk-test"
assert svc._knowledge_base_ids == ["kb-1", "kb-2"]
assert svc._timeout == 60
def test_trailing_slash_stripped(self):
svc = HttpRAGService(base_url="http://host/api/")
assert svc._base_url == "http://host/api"
class TestHttpRAGServiceSearch:
"""HttpRAGService.search — 语义检索"""
@pytest.fixture
def svc(self):
return HttpRAGService(
base_url="http://localhost:8000/api/knowledge",
api_key="test-key",
knowledge_base_ids=["kb-industry", "kb-enterprise"],
)
@pytest.mark.asyncio
async def test_search_standard_response(self, svc):
"""标准 SearchResponse 格式: {"results": [...]}"""
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = {
"results": [
{
"chunk_id": "c1",
"content": "AI 行业趋势分析",
"score": 0.92,
"document_id": "d1",
"document_title": "行业报告",
"metadata": {"page": 1},
},
{
"chunk_id": "c2",
"content": "企业数字化转型",
"score": 0.85,
"document_id": "d2",
"document_title": "企业案例",
"metadata": {},
},
],
"total": 2,
"latency_ms": 50,
}
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
results = await svc.search("AI 趋势", top_k=5)
assert len(results) == 2
assert results[0]["id"] == "c1"
assert results[0]["content"] == "AI 行业趋势分析"
assert results[0]["score"] == 0.92
assert results[0]["document_id"] == "d1"
assert results[1]["content"] == "企业数字化转型"
# Verify payload
call_args = mock_client.post.call_args
assert call_args[0][0] == "/search"
payload = call_args[1]["json"]
assert payload["query"] == "AI 趋势"
assert payload["knowledge_base_ids"] == ["kb-industry", "kb-enterprise"]
assert payload["top_k"] == 5
@pytest.mark.asyncio
async def test_search_list_response(self, svc):
"""直接返回列表格式: [...]"""
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = [
{"chunk_id": "c1", "content": "test", "score": 0.8},
]
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
results = await svc.search("test")
assert len(results) == 1
assert results[0]["content"] == "test"
@pytest.mark.asyncio
async def test_search_custom_kb_ids(self, svc):
"""传入自定义 knowledge_base_ids 覆盖默认值"""
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = {"results": []}
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
await svc.search("test", knowledge_base_ids=["custom-kb"])
payload = mock_client.post.call_args[1]["json"]
assert payload["knowledge_base_ids"] == ["custom-kb"]
@pytest.mark.asyncio
async def test_search_http_error_returns_empty(self, svc):
"""HTTP 错误返回空列表"""
import httpx
mock_resp = MagicMock()
mock_resp.status_code = 500
mock_resp.text = "Internal Server Error"
mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError(
"500", request=MagicMock(), response=mock_resp
)
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
results = await svc.search("test")
assert results == []
@pytest.mark.asyncio
async def test_search_connection_error_returns_empty(self, svc):
"""连接错误返回空列表"""
import httpx
mock_client = AsyncMock()
mock_client.post = AsyncMock(side_effect=httpx.ConnectError("Connection refused"))
svc._get_client = MagicMock(return_value=mock_client)
results = await svc.search("test")
assert results == []
@pytest.mark.asyncio
async def test_search_unexpected_format_returns_empty(self, svc):
"""非预期响应格式返回空列表"""
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = {"error": "something"}
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
results = await svc.search("test")
assert results == []
class TestHttpRAGServiceIngest:
"""HttpRAGService.ingest — 文档写入"""
@pytest.mark.asyncio
async def test_ingest_success(self):
svc = HttpRAGService(
base_url="http://localhost:8000/api/knowledge",
knowledge_base_ids=["kb-1"],
)
mock_resp = MagicMock()
mock_resp.status_code = 201
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = {"id": "doc-1", "status": "processing"}
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
result = await svc.ingest("测试文档", "文档内容")
assert result["id"] == "doc-1"
# Verify endpoint and payload
call_args = mock_client.post.call_args
assert call_args[0][0] == "/bases/kb-1/documents"
payload = call_args[1]["json"]
assert payload["title"] == "测试文档"
assert payload["content"] == "文档内容"
@pytest.mark.asyncio
async def test_ingest_no_kb_ids_returns_none(self):
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
result = await svc.ingest("test", "content")
assert result is None
@pytest.mark.asyncio
async def test_ingest_http_error_returns_none(self):
import httpx
svc = HttpRAGService(
base_url="http://localhost:8000/api/knowledge",
knowledge_base_ids=["kb-1"],
)
mock_resp = MagicMock()
mock_resp.status_code = 500
mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError(
"500", request=MagicMock(), response=mock_resp
)
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
result = await svc.ingest("test", "content")
assert result is None
class TestHttpRAGServiceHealthCheck:
"""HttpRAGService.health_check"""
@pytest.mark.asyncio
async def test_health_check_ok(self):
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
assert await svc.health_check() is True
@pytest.mark.asyncio
async def test_health_check_401_still_healthy(self):
"""401 表示服务在运行,只是需要认证"""
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
mock_resp = MagicMock()
mock_resp.status_code = 401
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
assert await svc.health_check() is True
@pytest.mark.asyncio
async def test_health_check_connection_error(self):
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
mock_client = AsyncMock()
mock_client.get = AsyncMock(side_effect=Exception("Connection refused"))
svc._get_client = MagicMock(return_value=mock_client)
assert await svc.health_check() is False
class TestHttpRAGServiceClient:
"""HttpRAGService HTTP 客户端管理"""
def test_client_lazy_init(self):
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge", api_key="sk-test")
assert svc._client is None
client = svc._get_client()
assert client is not None
assert "Bearer sk-test" in str(client.headers.get("Authorization", ""))
def test_client_reuse(self):
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
c1 = svc._get_client()
c2 = svc._get_client()
assert c1 is c2
@pytest.mark.asyncio
async def test_close(self):
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
svc._get_client() # init client
await svc.close()
assert svc._client is None
@pytest.mark.asyncio
async def test_context_manager(self):
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
async with svc as s:
s._get_client()
assert s._client is not None
assert svc._client is None
# ---------------------------------------------------------------------------
# SemanticMemory + HttpRAGService integration
# ---------------------------------------------------------------------------
class TestSemanticMemoryWithHttpRAG:
"""SemanticMemory 通过 HttpRAGService 检索知识库"""
@pytest.mark.asyncio
async def test_search_delegates_to_rag_service(self):
"""SemanticMemory.search 委托给 HttpRAGService.search"""
rag = HttpRAGService(
base_url="http://localhost:8000/api/knowledge",
knowledge_base_ids=["kb-1"],
)
# Mock the search
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = {
"results": [
{"chunk_id": "c1", "content": "行业知识", "score": 0.9, "document_id": "d1"},
]
}
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
rag._get_client = MagicMock(return_value=mock_client)
semantic = SemanticMemory(rag_service=rag, knowledge_base_ids=["kb-1"])
items = await semantic.search("行业趋势", top_k=3)
assert len(items) == 1
assert items[0].key == "c1"
assert items[0].value == "行业知识"
assert items[0].score == 0.9
assert items[0].metadata["source"] == "rag"
@pytest.mark.asyncio
async def test_search_no_rag_service_returns_empty(self):
"""无 RAG 服务时返回空列表"""
semantic = SemanticMemory()
items = await semantic.search("test")
assert items == []
class TestMemoryRetrieverWithSemantic:
"""MemoryRetriever 集成 SemanticMemory + HttpRAGService"""
@pytest.mark.asyncio
async def test_retriever_queries_semantic_layer(self):
"""MemoryRetriever 查询 Semantic 层并融合结果"""
rag = HttpRAGService(
base_url="http://localhost:8000/api/knowledge",
knowledge_base_ids=["kb-1"],
)
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = {
"results": [
{"chunk_id": "c1", "content": "知识库内容", "score": 0.95, "document_id": "d1"},
]
}
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
rag._get_client = MagicMock(return_value=mock_client)
semantic = SemanticMemory(rag_service=rag, knowledge_base_ids=["kb-1"])
retriever = MemoryRetriever(semantic_memory=semantic)
items = await retriever.retrieve("知识查询", top_k=3)
assert len(items) >= 1
# Semantic weight is 0.4 by default
assert items[0].score == pytest.approx(0.95 * 0.4, abs=0.01)
# ---------------------------------------------------------------------------
# Config-driven integration tests
# ---------------------------------------------------------------------------
class TestServerConfigMemorySemantic:
"""ServerConfig 解析 memory.semantic 配置"""
def test_from_dict_with_semantic(self):
from agentkit.server.config import ServerConfig
data = {
"memory": {
"semantic": {
"enabled": True,
"base_url": "http://geo:8000/api/knowledge",
"api_key": "sk-test",
"knowledge_base_ids": ["kb-1", "kb-2"],
"timeout": 60,
},
},
}
config = ServerConfig.from_dict(data)
assert config.memory["semantic"]["enabled"] is True
assert config.memory["semantic"]["base_url"] == "http://geo:8000/api/knowledge"
assert config.memory["semantic"]["api_key"] == "sk-test"
assert config.memory["semantic"]["knowledge_base_ids"] == ["kb-1", "kb-2"]
assert config.memory["semantic"]["timeout"] == 60
def test_from_dict_without_memory(self):
from agentkit.server.config import ServerConfig
config = ServerConfig.from_dict({})
assert config.memory == {}
def test_from_yaml_with_env_var_resolution(self):
"""验证 from_yaml 路径的 ${VAR:-default} 环境变量解析"""
from agentkit.server.config import ServerConfig
import os
import tempfile
os.environ["TEST_GEO_API_KEY"] = "sk-from-env"
yaml_content = """
memory:
semantic:
enabled: true
base_url: http://geo:8000/api/knowledge
api_key: "${TEST_GEO_API_KEY}"
"""
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
f.write(yaml_content)
f.flush()
config = ServerConfig.from_yaml(f.name)
assert config.memory["semantic"]["api_key"] == "sk-from-env"
del os.environ["TEST_GEO_API_KEY"]
def test_from_yaml_with_default_env_var(self):
"""验证 from_yaml 路径的 ${VAR:-default} 带默认值"""
from agentkit.server.config import ServerConfig
import tempfile
yaml_content = """
memory:
semantic:
enabled: true
base_url: http://geo:8000/api/knowledge
api_key: "${NONEXISTENT_KEY:-sk-default}"
"""
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
f.write(yaml_content)
f.flush()
config = ServerConfig.from_yaml(f.name)
assert config.memory["semantic"]["api_key"] == "sk-default"