fix(auth,chat): P0 security fixes + stop-generation button + doc sync

U1: whoami cold-start security — add is_active check (disabled users
now get 401, not 200) and replace create_token_pair with create_access_token
to avoid minting a discarded refresh token (token-amplification risk).

U2: list_active_by_provider now filters expired sessions (expires_at > now)
matching its docstring promise; previously only checked revoked = 0.

U3: Fix asyncio.run() crash in test_revoke_other_user_session_returns_404
(converted to async). Add U1/U2 verification tests (disabled-user whoami,
no-refresh-leak, expired-session filtering, provider filtering) and
strengthen admin route tests (404 boundary, non-admin 403 on /admin/sessions).

U4: Update CLAUDE.md/AGENTS.md Request Flow — CostAwareRouter 3-layer
diagram replaced with actual RequestPreprocessor architecture (@board/@team
prefix intercepts then @skill: prefix then trivial-input regex then default
REACT). ExecutionMode list expanded to all 7 values.

U5: Frontend stop-generation button — ChatInput.vue shows a stop button
when isGenerating is true; chat store gains stopGeneration() that sends
{type:"cancel"} over WebSocket (backend portal.py already handles cancel).

Tests: 120 auth tests pass (unit + integration). ruff clean. vue-tsc clean.
This commit is contained in:
chiguyong 2026-06-21 11:36:58 +08:00
parent aee7362665
commit 67c0d67262
11 changed files with 553 additions and 19 deletions

View File

