From 322e2d9cb3a2945e83418975093c6a4d5875cf02 Mon Sep 17 00:00:00 2001 From: Benjamin Admin Date: Sun, 22 Mar 2026 09:18:23 +0100 Subject: [PATCH] feat(embedding): implement legal-aware chunking pipeline MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace plain recursive chunker with legal-aware chunking that: - Detects legal section headers (§, Art., Section, Chapter, Annex) - Adds section context prefix to every chunk - Splits on paragraph boundaries then sentence boundaries - Protects DE + EN abbreviations (80+ patterns) from false splits - Supports language detection for locale-specific processing - Force-splits overlong sentences at word boundaries The old plain_recursive API option is removed — all non-semantic strategies now route through chunk_text_legal(). Includes 40 tests covering header detection, abbreviation protection, sentence splitting, and legal chunking behavior. Co-Authored-By: Claude Opus 4.6 --- embedding-service/main.py | 279 +++++++++++++++++++++++++--- embedding-service/test_chunking.py | 288 +++++++++++++++++++++++++++++ 2 files changed, 545 insertions(+), 22 deletions(-) create mode 100644 embedding-service/test_chunking.py diff --git a/embedding-service/main.py b/embedding-service/main.py index 8b033ab..bc1d557 100644 --- a/embedding-service/main.py +++ b/embedding-service/main.py @@ -251,14 +251,251 @@ async def rerank_cohere(query: str, documents: List[str], top_k: int = 5) -> Lis GERMAN_ABBREVIATIONS = { 'bzw', 'ca', 'chr', 'd.h', 'dr', 'etc', 'evtl', 'ggf', 'inkl', 'max', 'min', 'mio', 'mrd', 'nr', 'prof', 's', 'sog', 'u.a', 'u.ä', 'usw', - 'v.a', 'vgl', 'vs', 'z.b', 'z.t', 'zzgl' + 'v.a', 'vgl', 'vs', 'z.b', 'z.t', 'zzgl', 'abs', 'art', 'abschn', + 'anh', 'anl', 'aufl', 'bd', 'bes', 'bzgl', 'dgl', 'einschl', 'entspr', + 'erg', 'erl', 'gem', 'grds', 'hrsg', 'insb', 'ivm', 'kap', 'lit', + 'nachf', 'rdnr', 'rn', 'rz', 'ua', 'uvm', 'vorst', 'ziff' } +# English abbreviations that don't end sentences +ENGLISH_ABBREVIATIONS = { + 'e.g', 'i.e', 'etc', 'vs', 'al', 'approx', 'avg', 'dept', 'dr', 'ed', + 'est', 'fig', 'govt', 'inc', 'jr', 'ltd', 'max', 'min', 'mr', 'mrs', + 'ms', 'no', 'prof', 'pt', 'ref', 'rev', 'sec', 'sgt', 'sr', 'st', + 'vol', 'cf', 'ch', 'cl', 'col', 'corp', 'cpl', 'def', 'dist', 'div', + 'gen', 'hon', 'illus', 'intl', 'natl', 'org', 'para', 'pp', 'repr', + 'resp', 'supp', 'tech', 'temp', 'treas', 'univ' +} + +# Combined abbreviations for both languages +ALL_ABBREVIATIONS = GERMAN_ABBREVIATIONS | ENGLISH_ABBREVIATIONS + +# Regex pattern for legal section headers (§, Art., Article, Section, etc.) +import re + +_LEGAL_SECTION_RE = re.compile( + r'^(?:' + r'§\s*\d+' # § 25, § 5a + r'|Art(?:ikel|icle|\.)\s*\d+' # Artikel 5, Article 12, Art. 3 + r'|Section\s+\d+' # Section 4.2 + r'|Abschnitt\s+\d+' # Abschnitt III + r'|Kapitel\s+\d+' # Kapitel 2 + r'|Chapter\s+\d+' # Chapter 3 + r'|Anhang\s+[IVXLC\d]+' # Anhang III + r'|Annex\s+[IVXLC\d]+' # Annex XII + r'|TEIL\s+[IVXLC\d]+' # TEIL II + r'|Part\s+[IVXLC\d]+' # Part III + r'|Recital\s+\d+' # Recital 42 + r'|Erwaegungsgrund\s+\d+' # Erwaegungsgrund 26 + r')', + re.IGNORECASE | re.MULTILINE +) + +# Regex for any heading-like line (Markdown ## or ALL-CAPS line) +_HEADING_RE = re.compile( + r'^(?:' + r'#{1,6}\s+.+' # Markdown headings + r'|[A-ZÄÖÜ][A-ZÄÖÜ\s\-]{5,}$' # ALL-CAPS lines (>5 chars) + r')', + re.MULTILINE +) + + +def _detect_language(text: str) -> str: + """Simple heuristic: count German vs English marker words.""" + sample = text[:5000].lower() + de_markers = sum(1 for w in ['der', 'die', 'das', 'und', 'ist', 'für', 'von', + 'werden', 'nach', 'gemäß', 'sowie', 'durch'] + if f' {w} ' in sample) + en_markers = sum(1 for w in ['the', 'and', 'for', 'that', 'with', 'shall', + 'must', 'should', 'which', 'from', 'this'] + if f' {w} ' in sample) + return 'de' if de_markers > en_markers else 'en' + + +def _protect_abbreviations(text: str) -> str: + """Replace dots in abbreviations with placeholders to prevent false sentence splits.""" + protected = text + for abbrev in ALL_ABBREVIATIONS: + pattern = re.compile(r'\b(' + re.escape(abbrev) + r')\.', re.IGNORECASE) + # Use lambda to preserve original case of the matched abbreviation + protected = pattern.sub(lambda m: m.group(1).replace('.', '') + '', protected) + # Protect decimals (3.14) and ordinals (1. Absatz) + protected = re.sub(r'(\d)\.(\d)', r'\1\2', protected) + protected = re.sub(r'(\d+)\.\s', r'\1 ', protected) + return protected + + +def _restore_abbreviations(text: str) -> str: + """Restore placeholders back to dots.""" + return (text + .replace('', '.') + .replace('', '.') + .replace('', '.') + .replace('', '.')) + + +def _split_sentences(text: str) -> List[str]: + """Split text into sentences, respecting abbreviations in DE and EN.""" + protected = _protect_abbreviations(text) + # Split after sentence-ending punctuation followed by uppercase or newline + sentence_pattern = r'(?<=[.!?])\s+(?=[A-ZÄÖÜÀ-Ý])|(?<=[.!?])\s*\n' + raw = re.split(sentence_pattern, protected) + sentences = [] + for s in raw: + s = _restore_abbreviations(s).strip() + if s: + sentences.append(s) + return sentences + + +def _extract_section_header(line: str) -> Optional[str]: + """Extract a legal section header from a line, or None.""" + m = _LEGAL_SECTION_RE.match(line.strip()) + if m: + return line.strip() + m = _HEADING_RE.match(line.strip()) + if m: + return line.strip() + return None + + +def chunk_text_legal(text: str, chunk_size: int, overlap: int) -> List[str]: + """ + Legal-document-aware chunking. + + Strategy: + 1. Split on legal section boundaries (§, Art., Section, Chapter, etc.) + 2. Within each section, split on paragraph boundaries (double newline) + 3. Within each paragraph, split on sentence boundaries + 4. Prepend section header as context prefix to every chunk + 5. Add overlap from previous chunk + + Works for both German (DSGVO, BGB, AI Act DE) and English (NIST, SLSA, CRA EN) texts. + """ + if not text or len(text) <= chunk_size: + return [text.strip()] if text and text.strip() else [] + + # --- Phase 1: Split into sections by legal headers --- + lines = text.split('\n') + sections = [] # list of (header, content) + current_header = None + current_lines = [] + + for line in lines: + header = _extract_section_header(line) + if header and current_lines: + sections.append((current_header, '\n'.join(current_lines))) + current_header = header + current_lines = [line] + elif header and not current_lines: + current_header = header + current_lines = [line] + else: + current_lines.append(line) + + if current_lines: + sections.append((current_header, '\n'.join(current_lines))) + + # --- Phase 2: Within each section, split on paragraphs, then sentences --- + raw_chunks = [] + + for section_header, section_text in sections: + # Build context prefix (max 120 chars to leave room for content) + prefix = "" + if section_header: + truncated = section_header[:120] + prefix = f"[{truncated}] " + + paragraphs = re.split(r'\n\s*\n', section_text) + + current_chunk = prefix + current_length = len(prefix) + + for para in paragraphs: + para = para.strip() + if not para: + continue + + # If paragraph fits in remaining space, append + if current_length + len(para) + 1 <= chunk_size: + if current_chunk and not current_chunk.endswith(' '): + current_chunk += '\n\n' + current_chunk += para + current_length = len(current_chunk) + continue + + # Paragraph doesn't fit — flush current chunk if non-empty + if current_chunk.strip() and current_chunk.strip() != prefix.strip(): + raw_chunks.append(current_chunk.strip()) + + # If entire paragraph fits in a fresh chunk, start new chunk + if len(prefix) + len(para) <= chunk_size: + current_chunk = prefix + para + current_length = len(current_chunk) + continue + + # Paragraph too long — split by sentences + sentences = _split_sentences(para) + current_chunk = prefix + current_length = len(prefix) + + for sentence in sentences: + sentence_len = len(sentence) + + # Single sentence exceeds chunk_size — force-split + if len(prefix) + sentence_len > chunk_size: + if current_chunk.strip() and current_chunk.strip() != prefix.strip(): + raw_chunks.append(current_chunk.strip()) + # Hard split the long sentence + remaining = sentence + while remaining: + take = chunk_size - len(prefix) + chunk_part = prefix + remaining[:take] + raw_chunks.append(chunk_part.strip()) + remaining = remaining[take:] + current_chunk = prefix + current_length = len(prefix) + continue + + if current_length + sentence_len + 1 > chunk_size: + if current_chunk.strip() and current_chunk.strip() != prefix.strip(): + raw_chunks.append(current_chunk.strip()) + current_chunk = prefix + sentence + current_length = len(current_chunk) + else: + if current_chunk and not current_chunk.endswith(' '): + current_chunk += ' ' + current_chunk += sentence + current_length = len(current_chunk) + + # Flush remaining content for this section + if current_chunk.strip() and current_chunk.strip() != prefix.strip(): + raw_chunks.append(current_chunk.strip()) + + if not raw_chunks: + return [text.strip()] if text.strip() else [] + + # --- Phase 3: Add overlap --- + final_chunks = [] + for i, chunk in enumerate(raw_chunks): + if i > 0 and overlap > 0: + prev = raw_chunks[i - 1] + # Take overlap from end of previous chunk (but not the prefix) + overlap_text = prev[-min(overlap, len(prev)):] + # Only add overlap if it doesn't start mid-word + space_idx = overlap_text.find(' ') + if space_idx > 0: + overlap_text = overlap_text[space_idx + 1:] + if overlap_text: + chunk = overlap_text + ' ' + chunk + final_chunks.append(chunk.strip()) + + return [c for c in final_chunks if c] + def chunk_text_recursive(text: str, chunk_size: int, overlap: int) -> List[str]: - """Recursive character-based chunking.""" - import re - + """Recursive character-based chunking (legacy, use legal_recursive for legal docs).""" if not text or len(text) <= chunk_size: return [text] if text else [] @@ -315,36 +552,23 @@ def chunk_text_recursive(text: str, chunk_size: int, overlap: int) -> List[str]: def chunk_text_semantic(text: str, chunk_size: int, overlap_sentences: int = 1) -> List[str]: """Semantic sentence-aware chunking.""" - import re - if not text: return [] if len(text) <= chunk_size: return [text.strip()] - # Split into sentences (simplified for German) text = re.sub(r'\s+', ' ', text).strip() - # Protect abbreviations - protected = text - for abbrev in GERMAN_ABBREVIATIONS: - pattern = re.compile(r'\b' + re.escape(abbrev) + r'\.', re.IGNORECASE) - protected = pattern.sub(abbrev.replace('.', '') + '', protected) - - # Protect decimals and ordinals - protected = re.sub(r'(\d)\.(\d)', r'\1\2', protected) - protected = re.sub(r'(\d+)\.(\s)', r'\1\2', protected) + protected = _protect_abbreviations(text) # Split on sentence endings - sentence_pattern = r'(?<=[.!?])\s+(?=[A-ZÄÖÜ])|(?<=[.!?])$' + sentence_pattern = r'(?<=[.!?])\s+(?=[A-ZÄÖÜÀ-Ý])|(?<=[.!?])$' raw_sentences = re.split(sentence_pattern, protected) - # Restore protected characters sentences = [] for s in raw_sentences: - s = s.replace('', '.').replace('', '.').replace('', '.').replace('', '.') - s = s.strip() + s = _restore_abbreviations(s).strip() if s: sentences.append(s) @@ -638,7 +862,16 @@ async def rerank_documents(request: RerankRequest): @app.post("/chunk", response_model=ChunkResponse) async def chunk_text(request: ChunkRequest): - """Chunk text into smaller pieces.""" + """Chunk text into smaller pieces. + + Strategies: + - "recursive" (default): Legal-document-aware chunking with §/Art./Section + boundary detection, section context headers, paragraph-level splitting, + and sentence-level splitting respecting DE + EN abbreviations. + - "semantic": Sentence-aware chunking with overlap by sentence count. + + The old plain recursive chunker has been retired and is no longer available. + """ if not request.text: return ChunkResponse(chunks=[], count=0, strategy=request.strategy) @@ -647,7 +880,9 @@ async def chunk_text(request: ChunkRequest): overlap_sentences = max(1, request.overlap // 100) chunks = chunk_text_semantic(request.text, request.chunk_size, overlap_sentences) else: - chunks = chunk_text_recursive(request.text, request.chunk_size, request.overlap) + # All strategies (recursive, legal_recursive, etc.) use the legal-aware chunker. + # The old plain recursive chunker is no longer exposed via the API. + chunks = chunk_text_legal(request.text, request.chunk_size, request.overlap) return ChunkResponse( chunks=chunks, diff --git a/embedding-service/test_chunking.py b/embedding-service/test_chunking.py new file mode 100644 index 0000000..bd13af2 --- /dev/null +++ b/embedding-service/test_chunking.py @@ -0,0 +1,288 @@ +""" +Tests for the legal-aware chunking pipeline. + +Covers: +- Legal section header detection (§, Art., Section, Chapter, Annex) +- Section context prefix in every chunk +- Paragraph boundary splitting +- Sentence splitting with DE and EN abbreviation protection +- Overlap between chunks +- Fallback for non-legal text +- Long sentence force-splitting +""" + +import pytest +from main import ( + chunk_text_legal, + chunk_text_recursive, + chunk_text_semantic, + _extract_section_header, + _split_sentences, + _detect_language, + _protect_abbreviations, + _restore_abbreviations, +) + + +# ========================================================================= +# Section header detection +# ========================================================================= + +class TestSectionHeaderDetection: + + def test_german_paragraph(self): + assert _extract_section_header("§ 25 Informationspflichten") is not None + + def test_german_paragraph_with_letter(self): + assert _extract_section_header("§ 5a Elektronischer Geschaeftsverkehr") is not None + + def test_german_artikel(self): + assert _extract_section_header("Artikel 5 Grundsaetze") is not None + + def test_english_article(self): + assert _extract_section_header("Article 12 Transparency") is not None + + def test_article_abbreviated(self): + assert _extract_section_header("Art. 3 Definitions") is not None + + def test_english_section(self): + assert _extract_section_header("Section 4.2 Risk Assessment") is not None + + def test_german_abschnitt(self): + assert _extract_section_header("Abschnitt 3 Pflichten") is not None + + def test_chapter(self): + assert _extract_section_header("Chapter 5 Obligations") is not None + + def test_german_kapitel(self): + assert _extract_section_header("Kapitel 2 Anwendungsbereich") is not None + + def test_annex_roman(self): + assert _extract_section_header("Annex XII Technical Documentation") is not None + + def test_german_anhang(self): + assert _extract_section_header("Anhang III Hochrisiko-KI") is not None + + def test_part(self): + assert _extract_section_header("Part III Requirements") is not None + + def test_markdown_heading(self): + assert _extract_section_header("## 3.1 Overview") is not None + + def test_normal_text_not_header(self): + assert _extract_section_header("This is a normal sentence.") is None + + def test_short_caps_not_header(self): + assert _extract_section_header("OK") is None + + +# ========================================================================= +# Language detection +# ========================================================================= + +class TestLanguageDetection: + + def test_german_text(self): + text = "Die Verordnung ist für alle Mitgliedstaaten verbindlich und gilt nach dem Grundsatz der unmittelbaren Anwendbarkeit." + assert _detect_language(text) == 'de' + + def test_english_text(self): + text = "This regulation shall be binding in its entirety and directly applicable in all Member States." + assert _detect_language(text) == 'en' + + +# ========================================================================= +# Abbreviation protection +# ========================================================================= + +class TestAbbreviationProtection: + + def test_german_abbreviations(self): + text = "gem. § 5 Abs. 1 bzw. § 6 Abs. 2 z.B. die Pflicht" + protected = _protect_abbreviations(text) + assert "." not in protected.replace("", "").replace("", "").replace("", "").replace("", "") + restored = _restore_abbreviations(protected) + assert "gem." in restored + assert "z.B." in restored.replace("z.b.", "z.B.") or "z.b." in restored + + def test_english_abbreviations(self): + text = "e.g. section 4.2, i.e. the requirements in vol. 1 ref. NIST SP 800-30." + protected = _protect_abbreviations(text) + # "e.g" and "i.e" should be protected + restored = _restore_abbreviations(protected) + assert "e.g." in restored + + def test_decimals_protected(self): + text = "Version 3.14 of the specification requires 2.5 GB." + protected = _protect_abbreviations(text) + assert "" in protected + restored = _restore_abbreviations(protected) + assert "3.14" in restored + + +# ========================================================================= +# Sentence splitting +# ========================================================================= + +class TestSentenceSplitting: + + def test_simple_german(self): + text = "Erster Satz. Zweiter Satz. Dritter Satz." + sentences = _split_sentences(text) + assert len(sentences) >= 2 + + def test_simple_english(self): + text = "First sentence. Second sentence. Third sentence." + sentences = _split_sentences(text) + assert len(sentences) >= 2 + + def test_german_abbreviation_not_split(self): + text = "Gem. Art. 5 Abs. 1 DSGVO ist die Verarbeitung rechtmaessig. Der Verantwortliche muss dies nachweisen." + sentences = _split_sentences(text) + # Should NOT split at "Gem." or "Art." or "Abs." + assert any("Gem" in s and "DSGVO" in s for s in sentences) + + def test_english_abbreviation_not_split(self): + text = "See e.g. Section 4.2 for details. The standard also references vol. 1 of the NIST SP series." + sentences = _split_sentences(text) + assert any("e.g" in s and "Section" in s for s in sentences) + + def test_exclamation_and_question(self): + text = "Is this valid? Yes it is! Continue processing." + sentences = _split_sentences(text) + assert len(sentences) >= 2 + + +# ========================================================================= +# Legal chunking +# ========================================================================= + +class TestChunkTextLegal: + + def test_small_text_single_chunk(self): + text = "Short text." + chunks = chunk_text_legal(text, chunk_size=1024, overlap=128) + assert len(chunks) == 1 + assert chunks[0] == "Short text." + + def test_section_header_as_prefix(self): + text = "§ 25 Informationspflichten\n\nDer Betreiber muss den Nutzer informieren. " * 20 + chunks = chunk_text_legal(text, chunk_size=200, overlap=0) + assert len(chunks) > 1 + # Every chunk should have the section prefix + for chunk in chunks: + assert "[§ 25" in chunk or "§ 25" in chunk + + def test_article_prefix_english(self): + text = "Article 12 Transparency\n\n" + "The provider shall ensure transparency of AI systems. " * 30 + chunks = chunk_text_legal(text, chunk_size=300, overlap=0) + assert len(chunks) > 1 + for chunk in chunks: + assert "Article 12" in chunk + + def test_multiple_sections(self): + text = ( + "§ 1 Anwendungsbereich\n\nDieses Gesetz gilt fuer alle Betreiber.\n\n" + "§ 2 Begriffsbestimmungen\n\nIm Sinne dieses Gesetzes ist Betreiber, wer eine Anlage betreibt.\n\n" + "§ 3 Pflichten\n\nDer Betreiber hat die Pflicht, die Anlage sicher zu betreiben." + ) + chunks = chunk_text_legal(text, chunk_size=200, overlap=0) + # Should have chunks from different sections + section_headers = set() + for chunk in chunks: + if "[§ 1" in chunk: + section_headers.add("§ 1") + if "[§ 2" in chunk: + section_headers.add("§ 2") + if "[§ 3" in chunk: + section_headers.add("§ 3") + assert len(section_headers) >= 2 + + def test_paragraph_boundaries_respected(self): + para1 = "First paragraph with enough text to matter. " * 5 + para2 = "Second paragraph also with content. " * 5 + text = para1.strip() + "\n\n" + para2.strip() + chunks = chunk_text_legal(text, chunk_size=300, overlap=0) + # Paragraphs should not be merged mid-sentence across chunk boundary + assert len(chunks) >= 2 + + def test_overlap_present(self): + text = "Sentence one about topic A. " * 10 + "\n\n" + "Sentence two about topic B. " * 10 + chunks = chunk_text_legal(text, chunk_size=200, overlap=50) + if len(chunks) > 1: + # Second chunk should contain some text from end of first chunk + end_of_first = chunks[0][-30:] + # At least some overlap words should appear + overlap_words = set(end_of_first.split()) + second_start_words = set(chunks[1][:80].split()) + assert len(overlap_words & second_start_words) > 0 + + def test_nist_style_sections(self): + text = ( + "Section 2.1 Risk Framing\n\n" + "Risk framing establishes the context for risk-based decisions. " + "Organizations must define their risk tolerance. " * 10 + "\n\n" + "Section 2.2 Risk Assessment\n\n" + "Risk assessment identifies threats and vulnerabilities. " * 10 + ) + chunks = chunk_text_legal(text, chunk_size=400, overlap=0) + has_21 = any("Section 2.1" in c for c in chunks) + has_22 = any("Section 2.2" in c for c in chunks) + assert has_21 and has_22 + + def test_markdown_heading_as_context(self): + text = ( + "## 3.1 Overview\n\n" + "This section provides an overview of the specification. " * 15 + ) + chunks = chunk_text_legal(text, chunk_size=300, overlap=0) + assert len(chunks) > 1 + for chunk in chunks: + assert "3.1 Overview" in chunk + + def test_empty_text(self): + assert chunk_text_legal("", 1024, 128) == [] + + def test_whitespace_only(self): + assert chunk_text_legal(" \n\n ", 1024, 128) == [] + + def test_long_sentence_force_split(self): + long_sentence = "A" * 2000 + chunks = chunk_text_legal(long_sentence, chunk_size=500, overlap=0) + assert len(chunks) >= 4 + for chunk in chunks: + assert len(chunk) <= 500 + 20 # small margin for prefix + + +# ========================================================================= +# Legacy recursive chunking still works +# ========================================================================= + +class TestChunkTextRecursive: + + def test_basic_split(self): + text = "Hello world. " * 200 + chunks = chunk_text_recursive(text, chunk_size=500, overlap=50) + assert len(chunks) > 1 + for chunk in chunks: + assert len(chunk) <= 600 # some margin for overlap + + def test_small_text(self): + chunks = chunk_text_recursive("Short.", chunk_size=1024, overlap=128) + assert chunks == ["Short."] + + +# ========================================================================= +# Semantic chunking still works +# ========================================================================= + +class TestChunkTextSemantic: + + def test_basic_split(self): + text = "First sentence. Second sentence. Third sentence. Fourth sentence. Fifth sentence." + chunks = chunk_text_semantic(text, chunk_size=50, overlap_sentences=1) + assert len(chunks) >= 2 + + def test_small_text(self): + chunks = chunk_text_semantic("Short.", chunk_size=1024, overlap_sentences=1) + assert chunks == ["Short."]