Replace plain recursive chunker with legal-aware chunking that: - Detects legal section headers (§, Art., Section, Chapter, Annex) - Adds section context prefix to every chunk - Splits on paragraph boundaries then sentence boundaries - Protects DE + EN abbreviations (80+ patterns) from false splits - Supports language detection for locale-specific processing - Force-splits overlong sentences at word boundaries The old plain_recursive API option is removed — all non-semantic strategies now route through chunk_text_legal(). Includes 40 tests covering header detection, abbreviation protection, sentence splitting, and legal chunking behavior. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
289 lines
11 KiB
Python
289 lines
11 KiB
Python
"""
|
|
Tests for the legal-aware chunking pipeline.
|
|
|
|
Covers:
|
|
- Legal section header detection (§, Art., Section, Chapter, Annex)
|
|
- Section context prefix in every chunk
|
|
- Paragraph boundary splitting
|
|
- Sentence splitting with DE and EN abbreviation protection
|
|
- Overlap between chunks
|
|
- Fallback for non-legal text
|
|
- Long sentence force-splitting
|
|
"""
|
|
|
|
import pytest
|
|
from main import (
|
|
chunk_text_legal,
|
|
chunk_text_recursive,
|
|
chunk_text_semantic,
|
|
_extract_section_header,
|
|
_split_sentences,
|
|
_detect_language,
|
|
_protect_abbreviations,
|
|
_restore_abbreviations,
|
|
)
|
|
|
|
|
|
# =========================================================================
|
|
# Section header detection
|
|
# =========================================================================
|
|
|
|
class TestSectionHeaderDetection:
|
|
|
|
def test_german_paragraph(self):
|
|
assert _extract_section_header("§ 25 Informationspflichten") is not None
|
|
|
|
def test_german_paragraph_with_letter(self):
|
|
assert _extract_section_header("§ 5a Elektronischer Geschaeftsverkehr") is not None
|
|
|
|
def test_german_artikel(self):
|
|
assert _extract_section_header("Artikel 5 Grundsaetze") is not None
|
|
|
|
def test_english_article(self):
|
|
assert _extract_section_header("Article 12 Transparency") is not None
|
|
|
|
def test_article_abbreviated(self):
|
|
assert _extract_section_header("Art. 3 Definitions") is not None
|
|
|
|
def test_english_section(self):
|
|
assert _extract_section_header("Section 4.2 Risk Assessment") is not None
|
|
|
|
def test_german_abschnitt(self):
|
|
assert _extract_section_header("Abschnitt 3 Pflichten") is not None
|
|
|
|
def test_chapter(self):
|
|
assert _extract_section_header("Chapter 5 Obligations") is not None
|
|
|
|
def test_german_kapitel(self):
|
|
assert _extract_section_header("Kapitel 2 Anwendungsbereich") is not None
|
|
|
|
def test_annex_roman(self):
|
|
assert _extract_section_header("Annex XII Technical Documentation") is not None
|
|
|
|
def test_german_anhang(self):
|
|
assert _extract_section_header("Anhang III Hochrisiko-KI") is not None
|
|
|
|
def test_part(self):
|
|
assert _extract_section_header("Part III Requirements") is not None
|
|
|
|
def test_markdown_heading(self):
|
|
assert _extract_section_header("## 3.1 Overview") is not None
|
|
|
|
def test_normal_text_not_header(self):
|
|
assert _extract_section_header("This is a normal sentence.") is None
|
|
|
|
def test_short_caps_not_header(self):
|
|
assert _extract_section_header("OK") is None
|
|
|
|
|
|
# =========================================================================
|
|
# Language detection
|
|
# =========================================================================
|
|
|
|
class TestLanguageDetection:
|
|
|
|
def test_german_text(self):
|
|
text = "Die Verordnung ist für alle Mitgliedstaaten verbindlich und gilt nach dem Grundsatz der unmittelbaren Anwendbarkeit."
|
|
assert _detect_language(text) == 'de'
|
|
|
|
def test_english_text(self):
|
|
text = "This regulation shall be binding in its entirety and directly applicable in all Member States."
|
|
assert _detect_language(text) == 'en'
|
|
|
|
|
|
# =========================================================================
|
|
# Abbreviation protection
|
|
# =========================================================================
|
|
|
|
class TestAbbreviationProtection:
|
|
|
|
def test_german_abbreviations(self):
|
|
text = "gem. § 5 Abs. 1 bzw. § 6 Abs. 2 z.B. die Pflicht"
|
|
protected = _protect_abbreviations(text)
|
|
assert "." not in protected.replace("<DOT>", "").replace("<DECIMAL>", "").replace("<ORD>", "").replace("<ABBR>", "")
|
|
restored = _restore_abbreviations(protected)
|
|
assert "gem." in restored
|
|
assert "z.B." in restored.replace("z.b.", "z.B.") or "z.b." in restored
|
|
|
|
def test_english_abbreviations(self):
|
|
text = "e.g. section 4.2, i.e. the requirements in vol. 1 ref. NIST SP 800-30."
|
|
protected = _protect_abbreviations(text)
|
|
# "e.g" and "i.e" should be protected
|
|
restored = _restore_abbreviations(protected)
|
|
assert "e.g." in restored
|
|
|
|
def test_decimals_protected(self):
|
|
text = "Version 3.14 of the specification requires 2.5 GB."
|
|
protected = _protect_abbreviations(text)
|
|
assert "<DECIMAL>" in protected
|
|
restored = _restore_abbreviations(protected)
|
|
assert "3.14" in restored
|
|
|
|
|
|
# =========================================================================
|
|
# Sentence splitting
|
|
# =========================================================================
|
|
|
|
class TestSentenceSplitting:
|
|
|
|
def test_simple_german(self):
|
|
text = "Erster Satz. Zweiter Satz. Dritter Satz."
|
|
sentences = _split_sentences(text)
|
|
assert len(sentences) >= 2
|
|
|
|
def test_simple_english(self):
|
|
text = "First sentence. Second sentence. Third sentence."
|
|
sentences = _split_sentences(text)
|
|
assert len(sentences) >= 2
|
|
|
|
def test_german_abbreviation_not_split(self):
|
|
text = "Gem. Art. 5 Abs. 1 DSGVO ist die Verarbeitung rechtmaessig. Der Verantwortliche muss dies nachweisen."
|
|
sentences = _split_sentences(text)
|
|
# Should NOT split at "Gem." or "Art." or "Abs."
|
|
assert any("Gem" in s and "DSGVO" in s for s in sentences)
|
|
|
|
def test_english_abbreviation_not_split(self):
|
|
text = "See e.g. Section 4.2 for details. The standard also references vol. 1 of the NIST SP series."
|
|
sentences = _split_sentences(text)
|
|
assert any("e.g" in s and "Section" in s for s in sentences)
|
|
|
|
def test_exclamation_and_question(self):
|
|
text = "Is this valid? Yes it is! Continue processing."
|
|
sentences = _split_sentences(text)
|
|
assert len(sentences) >= 2
|
|
|
|
|
|
# =========================================================================
|
|
# Legal chunking
|
|
# =========================================================================
|
|
|
|
class TestChunkTextLegal:
|
|
|
|
def test_small_text_single_chunk(self):
|
|
text = "Short text."
|
|
chunks = chunk_text_legal(text, chunk_size=1024, overlap=128)
|
|
assert len(chunks) == 1
|
|
assert chunks[0] == "Short text."
|
|
|
|
def test_section_header_as_prefix(self):
|
|
text = "§ 25 Informationspflichten\n\nDer Betreiber muss den Nutzer informieren. " * 20
|
|
chunks = chunk_text_legal(text, chunk_size=200, overlap=0)
|
|
assert len(chunks) > 1
|
|
# Every chunk should have the section prefix
|
|
for chunk in chunks:
|
|
assert "[§ 25" in chunk or "§ 25" in chunk
|
|
|
|
def test_article_prefix_english(self):
|
|
text = "Article 12 Transparency\n\n" + "The provider shall ensure transparency of AI systems. " * 30
|
|
chunks = chunk_text_legal(text, chunk_size=300, overlap=0)
|
|
assert len(chunks) > 1
|
|
for chunk in chunks:
|
|
assert "Article 12" in chunk
|
|
|
|
def test_multiple_sections(self):
|
|
text = (
|
|
"§ 1 Anwendungsbereich\n\nDieses Gesetz gilt fuer alle Betreiber.\n\n"
|
|
"§ 2 Begriffsbestimmungen\n\nIm Sinne dieses Gesetzes ist Betreiber, wer eine Anlage betreibt.\n\n"
|
|
"§ 3 Pflichten\n\nDer Betreiber hat die Pflicht, die Anlage sicher zu betreiben."
|
|
)
|
|
chunks = chunk_text_legal(text, chunk_size=200, overlap=0)
|
|
# Should have chunks from different sections
|
|
section_headers = set()
|
|
for chunk in chunks:
|
|
if "[§ 1" in chunk:
|
|
section_headers.add("§ 1")
|
|
if "[§ 2" in chunk:
|
|
section_headers.add("§ 2")
|
|
if "[§ 3" in chunk:
|
|
section_headers.add("§ 3")
|
|
assert len(section_headers) >= 2
|
|
|
|
def test_paragraph_boundaries_respected(self):
|
|
para1 = "First paragraph with enough text to matter. " * 5
|
|
para2 = "Second paragraph also with content. " * 5
|
|
text = para1.strip() + "\n\n" + para2.strip()
|
|
chunks = chunk_text_legal(text, chunk_size=300, overlap=0)
|
|
# Paragraphs should not be merged mid-sentence across chunk boundary
|
|
assert len(chunks) >= 2
|
|
|
|
def test_overlap_present(self):
|
|
text = "Sentence one about topic A. " * 10 + "\n\n" + "Sentence two about topic B. " * 10
|
|
chunks = chunk_text_legal(text, chunk_size=200, overlap=50)
|
|
if len(chunks) > 1:
|
|
# Second chunk should contain some text from end of first chunk
|
|
end_of_first = chunks[0][-30:]
|
|
# At least some overlap words should appear
|
|
overlap_words = set(end_of_first.split())
|
|
second_start_words = set(chunks[1][:80].split())
|
|
assert len(overlap_words & second_start_words) > 0
|
|
|
|
def test_nist_style_sections(self):
|
|
text = (
|
|
"Section 2.1 Risk Framing\n\n"
|
|
"Risk framing establishes the context for risk-based decisions. "
|
|
"Organizations must define their risk tolerance. " * 10 + "\n\n"
|
|
"Section 2.2 Risk Assessment\n\n"
|
|
"Risk assessment identifies threats and vulnerabilities. " * 10
|
|
)
|
|
chunks = chunk_text_legal(text, chunk_size=400, overlap=0)
|
|
has_21 = any("Section 2.1" in c for c in chunks)
|
|
has_22 = any("Section 2.2" in c for c in chunks)
|
|
assert has_21 and has_22
|
|
|
|
def test_markdown_heading_as_context(self):
|
|
text = (
|
|
"## 3.1 Overview\n\n"
|
|
"This section provides an overview of the specification. " * 15
|
|
)
|
|
chunks = chunk_text_legal(text, chunk_size=300, overlap=0)
|
|
assert len(chunks) > 1
|
|
for chunk in chunks:
|
|
assert "3.1 Overview" in chunk
|
|
|
|
def test_empty_text(self):
|
|
assert chunk_text_legal("", 1024, 128) == []
|
|
|
|
def test_whitespace_only(self):
|
|
assert chunk_text_legal(" \n\n ", 1024, 128) == []
|
|
|
|
def test_long_sentence_force_split(self):
|
|
long_sentence = "A" * 2000
|
|
chunks = chunk_text_legal(long_sentence, chunk_size=500, overlap=0)
|
|
assert len(chunks) >= 4
|
|
for chunk in chunks:
|
|
assert len(chunk) <= 500 + 20 # small margin for prefix
|
|
|
|
|
|
# =========================================================================
|
|
# Legacy recursive chunking still works
|
|
# =========================================================================
|
|
|
|
class TestChunkTextRecursive:
|
|
|
|
def test_basic_split(self):
|
|
text = "Hello world. " * 200
|
|
chunks = chunk_text_recursive(text, chunk_size=500, overlap=50)
|
|
assert len(chunks) > 1
|
|
for chunk in chunks:
|
|
assert len(chunk) <= 600 # some margin for overlap
|
|
|
|
def test_small_text(self):
|
|
chunks = chunk_text_recursive("Short.", chunk_size=1024, overlap=128)
|
|
assert chunks == ["Short."]
|
|
|
|
|
|
# =========================================================================
|
|
# Semantic chunking still works
|
|
# =========================================================================
|
|
|
|
class TestChunkTextSemantic:
|
|
|
|
def test_basic_split(self):
|
|
text = "First sentence. Second sentence. Third sentence. Fourth sentence. Fifth sentence."
|
|
chunks = chunk_text_semantic(text, chunk_size=50, overlap_sentences=1)
|
|
assert len(chunks) >= 2
|
|
|
|
def test_small_text(self):
|
|
chunks = chunk_text_semantic("Short.", chunk_size=1024, overlap_sentences=1)
|
|
assert chunks == ["Short."]
|