fischer-agentkit/src/agentkit/tools/bitable_tool.py

486 lines
19 KiB
Python

"""BitableTool — Agent tool for bitable data ingestion and CRUD via HTTP.
Implements KTD5 (REST API boundary even when co-deployed) and KTD11
(internal service token auth). The tool uses ``httpx.AsyncClient`` to call
the bitable REST API; it never imports BitableService directly.
Actions: create_table, import_excel, import_database, collect_api,
upsert_records, query_records.
Batch chunking: upsert and import operations send at most ``BATCH_SIZE``
records per HTTP request. On partial failure, the result includes
``successful_count`` and ``resume_from`` for breakpoint continuation.
"""
from __future__ import annotations
import asyncio
import logging
import httpx
from agentkit.bitable.ingestion.excel import ParsedSheet, parse_excel, parse_excel_url
from agentkit.bitable.ingestion.database import import_table as import_db_table
from agentkit.bitable.ingestion.api_collector import transform_records
from agentkit.tools.base import Tool
logger = logging.getLogger(__name__)
BATCH_SIZE = 500
class BitableTool(Tool):
"""Agent tool for bitable operations via REST API.
Args:
base_url: Bitable API base URL (e.g. ``http://localhost:8001/api/v1/bitable``).
internal_token: Service token for KTD11 auth. If ``None``, requests
go unauthenticated (will fail if the server requires auth).
"""
def __init__(self, base_url: str, internal_token: str | None = None) -> None:
super().__init__(
name="bitable",
description=(
"Create and manage bitable (multi-dimensional spreadsheet) tables, "
"ingest data from Excel files, databases, or API responses, and "
"query records. Actions: create_table, import_excel, "
"import_database, collect_api, upsert_records, query_records."
),
input_schema={
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": [
"create_table",
"import_excel",
"import_database",
"collect_api",
"upsert_records",
"query_records",
],
"description": "Bitable operation to perform.",
},
"table_name": {
"type": "string",
"description": "Name for the new bitable table (create_table, import_excel, import_database).",
},
"description": {
"type": "string",
"description": "Table description (create_table).",
},
"file_path": {
"type": "string",
"description": "Path to .xlsx file (import_excel).",
},
"file_url": {
"type": "string",
"description": "URL to download .xlsx file (import_excel).",
},
"connection_string": {
"type": "string",
"description": "Database connection string (import_database).",
},
"table_names": {
"type": "array",
"items": {"type": "string"},
"description": "Source table names to import (import_database).",
},
"table_id": {
"type": "string",
"description": "Target bitable table ID (collect_api, upsert_records, query_records).",
},
"records": {
"type": "array",
"description": "Records to write (collect_api, upsert_records).",
},
"field_mapping": {
"type": "object",
"description": "Mapping {source_key: bitable_field_id} (collect_api).",
},
"primary_key_field_id": {
"type": "string",
"description": "Field ID of the primary key (upsert_records, collect_api).",
},
"resume_from": {
"type": "integer",
"description": "Skip this many records before resuming a failed batch (upsert_records, collect_api).",
},
"cursor": {
"type": "string",
"description": "Pagination cursor (query_records).",
},
"limit": {
"type": "integer",
"description": "Max records to return (query_records).",
},
},
"required": ["action"],
},
)
self._base_url = base_url.rstrip("/")
self._internal_token = internal_token
self._client: httpx.AsyncClient | None = None
async def _get_client(self) -> httpx.AsyncClient:
if self._client is None or self._client.is_closed:
headers: dict[str, str] = {}
if self._internal_token:
headers["X-Internal-Token"] = self._internal_token
self._client = httpx.AsyncClient(
base_url=self._base_url,
headers=headers,
timeout=60.0,
)
return self._client
async def close(self) -> None:
if self._client is not None and not self._client.is_closed:
await self._client.aclose()
async def execute(self, **kwargs) -> dict[str, object]:
action = kwargs.get("action")
handlers = {
"create_table": self._create_table,
"import_excel": self._import_excel,
"import_database": self._import_database,
"collect_api": self._collect_api,
"upsert_records": self._upsert_records,
"query_records": self._query_records,
}
handler = handlers.get(action)
if handler is None:
return {"success": False, "error": f"Unknown action: {action!r}"}
try:
return await handler(**kwargs)
except httpx.HTTPStatusError as e:
return {
"success": False,
"error": f"Bitable API error {e.response.status_code}: {e.response.text[:500]}",
}
except httpx.ConnectError as e:
return {"success": False, "error": f"Cannot connect to bitable API: {e}"}
except Exception as e:
return {"success": False, "error": f"{action} failed: {e}"}
# ------------------------------------------------------------------
# create_table
# ------------------------------------------------------------------
async def _create_table(self, **kwargs) -> dict[str, object]:
table_name = kwargs.get("table_name")
if not table_name:
return {"success": False, "error": "Missing required field: table_name"}
client = await self._get_client()
resp = await client.post(
"/tables",
json={"name": table_name, "description": kwargs.get("description", "")},
)
resp.raise_for_status()
data = resp.json()
return {"success": True, "table": data["table"]}
# ------------------------------------------------------------------
# import_excel
# ------------------------------------------------------------------
async def _import_excel(self, **kwargs) -> dict[str, object]:
file_path = kwargs.get("file_path")
file_url = kwargs.get("file_url")
if not file_path and not file_url:
return {"success": False, "error": "Either file_path or file_url is required"}
# Parse Excel — offload sync I/O to thread pool (P2 #21-23).
if file_path:
sheets = await asyncio.to_thread(parse_excel, file_path)
else:
sheets = await asyncio.to_thread(parse_excel_url, file_url)
if not sheets:
return {"success": False, "error": "Excel file has no sheets with data"}
results: list[dict[str, object]] = []
for sheet in sheets:
result = await self._import_sheet(sheet)
results.append(result)
return {"success": True, "sheets": results}
async def _import_sheet(self, sheet: ParsedSheet) -> dict[str, object]:
"""Create a bitable table from a parsed sheet and upsert all rows."""
client = await self._get_client()
# 1. Create table
resp = await client.post("/tables", json={"name": sheet.name})
resp.raise_for_status()
table_id = resp.json()["table"]["id"]
# 2. Create fields
field_name_to_id: dict[str, str] = {}
for col_name, field_type in zip(sheet.columns, sheet.field_types):
resp = await client.post(
f"/tables/{table_id}/fields",
json={"name": col_name, "field_type": field_type, "owner": "agent"},
)
resp.raise_for_status()
field_id = resp.json()["field"]["id"]
field_name_to_id[col_name] = field_id
# 3. Map record keys to field IDs and batch upsert
mapped_records = [
{field_name_to_id[k]: v for k, v in rec.items() if k in field_name_to_id}
for rec in sheet.records
]
if not mapped_records:
return {
"table_id": table_id,
"table_name": sheet.name,
"field_count": len(field_name_to_id),
"record_count": 0,
}
# Use first field as PK fallback (import_excel doesn't require a PK)
# If no PK is set, upsert won't work — use create_records instead
upsert_result = await self._batch_create_records(table_id, mapped_records)
return {
"table_id": table_id,
"table_name": sheet.name,
"field_count": len(field_name_to_id),
"record_count": upsert_result["successful_count"],
**upsert_result,
}
async def _batch_create_records(
self, table_id: str, records: list[dict[str, object]]
) -> dict[str, object]:
"""Create records in batches via POST /tables/{id}/records."""
client = await self._get_client()
total = len(records)
successful = 0
errors: list[dict[str, object]] = []
for start in range(0, total, BATCH_SIZE):
batch = records[start : start + BATCH_SIZE]
try:
resp = await client.post(
f"/tables/{table_id}/records",
json={"records": batch},
)
resp.raise_for_status()
successful += len(batch)
except httpx.HTTPStatusError as e:
errors.append(
{
"batch_start": start,
"batch_size": len(batch),
"status": e.response.status_code,
"error": e.response.text[:300],
}
)
break # stop on first failure
return {
"successful_count": successful,
"total": total,
"resume_from": successful,
**({"errors": errors} if errors else {}),
}
# ------------------------------------------------------------------
# import_database
# ------------------------------------------------------------------
async def _import_database(self, **kwargs) -> dict[str, object]:
conn_str = kwargs.get("connection_string")
table_names = kwargs.get("table_names")
if not conn_str:
return {"success": False, "error": "Missing required field: connection_string"}
if not table_names:
return {"success": False, "error": "Missing required field: table_names"}
results: list[dict[str, object]] = []
for src_table in table_names:
try:
# Offload sync DB reflection to thread pool (P2 #21-23).
reflected = await asyncio.to_thread(import_db_table, conn_str, src_table)
result = await self._import_reflected_table(reflected)
results.append(result)
except ConnectionError as e:
return {"success": False, "error": str(e), "imported": results}
except Exception as e:
results.append({"table_name": src_table, "success": False, "error": str(e)})
return {"success": True, "tables": results}
async def _import_reflected_table(self, reflected: dict[str, object]) -> dict[str, object]:
"""Create a bitable table from reflected DB data and upsert rows."""
client = await self._get_client()
table_name = reflected["table_name"]
# 1. Create table
resp = await client.post("/tables", json={"name": table_name})
resp.raise_for_status()
table_id = resp.json()["table"]["id"]
# 2. Create fields
field_name_to_id: dict[str, str] = {}
pk_field_id: str | None = None
for fdef in reflected["fields"]:
resp = await client.post(
f"/tables/{table_id}/fields",
json={
"name": fdef["name"],
"field_type": fdef["field_type"],
"owner": "agent",
},
)
resp.raise_for_status()
fid = resp.json()["field"]["id"]
field_name_to_id[fdef["name"]] = fid
if fdef.get("is_primary_key"):
pk_field_id = fid
# 3. Set primary key
if pk_field_id:
await client.patch("/tables/" + table_id, json={"primary_key_field_id": pk_field_id})
# 4. Map and upsert records
mapped = [
{field_name_to_id[k]: v for k, v in rec.items() if k in field_name_to_id}
for rec in reflected["records"]
]
if not mapped:
return {
"table_id": table_id,
"table_name": table_name,
"record_count": 0,
"success": True,
}
if pk_field_id:
upsert = await self._batch_upsert(table_id, mapped, pk_field_id)
else:
upsert = await self._batch_create_records(table_id, mapped)
return {
"table_id": table_id,
"table_name": table_name,
"record_count": upsert["successful_count"],
"success": True,
**upsert,
}
# ------------------------------------------------------------------
# collect_api
# ------------------------------------------------------------------
async def _collect_api(self, **kwargs) -> dict[str, object]:
table_id = kwargs.get("table_id")
records = kwargs.get("records")
field_mapping = kwargs.get("field_mapping")
pk_field_id = kwargs.get("primary_key_field_id")
resume_from = kwargs.get("resume_from", 0)
if not table_id:
return {"success": False, "error": "Missing required field: table_id"}
if not records:
return {"success": False, "error": "Missing required field: records"}
if not field_mapping:
return {"success": False, "error": "Missing required field: field_mapping"}
if not pk_field_id:
return {"success": False, "error": "Missing required field: primary_key_field_id"}
transformed = transform_records(records, field_mapping)
if resume_from > 0:
transformed = transformed[resume_from:]
result = await self._batch_upsert(table_id, transformed, pk_field_id)
return {"success": True, **result}
# ------------------------------------------------------------------
# upsert_records
# ------------------------------------------------------------------
async def _upsert_records(self, **kwargs) -> dict[str, object]:
table_id = kwargs.get("table_id")
records = kwargs.get("records")
pk_field_id = kwargs.get("primary_key_field_id")
resume_from = kwargs.get("resume_from", 0)
if not table_id:
return {"success": False, "error": "Missing required field: table_id"}
if not records:
return {"success": False, "error": "Missing required field: records"}
if not pk_field_id:
return {"success": False, "error": "Missing required field: primary_key_field_id"}
batch = records[resume_from:] if resume_from > 0 else records
result = await self._batch_upsert(table_id, batch, pk_field_id)
return {"success": True, **result}
async def _batch_upsert(
self, table_id: str, records: list[dict[str, object]], pk_field_id: str
) -> dict[str, object]:
"""Upsert records in batches of BATCH_SIZE via POST /tables/{id}/upsert."""
client = await self._get_client()
total = len(records)
successful = 0
errors: list[dict[str, object]] = []
for start in range(0, total, BATCH_SIZE):
batch = records[start : start + BATCH_SIZE]
try:
resp = await client.post(
f"/tables/{table_id}/upsert",
json={
"records": batch,
"primary_key_field_id": pk_field_id,
},
)
resp.raise_for_status()
data = resp.json()
successful += data.get("inserted", 0) + data.get("updated", 0)
except httpx.HTTPStatusError as e:
errors.append(
{
"batch_start": start,
"batch_size": len(batch),
"status": e.response.status_code,
"error": e.response.text[:300],
}
)
break
return {
"successful_count": successful,
"total": total,
"resume_from": successful,
**({"errors": errors} if errors else {}),
}
# ------------------------------------------------------------------
# query_records
# ------------------------------------------------------------------
async def _query_records(self, **kwargs) -> dict[str, object]:
table_id = kwargs.get("table_id")
if not table_id:
return {"success": False, "error": "Missing required field: table_id"}
client = await self._get_client()
params: dict[str, object] = {}
if kwargs.get("cursor"):
params["cursor"] = kwargs["cursor"]
if kwargs.get("limit"):
params["limit"] = kwargs["limit"]
resp = await client.get(f"/tables/{table_id}/records", params=params)
resp.raise_for_status()
data = resp.json()
return {
"success": True,
"records": data["records"],
"next_cursor": data.get("next_cursor"),
}