@ -59,14 +59,20 @@ docker-compose up -d # AgentKit + Redis + PostgreSQL
### Request Flow
```
User Input -> CostAwareRouter (3-layer)
Layer 0: RegexRules (~0ms, 0 tokens) -> DIRECT_CHAT
Layer 1: HeuristicClassifier (~0ms) / LLM quick_classify (~500ms, ~100 tokens)
Layer 1.5: SemanticRouter (vector similarity, optional)
Layer 2: Capability matching / Vickrey Auction
-> ExecutionMode: DIRECT_CHAT / REACT / SKILL_REACT / TEAM_COLLAB
User Input
├─ @board prefix -> BoardRouter (experts/board_router.py) -> BoardOrchestrator (multi-round discussion)
├─ @team prefix -> ExpertTeamRouter (experts/router.py) -> TeamOrchestrator (pipeline collaboration)
└─ otherwise -> RequestPreprocessor (chat/request_preprocessor.py)
Layer 0: @skill:xxx prefix -> explicit skill selection (SKILL_REACT or skill's configured mode)
Layer 1: Trivial-input regex (~0ms, 0 tokens) -> DIRECT_CHAT
(greetings, identity, factual Q&A, math, translation; guarded by _TOOL_CONTEXT_RE)
Default: -> REACT (LLM decides tool usage autonomously in the agent loop)
-> ExecutionMode: DIRECT_CHAT / REACT / SKILL_REACT / REWOO / REFLEXION / PLAN_EXEC / TEAM_COLLAB
(chat handler currently supports DIRECT_CHAT, REACT, SKILL_REACT; others raise "not yet supported")
```
**Note**: The old 3-layer `CostAwareRouter` (with `RegexRules` / `HeuristicClassifier` / `SemanticRouter` / `Vickrey Auction`) has been replaced by `RequestPreprocessor`. The `IntentRouter` (`router/intent.py`) exists but is not wired into the chat flow. `AuctionHouse` with Vickrey auction lives in `marketplace/auction.py` (marketplace subsystem, not routing).
### Agent Hierarchy
```

View File

@ -0,0 +1,240 @@
---
title: "fix: P0 安全修复与代码质量优化"
status: active
date: 2026-06-21
type: fix
origin: ocr 代码审查 + 全面项目质量评估报告
---
# P0 安全修复与代码质量优化计划
## Summary
基于 open-code-review 对 auth 功能分支的审查9 文件 15 评论)和全面项目质量评估,本计划修复 4 个 P0 级安全/Bug 问题、1 个文档不一致问题、1 个核心 UX 缺失,并补强集成测试覆盖。所有修复均有明确的代码证据和 ocr 审查评论支撑。
## Problem Frame
ocr 审查发现 whoami 冷启动路径存在 2 个安全漏洞(被禁用用户绕过 + 令牌放大风险1 个数据不一致 Buglist_active_by_provider 未过滤过期会话1 个测试崩溃 Bugasyncio.run 在事件循环内。此外CLAUDE.md/AGENTS.md 中路由系统描述与实际实现严重不匹配,前端缺少基础的"停止生成"功能。
## Requirements
- R1: whoami 冷启动必须检查 `is_active`,被禁用用户不能获取新 access token
- R2: whoami 冷启动必须实现 refresh token 轮换,防止令牌放大攻击
- R3: `list_active_by_provider` 必须过滤过期会话,与 docstring 承诺一致
- R4: 集成测试中 `asyncio.run()` 必须修复,跨用户撤销测试必须可运行
- R5: CLAUDE.md/AGENTS.md 路由系统描述必须与 RequestPreprocessor 实际实现匹配
- R6: 前端聊天界面必须支持"停止生成"功能,用户可中途取消 LLM 输出
- R7: 集成测试必须补齐 ocr 指出的覆盖缺口404 边界、密码修改端到端、会话撤销验证)
## Key Technical Decisions
### KTD-1: whoami 冷启动 refresh token 处理策略
**决策**: 采用方案 B — 仅签发短期 access token不创建新 refresh token。
**理由**: 客户端已有有效 refresh token冷启动只需 access token 即可恢复会话。调用 `svc.rotate()` 会增加复杂度且改变客户端持有的 refresh token而方案 B 更简单、安全(原 refresh token 仍受重用检测保护)。不调用 `create_token_pair`,改为单独签发 access token。
### KTD-2: 停止生成的实现方式
**决策**: 前端通过 WebSocket 发送 `cancel` 消息(协议已定义),后端通过 CancellationToken 取消正在执行的 Agent 任务。
**理由**: WebSocket 协议中已有 `cancel` 消息类型(见 AGENTS.md后端 BaseAgent 已实现 CancellationToken 协作式取消。只需前端添加按钮和发送逻辑,无需新增后端端点。
## Scope Boundaries
### In Scope
- whoami 冷启动安全修复is_active + token 策略)
- list_active_by_provider 过滤修复
- asyncio.run 测试修复 + 集成测试补强
- CLAUDE.md/AGENTS.md 路由文档更新
- 前端"停止生成"按钮
### Deferred to Follow-Up Work
- 多模态输入支持C1需独立计划
- 用户反馈机制 thumbs up/downC3需独立计划
- Prompt CachingC9需 LLM Gateway 改造)
- 多租户隔离逻辑C7需独立计划
- `: Any` 类型清理213 处,大规模重构)
- 吞异常清理109 处,需逐文件审查)
- LLM 模块测试覆盖0.06 比率,需独立计划)
---
## Implementation Units
### U1. 修复 whoami 冷启动安全漏洞
**Goal**: 修复 whoami 冷启动路径的 2 个安全漏洞:缺失 is_active 检查 + 令牌放大风险。
**Requirements**: R1, R2
**Dependencies**: 无
**Files**:
- Modify: `src/agentkit/server/routes/auth.py`whoami 路由)
- Modify: `src/agentkit/server/auth/jwt_utils.py`(如需新增单独签发 access token 的函数)
- Test: `tests/integration/auth/test_auth_routes.py`
**Approach**:
1. 在 whoami 路由中,将 `if row is None` 改为 `if row is None or not bool(row["is_active"])`,返回 401
2. 移除冷启动路径中的 `create_token_pair` 调用,改为仅签发 access token使用 jwt_utils 中现有的 access token 创建逻辑)
3. 不创建新 refresh token不调用 `svc.rotate()`,客户端保留原 refresh token
**Patterns to follow**: `/auth/refresh` 路由中的 `is_active` 检查模式auth.py 第 514 行)
**Test scenarios**:
- **Happy path**: 被禁用用户is_active=0用 refresh token 调用 whoami → 401
- **Happy path**: 活跃用户用 refresh token 调用 whoami → 200 + 新 access token + 无新 refresh token
- **Edge case**: 活跃用户用 access token 调用 whoami → 200 + access_token 为 None行为不变
- **Error path**: 不存在的用户 IDtoken 被篡改)→ 401
**Verification**: `pytest tests/integration/auth/test_auth_routes.py::TestWhoamiColdStart -v` 全部通过
---
### U2. 修复 list_active_by_provider 过滤过期会话
**Goal**: 使 SQL 查询与 docstring 承诺一致,过滤掉已过期的会话。
**Requirements**: R3
**Dependencies**: 无
**Files**:
- Modify: `src/agentkit/server/auth/session_service.py`list_active_by_provider 方法)
- Test: `tests/unit/auth/test_session_service.py`
**Approach**:
1. 在 SQL 查询中添加 `AND expires_at > ?` 条件
2. 传入当前 UTC 时间的 ISO 格式字符串作为参数
3. 需要在方法顶部导入 `datetime, timezone`(如未导入)
**Patterns to follow**: `list_all` 方法中的 SQL 构造模式
**Test scenarios**:
- **Happy path**: 有 1 个活跃会话 + 1 个过期会话 → 仅返回活跃会话
- **Happy path**: 所有会话均过期 → 返回空列表
- **Edge case**: 无会话 → 返回空列表
- **Edge case**: 会话 expires_at 恰好等于当前时间 → 不返回(边界 > 而非 >=
**Verification**: `pytest tests/unit/auth/test_session_service.py -k list_active_by_provider -v` 通过
---
### U3. 修复 asyncio.run 测试崩溃 + 补强集成测试
**Goal**: 修复 test_revoke_other_user_session_returns_404 中的 asyncio.run 崩溃,并补齐 ocr 指出的 6 个测试覆盖缺口。
**Requirements**: R4, R7
**Dependencies**: U1
**Files**:
- Modify: `tests/integration/auth/test_auth_routes.py`
- Modify: `tests/integration/auth/test_admin_routes.py`
**Approach**:
1. 将 `_login_sync_create_user` 改为 `async def`,调用方 `test_revoke_other_user_session_returns_404` 也改为 `async def`
2. 移除未使用的 `auth_db_with_admin` fixture
3. 补充测试:
- `test_non_admin_cannot_list_all_sessions`admin 路由 403 测试)
- `test_revoke_session_belonging_to_different_user_returns_404`admin 撤销不匹配用户会话)
- `test_admin_list_sessions_for_nonexistent_user_returns_404`
- `test_admin_revoke_nonexistent_session_returns_404`
- 密码修改测试:验证当前会话存活 + 用新密码登录
- 会话撤销测试:撤销后用原 token 调用 whoami → 401
- 全局会话列表:验证两个用户都会出现
**Patterns to follow**: 现有 `test_admin_routes.py` 中的测试模式
**Test scenarios**:
- **Happy path**: asyncio 修复后 test_revoke_other_user_session_returns_404 正常运行
- **Happy path**: 所有新增测试通过
- **Error path**: 非管理员访问 /admin/sessions → 403
- **Error path**: admin 撤销不存在会话 → 404
- **Integration**: 密码修改 → 旧密码登录失败 → 新密码登录成功
**Verification**: `pytest tests/integration/auth/ -v` 全部通过,无崩溃
---
### U4. 更新路由系统文档
**Goal**: 使 CLAUDE.md 和 AGENTS.md 中的路由系统描述与实际 RequestPreprocessor 实现匹配。
**Requirements**: R5
**Dependencies**: 无
**Files**:
- Modify: `CLAUDE.md`Request Flow 部分)
- Modify: `AGENTS.md`Request Flow 部分)
**Approach**:
1. 将 Request Flow 部分从 "CostAwareRouter (3-layer)" 改为 "RequestPreprocessor (2-layer)"
2. 更新层级描述:
- Layer 0: `@skill:xxx` 前缀 → 显式技能选择
- Layer 1: 正则快速路径(问候/闲聊/身份/知识/算术/翻译)→ DIRECT_CHAT
- Layer 2: 其他全部 → REACTLLM 在 agent loop 中自主决策)
3. 移除对 HeuristicClassifier、SemanticRouter、Capability matching、Vickrey Auction 的描述
4. 添加设计决策说明(引用 request_preprocessor.py docstring 中的 rationale
5. 更新 ExecutionMode 描述(移除 TEAM_COLLAB 从路由层触发,改为 @team 前缀触发)
**Patterns to follow**: `src/agentkit/chat/request_preprocessor.py` 的 docstring
**Test scenarios**:
- **Test expectation**: none — 纯文档变更,无行为变化
**Verification**: 文档中不再出现 "CostAwareRouter"、"HeuristicClassifier"、"SemanticRouter"作为当前架构描述grep 确认
---
### U5. 前端添加"停止生成"按钮
**Goal**: 用户可在 LLM 长输出时中途取消生成。
**Requirements**: R6
**Dependencies**: 无
**Files**:
- Modify: `src/agentkit/server/frontend/src/components/chat/ChatInput.vue`(添加停止按钮)
- Modify: `src/agentkit/server/frontend/src/stores/chat.ts`(添加 cancel 状态和发送逻辑)
- Test: `src/agentkit/server/frontend/src/components/chat/__tests__/ChatInput.test.ts`(如存在)
**Approach**:
1. 在 chat store 中添加 `isGenerating` 状态true 时显示停止按钮false 时显示发送按钮)
2. 添加 `cancelGeneration()` action通过 WebSocket 发送 `cancel` 消息
3. 在 ChatInput 组件中,根据 `isGenerating` 切换按钮:发送按钮 ↔ 停止按钮
4. 停止按钮点击时调用 `cancelGeneration()`
5. 收到 `final_answer``error` 事件时,自动将 `isGenerating` 置为 false
**Patterns to follow**: 现有 ChatInput.vue 中发送按钮的实现模式
**Test scenarios**:
- **Happy path**: 空闲状态显示发送按钮,生成中显示停止按钮
- **Happy path**: 点击停止按钮 → 发送 cancel 消息 → isGenerating 变为 false
- **Edge case**: 生成中输入框禁用或允许继续输入(取决于现有 UX 模式)
- **Integration**: 收到 final_answer 后 isGenerating 自动重置
**Verification**: `npm run typecheck` 通过,手动验证按钮切换和取消功能
---
## Risks & Dependencies
| 风险 | 影响 | 缓解措施 |
|------|------|----------|
| U1 修改 whoami 可能影响现有冷启动流程 | 中 | 集成测试覆盖所有 whoami 场景 |
| U5 修改 chat store 可能影响 WebSocket 状态机 | 中 | 仅添加状态,不修改现有流转 |
| U4 文档更新可能遗漏其他引用 CostAwareRouter 的位置 | 低 | grep 全仓库搜索 CostAwareRouter |
## System-Wide Impact
- **U1**: 影响 whoami 端点行为,前端 auth store 的冷启动流程需验证兼容性
- **U2**: 影响 admin 会话查询结果,数据更准确
- **U3**: 仅测试代码,无生产影响
- **U4**: 仅文档,无代码影响
- **U5**: 影响聊天 UI 交互,需验证 WebSocket cancel 消息处理

View File

@ -203,6 +203,59 @@ def create_token_pair(
)
def create_access_token(
user_id: str,
username: str,
role: str,
secret: str,
*,
session_id: str | None = None,
now: datetime | None = None,
) -> str:
"""Create a single signed access JWT (no refresh token).
Used by ``/auth/whoami`` cold-start to issue a fresh access token
without creating a new refresh token (the client already has one).
This avoids the token-amplification risk of ``create_token_pair``
which would silently discard the new refresh token.
Args:
user_id: Subject (user id) stored as ``sub``.
username: Username claim.
role: Role claim.
secret: HS256 signing secret.
session_id: Server-side session id. When provided, the token
carries ``sid`` and ``jti`` claims.
now: Override the issued-at time (for testing).
Returns:
The signed access token string.
"""
if not secret:
raise ValueError("JWT secret must not be empty")
issued_at = now or datetime.now(timezone.utc)
access_exp = issued_at + ACCESS_TOKEN_TTL
jti = str(uuid.uuid4()) if session_id else None
access_payload: dict[str, Any] = {
"sub": user_id,
"username": username,
"role": role,
"type": "access",
"iat": int(issued_at.timestamp()),
"exp": int(access_exp.timestamp()),
}
if session_id:
access_payload["sid"] = session_id
access_payload["jti"] = jti
access_token = jwt.encode(access_payload, secret, algorithm=JWT_ALGORITHM)
if isinstance(access_token, bytes):
access_token = access_token.decode("utf-8")
return access_token
def verify_token(
token: str,
secret: str,

View File

@ -209,14 +209,15 @@ class SessionService:
Supports the future "show me all OIDC sessions" admin view.
Only non-revoked, non-expired sessions are returned.
"""
now_iso = datetime.now(timezone.utc).isoformat()
sql = (
"SELECT * FROM auth_sessions "
"WHERE auth_provider = ? AND revoked = 0 "
"WHERE auth_provider = ? AND revoked = 0 AND expires_at > ? "
"ORDER BY created_at DESC"
)
async with aiosqlite.connect(str(self._db_path)) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute(sql, (auth_provider,))
cursor = await db.execute(sql, (auth_provider, now_iso))
rows = await cursor.fetchall()
return [_row_to_info(r) for r in rows]

View File

@ -40,6 +40,7 @@
class="chat-input__textarea"
/>
<a-button
v-if="!isGenerating"
type="primary"
:disabled="!canSend"
:loading="disabled"
@ -48,6 +49,14 @@
>
<template #icon><SendOutlined /></template>
</a-button>
<a-button
v-else
danger
@click="handleStop"
class="chat-input__stop"
>
<template #icon><PoweroffOutlined /></template>
</a-button>
</div>
<div class="chat-input__footer">
<div class="chat-input__footer-left">
@ -120,7 +129,7 @@
<script setup lang="ts">
import { ref, computed, onMounted, onUnmounted, type Component } from 'vue'
import { Input as AInput, Button as AButton, Select as ASelect } from 'ant-design-vue'
import { SendOutlined, TeamOutlined, UsergroupAddOutlined, PaperClipOutlined } from '@ant-design/icons-vue'
import { SendOutlined, TeamOutlined, UsergroupAddOutlined, PaperClipOutlined, PoweroffOutlined } from '@ant-design/icons-vue'
import ContextPill from './ContextPill.vue'
import MentionDropdown from './MentionDropdown.vue'
import BoardMeetingModal from './BoardMeetingModal.vue'
@ -155,16 +164,19 @@ interface ModelInfo {
interface IProps {
disabled?: boolean
isGenerating?: boolean
placeholder?: string
}
const props = withDefaults(defineProps<IProps>(), {
disabled: false,
isGenerating: false,
placeholder: '输入消息,按 Enter 发送...',
})
const emit = defineEmits<{
send: [message: string, model?: string]
stop: []
}>()
const inputText = ref('')
@ -290,6 +302,10 @@ function handleSend(): void {
setTimeout(() => { inputText.value = '' }, 0)
}
function handleStop(): void {
emit('stop')
}
function handleBoardSubmit(command: string): void {
// The BoardMeetingModal constructs an @board:expert1,expert2 topic command.
// Send it through the normal chat pipeline the backend intercepts @board prefix.
@ -480,6 +496,21 @@ function removePill(idx: number): void {
border-color: var(--text-primary) !important;
}
.chat-input__stop {
flex-shrink: 0;
width: 28px;
height: 28px;
border-radius: var(--radius-md) !important;
display: flex;
align-items: center;
justify-content: center;
padding: 0 !important;
}
.chat-input__stop:hover {
opacity: 0.85;
}
.chat-input__footer {
display: flex;
align-items: center;

View File

@ -343,6 +343,31 @@ export const useChatStore = defineStore('chat', () => {
}
}
/** Stop the in-flight generation by sending a `cancel` WS message.
*
* The backend (`/api/v1/portal/ws`) handles `{"type":"cancel"}` by
* cancelling the active asyncio task and replying with a `result`
* event whose `data.status === "cancelled"`. That `result` event
* clears `isLoading` via the normal handler so this function only
* needs to send the cancel message and guard against send failures.
*/
function stopGeneration(): void {
if (!ws.value || ws.value.readyState !== WebSocket.OPEN) {
// No open socket — just reset the loading flag locally.
isLoading.value = false
streamingSteps.value = []
return
}
try {
const cancelMsg: WsClientMessage = { type: 'cancel' }
ws.value.send(JSON.stringify(cancelMsg))
} catch (error) {
console.error('Failed to send cancel message:', error)
isLoading.value = false
streamingSteps.value = []
}
}
/** After WebSocket reconnects, check for running tasks and resume them */
async function _recoverTaskAfterReconnect(): Promise<void> {
if (!currentConversationId.value) return
@ -918,6 +943,7 @@ export const useChatStore = defineStore('chat', () => {
sendMessage,
sendWsMessage,
resendLastUserMessage,
stopGeneration,
connectWebSocket,
disconnectWebSocket,
}

View File

@ -57,7 +57,9 @@
<div class="chat-view__input-inner">
<ChatInput
:disabled="chatStore.isLoading"
:is-generating="chatStore.isLoading"
@send="handleSend"
@stop="chatStore.stopGeneration"
/>
</div>
</div>

View File

@ -36,6 +36,7 @@ from agentkit.server.auth.jwt_utils import (
ACCESS_TOKEN_TTL,
REFRESH_TOKEN_TTL,
REFRESH_TOKEN_TTL_REMEMBER_ME,
create_access_token,
create_token_pair,
get_or_create_jwt_secret,
verify_token,
@ -574,8 +575,8 @@ async def whoami(request: Request) -> WhoamiResponse:
db.row_factory = aiosqlite.Row
cursor = await db.execute("SELECT * FROM users WHERE id = ?", (user_id,))
row = await cursor.fetchone()
if row is None:
raise HTTPException(status_code=404, detail="User not found")
if row is None or not bool(row["is_active"]):
raise HTTPException(status_code=401, detail="User not found or disabled")
user_response = _user_row_to_response(row)
# V2 token with sid: validate the session is still active.
@ -591,16 +592,18 @@ async def whoami(request: Request) -> WhoamiResponse:
)
# Cold-start: refresh token presented → issue a fresh access token.
# Use ``create_access_token`` (not ``create_token_pair``) to avoid
# minting a new refresh token that would silently be discarded —
# the client already holds a valid refresh token.
new_access_token: str | None = None
if token_type == "refresh":
new_pair = create_token_pair(
new_access_token = create_access_token(
user_id=str(row["id"]),
username=str(row["username"]),
role=str(row["role"]),
secret=secret,
session_id=sid,
)
new_access_token = new_pair.access_token
return WhoamiResponse(
user=user_response,

View File

@ -171,6 +171,25 @@ class TestAdminListUserSessions:
)
assert resp.status_code in (403, 401)
def test_admin_list_unknown_user_returns_404(
self,
auth_client: TestClient,
users: dict[str, dict[str, Any]],
):
"""Admin lists sessions for a non-existent user → 404 (not 500)."""
admin_body = _login(
auth_client,
users["admin"]["username"],
users["admin"]["password"],
)
fake_id = str(uuid.uuid4())
resp = auth_client.get(
f"/api/v1/admin/users/{fake_id}/sessions",
headers={"Authorization": f"Bearer {admin_body['access_token']}"},
)
# The route 404s on unknown user — no 500 leak.
assert resp.status_code == 404
# ---------------------------------------------------------------------------
# Admin: revoke user session
@ -273,3 +292,24 @@ class TestAdminListAllSessions:
sessions = resp.json()
# At least 2 sessions (member + admin).
assert len(sessions) >= 2
# Sessions span both users (stronger than just count).
user_ids = {s["user_id"] for s in sessions}
assert users["member"]["id"] in user_ids
assert users["admin"]["id"] in user_ids
def test_non_admin_cannot_list_all_sessions(
self,
auth_client: TestClient,
users: dict[str, dict[str, Any]],
):
"""A member must NOT be able to call GET /admin/sessions (403/401)."""
member_body = _login(
auth_client,
users["member"]["username"],
users["member"]["password"],
)
resp = auth_client.get(
"/api/v1/admin/sessions",
headers={"Authorization": f"Bearer {member_body['access_token']}"},
)
assert resp.status_code in (403, 401)

View File

@ -190,6 +190,58 @@ class TestWhoamiColdStart:
)
assert resp.status_code == 401
async def test_whoami_disabled_user_returns_401(
self,
auth_client: TestClient,
auth_db_with_user: dict[str, Any],
tmp_auth_db: Path,
):
"""A disabled user (is_active=0) must not pass whoami (U1 fix)."""
body = _login(
auth_client,
auth_db_with_user["username"],
auth_db_with_user["password"],
)
# Disable the user directly in the DB.
async with aiosqlite.connect(str(tmp_auth_db)) as db:
await db.execute(
"UPDATE users SET is_active = 0 WHERE id = ?",
(auth_db_with_user["id"],),
)
await db.commit()
# whoami with the still-valid access token must now 401.
resp = auth_client.get(
"/api/v1/auth/whoami",
headers={"Authorization": f"Bearer {body['access_token']}"},
)
assert resp.status_code == 401
assert "disabled" in resp.json()["detail"].lower()
async def test_whoami_refresh_token_does_not_leak_new_refresh_token(
self,
auth_client: TestClient,
auth_db_with_user: dict[str, Any],
):
"""Cold-start must NOT issue a new refresh token (U1 token-amplification fix).
The response only carries ``access_token``; the client keeps using
its existing refresh token.
"""
body = _login(
auth_client,
auth_db_with_user["username"],
auth_db_with_user["password"],
)
resp = auth_client.get(
"/api/v1/auth/whoami",
headers={"Authorization": f"Bearer {body['refresh_token']}"},
)
assert resp.status_code == 200, resp.text
data = resp.json()
assert data["access_token"] is not None
# WhoamiResponse has no refresh_token field — verify it's absent.
assert "refresh_token" not in data
class TestSessionsManagement:
"""GET /auth/sessions, DELETE /auth/sessions/{id}."""
@ -230,7 +282,7 @@ class TestSessionsManagement:
assert resp.status_code == 200, resp.text
assert resp.json()["revoked"] is True
def test_revoke_other_user_session_returns_404(
async def test_revoke_other_user_session_returns_404(
self,
auth_client: TestClient,
auth_db_with_user: dict[str, Any],
@ -238,7 +290,7 @@ class TestSessionsManagement:
):
"""A user cannot revoke another user's session (404, not 403, to avoid leaking)."""
# Create a second user and log in as them.
other = _login_sync_create_user(auth_client, tmp_auth_db, username="bob")
other = await _login_async_create_user(auth_client, tmp_auth_db, username="bob")
# Alice tries to revoke Bob's session.
body = _login(auth_client, auth_db_with_user["username"], auth_db_with_user["password"])
resp = auth_client.delete(
@ -248,7 +300,7 @@ class TestSessionsManagement:
assert resp.status_code == 404
def _login_sync_create_user(
async def _login_async_create_user(
client: TestClient,
db_path: Path,
*,
@ -257,12 +309,13 @@ def _login_sync_create_user(
"""Insert a user directly into the DB and log in via the API.
Returns the login response + the session_id extracted from the JWT.
Async so it can be ``await``-ed inside pytest-asyncio tests without
triggering the ``asyncio.run() cannot be called from a running event
loop`` error.
"""
import asyncio
from agentkit.server.auth.jwt_utils import verify_token
user = asyncio.run(_insert_user(db_path, username=username))
user = await _insert_user(db_path, username=username)
body = _login(client, username, user["password"])
payload = verify_token(body["access_token"], client.app.state.jwt_secret)
return {**body, "session_id": payload.get("sid"), "user_id": user["id"]}

View File

@ -254,3 +254,82 @@ async def test_is_session_valid_rejects_expired(svc: SessionService, user_id: st
async def test_is_session_valid_returns_false_for_unknown(svc: SessionService):
assert await svc.is_session_valid("nonexistent-id") is False
# ---------------------------------------------------------------------------
# list_active_by_provider (U2 fix: expired sessions must be filtered)
# ---------------------------------------------------------------------------
async def test_list_active_by_provider_excludes_expired(svc: SessionService, user_id: str):
"""Expired sessions must NOT appear in list_active_by_provider (U2 fix).
The docstring promises "non-revoked, non-expired" before the U2 fix
the SQL only checked ``revoked = 0`` and ignored ``expires_at``.
"""
# An expired session (TTL=0 → expires immediately).
expired_create = SessionCreate(
user_id=user_id,
refresh_token="rt-expired",
device_fingerprint="fp",
device_label="expired-device",
ip="",
user_agent="",
auth_provider="local",
ttl_seconds=0,
)
expired_info = await svc.create(expired_create)
# A live session.
live_info = await svc.create(_make_create(user_id, "rt-live"))
# list_active_by_provider must return ONLY the live one.
active = await svc.list_active_by_provider("local")
active_ids = {s.id for s in active}
assert live_info.id in active_ids
assert expired_info.id not in active_ids
async def test_list_active_by_provider_excludes_revoked(svc: SessionService, user_id: str):
"""Revoked sessions must also be excluded (regression guard)."""
a = await svc.create(_make_create(user_id, "rt-a"))
b = await svc.create(_make_create(user_id, "rt-b"))
await svc.revoke(a.id)
active = await svc.list_active_by_provider("local")
active_ids = {s.id for s in active}
assert b.id in active_ids
assert a.id not in active_ids
async def test_list_active_by_provider_filters_by_provider(
svc: SessionService, user_id: str
):
"""Only sessions matching the requested auth_provider are returned."""
# SessionCreate is a frozen dataclass — build each with its provider.
local_create = SessionCreate(
user_id=user_id,
refresh_token="rt-local",
device_fingerprint="fp",
device_label="d",
ip="",
user_agent="",
auth_provider="local",
ttl_seconds=3600,
)
await svc.create(local_create)
oidc_create = SessionCreate(
user_id=user_id,
refresh_token="rt-oidc",
device_fingerprint="fp",
device_label="d",
ip="",
user_agent="",
auth_provider="oidc",
ttl_seconds=3600,
)
await svc.create(oidc_create)
local_only = await svc.list_active_by_provider("local")
assert all(s.auth_provider == "local" for s in local_only)
assert len(local_only) == 1