fix(calendar): code review fixes - 23 issues (2 critical, 15 major, 6 minor)

This commit is contained in:
chiguyong 2026-06-24 11:29:23 +08:00
parent 4ea7801bcf
commit 3fdee65979
12 changed files with 304 additions and 36 deletions

View File

@ -105,7 +105,10 @@ CREATE TABLE IF NOT EXISTS calendar_reminder_deliveries (
channel TEXT NOT NULL DEFAULT 'client',
attempts INTEGER NOT NULL DEFAULT 0,
last_error TEXT,
FOREIGN KEY (reminder_rule_id) REFERENCES calendar_reminder_rules(id),
-- ponytail: ON DELETE CASCADE ensures deliveries are removed when their
-- reminder_rule is cascade-deleted (e.g. event deletion). Existing DBs
-- created before this change need ALTER TABLE or DB recreation.
FOREIGN KEY (reminder_rule_id) REFERENCES calendar_reminder_rules(id) ON DELETE CASCADE,
FOREIGN KEY (event_id) REFERENCES calendar_events(id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_deliveries_status
@ -252,6 +255,8 @@ async def insert_event(event: CalendarEvent, db_path: str | Path | None = None)
"""Insert a calendar event."""
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
await db.execute(
"INSERT INTO calendar_events (id, user_id, title, description, "
@ -286,6 +291,9 @@ async def get_event(event_id: str, db_path: str | Path | None = None) -> Calenda
"""Return a single event by id, or None."""
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
db.row_factory = aiosqlite.Row
cursor = await db.execute("SELECT * FROM calendar_events WHERE id = ?", (event_id,))
row = await cursor.fetchone()
@ -304,6 +312,9 @@ async def get_event_by_external_id(
"""
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
db.row_factory = aiosqlite.Row
cursor = await db.execute(
"SELECT * FROM calendar_events "
@ -346,6 +357,9 @@ async def list_events(
query += " WHERE " + " AND ".join(conditions) + " ORDER BY e.start_time"
async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
db.row_factory = aiosqlite.Row
cursor = await db.execute(query, tuple(params))
rows = await cursor.fetchall()
@ -361,6 +375,9 @@ async def list_all_events_in_time_range(
"""
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
db.row_factory = aiosqlite.Row
cursor = await db.execute(
"SELECT * FROM calendar_events WHERE start_time >= ? AND start_time < ? "
@ -411,6 +428,9 @@ async def update_event(
sql = f"UPDATE calendar_events SET {', '.join(set_clauses)} WHERE id = ?"
async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
cursor = await db.execute(sql, tuple(params))
await db.commit()
return cursor.rowcount > 0
@ -420,6 +440,8 @@ async def delete_event(event_id: str, db_path: str | Path | None = None) -> bool
"""Delete an event and its dependent rows. Returns True if deleted."""
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
# Manual cascade for event_tags (no ON DELETE on junction FK in some SQLite versions)
await db.execute("DELETE FROM calendar_event_tags WHERE event_id = ?", (event_id,))
@ -436,6 +458,9 @@ async def delete_event(event_id: str, db_path: str | Path | None = None) -> bool
async def insert_event_type(et: EventType, db_path: str | Path | None = None) -> None:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
await db.execute(
"INSERT INTO calendar_event_types (id, user_id, name, color, is_default) "
"VALUES (?, ?, ?, ?, ?)",
@ -447,6 +472,9 @@ async def insert_event_type(et: EventType, db_path: str | Path | None = None) ->
async def list_event_types(user_id: str, db_path: str | Path | None = None) -> list[EventType]:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
db.row_factory = aiosqlite.Row
cursor = await db.execute(
"SELECT * FROM calendar_event_types WHERE user_id = ? ORDER BY name",
@ -477,6 +505,9 @@ async def update_event_type(
params.append(type_id)
sql = f"UPDATE calendar_event_types SET {', '.join(set_clauses)} WHERE id = ?"
async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
cursor = await db.execute(sql, tuple(params))
await db.commit()
return cursor.rowcount > 0
@ -485,6 +516,9 @@ async def update_event_type(
async def delete_event_type(type_id: str, db_path: str | Path | None = None) -> bool:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
cursor = await db.execute("DELETE FROM calendar_event_types WHERE id = ?", (type_id,))
await db.commit()
return cursor.rowcount > 0
@ -498,6 +532,9 @@ async def delete_event_type(type_id: str, db_path: str | Path | None = None) ->
async def insert_tag(tag: Tag, db_path: str | Path | None = None) -> None:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
await db.execute(
"INSERT INTO calendar_tags (id, user_id, name) VALUES (?, ?, ?)",
(tag.id, tag.user_id, tag.name),
@ -508,6 +545,9 @@ async def insert_tag(tag: Tag, db_path: str | Path | None = None) -> None:
async def list_tags(user_id: str, db_path: str | Path | None = None) -> list[Tag]:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
db.row_factory = aiosqlite.Row
cursor = await db.execute(
"SELECT * FROM calendar_tags WHERE user_id = ? ORDER BY name",
@ -520,6 +560,9 @@ async def list_tags(user_id: str, db_path: str | Path | None = None) -> list[Tag
async def delete_tag(tag_id: str, db_path: str | Path | None = None) -> bool:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
await db.execute("DELETE FROM calendar_event_tags WHERE tag_id = ?", (tag_id,))
cursor = await db.execute("DELETE FROM calendar_tags WHERE id = ?", (tag_id,))
await db.commit()
@ -534,6 +577,9 @@ async def delete_tag(tag_id: str, db_path: str | Path | None = None) -> bool:
async def add_tag_to_event(event_id: str, tag_id: str, db_path: str | Path | None = None) -> None:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
await db.execute(
"INSERT OR IGNORE INTO calendar_event_tags (event_id, tag_id) VALUES (?, ?)",
(event_id, tag_id),
@ -546,6 +592,9 @@ async def remove_tag_from_event(
) -> None:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
await db.execute(
"DELETE FROM calendar_event_tags WHERE event_id = ? AND tag_id = ?",
(event_id, tag_id),
@ -556,6 +605,9 @@ async def remove_tag_from_event(
async def get_event_tags(event_id: str, db_path: str | Path | None = None) -> list[Tag]:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
db.row_factory = aiosqlite.Row
cursor = await db.execute(
"SELECT t.* FROM calendar_tags t "
@ -575,6 +627,9 @@ async def get_event_tags(event_id: str, db_path: str | Path | None = None) -> li
async def insert_reminder_rule(rule: ReminderRule, db_path: str | Path | None = None) -> None:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
await db.execute(
"INSERT INTO calendar_reminder_rules "
"(id, event_id, event_type_id, offset_minutes, channels) "
@ -595,6 +650,9 @@ async def list_reminder_rules_for_event(
) -> list[ReminderRule]:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
db.row_factory = aiosqlite.Row
cursor = await db.execute(
"SELECT * FROM calendar_reminder_rules WHERE event_id = ?",
@ -609,6 +667,9 @@ async def list_reminder_rules_for_type(
) -> list[ReminderRule]:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
db.row_factory = aiosqlite.Row
cursor = await db.execute(
"SELECT * FROM calendar_reminder_rules WHERE event_type_id = ?",
@ -621,6 +682,9 @@ async def list_reminder_rules_for_type(
async def delete_reminder_rule(rule_id: str, db_path: str | Path | None = None) -> bool:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
cursor = await db.execute("DELETE FROM calendar_reminder_rules WHERE id = ?", (rule_id,))
await db.commit()
return cursor.rowcount > 0
@ -636,6 +700,9 @@ async def insert_reminder_delivery(
) -> None:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
await db.execute(
"INSERT INTO calendar_reminder_deliveries "
"(id, reminder_rule_id, event_id, scheduled_time, status, channel, "
@ -655,16 +722,27 @@ async def insert_reminder_delivery(
async def get_pending_deliveries(
event_id: str, reminder_rule_id: str, db_path: str | Path | None = None
event_id: str,
reminder_rule_id: str,
db_path: str | Path | None = None,
status: str = "sent",
) -> list[ReminderDelivery]:
"""Check idempotency — return existing deliveries for an event+rule."""
"""Check idempotency — return existing deliveries for an event+rule.
By default only ``sent`` deliveries are returned, so the scheduler's
idempotency check skips rules that already succeeded. Stuck ``pending``
or ``failed`` deliveries are ignored, allowing retry on the next scan.
"""
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
db.row_factory = aiosqlite.Row
cursor = await db.execute(
"SELECT * FROM calendar_reminder_deliveries "
"WHERE event_id = ? AND reminder_rule_id = ?",
(event_id, reminder_rule_id),
"WHERE event_id = ? AND reminder_rule_id = ? AND status = ?",
(event_id, reminder_rule_id, status),
)
rows = await cursor.fetchall()
return [_row_to_reminder_delivery(row) for row in rows]
@ -678,6 +756,9 @@ async def update_delivery_status(
) -> bool:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
cursor = await db.execute(
"UPDATE calendar_reminder_deliveries "
"SET status = ?, attempts = attempts + 1, last_error = ? "
@ -698,6 +779,9 @@ async def insert_external_config(
) -> None:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
await db.execute(
"INSERT INTO calendar_external_configs "
"(id, user_id, provider, credentials, sync_frequency, sync_scope, "
@ -721,6 +805,9 @@ async def list_external_configs(
) -> list[ExternalCalendarConfig]:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
db.row_factory = aiosqlite.Row
cursor = await db.execute(
"SELECT * FROM calendar_external_configs WHERE user_id = ?",
@ -757,6 +844,9 @@ async def update_external_config(
params.append(config_id)
sql = f"UPDATE calendar_external_configs SET {', '.join(set_clauses)} WHERE id = ?"
async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
cursor = await db.execute(sql, tuple(params))
await db.commit()
return cursor.rowcount > 0
@ -765,6 +855,9 @@ async def update_external_config(
async def delete_external_config(config_id: str, db_path: str | Path | None = None) -> bool:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
cursor = await db.execute(
"DELETE FROM calendar_external_configs WHERE id = ?", (config_id,)
)
@ -780,6 +873,9 @@ async def delete_external_config(config_id: str, db_path: str | Path | None = No
async def insert_invitation(invitation: Invitation, db_path: str | Path | None = None) -> None:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
await db.execute(
"INSERT INTO calendar_invitations "
"(id, event_id, inviter_user_id, invitee_email, status, responded_at) "
@ -801,6 +897,9 @@ async def get_invitation(
) -> Invitation | None:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
db.row_factory = aiosqlite.Row
cursor = await db.execute(
"SELECT * FROM calendar_invitations WHERE id = ?", (invitation_id,)
@ -814,6 +913,9 @@ async def list_invitations(
) -> list[Invitation]:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
db.row_factory = aiosqlite.Row
cursor = await db.execute(
"SELECT * FROM calendar_invitations WHERE invitee_email = ? ORDER BY responded_at DESC",
@ -831,6 +933,9 @@ async def update_invitation_status(
) -> bool:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
cursor = await db.execute(
"UPDATE calendar_invitations SET status = ?, responded_at = ? WHERE id = ?",
(status, responded_at, invitation_id),

View File

@ -6,6 +6,7 @@ All times are UTC (see KTD-11).
from __future__ import annotations
import itertools
from datetime import datetime, timezone
from dateutil.rrule import rrulestr
@ -46,7 +47,9 @@ def expand_rrule(
# rrulestr expects the RRULE to have a DTSTART context.
# We prepend DTSTART to ensure the rule starts from the event's start time.
full_rule = f"DTSTART:{start_dt.strftime('%Y%m%dT%H%M%SZ')}\nRRULE:{rrule_str}"
full_rule = (
f"DTSTART:{start_dt.astimezone(timezone.utc).strftime('%Y%m%dT%H%M%SZ')}\nRRULE:{rrule_str}"
)
rule = rrulestr(full_rule)
if range_start is not None and range_end is not None:
@ -61,7 +64,9 @@ def expand_rrule(
re_ = _parse_dt(range_end)
occurrences = [dt for dt in rule if dt < re_]
else:
occurrences = list(rule)
# ponytail: 1000-occurrence ceiling for unbounded rules (FREQ=DAILY
# without COUNT/UNTIL). Upgrade path: accept a max_occurrences param.
occurrences = list(itertools.islice(rule, 1000))
# Convert back to ISO 8601 UTC strings
return [dt.astimezone(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S+00:00") for dt in occurrences]

View File

@ -7,9 +7,12 @@ patching module imports.
from __future__ import annotations
import ipaddress
import logging
import socket
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from urllib.parse import urlparse
from agentkit.calendar.models import CalendarEvent
@ -21,13 +24,46 @@ class SmtpConfig:
"""SMTP server configuration for the email reminder channel."""
host: str = "localhost"
port: int = 25
# ponytail: STARTTLS on 587 is the modern default; plaintext port 25 is
# only appropriate for local MTA relays. Upgrade: implicit TLS on 465.
port: int = 587
username: str | None = None
password: str | None = None
use_tls: bool = False
use_tls: bool = True
from_email: str = "noreply@agentkit.local"
def _is_safe_webhook_url(url: str) -> bool:
"""Validate webhook URL to prevent SSRF attacks.
ponytail: Basic SSRF guard blocks private/internal IP ranges (RFC 1918),
loopback, link-local (169.254.x includes cloud metadata endpoints), and
non-http(s) schemes. Uses blocking socket.getaddrinfo; upgrade to
asyncio.getaddrinfo if webhook dispatch is on a hot path.
"""
parsed = urlparse(url)
if parsed.scheme not in ("http", "https"):
return False
hostname = parsed.hostname
if not hostname:
return False
if hostname == "localhost" or hostname.startswith("169.254."):
return False
try:
infos = socket.getaddrinfo(hostname, None)
for _, _, _, _, sockaddr in infos:
addr = sockaddr[0]
try:
ip = ipaddress.ip_address(addr)
except ValueError:
continue
if ip.is_private or ip.is_loopback or ip.is_link_local:
return False
except socket.gaierror:
pass # DNS resolution failed — let httpx handle the connection error
return True
class ReminderDispatcher:
"""Dispatch reminders via client push, email, and webhook channels.
@ -105,6 +141,9 @@ class ReminderDispatcher:
async def _send_webhook(self, event: CalendarEvent, user_id: str) -> bool:
if self._webhook_url is None:
return False
if not _is_safe_webhook_url(self._webhook_url):
logger.warning("Webhook URL blocked: private/internal address")
return False
import httpx
async with httpx.AsyncClient() as client:

View File

@ -126,7 +126,7 @@ class ReminderScheduler:
Idempotent: if any delivery already exists for this event+rule, skip.
"""
existing = await get_pending_deliveries(event.id, rule.id, self._db_path)
existing = await get_pending_deliveries(event.id, rule.id, self._db_path, status="sent")
if existing:
return 0

View File

@ -20,6 +20,7 @@ from agentkit.calendar.db import (
add_tag_to_event,
delete_event as db_delete_event,
get_event as db_get_event,
get_invitation as db_get_invitation,
insert_event,
insert_event_type,
insert_invitation,
@ -47,6 +48,16 @@ from agentkit.server.auth.models import DEFAULT_AUTH_DB_PATH
logger = logging.getLogger(__name__)
def _validate_iso(dt_str: str) -> None:
"""Validate that *dt_str* is a parseable ISO 8601 string."""
if not dt_str:
raise ValueError("start_time must be a valid ISO 8601 string")
try:
datetime.fromisoformat(dt_str)
except ValueError:
raise ValueError(f"Invalid ISO 8601 format: {dt_str}")
def _parse_dt(dt_str: str) -> datetime:
"""Parse ISO 8601 string to timezone-aware datetime (UTC)."""
dt = datetime.fromisoformat(dt_str)
@ -97,6 +108,7 @@ class CalendarService:
tag_ids: list[str] | None = None,
) -> CalendarEvent:
"""Create a calendar event with UUID, timestamps, tags, and cloned reminders."""
_validate_iso(start_time)
now = _now_iso()
event = CalendarEvent(
id=uuid.uuid4().hex,
@ -117,22 +129,35 @@ class CalendarService:
)
await insert_event(event, self.db_path)
# Link tags if provided
# Link tags if provided — skip any that don't belong to the user
if tag_ids:
user_tags = await self.list_tags(user_id)
user_tag_ids = {t.id for t in user_tags}
for tag_id in tag_ids:
if tag_id not in user_tag_ids:
logger.debug("Skipping tag %s — not owned by user %s", tag_id, user_id)
continue
await add_tag_to_event(event.id, tag_id, self.db_path)
# Clone type-level default reminder rules to the event
# Clone type-level default reminder rules to the event — skip if type
# doesn't belong to the user
if event_type_id:
type_rules = await list_reminder_rules_for_type(event_type_id, self.db_path)
for rule in type_rules:
cloned = dataclasses.replace(
rule,
id=uuid.uuid4().hex,
event_id=event.id,
event_type_id=None,
user_types = await self.list_event_types(user_id)
user_type_ids = {t.id for t in user_types}
if event_type_id not in user_type_ids:
logger.debug(
"Skipping event_type %s — not owned by user %s", event_type_id, user_id
)
await insert_reminder_rule(cloned, self.db_path)
else:
type_rules = await list_reminder_rules_for_type(event_type_id, self.db_path)
for rule in type_rules:
cloned = dataclasses.replace(
rule,
id=uuid.uuid4().hex,
event_id=event.id,
event_type_id=None,
)
await insert_reminder_rule(cloned, self.db_path)
logger.info(f"Created event {event.id} ({title}) for user {user_id}")
return event
@ -168,12 +193,23 @@ class CalendarService:
for event in events:
if event.rrule:
# Expand recurring event within [start, end] range
occurrences = expand_rrule(
event.rrule,
event.start_time,
range_start=start,
range_end=end,
)
try:
occurrences = expand_rrule(
event.rrule,
event.start_time,
range_start=start,
range_end=end,
)
except ValueError:
# ponytail: malformed RRULE → fall back to single occurrence
# so one bad event doesn't crash the whole list. Upgrade
# path: surface a validation error at create_event time.
logger.warning(
"Malformed RRULE %r for event %s; falling back to start_time",
event.rrule,
event.id,
)
occurrences = [event.start_time]
for occ_start_str in occurrences:
occ = self._make_occurrence(event, occ_start_str)
result.append(occ)
@ -216,6 +252,25 @@ class CalendarService:
"""List all event types for a user."""
return await db_list_event_types(user_id, self.db_path)
async def get_event_type(self, type_id: str) -> EventType | None:
"""Return a single event type by id, or None."""
async with aiosqlite.connect(str(self.db_path)) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute(
"SELECT * FROM calendar_event_types WHERE id = ?",
(type_id,),
)
row = await cursor.fetchone()
if row is None:
return None
return EventType(
id=row["id"],
user_id=row["user_id"],
name=row["name"],
color=row["color"],
is_default=bool(row["is_default"]),
)
async def create_event_type(
self,
user_id: str,
@ -284,6 +339,10 @@ class CalendarService:
self.db_path,
)
async def get_invitation(self, invitation_id: str) -> Invitation | None:
"""Return a single invitation by id, or None."""
return await db_get_invitation(invitation_id, self.db_path)
async def list_invitations(self, invitee_email: str) -> list[Invitation]:
"""List all invitations for an invitee email."""
return await db_list_invitations(invitee_email, self.db_path)
@ -298,11 +357,13 @@ class CalendarService:
Only ``username`` and ``email`` are returned never user_id or
password fields (G5/A3 least-privilege user search).
"""
pattern = f"%{q}%"
escaped_q = q.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
pattern = f"%{escaped_q}%"
async with aiosqlite.connect(str(self.auth_db_path)) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute(
"SELECT username, email FROM users WHERE username LIKE ? OR email LIKE ? LIMIT 10",
"SELECT username, email FROM users "
"WHERE username LIKE ? ESCAPE '\\' OR email LIKE ? ESCAPE '\\' LIMIT 10",
(pattern, pattern),
)
rows = await cursor.fetchall()

View File

@ -98,6 +98,9 @@ class CalDAVSyncProvider(AbstractSyncProvider):
url=creds.get("url", ""),
username=creds.get("username", ""),
password=creds.get("password", ""),
# ponytail: 30s timeout prevents indefinite hangs on unreachable
# CalDAV servers. Upgrade: make configurable per-config.
timeout=30,
)
def _get_calendar(self, config: ExternalCalendarConfig) -> Any:
@ -242,6 +245,7 @@ class CalDAVSyncProvider(AbstractSyncProvider):
if remote_lm > local_lm:
# Remote wins → update local
await self._notify_conflict(local, remote, winner="remote")
fields = {
"title": remote.title,
"description": remote.description,

View File

@ -65,6 +65,11 @@ class ICSProvider:
Returns ``{"imported": N, "skipped": M, "errors": [...]}``.
Raises ``ValueError`` if the ICS content cannot be parsed at all.
"""
# ponytail: hard size/count caps to prevent DoS via crafted ICS.
# Upgrade: stream-parse for unbounded inputs.
if len(ics_bytes) > 10 * 1024 * 1024:
raise ValueError("ICS file too large (max 10MB)")
imported = 0
skipped = 0
errors: list[str] = []
@ -74,7 +79,12 @@ class ICSProvider:
except Exception as e:
raise ValueError(f"Failed to parse ICS: {e}") from e
for component in cal.walk("VEVENT"):
components = list(cal.walk("VEVENT"))
# ponytail: hard size/count caps to prevent DoS via crafted ICS.
if len(components) > 10000:
raise ValueError("Too many events in ICS file (max 10000)")
for component in components:
try:
uid = str(component.get("UID", "") or "") or None
@ -93,6 +103,9 @@ class ICSProvider:
continue
start_str, is_all_day = _extract_dt(component, "DTSTART")
if not start_str:
errors.append("VEVENT missing DTSTART, skipped")
continue
end_str, _ = _extract_dt(component, "DTEND")
if not end_str:
end_str = start_str
@ -150,7 +163,7 @@ class ICSProvider:
def _event_to_vevent(self, event: CalendarEvent) -> Event:
"""Convert a :class:`CalendarEvent` to an icalendar ``Event`` component."""
vevent = Event()
vevent.add("uid", event.id)
vevent.add("uid", event.external_id or event.id)
vevent.add("summary", event.title)
start_dt = _parse_iso(event.start_time)

View File

@ -103,6 +103,11 @@ class SyncManager:
pulled = await provider.pull_changes(config, since=since)
# 2. Push local changes → remote (events modified since last_sync)
# ponytail: reset `since` to now after pull so pulled events (whose
# last_modified was set to the remote timestamp) aren't pushed back to
# remote, creating a sync loop. Upgrade: filter by pulled event IDs for
# sub-second accuracy.
since = datetime.now(timezone.utc).isoformat()
local_events = await self._get_modified_events(config, since)
pushed = await provider.push_changes(config, local_events)

View File

@ -20,7 +20,7 @@ from collections.abc import Awaitable, Callable
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Any
from urllib.parse import parse_qs, urlparse
from urllib.parse import parse_qs, quote, urlparse
import httpx
@ -245,6 +245,18 @@ class OutlookSyncProvider(AbstractSyncProvider):
"scope": DEFAULT_SCOPE,
},
)
if resp.status_code in (400, 401):
logger.error("Outlook refresh token expired or revoked (status=%s)", resp.status_code)
if self._notify is not None:
await self._notify(
"calendar_sync_error",
{
"config_id": config.id,
"message": "Outlook authentication expired, please re-authenticate",
},
)
# ponytail: re-auth UI is in U12. Upgrade: trigger OAuth re-flow automatically.
raise RuntimeError("Outlook refresh token expired — re-authentication required")
resp.raise_for_status()
payload = resp.json()
creds["access_token"] = payload["access_token"]
@ -338,7 +350,7 @@ class OutlookSyncProvider(AbstractSyncProvider):
def _build_delta_url(self, config: ExternalCalendarConfig) -> str:
"""Build the delta query URL. Uses sync_token if present (incremental)."""
if config.sync_token:
return f"{GRAPH_BASE}/me/calendarView/delta?$deltaToken={config.sync_token}"
return f"{GRAPH_BASE}/me/calendarView/delta?$deltaToken={quote(config.sync_token, safe='')}"
# Initial sync — use date range to scope the fetch
start = (datetime.now(timezone.utc) - timedelta(days=365)).strftime("%Y-%m-%dT%H:%M:%SZ")
end = (datetime.now(timezone.utc) + timedelta(days=90)).strftime("%Y-%m-%dT%H:%M:%SZ")
@ -406,6 +418,7 @@ class OutlookSyncProvider(AbstractSyncProvider):
if remote_lm > local_lm:
# Remote wins → update local
await self._notify_conflict(local, remote, winner="remote")
fields = {
"title": remote.title,
"description": remote.description,

View File

@ -82,8 +82,6 @@ class UpdateEventRequest(BaseModel):
event_type_id: str | None = None
rrule: str | None = None
model_config = {"extra": "allow"}
class CreateEventTypeRequest(BaseModel):
name: str
@ -193,7 +191,7 @@ async def update_event(
if event.user_id != user["user_id"]:
raise HTTPException(status_code=403, detail="Access denied")
# Build fields dict from non-None values (extra fields allowed by model_config)
# Build fields dict from non-None values
fields: dict[str, Any] = {
name: value
for name, value in body.model_dump(exclude_unset=True).items()
@ -273,6 +271,12 @@ async def respond_to_invitation(
)
service = _get_calendar_service(request)
invitation = await service.get_invitation(invitation_id)
if invitation is None:
raise HTTPException(status_code=404, detail="Invitation not found")
email = await service.get_user_email(user["user_id"])
if email is None or invitation.invitee_email != email:
raise HTTPException(status_code=403, detail="Only the invitee can respond")
updated = await service.respond_to_invitation(invitation_id, body.status)
if not updated:
raise HTTPException(status_code=404, detail="Invitation not found")
@ -363,6 +367,11 @@ async def update_event_type(
) -> dict[str, Any]:
"""Update specific fields of an event type."""
service = _get_calendar_service(request)
et = await service.get_event_type(type_id)
if et is None:
raise HTTPException(status_code=404, detail="Event type not found")
if et.user_id != user["user_id"]:
raise HTTPException(status_code=403, detail="Access denied")
fields: dict[str, Any] = {}
for name, value in body.model_dump(exclude_unset=True).items():
if value is not None:

View File

@ -238,6 +238,13 @@ class CalendarTool(Tool):
if not user_id:
return {"success": False, "error": "Missing required field: user_id"}
# Ownership check
event = await self._service.get_event(event_id)
if event is None:
return {"success": False, "error": "Event not found"}
if event.user_id != user_id:
return {"success": False, "error": "Permission denied"}
# Build fields dict from updatable params (only those explicitly provided)
updatable = ["title", "description", "start_time", "end_time", "location", "is_all_day"]
fields: dict[str, Any] = {}
@ -268,6 +275,13 @@ class CalendarTool(Tool):
if not user_id:
return {"success": False, "error": "Missing required field: user_id"}
# Ownership check
event = await self._service.get_event(event_id)
if event is None:
return {"success": False, "error": "Event not found"}
if event.user_id != user_id:
return {"success": False, "error": "Permission denied"}
try:
deleted = await self._service.delete_event(event_id)
if not deleted:

View File

@ -151,7 +151,7 @@ async def test_failed_delivery_retries_up_to_3_times(db_path: Path) -> None:
assert dispatcher.dispatch.call_count == 3 # type: ignore
deliveries = await get_pending_deliveries("evt-1", "rule-1", db_path)
deliveries = await get_pending_deliveries("evt-1", "rule-1", db_path, status="failed")
assert len(deliveries) == 1
assert deliveries[0].attempts == 3
assert deliveries[0].status == "failed"