156 lines
5.3 KiB
Python
156 lines
5.3 KiB
Python
"""U4 测试 — jieba 分词 + tsvector 写入/查询。
|
||
|
||
测试场景:
|
||
1. tokenize:中文分词后用空格连接
|
||
2. build_tsquery:构造 AND 语义的 tsquery,过滤空 token,转义特殊字符
|
||
3. write_search_vector:调用 session.execute 执行 UPDATE SQL
|
||
4. write_search_vector_batch:批量写入
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
from unittest.mock import AsyncMock
|
||
|
||
from agentkit.rag_platform.fulltext import (
|
||
KB_CHUNKS_TABLE,
|
||
build_tsquery,
|
||
tokenize,
|
||
write_search_vector,
|
||
write_search_vector_batch,
|
||
)
|
||
|
||
|
||
class TestTokenize:
|
||
"""jieba 分词测试。"""
|
||
|
||
def test_chinese_text_tokenized(self):
|
||
"""中文文本被分词后用空格连接。"""
|
||
result = tokenize("我爱自然语言处理")
|
||
# jieba 精确模式应切分为多个 token
|
||
assert isinstance(result, str)
|
||
assert " " in result # 多个 token 用空格连接
|
||
# 关键词应出现在结果中
|
||
assert "自然语言" in result or "自然语言处理" in result
|
||
|
||
def test_english_text_preserved(self):
|
||
"""英文文本保持原样(按空格分词)。"""
|
||
result = tokenize("hello world")
|
||
assert "hello" in result
|
||
assert "world" in result
|
||
|
||
def test_mixed_text(self):
|
||
"""中英文混合文本正常分词。"""
|
||
result = tokenize("RAG 检索增强生成")
|
||
assert isinstance(result, str)
|
||
assert len(result) > 0
|
||
|
||
def test_empty_string(self):
|
||
"""空字符串返回空字符串。"""
|
||
assert tokenize("") == ""
|
||
|
||
|
||
class TestBuildTsquery:
|
||
"""tsquery 构造测试。"""
|
||
|
||
def test_chinese_query_and_semantics(self):
|
||
"""中文查询用 & 连接(AND 语义)。"""
|
||
result = build_tsquery("自然语言处理")
|
||
# 应包含 & 连接符(多 token 时)
|
||
assert isinstance(result, str)
|
||
assert len(result) > 0
|
||
# 不应包含空 token(连续 & 之间无内容)
|
||
assert " & & " not in result
|
||
assert not result.startswith("& ")
|
||
assert not result.endswith(" &")
|
||
|
||
def test_filters_empty_tokens(self):
|
||
"""空 token 被过滤掉。"""
|
||
# 多空格输入
|
||
result = build_tsquery("hello world")
|
||
# 不应有连续的 &(空 token 会导致 " & & ")
|
||
assert " & & " not in result
|
||
|
||
def test_escapes_special_chars(self):
|
||
"""PG to_tsquery 特殊字符被转义(替换为空格)。"""
|
||
# 包含 & | ! ( ) : < > " ' \
|
||
result = build_tsquery("test & injection | attempt")
|
||
# 不应保留原始的特殊字符(会被替换为空格然后过滤)
|
||
assert "&" not in result or result.count("&") == result.count(" & ") + (
|
||
0 if result.startswith("&") else 0
|
||
)
|
||
# 应该是合法的 tsquery 格式
|
||
assert isinstance(result, str)
|
||
|
||
def test_empty_query_returns_empty(self):
|
||
"""空查询返回空字符串。"""
|
||
assert build_tsquery("") == ""
|
||
|
||
def test_whitespace_only_returns_empty(self):
|
||
"""纯空白查询返回空字符串。"""
|
||
assert build_tsquery(" ") == ""
|
||
|
||
|
||
class TestWriteSearchVector:
|
||
"""search_vector 写入测试。"""
|
||
|
||
async def test_write_calls_execute_with_correct_sql(self):
|
||
"""write_search_vector 调用 session.execute 执行 UPDATE SQL。"""
|
||
mock_session = AsyncMock()
|
||
mock_session.execute = AsyncMock()
|
||
|
||
await write_search_vector(mock_session, "chunk-001", "测试内容")
|
||
|
||
mock_session.execute.assert_awaited_once()
|
||
# 验证传入的参数
|
||
call_args = mock_session.execute.await_args
|
||
# 第一个参数是 SQL text 对象,第二个是参数字典
|
||
params = call_args.args[1] if len(call_args.args) > 1 else call_args.kwargs
|
||
assert params["chunk_id"] == "chunk-001"
|
||
# tokenize 后 jieba 将 "测试内容" 切分为 "测试 内容"
|
||
assert "测试" in params["tokens"]
|
||
assert "内容" in params["tokens"]
|
||
|
||
async def test_write_uses_correct_table_name(self):
|
||
"""write_search_vector 使用正确的表名。"""
|
||
mock_session = AsyncMock()
|
||
mock_session.execute = AsyncMock()
|
||
|
||
await write_search_vector(mock_session, "c1", "content")
|
||
|
||
call_args = mock_session.execute.await_args
|
||
sql_obj = call_args.args[0]
|
||
# SQL 文本应包含表名
|
||
assert KB_CHUNKS_TABLE in str(sql_obj)
|
||
# 应使用 to_tsvector('simple', ...)
|
||
assert "to_tsvector" in str(sql_obj)
|
||
assert "search_vector" in str(sql_obj)
|
||
|
||
|
||
class TestWriteSearchVectorBatch:
|
||
"""批量写入测试。"""
|
||
|
||
async def test_batch_writes_all_items(self):
|
||
"""批量写入调用 write_search_vector N 次。"""
|
||
mock_session = AsyncMock()
|
||
mock_session.execute = AsyncMock()
|
||
|
||
items = [
|
||
("chunk-1", "内容一"),
|
||
("chunk-2", "内容二"),
|
||
("chunk-3", "内容三"),
|
||
]
|
||
|
||
await write_search_vector_batch(mock_session, items)
|
||
|
||
# 应调用 execute 3 次
|
||
assert mock_session.execute.await_count == 3
|
||
|
||
async def test_batch_empty_items_no_calls(self):
|
||
"""空列表不调用 execute。"""
|
||
mock_session = AsyncMock()
|
||
mock_session.execute = AsyncMock()
|
||
|
||
await write_search_vector_batch(mock_session, [])
|
||
|
||
mock_session.execute.assert_not_awaited()
|