feat(embedding): implement legal-aware chunking pipeline

Replace plain recursive chunker with legal-aware chunking that:
- Detects legal section headers (§, Art., Section, Chapter, Annex)
- Adds section context prefix to every chunk
- Splits on paragraph boundaries then sentence boundaries
- Protects DE + EN abbreviations (80+ patterns) from false splits
- Supports language detection for locale-specific processing
- Force-splits overlong sentences at word boundaries

The old plain_recursive API option is removed — all non-semantic
strategies now route through chunk_text_legal().

Includes 40 tests covering header detection, abbreviation protection,
sentence splitting, and legal chunking behavior.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Benjamin Admin
2026-03-22 09:18:23 +01:00
parent c1a8b9d936
commit 322e2d9cb3
2 changed files with 545 additions and 22 deletions

View File

@@ -251,14 +251,251 @@ async def rerank_cohere(query: str, documents: List[str], top_k: int = 5) -> Lis
GERMAN_ABBREVIATIONS = {
'bzw', 'ca', 'chr', 'd.h', 'dr', 'etc', 'evtl', 'ggf', 'inkl', 'max',
'min', 'mio', 'mrd', 'nr', 'prof', 's', 'sog', 'u.a', 'u.ä', 'usw',
'v.a', 'vgl', 'vs', 'z.b', 'z.t', 'zzgl'
'v.a', 'vgl', 'vs', 'z.b', 'z.t', 'zzgl', 'abs', 'art', 'abschn',
'anh', 'anl', 'aufl', 'bd', 'bes', 'bzgl', 'dgl', 'einschl', 'entspr',
'erg', 'erl', 'gem', 'grds', 'hrsg', 'insb', 'ivm', 'kap', 'lit',
'nachf', 'rdnr', 'rn', 'rz', 'ua', 'uvm', 'vorst', 'ziff'
}
# English abbreviations that don't end sentences
ENGLISH_ABBREVIATIONS = {
'e.g', 'i.e', 'etc', 'vs', 'al', 'approx', 'avg', 'dept', 'dr', 'ed',
'est', 'fig', 'govt', 'inc', 'jr', 'ltd', 'max', 'min', 'mr', 'mrs',
'ms', 'no', 'prof', 'pt', 'ref', 'rev', 'sec', 'sgt', 'sr', 'st',
'vol', 'cf', 'ch', 'cl', 'col', 'corp', 'cpl', 'def', 'dist', 'div',
'gen', 'hon', 'illus', 'intl', 'natl', 'org', 'para', 'pp', 'repr',
'resp', 'supp', 'tech', 'temp', 'treas', 'univ'
}
# Combined abbreviations for both languages
ALL_ABBREVIATIONS = GERMAN_ABBREVIATIONS | ENGLISH_ABBREVIATIONS
# Regex pattern for legal section headers (§, Art., Article, Section, etc.)
import re
_LEGAL_SECTION_RE = re.compile(
r'^(?:'
r'§\s*\d+' # § 25, § 5a
r'|Art(?:ikel|icle|\.)\s*\d+' # Artikel 5, Article 12, Art. 3
r'|Section\s+\d+' # Section 4.2
r'|Abschnitt\s+\d+' # Abschnitt III
r'|Kapitel\s+\d+' # Kapitel 2
r'|Chapter\s+\d+' # Chapter 3
r'|Anhang\s+[IVXLC\d]+' # Anhang III
r'|Annex\s+[IVXLC\d]+' # Annex XII
r'|TEIL\s+[IVXLC\d]+' # TEIL II
r'|Part\s+[IVXLC\d]+' # Part III
r'|Recital\s+\d+' # Recital 42
r'|Erwaegungsgrund\s+\d+' # Erwaegungsgrund 26
r')',
re.IGNORECASE | re.MULTILINE
)
# Regex for any heading-like line (Markdown ## or ALL-CAPS line)
_HEADING_RE = re.compile(
r'^(?:'
r'#{1,6}\s+.+' # Markdown headings
r'|[A-ZÄÖÜ][A-ZÄÖÜ\s\-]{5,}$' # ALL-CAPS lines (>5 chars)
r')',
re.MULTILINE
)
def _detect_language(text: str) -> str:
"""Simple heuristic: count German vs English marker words."""
sample = text[:5000].lower()
de_markers = sum(1 for w in ['der', 'die', 'das', 'und', 'ist', 'für', 'von',
'werden', 'nach', 'gemäß', 'sowie', 'durch']
if f' {w} ' in sample)
en_markers = sum(1 for w in ['the', 'and', 'for', 'that', 'with', 'shall',
'must', 'should', 'which', 'from', 'this']
if f' {w} ' in sample)
return 'de' if de_markers > en_markers else 'en'
def _protect_abbreviations(text: str) -> str:
"""Replace dots in abbreviations with placeholders to prevent false sentence splits."""
protected = text
for abbrev in ALL_ABBREVIATIONS:
pattern = re.compile(r'\b(' + re.escape(abbrev) + r')\.', re.IGNORECASE)
# Use lambda to preserve original case of the matched abbreviation
protected = pattern.sub(lambda m: m.group(1).replace('.', '<DOT>') + '<ABBR>', protected)
# Protect decimals (3.14) and ordinals (1. Absatz)
protected = re.sub(r'(\d)\.(\d)', r'\1<DECIMAL>\2', protected)
protected = re.sub(r'(\d+)\.\s', r'\1<ORD> ', protected)
return protected
def _restore_abbreviations(text: str) -> str:
"""Restore placeholders back to dots."""
return (text
.replace('<DOT>', '.')
.replace('<ABBR>', '.')
.replace('<DECIMAL>', '.')
.replace('<ORD>', '.'))
def _split_sentences(text: str) -> List[str]:
"""Split text into sentences, respecting abbreviations in DE and EN."""
protected = _protect_abbreviations(text)
# Split after sentence-ending punctuation followed by uppercase or newline
sentence_pattern = r'(?<=[.!?])\s+(?=[A-ZÄÖÜÀ-Ý])|(?<=[.!?])\s*\n'
raw = re.split(sentence_pattern, protected)
sentences = []
for s in raw:
s = _restore_abbreviations(s).strip()
if s:
sentences.append(s)
return sentences
def _extract_section_header(line: str) -> Optional[str]:
"""Extract a legal section header from a line, or None."""
m = _LEGAL_SECTION_RE.match(line.strip())
if m:
return line.strip()
m = _HEADING_RE.match(line.strip())
if m:
return line.strip()
return None
def chunk_text_legal(text: str, chunk_size: int, overlap: int) -> List[str]:
"""
Legal-document-aware chunking.
Strategy:
1. Split on legal section boundaries (§, Art., Section, Chapter, etc.)
2. Within each section, split on paragraph boundaries (double newline)
3. Within each paragraph, split on sentence boundaries
4. Prepend section header as context prefix to every chunk
5. Add overlap from previous chunk
Works for both German (DSGVO, BGB, AI Act DE) and English (NIST, SLSA, CRA EN) texts.
"""
if not text or len(text) <= chunk_size:
return [text.strip()] if text and text.strip() else []
# --- Phase 1: Split into sections by legal headers ---
lines = text.split('\n')
sections = [] # list of (header, content)
current_header = None
current_lines = []
for line in lines:
header = _extract_section_header(line)
if header and current_lines:
sections.append((current_header, '\n'.join(current_lines)))
current_header = header
current_lines = [line]
elif header and not current_lines:
current_header = header
current_lines = [line]
else:
current_lines.append(line)
if current_lines:
sections.append((current_header, '\n'.join(current_lines)))
# --- Phase 2: Within each section, split on paragraphs, then sentences ---
raw_chunks = []
for section_header, section_text in sections:
# Build context prefix (max 120 chars to leave room for content)
prefix = ""
if section_header:
truncated = section_header[:120]
prefix = f"[{truncated}] "
paragraphs = re.split(r'\n\s*\n', section_text)
current_chunk = prefix
current_length = len(prefix)
for para in paragraphs:
para = para.strip()
if not para:
continue
# If paragraph fits in remaining space, append
if current_length + len(para) + 1 <= chunk_size:
if current_chunk and not current_chunk.endswith(' '):
current_chunk += '\n\n'
current_chunk += para
current_length = len(current_chunk)
continue
# Paragraph doesn't fit — flush current chunk if non-empty
if current_chunk.strip() and current_chunk.strip() != prefix.strip():
raw_chunks.append(current_chunk.strip())
# If entire paragraph fits in a fresh chunk, start new chunk
if len(prefix) + len(para) <= chunk_size:
current_chunk = prefix + para
current_length = len(current_chunk)
continue
# Paragraph too long — split by sentences
sentences = _split_sentences(para)
current_chunk = prefix
current_length = len(prefix)
for sentence in sentences:
sentence_len = len(sentence)
# Single sentence exceeds chunk_size — force-split
if len(prefix) + sentence_len > chunk_size:
if current_chunk.strip() and current_chunk.strip() != prefix.strip():
raw_chunks.append(current_chunk.strip())
# Hard split the long sentence
remaining = sentence
while remaining:
take = chunk_size - len(prefix)
chunk_part = prefix + remaining[:take]
raw_chunks.append(chunk_part.strip())
remaining = remaining[take:]
current_chunk = prefix
current_length = len(prefix)
continue
if current_length + sentence_len + 1 > chunk_size:
if current_chunk.strip() and current_chunk.strip() != prefix.strip():
raw_chunks.append(current_chunk.strip())
current_chunk = prefix + sentence
current_length = len(current_chunk)
else:
if current_chunk and not current_chunk.endswith(' '):
current_chunk += ' '
current_chunk += sentence
current_length = len(current_chunk)
# Flush remaining content for this section
if current_chunk.strip() and current_chunk.strip() != prefix.strip():
raw_chunks.append(current_chunk.strip())
if not raw_chunks:
return [text.strip()] if text.strip() else []
# --- Phase 3: Add overlap ---
final_chunks = []
for i, chunk in enumerate(raw_chunks):
if i > 0 and overlap > 0:
prev = raw_chunks[i - 1]
# Take overlap from end of previous chunk (but not the prefix)
overlap_text = prev[-min(overlap, len(prev)):]
# Only add overlap if it doesn't start mid-word
space_idx = overlap_text.find(' ')
if space_idx > 0:
overlap_text = overlap_text[space_idx + 1:]
if overlap_text:
chunk = overlap_text + ' ' + chunk
final_chunks.append(chunk.strip())
return [c for c in final_chunks if c]
def chunk_text_recursive(text: str, chunk_size: int, overlap: int) -> List[str]:
"""Recursive character-based chunking."""
import re
"""Recursive character-based chunking (legacy, use legal_recursive for legal docs)."""
if not text or len(text) <= chunk_size:
return [text] if text else []
@@ -315,36 +552,23 @@ def chunk_text_recursive(text: str, chunk_size: int, overlap: int) -> List[str]:
def chunk_text_semantic(text: str, chunk_size: int, overlap_sentences: int = 1) -> List[str]:
"""Semantic sentence-aware chunking."""
import re
if not text:
return []
if len(text) <= chunk_size:
return [text.strip()]
# Split into sentences (simplified for German)
text = re.sub(r'\s+', ' ', text).strip()
# Protect abbreviations
protected = text
for abbrev in GERMAN_ABBREVIATIONS:
pattern = re.compile(r'\b' + re.escape(abbrev) + r'\.', re.IGNORECASE)
protected = pattern.sub(abbrev.replace('.', '<DOT>') + '<ABBR>', protected)
# Protect decimals and ordinals
protected = re.sub(r'(\d)\.(\d)', r'\1<DECIMAL>\2', protected)
protected = re.sub(r'(\d+)\.(\s)', r'\1<ORD>\2', protected)
protected = _protect_abbreviations(text)
# Split on sentence endings
sentence_pattern = r'(?<=[.!?])\s+(?=[A-ZÄÖÜ])|(?<=[.!?])$'
sentence_pattern = r'(?<=[.!?])\s+(?=[A-ZÄÖÜÀ-Ý])|(?<=[.!?])$'
raw_sentences = re.split(sentence_pattern, protected)
# Restore protected characters
sentences = []
for s in raw_sentences:
s = s.replace('<DOT>', '.').replace('<ABBR>', '.').replace('<DECIMAL>', '.').replace('<ORD>', '.')
s = s.strip()
s = _restore_abbreviations(s).strip()
if s:
sentences.append(s)
@@ -638,7 +862,16 @@ async def rerank_documents(request: RerankRequest):
@app.post("/chunk", response_model=ChunkResponse)
async def chunk_text(request: ChunkRequest):
"""Chunk text into smaller pieces."""
"""Chunk text into smaller pieces.
Strategies:
- "recursive" (default): Legal-document-aware chunking with §/Art./Section
boundary detection, section context headers, paragraph-level splitting,
and sentence-level splitting respecting DE + EN abbreviations.
- "semantic": Sentence-aware chunking with overlap by sentence count.
The old plain recursive chunker has been retired and is no longer available.
"""
if not request.text:
return ChunkResponse(chunks=[], count=0, strategy=request.strategy)
@@ -647,7 +880,9 @@ async def chunk_text(request: ChunkRequest):
overlap_sentences = max(1, request.overlap // 100)
chunks = chunk_text_semantic(request.text, request.chunk_size, overlap_sentences)
else:
chunks = chunk_text_recursive(request.text, request.chunk_size, request.overlap)
# All strategies (recursive, legal_recursive, etc.) use the legal-aware chunker.
# The old plain recursive chunker is no longer exposed via the API.
chunks = chunk_text_legal(request.text, request.chunk_size, request.overlap)
return ChunkResponse(
chunks=chunks,

View File

@@ -0,0 +1,288 @@
"""
Tests for the legal-aware chunking pipeline.
Covers:
- Legal section header detection (§, Art., Section, Chapter, Annex)
- Section context prefix in every chunk
- Paragraph boundary splitting
- Sentence splitting with DE and EN abbreviation protection
- Overlap between chunks
- Fallback for non-legal text
- Long sentence force-splitting
"""
import pytest
from main import (
chunk_text_legal,
chunk_text_recursive,
chunk_text_semantic,
_extract_section_header,
_split_sentences,
_detect_language,
_protect_abbreviations,
_restore_abbreviations,
)
# =========================================================================
# Section header detection
# =========================================================================
class TestSectionHeaderDetection:
def test_german_paragraph(self):
assert _extract_section_header("§ 25 Informationspflichten") is not None
def test_german_paragraph_with_letter(self):
assert _extract_section_header("§ 5a Elektronischer Geschaeftsverkehr") is not None
def test_german_artikel(self):
assert _extract_section_header("Artikel 5 Grundsaetze") is not None
def test_english_article(self):
assert _extract_section_header("Article 12 Transparency") is not None
def test_article_abbreviated(self):
assert _extract_section_header("Art. 3 Definitions") is not None
def test_english_section(self):
assert _extract_section_header("Section 4.2 Risk Assessment") is not None
def test_german_abschnitt(self):
assert _extract_section_header("Abschnitt 3 Pflichten") is not None
def test_chapter(self):
assert _extract_section_header("Chapter 5 Obligations") is not None
def test_german_kapitel(self):
assert _extract_section_header("Kapitel 2 Anwendungsbereich") is not None
def test_annex_roman(self):
assert _extract_section_header("Annex XII Technical Documentation") is not None
def test_german_anhang(self):
assert _extract_section_header("Anhang III Hochrisiko-KI") is not None
def test_part(self):
assert _extract_section_header("Part III Requirements") is not None
def test_markdown_heading(self):
assert _extract_section_header("## 3.1 Overview") is not None
def test_normal_text_not_header(self):
assert _extract_section_header("This is a normal sentence.") is None
def test_short_caps_not_header(self):
assert _extract_section_header("OK") is None
# =========================================================================
# Language detection
# =========================================================================
class TestLanguageDetection:
def test_german_text(self):
text = "Die Verordnung ist für alle Mitgliedstaaten verbindlich und gilt nach dem Grundsatz der unmittelbaren Anwendbarkeit."
assert _detect_language(text) == 'de'
def test_english_text(self):
text = "This regulation shall be binding in its entirety and directly applicable in all Member States."
assert _detect_language(text) == 'en'
# =========================================================================
# Abbreviation protection
# =========================================================================
class TestAbbreviationProtection:
def test_german_abbreviations(self):
text = "gem. § 5 Abs. 1 bzw. § 6 Abs. 2 z.B. die Pflicht"
protected = _protect_abbreviations(text)
assert "." not in protected.replace("<DOT>", "").replace("<DECIMAL>", "").replace("<ORD>", "").replace("<ABBR>", "")
restored = _restore_abbreviations(protected)
assert "gem." in restored
assert "z.B." in restored.replace("z.b.", "z.B.") or "z.b." in restored
def test_english_abbreviations(self):
text = "e.g. section 4.2, i.e. the requirements in vol. 1 ref. NIST SP 800-30."
protected = _protect_abbreviations(text)
# "e.g" and "i.e" should be protected
restored = _restore_abbreviations(protected)
assert "e.g." in restored
def test_decimals_protected(self):
text = "Version 3.14 of the specification requires 2.5 GB."
protected = _protect_abbreviations(text)
assert "<DECIMAL>" in protected
restored = _restore_abbreviations(protected)
assert "3.14" in restored
# =========================================================================
# Sentence splitting
# =========================================================================
class TestSentenceSplitting:
def test_simple_german(self):
text = "Erster Satz. Zweiter Satz. Dritter Satz."
sentences = _split_sentences(text)
assert len(sentences) >= 2
def test_simple_english(self):
text = "First sentence. Second sentence. Third sentence."
sentences = _split_sentences(text)
assert len(sentences) >= 2
def test_german_abbreviation_not_split(self):
text = "Gem. Art. 5 Abs. 1 DSGVO ist die Verarbeitung rechtmaessig. Der Verantwortliche muss dies nachweisen."
sentences = _split_sentences(text)
# Should NOT split at "Gem." or "Art." or "Abs."
assert any("Gem" in s and "DSGVO" in s for s in sentences)
def test_english_abbreviation_not_split(self):
text = "See e.g. Section 4.2 for details. The standard also references vol. 1 of the NIST SP series."
sentences = _split_sentences(text)
assert any("e.g" in s and "Section" in s for s in sentences)
def test_exclamation_and_question(self):
text = "Is this valid? Yes it is! Continue processing."
sentences = _split_sentences(text)
assert len(sentences) >= 2
# =========================================================================
# Legal chunking
# =========================================================================
class TestChunkTextLegal:
def test_small_text_single_chunk(self):
text = "Short text."
chunks = chunk_text_legal(text, chunk_size=1024, overlap=128)
assert len(chunks) == 1
assert chunks[0] == "Short text."
def test_section_header_as_prefix(self):
text = "§ 25 Informationspflichten\n\nDer Betreiber muss den Nutzer informieren. " * 20
chunks = chunk_text_legal(text, chunk_size=200, overlap=0)
assert len(chunks) > 1
# Every chunk should have the section prefix
for chunk in chunks:
assert "[§ 25" in chunk or "§ 25" in chunk
def test_article_prefix_english(self):
text = "Article 12 Transparency\n\n" + "The provider shall ensure transparency of AI systems. " * 30
chunks = chunk_text_legal(text, chunk_size=300, overlap=0)
assert len(chunks) > 1
for chunk in chunks:
assert "Article 12" in chunk
def test_multiple_sections(self):
text = (
"§ 1 Anwendungsbereich\n\nDieses Gesetz gilt fuer alle Betreiber.\n\n"
"§ 2 Begriffsbestimmungen\n\nIm Sinne dieses Gesetzes ist Betreiber, wer eine Anlage betreibt.\n\n"
"§ 3 Pflichten\n\nDer Betreiber hat die Pflicht, die Anlage sicher zu betreiben."
)
chunks = chunk_text_legal(text, chunk_size=200, overlap=0)
# Should have chunks from different sections
section_headers = set()
for chunk in chunks:
if "[§ 1" in chunk:
section_headers.add("§ 1")
if "[§ 2" in chunk:
section_headers.add("§ 2")
if "[§ 3" in chunk:
section_headers.add("§ 3")
assert len(section_headers) >= 2
def test_paragraph_boundaries_respected(self):
para1 = "First paragraph with enough text to matter. " * 5
para2 = "Second paragraph also with content. " * 5
text = para1.strip() + "\n\n" + para2.strip()
chunks = chunk_text_legal(text, chunk_size=300, overlap=0)
# Paragraphs should not be merged mid-sentence across chunk boundary
assert len(chunks) >= 2
def test_overlap_present(self):
text = "Sentence one about topic A. " * 10 + "\n\n" + "Sentence two about topic B. " * 10
chunks = chunk_text_legal(text, chunk_size=200, overlap=50)
if len(chunks) > 1:
# Second chunk should contain some text from end of first chunk
end_of_first = chunks[0][-30:]
# At least some overlap words should appear
overlap_words = set(end_of_first.split())
second_start_words = set(chunks[1][:80].split())
assert len(overlap_words & second_start_words) > 0
def test_nist_style_sections(self):
text = (
"Section 2.1 Risk Framing\n\n"
"Risk framing establishes the context for risk-based decisions. "
"Organizations must define their risk tolerance. " * 10 + "\n\n"
"Section 2.2 Risk Assessment\n\n"
"Risk assessment identifies threats and vulnerabilities. " * 10
)
chunks = chunk_text_legal(text, chunk_size=400, overlap=0)
has_21 = any("Section 2.1" in c for c in chunks)
has_22 = any("Section 2.2" in c for c in chunks)
assert has_21 and has_22
def test_markdown_heading_as_context(self):
text = (
"## 3.1 Overview\n\n"
"This section provides an overview of the specification. " * 15
)
chunks = chunk_text_legal(text, chunk_size=300, overlap=0)
assert len(chunks) > 1
for chunk in chunks:
assert "3.1 Overview" in chunk
def test_empty_text(self):
assert chunk_text_legal("", 1024, 128) == []
def test_whitespace_only(self):
assert chunk_text_legal(" \n\n ", 1024, 128) == []
def test_long_sentence_force_split(self):
long_sentence = "A" * 2000
chunks = chunk_text_legal(long_sentence, chunk_size=500, overlap=0)
assert len(chunks) >= 4
for chunk in chunks:
assert len(chunk) <= 500 + 20 # small margin for prefix
# =========================================================================
# Legacy recursive chunking still works
# =========================================================================
class TestChunkTextRecursive:
def test_basic_split(self):
text = "Hello world. " * 200
chunks = chunk_text_recursive(text, chunk_size=500, overlap=50)
assert len(chunks) > 1
for chunk in chunks:
assert len(chunk) <= 600 # some margin for overlap
def test_small_text(self):
chunks = chunk_text_recursive("Short.", chunk_size=1024, overlap=128)
assert chunks == ["Short."]
# =========================================================================
# Semantic chunking still works
# =========================================================================
class TestChunkTextSemantic:
def test_basic_split(self):
text = "First sentence. Second sentence. Third sentence. Fourth sentence. Fifth sentence."
chunks = chunk_text_semantic(text, chunk_size=50, overlap_sentences=1)
assert len(chunks) >= 2
def test_small_text(self):
chunks = chunk_text_semantic("Short.", chunk_size=1024, overlap_sentences=1)
assert chunks == ["Short."]