fischer-agentkit/tests/unit/rag_platform/test_fulltext.py

156 lines
5.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""U4 测试 — jieba 分词 + tsvector 写入/查询。
测试场景:
1. tokenize中文分词后用空格连接
2. build_tsquery构造 AND 语义的 tsquery过滤空 token转义特殊字符
3. write_search_vector调用 session.execute 执行 UPDATE SQL
4. write_search_vector_batch批量写入
"""
from __future__ import annotations
from unittest.mock import AsyncMock
from agentkit.rag_platform.fulltext import (
KB_CHUNKS_TABLE,
build_tsquery,
tokenize,
write_search_vector,
write_search_vector_batch,
)
class TestTokenize:
"""jieba 分词测试。"""
def test_chinese_text_tokenized(self):
"""中文文本被分词后用空格连接。"""
result = tokenize("我爱自然语言处理")
# jieba 精确模式应切分为多个 token
assert isinstance(result, str)
assert " " in result # 多个 token 用空格连接
# 关键词应出现在结果中
assert "自然语言" in result or "自然语言处理" in result
def test_english_text_preserved(self):
"""英文文本保持原样(按空格分词)。"""
result = tokenize("hello world")
assert "hello" in result
assert "world" in result
def test_mixed_text(self):
"""中英文混合文本正常分词。"""
result = tokenize("RAG 检索增强生成")
assert isinstance(result, str)
assert len(result) > 0
def test_empty_string(self):
"""空字符串返回空字符串。"""
assert tokenize("") == ""
class TestBuildTsquery:
"""tsquery 构造测试。"""
def test_chinese_query_and_semantics(self):
"""中文查询用 & 连接AND 语义)。"""
result = build_tsquery("自然语言处理")
# 应包含 & 连接符(多 token 时)
assert isinstance(result, str)
assert len(result) > 0
# 不应包含空 token连续 & 之间无内容)
assert " & & " not in result
assert not result.startswith("& ")
assert not result.endswith(" &")
def test_filters_empty_tokens(self):
"""空 token 被过滤掉。"""
# 多空格输入
result = build_tsquery("hello world")
# 不应有连续的 &(空 token 会导致 " & & "
assert " & & " not in result
def test_escapes_special_chars(self):
"""PG to_tsquery 特殊字符被转义(替换为空格)。"""
# 包含 & | ! ( ) : < > " ' \
result = build_tsquery("test & injection | attempt")
# 不应保留原始的特殊字符(会被替换为空格然后过滤)
assert "&" not in result or result.count("&") == result.count(" & ") + (
0 if result.startswith("&") else 0
)
# 应该是合法的 tsquery 格式
assert isinstance(result, str)
def test_empty_query_returns_empty(self):
"""空查询返回空字符串。"""
assert build_tsquery("") == ""
def test_whitespace_only_returns_empty(self):
"""纯空白查询返回空字符串。"""
assert build_tsquery(" ") == ""
class TestWriteSearchVector:
"""search_vector 写入测试。"""
async def test_write_calls_execute_with_correct_sql(self):
"""write_search_vector 调用 session.execute 执行 UPDATE SQL。"""
mock_session = AsyncMock()
mock_session.execute = AsyncMock()
await write_search_vector(mock_session, "chunk-001", "测试内容")
mock_session.execute.assert_awaited_once()
# 验证传入的参数
call_args = mock_session.execute.await_args
# 第一个参数是 SQL text 对象,第二个是参数字典
params = call_args.args[1] if len(call_args.args) > 1 else call_args.kwargs
assert params["chunk_id"] == "chunk-001"
# tokenize 后 jieba 将 "测试内容" 切分为 "测试 内容"
assert "测试" in params["tokens"]
assert "内容" in params["tokens"]
async def test_write_uses_correct_table_name(self):
"""write_search_vector 使用正确的表名。"""
mock_session = AsyncMock()
mock_session.execute = AsyncMock()
await write_search_vector(mock_session, "c1", "content")
call_args = mock_session.execute.await_args
sql_obj = call_args.args[0]
# SQL 文本应包含表名
assert KB_CHUNKS_TABLE in str(sql_obj)
# 应使用 to_tsvector('simple', ...)
assert "to_tsvector" in str(sql_obj)
assert "search_vector" in str(sql_obj)
class TestWriteSearchVectorBatch:
"""批量写入测试。"""
async def test_batch_writes_all_items(self):
"""批量写入调用 write_search_vector N 次。"""
mock_session = AsyncMock()
mock_session.execute = AsyncMock()
items = [
("chunk-1", "内容一"),
("chunk-2", "内容二"),
("chunk-3", "内容三"),
]
await write_search_vector_batch(mock_session, items)
# 应调用 execute 3 次
assert mock_session.execute.await_count == 3
async def test_batch_empty_items_no_calls(self):
"""空列表不调用 execute。"""
mock_session = AsyncMock()
mock_session.execute = AsyncMock()
await write_search_vector_batch(mock_session, [])
mock_session.execute.assert_not_awaited()