geo/tests/test_citation_engine.py

127 lines
4.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pytest
from app.workers.citation_engine import BrandMatcher, CompetitorDetector
def test_brand_matcher_exact():
matcher = BrandMatcher(target_brand="华为", brand_aliases=["Huawei"])
result = matcher.match("华为是一家伟大的科技公司")
assert result["cited"] is True
assert result["match_type"] == "exact"
assert result["confidence"] == 1.0
def test_brand_matcher_alias():
matcher = BrandMatcher(target_brand="华为", brand_aliases=["Huawei"])
result = matcher.match("Huawei makes great phones")
assert result["cited"] is True
assert result["match_type"] == "alias"
assert result["confidence"] == 0.9
def test_brand_matcher_fuzzy():
matcher = BrandMatcher(target_brand="华为")
# "华伟" is a fuzzy match to "华为"
result = matcher.match("华伟 是一家科技公司")
assert result["cited"] is True
assert result["match_type"] == "fuzzy"
assert result["confidence"] > 0.4
def test_brand_matcher_no_match():
matcher = BrandMatcher(target_brand="华为")
result = matcher.match("这是一段完全不相关的文本,没有任何品牌信息")
assert result["cited"] is False
assert result["match_type"] is None
assert result["confidence"] == 0.0
def test_competitor_detector():
detector = CompetitorDetector()
text = "中国平安和中国人寿都是大型保险公司"
competitors = detector.detect(text, target_brand="中国平安")
assert "中国人寿" in competitors
assert "中国平安" not in competitors
def test_citation_position():
matcher = BrandMatcher(target_brand="华为")
text = "第一段介绍市场情况\n第二段提到华为的产品\n第三段是总结"
result = matcher.match(text)
assert result["cited"] is True
assert result["position"] == 2
assert result["citation_text"] is not None
# ---------------------------------------------------------------------------
# 补充测试
# ---------------------------------------------------------------------------
def test_brand_matcher_multiple_aliases():
"""多个别名时,应能匹配任意一个别名。"""
matcher = BrandMatcher(target_brand="华为", brand_aliases=["Huawei", "HW", "Honor"])
# 匹配第二个别名
result = matcher.match("HW released a new chip")
assert result["cited"] is True
assert result["match_type"] == "alias"
assert result["confidence"] == 0.9
# 匹配第三个别名
result2 = matcher.match("Honor phones are popular")
assert result2["cited"] is True
assert result2["match_type"] == "alias"
def test_brand_matcher_fuzzy_threshold_boundary():
"""模糊匹配阈值边界ratio 恰好在 0.4 附近的情况。"""
matcher = BrandMatcher(target_brand="华为")
# "华伟" 与 "华为" 的 ratio 约为 0.5,应大于 0.4
result = matcher.match("华伟 是一家科技公司")
assert result["cited"] is True
assert result["match_type"] == "fuzzy"
assert result["confidence"] > 0.4
# "苹果" vs "华为" ratio 很低,不应超过 0.4
result2 = matcher.match("苹果科技公司")
assert result2["cited"] is False
assert result2["match_type"] is None
def test_brand_matcher_empty_text():
"""空字符串输入不应崩溃,应返回 cited=False。"""
matcher = BrandMatcher(target_brand="华为", brand_aliases=["Huawei"])
result = matcher.match("")
assert result["cited"] is False
assert result["confidence"] == 0.0
assert result["match_type"] is None
assert result["position"] is None
def test_competitor_detector_multi_industry():
"""竞品检测应能跨行业识别品牌。"""
detector = CompetitorDetector()
text = "华为和小米是科技公司,工商银行和招商银行是银行"
competitors_tech = detector.detect(text, target_brand="华为")
assert "小米" in competitors_tech
assert "腾讯" not in competitors_tech # 未在文本中出现
assert "华为" not in competitors_tech
competitors_finance = detector.detect(text, target_brand="工商银行")
assert "招商银行" in competitors_finance
assert "建设银行" not in competitors_finance # 未在文本中出现
def test_citation_position_multiple_paragraphs():
"""品牌在不同段落位置出现时的 position 检测。"""
matcher = BrandMatcher(target_brand="华为")
text_first = "华为位于第一段\n第二段没有\n第三段也没有"
result = matcher.match(text_first)
assert result["position"] == 1
text_third = "第一段没有\n第二段没有\n华为在第三段"
result = matcher.match(text_third)
assert result["position"] == 3
text_last = "第一段\n第二段\n第三段\n最后一段提到华为"
result = matcher.match(text_last)
assert result["position"] == 4