fix(calendar): code review fixes - 23 issues (2 critical, 15 major, 6 minor)

This commit is contained in:
chiguyong 2026-06-24 11:29:23 +08:00
parent 4ea7801bcf
commit 3fdee65979
12 changed files with 304 additions and 36 deletions

View File

@ -105,7 +105,10 @@ CREATE TABLE IF NOT EXISTS calendar_reminder_deliveries (
channel TEXT NOT NULL DEFAULT 'client', channel TEXT NOT NULL DEFAULT 'client',
attempts INTEGER NOT NULL DEFAULT 0, attempts INTEGER NOT NULL DEFAULT 0,
last_error TEXT, last_error TEXT,
FOREIGN KEY (reminder_rule_id) REFERENCES calendar_reminder_rules(id), -- ponytail: ON DELETE CASCADE ensures deliveries are removed when their
-- reminder_rule is cascade-deleted (e.g. event deletion). Existing DBs
-- created before this change need ALTER TABLE or DB recreation.
FOREIGN KEY (reminder_rule_id) REFERENCES calendar_reminder_rules(id) ON DELETE CASCADE,
FOREIGN KEY (event_id) REFERENCES calendar_events(id) ON DELETE CASCADE FOREIGN KEY (event_id) REFERENCES calendar_events(id) ON DELETE CASCADE
); );
CREATE INDEX IF NOT EXISTS idx_deliveries_status CREATE INDEX IF NOT EXISTS idx_deliveries_status
@ -252,6 +255,8 @@ async def insert_event(event: CalendarEvent, db_path: str | Path | None = None)
"""Insert a calendar event.""" """Insert a calendar event."""
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db: async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON") await db.execute("PRAGMA foreign_keys = ON")
await db.execute( await db.execute(
"INSERT INTO calendar_events (id, user_id, title, description, " "INSERT INTO calendar_events (id, user_id, title, description, "
@ -286,6 +291,9 @@ async def get_event(event_id: str, db_path: str | Path | None = None) -> Calenda
"""Return a single event by id, or None.""" """Return a single event by id, or None."""
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db: async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
db.row_factory = aiosqlite.Row db.row_factory = aiosqlite.Row
cursor = await db.execute("SELECT * FROM calendar_events WHERE id = ?", (event_id,)) cursor = await db.execute("SELECT * FROM calendar_events WHERE id = ?", (event_id,))
row = await cursor.fetchone() row = await cursor.fetchone()
@ -304,6 +312,9 @@ async def get_event_by_external_id(
""" """
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db: async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
db.row_factory = aiosqlite.Row db.row_factory = aiosqlite.Row
cursor = await db.execute( cursor = await db.execute(
"SELECT * FROM calendar_events " "SELECT * FROM calendar_events "
@ -346,6 +357,9 @@ async def list_events(
query += " WHERE " + " AND ".join(conditions) + " ORDER BY e.start_time" query += " WHERE " + " AND ".join(conditions) + " ORDER BY e.start_time"
async with aiosqlite.connect(str(path)) as db: async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
db.row_factory = aiosqlite.Row db.row_factory = aiosqlite.Row
cursor = await db.execute(query, tuple(params)) cursor = await db.execute(query, tuple(params))
rows = await cursor.fetchall() rows = await cursor.fetchall()
@ -361,6 +375,9 @@ async def list_all_events_in_time_range(
""" """
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db: async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
db.row_factory = aiosqlite.Row db.row_factory = aiosqlite.Row
cursor = await db.execute( cursor = await db.execute(
"SELECT * FROM calendar_events WHERE start_time >= ? AND start_time < ? " "SELECT * FROM calendar_events WHERE start_time >= ? AND start_time < ? "
@ -411,6 +428,9 @@ async def update_event(
sql = f"UPDATE calendar_events SET {', '.join(set_clauses)} WHERE id = ?" sql = f"UPDATE calendar_events SET {', '.join(set_clauses)} WHERE id = ?"
async with aiosqlite.connect(str(path)) as db: async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
cursor = await db.execute(sql, tuple(params)) cursor = await db.execute(sql, tuple(params))
await db.commit() await db.commit()
return cursor.rowcount > 0 return cursor.rowcount > 0
@ -420,6 +440,8 @@ async def delete_event(event_id: str, db_path: str | Path | None = None) -> bool
"""Delete an event and its dependent rows. Returns True if deleted.""" """Delete an event and its dependent rows. Returns True if deleted."""
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db: async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON") await db.execute("PRAGMA foreign_keys = ON")
# Manual cascade for event_tags (no ON DELETE on junction FK in some SQLite versions) # Manual cascade for event_tags (no ON DELETE on junction FK in some SQLite versions)
await db.execute("DELETE FROM calendar_event_tags WHERE event_id = ?", (event_id,)) await db.execute("DELETE FROM calendar_event_tags WHERE event_id = ?", (event_id,))
@ -436,6 +458,9 @@ async def delete_event(event_id: str, db_path: str | Path | None = None) -> bool
async def insert_event_type(et: EventType, db_path: str | Path | None = None) -> None: async def insert_event_type(et: EventType, db_path: str | Path | None = None) -> None:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db: async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
await db.execute( await db.execute(
"INSERT INTO calendar_event_types (id, user_id, name, color, is_default) " "INSERT INTO calendar_event_types (id, user_id, name, color, is_default) "
"VALUES (?, ?, ?, ?, ?)", "VALUES (?, ?, ?, ?, ?)",
@ -447,6 +472,9 @@ async def insert_event_type(et: EventType, db_path: str | Path | None = None) ->
async def list_event_types(user_id: str, db_path: str | Path | None = None) -> list[EventType]: async def list_event_types(user_id: str, db_path: str | Path | None = None) -> list[EventType]:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db: async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
db.row_factory = aiosqlite.Row db.row_factory = aiosqlite.Row
cursor = await db.execute( cursor = await db.execute(
"SELECT * FROM calendar_event_types WHERE user_id = ? ORDER BY name", "SELECT * FROM calendar_event_types WHERE user_id = ? ORDER BY name",
@ -477,6 +505,9 @@ async def update_event_type(
params.append(type_id) params.append(type_id)
sql = f"UPDATE calendar_event_types SET {', '.join(set_clauses)} WHERE id = ?" sql = f"UPDATE calendar_event_types SET {', '.join(set_clauses)} WHERE id = ?"
async with aiosqlite.connect(str(path)) as db: async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
cursor = await db.execute(sql, tuple(params)) cursor = await db.execute(sql, tuple(params))
await db.commit() await db.commit()
return cursor.rowcount > 0 return cursor.rowcount > 0
@ -485,6 +516,9 @@ async def update_event_type(
async def delete_event_type(type_id: str, db_path: str | Path | None = None) -> bool: async def delete_event_type(type_id: str, db_path: str | Path | None = None) -> bool:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db: async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
cursor = await db.execute("DELETE FROM calendar_event_types WHERE id = ?", (type_id,)) cursor = await db.execute("DELETE FROM calendar_event_types WHERE id = ?", (type_id,))
await db.commit() await db.commit()
return cursor.rowcount > 0 return cursor.rowcount > 0
@ -498,6 +532,9 @@ async def delete_event_type(type_id: str, db_path: str | Path | None = None) ->
async def insert_tag(tag: Tag, db_path: str | Path | None = None) -> None: async def insert_tag(tag: Tag, db_path: str | Path | None = None) -> None:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db: async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
await db.execute( await db.execute(
"INSERT INTO calendar_tags (id, user_id, name) VALUES (?, ?, ?)", "INSERT INTO calendar_tags (id, user_id, name) VALUES (?, ?, ?)",
(tag.id, tag.user_id, tag.name), (tag.id, tag.user_id, tag.name),
@ -508,6 +545,9 @@ async def insert_tag(tag: Tag, db_path: str | Path | None = None) -> None:
async def list_tags(user_id: str, db_path: str | Path | None = None) -> list[Tag]: async def list_tags(user_id: str, db_path: str | Path | None = None) -> list[Tag]:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db: async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
db.row_factory = aiosqlite.Row db.row_factory = aiosqlite.Row
cursor = await db.execute( cursor = await db.execute(
"SELECT * FROM calendar_tags WHERE user_id = ? ORDER BY name", "SELECT * FROM calendar_tags WHERE user_id = ? ORDER BY name",
@ -520,6 +560,9 @@ async def list_tags(user_id: str, db_path: str | Path | None = None) -> list[Tag
async def delete_tag(tag_id: str, db_path: str | Path | None = None) -> bool: async def delete_tag(tag_id: str, db_path: str | Path | None = None) -> bool:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db: async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
await db.execute("DELETE FROM calendar_event_tags WHERE tag_id = ?", (tag_id,)) await db.execute("DELETE FROM calendar_event_tags WHERE tag_id = ?", (tag_id,))
cursor = await db.execute("DELETE FROM calendar_tags WHERE id = ?", (tag_id,)) cursor = await db.execute("DELETE FROM calendar_tags WHERE id = ?", (tag_id,))
await db.commit() await db.commit()
@ -534,6 +577,9 @@ async def delete_tag(tag_id: str, db_path: str | Path | None = None) -> bool:
async def add_tag_to_event(event_id: str, tag_id: str, db_path: str | Path | None = None) -> None: async def add_tag_to_event(event_id: str, tag_id: str, db_path: str | Path | None = None) -> None:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db: async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
await db.execute( await db.execute(
"INSERT OR IGNORE INTO calendar_event_tags (event_id, tag_id) VALUES (?, ?)", "INSERT OR IGNORE INTO calendar_event_tags (event_id, tag_id) VALUES (?, ?)",
(event_id, tag_id), (event_id, tag_id),
@ -546,6 +592,9 @@ async def remove_tag_from_event(
) -> None: ) -> None:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db: async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
await db.execute( await db.execute(
"DELETE FROM calendar_event_tags WHERE event_id = ? AND tag_id = ?", "DELETE FROM calendar_event_tags WHERE event_id = ? AND tag_id = ?",
(event_id, tag_id), (event_id, tag_id),
@ -556,6 +605,9 @@ async def remove_tag_from_event(
async def get_event_tags(event_id: str, db_path: str | Path | None = None) -> list[Tag]: async def get_event_tags(event_id: str, db_path: str | Path | None = None) -> list[Tag]:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db: async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
db.row_factory = aiosqlite.Row db.row_factory = aiosqlite.Row
cursor = await db.execute( cursor = await db.execute(
"SELECT t.* FROM calendar_tags t " "SELECT t.* FROM calendar_tags t "
@ -575,6 +627,9 @@ async def get_event_tags(event_id: str, db_path: str | Path | None = None) -> li
async def insert_reminder_rule(rule: ReminderRule, db_path: str | Path | None = None) -> None: async def insert_reminder_rule(rule: ReminderRule, db_path: str | Path | None = None) -> None:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db: async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
await db.execute( await db.execute(
"INSERT INTO calendar_reminder_rules " "INSERT INTO calendar_reminder_rules "
"(id, event_id, event_type_id, offset_minutes, channels) " "(id, event_id, event_type_id, offset_minutes, channels) "
@ -595,6 +650,9 @@ async def list_reminder_rules_for_event(
) -> list[ReminderRule]: ) -> list[ReminderRule]:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db: async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
db.row_factory = aiosqlite.Row db.row_factory = aiosqlite.Row
cursor = await db.execute( cursor = await db.execute(
"SELECT * FROM calendar_reminder_rules WHERE event_id = ?", "SELECT * FROM calendar_reminder_rules WHERE event_id = ?",
@ -609,6 +667,9 @@ async def list_reminder_rules_for_type(
) -> list[ReminderRule]: ) -> list[ReminderRule]:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db: async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
db.row_factory = aiosqlite.Row db.row_factory = aiosqlite.Row
cursor = await db.execute( cursor = await db.execute(
"SELECT * FROM calendar_reminder_rules WHERE event_type_id = ?", "SELECT * FROM calendar_reminder_rules WHERE event_type_id = ?",
@ -621,6 +682,9 @@ async def list_reminder_rules_for_type(
async def delete_reminder_rule(rule_id: str, db_path: str | Path | None = None) -> bool: async def delete_reminder_rule(rule_id: str, db_path: str | Path | None = None) -> bool:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db: async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
cursor = await db.execute("DELETE FROM calendar_reminder_rules WHERE id = ?", (rule_id,)) cursor = await db.execute("DELETE FROM calendar_reminder_rules WHERE id = ?", (rule_id,))
await db.commit() await db.commit()
return cursor.rowcount > 0 return cursor.rowcount > 0
@ -636,6 +700,9 @@ async def insert_reminder_delivery(
) -> None: ) -> None:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db: async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
await db.execute( await db.execute(
"INSERT INTO calendar_reminder_deliveries " "INSERT INTO calendar_reminder_deliveries "
"(id, reminder_rule_id, event_id, scheduled_time, status, channel, " "(id, reminder_rule_id, event_id, scheduled_time, status, channel, "
@ -655,16 +722,27 @@ async def insert_reminder_delivery(
async def get_pending_deliveries( async def get_pending_deliveries(
event_id: str, reminder_rule_id: str, db_path: str | Path | None = None event_id: str,
reminder_rule_id: str,
db_path: str | Path | None = None,
status: str = "sent",
) -> list[ReminderDelivery]: ) -> list[ReminderDelivery]:
"""Check idempotency — return existing deliveries for an event+rule.""" """Check idempotency — return existing deliveries for an event+rule.
By default only ``sent`` deliveries are returned, so the scheduler's
idempotency check skips rules that already succeeded. Stuck ``pending``
or ``failed`` deliveries are ignored, allowing retry on the next scan.
"""
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db: async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
db.row_factory = aiosqlite.Row db.row_factory = aiosqlite.Row
cursor = await db.execute( cursor = await db.execute(
"SELECT * FROM calendar_reminder_deliveries " "SELECT * FROM calendar_reminder_deliveries "
"WHERE event_id = ? AND reminder_rule_id = ?", "WHERE event_id = ? AND reminder_rule_id = ? AND status = ?",
(event_id, reminder_rule_id), (event_id, reminder_rule_id, status),
) )
rows = await cursor.fetchall() rows = await cursor.fetchall()
return [_row_to_reminder_delivery(row) for row in rows] return [_row_to_reminder_delivery(row) for row in rows]
@ -678,6 +756,9 @@ async def update_delivery_status(
) -> bool: ) -> bool:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db: async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
cursor = await db.execute( cursor = await db.execute(
"UPDATE calendar_reminder_deliveries " "UPDATE calendar_reminder_deliveries "
"SET status = ?, attempts = attempts + 1, last_error = ? " "SET status = ?, attempts = attempts + 1, last_error = ? "
@ -698,6 +779,9 @@ async def insert_external_config(
) -> None: ) -> None:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db: async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
await db.execute( await db.execute(
"INSERT INTO calendar_external_configs " "INSERT INTO calendar_external_configs "
"(id, user_id, provider, credentials, sync_frequency, sync_scope, " "(id, user_id, provider, credentials, sync_frequency, sync_scope, "
@ -721,6 +805,9 @@ async def list_external_configs(
) -> list[ExternalCalendarConfig]: ) -> list[ExternalCalendarConfig]:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db: async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
db.row_factory = aiosqlite.Row db.row_factory = aiosqlite.Row
cursor = await db.execute( cursor = await db.execute(
"SELECT * FROM calendar_external_configs WHERE user_id = ?", "SELECT * FROM calendar_external_configs WHERE user_id = ?",
@ -757,6 +844,9 @@ async def update_external_config(
params.append(config_id) params.append(config_id)
sql = f"UPDATE calendar_external_configs SET {', '.join(set_clauses)} WHERE id = ?" sql = f"UPDATE calendar_external_configs SET {', '.join(set_clauses)} WHERE id = ?"
async with aiosqlite.connect(str(path)) as db: async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
cursor = await db.execute(sql, tuple(params)) cursor = await db.execute(sql, tuple(params))
await db.commit() await db.commit()
return cursor.rowcount > 0 return cursor.rowcount > 0
@ -765,6 +855,9 @@ async def update_external_config(
async def delete_external_config(config_id: str, db_path: str | Path | None = None) -> bool: async def delete_external_config(config_id: str, db_path: str | Path | None = None) -> bool:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db: async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
cursor = await db.execute( cursor = await db.execute(
"DELETE FROM calendar_external_configs WHERE id = ?", (config_id,) "DELETE FROM calendar_external_configs WHERE id = ?", (config_id,)
) )
@ -780,6 +873,9 @@ async def delete_external_config(config_id: str, db_path: str | Path | None = No
async def insert_invitation(invitation: Invitation, db_path: str | Path | None = None) -> None: async def insert_invitation(invitation: Invitation, db_path: str | Path | None = None) -> None:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db: async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
await db.execute( await db.execute(
"INSERT INTO calendar_invitations " "INSERT INTO calendar_invitations "
"(id, event_id, inviter_user_id, invitee_email, status, responded_at) " "(id, event_id, inviter_user_id, invitee_email, status, responded_at) "
@ -801,6 +897,9 @@ async def get_invitation(
) -> Invitation | None: ) -> Invitation | None:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db: async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
db.row_factory = aiosqlite.Row db.row_factory = aiosqlite.Row
cursor = await db.execute( cursor = await db.execute(
"SELECT * FROM calendar_invitations WHERE id = ?", (invitation_id,) "SELECT * FROM calendar_invitations WHERE id = ?", (invitation_id,)
@ -814,6 +913,9 @@ async def list_invitations(
) -> list[Invitation]: ) -> list[Invitation]:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db: async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
db.row_factory = aiosqlite.Row db.row_factory = aiosqlite.Row
cursor = await db.execute( cursor = await db.execute(
"SELECT * FROM calendar_invitations WHERE invitee_email = ? ORDER BY responded_at DESC", "SELECT * FROM calendar_invitations WHERE invitee_email = ? ORDER BY responded_at DESC",
@ -831,6 +933,9 @@ async def update_invitation_status(
) -> bool: ) -> bool:
path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
async with aiosqlite.connect(str(path)) as db: async with aiosqlite.connect(str(path)) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute("PRAGMA busy_timeout = 5000")
await db.execute("PRAGMA foreign_keys = ON")
cursor = await db.execute( cursor = await db.execute(
"UPDATE calendar_invitations SET status = ?, responded_at = ? WHERE id = ?", "UPDATE calendar_invitations SET status = ?, responded_at = ? WHERE id = ?",
(status, responded_at, invitation_id), (status, responded_at, invitation_id),

View File

@ -6,6 +6,7 @@ All times are UTC (see KTD-11).
from __future__ import annotations from __future__ import annotations
import itertools
from datetime import datetime, timezone from datetime import datetime, timezone
from dateutil.rrule import rrulestr from dateutil.rrule import rrulestr
@ -46,7 +47,9 @@ def expand_rrule(
# rrulestr expects the RRULE to have a DTSTART context. # rrulestr expects the RRULE to have a DTSTART context.
# We prepend DTSTART to ensure the rule starts from the event's start time. # We prepend DTSTART to ensure the rule starts from the event's start time.
full_rule = f"DTSTART:{start_dt.strftime('%Y%m%dT%H%M%SZ')}\nRRULE:{rrule_str}" full_rule = (
f"DTSTART:{start_dt.astimezone(timezone.utc).strftime('%Y%m%dT%H%M%SZ')}\nRRULE:{rrule_str}"
)
rule = rrulestr(full_rule) rule = rrulestr(full_rule)
if range_start is not None and range_end is not None: if range_start is not None and range_end is not None:
@ -61,7 +64,9 @@ def expand_rrule(
re_ = _parse_dt(range_end) re_ = _parse_dt(range_end)
occurrences = [dt for dt in rule if dt < re_] occurrences = [dt for dt in rule if dt < re_]
else: else:
occurrences = list(rule) # ponytail: 1000-occurrence ceiling for unbounded rules (FREQ=DAILY
# without COUNT/UNTIL). Upgrade path: accept a max_occurrences param.
occurrences = list(itertools.islice(rule, 1000))
# Convert back to ISO 8601 UTC strings # Convert back to ISO 8601 UTC strings
return [dt.astimezone(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S+00:00") for dt in occurrences] return [dt.astimezone(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S+00:00") for dt in occurrences]

View File

@ -7,9 +7,12 @@ patching module imports.
from __future__ import annotations from __future__ import annotations
import ipaddress
import logging import logging
import socket
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from dataclasses import dataclass from dataclasses import dataclass
from urllib.parse import urlparse
from agentkit.calendar.models import CalendarEvent from agentkit.calendar.models import CalendarEvent
@ -21,13 +24,46 @@ class SmtpConfig:
"""SMTP server configuration for the email reminder channel.""" """SMTP server configuration for the email reminder channel."""
host: str = "localhost" host: str = "localhost"
port: int = 25 # ponytail: STARTTLS on 587 is the modern default; plaintext port 25 is
# only appropriate for local MTA relays. Upgrade: implicit TLS on 465.
port: int = 587
username: str | None = None username: str | None = None
password: str | None = None password: str | None = None
use_tls: bool = False use_tls: bool = True
from_email: str = "noreply@agentkit.local" from_email: str = "noreply@agentkit.local"
def _is_safe_webhook_url(url: str) -> bool:
"""Validate webhook URL to prevent SSRF attacks.
ponytail: Basic SSRF guard blocks private/internal IP ranges (RFC 1918),
loopback, link-local (169.254.x includes cloud metadata endpoints), and
non-http(s) schemes. Uses blocking socket.getaddrinfo; upgrade to
asyncio.getaddrinfo if webhook dispatch is on a hot path.
"""
parsed = urlparse(url)
if parsed.scheme not in ("http", "https"):
return False
hostname = parsed.hostname
if not hostname:
return False
if hostname == "localhost" or hostname.startswith("169.254."):
return False
try:
infos = socket.getaddrinfo(hostname, None)
for _, _, _, _, sockaddr in infos:
addr = sockaddr[0]
try:
ip = ipaddress.ip_address(addr)
except ValueError:
continue
if ip.is_private or ip.is_loopback or ip.is_link_local:
return False
except socket.gaierror:
pass # DNS resolution failed — let httpx handle the connection error
return True
class ReminderDispatcher: class ReminderDispatcher:
"""Dispatch reminders via client push, email, and webhook channels. """Dispatch reminders via client push, email, and webhook channels.
@ -105,6 +141,9 @@ class ReminderDispatcher:
async def _send_webhook(self, event: CalendarEvent, user_id: str) -> bool: async def _send_webhook(self, event: CalendarEvent, user_id: str) -> bool:
if self._webhook_url is None: if self._webhook_url is None:
return False return False
if not _is_safe_webhook_url(self._webhook_url):
logger.warning("Webhook URL blocked: private/internal address")
return False
import httpx import httpx
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:

View File

@ -126,7 +126,7 @@ class ReminderScheduler:
Idempotent: if any delivery already exists for this event+rule, skip. Idempotent: if any delivery already exists for this event+rule, skip.
""" """
existing = await get_pending_deliveries(event.id, rule.id, self._db_path) existing = await get_pending_deliveries(event.id, rule.id, self._db_path, status="sent")
if existing: if existing:
return 0 return 0

View File

@ -20,6 +20,7 @@ from agentkit.calendar.db import (
add_tag_to_event, add_tag_to_event,
delete_event as db_delete_event, delete_event as db_delete_event,
get_event as db_get_event, get_event as db_get_event,
get_invitation as db_get_invitation,
insert_event, insert_event,
insert_event_type, insert_event_type,
insert_invitation, insert_invitation,
@ -47,6 +48,16 @@ from agentkit.server.auth.models import DEFAULT_AUTH_DB_PATH
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _validate_iso(dt_str: str) -> None:
"""Validate that *dt_str* is a parseable ISO 8601 string."""
if not dt_str:
raise ValueError("start_time must be a valid ISO 8601 string")
try:
datetime.fromisoformat(dt_str)
except ValueError:
raise ValueError(f"Invalid ISO 8601 format: {dt_str}")
def _parse_dt(dt_str: str) -> datetime: def _parse_dt(dt_str: str) -> datetime:
"""Parse ISO 8601 string to timezone-aware datetime (UTC).""" """Parse ISO 8601 string to timezone-aware datetime (UTC)."""
dt = datetime.fromisoformat(dt_str) dt = datetime.fromisoformat(dt_str)
@ -97,6 +108,7 @@ class CalendarService:
tag_ids: list[str] | None = None, tag_ids: list[str] | None = None,
) -> CalendarEvent: ) -> CalendarEvent:
"""Create a calendar event with UUID, timestamps, tags, and cloned reminders.""" """Create a calendar event with UUID, timestamps, tags, and cloned reminders."""
_validate_iso(start_time)
now = _now_iso() now = _now_iso()
event = CalendarEvent( event = CalendarEvent(
id=uuid.uuid4().hex, id=uuid.uuid4().hex,
@ -117,22 +129,35 @@ class CalendarService:
) )
await insert_event(event, self.db_path) await insert_event(event, self.db_path)
# Link tags if provided # Link tags if provided — skip any that don't belong to the user
if tag_ids: if tag_ids:
user_tags = await self.list_tags(user_id)
user_tag_ids = {t.id for t in user_tags}
for tag_id in tag_ids: for tag_id in tag_ids:
if tag_id not in user_tag_ids:
logger.debug("Skipping tag %s — not owned by user %s", tag_id, user_id)
continue
await add_tag_to_event(event.id, tag_id, self.db_path) await add_tag_to_event(event.id, tag_id, self.db_path)
# Clone type-level default reminder rules to the event # Clone type-level default reminder rules to the event — skip if type
# doesn't belong to the user
if event_type_id: if event_type_id:
type_rules = await list_reminder_rules_for_type(event_type_id, self.db_path) user_types = await self.list_event_types(user_id)
for rule in type_rules: user_type_ids = {t.id for t in user_types}
cloned = dataclasses.replace( if event_type_id not in user_type_ids:
rule, logger.debug(
id=uuid.uuid4().hex, "Skipping event_type %s — not owned by user %s", event_type_id, user_id
event_id=event.id,
event_type_id=None,
) )
await insert_reminder_rule(cloned, self.db_path) else:
type_rules = await list_reminder_rules_for_type(event_type_id, self.db_path)
for rule in type_rules:
cloned = dataclasses.replace(
rule,
id=uuid.uuid4().hex,
event_id=event.id,
event_type_id=None,
)
await insert_reminder_rule(cloned, self.db_path)
logger.info(f"Created event {event.id} ({title}) for user {user_id}") logger.info(f"Created event {event.id} ({title}) for user {user_id}")
return event return event
@ -168,12 +193,23 @@ class CalendarService:
for event in events: for event in events:
if event.rrule: if event.rrule:
# Expand recurring event within [start, end] range # Expand recurring event within [start, end] range
occurrences = expand_rrule( try:
event.rrule, occurrences = expand_rrule(
event.start_time, event.rrule,
range_start=start, event.start_time,
range_end=end, range_start=start,
) range_end=end,
)
except ValueError:
# ponytail: malformed RRULE → fall back to single occurrence
# so one bad event doesn't crash the whole list. Upgrade
# path: surface a validation error at create_event time.
logger.warning(
"Malformed RRULE %r for event %s; falling back to start_time",
event.rrule,
event.id,
)
occurrences = [event.start_time]
for occ_start_str in occurrences: for occ_start_str in occurrences:
occ = self._make_occurrence(event, occ_start_str) occ = self._make_occurrence(event, occ_start_str)
result.append(occ) result.append(occ)
@ -216,6 +252,25 @@ class CalendarService:
"""List all event types for a user.""" """List all event types for a user."""
return await db_list_event_types(user_id, self.db_path) return await db_list_event_types(user_id, self.db_path)
async def get_event_type(self, type_id: str) -> EventType | None:
"""Return a single event type by id, or None."""
async with aiosqlite.connect(str(self.db_path)) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute(
"SELECT * FROM calendar_event_types WHERE id = ?",
(type_id,),
)
row = await cursor.fetchone()
if row is None:
return None
return EventType(
id=row["id"],
user_id=row["user_id"],
name=row["name"],
color=row["color"],
is_default=bool(row["is_default"]),
)
async def create_event_type( async def create_event_type(
self, self,
user_id: str, user_id: str,
@ -284,6 +339,10 @@ class CalendarService:
self.db_path, self.db_path,
) )
async def get_invitation(self, invitation_id: str) -> Invitation | None:
"""Return a single invitation by id, or None."""
return await db_get_invitation(invitation_id, self.db_path)
async def list_invitations(self, invitee_email: str) -> list[Invitation]: async def list_invitations(self, invitee_email: str) -> list[Invitation]:
"""List all invitations for an invitee email.""" """List all invitations for an invitee email."""
return await db_list_invitations(invitee_email, self.db_path) return await db_list_invitations(invitee_email, self.db_path)
@ -298,11 +357,13 @@ class CalendarService:
Only ``username`` and ``email`` are returned never user_id or Only ``username`` and ``email`` are returned never user_id or
password fields (G5/A3 least-privilege user search). password fields (G5/A3 least-privilege user search).
""" """
pattern = f"%{q}%" escaped_q = q.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
pattern = f"%{escaped_q}%"
async with aiosqlite.connect(str(self.auth_db_path)) as db: async with aiosqlite.connect(str(self.auth_db_path)) as db:
db.row_factory = aiosqlite.Row db.row_factory = aiosqlite.Row
cursor = await db.execute( cursor = await db.execute(
"SELECT username, email FROM users WHERE username LIKE ? OR email LIKE ? LIMIT 10", "SELECT username, email FROM users "
"WHERE username LIKE ? ESCAPE '\\' OR email LIKE ? ESCAPE '\\' LIMIT 10",
(pattern, pattern), (pattern, pattern),
) )
rows = await cursor.fetchall() rows = await cursor.fetchall()

View File

@ -98,6 +98,9 @@ class CalDAVSyncProvider(AbstractSyncProvider):
url=creds.get("url", ""), url=creds.get("url", ""),
username=creds.get("username", ""), username=creds.get("username", ""),
password=creds.get("password", ""), password=creds.get("password", ""),
# ponytail: 30s timeout prevents indefinite hangs on unreachable
# CalDAV servers. Upgrade: make configurable per-config.
timeout=30,
) )
def _get_calendar(self, config: ExternalCalendarConfig) -> Any: def _get_calendar(self, config: ExternalCalendarConfig) -> Any:
@ -242,6 +245,7 @@ class CalDAVSyncProvider(AbstractSyncProvider):
if remote_lm > local_lm: if remote_lm > local_lm:
# Remote wins → update local # Remote wins → update local
await self._notify_conflict(local, remote, winner="remote")
fields = { fields = {
"title": remote.title, "title": remote.title,
"description": remote.description, "description": remote.description,

View File

@ -65,6 +65,11 @@ class ICSProvider:
Returns ``{"imported": N, "skipped": M, "errors": [...]}``. Returns ``{"imported": N, "skipped": M, "errors": [...]}``.
Raises ``ValueError`` if the ICS content cannot be parsed at all. Raises ``ValueError`` if the ICS content cannot be parsed at all.
""" """
# ponytail: hard size/count caps to prevent DoS via crafted ICS.
# Upgrade: stream-parse for unbounded inputs.
if len(ics_bytes) > 10 * 1024 * 1024:
raise ValueError("ICS file too large (max 10MB)")
imported = 0 imported = 0
skipped = 0 skipped = 0
errors: list[str] = [] errors: list[str] = []
@ -74,7 +79,12 @@ class ICSProvider:
except Exception as e: except Exception as e:
raise ValueError(f"Failed to parse ICS: {e}") from e raise ValueError(f"Failed to parse ICS: {e}") from e
for component in cal.walk("VEVENT"): components = list(cal.walk("VEVENT"))
# ponytail: hard size/count caps to prevent DoS via crafted ICS.
if len(components) > 10000:
raise ValueError("Too many events in ICS file (max 10000)")
for component in components:
try: try:
uid = str(component.get("UID", "") or "") or None uid = str(component.get("UID", "") or "") or None
@ -93,6 +103,9 @@ class ICSProvider:
continue continue
start_str, is_all_day = _extract_dt(component, "DTSTART") start_str, is_all_day = _extract_dt(component, "DTSTART")
if not start_str:
errors.append("VEVENT missing DTSTART, skipped")
continue
end_str, _ = _extract_dt(component, "DTEND") end_str, _ = _extract_dt(component, "DTEND")
if not end_str: if not end_str:
end_str = start_str end_str = start_str
@ -150,7 +163,7 @@ class ICSProvider:
def _event_to_vevent(self, event: CalendarEvent) -> Event: def _event_to_vevent(self, event: CalendarEvent) -> Event:
"""Convert a :class:`CalendarEvent` to an icalendar ``Event`` component.""" """Convert a :class:`CalendarEvent` to an icalendar ``Event`` component."""
vevent = Event() vevent = Event()
vevent.add("uid", event.id) vevent.add("uid", event.external_id or event.id)
vevent.add("summary", event.title) vevent.add("summary", event.title)
start_dt = _parse_iso(event.start_time) start_dt = _parse_iso(event.start_time)

View File

@ -103,6 +103,11 @@ class SyncManager:
pulled = await provider.pull_changes(config, since=since) pulled = await provider.pull_changes(config, since=since)
# 2. Push local changes → remote (events modified since last_sync) # 2. Push local changes → remote (events modified since last_sync)
# ponytail: reset `since` to now after pull so pulled events (whose
# last_modified was set to the remote timestamp) aren't pushed back to
# remote, creating a sync loop. Upgrade: filter by pulled event IDs for
# sub-second accuracy.
since = datetime.now(timezone.utc).isoformat()
local_events = await self._get_modified_events(config, since) local_events = await self._get_modified_events(config, since)
pushed = await provider.push_changes(config, local_events) pushed = await provider.push_changes(config, local_events)

View File

@ -20,7 +20,7 @@ from collections.abc import Awaitable, Callable
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from urllib.parse import parse_qs, urlparse from urllib.parse import parse_qs, quote, urlparse
import httpx import httpx
@ -245,6 +245,18 @@ class OutlookSyncProvider(AbstractSyncProvider):
"scope": DEFAULT_SCOPE, "scope": DEFAULT_SCOPE,
}, },
) )
if resp.status_code in (400, 401):
logger.error("Outlook refresh token expired or revoked (status=%s)", resp.status_code)
if self._notify is not None:
await self._notify(
"calendar_sync_error",
{
"config_id": config.id,
"message": "Outlook authentication expired, please re-authenticate",
},
)
# ponytail: re-auth UI is in U12. Upgrade: trigger OAuth re-flow automatically.
raise RuntimeError("Outlook refresh token expired — re-authentication required")
resp.raise_for_status() resp.raise_for_status()
payload = resp.json() payload = resp.json()
creds["access_token"] = payload["access_token"] creds["access_token"] = payload["access_token"]
@ -338,7 +350,7 @@ class OutlookSyncProvider(AbstractSyncProvider):
def _build_delta_url(self, config: ExternalCalendarConfig) -> str: def _build_delta_url(self, config: ExternalCalendarConfig) -> str:
"""Build the delta query URL. Uses sync_token if present (incremental).""" """Build the delta query URL. Uses sync_token if present (incremental)."""
if config.sync_token: if config.sync_token:
return f"{GRAPH_BASE}/me/calendarView/delta?$deltaToken={config.sync_token}" return f"{GRAPH_BASE}/me/calendarView/delta?$deltaToken={quote(config.sync_token, safe='')}"
# Initial sync — use date range to scope the fetch # Initial sync — use date range to scope the fetch
start = (datetime.now(timezone.utc) - timedelta(days=365)).strftime("%Y-%m-%dT%H:%M:%SZ") start = (datetime.now(timezone.utc) - timedelta(days=365)).strftime("%Y-%m-%dT%H:%M:%SZ")
end = (datetime.now(timezone.utc) + timedelta(days=90)).strftime("%Y-%m-%dT%H:%M:%SZ") end = (datetime.now(timezone.utc) + timedelta(days=90)).strftime("%Y-%m-%dT%H:%M:%SZ")
@ -406,6 +418,7 @@ class OutlookSyncProvider(AbstractSyncProvider):
if remote_lm > local_lm: if remote_lm > local_lm:
# Remote wins → update local # Remote wins → update local
await self._notify_conflict(local, remote, winner="remote")
fields = { fields = {
"title": remote.title, "title": remote.title,
"description": remote.description, "description": remote.description,

View File

@ -82,8 +82,6 @@ class UpdateEventRequest(BaseModel):
event_type_id: str | None = None event_type_id: str | None = None
rrule: str | None = None rrule: str | None = None
model_config = {"extra": "allow"}
class CreateEventTypeRequest(BaseModel): class CreateEventTypeRequest(BaseModel):
name: str name: str
@ -193,7 +191,7 @@ async def update_event(
if event.user_id != user["user_id"]: if event.user_id != user["user_id"]:
raise HTTPException(status_code=403, detail="Access denied") raise HTTPException(status_code=403, detail="Access denied")
# Build fields dict from non-None values (extra fields allowed by model_config) # Build fields dict from non-None values
fields: dict[str, Any] = { fields: dict[str, Any] = {
name: value name: value
for name, value in body.model_dump(exclude_unset=True).items() for name, value in body.model_dump(exclude_unset=True).items()
@ -273,6 +271,12 @@ async def respond_to_invitation(
) )
service = _get_calendar_service(request) service = _get_calendar_service(request)
invitation = await service.get_invitation(invitation_id)
if invitation is None:
raise HTTPException(status_code=404, detail="Invitation not found")
email = await service.get_user_email(user["user_id"])
if email is None or invitation.invitee_email != email:
raise HTTPException(status_code=403, detail="Only the invitee can respond")
updated = await service.respond_to_invitation(invitation_id, body.status) updated = await service.respond_to_invitation(invitation_id, body.status)
if not updated: if not updated:
raise HTTPException(status_code=404, detail="Invitation not found") raise HTTPException(status_code=404, detail="Invitation not found")
@ -363,6 +367,11 @@ async def update_event_type(
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Update specific fields of an event type.""" """Update specific fields of an event type."""
service = _get_calendar_service(request) service = _get_calendar_service(request)
et = await service.get_event_type(type_id)
if et is None:
raise HTTPException(status_code=404, detail="Event type not found")
if et.user_id != user["user_id"]:
raise HTTPException(status_code=403, detail="Access denied")
fields: dict[str, Any] = {} fields: dict[str, Any] = {}
for name, value in body.model_dump(exclude_unset=True).items(): for name, value in body.model_dump(exclude_unset=True).items():
if value is not None: if value is not None:

View File

@ -238,6 +238,13 @@ class CalendarTool(Tool):
if not user_id: if not user_id:
return {"success": False, "error": "Missing required field: user_id"} return {"success": False, "error": "Missing required field: user_id"}
# Ownership check
event = await self._service.get_event(event_id)
if event is None:
return {"success": False, "error": "Event not found"}
if event.user_id != user_id:
return {"success": False, "error": "Permission denied"}
# Build fields dict from updatable params (only those explicitly provided) # Build fields dict from updatable params (only those explicitly provided)
updatable = ["title", "description", "start_time", "end_time", "location", "is_all_day"] updatable = ["title", "description", "start_time", "end_time", "location", "is_all_day"]
fields: dict[str, Any] = {} fields: dict[str, Any] = {}
@ -268,6 +275,13 @@ class CalendarTool(Tool):
if not user_id: if not user_id:
return {"success": False, "error": "Missing required field: user_id"} return {"success": False, "error": "Missing required field: user_id"}
# Ownership check
event = await self._service.get_event(event_id)
if event is None:
return {"success": False, "error": "Event not found"}
if event.user_id != user_id:
return {"success": False, "error": "Permission denied"}
try: try:
deleted = await self._service.delete_event(event_id) deleted = await self._service.delete_event(event_id)
if not deleted: if not deleted:

View File

@ -151,7 +151,7 @@ async def test_failed_delivery_retries_up_to_3_times(db_path: Path) -> None:
assert dispatcher.dispatch.call_count == 3 # type: ignore assert dispatcher.dispatch.call_count == 3 # type: ignore
deliveries = await get_pending_deliveries("evt-1", "rule-1", db_path) deliveries = await get_pending_deliveries("evt-1", "rule-1", db_path, status="failed")
assert len(deliveries) == 1 assert len(deliveries) == 1
assert deliveries[0].attempts == 3 assert deliveries[0].attempts == 3
assert deliveries[0].status == "failed" assert deliveries[0].status == "failed"