[split-required] Split 700-870 LOC files across all services

backend-lehrer (11 files):
- llm_gateway/routes/schools.py (867 → 5), recording_api.py (848 → 6)
- messenger_api.py (840 → 5), print_generator.py (824 → 5)
- unit_analytics_api.py (751 → 5), classroom/routes/context.py (726 → 4)
- llm_gateway/routes/edu_search_seeds.py (710 → 4)

klausur-service (12 files):
- ocr_labeling_api.py (845 → 4), metrics_db.py (833 → 4)
- legal_corpus_api.py (790 → 4), page_crop.py (758 → 3)
- mail/ai_service.py (747 → 4), github_crawler.py (767 → 3)
- trocr_service.py (730 → 4), full_compliance_pipeline.py (723 → 4)
- dsfa_rag_api.py (715 → 4), ocr_pipeline_auto.py (705 → 4)

website (6 pages):
- audit-checklist (867 → 8), content (806 → 6)
- screen-flow (790 → 4), scraper (789 → 5)
- zeugnisse (776 → 5), modules (745 → 4)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Benjamin Admin
2026-04-25 08:01:18 +02:00
parent b6983ab1dc
commit 34da9f4cda
106 changed files with 16500 additions and 16947 deletions

View File

@@ -0,0 +1,200 @@
"""
Compliance Extraction & Generation.
Functions for extracting checkpoints from legal text chunks,
generating controls, and creating remediation measures.
"""
import re
import hashlib
import logging
from typing import Dict, List, Optional
from compliance_models import Checkpoint, Control, Measure
logger = logging.getLogger(__name__)
def extract_checkpoints_from_chunk(chunk_text: str, payload: Dict) -> List[Checkpoint]:
"""
Extract checkpoints/requirements from a chunk of text.
Uses pattern matching to find requirement-like statements.
"""
checkpoints = []
regulation_code = payload.get("regulation_code", "UNKNOWN")
regulation_name = payload.get("regulation_name", "Unknown")
source_url = payload.get("source_url", "")
chunk_id = hashlib.md5(chunk_text[:100].encode()).hexdigest()[:8]
# Patterns for different requirement types
patterns = [
# BSI-TR patterns
(r'([OT]\.[A-Za-z_]+\d*)[:\s]+(.+?)(?=\n[OT]\.|$)', 'bsi_requirement'),
# Article patterns (GDPR, AI Act, etc.)
(r'(?:Artikel|Art\.?)\s+(\d+)(?:\s+Abs(?:atz)?\.?\s*(\d+))?\s*[-\u2013:]\s*(.+?)(?=\n|$)', 'article'),
# Numbered requirements
(r'\((\d+)\)\s+(.+?)(?=\n\(\d+\)|$)', 'numbered'),
# "Der Verantwortliche muss" patterns
(r'(?:Der Verantwortliche|Die Aufsichtsbeh\u00f6rde|Der Auftragsverarbeiter)\s+(muss|hat|soll)\s+(.+?)(?=\.\s|$)', 'obligation'),
# "Es ist erforderlich" patterns
(r'(?:Es ist erforderlich|Es muss gew\u00e4hrleistet|Es sind geeignete)\s+(.+?)(?=\.\s|$)', 'requirement'),
]
for pattern, pattern_type in patterns:
matches = re.finditer(pattern, chunk_text, re.MULTILINE | re.DOTALL)
for match in matches:
if pattern_type == 'bsi_requirement':
req_id = match.group(1)
description = match.group(2).strip()
title = req_id
elif pattern_type == 'article':
article_num = match.group(1)
paragraph = match.group(2) or ""
title_text = match.group(3).strip()
req_id = f"{regulation_code}-Art{article_num}"
if paragraph:
req_id += f"-{paragraph}"
title = f"Art. {article_num}" + (f" Abs. {paragraph}" if paragraph else "")
description = title_text
elif pattern_type == 'numbered':
num = match.group(1)
description = match.group(2).strip()
req_id = f"{regulation_code}-{num}"
title = f"Anforderung {num}"
else:
# Generic requirement
description = match.group(0).strip()
req_id = f"{regulation_code}-{chunk_id}-{len(checkpoints)}"
title = description[:50] + "..." if len(description) > 50 else description
# Skip very short matches
if len(description) < 20:
continue
checkpoint = Checkpoint(
id=req_id,
regulation_code=regulation_code,
regulation_name=regulation_name,
article=title if 'Art' in title else None,
title=title,
description=description[:500],
original_text=description,
chunk_id=chunk_id,
source_url=source_url
)
checkpoints.append(checkpoint)
return checkpoints
def generate_control_for_checkpoints(
checkpoints: List[Checkpoint],
domain_counts: Dict[str, int],
) -> Optional[Control]:
"""
Generate a control that covers the given checkpoints.
This is a simplified version - in production this would use the AI assistant.
"""
if not checkpoints:
return None
# Group by regulation
regulation = checkpoints[0].regulation_code
# Determine domain based on content
all_text = " ".join([cp.description for cp in checkpoints]).lower()
domain = "gov" # Default
if any(kw in all_text for kw in ["verschl\u00fcssel", "krypto", "encrypt", "hash"]):
domain = "crypto"
elif any(kw in all_text for kw in ["zugang", "access", "authentif", "login", "benutzer"]):
domain = "iam"
elif any(kw in all_text for kw in ["datenschutz", "personenbezogen", "privacy", "einwilligung"]):
domain = "priv"
elif any(kw in all_text for kw in ["entwicklung", "test", "code", "software"]):
domain = "sdlc"
elif any(kw in all_text for kw in ["\u00fcberwach", "monitor", "log", "audit"]):
domain = "aud"
elif any(kw in all_text for kw in ["ki", "k\u00fcnstlich", "ai", "machine learning", "model"]):
domain = "ai"
elif any(kw in all_text for kw in ["betrieb", "operation", "verf\u00fcgbar", "backup"]):
domain = "ops"
elif any(kw in all_text for kw in ["cyber", "resilience", "sbom", "vulnerab"]):
domain = "cra"
# Generate control ID
domain_count = domain_counts.get(domain, 0) + 1
control_id = f"{domain.upper()}-{domain_count:03d}"
# Create title from first checkpoint
title = checkpoints[0].title
if len(title) > 100:
title = title[:97] + "..."
# Create description
description = f"Control f\u00fcr {regulation}: " + checkpoints[0].description[:200]
# Pass criteria
pass_criteria = f"Alle {len(checkpoints)} zugeh\u00f6rigen Anforderungen sind erf\u00fcllt und dokumentiert."
# Implementation guidance
guidance = f"Implementiere Ma\u00dfnahmen zur Erf\u00fcllung der Anforderungen aus {regulation}. "
guidance += f"Dokumentiere die Umsetzung und f\u00fchre regelm\u00e4\u00dfige Reviews durch."
# Determine if automated
is_automated = any(kw in all_text for kw in ["automat", "tool", "scan", "test"])
control = Control(
id=control_id,
domain=domain,
title=title,
description=description,
checkpoints=[cp.id for cp in checkpoints],
pass_criteria=pass_criteria,
implementation_guidance=guidance,
is_automated=is_automated,
automation_tool="CI/CD Pipeline" if is_automated else None,
priority="high" if "muss" in all_text or "erforderlich" in all_text else "medium"
)
return control
def generate_measure_for_control(control: Control) -> Measure:
"""Generate a remediation measure for a control."""
measure_id = f"M-{control.id}"
# Determine deadline based on priority
deadline_days = {
"critical": 30,
"high": 60,
"medium": 90,
"low": 180
}.get(control.priority, 90)
# Determine responsible team
responsible = {
"priv": "Datenschutzbeauftragter",
"iam": "IT-Security Team",
"sdlc": "Entwicklungsteam",
"crypto": "IT-Security Team",
"ops": "Operations Team",
"aud": "Compliance Team",
"ai": "AI/ML Team",
"cra": "IT-Security Team",
"gov": "Management"
}.get(control.domain, "Compliance Team")
measure = Measure(
id=measure_id,
control_id=control.id,
title=f"Umsetzung: {control.title[:50]}",
description=f"Implementierung und Dokumentation von {control.id}: {control.description[:100]}",
responsible=responsible,
deadline_days=deadline_days,
status="pending"
)
return measure

View File

@@ -0,0 +1,49 @@
"""
Compliance Pipeline Data Models.
Dataclasses for checkpoints, controls, and measures.
"""
from typing import Optional, List
from dataclasses import dataclass
@dataclass
class Checkpoint:
"""A requirement/checkpoint extracted from legal text."""
id: str
regulation_code: str
regulation_name: str
article: Optional[str]
title: str
description: str
original_text: str
chunk_id: str
source_url: str
@dataclass
class Control:
"""A control derived from checkpoints."""
id: str
domain: str
title: str
description: str
checkpoints: List[str] # List of checkpoint IDs
pass_criteria: str
implementation_guidance: str
is_automated: bool
automation_tool: Optional[str]
priority: str
@dataclass
class Measure:
"""A remediation measure for a control."""
id: str
control_id: str
title: str
description: str
responsible: str
deadline_days: int
status: str

View File

@@ -0,0 +1,441 @@
"""
Compliance Pipeline Execution.
Pipeline phases (ingestion, extraction, control generation, measures)
and orchestration logic.
"""
import asyncio
import json
import logging
import os
import sys
import time
from datetime import datetime
from typing import Dict, List, Any
from dataclasses import asdict
from compliance_models import Checkpoint, Control, Measure
from compliance_extraction import (
extract_checkpoints_from_chunk,
generate_control_for_checkpoints,
generate_measure_for_control,
)
logger = logging.getLogger(__name__)
# Import checkpoint manager
try:
from pipeline_checkpoints import CheckpointManager, EXPECTED_VALUES, ValidationStatus
except ImportError:
logger.warning("Checkpoint manager not available, running without checkpoints")
CheckpointManager = None
EXPECTED_VALUES = {}
ValidationStatus = None
# Set environment variables for Docker network
if not os.getenv("QDRANT_URL") and not os.getenv("QDRANT_HOST"):
os.environ["QDRANT_HOST"] = "qdrant"
os.environ.setdefault("EMBEDDING_SERVICE_URL", "http://embedding-service:8087")
# Try to import from klausur-service
try:
from legal_corpus_ingestion import LegalCorpusIngestion, REGULATIONS, LEGAL_CORPUS_COLLECTION
from qdrant_client import QdrantClient
from qdrant_client.models import Filter, FieldCondition, MatchValue
except ImportError:
logger.error("Could not import required modules. Make sure you're in the klausur-service container.")
sys.exit(1)
class CompliancePipeline:
"""Handles the full compliance pipeline."""
def __init__(self):
# Support both QDRANT_URL and QDRANT_HOST/PORT
qdrant_url = os.getenv("QDRANT_URL", "")
if qdrant_url:
from urllib.parse import urlparse
parsed = urlparse(qdrant_url)
qdrant_host = parsed.hostname or "qdrant"
qdrant_port = parsed.port or 6333
else:
qdrant_host = os.getenv("QDRANT_HOST", "qdrant")
qdrant_port = 6333
self.qdrant = QdrantClient(host=qdrant_host, port=qdrant_port)
self.checkpoints: List[Checkpoint] = []
self.controls: List[Control] = []
self.measures: List[Measure] = []
self.stats = {
"chunks_processed": 0,
"checkpoints_extracted": 0,
"controls_created": 0,
"measures_defined": 0,
"by_regulation": {},
"by_domain": {},
}
# Initialize checkpoint manager
self.checkpoint_mgr = CheckpointManager() if CheckpointManager else None
async def run_ingestion_phase(self, force_reindex: bool = False) -> int:
"""Phase 1: Ingest documents (incremental - only missing ones)."""
logger.info("\n" + "=" * 60)
logger.info("PHASE 1: DOCUMENT INGESTION (INCREMENTAL)")
logger.info("=" * 60)
if self.checkpoint_mgr:
self.checkpoint_mgr.start_checkpoint("ingestion", "Document Ingestion")
ingestion = LegalCorpusIngestion()
try:
# Check existing chunks per regulation
existing_chunks = {}
try:
for regulation in REGULATIONS:
count_result = self.qdrant.count(
collection_name=LEGAL_CORPUS_COLLECTION,
count_filter=Filter(
must=[FieldCondition(key="regulation_code", match=MatchValue(value=regulation.code))]
)
)
existing_chunks[regulation.code] = count_result.count
logger.info(f" {regulation.code}: {count_result.count} existing chunks")
except Exception as e:
logger.warning(f"Could not check existing chunks: {e}")
# Determine which regulations need ingestion
regulations_to_ingest = []
for regulation in REGULATIONS:
existing = existing_chunks.get(regulation.code, 0)
if force_reindex or existing == 0:
regulations_to_ingest.append(regulation)
logger.info(f" -> Will ingest: {regulation.code} (existing: {existing}, force: {force_reindex})")
else:
logger.info(f" -> Skipping: {regulation.code} (already has {existing} chunks)")
self.stats["by_regulation"][regulation.code] = existing
if not regulations_to_ingest:
logger.info("All regulations already indexed. Skipping ingestion phase.")
total_chunks = sum(existing_chunks.values())
self.stats["chunks_processed"] = total_chunks
if self.checkpoint_mgr:
self.checkpoint_mgr.add_metric("total_chunks", total_chunks)
self.checkpoint_mgr.add_metric("skipped", True)
self.checkpoint_mgr.complete_checkpoint(success=True)
return total_chunks
# Ingest only missing regulations
total_chunks = sum(existing_chunks.values())
for i, regulation in enumerate(regulations_to_ingest, 1):
logger.info(f"[{i}/{len(regulations_to_ingest)}] Ingesting {regulation.code}...")
try:
count = await ingestion.ingest_regulation(regulation)
total_chunks += count
self.stats["by_regulation"][regulation.code] = count
logger.info(f" -> {count} chunks")
if self.checkpoint_mgr:
self.checkpoint_mgr.add_metric(f"chunks_{regulation.code}", count)
except Exception as e:
logger.error(f" -> FAILED: {e}")
self.stats["by_regulation"][regulation.code] = 0
self.stats["chunks_processed"] = total_chunks
logger.info(f"\nTotal chunks in collection: {total_chunks}")
# Validate ingestion results
if self.checkpoint_mgr:
self.checkpoint_mgr.add_metric("total_chunks", total_chunks)
self.checkpoint_mgr.add_metric("regulations_count", len(REGULATIONS))
expected = EXPECTED_VALUES.get("ingestion", {})
self.checkpoint_mgr.validate(
"total_chunks",
expected=expected.get("total_chunks", 8000),
actual=total_chunks,
min_value=expected.get("min_chunks", 7000)
)
reg_expected = expected.get("regulations", {})
for reg_code, reg_exp in reg_expected.items():
actual = self.stats["by_regulation"].get(reg_code, 0)
self.checkpoint_mgr.validate(
f"chunks_{reg_code}",
expected=reg_exp.get("expected", 0),
actual=actual,
min_value=reg_exp.get("min", 0)
)
self.checkpoint_mgr.complete_checkpoint(success=True)
return total_chunks
except Exception as e:
if self.checkpoint_mgr:
self.checkpoint_mgr.fail_checkpoint(str(e))
raise
finally:
await ingestion.close()
async def run_extraction_phase(self) -> int:
"""Phase 2: Extract checkpoints from chunks."""
logger.info("\n" + "=" * 60)
logger.info("PHASE 2: CHECKPOINT EXTRACTION")
logger.info("=" * 60)
if self.checkpoint_mgr:
self.checkpoint_mgr.start_checkpoint("extraction", "Checkpoint Extraction")
try:
offset = None
total_checkpoints = 0
while True:
result = self.qdrant.scroll(
collection_name=LEGAL_CORPUS_COLLECTION,
limit=100,
offset=offset,
with_payload=True,
with_vectors=False
)
points, next_offset = result
if not points:
break
for point in points:
payload = point.payload
text = payload.get("text", "")
cps = extract_checkpoints_from_chunk(text, payload)
self.checkpoints.extend(cps)
total_checkpoints += len(cps)
logger.info(f"Processed {len(points)} chunks, extracted {total_checkpoints} checkpoints so far...")
if next_offset is None:
break
offset = next_offset
self.stats["checkpoints_extracted"] = len(self.checkpoints)
logger.info(f"\nTotal checkpoints extracted: {len(self.checkpoints)}")
by_reg = {}
for cp in self.checkpoints:
by_reg[cp.regulation_code] = by_reg.get(cp.regulation_code, 0) + 1
for reg, count in sorted(by_reg.items()):
logger.info(f" {reg}: {count} checkpoints")
if self.checkpoint_mgr:
self.checkpoint_mgr.add_metric("total_checkpoints", len(self.checkpoints))
self.checkpoint_mgr.add_metric("checkpoints_by_regulation", by_reg)
expected = EXPECTED_VALUES.get("extraction", {})
self.checkpoint_mgr.validate(
"total_checkpoints",
expected=expected.get("total_checkpoints", 3500),
actual=len(self.checkpoints),
min_value=expected.get("min_checkpoints", 3000)
)
self.checkpoint_mgr.complete_checkpoint(success=True)
return len(self.checkpoints)
except Exception as e:
if self.checkpoint_mgr:
self.checkpoint_mgr.fail_checkpoint(str(e))
raise
async def run_control_generation_phase(self) -> int:
"""Phase 3: Generate controls from checkpoints."""
logger.info("\n" + "=" * 60)
logger.info("PHASE 3: CONTROL GENERATION")
logger.info("=" * 60)
if self.checkpoint_mgr:
self.checkpoint_mgr.start_checkpoint("controls", "Control Generation")
try:
# Group checkpoints by regulation
by_regulation: Dict[str, List[Checkpoint]] = {}
for cp in self.checkpoints:
reg = cp.regulation_code
if reg not in by_regulation:
by_regulation[reg] = []
by_regulation[reg].append(cp)
# Generate controls per regulation (group every 3-5 checkpoints)
for regulation, checkpoints in by_regulation.items():
logger.info(f"Generating controls for {regulation} ({len(checkpoints)} checkpoints)...")
batch_size = 4
for i in range(0, len(checkpoints), batch_size):
batch = checkpoints[i:i + batch_size]
control = generate_control_for_checkpoints(batch, self.stats.get("by_domain", {}))
if control:
self.controls.append(control)
self.stats["by_domain"][control.domain] = self.stats["by_domain"].get(control.domain, 0) + 1
self.stats["controls_created"] = len(self.controls)
logger.info(f"\nTotal controls created: {len(self.controls)}")
for domain, count in sorted(self.stats["by_domain"].items()):
logger.info(f" {domain}: {count} controls")
if self.checkpoint_mgr:
self.checkpoint_mgr.add_metric("total_controls", len(self.controls))
self.checkpoint_mgr.add_metric("controls_by_domain", dict(self.stats["by_domain"]))
expected = EXPECTED_VALUES.get("controls", {})
self.checkpoint_mgr.validate(
"total_controls",
expected=expected.get("total_controls", 900),
actual=len(self.controls),
min_value=expected.get("min_controls", 800)
)
self.checkpoint_mgr.complete_checkpoint(success=True)
return len(self.controls)
except Exception as e:
if self.checkpoint_mgr:
self.checkpoint_mgr.fail_checkpoint(str(e))
raise
async def run_measure_generation_phase(self) -> int:
"""Phase 4: Generate measures for controls."""
logger.info("\n" + "=" * 60)
logger.info("PHASE 4: MEASURE GENERATION")
logger.info("=" * 60)
if self.checkpoint_mgr:
self.checkpoint_mgr.start_checkpoint("measures", "Measure Generation")
try:
for control in self.controls:
measure = generate_measure_for_control(control)
self.measures.append(measure)
self.stats["measures_defined"] = len(self.measures)
logger.info(f"\nTotal measures defined: {len(self.measures)}")
if self.checkpoint_mgr:
self.checkpoint_mgr.add_metric("total_measures", len(self.measures))
expected = EXPECTED_VALUES.get("measures", {})
self.checkpoint_mgr.validate(
"total_measures",
expected=expected.get("total_measures", 900),
actual=len(self.measures),
min_value=expected.get("min_measures", 800)
)
self.checkpoint_mgr.complete_checkpoint(success=True)
return len(self.measures)
except Exception as e:
if self.checkpoint_mgr:
self.checkpoint_mgr.fail_checkpoint(str(e))
raise
def save_results(self, output_dir: str = "/tmp/compliance_output"):
"""Save results to JSON files."""
logger.info("\n" + "=" * 60)
logger.info("SAVING RESULTS")
logger.info("=" * 60)
os.makedirs(output_dir, exist_ok=True)
checkpoints_file = os.path.join(output_dir, "checkpoints.json")
with open(checkpoints_file, "w") as f:
json.dump([asdict(cp) for cp in self.checkpoints], f, indent=2, ensure_ascii=False)
logger.info(f"Saved {len(self.checkpoints)} checkpoints to {checkpoints_file}")
controls_file = os.path.join(output_dir, "controls.json")
with open(controls_file, "w") as f:
json.dump([asdict(c) for c in self.controls], f, indent=2, ensure_ascii=False)
logger.info(f"Saved {len(self.controls)} controls to {controls_file}")
measures_file = os.path.join(output_dir, "measures.json")
with open(measures_file, "w") as f:
json.dump([asdict(m) for m in self.measures], f, indent=2, ensure_ascii=False)
logger.info(f"Saved {len(self.measures)} measures to {measures_file}")
stats_file = os.path.join(output_dir, "statistics.json")
self.stats["generated_at"] = datetime.now().isoformat()
with open(stats_file, "w") as f:
json.dump(self.stats, f, indent=2, ensure_ascii=False)
logger.info(f"Saved statistics to {stats_file}")
async def run_full_pipeline(self, force_reindex: bool = False, skip_ingestion: bool = False):
"""Run the complete pipeline.
Args:
force_reindex: If True, re-ingest all documents even if they exist
skip_ingestion: If True, skip ingestion phase entirely (use existing chunks)
"""
start_time = time.time()
logger.info("=" * 60)
logger.info("FULL COMPLIANCE PIPELINE (INCREMENTAL)")
logger.info(f"Started at: {datetime.now().isoformat()}")
logger.info(f"Force reindex: {force_reindex}")
logger.info(f"Skip ingestion: {skip_ingestion}")
if self.checkpoint_mgr:
logger.info(f"Pipeline ID: {self.checkpoint_mgr.pipeline_id}")
logger.info("=" * 60)
try:
if skip_ingestion:
logger.info("Skipping ingestion phase as requested...")
try:
collection_info = self.qdrant.get_collection(LEGAL_CORPUS_COLLECTION)
self.stats["chunks_processed"] = collection_info.points_count
except Exception:
self.stats["chunks_processed"] = 0
else:
await self.run_ingestion_phase(force_reindex=force_reindex)
await self.run_extraction_phase()
await self.run_control_generation_phase()
await self.run_measure_generation_phase()
self.save_results()
elapsed = time.time() - start_time
logger.info("\n" + "=" * 60)
logger.info("PIPELINE COMPLETE")
logger.info("=" * 60)
logger.info(f"Duration: {elapsed:.1f} seconds")
logger.info(f"Chunks processed: {self.stats['chunks_processed']}")
logger.info(f"Checkpoints extracted: {self.stats['checkpoints_extracted']}")
logger.info(f"Controls created: {self.stats['controls_created']}")
logger.info(f"Measures defined: {self.stats['measures_defined']}")
logger.info(f"\nResults saved to: /tmp/compliance_output/")
logger.info("Checkpoint status: /tmp/pipeline_checkpoints.json")
logger.info("=" * 60)
if self.checkpoint_mgr:
self.checkpoint_mgr.complete_pipeline({
"duration_seconds": elapsed,
"chunks_processed": self.stats['chunks_processed'],
"checkpoints_extracted": self.stats['checkpoints_extracted'],
"controls_created": self.stats['controls_created'],
"measures_defined": self.stats['measures_defined'],
"by_regulation": self.stats['by_regulation'],
"by_domain": self.stats['by_domain'],
})
except Exception as e:
logger.error(f"Pipeline failed: {e}")
if self.checkpoint_mgr:
self.checkpoint_mgr.state.status = "failed"
self.checkpoint_mgr._save()
raise

View File

@@ -1,7 +1,10 @@
"""
DSFA RAG API Endpoints.
DSFA RAG API Endpoints — Barrel Re-export.
Provides REST API for searching DSFA corpus with full source attribution.
Split into submodules:
- dsfa_rag_models.py — Pydantic request/response models
- dsfa_rag_embedding.py — Embedding service integration & text extraction
- dsfa_rag_routes.py — Route handlers (search, sources, ingest, stats)
Endpoints:
- GET /api/v1/dsfa-rag/search - Semantic search with attribution
@@ -11,705 +14,54 @@ Endpoints:
- GET /api/v1/dsfa-rag/stats - Get corpus statistics
"""
import os
import uuid
import logging
from typing import List, Optional
from dataclasses import dataclass, asdict
import httpx
from fastapi import APIRouter, HTTPException, Query, Depends
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
# Embedding service configuration
EMBEDDING_SERVICE_URL = os.getenv("EMBEDDING_SERVICE_URL", "http://172.18.0.13:8087")
# Import from ingestion module
from dsfa_corpus_ingestion import (
DSFACorpusStore,
DSFAQdrantService,
DSFASearchResult,
LICENSE_REGISTRY,
DSFA_SOURCES,
generate_attribution_notice,
get_license_label,
DSFA_COLLECTION,
chunk_document
# Models
from dsfa_rag_models import (
DSFASourceResponse,
DSFAChunkResponse,
DSFASearchResultResponse,
DSFASearchResponse,
DSFASourceStatsResponse,
DSFACorpusStatsResponse,
IngestRequest,
IngestResponse,
LicenseInfo,
)
router = APIRouter(prefix="/api/v1/dsfa-rag", tags=["DSFA RAG"])
# =============================================================================
# Pydantic Models
# =============================================================================
class DSFASourceResponse(BaseModel):
"""Response model for DSFA source."""
id: str
source_code: str
name: str
full_name: Optional[str] = None
organization: Optional[str] = None
source_url: Optional[str] = None
license_code: str
license_name: str
license_url: Optional[str] = None
attribution_required: bool
attribution_text: str
document_type: Optional[str] = None
language: str = "de"
class DSFAChunkResponse(BaseModel):
"""Response model for a single chunk with attribution."""
chunk_id: str
content: str
section_title: Optional[str] = None
page_number: Optional[int] = None
category: Optional[str] = None
# Document info
document_id: str
document_title: Optional[str] = None
# Attribution (always included)
source_id: str
source_code: str
source_name: str
attribution_text: str
license_code: str
license_name: str
license_url: Optional[str] = None
attribution_required: bool
source_url: Optional[str] = None
document_type: Optional[str] = None
class DSFASearchResultResponse(BaseModel):
"""Response model for search result."""
chunk_id: str
content: str
score: float
# Attribution
source_code: str
source_name: str
attribution_text: str
license_code: str
license_name: str
license_url: Optional[str] = None
attribution_required: bool
source_url: Optional[str] = None
# Metadata
document_type: Optional[str] = None
category: Optional[str] = None
section_title: Optional[str] = None
page_number: Optional[int] = None
class DSFASearchResponse(BaseModel):
"""Response model for search endpoint."""
query: str
results: List[DSFASearchResultResponse]
total_results: int
# Aggregated licenses for footer
licenses_used: List[str]
attribution_notice: str
class DSFASourceStatsResponse(BaseModel):
"""Response model for source statistics."""
source_id: str
source_code: str
name: str
organization: Optional[str] = None
license_code: str
document_type: Optional[str] = None
document_count: int
chunk_count: int
last_indexed_at: Optional[str] = None
class DSFACorpusStatsResponse(BaseModel):
"""Response model for corpus statistics."""
sources: List[DSFASourceStatsResponse]
total_sources: int
total_documents: int
total_chunks: int
qdrant_collection: str
qdrant_points_count: int
qdrant_status: str
class IngestRequest(BaseModel):
"""Request model for ingestion."""
document_url: Optional[str] = None
document_text: Optional[str] = None
title: Optional[str] = None
class IngestResponse(BaseModel):
"""Response model for ingestion."""
source_code: str
document_id: Optional[str] = None
chunks_created: int
message: str
class LicenseInfo(BaseModel):
"""License information."""
code: str
name: str
url: Optional[str] = None
attribution_required: bool
modification_allowed: bool
commercial_use: bool
# =============================================================================
# Dependency Injection
# =============================================================================
# Database pool (will be set from main.py)
_db_pool = None
def set_db_pool(pool):
"""Set the database pool for API endpoints."""
global _db_pool
_db_pool = pool
async def get_store() -> DSFACorpusStore:
"""Get DSFA corpus store."""
if _db_pool is None:
raise HTTPException(status_code=503, detail="Database not initialized")
return DSFACorpusStore(_db_pool)
async def get_qdrant() -> DSFAQdrantService:
"""Get Qdrant service."""
return DSFAQdrantService()
# =============================================================================
# Embedding Service Integration
# =============================================================================
async def get_embedding(text: str) -> List[float]:
"""
Get embedding for text using the embedding-service.
Uses BGE-M3 model which produces 1024-dimensional vectors.
"""
async with httpx.AsyncClient(timeout=60.0) as client:
try:
response = await client.post(
f"{EMBEDDING_SERVICE_URL}/embed-single",
json={"text": text}
)
response.raise_for_status()
data = response.json()
return data.get("embedding", [])
except httpx.HTTPError as e:
logger.error(f"Embedding service error: {e}")
# Fallback to hash-based pseudo-embedding for development
return _generate_fallback_embedding(text)
async def get_embeddings_batch(texts: List[str]) -> List[List[float]]:
"""
Get embeddings for multiple texts in batch.
"""
async with httpx.AsyncClient(timeout=120.0) as client:
try:
response = await client.post(
f"{EMBEDDING_SERVICE_URL}/embed",
json={"texts": texts}
)
response.raise_for_status()
data = response.json()
return data.get("embeddings", [])
except httpx.HTTPError as e:
logger.error(f"Embedding service batch error: {e}")
# Fallback
return [_generate_fallback_embedding(t) for t in texts]
async def extract_text_from_url(url: str) -> str:
"""
Extract text from a document URL (PDF, HTML, etc.).
"""
async with httpx.AsyncClient(timeout=120.0) as client:
try:
# First try to use the embedding-service's extract-pdf endpoint
response = await client.post(
f"{EMBEDDING_SERVICE_URL}/extract-pdf",
json={"url": url}
)
response.raise_for_status()
data = response.json()
return data.get("text", "")
except httpx.HTTPError as e:
logger.error(f"PDF extraction error for {url}: {e}")
# Fallback: try to fetch HTML content directly
try:
response = await client.get(url, follow_redirects=True)
response.raise_for_status()
content_type = response.headers.get("content-type", "")
if "html" in content_type:
# Simple HTML text extraction
import re
html = response.text
# Remove scripts and styles
html = re.sub(r'<script[^>]*>.*?</script>', '', html, flags=re.DOTALL | re.IGNORECASE)
html = re.sub(r'<style[^>]*>.*?</style>', '', html, flags=re.DOTALL | re.IGNORECASE)
# Remove tags
text = re.sub(r'<[^>]+>', ' ', html)
# Clean whitespace
text = re.sub(r'\s+', ' ', text).strip()
return text
else:
return ""
except Exception as fetch_err:
logger.error(f"Fallback fetch error for {url}: {fetch_err}")
return ""
def _generate_fallback_embedding(text: str) -> List[float]:
"""
Generate deterministic pseudo-embedding from text hash.
Used as fallback when embedding service is unavailable.
"""
import hashlib
import struct
hash_bytes = hashlib.sha256(text.encode()).digest()
embedding = []
for i in range(0, min(len(hash_bytes), 128), 4):
val = struct.unpack('f', hash_bytes[i:i+4])[0]
embedding.append(val % 1.0)
# Pad to 1024 dimensions
while len(embedding) < 1024:
embedding.extend(embedding[:min(len(embedding), 1024 - len(embedding))])
return embedding[:1024]
# =============================================================================
# API Endpoints
# =============================================================================
@router.get("/search", response_model=DSFASearchResponse)
async def search_dsfa_corpus(
query: str = Query(..., min_length=3, description="Search query"),
source_codes: Optional[List[str]] = Query(None, description="Filter by source codes"),
document_types: Optional[List[str]] = Query(None, description="Filter by document types (guideline, checklist, regulation)"),
categories: Optional[List[str]] = Query(None, description="Filter by categories (threshold_analysis, risk_assessment, mitigation)"),
limit: int = Query(10, ge=1, le=50, description="Maximum results"),
include_attribution: bool = Query(True, description="Include attribution in results"),
store: DSFACorpusStore = Depends(get_store),
qdrant: DSFAQdrantService = Depends(get_qdrant)
):
"""
Search DSFA corpus with full attribution.
Returns matching chunks with source/license information for compliance.
"""
# Get query embedding
query_embedding = await get_embedding(query)
# Search Qdrant
raw_results = await qdrant.search(
query_embedding=query_embedding,
source_codes=source_codes,
document_types=document_types,
categories=categories,
limit=limit
)
# Transform results
results = []
licenses_used = set()
for r in raw_results:
license_code = r.get("license_code", "")
license_info = LICENSE_REGISTRY.get(license_code, {})
result = DSFASearchResultResponse(
chunk_id=r.get("chunk_id", ""),
content=r.get("content", ""),
score=r.get("score", 0.0),
source_code=r.get("source_code", ""),
source_name=r.get("source_name", ""),
attribution_text=r.get("attribution_text", ""),
license_code=license_code,
license_name=license_info.get("name", license_code),
license_url=license_info.get("url"),
attribution_required=r.get("attribution_required", True),
source_url=r.get("source_url"),
document_type=r.get("document_type"),
category=r.get("category"),
section_title=r.get("section_title"),
page_number=r.get("page_number")
)
results.append(result)
licenses_used.add(license_code)
# Generate attribution notice
search_results = [
DSFASearchResult(
chunk_id=r.chunk_id,
content=r.content,
score=r.score,
source_code=r.source_code,
source_name=r.source_name,
attribution_text=r.attribution_text,
license_code=r.license_code,
license_url=r.license_url,
attribution_required=r.attribution_required,
source_url=r.source_url,
document_type=r.document_type or "",
category=r.category or "",
section_title=r.section_title,
page_number=r.page_number
)
for r in results
]
attribution_notice = generate_attribution_notice(search_results) if include_attribution else ""
return DSFASearchResponse(
query=query,
results=results,
total_results=len(results),
licenses_used=list(licenses_used),
attribution_notice=attribution_notice
)
@router.get("/sources", response_model=List[DSFASourceResponse])
async def list_dsfa_sources(
document_type: Optional[str] = Query(None, description="Filter by document type"),
license_code: Optional[str] = Query(None, description="Filter by license"),
store: DSFACorpusStore = Depends(get_store)
):
"""List all registered DSFA sources with license info."""
sources = await store.list_sources()
result = []
for s in sources:
# Apply filters
if document_type and s.get("document_type") != document_type:
continue
if license_code and s.get("license_code") != license_code:
continue
license_info = LICENSE_REGISTRY.get(s.get("license_code", ""), {})
result.append(DSFASourceResponse(
id=str(s["id"]),
source_code=s["source_code"],
name=s["name"],
full_name=s.get("full_name"),
organization=s.get("organization"),
source_url=s.get("source_url"),
license_code=s.get("license_code", ""),
license_name=license_info.get("name", s.get("license_code", "")),
license_url=license_info.get("url"),
attribution_required=s.get("attribution_required", True),
attribution_text=s.get("attribution_text", ""),
document_type=s.get("document_type"),
language=s.get("language", "de")
))
return result
@router.get("/sources/available")
async def list_available_sources():
"""List all available source definitions (from DSFA_SOURCES constant)."""
return [
{
"source_code": s["source_code"],
"name": s["name"],
"organization": s.get("organization"),
"license_code": s["license_code"],
"document_type": s.get("document_type")
}
for s in DSFA_SOURCES
]
@router.get("/sources/{source_code}", response_model=DSFASourceResponse)
async def get_dsfa_source(
source_code: str,
store: DSFACorpusStore = Depends(get_store)
):
"""Get details for a specific source."""
source = await store.get_source_by_code(source_code)
if not source:
raise HTTPException(status_code=404, detail=f"Source not found: {source_code}")
license_info = LICENSE_REGISTRY.get(source.get("license_code", ""), {})
return DSFASourceResponse(
id=str(source["id"]),
source_code=source["source_code"],
name=source["name"],
full_name=source.get("full_name"),
organization=source.get("organization"),
source_url=source.get("source_url"),
license_code=source.get("license_code", ""),
license_name=license_info.get("name", source.get("license_code", "")),
license_url=license_info.get("url"),
attribution_required=source.get("attribution_required", True),
attribution_text=source.get("attribution_text", ""),
document_type=source.get("document_type"),
language=source.get("language", "de")
)
@router.post("/sources/{source_code}/ingest", response_model=IngestResponse)
async def ingest_dsfa_source(
source_code: str,
request: IngestRequest,
store: DSFACorpusStore = Depends(get_store),
qdrant: DSFAQdrantService = Depends(get_qdrant)
):
"""
Trigger ingestion for a specific source.
Can provide document via URL or direct text.
"""
# Get source
source = await store.get_source_by_code(source_code)
if not source:
raise HTTPException(status_code=404, detail=f"Source not found: {source_code}")
# Need either URL or text
if not request.document_text and not request.document_url:
raise HTTPException(
status_code=400,
detail="Either document_text or document_url must be provided"
)
# Ensure Qdrant collection exists
await qdrant.ensure_collection()
# Get text content
text_content = request.document_text
if request.document_url and not text_content:
# Download and extract text from URL
logger.info(f"Extracting text from URL: {request.document_url}")
text_content = await extract_text_from_url(request.document_url)
if not text_content:
raise HTTPException(
status_code=400,
detail=f"Could not extract text from URL: {request.document_url}"
)
if not text_content or len(text_content.strip()) < 50:
raise HTTPException(status_code=400, detail="Document text too short (min 50 chars)")
# Create document record
doc_title = request.title or f"Document for {source_code}"
document_id = await store.create_document(
source_id=str(source["id"]),
title=doc_title,
file_type="text",
metadata={"ingested_via": "api", "source_code": source_code}
)
# Chunk the document
chunks = chunk_document(text_content, source_code)
if not chunks:
return IngestResponse(
source_code=source_code,
document_id=document_id,
chunks_created=0,
message="Document created but no chunks generated"
)
# Generate embeddings in batch for efficiency
chunk_texts = [chunk["content"] for chunk in chunks]
logger.info(f"Generating embeddings for {len(chunk_texts)} chunks...")
embeddings = await get_embeddings_batch(chunk_texts)
# Create chunk records in PostgreSQL and prepare for Qdrant
chunk_records = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
# Create chunk in PostgreSQL
chunk_id = await store.create_chunk(
document_id=document_id,
source_id=str(source["id"]),
content=chunk["content"],
chunk_index=i,
section_title=chunk.get("section_title"),
page_number=chunk.get("page_number"),
category=chunk.get("category")
)
chunk_records.append({
"chunk_id": chunk_id,
"document_id": document_id,
"source_id": str(source["id"]),
"content": chunk["content"],
"section_title": chunk.get("section_title"),
"source_code": source_code,
"source_name": source["name"],
"attribution_text": source["attribution_text"],
"license_code": source["license_code"],
"attribution_required": source.get("attribution_required", True),
"document_type": source.get("document_type", ""),
"category": chunk.get("category", ""),
"language": source.get("language", "de"),
"page_number": chunk.get("page_number")
})
# Index in Qdrant
indexed_count = await qdrant.index_chunks(chunk_records, embeddings)
# Update document record
await store.update_document_indexed(document_id, len(chunks))
return IngestResponse(
source_code=source_code,
document_id=document_id,
chunks_created=indexed_count,
message=f"Successfully ingested {indexed_count} chunks from document"
)
@router.get("/chunks/{chunk_id}", response_model=DSFAChunkResponse)
async def get_chunk_with_attribution(
chunk_id: str,
store: DSFACorpusStore = Depends(get_store)
):
"""Get single chunk with full source attribution."""
chunk = await store.get_chunk_with_attribution(chunk_id)
if not chunk:
raise HTTPException(status_code=404, detail=f"Chunk not found: {chunk_id}")
license_info = LICENSE_REGISTRY.get(chunk.get("license_code", ""), {})
return DSFAChunkResponse(
chunk_id=str(chunk["chunk_id"]),
content=chunk.get("content", ""),
section_title=chunk.get("section_title"),
page_number=chunk.get("page_number"),
category=chunk.get("category"),
document_id=str(chunk.get("document_id", "")),
document_title=chunk.get("document_title"),
source_id=str(chunk.get("source_id", "")),
source_code=chunk.get("source_code", ""),
source_name=chunk.get("source_name", ""),
attribution_text=chunk.get("attribution_text", ""),
license_code=chunk.get("license_code", ""),
license_name=license_info.get("name", chunk.get("license_code", "")),
license_url=license_info.get("url"),
attribution_required=chunk.get("attribution_required", True),
source_url=chunk.get("source_url"),
document_type=chunk.get("document_type")
)
@router.get("/stats", response_model=DSFACorpusStatsResponse)
async def get_corpus_stats(
store: DSFACorpusStore = Depends(get_store),
qdrant: DSFAQdrantService = Depends(get_qdrant)
):
"""Get corpus statistics for dashboard."""
# Get PostgreSQL stats
source_stats = await store.get_source_stats()
total_docs = 0
total_chunks = 0
stats_response = []
for s in source_stats:
doc_count = s.get("document_count", 0) or 0
chunk_count = s.get("chunk_count", 0) or 0
total_docs += doc_count
total_chunks += chunk_count
last_indexed = s.get("last_indexed_at")
stats_response.append(DSFASourceStatsResponse(
source_id=str(s.get("source_id", "")),
source_code=s.get("source_code", ""),
name=s.get("name", ""),
organization=s.get("organization"),
license_code=s.get("license_code", ""),
document_type=s.get("document_type"),
document_count=doc_count,
chunk_count=chunk_count,
last_indexed_at=last_indexed.isoformat() if last_indexed else None
))
# Get Qdrant stats
qdrant_stats = await qdrant.get_stats()
return DSFACorpusStatsResponse(
sources=stats_response,
total_sources=len(source_stats),
total_documents=total_docs,
total_chunks=total_chunks,
qdrant_collection=DSFA_COLLECTION,
qdrant_points_count=qdrant_stats.get("points_count", 0),
qdrant_status=qdrant_stats.get("status", "unknown")
)
@router.get("/licenses")
async def list_licenses():
"""List all supported licenses with their terms."""
return [
LicenseInfo(
code=code,
name=info.get("name", code),
url=info.get("url"),
attribution_required=info.get("attribution_required", True),
modification_allowed=info.get("modification_allowed", True),
commercial_use=info.get("commercial_use", True)
)
for code, info in LICENSE_REGISTRY.items()
]
@router.post("/init")
async def initialize_dsfa_corpus(
store: DSFACorpusStore = Depends(get_store),
qdrant: DSFAQdrantService = Depends(get_qdrant)
):
"""
Initialize DSFA corpus.
- Creates Qdrant collection
- Registers all predefined sources
"""
# Ensure Qdrant collection exists
qdrant_ok = await qdrant.ensure_collection()
# Register all sources
registered = 0
for source in DSFA_SOURCES:
try:
await store.register_source(source)
registered += 1
except Exception as e:
print(f"Error registering source {source['source_code']}: {e}")
return {
"qdrant_collection_created": qdrant_ok,
"sources_registered": registered,
"total_sources": len(DSFA_SOURCES)
}
# Embedding utilities
from dsfa_rag_embedding import (
get_embedding,
get_embeddings_batch,
extract_text_from_url,
EMBEDDING_SERVICE_URL,
)
# Routes (router + set_db_pool)
from dsfa_rag_routes import (
router,
set_db_pool,
get_store,
get_qdrant,
)
__all__ = [
# Router
"router",
"set_db_pool",
"get_store",
"get_qdrant",
# Models
"DSFASourceResponse",
"DSFAChunkResponse",
"DSFASearchResultResponse",
"DSFASearchResponse",
"DSFASourceStatsResponse",
"DSFACorpusStatsResponse",
"IngestRequest",
"IngestResponse",
"LicenseInfo",
# Embedding
"get_embedding",
"get_embeddings_batch",
"extract_text_from_url",
"EMBEDDING_SERVICE_URL",
]

View File

@@ -0,0 +1,116 @@
"""
DSFA RAG Embedding Service Integration.
Handles embedding generation, text extraction, and fallback logic.
"""
import os
import hashlib
import logging
import struct
import re
from typing import List
import httpx
logger = logging.getLogger(__name__)
# Embedding service configuration
EMBEDDING_SERVICE_URL = os.getenv("EMBEDDING_SERVICE_URL", "http://172.18.0.13:8087")
async def get_embedding(text: str) -> List[float]:
"""
Get embedding for text using the embedding-service.
Uses BGE-M3 model which produces 1024-dimensional vectors.
"""
async with httpx.AsyncClient(timeout=60.0) as client:
try:
response = await client.post(
f"{EMBEDDING_SERVICE_URL}/embed-single",
json={"text": text}
)
response.raise_for_status()
data = response.json()
return data.get("embedding", [])
except httpx.HTTPError as e:
logger.error(f"Embedding service error: {e}")
# Fallback to hash-based pseudo-embedding for development
return _generate_fallback_embedding(text)
async def get_embeddings_batch(texts: List[str]) -> List[List[float]]:
"""
Get embeddings for multiple texts in batch.
"""
async with httpx.AsyncClient(timeout=120.0) as client:
try:
response = await client.post(
f"{EMBEDDING_SERVICE_URL}/embed",
json={"texts": texts}
)
response.raise_for_status()
data = response.json()
return data.get("embeddings", [])
except httpx.HTTPError as e:
logger.error(f"Embedding service batch error: {e}")
# Fallback
return [_generate_fallback_embedding(t) for t in texts]
async def extract_text_from_url(url: str) -> str:
"""
Extract text from a document URL (PDF, HTML, etc.).
"""
async with httpx.AsyncClient(timeout=120.0) as client:
try:
# First try to use the embedding-service's extract-pdf endpoint
response = await client.post(
f"{EMBEDDING_SERVICE_URL}/extract-pdf",
json={"url": url}
)
response.raise_for_status()
data = response.json()
return data.get("text", "")
except httpx.HTTPError as e:
logger.error(f"PDF extraction error for {url}: {e}")
# Fallback: try to fetch HTML content directly
try:
response = await client.get(url, follow_redirects=True)
response.raise_for_status()
content_type = response.headers.get("content-type", "")
if "html" in content_type:
# Simple HTML text extraction
html = response.text
# Remove scripts and styles
html = re.sub(r'<script[^>]*>.*?</script>', '', html, flags=re.DOTALL | re.IGNORECASE)
html = re.sub(r'<style[^>]*>.*?</style>', '', html, flags=re.DOTALL | re.IGNORECASE)
# Remove tags
text = re.sub(r'<[^>]+>', ' ', html)
# Clean whitespace
text = re.sub(r'\s+', ' ', text).strip()
return text
else:
return ""
except Exception as fetch_err:
logger.error(f"Fallback fetch error for {url}: {fetch_err}")
return ""
def _generate_fallback_embedding(text: str) -> List[float]:
"""
Generate deterministic pseudo-embedding from text hash.
Used as fallback when embedding service is unavailable.
"""
hash_bytes = hashlib.sha256(text.encode()).digest()
embedding = []
for i in range(0, min(len(hash_bytes), 128), 4):
val = struct.unpack('f', hash_bytes[i:i+4])[0]
embedding.append(val % 1.0)
# Pad to 1024 dimensions
while len(embedding) < 1024:
embedding.extend(embedding[:min(len(embedding), 1024 - len(embedding))])
return embedding[:1024]

View File

@@ -0,0 +1,137 @@
"""
DSFA RAG Pydantic Models.
Request/Response models for the DSFA RAG API.
"""
from typing import List, Optional
from pydantic import BaseModel, Field
# =============================================================================
# Response Models
# =============================================================================
class DSFASourceResponse(BaseModel):
"""Response model for DSFA source."""
id: str
source_code: str
name: str
full_name: Optional[str] = None
organization: Optional[str] = None
source_url: Optional[str] = None
license_code: str
license_name: str
license_url: Optional[str] = None
attribution_required: bool
attribution_text: str
document_type: Optional[str] = None
language: str = "de"
class DSFAChunkResponse(BaseModel):
"""Response model for a single chunk with attribution."""
chunk_id: str
content: str
section_title: Optional[str] = None
page_number: Optional[int] = None
category: Optional[str] = None
# Document info
document_id: str
document_title: Optional[str] = None
# Attribution (always included)
source_id: str
source_code: str
source_name: str
attribution_text: str
license_code: str
license_name: str
license_url: Optional[str] = None
attribution_required: bool
source_url: Optional[str] = None
document_type: Optional[str] = None
class DSFASearchResultResponse(BaseModel):
"""Response model for search result."""
chunk_id: str
content: str
score: float
# Attribution
source_code: str
source_name: str
attribution_text: str
license_code: str
license_name: str
license_url: Optional[str] = None
attribution_required: bool
source_url: Optional[str] = None
# Metadata
document_type: Optional[str] = None
category: Optional[str] = None
section_title: Optional[str] = None
page_number: Optional[int] = None
class DSFASearchResponse(BaseModel):
"""Response model for search endpoint."""
query: str
results: List[DSFASearchResultResponse]
total_results: int
# Aggregated licenses for footer
licenses_used: List[str]
attribution_notice: str
class DSFASourceStatsResponse(BaseModel):
"""Response model for source statistics."""
source_id: str
source_code: str
name: str
organization: Optional[str] = None
license_code: str
document_type: Optional[str] = None
document_count: int
chunk_count: int
last_indexed_at: Optional[str] = None
class DSFACorpusStatsResponse(BaseModel):
"""Response model for corpus statistics."""
sources: List[DSFASourceStatsResponse]
total_sources: int
total_documents: int
total_chunks: int
qdrant_collection: str
qdrant_points_count: int
qdrant_status: str
class IngestRequest(BaseModel):
"""Request model for ingestion."""
document_url: Optional[str] = None
document_text: Optional[str] = None
title: Optional[str] = None
class IngestResponse(BaseModel):
"""Response model for ingestion."""
source_code: str
document_id: Optional[str] = None
chunks_created: int
message: str
class LicenseInfo(BaseModel):
"""License information."""
code: str
name: str
url: Optional[str] = None
attribution_required: bool
modification_allowed: bool
commercial_use: bool

View File

@@ -0,0 +1,461 @@
"""
DSFA RAG API Route Handlers.
Endpoint implementations for search, sources, ingestion, stats, and init.
"""
import logging
from typing import List, Optional
from fastapi import APIRouter, HTTPException, Query, Depends
from dsfa_corpus_ingestion import (
DSFACorpusStore,
DSFAQdrantService,
DSFASearchResult,
LICENSE_REGISTRY,
DSFA_SOURCES,
generate_attribution_notice,
get_license_label,
DSFA_COLLECTION,
chunk_document,
)
from dsfa_rag_models import (
DSFASourceResponse,
DSFAChunkResponse,
DSFASearchResultResponse,
DSFASearchResponse,
DSFASourceStatsResponse,
DSFACorpusStatsResponse,
IngestRequest,
IngestResponse,
LicenseInfo,
)
from dsfa_rag_embedding import (
get_embedding,
get_embeddings_batch,
extract_text_from_url,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/dsfa-rag", tags=["DSFA RAG"])
# =============================================================================
# Dependency Injection
# =============================================================================
_db_pool = None
def set_db_pool(pool):
"""Set the database pool for API endpoints."""
global _db_pool
_db_pool = pool
async def get_store() -> DSFACorpusStore:
"""Get DSFA corpus store."""
if _db_pool is None:
raise HTTPException(status_code=503, detail="Database not initialized")
return DSFACorpusStore(_db_pool)
async def get_qdrant() -> DSFAQdrantService:
"""Get Qdrant service."""
return DSFAQdrantService()
# =============================================================================
# API Endpoints
# =============================================================================
@router.get("/search", response_model=DSFASearchResponse)
async def search_dsfa_corpus(
query: str = Query(..., min_length=3, description="Search query"),
source_codes: Optional[List[str]] = Query(None, description="Filter by source codes"),
document_types: Optional[List[str]] = Query(None, description="Filter by document types (guideline, checklist, regulation)"),
categories: Optional[List[str]] = Query(None, description="Filter by categories (threshold_analysis, risk_assessment, mitigation)"),
limit: int = Query(10, ge=1, le=50, description="Maximum results"),
include_attribution: bool = Query(True, description="Include attribution in results"),
store: DSFACorpusStore = Depends(get_store),
qdrant: DSFAQdrantService = Depends(get_qdrant)
):
"""
Search DSFA corpus with full attribution.
Returns matching chunks with source/license information for compliance.
"""
query_embedding = await get_embedding(query)
raw_results = await qdrant.search(
query_embedding=query_embedding,
source_codes=source_codes,
document_types=document_types,
categories=categories,
limit=limit
)
results = []
licenses_used = set()
for r in raw_results:
license_code = r.get("license_code", "")
license_info = LICENSE_REGISTRY.get(license_code, {})
result = DSFASearchResultResponse(
chunk_id=r.get("chunk_id", ""),
content=r.get("content", ""),
score=r.get("score", 0.0),
source_code=r.get("source_code", ""),
source_name=r.get("source_name", ""),
attribution_text=r.get("attribution_text", ""),
license_code=license_code,
license_name=license_info.get("name", license_code),
license_url=license_info.get("url"),
attribution_required=r.get("attribution_required", True),
source_url=r.get("source_url"),
document_type=r.get("document_type"),
category=r.get("category"),
section_title=r.get("section_title"),
page_number=r.get("page_number")
)
results.append(result)
licenses_used.add(license_code)
# Generate attribution notice
search_results = [
DSFASearchResult(
chunk_id=r.chunk_id,
content=r.content,
score=r.score,
source_code=r.source_code,
source_name=r.source_name,
attribution_text=r.attribution_text,
license_code=r.license_code,
license_url=r.license_url,
attribution_required=r.attribution_required,
source_url=r.source_url,
document_type=r.document_type or "",
category=r.category or "",
section_title=r.section_title,
page_number=r.page_number
)
for r in results
]
attribution_notice = generate_attribution_notice(search_results) if include_attribution else ""
return DSFASearchResponse(
query=query,
results=results,
total_results=len(results),
licenses_used=list(licenses_used),
attribution_notice=attribution_notice
)
@router.get("/sources", response_model=List[DSFASourceResponse])
async def list_dsfa_sources(
document_type: Optional[str] = Query(None, description="Filter by document type"),
license_code: Optional[str] = Query(None, description="Filter by license"),
store: DSFACorpusStore = Depends(get_store)
):
"""List all registered DSFA sources with license info."""
sources = await store.list_sources()
result = []
for s in sources:
if document_type and s.get("document_type") != document_type:
continue
if license_code and s.get("license_code") != license_code:
continue
license_info = LICENSE_REGISTRY.get(s.get("license_code", ""), {})
result.append(DSFASourceResponse(
id=str(s["id"]),
source_code=s["source_code"],
name=s["name"],
full_name=s.get("full_name"),
organization=s.get("organization"),
source_url=s.get("source_url"),
license_code=s.get("license_code", ""),
license_name=license_info.get("name", s.get("license_code", "")),
license_url=license_info.get("url"),
attribution_required=s.get("attribution_required", True),
attribution_text=s.get("attribution_text", ""),
document_type=s.get("document_type"),
language=s.get("language", "de")
))
return result
@router.get("/sources/available")
async def list_available_sources():
"""List all available source definitions (from DSFA_SOURCES constant)."""
return [
{
"source_code": s["source_code"],
"name": s["name"],
"organization": s.get("organization"),
"license_code": s["license_code"],
"document_type": s.get("document_type")
}
for s in DSFA_SOURCES
]
@router.get("/sources/{source_code}", response_model=DSFASourceResponse)
async def get_dsfa_source(
source_code: str,
store: DSFACorpusStore = Depends(get_store)
):
"""Get details for a specific source."""
source = await store.get_source_by_code(source_code)
if not source:
raise HTTPException(status_code=404, detail=f"Source not found: {source_code}")
license_info = LICENSE_REGISTRY.get(source.get("license_code", ""), {})
return DSFASourceResponse(
id=str(source["id"]),
source_code=source["source_code"],
name=source["name"],
full_name=source.get("full_name"),
organization=source.get("organization"),
source_url=source.get("source_url"),
license_code=source.get("license_code", ""),
license_name=license_info.get("name", source.get("license_code", "")),
license_url=license_info.get("url"),
attribution_required=source.get("attribution_required", True),
attribution_text=source.get("attribution_text", ""),
document_type=source.get("document_type"),
language=source.get("language", "de")
)
@router.post("/sources/{source_code}/ingest", response_model=IngestResponse)
async def ingest_dsfa_source(
source_code: str,
request: IngestRequest,
store: DSFACorpusStore = Depends(get_store),
qdrant: DSFAQdrantService = Depends(get_qdrant)
):
"""
Trigger ingestion for a specific source.
Can provide document via URL or direct text.
"""
source = await store.get_source_by_code(source_code)
if not source:
raise HTTPException(status_code=404, detail=f"Source not found: {source_code}")
if not request.document_text and not request.document_url:
raise HTTPException(
status_code=400,
detail="Either document_text or document_url must be provided"
)
await qdrant.ensure_collection()
text_content = request.document_text
if request.document_url and not text_content:
logger.info(f"Extracting text from URL: {request.document_url}")
text_content = await extract_text_from_url(request.document_url)
if not text_content:
raise HTTPException(
status_code=400,
detail=f"Could not extract text from URL: {request.document_url}"
)
if not text_content or len(text_content.strip()) < 50:
raise HTTPException(status_code=400, detail="Document text too short (min 50 chars)")
doc_title = request.title or f"Document for {source_code}"
document_id = await store.create_document(
source_id=str(source["id"]),
title=doc_title,
file_type="text",
metadata={"ingested_via": "api", "source_code": source_code}
)
chunks = chunk_document(text_content, source_code)
if not chunks:
return IngestResponse(
source_code=source_code,
document_id=document_id,
chunks_created=0,
message="Document created but no chunks generated"
)
chunk_texts = [chunk["content"] for chunk in chunks]
logger.info(f"Generating embeddings for {len(chunk_texts)} chunks...")
embeddings = await get_embeddings_batch(chunk_texts)
chunk_records = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
chunk_id = await store.create_chunk(
document_id=document_id,
source_id=str(source["id"]),
content=chunk["content"],
chunk_index=i,
section_title=chunk.get("section_title"),
page_number=chunk.get("page_number"),
category=chunk.get("category")
)
chunk_records.append({
"chunk_id": chunk_id,
"document_id": document_id,
"source_id": str(source["id"]),
"content": chunk["content"],
"section_title": chunk.get("section_title"),
"source_code": source_code,
"source_name": source["name"],
"attribution_text": source["attribution_text"],
"license_code": source["license_code"],
"attribution_required": source.get("attribution_required", True),
"document_type": source.get("document_type", ""),
"category": chunk.get("category", ""),
"language": source.get("language", "de"),
"page_number": chunk.get("page_number")
})
indexed_count = await qdrant.index_chunks(chunk_records, embeddings)
await store.update_document_indexed(document_id, len(chunks))
return IngestResponse(
source_code=source_code,
document_id=document_id,
chunks_created=indexed_count,
message=f"Successfully ingested {indexed_count} chunks from document"
)
@router.get("/chunks/{chunk_id}", response_model=DSFAChunkResponse)
async def get_chunk_with_attribution(
chunk_id: str,
store: DSFACorpusStore = Depends(get_store)
):
"""Get single chunk with full source attribution."""
chunk = await store.get_chunk_with_attribution(chunk_id)
if not chunk:
raise HTTPException(status_code=404, detail=f"Chunk not found: {chunk_id}")
license_info = LICENSE_REGISTRY.get(chunk.get("license_code", ""), {})
return DSFAChunkResponse(
chunk_id=str(chunk["chunk_id"]),
content=chunk.get("content", ""),
section_title=chunk.get("section_title"),
page_number=chunk.get("page_number"),
category=chunk.get("category"),
document_id=str(chunk.get("document_id", "")),
document_title=chunk.get("document_title"),
source_id=str(chunk.get("source_id", "")),
source_code=chunk.get("source_code", ""),
source_name=chunk.get("source_name", ""),
attribution_text=chunk.get("attribution_text", ""),
license_code=chunk.get("license_code", ""),
license_name=license_info.get("name", chunk.get("license_code", "")),
license_url=license_info.get("url"),
attribution_required=chunk.get("attribution_required", True),
source_url=chunk.get("source_url"),
document_type=chunk.get("document_type")
)
@router.get("/stats", response_model=DSFACorpusStatsResponse)
async def get_corpus_stats(
store: DSFACorpusStore = Depends(get_store),
qdrant: DSFAQdrantService = Depends(get_qdrant)
):
"""Get corpus statistics for dashboard."""
source_stats = await store.get_source_stats()
total_docs = 0
total_chunks = 0
stats_response = []
for s in source_stats:
doc_count = s.get("document_count", 0) or 0
chunk_count = s.get("chunk_count", 0) or 0
total_docs += doc_count
total_chunks += chunk_count
last_indexed = s.get("last_indexed_at")
stats_response.append(DSFASourceStatsResponse(
source_id=str(s.get("source_id", "")),
source_code=s.get("source_code", ""),
name=s.get("name", ""),
organization=s.get("organization"),
license_code=s.get("license_code", ""),
document_type=s.get("document_type"),
document_count=doc_count,
chunk_count=chunk_count,
last_indexed_at=last_indexed.isoformat() if last_indexed else None
))
qdrant_stats = await qdrant.get_stats()
return DSFACorpusStatsResponse(
sources=stats_response,
total_sources=len(source_stats),
total_documents=total_docs,
total_chunks=total_chunks,
qdrant_collection=DSFA_COLLECTION,
qdrant_points_count=qdrant_stats.get("points_count", 0),
qdrant_status=qdrant_stats.get("status", "unknown")
)
@router.get("/licenses")
async def list_licenses():
"""List all supported licenses with their terms."""
return [
LicenseInfo(
code=code,
name=info.get("name", code),
url=info.get("url"),
attribution_required=info.get("attribution_required", True),
modification_allowed=info.get("modification_allowed", True),
commercial_use=info.get("commercial_use", True)
)
for code, info in LICENSE_REGISTRY.items()
]
@router.post("/init")
async def initialize_dsfa_corpus(
store: DSFACorpusStore = Depends(get_store),
qdrant: DSFAQdrantService = Depends(get_qdrant)
):
"""
Initialize DSFA corpus.
- Creates Qdrant collection
- Registers all predefined sources
"""
qdrant_ok = await qdrant.ensure_collection()
registered = 0
for source in DSFA_SOURCES:
try:
await store.register_source(source)
registered += 1
except Exception as e:
print(f"Error registering source {source['source_code']}: {e}")
return {
"qdrant_collection_created": qdrant_ok,
"sources_registered": registered,
"total_sources": len(DSFA_SOURCES)
}

View File

@@ -1,31 +1,19 @@
#!/usr/bin/env python3
"""
Full Compliance Pipeline for Legal Corpus.
Full Compliance Pipeline for Legal Corpus — Barrel Re-export.
This script runs the complete pipeline:
1. Re-ingest all legal documents with improved chunking
2. Extract requirements/checkpoints from chunks
3. Generate controls using AI
4. Define remediation measures
5. Update statistics
Split into submodules:
- compliance_models.py — Dataclasses (Checkpoint, Control, Measure)
- compliance_extraction.py — Pattern extraction & control/measure generation
- compliance_pipeline.py — Pipeline phases & orchestrator
Run on Mac Mini:
nohup python full_compliance_pipeline.py > /tmp/compliance_pipeline.log 2>&1 &
Checkpoints are saved to /tmp/pipeline_checkpoints.json and can be viewed in admin-v2.
"""
import asyncio
import json
import logging
import os
import sys
import time
from datetime import datetime
from typing import Dict, List, Any, Optional
from dataclasses import dataclass, asdict
import re
import hashlib
# Configure logging
logging.basicConfig(
@@ -36,671 +24,25 @@ logging.basicConfig(
logging.FileHandler('/tmp/compliance_pipeline.log')
]
)
logger = logging.getLogger(__name__)
# Import checkpoint manager
try:
from pipeline_checkpoints import CheckpointManager, EXPECTED_VALUES, ValidationStatus
except ImportError:
logger.warning("Checkpoint manager not available, running without checkpoints")
CheckpointManager = None
EXPECTED_VALUES = {}
ValidationStatus = None
# Set environment variables for Docker network
# Support both QDRANT_URL and QDRANT_HOST
if not os.getenv("QDRANT_URL") and not os.getenv("QDRANT_HOST"):
os.environ["QDRANT_HOST"] = "qdrant"
os.environ.setdefault("EMBEDDING_SERVICE_URL", "http://embedding-service:8087")
# Try to import from klausur-service
try:
from legal_corpus_ingestion import LegalCorpusIngestion, REGULATIONS, LEGAL_CORPUS_COLLECTION
from qdrant_client import QdrantClient
from qdrant_client.models import Filter, FieldCondition, MatchValue
except ImportError:
logger.error("Could not import required modules. Make sure you're in the klausur-service container.")
sys.exit(1)
@dataclass
class Checkpoint:
"""A requirement/checkpoint extracted from legal text."""
id: str
regulation_code: str
regulation_name: str
article: Optional[str]
title: str
description: str
original_text: str
chunk_id: str
source_url: str
@dataclass
class Control:
"""A control derived from checkpoints."""
id: str
domain: str
title: str
description: str
checkpoints: List[str] # List of checkpoint IDs
pass_criteria: str
implementation_guidance: str
is_automated: bool
automation_tool: Optional[str]
priority: str
@dataclass
class Measure:
"""A remediation measure for a control."""
id: str
control_id: str
title: str
description: str
responsible: str
deadline_days: int
status: str
class CompliancePipeline:
"""Handles the full compliance pipeline."""
def __init__(self):
# Support both QDRANT_URL and QDRANT_HOST/PORT
qdrant_url = os.getenv("QDRANT_URL", "")
if qdrant_url:
from urllib.parse import urlparse
parsed = urlparse(qdrant_url)
qdrant_host = parsed.hostname or "qdrant"
qdrant_port = parsed.port or 6333
else:
qdrant_host = os.getenv("QDRANT_HOST", "qdrant")
qdrant_port = 6333
self.qdrant = QdrantClient(host=qdrant_host, port=qdrant_port)
self.checkpoints: List[Checkpoint] = []
self.controls: List[Control] = []
self.measures: List[Measure] = []
self.stats = {
"chunks_processed": 0,
"checkpoints_extracted": 0,
"controls_created": 0,
"measures_defined": 0,
"by_regulation": {},
"by_domain": {},
}
# Initialize checkpoint manager
self.checkpoint_mgr = CheckpointManager() if CheckpointManager else None
def extract_checkpoints_from_chunk(self, chunk_text: str, payload: Dict) -> List[Checkpoint]:
"""
Extract checkpoints/requirements from a chunk of text.
Uses pattern matching to find requirement-like statements.
"""
checkpoints = []
regulation_code = payload.get("regulation_code", "UNKNOWN")
regulation_name = payload.get("regulation_name", "Unknown")
source_url = payload.get("source_url", "")
chunk_id = hashlib.md5(chunk_text[:100].encode()).hexdigest()[:8]
# Patterns for different requirement types
patterns = [
# BSI-TR patterns
(r'([OT]\.[A-Za-z_]+\d*)[:\s]+(.+?)(?=\n[OT]\.|$)', 'bsi_requirement'),
# Article patterns (GDPR, AI Act, etc.)
(r'(?:Artikel|Art\.?)\s+(\d+)(?:\s+Abs(?:atz)?\.?\s*(\d+))?\s*[-:]\s*(.+?)(?=\n|$)', 'article'),
# Numbered requirements
(r'\((\d+)\)\s+(.+?)(?=\n\(\d+\)|$)', 'numbered'),
# "Der Verantwortliche muss" patterns
(r'(?:Der Verantwortliche|Die Aufsichtsbehörde|Der Auftragsverarbeiter)\s+(muss|hat|soll)\s+(.+?)(?=\.\s|$)', 'obligation'),
# "Es ist erforderlich" patterns
(r'(?:Es ist erforderlich|Es muss gewährleistet|Es sind geeignete)\s+(.+?)(?=\.\s|$)', 'requirement'),
]
for pattern, pattern_type in patterns:
matches = re.finditer(pattern, chunk_text, re.MULTILINE | re.DOTALL)
for match in matches:
if pattern_type == 'bsi_requirement':
req_id = match.group(1)
description = match.group(2).strip()
title = req_id
elif pattern_type == 'article':
article_num = match.group(1)
paragraph = match.group(2) or ""
title_text = match.group(3).strip()
req_id = f"{regulation_code}-Art{article_num}"
if paragraph:
req_id += f"-{paragraph}"
title = f"Art. {article_num}" + (f" Abs. {paragraph}" if paragraph else "")
description = title_text
elif pattern_type == 'numbered':
num = match.group(1)
description = match.group(2).strip()
req_id = f"{regulation_code}-{num}"
title = f"Anforderung {num}"
else:
# Generic requirement
description = match.group(0).strip()
req_id = f"{regulation_code}-{chunk_id}-{len(checkpoints)}"
title = description[:50] + "..." if len(description) > 50 else description
# Skip very short matches
if len(description) < 20:
continue
checkpoint = Checkpoint(
id=req_id,
regulation_code=regulation_code,
regulation_name=regulation_name,
article=title if 'Art' in title else None,
title=title,
description=description[:500],
original_text=description,
chunk_id=chunk_id,
source_url=source_url
)
checkpoints.append(checkpoint)
return checkpoints
def generate_control_for_checkpoints(self, checkpoints: List[Checkpoint]) -> Optional[Control]:
"""
Generate a control that covers the given checkpoints.
This is a simplified version - in production this would use the AI assistant.
"""
if not checkpoints:
return None
# Group by regulation
regulation = checkpoints[0].regulation_code
# Determine domain based on content
all_text = " ".join([cp.description for cp in checkpoints]).lower()
domain = "gov" # Default
if any(kw in all_text for kw in ["verschlüssel", "krypto", "encrypt", "hash"]):
domain = "crypto"
elif any(kw in all_text for kw in ["zugang", "access", "authentif", "login", "benutzer"]):
domain = "iam"
elif any(kw in all_text for kw in ["datenschutz", "personenbezogen", "privacy", "einwilligung"]):
domain = "priv"
elif any(kw in all_text for kw in ["entwicklung", "test", "code", "software"]):
domain = "sdlc"
elif any(kw in all_text for kw in ["überwach", "monitor", "log", "audit"]):
domain = "aud"
elif any(kw in all_text for kw in ["ki", "künstlich", "ai", "machine learning", "model"]):
domain = "ai"
elif any(kw in all_text for kw in ["betrieb", "operation", "verfügbar", "backup"]):
domain = "ops"
elif any(kw in all_text for kw in ["cyber", "resilience", "sbom", "vulnerab"]):
domain = "cra"
# Generate control ID
domain_counts = self.stats.get("by_domain", {})
domain_count = domain_counts.get(domain, 0) + 1
control_id = f"{domain.upper()}-{domain_count:03d}"
# Create title from first checkpoint
title = checkpoints[0].title
if len(title) > 100:
title = title[:97] + "..."
# Create description
description = f"Control für {regulation}: " + checkpoints[0].description[:200]
# Pass criteria
pass_criteria = f"Alle {len(checkpoints)} zugehörigen Anforderungen sind erfüllt und dokumentiert."
# Implementation guidance
guidance = f"Implementiere Maßnahmen zur Erfüllung der Anforderungen aus {regulation}. "
guidance += f"Dokumentiere die Umsetzung und führe regelmäßige Reviews durch."
# Determine if automated
is_automated = any(kw in all_text for kw in ["automat", "tool", "scan", "test"])
control = Control(
id=control_id,
domain=domain,
title=title,
description=description,
checkpoints=[cp.id for cp in checkpoints],
pass_criteria=pass_criteria,
implementation_guidance=guidance,
is_automated=is_automated,
automation_tool="CI/CD Pipeline" if is_automated else None,
priority="high" if "muss" in all_text or "erforderlich" in all_text else "medium"
)
return control
def generate_measure_for_control(self, control: Control) -> Measure:
"""Generate a remediation measure for a control."""
measure_id = f"M-{control.id}"
# Determine deadline based on priority
deadline_days = {
"critical": 30,
"high": 60,
"medium": 90,
"low": 180
}.get(control.priority, 90)
# Determine responsible team
responsible = {
"priv": "Datenschutzbeauftragter",
"iam": "IT-Security Team",
"sdlc": "Entwicklungsteam",
"crypto": "IT-Security Team",
"ops": "Operations Team",
"aud": "Compliance Team",
"ai": "AI/ML Team",
"cra": "IT-Security Team",
"gov": "Management"
}.get(control.domain, "Compliance Team")
measure = Measure(
id=measure_id,
control_id=control.id,
title=f"Umsetzung: {control.title[:50]}",
description=f"Implementierung und Dokumentation von {control.id}: {control.description[:100]}",
responsible=responsible,
deadline_days=deadline_days,
status="pending"
)
return measure
async def run_ingestion_phase(self, force_reindex: bool = False) -> int:
"""Phase 1: Ingest documents (incremental - only missing ones)."""
logger.info("\n" + "=" * 60)
logger.info("PHASE 1: DOCUMENT INGESTION (INCREMENTAL)")
logger.info("=" * 60)
if self.checkpoint_mgr:
self.checkpoint_mgr.start_checkpoint("ingestion", "Document Ingestion")
ingestion = LegalCorpusIngestion()
try:
# Check existing chunks per regulation
existing_chunks = {}
try:
for regulation in REGULATIONS:
count_result = self.qdrant.count(
collection_name=LEGAL_CORPUS_COLLECTION,
count_filter=Filter(
must=[FieldCondition(key="regulation_code", match=MatchValue(value=regulation.code))]
)
)
existing_chunks[regulation.code] = count_result.count
logger.info(f" {regulation.code}: {count_result.count} existing chunks")
except Exception as e:
logger.warning(f"Could not check existing chunks: {e}")
# Collection might not exist, that's OK
# Determine which regulations need ingestion
regulations_to_ingest = []
for regulation in REGULATIONS:
existing = existing_chunks.get(regulation.code, 0)
if force_reindex or existing == 0:
regulations_to_ingest.append(regulation)
logger.info(f" -> Will ingest: {regulation.code} (existing: {existing}, force: {force_reindex})")
else:
logger.info(f" -> Skipping: {regulation.code} (already has {existing} chunks)")
self.stats["by_regulation"][regulation.code] = existing
if not regulations_to_ingest:
logger.info("All regulations already indexed. Skipping ingestion phase.")
total_chunks = sum(existing_chunks.values())
self.stats["chunks_processed"] = total_chunks
if self.checkpoint_mgr:
self.checkpoint_mgr.add_metric("total_chunks", total_chunks)
self.checkpoint_mgr.add_metric("skipped", True)
self.checkpoint_mgr.complete_checkpoint(success=True)
return total_chunks
# Ingest only missing regulations
total_chunks = sum(existing_chunks.values())
for i, regulation in enumerate(regulations_to_ingest, 1):
logger.info(f"[{i}/{len(regulations_to_ingest)}] Ingesting {regulation.code}...")
try:
count = await ingestion.ingest_regulation(regulation)
total_chunks += count
self.stats["by_regulation"][regulation.code] = count
logger.info(f" -> {count} chunks")
# Add metric for this regulation
if self.checkpoint_mgr:
self.checkpoint_mgr.add_metric(f"chunks_{regulation.code}", count)
except Exception as e:
logger.error(f" -> FAILED: {e}")
self.stats["by_regulation"][regulation.code] = 0
self.stats["chunks_processed"] = total_chunks
logger.info(f"\nTotal chunks in collection: {total_chunks}")
# Validate ingestion results
if self.checkpoint_mgr:
self.checkpoint_mgr.add_metric("total_chunks", total_chunks)
self.checkpoint_mgr.add_metric("regulations_count", len(REGULATIONS))
# Validate total chunks
expected = EXPECTED_VALUES.get("ingestion", {})
self.checkpoint_mgr.validate(
"total_chunks",
expected=expected.get("total_chunks", 8000),
actual=total_chunks,
min_value=expected.get("min_chunks", 7000)
)
# Validate key regulations
reg_expected = expected.get("regulations", {})
for reg_code, reg_exp in reg_expected.items():
actual = self.stats["by_regulation"].get(reg_code, 0)
self.checkpoint_mgr.validate(
f"chunks_{reg_code}",
expected=reg_exp.get("expected", 0),
actual=actual,
min_value=reg_exp.get("min", 0)
)
self.checkpoint_mgr.complete_checkpoint(success=True)
return total_chunks
except Exception as e:
if self.checkpoint_mgr:
self.checkpoint_mgr.fail_checkpoint(str(e))
raise
finally:
await ingestion.close()
async def run_extraction_phase(self) -> int:
"""Phase 2: Extract checkpoints from chunks."""
logger.info("\n" + "=" * 60)
logger.info("PHASE 2: CHECKPOINT EXTRACTION")
logger.info("=" * 60)
if self.checkpoint_mgr:
self.checkpoint_mgr.start_checkpoint("extraction", "Checkpoint Extraction")
try:
# Scroll through all chunks
offset = None
total_checkpoints = 0
while True:
result = self.qdrant.scroll(
collection_name=LEGAL_CORPUS_COLLECTION,
limit=100,
offset=offset,
with_payload=True,
with_vectors=False
)
points, next_offset = result
if not points:
break
for point in points:
payload = point.payload
text = payload.get("text", "")
checkpoints = self.extract_checkpoints_from_chunk(text, payload)
self.checkpoints.extend(checkpoints)
total_checkpoints += len(checkpoints)
logger.info(f"Processed {len(points)} chunks, extracted {total_checkpoints} checkpoints so far...")
if next_offset is None:
break
offset = next_offset
self.stats["checkpoints_extracted"] = len(self.checkpoints)
logger.info(f"\nTotal checkpoints extracted: {len(self.checkpoints)}")
# Log per regulation
by_reg = {}
for cp in self.checkpoints:
by_reg[cp.regulation_code] = by_reg.get(cp.regulation_code, 0) + 1
for reg, count in sorted(by_reg.items()):
logger.info(f" {reg}: {count} checkpoints")
# Validate extraction results
if self.checkpoint_mgr:
self.checkpoint_mgr.add_metric("total_checkpoints", len(self.checkpoints))
self.checkpoint_mgr.add_metric("checkpoints_by_regulation", by_reg)
expected = EXPECTED_VALUES.get("extraction", {})
self.checkpoint_mgr.validate(
"total_checkpoints",
expected=expected.get("total_checkpoints", 3500),
actual=len(self.checkpoints),
min_value=expected.get("min_checkpoints", 3000)
)
self.checkpoint_mgr.complete_checkpoint(success=True)
return len(self.checkpoints)
except Exception as e:
if self.checkpoint_mgr:
self.checkpoint_mgr.fail_checkpoint(str(e))
raise
async def run_control_generation_phase(self) -> int:
"""Phase 3: Generate controls from checkpoints."""
logger.info("\n" + "=" * 60)
logger.info("PHASE 3: CONTROL GENERATION")
logger.info("=" * 60)
if self.checkpoint_mgr:
self.checkpoint_mgr.start_checkpoint("controls", "Control Generation")
try:
# Group checkpoints by regulation
by_regulation: Dict[str, List[Checkpoint]] = {}
for cp in self.checkpoints:
reg = cp.regulation_code
if reg not in by_regulation:
by_regulation[reg] = []
by_regulation[reg].append(cp)
# Generate controls per regulation (group every 3-5 checkpoints)
for regulation, checkpoints in by_regulation.items():
logger.info(f"Generating controls for {regulation} ({len(checkpoints)} checkpoints)...")
# Group checkpoints into batches of 3-5
batch_size = 4
for i in range(0, len(checkpoints), batch_size):
batch = checkpoints[i:i + batch_size]
control = self.generate_control_for_checkpoints(batch)
if control:
self.controls.append(control)
self.stats["by_domain"][control.domain] = self.stats["by_domain"].get(control.domain, 0) + 1
self.stats["controls_created"] = len(self.controls)
logger.info(f"\nTotal controls created: {len(self.controls)}")
# Log per domain
for domain, count in sorted(self.stats["by_domain"].items()):
logger.info(f" {domain}: {count} controls")
# Validate control generation
if self.checkpoint_mgr:
self.checkpoint_mgr.add_metric("total_controls", len(self.controls))
self.checkpoint_mgr.add_metric("controls_by_domain", dict(self.stats["by_domain"]))
expected = EXPECTED_VALUES.get("controls", {})
self.checkpoint_mgr.validate(
"total_controls",
expected=expected.get("total_controls", 900),
actual=len(self.controls),
min_value=expected.get("min_controls", 800)
)
self.checkpoint_mgr.complete_checkpoint(success=True)
return len(self.controls)
except Exception as e:
if self.checkpoint_mgr:
self.checkpoint_mgr.fail_checkpoint(str(e))
raise
async def run_measure_generation_phase(self) -> int:
"""Phase 4: Generate measures for controls."""
logger.info("\n" + "=" * 60)
logger.info("PHASE 4: MEASURE GENERATION")
logger.info("=" * 60)
if self.checkpoint_mgr:
self.checkpoint_mgr.start_checkpoint("measures", "Measure Generation")
try:
for control in self.controls:
measure = self.generate_measure_for_control(control)
self.measures.append(measure)
self.stats["measures_defined"] = len(self.measures)
logger.info(f"\nTotal measures defined: {len(self.measures)}")
# Validate measure generation
if self.checkpoint_mgr:
self.checkpoint_mgr.add_metric("total_measures", len(self.measures))
expected = EXPECTED_VALUES.get("measures", {})
self.checkpoint_mgr.validate(
"total_measures",
expected=expected.get("total_measures", 900),
actual=len(self.measures),
min_value=expected.get("min_measures", 800)
)
self.checkpoint_mgr.complete_checkpoint(success=True)
return len(self.measures)
except Exception as e:
if self.checkpoint_mgr:
self.checkpoint_mgr.fail_checkpoint(str(e))
raise
def save_results(self, output_dir: str = "/tmp/compliance_output"):
"""Save results to JSON files."""
logger.info("\n" + "=" * 60)
logger.info("SAVING RESULTS")
logger.info("=" * 60)
os.makedirs(output_dir, exist_ok=True)
# Save checkpoints
checkpoints_file = os.path.join(output_dir, "checkpoints.json")
with open(checkpoints_file, "w") as f:
json.dump([asdict(cp) for cp in self.checkpoints], f, indent=2, ensure_ascii=False)
logger.info(f"Saved {len(self.checkpoints)} checkpoints to {checkpoints_file}")
# Save controls
controls_file = os.path.join(output_dir, "controls.json")
with open(controls_file, "w") as f:
json.dump([asdict(c) for c in self.controls], f, indent=2, ensure_ascii=False)
logger.info(f"Saved {len(self.controls)} controls to {controls_file}")
# Save measures
measures_file = os.path.join(output_dir, "measures.json")
with open(measures_file, "w") as f:
json.dump([asdict(m) for m in self.measures], f, indent=2, ensure_ascii=False)
logger.info(f"Saved {len(self.measures)} measures to {measures_file}")
# Save statistics
stats_file = os.path.join(output_dir, "statistics.json")
self.stats["generated_at"] = datetime.now().isoformat()
with open(stats_file, "w") as f:
json.dump(self.stats, f, indent=2, ensure_ascii=False)
logger.info(f"Saved statistics to {stats_file}")
async def run_full_pipeline(self, force_reindex: bool = False, skip_ingestion: bool = False):
"""Run the complete pipeline.
Args:
force_reindex: If True, re-ingest all documents even if they exist
skip_ingestion: If True, skip ingestion phase entirely (use existing chunks)
"""
start_time = time.time()
logger.info("=" * 60)
logger.info("FULL COMPLIANCE PIPELINE (INCREMENTAL)")
logger.info(f"Started at: {datetime.now().isoformat()}")
logger.info(f"Force reindex: {force_reindex}")
logger.info(f"Skip ingestion: {skip_ingestion}")
if self.checkpoint_mgr:
logger.info(f"Pipeline ID: {self.checkpoint_mgr.pipeline_id}")
logger.info("=" * 60)
try:
# Phase 1: Ingestion (skip if requested or run incrementally)
if skip_ingestion:
logger.info("Skipping ingestion phase as requested...")
# Still get the chunk count
try:
collection_info = self.qdrant.get_collection(LEGAL_CORPUS_COLLECTION)
self.stats["chunks_processed"] = collection_info.points_count
except Exception:
self.stats["chunks_processed"] = 0
else:
await self.run_ingestion_phase(force_reindex=force_reindex)
# Phase 2: Extraction
await self.run_extraction_phase()
# Phase 3: Control Generation
await self.run_control_generation_phase()
# Phase 4: Measure Generation
await self.run_measure_generation_phase()
# Save results
self.save_results()
# Final summary
elapsed = time.time() - start_time
logger.info("\n" + "=" * 60)
logger.info("PIPELINE COMPLETE")
logger.info("=" * 60)
logger.info(f"Duration: {elapsed:.1f} seconds")
logger.info(f"Chunks processed: {self.stats['chunks_processed']}")
logger.info(f"Checkpoints extracted: {self.stats['checkpoints_extracted']}")
logger.info(f"Controls created: {self.stats['controls_created']}")
logger.info(f"Measures defined: {self.stats['measures_defined']}")
logger.info(f"\nResults saved to: /tmp/compliance_output/")
logger.info("Checkpoint status: /tmp/pipeline_checkpoints.json")
logger.info("=" * 60)
# Complete pipeline checkpoint
if self.checkpoint_mgr:
self.checkpoint_mgr.complete_pipeline({
"duration_seconds": elapsed,
"chunks_processed": self.stats['chunks_processed'],
"checkpoints_extracted": self.stats['checkpoints_extracted'],
"controls_created": self.stats['controls_created'],
"measures_defined": self.stats['measures_defined'],
"by_regulation": self.stats['by_regulation'],
"by_domain": self.stats['by_domain'],
})
except Exception as e:
logger.error(f"Pipeline failed: {e}")
if self.checkpoint_mgr:
self.checkpoint_mgr.state.status = "failed"
self.checkpoint_mgr._save()
raise
# Re-export all public symbols
from compliance_models import Checkpoint, Control, Measure
from compliance_extraction import (
extract_checkpoints_from_chunk,
generate_control_for_checkpoints,
generate_measure_for_control,
)
from compliance_pipeline import CompliancePipeline
__all__ = [
"Checkpoint",
"Control",
"Measure",
"extract_checkpoints_from_chunk",
"generate_control_for_checkpoints",
"generate_measure_for_control",
"CompliancePipeline",
]
async def main():

View File

@@ -1,767 +1,35 @@
"""
GitHub Repository Crawler for Legal Templates.
GitHub Repository Crawler — Barrel Re-export
Crawls GitHub and GitLab repositories to extract legal template documents
(Markdown, HTML, JSON, etc.) for ingestion into the RAG system.
Split into:
- github_crawler_parsers.py — ExtractedDocument, MarkdownParser, HTMLParser, JSONParser
- github_crawler_core.py — GitHubCrawler, RepositoryDownloader, crawl_source
Features:
- Clone repositories via Git or download as ZIP
- Parse Markdown, HTML, JSON, and plain text files
- Extract structured content with metadata
- Track git commit hashes for reproducibility
- Handle rate limiting and errors gracefully
All public names are re-exported here for backward compatibility.
"""
import asyncio
import hashlib
import json
import logging
import os
import re
import shutil
import tempfile
import zipfile
from dataclasses import dataclass, field
from datetime import datetime
from fnmatch import fnmatch
from pathlib import Path
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
from urllib.parse import urlparse
import httpx
from template_sources import LicenseType, SourceConfig, LICENSES
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Configuration
GITHUB_API_URL = "https://api.github.com"
GITLAB_API_URL = "https://gitlab.com/api/v4"
GITHUB_TOKEN = os.getenv("GITHUB_TOKEN", "") # Optional for higher rate limits
MAX_FILE_SIZE = 1024 * 1024 # 1 MB max file size
REQUEST_TIMEOUT = 60.0
RATE_LIMIT_DELAY = 1.0 # Delay between requests to avoid rate limiting
@dataclass
class ExtractedDocument:
"""A document extracted from a repository."""
text: str
title: str
file_path: str
file_type: str # "markdown", "html", "json", "text"
source_url: str
source_commit: Optional[str] = None
source_hash: str = "" # SHA256 of original content
sections: List[Dict[str, Any]] = field(default_factory=list)
placeholders: List[str] = field(default_factory=list)
language: str = "en"
metadata: Dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
if not self.source_hash and self.text:
self.source_hash = hashlib.sha256(self.text.encode()).hexdigest()
class MarkdownParser:
"""Parse Markdown files into structured content."""
# Common placeholder patterns
PLACEHOLDER_PATTERNS = [
r'\[([A-Z_]+)\]', # [COMPANY_NAME]
r'\{([a-z_]+)\}', # {company_name}
r'\{\{([a-z_]+)\}\}', # {{company_name}}
r'__([A-Z_]+)__', # __COMPANY_NAME__
r'<([A-Z_]+)>', # <COMPANY_NAME>
]
@classmethod
def parse(cls, content: str, filename: str = "") -> ExtractedDocument:
"""Parse markdown content into an ExtractedDocument."""
# Extract title from first heading or filename
title = cls._extract_title(content, filename)
# Extract sections
sections = cls._extract_sections(content)
# Find placeholders
placeholders = cls._find_placeholders(content)
# Detect language
language = cls._detect_language(content)
# Clean content for indexing
clean_text = cls._clean_for_indexing(content)
return ExtractedDocument(
text=clean_text,
title=title,
file_path=filename,
file_type="markdown",
source_url="", # Will be set by caller
sections=sections,
placeholders=placeholders,
language=language,
)
@classmethod
def _extract_title(cls, content: str, filename: str) -> str:
"""Extract title from markdown heading or filename."""
# Look for first h1 heading
h1_match = re.search(r'^#\s+(.+)$', content, re.MULTILINE)
if h1_match:
return h1_match.group(1).strip()
# Look for YAML frontmatter title
frontmatter_match = re.search(
r'^---\s*\n.*?title:\s*["\']?(.+?)["\']?\s*\n.*?---',
content, re.DOTALL
)
if frontmatter_match:
return frontmatter_match.group(1).strip()
# Fall back to filename
if filename:
name = Path(filename).stem
# Convert kebab-case or snake_case to title case
return name.replace('-', ' ').replace('_', ' ').title()
return "Untitled"
@classmethod
def _extract_sections(cls, content: str) -> List[Dict[str, Any]]:
"""Extract sections from markdown content."""
sections = []
current_section = {"heading": "", "level": 0, "content": "", "start": 0}
for match in re.finditer(r'^(#{1,6})\s+(.+)$', content, re.MULTILINE):
# Save previous section if it has content
if current_section["heading"] or current_section["content"].strip():
current_section["content"] = current_section["content"].strip()
sections.append(current_section.copy())
# Start new section
level = len(match.group(1))
heading = match.group(2).strip()
current_section = {
"heading": heading,
"level": level,
"content": "",
"start": match.end(),
}
# Add final section
if current_section["heading"] or current_section["content"].strip():
current_section["content"] = content[current_section["start"]:].strip()
sections.append(current_section)
return sections
@classmethod
def _find_placeholders(cls, content: str) -> List[str]:
"""Find placeholder patterns in content."""
placeholders = set()
for pattern in cls.PLACEHOLDER_PATTERNS:
for match in re.finditer(pattern, content):
placeholder = match.group(0)
placeholders.add(placeholder)
return sorted(list(placeholders))
@classmethod
def _detect_language(cls, content: str) -> str:
"""Detect language from content."""
# Look for German-specific words
german_indicators = [
'Datenschutz', 'Impressum', 'Nutzungsbedingungen', 'Haftung',
'Widerruf', 'Verantwortlicher', 'personenbezogene', 'Verarbeitung',
'und', 'der', 'die', 'das', 'ist', 'wird', 'werden', 'sind',
]
lower_content = content.lower()
german_count = sum(1 for word in german_indicators if word.lower() in lower_content)
if german_count >= 3:
return "de"
return "en"
@classmethod
def _clean_for_indexing(cls, content: str) -> str:
"""Clean markdown content for text indexing."""
# Remove YAML frontmatter
content = re.sub(r'^---\s*\n.*?---\s*\n', '', content, flags=re.DOTALL)
# Remove HTML comments
content = re.sub(r'<!--.*?-->', '', content, flags=re.DOTALL)
# Remove inline HTML tags but keep content
content = re.sub(r'<[^>]+>', '', content)
# Convert markdown formatting
content = re.sub(r'\*\*(.+?)\*\*', r'\1', content) # Bold
content = re.sub(r'\*(.+?)\*', r'\1', content) # Italic
content = re.sub(r'`(.+?)`', r'\1', content) # Inline code
content = re.sub(r'~~(.+?)~~', r'\1', content) # Strikethrough
# Remove link syntax but keep text
content = re.sub(r'\[([^\]]+)\]\([^)]+\)', r'\1', content)
# Remove image syntax
content = re.sub(r'!\[([^\]]*)\]\([^)]+\)', r'\1', content)
# Clean up whitespace
content = re.sub(r'\n{3,}', '\n\n', content)
content = re.sub(r' +', ' ', content)
return content.strip()
class HTMLParser:
"""Parse HTML files into structured content."""
@classmethod
def parse(cls, content: str, filename: str = "") -> ExtractedDocument:
"""Parse HTML content into an ExtractedDocument."""
# Extract title
title_match = re.search(r'<title>(.+?)</title>', content, re.IGNORECASE)
title = title_match.group(1) if title_match else Path(filename).stem
# Convert to text
text = cls._html_to_text(content)
# Find placeholders
placeholders = MarkdownParser._find_placeholders(text)
# Detect language
lang_match = re.search(r'<html[^>]*lang=["\']([a-z]{2})["\']', content, re.IGNORECASE)
language = lang_match.group(1) if lang_match else MarkdownParser._detect_language(text)
return ExtractedDocument(
text=text,
title=title,
file_path=filename,
file_type="html",
source_url="",
placeholders=placeholders,
language=language,
)
@classmethod
def _html_to_text(cls, html: str) -> str:
"""Convert HTML to clean text."""
# Remove script and style tags
html = re.sub(r'<script[^>]*>.*?</script>', '', html, flags=re.DOTALL | re.IGNORECASE)
html = re.sub(r'<style[^>]*>.*?</style>', '', html, flags=re.DOTALL | re.IGNORECASE)
# Remove comments
html = re.sub(r'<!--.*?-->', '', html, flags=re.DOTALL)
# Replace common entities
html = html.replace('&nbsp;', ' ')
html = html.replace('&amp;', '&')
html = html.replace('&lt;', '<')
html = html.replace('&gt;', '>')
html = html.replace('&quot;', '"')
html = html.replace('&apos;', "'")
# Add line breaks for block elements
html = re.sub(r'<br\s*/?>', '\n', html, flags=re.IGNORECASE)
html = re.sub(r'</p>', '\n\n', html, flags=re.IGNORECASE)
html = re.sub(r'</div>', '\n', html, flags=re.IGNORECASE)
html = re.sub(r'</h[1-6]>', '\n\n', html, flags=re.IGNORECASE)
html = re.sub(r'</li>', '\n', html, flags=re.IGNORECASE)
# Remove remaining tags
html = re.sub(r'<[^>]+>', '', html)
# Clean whitespace
html = re.sub(r'[ \t]+', ' ', html)
html = re.sub(r'\n[ \t]+', '\n', html)
html = re.sub(r'[ \t]+\n', '\n', html)
html = re.sub(r'\n{3,}', '\n\n', html)
return html.strip()
class JSONParser:
"""Parse JSON files containing legal template data."""
@classmethod
def parse(cls, content: str, filename: str = "") -> List[ExtractedDocument]:
"""Parse JSON content into ExtractedDocuments."""
try:
data = json.loads(content)
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse JSON from {filename}: {e}")
return []
documents = []
if isinstance(data, dict):
# Handle different JSON structures
documents.extend(cls._parse_dict(data, filename))
elif isinstance(data, list):
for i, item in enumerate(data):
if isinstance(item, dict):
docs = cls._parse_dict(item, f"{filename}[{i}]")
documents.extend(docs)
return documents
@classmethod
def _parse_dict(cls, data: dict, filename: str) -> List[ExtractedDocument]:
"""Parse a dictionary into documents."""
documents = []
# Look for text content in common keys
text_keys = ['text', 'content', 'body', 'description', 'value']
title_keys = ['title', 'name', 'heading', 'label', 'key']
# Try to find main text content
text = ""
for key in text_keys:
if key in data and isinstance(data[key], str):
text = data[key]
break
if not text:
# Check for nested structures (like webflorist format)
for key, value in data.items():
if isinstance(value, dict):
nested_docs = cls._parse_dict(value, f"{filename}.{key}")
documents.extend(nested_docs)
elif isinstance(value, list):
for i, item in enumerate(value):
if isinstance(item, dict):
nested_docs = cls._parse_dict(item, f"{filename}.{key}[{i}]")
documents.extend(nested_docs)
elif isinstance(item, str) and len(item) > 50:
# Treat long strings as content
documents.append(ExtractedDocument(
text=item,
title=f"{key} {i+1}",
file_path=filename,
file_type="json",
source_url="",
language=MarkdownParser._detect_language(item),
))
return documents
# Found text content
title = ""
for key in title_keys:
if key in data and isinstance(data[key], str):
title = data[key]
break
if not title:
title = Path(filename).stem
# Extract metadata
metadata = {}
for key, value in data.items():
if key not in text_keys + title_keys and not isinstance(value, (dict, list)):
metadata[key] = value
placeholders = MarkdownParser._find_placeholders(text)
language = data.get('lang', data.get('language', MarkdownParser._detect_language(text)))
documents.append(ExtractedDocument(
text=text,
title=title,
file_path=filename,
file_type="json",
source_url="",
placeholders=placeholders,
language=language,
metadata=metadata,
))
return documents
class GitHubCrawler:
"""Crawl GitHub repositories for legal templates."""
def __init__(self, token: Optional[str] = None):
self.token = token or GITHUB_TOKEN
self.headers = {
"Accept": "application/vnd.github.v3+json",
"User-Agent": "LegalTemplatesCrawler/1.0",
}
if self.token:
self.headers["Authorization"] = f"token {self.token}"
self.http_client: Optional[httpx.AsyncClient] = None
async def __aenter__(self):
self.http_client = httpx.AsyncClient(
timeout=REQUEST_TIMEOUT,
headers=self.headers,
follow_redirects=True,
)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.http_client:
await self.http_client.aclose()
def _parse_repo_url(self, url: str) -> Tuple[str, str, str]:
"""Parse repository URL into owner, repo, and host."""
parsed = urlparse(url)
path_parts = parsed.path.strip('/').split('/')
if len(path_parts) < 2:
raise ValueError(f"Invalid repository URL: {url}")
owner = path_parts[0]
repo = path_parts[1].replace('.git', '')
if 'gitlab' in parsed.netloc:
host = 'gitlab'
else:
host = 'github'
return owner, repo, host
async def get_default_branch(self, owner: str, repo: str) -> str:
"""Get the default branch of a repository."""
if not self.http_client:
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
url = f"{GITHUB_API_URL}/repos/{owner}/{repo}"
response = await self.http_client.get(url)
response.raise_for_status()
data = response.json()
return data.get("default_branch", "main")
async def get_latest_commit(self, owner: str, repo: str, branch: str = "main") -> str:
"""Get the latest commit SHA for a branch."""
if not self.http_client:
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
url = f"{GITHUB_API_URL}/repos/{owner}/{repo}/commits/{branch}"
response = await self.http_client.get(url)
response.raise_for_status()
data = response.json()
return data.get("sha", "")
async def list_files(
self,
owner: str,
repo: str,
path: str = "",
branch: str = "main",
patterns: List[str] = None,
exclude_patterns: List[str] = None,
) -> List[Dict[str, Any]]:
"""List files in a repository matching the given patterns."""
if not self.http_client:
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
patterns = patterns or ["*.md", "*.txt", "*.html"]
exclude_patterns = exclude_patterns or []
url = f"{GITHUB_API_URL}/repos/{owner}/{repo}/git/trees/{branch}?recursive=1"
response = await self.http_client.get(url)
response.raise_for_status()
data = response.json()
files = []
for item in data.get("tree", []):
if item["type"] != "blob":
continue
file_path = item["path"]
# Check exclude patterns
excluded = any(fnmatch(file_path, pattern) for pattern in exclude_patterns)
if excluded:
continue
# Check include patterns
matched = any(fnmatch(file_path, pattern) for pattern in patterns)
if not matched:
continue
# Skip large files
if item.get("size", 0) > MAX_FILE_SIZE:
logger.warning(f"Skipping large file: {file_path} ({item['size']} bytes)")
continue
files.append({
"path": file_path,
"sha": item["sha"],
"size": item.get("size", 0),
"url": item.get("url", ""),
})
return files
async def get_file_content(self, owner: str, repo: str, path: str, branch: str = "main") -> str:
"""Get the content of a file from a repository."""
if not self.http_client:
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
# Use raw content URL for simplicity
url = f"https://raw.githubusercontent.com/{owner}/{repo}/{branch}/{path}"
response = await self.http_client.get(url)
response.raise_for_status()
return response.text
async def crawl_repository(
self,
source: SourceConfig,
) -> AsyncGenerator[ExtractedDocument, None]:
"""Crawl a repository and yield extracted documents."""
if not source.repo_url:
logger.warning(f"No repo URL for source: {source.name}")
return
try:
owner, repo, host = self._parse_repo_url(source.repo_url)
except ValueError as e:
logger.error(f"Failed to parse repo URL for {source.name}: {e}")
return
if host == "gitlab":
logger.info(f"GitLab repos not yet supported: {source.name}")
return
logger.info(f"Crawling repository: {owner}/{repo}")
try:
# Get default branch and latest commit
branch = await self.get_default_branch(owner, repo)
commit_sha = await self.get_latest_commit(owner, repo, branch)
await asyncio.sleep(RATE_LIMIT_DELAY)
# List files matching patterns
files = await self.list_files(
owner, repo,
branch=branch,
patterns=source.file_patterns,
exclude_patterns=source.exclude_patterns,
)
logger.info(f"Found {len(files)} matching files in {source.name}")
for file_info in files:
await asyncio.sleep(RATE_LIMIT_DELAY)
try:
content = await self.get_file_content(
owner, repo, file_info["path"], branch
)
# Parse based on file type
file_path = file_info["path"]
source_url = f"https://github.com/{owner}/{repo}/blob/{branch}/{file_path}"
if file_path.endswith('.md'):
doc = MarkdownParser.parse(content, file_path)
doc.source_url = source_url
doc.source_commit = commit_sha
yield doc
elif file_path.endswith('.html') or file_path.endswith('.htm'):
doc = HTMLParser.parse(content, file_path)
doc.source_url = source_url
doc.source_commit = commit_sha
yield doc
elif file_path.endswith('.json'):
docs = JSONParser.parse(content, file_path)
for doc in docs:
doc.source_url = source_url
doc.source_commit = commit_sha
yield doc
elif file_path.endswith('.txt'):
# Plain text file
yield ExtractedDocument(
text=content,
title=Path(file_path).stem,
file_path=file_path,
file_type="text",
source_url=source_url,
source_commit=commit_sha,
language=MarkdownParser._detect_language(content),
placeholders=MarkdownParser._find_placeholders(content),
)
except httpx.HTTPError as e:
logger.warning(f"Failed to fetch {file_path}: {e}")
continue
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")
continue
except httpx.HTTPError as e:
logger.error(f"HTTP error crawling {source.name}: {e}")
except Exception as e:
logger.error(f"Error crawling {source.name}: {e}")
class RepositoryDownloader:
"""Download and extract repository archives."""
def __init__(self):
self.http_client: Optional[httpx.AsyncClient] = None
async def __aenter__(self):
self.http_client = httpx.AsyncClient(
timeout=120.0,
follow_redirects=True,
)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.http_client:
await self.http_client.aclose()
async def download_zip(self, repo_url: str, branch: str = "main") -> Path:
"""Download repository as ZIP and extract to temp directory."""
if not self.http_client:
raise RuntimeError("Downloader not initialized. Use 'async with' context.")
parsed = urlparse(repo_url)
path_parts = parsed.path.strip('/').split('/')
owner = path_parts[0]
repo = path_parts[1].replace('.git', '')
zip_url = f"https://github.com/{owner}/{repo}/archive/refs/heads/{branch}.zip"
logger.info(f"Downloading ZIP from {zip_url}")
response = await self.http_client.get(zip_url)
response.raise_for_status()
# Save to temp file
temp_dir = Path(tempfile.mkdtemp())
zip_path = temp_dir / f"{repo}.zip"
with open(zip_path, 'wb') as f:
f.write(response.content)
# Extract ZIP
extract_dir = temp_dir / repo
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(temp_dir)
# The extracted directory is usually named repo-branch
extracted_dirs = list(temp_dir.glob(f"{repo}-*"))
if extracted_dirs:
return extracted_dirs[0]
return extract_dir
async def crawl_local_directory(
self,
directory: Path,
source: SourceConfig,
base_url: str,
) -> AsyncGenerator[ExtractedDocument, None]:
"""Crawl a local directory for documents."""
patterns = source.file_patterns or ["*.md", "*.txt", "*.html"]
exclude_patterns = source.exclude_patterns or []
for pattern in patterns:
for file_path in directory.rglob(pattern.replace("**/", "")):
if not file_path.is_file():
continue
rel_path = str(file_path.relative_to(directory))
# Check exclude patterns
excluded = any(fnmatch(rel_path, ep) for ep in exclude_patterns)
if excluded:
continue
# Skip large files
if file_path.stat().st_size > MAX_FILE_SIZE:
continue
try:
content = file_path.read_text(encoding='utf-8')
except UnicodeDecodeError:
try:
content = file_path.read_text(encoding='latin-1')
except Exception:
continue
source_url = f"{base_url}/{rel_path}"
if file_path.suffix == '.md':
doc = MarkdownParser.parse(content, rel_path)
doc.source_url = source_url
yield doc
elif file_path.suffix in ['.html', '.htm']:
doc = HTMLParser.parse(content, rel_path)
doc.source_url = source_url
yield doc
elif file_path.suffix == '.json':
docs = JSONParser.parse(content, rel_path)
for doc in docs:
doc.source_url = source_url
yield doc
elif file_path.suffix == '.txt':
yield ExtractedDocument(
text=content,
title=file_path.stem,
file_path=rel_path,
file_type="text",
source_url=source_url,
language=MarkdownParser._detect_language(content),
placeholders=MarkdownParser._find_placeholders(content),
)
def cleanup(self, directory: Path):
"""Clean up temporary directory."""
if directory.exists():
shutil.rmtree(directory, ignore_errors=True)
async def crawl_source(source: SourceConfig) -> List[ExtractedDocument]:
"""Crawl a source configuration and return all extracted documents."""
documents = []
if source.repo_url:
async with GitHubCrawler() as crawler:
async for doc in crawler.crawl_repository(source):
documents.append(doc)
return documents
# CLI for testing
async def main():
"""Test crawler with a sample source."""
from template_sources import TEMPLATE_SOURCES
# Test with github-site-policy
source = next(s for s in TEMPLATE_SOURCES if s.name == "github-site-policy")
async with GitHubCrawler() as crawler:
count = 0
async for doc in crawler.crawl_repository(source):
count += 1
print(f"\n{'='*60}")
print(f"Title: {doc.title}")
print(f"Path: {doc.file_path}")
print(f"Type: {doc.file_type}")
print(f"Language: {doc.language}")
print(f"URL: {doc.source_url}")
print(f"Placeholders: {doc.placeholders[:5] if doc.placeholders else 'None'}")
print(f"Text preview: {doc.text[:200]}...")
print(f"\n\nTotal documents: {count}")
# Parsers
from github_crawler_parsers import ( # noqa: F401
ExtractedDocument,
MarkdownParser,
HTMLParser,
JSONParser,
)
# Crawler and downloader
from github_crawler_core import ( # noqa: F401
GITHUB_API_URL,
GITLAB_API_URL,
GITHUB_TOKEN,
MAX_FILE_SIZE,
REQUEST_TIMEOUT,
RATE_LIMIT_DELAY,
GitHubCrawler,
RepositoryDownloader,
crawl_source,
main,
)
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View File

@@ -0,0 +1,411 @@
"""
GitHub Crawler - Core Crawler and Downloader
GitHubCrawler for API-based repository crawling and RepositoryDownloader
for ZIP-based local extraction.
Extracted from github_crawler.py to keep files under 500 LOC.
"""
import asyncio
import logging
import os
import shutil
import tempfile
import zipfile
from fnmatch import fnmatch
from pathlib import Path
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
from urllib.parse import urlparse
import httpx
from template_sources import SourceConfig
from github_crawler_parsers import (
ExtractedDocument,
MarkdownParser,
HTMLParser,
JSONParser,
)
logger = logging.getLogger(__name__)
# Configuration
GITHUB_API_URL = "https://api.github.com"
GITLAB_API_URL = "https://gitlab.com/api/v4"
GITHUB_TOKEN = os.getenv("GITHUB_TOKEN", "")
MAX_FILE_SIZE = 1024 * 1024 # 1 MB max file size
REQUEST_TIMEOUT = 60.0
RATE_LIMIT_DELAY = 1.0
class GitHubCrawler:
"""Crawl GitHub repositories for legal templates."""
def __init__(self, token: Optional[str] = None):
self.token = token or GITHUB_TOKEN
self.headers = {
"Accept": "application/vnd.github.v3+json",
"User-Agent": "LegalTemplatesCrawler/1.0",
}
if self.token:
self.headers["Authorization"] = f"token {self.token}"
self.http_client: Optional[httpx.AsyncClient] = None
async def __aenter__(self):
self.http_client = httpx.AsyncClient(
timeout=REQUEST_TIMEOUT,
headers=self.headers,
follow_redirects=True,
)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.http_client:
await self.http_client.aclose()
def _parse_repo_url(self, url: str) -> Tuple[str, str, str]:
"""Parse repository URL into owner, repo, and host."""
parsed = urlparse(url)
path_parts = parsed.path.strip('/').split('/')
if len(path_parts) < 2:
raise ValueError(f"Invalid repository URL: {url}")
owner = path_parts[0]
repo = path_parts[1].replace('.git', '')
if 'gitlab' in parsed.netloc:
host = 'gitlab'
else:
host = 'github'
return owner, repo, host
async def get_default_branch(self, owner: str, repo: str) -> str:
"""Get the default branch of a repository."""
if not self.http_client:
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
url = f"{GITHUB_API_URL}/repos/{owner}/{repo}"
response = await self.http_client.get(url)
response.raise_for_status()
data = response.json()
return data.get("default_branch", "main")
async def get_latest_commit(self, owner: str, repo: str, branch: str = "main") -> str:
"""Get the latest commit SHA for a branch."""
if not self.http_client:
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
url = f"{GITHUB_API_URL}/repos/{owner}/{repo}/commits/{branch}"
response = await self.http_client.get(url)
response.raise_for_status()
data = response.json()
return data.get("sha", "")
async def list_files(
self,
owner: str,
repo: str,
path: str = "",
branch: str = "main",
patterns: List[str] = None,
exclude_patterns: List[str] = None,
) -> List[Dict[str, Any]]:
"""List files in a repository matching the given patterns."""
if not self.http_client:
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
patterns = patterns or ["*.md", "*.txt", "*.html"]
exclude_patterns = exclude_patterns or []
url = f"{GITHUB_API_URL}/repos/{owner}/{repo}/git/trees/{branch}?recursive=1"
response = await self.http_client.get(url)
response.raise_for_status()
data = response.json()
files = []
for item in data.get("tree", []):
if item["type"] != "blob":
continue
file_path = item["path"]
excluded = any(fnmatch(file_path, pattern) for pattern in exclude_patterns)
if excluded:
continue
matched = any(fnmatch(file_path, pattern) for pattern in patterns)
if not matched:
continue
if item.get("size", 0) > MAX_FILE_SIZE:
logger.warning(f"Skipping large file: {file_path} ({item['size']} bytes)")
continue
files.append({
"path": file_path,
"sha": item["sha"],
"size": item.get("size", 0),
"url": item.get("url", ""),
})
return files
async def get_file_content(self, owner: str, repo: str, path: str, branch: str = "main") -> str:
"""Get the content of a file from a repository."""
if not self.http_client:
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
url = f"https://raw.githubusercontent.com/{owner}/{repo}/{branch}/{path}"
response = await self.http_client.get(url)
response.raise_for_status()
return response.text
async def crawl_repository(
self,
source: SourceConfig,
) -> AsyncGenerator[ExtractedDocument, None]:
"""Crawl a repository and yield extracted documents."""
if not source.repo_url:
logger.warning(f"No repo URL for source: {source.name}")
return
try:
owner, repo, host = self._parse_repo_url(source.repo_url)
except ValueError as e:
logger.error(f"Failed to parse repo URL for {source.name}: {e}")
return
if host == "gitlab":
logger.info(f"GitLab repos not yet supported: {source.name}")
return
logger.info(f"Crawling repository: {owner}/{repo}")
try:
branch = await self.get_default_branch(owner, repo)
commit_sha = await self.get_latest_commit(owner, repo, branch)
await asyncio.sleep(RATE_LIMIT_DELAY)
files = await self.list_files(
owner, repo,
branch=branch,
patterns=source.file_patterns,
exclude_patterns=source.exclude_patterns,
)
logger.info(f"Found {len(files)} matching files in {source.name}")
for file_info in files:
await asyncio.sleep(RATE_LIMIT_DELAY)
try:
content = await self.get_file_content(
owner, repo, file_info["path"], branch
)
file_path = file_info["path"]
source_url = f"https://github.com/{owner}/{repo}/blob/{branch}/{file_path}"
if file_path.endswith('.md'):
doc = MarkdownParser.parse(content, file_path)
doc.source_url = source_url
doc.source_commit = commit_sha
yield doc
elif file_path.endswith('.html') or file_path.endswith('.htm'):
doc = HTMLParser.parse(content, file_path)
doc.source_url = source_url
doc.source_commit = commit_sha
yield doc
elif file_path.endswith('.json'):
docs = JSONParser.parse(content, file_path)
for doc in docs:
doc.source_url = source_url
doc.source_commit = commit_sha
yield doc
elif file_path.endswith('.txt'):
yield ExtractedDocument(
text=content,
title=Path(file_path).stem,
file_path=file_path,
file_type="text",
source_url=source_url,
source_commit=commit_sha,
language=MarkdownParser._detect_language(content),
placeholders=MarkdownParser._find_placeholders(content),
)
except httpx.HTTPError as e:
logger.warning(f"Failed to fetch {file_path}: {e}")
continue
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")
continue
except httpx.HTTPError as e:
logger.error(f"HTTP error crawling {source.name}: {e}")
except Exception as e:
logger.error(f"Error crawling {source.name}: {e}")
class RepositoryDownloader:
"""Download and extract repository archives."""
def __init__(self):
self.http_client: Optional[httpx.AsyncClient] = None
async def __aenter__(self):
self.http_client = httpx.AsyncClient(
timeout=120.0,
follow_redirects=True,
)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.http_client:
await self.http_client.aclose()
async def download_zip(self, repo_url: str, branch: str = "main") -> Path:
"""Download repository as ZIP and extract to temp directory."""
if not self.http_client:
raise RuntimeError("Downloader not initialized. Use 'async with' context.")
parsed = urlparse(repo_url)
path_parts = parsed.path.strip('/').split('/')
owner = path_parts[0]
repo = path_parts[1].replace('.git', '')
zip_url = f"https://github.com/{owner}/{repo}/archive/refs/heads/{branch}.zip"
logger.info(f"Downloading ZIP from {zip_url}")
response = await self.http_client.get(zip_url)
response.raise_for_status()
temp_dir = Path(tempfile.mkdtemp())
zip_path = temp_dir / f"{repo}.zip"
with open(zip_path, 'wb') as f:
f.write(response.content)
extract_dir = temp_dir / repo
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(temp_dir)
extracted_dirs = list(temp_dir.glob(f"{repo}-*"))
if extracted_dirs:
return extracted_dirs[0]
return extract_dir
async def crawl_local_directory(
self,
directory: Path,
source: SourceConfig,
base_url: str,
) -> AsyncGenerator[ExtractedDocument, None]:
"""Crawl a local directory for documents."""
patterns = source.file_patterns or ["*.md", "*.txt", "*.html"]
exclude_patterns = source.exclude_patterns or []
for pattern in patterns:
for file_path in directory.rglob(pattern.replace("**/", "")):
if not file_path.is_file():
continue
rel_path = str(file_path.relative_to(directory))
excluded = any(fnmatch(rel_path, ep) for ep in exclude_patterns)
if excluded:
continue
if file_path.stat().st_size > MAX_FILE_SIZE:
continue
try:
content = file_path.read_text(encoding='utf-8')
except UnicodeDecodeError:
try:
content = file_path.read_text(encoding='latin-1')
except Exception:
continue
source_url = f"{base_url}/{rel_path}"
if file_path.suffix == '.md':
doc = MarkdownParser.parse(content, rel_path)
doc.source_url = source_url
yield doc
elif file_path.suffix in ['.html', '.htm']:
doc = HTMLParser.parse(content, rel_path)
doc.source_url = source_url
yield doc
elif file_path.suffix == '.json':
docs = JSONParser.parse(content, rel_path)
for doc in docs:
doc.source_url = source_url
yield doc
elif file_path.suffix == '.txt':
yield ExtractedDocument(
text=content,
title=file_path.stem,
file_path=rel_path,
file_type="text",
source_url=source_url,
language=MarkdownParser._detect_language(content),
placeholders=MarkdownParser._find_placeholders(content),
)
def cleanup(self, directory: Path):
"""Clean up temporary directory."""
if directory.exists():
shutil.rmtree(directory, ignore_errors=True)
async def crawl_source(source: SourceConfig) -> List[ExtractedDocument]:
"""Crawl a source configuration and return all extracted documents."""
documents = []
if source.repo_url:
async with GitHubCrawler() as crawler:
async for doc in crawler.crawl_repository(source):
documents.append(doc)
return documents
# CLI for testing
async def main():
"""Test crawler with a sample source."""
from template_sources import TEMPLATE_SOURCES
source = next(s for s in TEMPLATE_SOURCES if s.name == "github-site-policy")
async with GitHubCrawler() as crawler:
count = 0
async for doc in crawler.crawl_repository(source):
count += 1
print(f"\n{'='*60}")
print(f"Title: {doc.title}")
print(f"Path: {doc.file_path}")
print(f"Type: {doc.file_type}")
print(f"Language: {doc.language}")
print(f"URL: {doc.source_url}")
print(f"Placeholders: {doc.placeholders[:5] if doc.placeholders else 'None'}")
print(f"Text preview: {doc.text[:200]}...")
print(f"\n\nTotal documents: {count}")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,303 @@
"""
GitHub Crawler - Document Parsers
Markdown, HTML, and JSON parsers for extracting structured content
from legal template documents.
Extracted from github_crawler.py to keep files under 500 LOC.
"""
import hashlib
import json
import re
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional
@dataclass
class ExtractedDocument:
"""A document extracted from a repository."""
text: str
title: str
file_path: str
file_type: str # "markdown", "html", "json", "text"
source_url: str
source_commit: Optional[str] = None
source_hash: str = "" # SHA256 of original content
sections: List[Dict[str, Any]] = field(default_factory=list)
placeholders: List[str] = field(default_factory=list)
language: str = "en"
metadata: Dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
if not self.source_hash and self.text:
self.source_hash = hashlib.sha256(self.text.encode()).hexdigest()
class MarkdownParser:
"""Parse Markdown files into structured content."""
# Common placeholder patterns
PLACEHOLDER_PATTERNS = [
r'\[([A-Z_]+)\]', # [COMPANY_NAME]
r'\{([a-z_]+)\}', # {company_name}
r'\{\{([a-z_]+)\}\}', # {{company_name}}
r'__([A-Z_]+)__', # __COMPANY_NAME__
r'<([A-Z_]+)>', # <COMPANY_NAME>
]
@classmethod
def parse(cls, content: str, filename: str = "") -> ExtractedDocument:
"""Parse markdown content into an ExtractedDocument."""
title = cls._extract_title(content, filename)
sections = cls._extract_sections(content)
placeholders = cls._find_placeholders(content)
language = cls._detect_language(content)
clean_text = cls._clean_for_indexing(content)
return ExtractedDocument(
text=clean_text,
title=title,
file_path=filename,
file_type="markdown",
source_url="",
sections=sections,
placeholders=placeholders,
language=language,
)
@classmethod
def _extract_title(cls, content: str, filename: str) -> str:
"""Extract title from markdown heading or filename."""
h1_match = re.search(r'^#\s+(.+)$', content, re.MULTILINE)
if h1_match:
return h1_match.group(1).strip()
frontmatter_match = re.search(
r'^---\s*\n.*?title:\s*["\']?(.+?)["\']?\s*\n.*?---',
content, re.DOTALL
)
if frontmatter_match:
return frontmatter_match.group(1).strip()
if filename:
name = Path(filename).stem
return name.replace('-', ' ').replace('_', ' ').title()
return "Untitled"
@classmethod
def _extract_sections(cls, content: str) -> List[Dict[str, Any]]:
"""Extract sections from markdown content."""
sections = []
current_section = {"heading": "", "level": 0, "content": "", "start": 0}
for match in re.finditer(r'^(#{1,6})\s+(.+)$', content, re.MULTILINE):
if current_section["heading"] or current_section["content"].strip():
current_section["content"] = current_section["content"].strip()
sections.append(current_section.copy())
level = len(match.group(1))
heading = match.group(2).strip()
current_section = {
"heading": heading,
"level": level,
"content": "",
"start": match.end(),
}
if current_section["heading"] or current_section["content"].strip():
current_section["content"] = content[current_section["start"]:].strip()
sections.append(current_section)
return sections
@classmethod
def _find_placeholders(cls, content: str) -> List[str]:
"""Find placeholder patterns in content."""
placeholders = set()
for pattern in cls.PLACEHOLDER_PATTERNS:
for match in re.finditer(pattern, content):
placeholder = match.group(0)
placeholders.add(placeholder)
return sorted(list(placeholders))
@classmethod
def _detect_language(cls, content: str) -> str:
"""Detect language from content."""
german_indicators = [
'Datenschutz', 'Impressum', 'Nutzungsbedingungen', 'Haftung',
'Widerruf', 'Verantwortlicher', 'personenbezogene', 'Verarbeitung',
'und', 'der', 'die', 'das', 'ist', 'wird', 'werden', 'sind',
]
lower_content = content.lower()
german_count = sum(1 for word in german_indicators if word.lower() in lower_content)
if german_count >= 3:
return "de"
return "en"
@classmethod
def _clean_for_indexing(cls, content: str) -> str:
"""Clean markdown content for text indexing."""
content = re.sub(r'^---\s*\n.*?---\s*\n', '', content, flags=re.DOTALL)
content = re.sub(r'<!--.*?-->', '', content, flags=re.DOTALL)
content = re.sub(r'<[^>]+>', '', content)
content = re.sub(r'\*\*(.+?)\*\*', r'\1', content)
content = re.sub(r'\*(.+?)\*', r'\1', content)
content = re.sub(r'`(.+?)`', r'\1', content)
content = re.sub(r'~~(.+?)~~', r'\1', content)
content = re.sub(r'\[([^\]]+)\]\([^)]+\)', r'\1', content)
content = re.sub(r'!\[([^\]]*)\]\([^)]+\)', r'\1', content)
content = re.sub(r'\n{3,}', '\n\n', content)
content = re.sub(r' +', ' ', content)
return content.strip()
class HTMLParser:
"""Parse HTML files into structured content."""
@classmethod
def parse(cls, content: str, filename: str = "") -> ExtractedDocument:
"""Parse HTML content into an ExtractedDocument."""
title_match = re.search(r'<title>(.+?)</title>', content, re.IGNORECASE)
title = title_match.group(1) if title_match else Path(filename).stem
text = cls._html_to_text(content)
placeholders = MarkdownParser._find_placeholders(text)
lang_match = re.search(r'<html[^>]*lang=["\']([a-z]{2})["\']', content, re.IGNORECASE)
language = lang_match.group(1) if lang_match else MarkdownParser._detect_language(text)
return ExtractedDocument(
text=text,
title=title,
file_path=filename,
file_type="html",
source_url="",
placeholders=placeholders,
language=language,
)
@classmethod
def _html_to_text(cls, html: str) -> str:
"""Convert HTML to clean text."""
html = re.sub(r'<script[^>]*>.*?</script>', '', html, flags=re.DOTALL | re.IGNORECASE)
html = re.sub(r'<style[^>]*>.*?</style>', '', html, flags=re.DOTALL | re.IGNORECASE)
html = re.sub(r'<!--.*?-->', '', html, flags=re.DOTALL)
html = html.replace('&nbsp;', ' ')
html = html.replace('&amp;', '&')
html = html.replace('&lt;', '<')
html = html.replace('&gt;', '>')
html = html.replace('&quot;', '"')
html = html.replace('&apos;', "'")
html = re.sub(r'<br\s*/?>', '\n', html, flags=re.IGNORECASE)
html = re.sub(r'</p>', '\n\n', html, flags=re.IGNORECASE)
html = re.sub(r'</div>', '\n', html, flags=re.IGNORECASE)
html = re.sub(r'</h[1-6]>', '\n\n', html, flags=re.IGNORECASE)
html = re.sub(r'</li>', '\n', html, flags=re.IGNORECASE)
html = re.sub(r'<[^>]+>', '', html)
html = re.sub(r'[ \t]+', ' ', html)
html = re.sub(r'\n[ \t]+', '\n', html)
html = re.sub(r'[ \t]+\n', '\n', html)
html = re.sub(r'\n{3,}', '\n\n', html)
return html.strip()
class JSONParser:
"""Parse JSON files containing legal template data."""
@classmethod
def parse(cls, content: str, filename: str = "") -> List[ExtractedDocument]:
"""Parse JSON content into ExtractedDocuments."""
try:
data = json.loads(content)
except json.JSONDecodeError as e:
import logging
logging.getLogger(__name__).warning(f"Failed to parse JSON from {filename}: {e}")
return []
documents = []
if isinstance(data, dict):
documents.extend(cls._parse_dict(data, filename))
elif isinstance(data, list):
for i, item in enumerate(data):
if isinstance(item, dict):
docs = cls._parse_dict(item, f"{filename}[{i}]")
documents.extend(docs)
return documents
@classmethod
def _parse_dict(cls, data: dict, filename: str) -> List[ExtractedDocument]:
"""Parse a dictionary into documents."""
documents = []
text_keys = ['text', 'content', 'body', 'description', 'value']
title_keys = ['title', 'name', 'heading', 'label', 'key']
text = ""
for key in text_keys:
if key in data and isinstance(data[key], str):
text = data[key]
break
if not text:
for key, value in data.items():
if isinstance(value, dict):
nested_docs = cls._parse_dict(value, f"{filename}.{key}")
documents.extend(nested_docs)
elif isinstance(value, list):
for i, item in enumerate(value):
if isinstance(item, dict):
nested_docs = cls._parse_dict(item, f"{filename}.{key}[{i}]")
documents.extend(nested_docs)
elif isinstance(item, str) and len(item) > 50:
documents.append(ExtractedDocument(
text=item,
title=f"{key} {i+1}",
file_path=filename,
file_type="json",
source_url="",
language=MarkdownParser._detect_language(item),
))
return documents
title = ""
for key in title_keys:
if key in data and isinstance(data[key], str):
title = data[key]
break
if not title:
title = Path(filename).stem
metadata = {}
for key, value in data.items():
if key not in text_keys + title_keys and not isinstance(value, (dict, list)):
metadata[key] = value
placeholders = MarkdownParser._find_placeholders(text)
language = data.get('lang', data.get('language', MarkdownParser._detect_language(text)))
documents.append(ExtractedDocument(
text=text,
title=title,
file_path=filename,
file_type="json",
source_url="",
placeholders=placeholders,
language=language,
metadata=metadata,
))
return documents

View File

@@ -1,790 +1,30 @@
"""
Legal Corpus API - Endpoints for RAG page in admin-v2
Legal Corpus API — Barrel Re-export
Provides endpoints for:
- GET /api/v1/admin/legal-corpus/status - Collection status with chunk counts
- GET /api/v1/admin/legal-corpus/search - Semantic search
- POST /api/v1/admin/legal-corpus/ingest - Trigger ingestion
- GET /api/v1/admin/legal-corpus/ingestion-status - Ingestion status
- POST /api/v1/admin/legal-corpus/upload - Upload document
- POST /api/v1/admin/legal-corpus/add-link - Add link for ingestion
- POST /api/v1/admin/pipeline/start - Start compliance pipeline
Split into:
- legal_corpus_routes.py — Corpus endpoints (status, search, ingest, upload)
- legal_corpus_pipeline.py — Pipeline endpoints (checkpoints, start, status)
All public names are re-exported here for backward compatibility.
"""
import os
import asyncio
import httpx
import uuid
import shutil
from datetime import datetime
from typing import Optional, List, Dict, Any
from fastapi import APIRouter, HTTPException, Query, BackgroundTasks, UploadFile, File, Form
from pydantic import BaseModel
import logging
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/admin/legal-corpus", tags=["legal-corpus"])
# Configuration
QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6333")
EMBEDDING_SERVICE_URL = os.getenv("EMBEDDING_SERVICE_URL", "http://embedding-service:8087")
COLLECTION_NAME = "bp_legal_corpus"
# All regulations for status endpoint
REGULATIONS = [
{"code": "GDPR", "name": "DSGVO", "fullName": "Datenschutz-Grundverordnung", "type": "eu_regulation"},
{"code": "EPRIVACY", "name": "ePrivacy-Richtlinie", "fullName": "Richtlinie 2002/58/EG", "type": "eu_directive"},
{"code": "TDDDG", "name": "TDDDG", "fullName": "Telekommunikation-Digitale-Dienste-Datenschutz-Gesetz", "type": "de_law"},
{"code": "SCC", "name": "Standardvertragsklauseln", "fullName": "2021/914/EU", "type": "eu_regulation"},
{"code": "DPF", "name": "EU-US Data Privacy Framework", "fullName": "Angemessenheitsbeschluss", "type": "eu_regulation"},
{"code": "AIACT", "name": "EU AI Act", "fullName": "Verordnung (EU) 2024/1689", "type": "eu_regulation"},
{"code": "CRA", "name": "Cyber Resilience Act", "fullName": "Verordnung (EU) 2024/2847", "type": "eu_regulation"},
{"code": "NIS2", "name": "NIS2-Richtlinie", "fullName": "Richtlinie (EU) 2022/2555", "type": "eu_directive"},
{"code": "EUCSA", "name": "EU Cybersecurity Act", "fullName": "Verordnung (EU) 2019/881", "type": "eu_regulation"},
{"code": "DATAACT", "name": "Data Act", "fullName": "Verordnung (EU) 2023/2854", "type": "eu_regulation"},
{"code": "DGA", "name": "Data Governance Act", "fullName": "Verordnung (EU) 2022/868", "type": "eu_regulation"},
{"code": "DSA", "name": "Digital Services Act", "fullName": "Verordnung (EU) 2022/2065", "type": "eu_regulation"},
{"code": "EAA", "name": "European Accessibility Act", "fullName": "Richtlinie (EU) 2019/882", "type": "eu_directive"},
{"code": "DSM", "name": "DSM-Urheberrechtsrichtlinie", "fullName": "Richtlinie (EU) 2019/790", "type": "eu_directive"},
{"code": "PLD", "name": "Produkthaftungsrichtlinie", "fullName": "Richtlinie 85/374/EWG", "type": "eu_directive"},
{"code": "GPSR", "name": "General Product Safety", "fullName": "Verordnung (EU) 2023/988", "type": "eu_regulation"},
{"code": "BSI-TR-03161-1", "name": "BSI-TR Teil 1", "fullName": "BSI TR-03161 Teil 1 - Mobile Anwendungen", "type": "bsi_standard"},
{"code": "BSI-TR-03161-2", "name": "BSI-TR Teil 2", "fullName": "BSI TR-03161 Teil 2 - Web-Anwendungen", "type": "bsi_standard"},
{"code": "BSI-TR-03161-3", "name": "BSI-TR Teil 3", "fullName": "BSI TR-03161 Teil 3 - Hintergrundsysteme", "type": "bsi_standard"},
]
# Ingestion state (in-memory for now)
ingestion_state = {
"running": False,
"completed": False,
"current_regulation": None,
"processed": 0,
"total": len(REGULATIONS),
"error": None,
}
class SearchRequest(BaseModel):
query: str
regulations: Optional[List[str]] = None
top_k: int = 5
class IngestRequest(BaseModel):
force: bool = False
regulations: Optional[List[str]] = None
class AddLinkRequest(BaseModel):
url: str
title: str
code: str # Regulation code (e.g. "CUSTOM-1")
document_type: str = "custom" # custom, eu_regulation, eu_directive, de_law, bsi_standard
class StartPipelineRequest(BaseModel):
force_reindex: bool = False
skip_ingestion: bool = False
# Store for custom documents (in-memory for now, should be persisted)
custom_documents: List[Dict[str, Any]] = []
async def get_qdrant_client():
"""Get async HTTP client for Qdrant."""
return httpx.AsyncClient(timeout=30.0)
@router.get("/status")
async def get_legal_corpus_status():
"""
Get status of the legal corpus collection including chunk counts per regulation.
"""
async with httpx.AsyncClient(timeout=30.0) as client:
try:
# Get collection info
collection_res = await client.get(f"{QDRANT_URL}/collections/{COLLECTION_NAME}")
if collection_res.status_code != 200:
return {
"collection": COLLECTION_NAME,
"totalPoints": 0,
"vectorSize": 1024,
"status": "not_found",
"regulations": {},
}
collection_data = collection_res.json()
result = collection_data.get("result", {})
# Get chunk counts per regulation
regulation_counts = {}
for reg in REGULATIONS:
count_res = await client.post(
f"{QDRANT_URL}/collections/{COLLECTION_NAME}/points/count",
json={
"filter": {
"must": [{"key": "regulation_code", "match": {"value": reg["code"]}}]
}
},
)
if count_res.status_code == 200:
count_data = count_res.json()
regulation_counts[reg["code"]] = count_data.get("result", {}).get("count", 0)
else:
regulation_counts[reg["code"]] = 0
return {
"collection": COLLECTION_NAME,
"totalPoints": result.get("points_count", 0),
"vectorSize": result.get("config", {}).get("params", {}).get("vectors", {}).get("size", 1024),
"status": result.get("status", "unknown"),
"regulations": regulation_counts,
}
except httpx.RequestError as e:
logger.error(f"Failed to get Qdrant status: {e}")
raise HTTPException(status_code=503, detail=f"Qdrant not available: {str(e)}")
@router.get("/search")
async def search_legal_corpus(
query: str = Query(..., description="Search query"),
top_k: int = Query(5, ge=1, le=20, description="Number of results"),
regulations: Optional[str] = Query(None, description="Comma-separated regulation codes to filter"),
):
"""
Semantic search in legal corpus using BGE-M3 embeddings.
"""
async with httpx.AsyncClient(timeout=60.0) as client:
try:
# Generate embedding for query
embed_res = await client.post(
f"{EMBEDDING_SERVICE_URL}/embed",
json={"texts": [query]},
)
if embed_res.status_code != 200:
raise HTTPException(status_code=500, detail="Embedding service error")
embed_data = embed_res.json()
query_vector = embed_data["embeddings"][0]
# Build Qdrant search request
search_request = {
"vector": query_vector,
"limit": top_k,
"with_payload": True,
}
# Add regulation filter if specified
if regulations:
reg_codes = [r.strip() for r in regulations.split(",")]
search_request["filter"] = {
"should": [
{"key": "regulation_code", "match": {"value": code}}
for code in reg_codes
]
}
# Search Qdrant
search_res = await client.post(
f"{QDRANT_URL}/collections/{COLLECTION_NAME}/points/search",
json=search_request,
)
if search_res.status_code != 200:
raise HTTPException(status_code=500, detail="Search failed")
search_data = search_res.json()
results = []
for point in search_data.get("result", []):
payload = point.get("payload", {})
results.append({
"text": payload.get("text", ""),
"regulation_code": payload.get("regulation_code", ""),
"regulation_name": payload.get("regulation_name", ""),
"article": payload.get("article"),
"paragraph": payload.get("paragraph"),
"source_url": payload.get("source_url", ""),
"score": point.get("score", 0),
})
return {"results": results, "query": query, "count": len(results)}
except httpx.RequestError as e:
logger.error(f"Search failed: {e}")
raise HTTPException(status_code=503, detail=f"Service not available: {str(e)}")
@router.post("/ingest")
async def trigger_ingestion(request: IngestRequest, background_tasks: BackgroundTasks):
"""
Trigger legal corpus ingestion in background.
"""
global ingestion_state
if ingestion_state["running"]:
raise HTTPException(status_code=409, detail="Ingestion already running")
# Reset state
ingestion_state = {
"running": True,
"completed": False,
"current_regulation": None,
"processed": 0,
"total": len(REGULATIONS),
"error": None,
}
# Start ingestion in background
background_tasks.add_task(run_ingestion, request.force, request.regulations)
return {
"status": "started",
"job_id": "manual-trigger",
"message": f"Ingestion started for {len(REGULATIONS)} regulations",
}
async def run_ingestion(force: bool, regulations: Optional[List[str]]):
"""Background task for running ingestion."""
global ingestion_state
try:
# Import ingestion module
from legal_corpus_ingestion import LegalCorpusIngestion
ingestion = LegalCorpusIngestion()
# Filter regulations if specified
regs_to_process = regulations or [r["code"] for r in REGULATIONS]
for i, reg_code in enumerate(regs_to_process):
ingestion_state["current_regulation"] = reg_code
ingestion_state["processed"] = i
try:
await ingestion.ingest_single(reg_code, force=force)
except Exception as e:
logger.error(f"Failed to ingest {reg_code}: {e}")
ingestion_state["completed"] = True
ingestion_state["processed"] = len(regs_to_process)
except Exception as e:
logger.error(f"Ingestion failed: {e}")
ingestion_state["error"] = str(e)
finally:
ingestion_state["running"] = False
@router.get("/ingestion-status")
async def get_ingestion_status():
"""
Get current ingestion status.
"""
return ingestion_state
@router.get("/regulations")
async def get_regulations():
"""
Get list of all supported regulations.
"""
return {"regulations": REGULATIONS}
@router.get("/custom-documents")
async def get_custom_documents():
"""
Get list of custom documents added by user.
"""
return {"documents": custom_documents}
@router.post("/upload")
async def upload_document(
background_tasks: BackgroundTasks,
file: UploadFile = File(...),
title: str = Form(...),
code: str = Form(...),
document_type: str = Form("custom"),
):
"""
Upload a document (PDF) for ingestion into the legal corpus.
The document will be saved and queued for processing.
"""
global custom_documents
# Validate file type
if not file.filename.endswith(('.pdf', '.PDF')):
raise HTTPException(status_code=400, detail="Only PDF files are supported")
# Create upload directory if needed
upload_dir = "/tmp/legal_corpus_uploads"
os.makedirs(upload_dir, exist_ok=True)
# Save file with unique name
doc_id = str(uuid.uuid4())[:8]
safe_filename = f"{doc_id}_{file.filename}"
file_path = os.path.join(upload_dir, safe_filename)
try:
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
except Exception as e:
logger.error(f"Failed to save uploaded file: {e}")
raise HTTPException(status_code=500, detail=f"Failed to save file: {str(e)}")
# Create document record
doc_record = {
"id": doc_id,
"code": code,
"title": title,
"filename": file.filename,
"file_path": file_path,
"document_type": document_type,
"uploaded_at": datetime.now().isoformat(),
"status": "uploaded",
"chunk_count": 0,
}
custom_documents.append(doc_record)
# Queue for background ingestion
background_tasks.add_task(ingest_uploaded_document, doc_record)
return {
"status": "uploaded",
"document_id": doc_id,
"message": f"Document '{title}' uploaded and queued for ingestion",
"document": doc_record,
}
async def ingest_uploaded_document(doc_record: Dict[str, Any]):
"""Background task to ingest an uploaded document."""
global custom_documents
try:
doc_record["status"] = "processing"
from legal_corpus_ingestion import LegalCorpusIngestion
ingestion = LegalCorpusIngestion()
# Read PDF and extract text
import fitz # PyMuPDF
doc = fitz.open(doc_record["file_path"])
full_text = ""
for page in doc:
full_text += page.get_text()
doc.close()
if not full_text.strip():
doc_record["status"] = "error"
doc_record["error"] = "No text could be extracted from PDF"
return
# Chunk the text
chunks = ingestion.chunk_text(full_text, doc_record["code"])
# Add metadata
for chunk in chunks:
chunk["regulation_code"] = doc_record["code"]
chunk["regulation_name"] = doc_record["title"]
chunk["document_type"] = doc_record["document_type"]
chunk["source_url"] = f"upload://{doc_record['filename']}"
# Generate embeddings and upsert to Qdrant
if chunks:
await ingestion.embed_and_upsert(chunks)
doc_record["chunk_count"] = len(chunks)
doc_record["status"] = "indexed"
logger.info(f"Ingested {len(chunks)} chunks from uploaded document {doc_record['code']}")
else:
doc_record["status"] = "error"
doc_record["error"] = "No chunks generated from document"
except Exception as e:
logger.error(f"Failed to ingest uploaded document: {e}")
doc_record["status"] = "error"
doc_record["error"] = str(e)
@router.post("/add-link")
async def add_link(request: AddLinkRequest, background_tasks: BackgroundTasks):
"""
Add a URL/link for ingestion into the legal corpus.
The content will be fetched, extracted, and indexed.
"""
global custom_documents
# Create document record
doc_id = str(uuid.uuid4())[:8]
doc_record = {
"id": doc_id,
"code": request.code,
"title": request.title,
"url": request.url,
"document_type": request.document_type,
"uploaded_at": datetime.now().isoformat(),
"status": "queued",
"chunk_count": 0,
}
custom_documents.append(doc_record)
# Queue for background ingestion
background_tasks.add_task(ingest_link_document, doc_record)
return {
"status": "queued",
"document_id": doc_id,
"message": f"Link '{request.title}' queued for ingestion",
"document": doc_record,
}
async def ingest_link_document(doc_record: Dict[str, Any]):
"""Background task to ingest content from a URL."""
global custom_documents
try:
doc_record["status"] = "fetching"
async with httpx.AsyncClient(timeout=60.0) as client:
# Fetch the URL
response = await client.get(doc_record["url"], follow_redirects=True)
response.raise_for_status()
content_type = response.headers.get("content-type", "")
if "application/pdf" in content_type:
# Save PDF and process
import tempfile
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
f.write(response.content)
pdf_path = f.name
import fitz
pdf_doc = fitz.open(pdf_path)
full_text = ""
for page in pdf_doc:
full_text += page.get_text()
pdf_doc.close()
os.unlink(pdf_path)
elif "text/html" in content_type:
# Extract text from HTML
from bs4 import BeautifulSoup
soup = BeautifulSoup(response.text, "html.parser")
# Remove script and style elements
for script in soup(["script", "style", "nav", "footer", "header"]):
script.decompose()
full_text = soup.get_text(separator="\n", strip=True)
else:
# Try to use as plain text
full_text = response.text
if not full_text.strip():
doc_record["status"] = "error"
doc_record["error"] = "No text could be extracted from URL"
return
doc_record["status"] = "processing"
from legal_corpus_ingestion import LegalCorpusIngestion
ingestion = LegalCorpusIngestion()
# Chunk the text
chunks = ingestion.chunk_text(full_text, doc_record["code"])
# Add metadata
for chunk in chunks:
chunk["regulation_code"] = doc_record["code"]
chunk["regulation_name"] = doc_record["title"]
chunk["document_type"] = doc_record["document_type"]
chunk["source_url"] = doc_record["url"]
# Generate embeddings and upsert to Qdrant
if chunks:
await ingestion.embed_and_upsert(chunks)
doc_record["chunk_count"] = len(chunks)
doc_record["status"] = "indexed"
logger.info(f"Ingested {len(chunks)} chunks from URL {doc_record['url']}")
else:
doc_record["status"] = "error"
doc_record["error"] = "No chunks generated from content"
except httpx.HTTPError as e:
logger.error(f"Failed to fetch URL: {e}")
doc_record["status"] = "error"
doc_record["error"] = f"Failed to fetch URL: {str(e)}"
except Exception as e:
logger.error(f"Failed to ingest URL content: {e}")
doc_record["status"] = "error"
doc_record["error"] = str(e)
@router.delete("/custom-documents/{doc_id}")
async def delete_custom_document(doc_id: str):
"""
Delete a custom document from the list.
Note: This does not remove the chunks from Qdrant yet.
"""
global custom_documents
doc = next((d for d in custom_documents if d["id"] == doc_id), None)
if not doc:
raise HTTPException(status_code=404, detail="Document not found")
custom_documents = [d for d in custom_documents if d["id"] != doc_id]
# TODO: Also remove chunks from Qdrant by filtering on code
return {"status": "deleted", "document_id": doc_id}
# ========== Pipeline Checkpoints ==========
# Create a separate router for pipeline-related endpoints
pipeline_router = APIRouter(prefix="/api/v1/admin/pipeline", tags=["pipeline"])
@pipeline_router.get("/checkpoints")
async def get_pipeline_checkpoints():
"""
Get current pipeline checkpoint state.
Returns the current state of the compliance pipeline including:
- Pipeline ID and overall status
- Start and completion times
- All checkpoints with their validations and metrics
- Summary data
"""
from pipeline_checkpoints import CheckpointManager
state = CheckpointManager.load_state()
if state is None:
return {
"status": "no_data",
"message": "No pipeline run data available yet.",
"pipeline_id": None,
"checkpoints": [],
"summary": {}
}
# Enrich with validation summary
validation_summary = {
"passed": 0,
"warning": 0,
"failed": 0,
"total": 0
}
for checkpoint in state.get("checkpoints", []):
for validation in checkpoint.get("validations", []):
validation_summary["total"] += 1
status = validation.get("status", "not_run")
if status in validation_summary:
validation_summary[status] += 1
state["validation_summary"] = validation_summary
return state
@pipeline_router.get("/checkpoints/history")
async def get_pipeline_history():
"""
Get list of previous pipeline runs (if stored).
For now, returns only current run.
"""
from pipeline_checkpoints import CheckpointManager
state = CheckpointManager.load_state()
if state is None:
return {"runs": []}
return {
"runs": [{
"pipeline_id": state.get("pipeline_id"),
"status": state.get("status"),
"started_at": state.get("started_at"),
"completed_at": state.get("completed_at"),
}]
}
# Pipeline state for start/stop
pipeline_process_state = {
"running": False,
"pid": None,
"started_at": None,
}
@pipeline_router.post("/start")
async def start_pipeline(request: StartPipelineRequest, background_tasks: BackgroundTasks):
"""
Start the compliance pipeline in the background.
This runs the full_compliance_pipeline.py script which:
1. Ingests all legal documents (unless skip_ingestion=True)
2. Extracts requirements and controls
3. Generates compliance measures
4. Creates checkpoint data for monitoring
"""
global pipeline_process_state
# Check if already running
from pipeline_checkpoints import CheckpointManager
state = CheckpointManager.load_state()
if state and state.get("status") == "running":
raise HTTPException(
status_code=409,
detail="Pipeline is already running"
)
if pipeline_process_state["running"]:
raise HTTPException(
status_code=409,
detail="Pipeline start already in progress"
)
pipeline_process_state["running"] = True
pipeline_process_state["started_at"] = datetime.now().isoformat()
# Start pipeline in background
background_tasks.add_task(
run_pipeline_background,
request.force_reindex,
request.skip_ingestion
)
return {
"status": "starting",
"message": "Compliance pipeline is starting in background",
"started_at": pipeline_process_state["started_at"],
}
async def run_pipeline_background(force_reindex: bool, skip_ingestion: bool):
"""Background task to run the compliance pipeline."""
global pipeline_process_state
try:
import subprocess
import sys
# Build command
cmd = [sys.executable, "full_compliance_pipeline.py"]
if force_reindex:
cmd.append("--force-reindex")
if skip_ingestion:
cmd.append("--skip-ingestion")
# Run as subprocess
logger.info(f"Starting pipeline: {' '.join(cmd)}")
process = subprocess.Popen(
cmd,
cwd=os.path.dirname(os.path.abspath(__file__)),
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
pipeline_process_state["pid"] = process.pid
# Wait for completion (non-blocking via asyncio)
import asyncio
while process.poll() is None:
await asyncio.sleep(5)
return_code = process.returncode
if return_code != 0:
output = process.stdout.read() if process.stdout else ""
logger.error(f"Pipeline failed with code {return_code}: {output}")
else:
logger.info("Pipeline completed successfully")
except Exception as e:
logger.error(f"Failed to run pipeline: {e}")
finally:
pipeline_process_state["running"] = False
pipeline_process_state["pid"] = None
@pipeline_router.get("/status")
async def get_pipeline_status():
"""
Get current pipeline running status.
"""
from pipeline_checkpoints import CheckpointManager
state = CheckpointManager.load_state()
checkpoint_status = state.get("status") if state else "no_data"
return {
"process_running": pipeline_process_state["running"],
"process_pid": pipeline_process_state["pid"],
"process_started_at": pipeline_process_state["started_at"],
"checkpoint_status": checkpoint_status,
"current_phase": state.get("current_phase") if state else None,
}
# ========== Traceability / Quality Endpoints ==========
@router.get("/traceability")
async def get_traceability(
chunk_id: str = Query(..., description="Chunk ID or identifier"),
regulation: str = Query(..., description="Regulation code"),
):
"""
Get traceability information for a specific chunk.
Returns:
- The chunk details
- Requirements extracted from this chunk
- Controls derived from those requirements
Note: This is a placeholder that will be enhanced once the
requirements extraction pipeline is fully implemented.
"""
async with httpx.AsyncClient(timeout=30.0) as client:
try:
# Try to find the chunk by scrolling through points with the regulation filter
# In a production system, we would have proper IDs and indexing
# For now, return placeholder structure
# The actual implementation will query:
# 1. The chunk from Qdrant
# 2. Requirements from a requirements collection/table
# 3. Controls from a controls collection/table
return {
"chunk_id": chunk_id,
"regulation": regulation,
"requirements": [],
"controls": [],
"message": "Traceability-Daten werden verfuegbar sein, sobald die Requirements-Extraktion und Control-Ableitung implementiert sind."
}
except Exception as e:
logger.error(f"Failed to get traceability: {e}")
raise HTTPException(status_code=500, detail=f"Traceability lookup failed: {str(e)}")
# Corpus routes and state
from legal_corpus_routes import ( # noqa: F401
router,
REGULATIONS,
COLLECTION_NAME,
QDRANT_URL,
EMBEDDING_SERVICE_URL,
ingestion_state,
custom_documents,
SearchRequest,
IngestRequest,
AddLinkRequest,
)
# Pipeline routes and state
from legal_corpus_pipeline import ( # noqa: F401
pipeline_router,
pipeline_process_state,
StartPipelineRequest,
)

View File

@@ -0,0 +1,166 @@
"""
Legal Corpus API - Background Ingestion Tasks
Background tasks for ingesting uploaded documents and URL links
into the legal corpus vector database.
Extracted from legal_corpus_routes.py to keep files under 500 LOC.
"""
import os
import logging
from typing import Dict, Any, Optional, List
import httpx
logger = logging.getLogger(__name__)
async def ingest_uploaded_document(doc_record: Dict[str, Any]):
"""Background task to ingest an uploaded document."""
try:
doc_record["status"] = "processing"
from legal_corpus_ingestion import LegalCorpusIngestion
ingestion = LegalCorpusIngestion()
import fitz
doc = fitz.open(doc_record["file_path"])
full_text = ""
for page in doc:
full_text += page.get_text()
doc.close()
if not full_text.strip():
doc_record["status"] = "error"
doc_record["error"] = "No text could be extracted from PDF"
return
chunks = ingestion.chunk_text(full_text, doc_record["code"])
for chunk in chunks:
chunk["regulation_code"] = doc_record["code"]
chunk["regulation_name"] = doc_record["title"]
chunk["document_type"] = doc_record["document_type"]
chunk["source_url"] = f"upload://{doc_record['filename']}"
if chunks:
await ingestion.embed_and_upsert(chunks)
doc_record["chunk_count"] = len(chunks)
doc_record["status"] = "indexed"
logger.info(f"Ingested {len(chunks)} chunks from uploaded document {doc_record['code']}")
else:
doc_record["status"] = "error"
doc_record["error"] = "No chunks generated from document"
except Exception as e:
logger.error(f"Failed to ingest uploaded document: {e}")
doc_record["status"] = "error"
doc_record["error"] = str(e)
async def ingest_link_document(doc_record: Dict[str, Any]):
"""Background task to ingest content from a URL."""
try:
doc_record["status"] = "fetching"
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.get(doc_record["url"], follow_redirects=True)
response.raise_for_status()
content_type = response.headers.get("content-type", "")
if "application/pdf" in content_type:
import tempfile
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
f.write(response.content)
pdf_path = f.name
import fitz
pdf_doc = fitz.open(pdf_path)
full_text = ""
for page in pdf_doc:
full_text += page.get_text()
pdf_doc.close()
os.unlink(pdf_path)
elif "text/html" in content_type:
from bs4 import BeautifulSoup
soup = BeautifulSoup(response.text, "html.parser")
for script in soup(["script", "style", "nav", "footer", "header"]):
script.decompose()
full_text = soup.get_text(separator="\n", strip=True)
else:
full_text = response.text
if not full_text.strip():
doc_record["status"] = "error"
doc_record["error"] = "No text could be extracted from URL"
return
doc_record["status"] = "processing"
from legal_corpus_ingestion import LegalCorpusIngestion
ingestion = LegalCorpusIngestion()
chunks = ingestion.chunk_text(full_text, doc_record["code"])
for chunk in chunks:
chunk["regulation_code"] = doc_record["code"]
chunk["regulation_name"] = doc_record["title"]
chunk["document_type"] = doc_record["document_type"]
chunk["source_url"] = doc_record["url"]
if chunks:
await ingestion.embed_and_upsert(chunks)
doc_record["chunk_count"] = len(chunks)
doc_record["status"] = "indexed"
logger.info(f"Ingested {len(chunks)} chunks from URL {doc_record['url']}")
else:
doc_record["status"] = "error"
doc_record["error"] = "No chunks generated from content"
except httpx.HTTPError as e:
logger.error(f"Failed to fetch URL: {e}")
doc_record["status"] = "error"
doc_record["error"] = f"Failed to fetch URL: {str(e)}"
except Exception as e:
logger.error(f"Failed to ingest URL content: {e}")
doc_record["status"] = "error"
doc_record["error"] = str(e)
async def run_ingestion(
force: bool,
regulations: Optional[List[str]],
ingestion_state: Dict[str, Any],
all_regulations: List[Dict[str, str]],
):
"""Background task for running full corpus ingestion."""
try:
from legal_corpus_ingestion import LegalCorpusIngestion
ingestion = LegalCorpusIngestion()
regs_to_process = regulations or [r["code"] for r in all_regulations]
for i, reg_code in enumerate(regs_to_process):
ingestion_state["current_regulation"] = reg_code
ingestion_state["processed"] = i
try:
await ingestion.ingest_single(reg_code, force=force)
except Exception as e:
logger.error(f"Failed to ingest {reg_code}: {e}")
ingestion_state["completed"] = True
ingestion_state["processed"] = len(regs_to_process)
except Exception as e:
logger.error(f"Ingestion failed: {e}")
ingestion_state["error"] = str(e)
finally:
ingestion_state["running"] = False

View File

@@ -0,0 +1,206 @@
"""
Legal Corpus API - Pipeline Routes
Pipeline checkpoints, history, start/stop, and status endpoints.
Extracted from legal_corpus_api.py to keep files under 500 LOC.
"""
import os
import asyncio
from datetime import datetime
from fastapi import APIRouter, HTTPException, BackgroundTasks
from pydantic import BaseModel
import logging
logger = logging.getLogger(__name__)
class StartPipelineRequest(BaseModel):
force_reindex: bool = False
skip_ingestion: bool = False
# Create a separate router for pipeline-related endpoints
pipeline_router = APIRouter(prefix="/api/v1/admin/pipeline", tags=["pipeline"])
@pipeline_router.get("/checkpoints")
async def get_pipeline_checkpoints():
"""
Get current pipeline checkpoint state.
Returns the current state of the compliance pipeline including:
- Pipeline ID and overall status
- Start and completion times
- All checkpoints with their validations and metrics
- Summary data
"""
from pipeline_checkpoints import CheckpointManager
state = CheckpointManager.load_state()
if state is None:
return {
"status": "no_data",
"message": "No pipeline run data available yet.",
"pipeline_id": None,
"checkpoints": [],
"summary": {}
}
# Enrich with validation summary
validation_summary = {
"passed": 0,
"warning": 0,
"failed": 0,
"total": 0
}
for checkpoint in state.get("checkpoints", []):
for validation in checkpoint.get("validations", []):
validation_summary["total"] += 1
status = validation.get("status", "not_run")
if status in validation_summary:
validation_summary[status] += 1
state["validation_summary"] = validation_summary
return state
@pipeline_router.get("/checkpoints/history")
async def get_pipeline_history():
"""
Get list of previous pipeline runs (if stored).
For now, returns only current run.
"""
from pipeline_checkpoints import CheckpointManager
state = CheckpointManager.load_state()
if state is None:
return {"runs": []}
return {
"runs": [{
"pipeline_id": state.get("pipeline_id"),
"status": state.get("status"),
"started_at": state.get("started_at"),
"completed_at": state.get("completed_at"),
}]
}
# Pipeline state for start/stop
pipeline_process_state = {
"running": False,
"pid": None,
"started_at": None,
}
@pipeline_router.post("/start")
async def start_pipeline(request: StartPipelineRequest, background_tasks: BackgroundTasks):
"""
Start the compliance pipeline in the background.
This runs the full_compliance_pipeline.py script which:
1. Ingests all legal documents (unless skip_ingestion=True)
2. Extracts requirements and controls
3. Generates compliance measures
4. Creates checkpoint data for monitoring
"""
global pipeline_process_state
from pipeline_checkpoints import CheckpointManager
state = CheckpointManager.load_state()
if state and state.get("status") == "running":
raise HTTPException(
status_code=409,
detail="Pipeline is already running"
)
if pipeline_process_state["running"]:
raise HTTPException(
status_code=409,
detail="Pipeline start already in progress"
)
pipeline_process_state["running"] = True
pipeline_process_state["started_at"] = datetime.now().isoformat()
background_tasks.add_task(
run_pipeline_background,
request.force_reindex,
request.skip_ingestion
)
return {
"status": "starting",
"message": "Compliance pipeline is starting in background",
"started_at": pipeline_process_state["started_at"],
}
async def run_pipeline_background(force_reindex: bool, skip_ingestion: bool):
"""Background task to run the compliance pipeline."""
global pipeline_process_state
try:
import subprocess
import sys
cmd = [sys.executable, "full_compliance_pipeline.py"]
if force_reindex:
cmd.append("--force-reindex")
if skip_ingestion:
cmd.append("--skip-ingestion")
logger.info(f"Starting pipeline: {' '.join(cmd)}")
process = subprocess.Popen(
cmd,
cwd=os.path.dirname(os.path.abspath(__file__)),
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
pipeline_process_state["pid"] = process.pid
while process.poll() is None:
await asyncio.sleep(5)
return_code = process.returncode
if return_code != 0:
output = process.stdout.read() if process.stdout else ""
logger.error(f"Pipeline failed with code {return_code}: {output}")
else:
logger.info("Pipeline completed successfully")
except Exception as e:
logger.error(f"Failed to run pipeline: {e}")
finally:
pipeline_process_state["running"] = False
pipeline_process_state["pid"] = None
@pipeline_router.get("/status")
async def get_pipeline_status():
"""Get current pipeline running status."""
from pipeline_checkpoints import CheckpointManager
state = CheckpointManager.load_state()
checkpoint_status = state.get("status") if state else "no_data"
return {
"process_running": pipeline_process_state["running"],
"process_pid": pipeline_process_state["pid"],
"process_started_at": pipeline_process_state["started_at"],
"checkpoint_status": checkpoint_status,
"current_phase": state.get("current_phase") if state else None,
}

View File

@@ -0,0 +1,368 @@
"""
Legal Corpus API - Corpus Routes
Endpoints for the RAG page in admin-v2:
- GET /status - Collection status with chunk counts
- GET /search - Semantic search
- POST /ingest - Trigger ingestion
- GET /ingestion-status - Ingestion status
- GET /regulations - List regulations
- GET /custom-documents - List custom docs
- POST /upload - Upload document
- POST /add-link - Add link for ingestion
- DELETE /custom-documents/{id} - Delete custom doc
- GET /traceability - Traceability info
Extracted from legal_corpus_api.py to keep files under 500 LOC.
"""
import os
import httpx
import uuid
import shutil
from datetime import datetime
from typing import Optional, List, Dict, Any
from fastapi import APIRouter, HTTPException, Query, BackgroundTasks, UploadFile, File, Form
from pydantic import BaseModel
import logging
from legal_corpus_ingest_tasks import (
ingest_uploaded_document,
ingest_link_document,
run_ingestion,
)
logger = logging.getLogger(__name__)
# Configuration
QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6333")
EMBEDDING_SERVICE_URL = os.getenv("EMBEDDING_SERVICE_URL", "http://embedding-service:8087")
COLLECTION_NAME = "bp_legal_corpus"
# All regulations for status endpoint
REGULATIONS = [
{"code": "GDPR", "name": "DSGVO", "fullName": "Datenschutz-Grundverordnung", "type": "eu_regulation"},
{"code": "EPRIVACY", "name": "ePrivacy-Richtlinie", "fullName": "Richtlinie 2002/58/EG", "type": "eu_directive"},
{"code": "TDDDG", "name": "TDDDG", "fullName": "Telekommunikation-Digitale-Dienste-Datenschutz-Gesetz", "type": "de_law"},
{"code": "SCC", "name": "Standardvertragsklauseln", "fullName": "2021/914/EU", "type": "eu_regulation"},
{"code": "DPF", "name": "EU-US Data Privacy Framework", "fullName": "Angemessenheitsbeschluss", "type": "eu_regulation"},
{"code": "AIACT", "name": "EU AI Act", "fullName": "Verordnung (EU) 2024/1689", "type": "eu_regulation"},
{"code": "CRA", "name": "Cyber Resilience Act", "fullName": "Verordnung (EU) 2024/2847", "type": "eu_regulation"},
{"code": "NIS2", "name": "NIS2-Richtlinie", "fullName": "Richtlinie (EU) 2022/2555", "type": "eu_directive"},
{"code": "EUCSA", "name": "EU Cybersecurity Act", "fullName": "Verordnung (EU) 2019/881", "type": "eu_regulation"},
{"code": "DATAACT", "name": "Data Act", "fullName": "Verordnung (EU) 2023/2854", "type": "eu_regulation"},
{"code": "DGA", "name": "Data Governance Act", "fullName": "Verordnung (EU) 2022/868", "type": "eu_regulation"},
{"code": "DSA", "name": "Digital Services Act", "fullName": "Verordnung (EU) 2022/2065", "type": "eu_regulation"},
{"code": "EAA", "name": "European Accessibility Act", "fullName": "Richtlinie (EU) 2019/882", "type": "eu_directive"},
{"code": "DSM", "name": "DSM-Urheberrechtsrichtlinie", "fullName": "Richtlinie (EU) 2019/790", "type": "eu_directive"},
{"code": "PLD", "name": "Produkthaftungsrichtlinie", "fullName": "Richtlinie 85/374/EWG", "type": "eu_directive"},
{"code": "GPSR", "name": "General Product Safety", "fullName": "Verordnung (EU) 2023/988", "type": "eu_regulation"},
{"code": "BSI-TR-03161-1", "name": "BSI-TR Teil 1", "fullName": "BSI TR-03161 Teil 1 - Mobile Anwendungen", "type": "bsi_standard"},
{"code": "BSI-TR-03161-2", "name": "BSI-TR Teil 2", "fullName": "BSI TR-03161 Teil 2 - Web-Anwendungen", "type": "bsi_standard"},
{"code": "BSI-TR-03161-3", "name": "BSI-TR Teil 3", "fullName": "BSI TR-03161 Teil 3 - Hintergrundsysteme", "type": "bsi_standard"},
]
# Ingestion state (in-memory for now)
ingestion_state = {
"running": False,
"completed": False,
"current_regulation": None,
"processed": 0,
"total": len(REGULATIONS),
"error": None,
}
class SearchRequest(BaseModel):
query: str
regulations: Optional[List[str]] = None
top_k: int = 5
class IngestRequest(BaseModel):
force: bool = False
regulations: Optional[List[str]] = None
class AddLinkRequest(BaseModel):
url: str
title: str
code: str
document_type: str = "custom"
# Store for custom documents (in-memory for now)
custom_documents: List[Dict[str, Any]] = []
router = APIRouter(prefix="/api/v1/admin/legal-corpus", tags=["legal-corpus"])
@router.get("/status")
async def get_legal_corpus_status():
"""Get status of the legal corpus collection including chunk counts per regulation."""
async with httpx.AsyncClient(timeout=30.0) as client:
try:
collection_res = await client.get(f"{QDRANT_URL}/collections/{COLLECTION_NAME}")
if collection_res.status_code != 200:
return {
"collection": COLLECTION_NAME,
"totalPoints": 0,
"vectorSize": 1024,
"status": "not_found",
"regulations": {},
}
collection_data = collection_res.json()
result = collection_data.get("result", {})
regulation_counts = {}
for reg in REGULATIONS:
count_res = await client.post(
f"{QDRANT_URL}/collections/{COLLECTION_NAME}/points/count",
json={
"filter": {
"must": [{"key": "regulation_code", "match": {"value": reg["code"]}}]
}
},
)
if count_res.status_code == 200:
count_data = count_res.json()
regulation_counts[reg["code"]] = count_data.get("result", {}).get("count", 0)
else:
regulation_counts[reg["code"]] = 0
return {
"collection": COLLECTION_NAME,
"totalPoints": result.get("points_count", 0),
"vectorSize": result.get("config", {}).get("params", {}).get("vectors", {}).get("size", 1024),
"status": result.get("status", "unknown"),
"regulations": regulation_counts,
}
except httpx.RequestError as e:
logger.error(f"Failed to get Qdrant status: {e}")
raise HTTPException(status_code=503, detail=f"Qdrant not available: {str(e)}")
@router.get("/search")
async def search_legal_corpus(
query: str = Query(..., description="Search query"),
top_k: int = Query(5, ge=1, le=20, description="Number of results"),
regulations: Optional[str] = Query(None, description="Comma-separated regulation codes to filter"),
):
"""Semantic search in legal corpus using BGE-M3 embeddings."""
async with httpx.AsyncClient(timeout=60.0) as client:
try:
embed_res = await client.post(
f"{EMBEDDING_SERVICE_URL}/embed",
json={"texts": [query]},
)
if embed_res.status_code != 200:
raise HTTPException(status_code=500, detail="Embedding service error")
embed_data = embed_res.json()
query_vector = embed_data["embeddings"][0]
search_request = {
"vector": query_vector,
"limit": top_k,
"with_payload": True,
}
if regulations:
reg_codes = [r.strip() for r in regulations.split(",")]
search_request["filter"] = {
"should": [
{"key": "regulation_code", "match": {"value": code}}
for code in reg_codes
]
}
search_res = await client.post(
f"{QDRANT_URL}/collections/{COLLECTION_NAME}/points/search",
json=search_request,
)
if search_res.status_code != 200:
raise HTTPException(status_code=500, detail="Search failed")
search_data = search_res.json()
results = []
for point in search_data.get("result", []):
payload = point.get("payload", {})
results.append({
"text": payload.get("text", ""),
"regulation_code": payload.get("regulation_code", ""),
"regulation_name": payload.get("regulation_name", ""),
"article": payload.get("article"),
"paragraph": payload.get("paragraph"),
"source_url": payload.get("source_url", ""),
"score": point.get("score", 0),
})
return {"results": results, "query": query, "count": len(results)}
except httpx.RequestError as e:
logger.error(f"Search failed: {e}")
raise HTTPException(status_code=503, detail=f"Service not available: {str(e)}")
@router.post("/ingest")
async def trigger_ingestion(request: IngestRequest, background_tasks: BackgroundTasks):
"""Trigger legal corpus ingestion in background."""
global ingestion_state
if ingestion_state["running"]:
raise HTTPException(status_code=409, detail="Ingestion already running")
ingestion_state = {
"running": True,
"completed": False,
"current_regulation": None,
"processed": 0,
"total": len(REGULATIONS),
"error": None,
}
background_tasks.add_task(run_ingestion, request.force, request.regulations, ingestion_state, REGULATIONS)
return {
"status": "started",
"job_id": "manual-trigger",
"message": f"Ingestion started for {len(REGULATIONS)} regulations",
}
@router.get("/ingestion-status")
async def get_ingestion_status():
"""Get current ingestion status."""
return ingestion_state
@router.get("/regulations")
async def get_regulations():
"""Get list of all supported regulations."""
return {"regulations": REGULATIONS}
@router.get("/custom-documents")
async def get_custom_documents():
"""Get list of custom documents added by user."""
return {"documents": custom_documents}
@router.post("/upload")
async def upload_document(
background_tasks: BackgroundTasks,
file: UploadFile = File(...),
title: str = Form(...),
code: str = Form(...),
document_type: str = Form("custom"),
):
"""Upload a document (PDF) for ingestion into the legal corpus."""
global custom_documents
if not file.filename.endswith(('.pdf', '.PDF')):
raise HTTPException(status_code=400, detail="Only PDF files are supported")
upload_dir = "/tmp/legal_corpus_uploads"
os.makedirs(upload_dir, exist_ok=True)
doc_id = str(uuid.uuid4())[:8]
safe_filename = f"{doc_id}_{file.filename}"
file_path = os.path.join(upload_dir, safe_filename)
try:
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
except Exception as e:
logger.error(f"Failed to save uploaded file: {e}")
raise HTTPException(status_code=500, detail=f"Failed to save file: {str(e)}")
doc_record = {
"id": doc_id,
"code": code,
"title": title,
"filename": file.filename,
"file_path": file_path,
"document_type": document_type,
"uploaded_at": datetime.now().isoformat(),
"status": "uploaded",
"chunk_count": 0,
}
custom_documents.append(doc_record)
background_tasks.add_task(ingest_uploaded_document, doc_record)
return {
"status": "uploaded",
"document_id": doc_id,
"message": f"Document '{title}' uploaded and queued for ingestion",
"document": doc_record,
}
@router.post("/add-link")
async def add_link(request: AddLinkRequest, background_tasks: BackgroundTasks):
"""Add a URL/link for ingestion into the legal corpus."""
global custom_documents
doc_id = str(uuid.uuid4())[:8]
doc_record = {
"id": doc_id,
"code": request.code,
"title": request.title,
"url": request.url,
"document_type": request.document_type,
"uploaded_at": datetime.now().isoformat(),
"status": "queued",
"chunk_count": 0,
}
custom_documents.append(doc_record)
background_tasks.add_task(ingest_link_document, doc_record)
return {
"status": "queued",
"document_id": doc_id,
"message": f"Link '{request.title}' queued for ingestion",
"document": doc_record,
}
@router.delete("/custom-documents/{doc_id}")
async def delete_custom_document(doc_id: str):
"""Delete a custom document from the list."""
global custom_documents
doc = next((d for d in custom_documents if d["id"] == doc_id), None)
if not doc:
raise HTTPException(status_code=404, detail="Document not found")
custom_documents = [d for d in custom_documents if d["id"] != doc_id]
return {"status": "deleted", "document_id": doc_id}
@router.get("/traceability")
async def get_traceability(
chunk_id: str = Query(..., description="Chunk ID or identifier"),
regulation: str = Query(..., description="Regulation code"),
):
"""Get traceability information for a specific chunk."""
async with httpx.AsyncClient(timeout=30.0) as client:
try:
return {
"chunk_id": chunk_id,
"regulation": regulation,
"requirements": [],
"controls": [],
"message": "Traceability-Daten werden verfuegbar sein, sobald die Requirements-Extraktion und Control-Ableitung implementiert sind."
}
except Exception as e:
logger.error(f"Failed to get traceability: {e}")
raise HTTPException(status_code=500, detail=f"Traceability lookup failed: {str(e)}")

View File

@@ -0,0 +1,269 @@
"""
AI Email - Category Classification and Response Suggestions
Rule-based and LLM-based email category classification,
plus response suggestion generation.
Extracted from ai_service.py to keep files under 500 LOC.
"""
import os
import logging
from typing import Optional, List, Tuple
import httpx
from .models import (
EmailCategory,
SenderType,
ResponseSuggestion,
)
logger = logging.getLogger(__name__)
# LLM Gateway configuration
LLM_GATEWAY_URL = os.getenv("LLM_GATEWAY_URL", "http://localhost:8090")
async def classify_category(
http_client: httpx.AsyncClient,
subject: str,
body_preview: str,
sender_type: SenderType,
) -> Tuple[EmailCategory, float]:
"""
Classify email into a category.
Rule-based classification first, falls back to LLM.
"""
category, confidence = _classify_category_rules(subject, body_preview, sender_type)
if confidence > 0.7:
return category, confidence
return await _classify_category_llm(http_client, subject, body_preview)
def _classify_category_rules(
subject: str,
body_preview: str,
sender_type: SenderType,
) -> Tuple[EmailCategory, float]:
"""Rule-based category classification."""
text = f"{subject} {body_preview}".lower()
category_keywords = {
EmailCategory.DIENSTLICH: [
"dienstlich", "dienstanweisung", "erlass", "verordnung",
"bescheid", "verfuegung", "ministerium", "behoerde"
],
EmailCategory.PERSONAL: [
"personalrat", "stellenausschreibung", "versetzung",
"beurteilung", "dienstzeugnis", "krankmeldung", "elternzeit"
],
EmailCategory.FINANZEN: [
"budget", "haushalt", "etat", "abrechnung", "rechnung",
"erstattung", "zuschuss", "foerdermittel"
],
EmailCategory.ELTERN: [
"elternbrief", "elternabend", "schulkonferenz",
"elternvertreter", "elternbeirat"
],
EmailCategory.SCHUELER: [
"schueler", "schuelerin", "zeugnis", "klasse", "unterricht",
"pruefung", "klassenfahrt", "schulpflicht"
],
EmailCategory.FORTBILDUNG: [
"fortbildung", "seminar", "workshop", "schulung",
"weiterbildung", "nlq", "didaktik"
],
EmailCategory.VERANSTALTUNG: [
"einladung", "veranstaltung", "termin", "konferenz",
"sitzung", "tagung", "feier"
],
EmailCategory.SICHERHEIT: [
"sicherheit", "notfall", "brandschutz", "evakuierung",
"hygiene", "corona", "infektionsschutz"
],
EmailCategory.TECHNIK: [
"it", "software", "computer", "netzwerk", "login",
"passwort", "digitalisierung", "iserv"
],
EmailCategory.NEWSLETTER: [
"newsletter", "rundschreiben", "info-mail", "mitteilung"
],
EmailCategory.WERBUNG: [
"angebot", "rabatt", "aktion", "werbung", "abonnement"
],
}
best_category = EmailCategory.SONSTIGES
best_score = 0.0
for category, keywords in category_keywords.items():
score = sum(1 for kw in keywords if kw in text)
if score > best_score:
best_score = score
best_category = category
if sender_type in [SenderType.KULTUSMINISTERIUM, SenderType.LANDESSCHULBEHOERDE, SenderType.RLSB]:
if best_category == EmailCategory.SONSTIGES:
best_category = EmailCategory.DIENSTLICH
best_score = 2
confidence = min(0.9, 0.4 + (best_score * 0.15))
return best_category, confidence
async def _classify_category_llm(
client: httpx.AsyncClient,
subject: str,
body_preview: str,
) -> Tuple[EmailCategory, float]:
"""LLM-based category classification."""
try:
categories = ", ".join([c.value for c in EmailCategory])
prompt = f"""Klassifiziere diese E-Mail in EINE Kategorie:
Betreff: {subject}
Inhalt: {body_preview[:500]}
Kategorien: {categories}
Antworte NUR mit dem Kategorienamen und einer Konfidenz (0.0-1.0):
Format: kategorie|konfidenz
"""
response = await client.post(
f"{LLM_GATEWAY_URL}/api/v1/inference",
json={
"prompt": prompt,
"playbook": "mail_analysis",
"max_tokens": 50,
},
)
if response.status_code == 200:
data = response.json()
result = data.get("response", "sonstiges|0.5")
parts = result.strip().split("|")
if len(parts) >= 2:
category_str = parts[0].strip().lower()
confidence = float(parts[1].strip())
try:
category = EmailCategory(category_str)
return category, min(max(confidence, 0.0), 1.0)
except ValueError:
pass
except Exception as e:
logger.warning(f"LLM category classification failed: {e}")
return EmailCategory.SONSTIGES, 0.5
async def suggest_response(
http_client: httpx.AsyncClient,
subject: str,
body_text: str,
sender_type: SenderType,
category: EmailCategory,
) -> List[ResponseSuggestion]:
"""Generate response suggestions for an email."""
suggestions = []
if sender_type in [SenderType.KULTUSMINISTERIUM, SenderType.LANDESSCHULBEHOERDE, SenderType.RLSB]:
suggestions.append(ResponseSuggestion(
template_type="acknowledgment",
subject=f"Re: {subject}",
body="""Sehr geehrte Damen und Herren,
vielen Dank fuer Ihre Nachricht.
Ich bestaetige den Eingang und werde die Angelegenheit fristgerecht bearbeiten.
Mit freundlichen Gruessen""",
confidence=0.8,
))
if category == EmailCategory.ELTERN:
suggestions.append(ResponseSuggestion(
template_type="parent_response",
subject=f"Re: {subject}",
body="""Liebe Eltern,
vielen Dank fuer Ihre Nachricht.
[Ihre Antwort hier]
Mit freundlichen Gruessen""",
confidence=0.7,
))
try:
llm_suggestion = await _generate_response_llm(http_client, subject, body_text[:500], sender_type)
if llm_suggestion:
suggestions.append(llm_suggestion)
except Exception as e:
logger.warning(f"LLM response generation failed: {e}")
return suggestions
async def _generate_response_llm(
client: httpx.AsyncClient,
subject: str,
body_preview: str,
sender_type: SenderType,
) -> Optional[ResponseSuggestion]:
"""Generate a response suggestion using LLM."""
try:
sender_desc = {
SenderType.KULTUSMINISTERIUM: "dem Kultusministerium",
SenderType.LANDESSCHULBEHOERDE: "der Landesschulbehoerde",
SenderType.RLSB: "dem RLSB",
SenderType.ELTERNVERTRETER: "einem Elternvertreter",
}.get(sender_type, "einem Absender")
prompt = f"""Du bist eine Schulleiterin in Niedersachsen. Formuliere eine professionelle, kurze Antwort auf diese E-Mail von {sender_desc}:
Betreff: {subject}
Inhalt: {body_preview}
Die Antwort sollte:
- Hoeflich und formell sein
- Den Eingang bestaetigen
- Eine konkrete naechste Aktion nennen oder um Klaerung bitten
Antworte NUR mit dem Antworttext (ohne Betreffzeile, ohne "Betreff:").
"""
response = await client.post(
f"{LLM_GATEWAY_URL}/api/v1/inference",
json={
"prompt": prompt,
"playbook": "mail_analysis",
"max_tokens": 300,
},
)
if response.status_code == 200:
data = response.json()
body = data.get("response", "").strip()
if body:
return ResponseSuggestion(
template_type="ai_generated",
subject=f"Re: {subject}",
body=body,
confidence=0.6,
)
except Exception as e:
logger.warning(f"LLM response generation failed: {e}")
return None

View File

@@ -0,0 +1,184 @@
"""
AI Email - Deadline Extraction
Regex-based and LLM-based deadline extraction from email content.
Extracted from ai_service.py to keep files under 500 LOC.
"""
import os
import re
import logging
from typing import List
from datetime import datetime, timedelta
import httpx
from .models import DeadlineExtraction
logger = logging.getLogger(__name__)
# LLM Gateway configuration
LLM_GATEWAY_URL = os.getenv("LLM_GATEWAY_URL", "http://localhost:8090")
async def extract_deadlines(
http_client: httpx.AsyncClient,
subject: str,
body_text: str,
) -> List[DeadlineExtraction]:
"""
Extract deadlines from email content.
Uses regex patterns first, then LLM for complex cases.
"""
deadlines = []
full_text = f"{subject}\n{body_text}" if body_text else subject
# Try regex extraction first
regex_deadlines = _extract_deadlines_regex(full_text)
deadlines.extend(regex_deadlines)
# If no regex matches, try LLM
if not deadlines and body_text:
llm_deadlines = await _extract_deadlines_llm(http_client, subject, body_text[:1000])
deadlines.extend(llm_deadlines)
return deadlines
def _extract_deadlines_regex(text: str) -> List[DeadlineExtraction]:
"""Extract deadlines using regex patterns."""
deadlines = []
now = datetime.now()
# German date patterns
patterns = [
# "bis zum 15.01.2025"
(r"bis\s+(?:zum\s+)?(\d{1,2})\.(\d{1,2})\.(\d{2,4})", True),
# "spaetestens am 15.01.2025"
(r"sp\u00e4testens\s+(?:am\s+)?(\d{1,2})\.(\d{1,2})\.(\d{2,4})", True),
# "Abgabetermin: 15.01.2025"
(r"(?:Abgabe|Termin|Frist)[:\s]+(\d{1,2})\.(\d{1,2})\.(\d{2,4})", True),
# "innerhalb von 14 Tagen"
(r"innerhalb\s+von\s+(\d+)\s+(?:Tagen|Wochen)", False),
# "bis Ende Januar"
(r"bis\s+(?:Ende\s+)?(Januar|Februar|M\u00e4rz|April|Mai|Juni|Juli|August|September|Oktober|November|Dezember)", False),
]
for pattern, is_specific_date in patterns:
matches = re.finditer(pattern, text, re.IGNORECASE)
for match in matches:
try:
if is_specific_date:
day = int(match.group(1))
month = int(match.group(2))
year = int(match.group(3))
if year < 100:
year += 2000
deadline_date = datetime(year, month, day)
if deadline_date < now:
continue
start = max(0, match.start() - 50)
end = min(len(text), match.end() + 50)
context = text[start:end].strip()
deadlines.append(DeadlineExtraction(
deadline_date=deadline_date,
description=f"Frist: {match.group(0)}",
confidence=0.85,
source_text=context,
is_firm=True,
))
else:
if "Tagen" in pattern or "Wochen" in pattern:
days = int(match.group(1))
if "Wochen" in match.group(0).lower():
days *= 7
deadline_date = now + timedelta(days=days)
deadlines.append(DeadlineExtraction(
deadline_date=deadline_date,
description=f"Relative Frist: {match.group(0)}",
confidence=0.7,
source_text=match.group(0),
is_firm=False,
))
except (ValueError, IndexError) as e:
logger.debug(f"Failed to parse date: {e}")
continue
return deadlines
async def _extract_deadlines_llm(
client: httpx.AsyncClient,
subject: str,
body_preview: str,
) -> List[DeadlineExtraction]:
"""Extract deadlines using LLM."""
try:
prompt = f"""Analysiere diese E-Mail und extrahiere alle genannten Fristen und Termine:
Betreff: {subject}
Inhalt: {body_preview}
Liste alle Fristen im folgenden Format auf (eine pro Zeile):
DATUM|BESCHREIBUNG|VERBINDLICH
Beispiel: 2025-01-15|Abgabe der Berichte|ja
Wenn keine Fristen gefunden werden, antworte mit: KEINE_FRISTEN
Antworte NUR im angegebenen Format.
"""
response = await client.post(
f"{LLM_GATEWAY_URL}/api/v1/inference",
json={
"prompt": prompt,
"playbook": "mail_analysis",
"max_tokens": 200,
},
)
if response.status_code == 200:
data = response.json()
result_text = data.get("response", "")
if "KEINE_FRISTEN" in result_text:
return []
deadlines = []
for line in result_text.strip().split("\n"):
parts = line.split("|")
if len(parts) >= 2:
try:
date_str = parts[0].strip()
deadline_date = datetime.fromisoformat(date_str)
description = parts[1].strip()
is_firm = parts[2].strip().lower() == "ja" if len(parts) > 2 else True
deadlines.append(DeadlineExtraction(
deadline_date=deadline_date,
description=description,
confidence=0.7,
source_text=line,
is_firm=is_firm,
))
except (ValueError, IndexError):
continue
return deadlines
except Exception as e:
logger.warning(f"LLM deadline extraction failed: {e}")
return []

View File

@@ -0,0 +1,134 @@
"""
AI Email - Sender Classification
Domain-based and LLM-based sender classification for emails.
Extracted from ai_service.py to keep files under 500 LOC.
"""
import os
import logging
from typing import Optional
import httpx
from .models import (
SenderType,
SenderClassification,
classify_sender_by_domain,
)
logger = logging.getLogger(__name__)
# LLM Gateway configuration
LLM_GATEWAY_URL = os.getenv("LLM_GATEWAY_URL", "http://localhost:8090")
async def classify_sender(
http_client: httpx.AsyncClient,
sender_email: str,
sender_name: Optional[str] = None,
subject: Optional[str] = None,
body_preview: Optional[str] = None,
) -> SenderClassification:
"""
Classify the sender of an email.
First tries domain matching, then falls back to LLM.
"""
# Try domain-based classification first (fast, high confidence)
domain_result = classify_sender_by_domain(sender_email)
if domain_result:
return domain_result
# Fall back to LLM classification
return await _classify_sender_llm(
http_client, sender_email, sender_name, subject, body_preview
)
async def _classify_sender_llm(
client: httpx.AsyncClient,
sender_email: str,
sender_name: Optional[str],
subject: Optional[str],
body_preview: Optional[str],
) -> SenderClassification:
"""Classify sender using LLM."""
try:
prompt = f"""Analysiere den Absender dieser E-Mail und klassifiziere ihn:
Absender E-Mail: {sender_email}
Absender Name: {sender_name or "Nicht angegeben"}
Betreff: {subject or "Nicht angegeben"}
Vorschau: {body_preview[:200] if body_preview else "Nicht verfuegbar"}
Klassifiziere den Absender in EINE der folgenden Kategorien:
- kultusministerium: Kultusministerium/Bildungsministerium
- landesschulbehoerde: Landesschulbehoerde
- rlsb: Regionales Landesamt fuer Schule und Bildung
- schulamt: Schulamt
- nibis: Niedersaechsischer Bildungsserver
- schultraeger: Schultraeger/Kommune
- elternvertreter: Elternvertreter/Elternrat
- gewerkschaft: Gewerkschaft (GEW, VBE, etc.)
- fortbildungsinstitut: Fortbildungsinstitut (NLQ, etc.)
- privatperson: Privatperson
- unternehmen: Unternehmen/Firma
- unbekannt: Nicht einzuordnen
Antworte NUR mit dem Kategorienamen (z.B. "kultusministerium") und einer Konfidenz von 0.0 bis 1.0.
Format: kategorie|konfidenz|kurze_begruendung
"""
response = await client.post(
f"{LLM_GATEWAY_URL}/api/v1/inference",
json={
"prompt": prompt,
"playbook": "mail_analysis",
"max_tokens": 100,
},
)
if response.status_code == 200:
data = response.json()
result_text = data.get("response", "unbekannt|0.5|")
parts = result_text.strip().split("|")
if len(parts) >= 2:
sender_type_str = parts[0].strip().lower()
confidence = float(parts[1].strip())
type_mapping = {
"kultusministerium": SenderType.KULTUSMINISTERIUM,
"landesschulbehoerde": SenderType.LANDESSCHULBEHOERDE,
"rlsb": SenderType.RLSB,
"schulamt": SenderType.SCHULAMT,
"nibis": SenderType.NIBIS,
"schultraeger": SenderType.SCHULTRAEGER,
"elternvertreter": SenderType.ELTERNVERTRETER,
"gewerkschaft": SenderType.GEWERKSCHAFT,
"fortbildungsinstitut": SenderType.FORTBILDUNGSINSTITUT,
"privatperson": SenderType.PRIVATPERSON,
"unternehmen": SenderType.UNTERNEHMEN,
}
sender_type = type_mapping.get(sender_type_str, SenderType.UNBEKANNT)
return SenderClassification(
sender_type=sender_type,
confidence=min(max(confidence, 0.0), 1.0),
domain_matched=False,
ai_classified=True,
)
except Exception as e:
logger.warning(f"LLM sender classification failed: {e}")
# Default fallback
return SenderClassification(
sender_type=SenderType.UNBEKANNT,
confidence=0.3,
domain_matched=False,
ai_classified=False,
)

View File

@@ -1,18 +1,19 @@
"""
AI Email Analysis Service
AI Email Analysis Service — Barrel Re-export
KI-powered email analysis with:
- Sender classification (authority recognition)
- Deadline extraction
- Category classification
- Response suggestions
Split into:
- mail/ai_sender.py — Sender classification (domain + LLM)
- mail/ai_deadline.py — Deadline extraction (regex + LLM)
- mail/ai_category.py — Category classification + response suggestions
The AIEmailService class and get_ai_email_service() are defined here
to maintain the original public API.
"""
import os
import re
import logging
from typing import Optional, List, Dict, Any, Tuple
from datetime import datetime, timedelta
from typing import Optional, List, Tuple
from datetime import datetime
import httpx
from .models import (
@@ -23,17 +24,15 @@ from .models import (
DeadlineExtraction,
EmailAnalysisResult,
ResponseSuggestion,
KNOWN_AUTHORITIES_NI,
classify_sender_by_domain,
get_priority_from_sender_type,
)
from .mail_db import update_email_ai_analysis
from .ai_sender import classify_sender, LLM_GATEWAY_URL
from .ai_deadline import extract_deadlines
from .ai_category import classify_category, suggest_response
logger = logging.getLogger(__name__)
# LLM Gateway configuration
LLM_GATEWAY_URL = os.getenv("LLM_GATEWAY_URL", "http://localhost:8090")
class AIEmailService:
"""
@@ -56,10 +55,6 @@ class AIEmailService:
self._http_client = httpx.AsyncClient(timeout=30.0)
return self._http_client
# =========================================================================
# Sender Classification
# =========================================================================
async def classify_sender(
self,
sender_email: str,
@@ -67,300 +62,20 @@ class AIEmailService:
subject: Optional[str] = None,
body_preview: Optional[str] = None,
) -> SenderClassification:
"""
Classify the sender of an email.
First tries domain matching, then falls back to LLM.
Args:
sender_email: Sender's email address
sender_name: Sender's display name
subject: Email subject
body_preview: First 200 chars of body
Returns:
SenderClassification with type and confidence
"""
# Try domain-based classification first (fast, high confidence)
domain_result = classify_sender_by_domain(sender_email)
if domain_result:
return domain_result
# Fall back to LLM classification
return await self._classify_sender_llm(
sender_email, sender_name, subject, body_preview
"""Classify the sender of an email."""
client = await self.get_http_client()
return await classify_sender(
client, sender_email, sender_name, subject, body_preview
)
async def _classify_sender_llm(
self,
sender_email: str,
sender_name: Optional[str],
subject: Optional[str],
body_preview: Optional[str],
) -> SenderClassification:
"""Classify sender using LLM."""
try:
client = await self.get_http_client()
prompt = f"""Analysiere den Absender dieser E-Mail und klassifiziere ihn:
Absender E-Mail: {sender_email}
Absender Name: {sender_name or "Nicht angegeben"}
Betreff: {subject or "Nicht angegeben"}
Vorschau: {body_preview[:200] if body_preview else "Nicht verfügbar"}
Klassifiziere den Absender in EINE der folgenden Kategorien:
- kultusministerium: Kultusministerium/Bildungsministerium
- landesschulbehoerde: Landesschulbehörde
- rlsb: Regionales Landesamt für Schule und Bildung
- schulamt: Schulamt
- nibis: Niedersächsischer Bildungsserver
- schultraeger: Schulträger/Kommune
- elternvertreter: Elternvertreter/Elternrat
- gewerkschaft: Gewerkschaft (GEW, VBE, etc.)
- fortbildungsinstitut: Fortbildungsinstitut (NLQ, etc.)
- privatperson: Privatperson
- unternehmen: Unternehmen/Firma
- unbekannt: Nicht einzuordnen
Antworte NUR mit dem Kategorienamen (z.B. "kultusministerium") und einer Konfidenz von 0.0 bis 1.0.
Format: kategorie|konfidenz|kurze_begründung
"""
response = await client.post(
f"{LLM_GATEWAY_URL}/api/v1/inference",
json={
"prompt": prompt,
"playbook": "mail_analysis",
"max_tokens": 100,
},
)
if response.status_code == 200:
data = response.json()
result_text = data.get("response", "unbekannt|0.5|")
# Parse response
parts = result_text.strip().split("|")
if len(parts) >= 2:
sender_type_str = parts[0].strip().lower()
confidence = float(parts[1].strip())
# Map to enum
type_mapping = {
"kultusministerium": SenderType.KULTUSMINISTERIUM,
"landesschulbehoerde": SenderType.LANDESSCHULBEHOERDE,
"rlsb": SenderType.RLSB,
"schulamt": SenderType.SCHULAMT,
"nibis": SenderType.NIBIS,
"schultraeger": SenderType.SCHULTRAEGER,
"elternvertreter": SenderType.ELTERNVERTRETER,
"gewerkschaft": SenderType.GEWERKSCHAFT,
"fortbildungsinstitut": SenderType.FORTBILDUNGSINSTITUT,
"privatperson": SenderType.PRIVATPERSON,
"unternehmen": SenderType.UNTERNEHMEN,
}
sender_type = type_mapping.get(sender_type_str, SenderType.UNBEKANNT)
return SenderClassification(
sender_type=sender_type,
confidence=min(max(confidence, 0.0), 1.0),
domain_matched=False,
ai_classified=True,
)
except Exception as e:
logger.warning(f"LLM sender classification failed: {e}")
# Default fallback
return SenderClassification(
sender_type=SenderType.UNBEKANNT,
confidence=0.3,
domain_matched=False,
ai_classified=False,
)
# =========================================================================
# Deadline Extraction
# =========================================================================
async def extract_deadlines(
self,
subject: str,
body_text: str,
) -> List[DeadlineExtraction]:
"""
Extract deadlines from email content.
Uses regex patterns first, then LLM for complex cases.
Args:
subject: Email subject
body_text: Email body text
Returns:
List of extracted deadlines
"""
deadlines = []
# Combine subject and body
full_text = f"{subject}\n{body_text}" if body_text else subject
# Try regex extraction first
regex_deadlines = self._extract_deadlines_regex(full_text)
deadlines.extend(regex_deadlines)
# If no regex matches, try LLM
if not deadlines and body_text:
llm_deadlines = await self._extract_deadlines_llm(subject, body_text[:1000])
deadlines.extend(llm_deadlines)
return deadlines
def _extract_deadlines_regex(self, text: str) -> List[DeadlineExtraction]:
"""Extract deadlines using regex patterns."""
deadlines = []
now = datetime.now()
# German date patterns
patterns = [
# "bis zum 15.01.2025"
(r"bis\s+(?:zum\s+)?(\d{1,2})\.(\d{1,2})\.(\d{2,4})", True),
# "spätestens am 15.01.2025"
(r"spätestens\s+(?:am\s+)?(\d{1,2})\.(\d{1,2})\.(\d{2,4})", True),
# "Abgabetermin: 15.01.2025"
(r"(?:Abgabe|Termin|Frist)[:\s]+(\d{1,2})\.(\d{1,2})\.(\d{2,4})", True),
# "innerhalb von 14 Tagen"
(r"innerhalb\s+von\s+(\d+)\s+(?:Tagen|Wochen)", False),
# "bis Ende Januar"
(r"bis\s+(?:Ende\s+)?(Januar|Februar|März|April|Mai|Juni|Juli|August|September|Oktober|November|Dezember)", False),
]
for pattern, is_specific_date in patterns:
matches = re.finditer(pattern, text, re.IGNORECASE)
for match in matches:
try:
if is_specific_date:
day = int(match.group(1))
month = int(match.group(2))
year = int(match.group(3))
# Handle 2-digit years
if year < 100:
year += 2000
deadline_date = datetime(year, month, day)
# Skip past dates
if deadline_date < now:
continue
# Get surrounding context
start = max(0, match.start() - 50)
end = min(len(text), match.end() + 50)
context = text[start:end].strip()
deadlines.append(DeadlineExtraction(
deadline_date=deadline_date,
description=f"Frist: {match.group(0)}",
confidence=0.85,
source_text=context,
is_firm=True,
))
else:
# Relative dates (innerhalb von X Tagen)
if "Tagen" in pattern or "Wochen" in pattern:
days = int(match.group(1))
if "Wochen" in match.group(0).lower():
days *= 7
deadline_date = now + timedelta(days=days)
deadlines.append(DeadlineExtraction(
deadline_date=deadline_date,
description=f"Relative Frist: {match.group(0)}",
confidence=0.7,
source_text=match.group(0),
is_firm=False,
))
except (ValueError, IndexError) as e:
logger.debug(f"Failed to parse date: {e}")
continue
return deadlines
async def _extract_deadlines_llm(
self,
subject: str,
body_preview: str,
) -> List[DeadlineExtraction]:
"""Extract deadlines using LLM."""
try:
client = await self.get_http_client()
prompt = f"""Analysiere diese E-Mail und extrahiere alle genannten Fristen und Termine:
Betreff: {subject}
Inhalt: {body_preview}
Liste alle Fristen im folgenden Format auf (eine pro Zeile):
DATUM|BESCHREIBUNG|VERBINDLICH
Beispiel: 2025-01-15|Abgabe der Berichte|ja
Wenn keine Fristen gefunden werden, antworte mit: KEINE_FRISTEN
Antworte NUR im angegebenen Format.
"""
response = await client.post(
f"{LLM_GATEWAY_URL}/api/v1/inference",
json={
"prompt": prompt,
"playbook": "mail_analysis",
"max_tokens": 200,
},
)
if response.status_code == 200:
data = response.json()
result_text = data.get("response", "")
if "KEINE_FRISTEN" in result_text:
return []
deadlines = []
for line in result_text.strip().split("\n"):
parts = line.split("|")
if len(parts) >= 2:
try:
date_str = parts[0].strip()
deadline_date = datetime.fromisoformat(date_str)
description = parts[1].strip()
is_firm = parts[2].strip().lower() == "ja" if len(parts) > 2 else True
deadlines.append(DeadlineExtraction(
deadline_date=deadline_date,
description=description,
confidence=0.7,
source_text=line,
is_firm=is_firm,
))
except (ValueError, IndexError):
continue
return deadlines
except Exception as e:
logger.warning(f"LLM deadline extraction failed: {e}")
return []
# =========================================================================
# Email Category Classification
# =========================================================================
"""Extract deadlines from email content."""
client = await self.get_http_client()
return await extract_deadlines(client, subject, body_text)
async def classify_category(
self,
@@ -368,155 +83,9 @@ Antworte NUR im angegebenen Format.
body_preview: str,
sender_type: SenderType,
) -> Tuple[EmailCategory, float]:
"""
Classify email into a category.
Args:
subject: Email subject
body_preview: First 200 chars of body
sender_type: Already classified sender type
Returns:
Tuple of (category, confidence)
"""
# Rule-based classification first
category, confidence = self._classify_category_rules(subject, body_preview, sender_type)
if confidence > 0.7:
return category, confidence
# Fall back to LLM
return await self._classify_category_llm(subject, body_preview)
def _classify_category_rules(
self,
subject: str,
body_preview: str,
sender_type: SenderType,
) -> Tuple[EmailCategory, float]:
"""Rule-based category classification."""
text = f"{subject} {body_preview}".lower()
# Keywords for each category
category_keywords = {
EmailCategory.DIENSTLICH: [
"dienstlich", "dienstanweisung", "erlass", "verordnung",
"bescheid", "verfügung", "ministerium", "behörde"
],
EmailCategory.PERSONAL: [
"personalrat", "stellenausschreibung", "versetzung",
"beurteilung", "dienstzeugnis", "krankmeldung", "elternzeit"
],
EmailCategory.FINANZEN: [
"budget", "haushalt", "etat", "abrechnung", "rechnung",
"erstattung", "zuschuss", "fördermittel"
],
EmailCategory.ELTERN: [
"elternbrief", "elternabend", "schulkonferenz",
"elternvertreter", "elternbeirat"
],
EmailCategory.SCHUELER: [
"schüler", "schülerin", "zeugnis", "klasse", "unterricht",
"prüfung", "klassenfahrt", "schulpflicht"
],
EmailCategory.FORTBILDUNG: [
"fortbildung", "seminar", "workshop", "schulung",
"weiterbildung", "nlq", "didaktik"
],
EmailCategory.VERANSTALTUNG: [
"einladung", "veranstaltung", "termin", "konferenz",
"sitzung", "tagung", "feier"
],
EmailCategory.SICHERHEIT: [
"sicherheit", "notfall", "brandschutz", "evakuierung",
"hygiene", "corona", "infektionsschutz"
],
EmailCategory.TECHNIK: [
"it", "software", "computer", "netzwerk", "login",
"passwort", "digitalisierung", "iserv"
],
EmailCategory.NEWSLETTER: [
"newsletter", "rundschreiben", "info-mail", "mitteilung"
],
EmailCategory.WERBUNG: [
"angebot", "rabatt", "aktion", "werbung", "abonnement"
],
}
best_category = EmailCategory.SONSTIGES
best_score = 0.0
for category, keywords in category_keywords.items():
score = sum(1 for kw in keywords if kw in text)
if score > best_score:
best_score = score
best_category = category
# Adjust based on sender type
if sender_type in [SenderType.KULTUSMINISTERIUM, SenderType.LANDESSCHULBEHOERDE, SenderType.RLSB]:
if best_category == EmailCategory.SONSTIGES:
best_category = EmailCategory.DIENSTLICH
best_score = 2
# Convert score to confidence
confidence = min(0.9, 0.4 + (best_score * 0.15))
return best_category, confidence
async def _classify_category_llm(
self,
subject: str,
body_preview: str,
) -> Tuple[EmailCategory, float]:
"""LLM-based category classification."""
try:
client = await self.get_http_client()
categories = ", ".join([c.value for c in EmailCategory])
prompt = f"""Klassifiziere diese E-Mail in EINE Kategorie:
Betreff: {subject}
Inhalt: {body_preview[:500]}
Kategorien: {categories}
Antworte NUR mit dem Kategorienamen und einer Konfidenz (0.0-1.0):
Format: kategorie|konfidenz
"""
response = await client.post(
f"{LLM_GATEWAY_URL}/api/v1/inference",
json={
"prompt": prompt,
"playbook": "mail_analysis",
"max_tokens": 50,
},
)
if response.status_code == 200:
data = response.json()
result = data.get("response", "sonstiges|0.5")
parts = result.strip().split("|")
if len(parts) >= 2:
category_str = parts[0].strip().lower()
confidence = float(parts[1].strip())
try:
category = EmailCategory(category_str)
return category, min(max(confidence, 0.0), 1.0)
except ValueError:
pass
except Exception as e:
logger.warning(f"LLM category classification failed: {e}")
return EmailCategory.SONSTIGES, 0.5
# =========================================================================
# Full Analysis Pipeline
# =========================================================================
"""Classify email into a category."""
client = await self.get_http_client()
return await classify_category(client, subject, body_preview, sender_type)
async def analyze_email(
self,
@@ -527,20 +96,7 @@ Format: kategorie|konfidenz
body_text: Optional[str],
body_preview: Optional[str],
) -> EmailAnalysisResult:
"""
Run full analysis pipeline on an email.
Args:
email_id: Database ID of the email
sender_email: Sender's email address
sender_name: Sender's display name
subject: Email subject
body_text: Full body text
body_preview: Preview text
Returns:
Complete analysis result
"""
"""Run full analysis pipeline on an email."""
# 1. Classify sender
sender_classification = await self.classify_sender(
sender_email, sender_name, subject, body_preview
@@ -569,8 +125,8 @@ Format: kategorie|konfidenz
elif days_until <= 7:
suggested_priority = max(suggested_priority, TaskPriority.MEDIUM)
# 5. Generate summary (optional, can be expensive)
summary = None # Could add LLM summary generation here
# 5. Summary (optional)
summary = None
# 6. Determine if task should be auto-created
auto_create_task = (
@@ -612,10 +168,6 @@ Format: kategorie|konfidenz
auto_create_task=auto_create_task,
)
# =========================================================================
# Response Suggestions
# =========================================================================
async def suggest_response(
self,
subject: str,
@@ -623,114 +175,11 @@ Format: kategorie|konfidenz
sender_type: SenderType,
category: EmailCategory,
) -> List[ResponseSuggestion]:
"""
Generate response suggestions for an email.
Args:
subject: Original email subject
body_text: Original email body
sender_type: Classified sender type
category: Classified category
Returns:
List of response suggestions
"""
suggestions = []
# Add standard templates based on sender type and category
if sender_type in [SenderType.KULTUSMINISTERIUM, SenderType.LANDESSCHULBEHOERDE, SenderType.RLSB]:
suggestions.append(ResponseSuggestion(
template_type="acknowledgment",
subject=f"Re: {subject}",
body="""Sehr geehrte Damen und Herren,
vielen Dank für Ihre Nachricht.
Ich bestätige den Eingang und werde die Angelegenheit fristgerecht bearbeiten.
Mit freundlichen Grüßen""",
confidence=0.8,
))
if category == EmailCategory.ELTERN:
suggestions.append(ResponseSuggestion(
template_type="parent_response",
subject=f"Re: {subject}",
body="""Liebe Eltern,
vielen Dank für Ihre Nachricht.
[Ihre Antwort hier]
Mit freundlichen Grüßen""",
confidence=0.7,
))
# Add LLM-generated suggestion
try:
llm_suggestion = await self._generate_response_llm(subject, body_text[:500], sender_type)
if llm_suggestion:
suggestions.append(llm_suggestion)
except Exception as e:
logger.warning(f"LLM response generation failed: {e}")
return suggestions
async def _generate_response_llm(
self,
subject: str,
body_preview: str,
sender_type: SenderType,
) -> Optional[ResponseSuggestion]:
"""Generate a response suggestion using LLM."""
try:
client = await self.get_http_client()
sender_desc = {
SenderType.KULTUSMINISTERIUM: "dem Kultusministerium",
SenderType.LANDESSCHULBEHOERDE: "der Landesschulbehörde",
SenderType.RLSB: "dem RLSB",
SenderType.ELTERNVERTRETER: "einem Elternvertreter",
}.get(sender_type, "einem Absender")
prompt = f"""Du bist eine Schulleiterin in Niedersachsen. Formuliere eine professionelle, kurze Antwort auf diese E-Mail von {sender_desc}:
Betreff: {subject}
Inhalt: {body_preview}
Die Antwort sollte:
- Höflich und formell sein
- Den Eingang bestätigen
- Eine konkrete nächste Aktion nennen oder um Klärung bitten
Antworte NUR mit dem Antworttext (ohne Betreffzeile, ohne "Betreff:").
"""
response = await client.post(
f"{LLM_GATEWAY_URL}/api/v1/inference",
json={
"prompt": prompt,
"playbook": "mail_analysis",
"max_tokens": 300,
},
)
if response.status_code == 200:
data = response.json()
body = data.get("response", "").strip()
if body:
return ResponseSuggestion(
template_type="ai_generated",
subject=f"Re: {subject}",
body=body,
confidence=0.6,
)
except Exception as e:
logger.warning(f"LLM response generation failed: {e}")
return None
"""Generate response suggestions for an email."""
client = await self.get_http_client()
return await suggest_response(
client, subject, body_text, sender_type, category
)
# Global instance

View File

@@ -1,833 +1,36 @@
"""
PostgreSQL Metrics Database Service
Stores search feedback, calculates quality metrics (Precision, Recall, MRR).
PostgreSQL Metrics Database Service — Barrel Re-export
Split into:
- metrics_db_core.py — Pool, feedback, metrics, relevance
- metrics_db_schema.py — Table initialization (DDL)
- metrics_db_zeugnis.py — Zeugnis source/document/stats operations
All public names are re-exported here for backward compatibility.
"""
import os
from typing import Optional, List, Dict
from datetime import datetime, timedelta
import asyncio
# Database Configuration - uses test default if not configured (for CI)
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://test:test@localhost:5432/test_metrics")
# Connection pool
_pool = None
async def get_pool():
"""Get or create database connection pool."""
global _pool
if _pool is None:
try:
import asyncpg
_pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10)
except ImportError:
print("Warning: asyncpg not installed. Metrics storage disabled.")
return None
except Exception as e:
print(f"Warning: Failed to connect to PostgreSQL: {e}")
return None
return _pool
async def init_metrics_tables() -> bool:
"""Initialize metrics tables in PostgreSQL."""
pool = await get_pool()
if pool is None:
return False
create_tables_sql = """
-- RAG Search Feedback Table
CREATE TABLE IF NOT EXISTS rag_search_feedback (
id SERIAL PRIMARY KEY,
result_id VARCHAR(255) NOT NULL,
query_text TEXT,
collection_name VARCHAR(100),
score FLOAT,
rating INTEGER CHECK (rating >= 1 AND rating <= 5),
notes TEXT,
user_id VARCHAR(100),
created_at TIMESTAMP DEFAULT NOW()
);
-- Index for efficient querying
CREATE INDEX IF NOT EXISTS idx_feedback_created_at ON rag_search_feedback(created_at);
CREATE INDEX IF NOT EXISTS idx_feedback_collection ON rag_search_feedback(collection_name);
CREATE INDEX IF NOT EXISTS idx_feedback_rating ON rag_search_feedback(rating);
-- RAG Search Logs Table (for latency tracking)
CREATE TABLE IF NOT EXISTS rag_search_logs (
id SERIAL PRIMARY KEY,
query_text TEXT NOT NULL,
collection_name VARCHAR(100),
result_count INTEGER,
latency_ms INTEGER,
top_score FLOAT,
filters JSONB,
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_search_logs_created_at ON rag_search_logs(created_at);
-- RAG Upload History Table
CREATE TABLE IF NOT EXISTS rag_upload_history (
id SERIAL PRIMARY KEY,
filename VARCHAR(500) NOT NULL,
collection_name VARCHAR(100),
year INTEGER,
pdfs_extracted INTEGER,
minio_path VARCHAR(1000),
uploaded_by VARCHAR(100),
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_upload_history_created_at ON rag_upload_history(created_at);
-- Binäre Relevanz-Judgments für echte Precision/Recall
CREATE TABLE IF NOT EXISTS rag_relevance_judgments (
id SERIAL PRIMARY KEY,
query_id VARCHAR(255) NOT NULL,
query_text TEXT NOT NULL,
result_id VARCHAR(255) NOT NULL,
result_rank INTEGER,
is_relevant BOOLEAN NOT NULL,
collection_name VARCHAR(100),
user_id VARCHAR(100),
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_relevance_query ON rag_relevance_judgments(query_id);
CREATE INDEX IF NOT EXISTS idx_relevance_created_at ON rag_relevance_judgments(created_at);
-- Zeugnisse Source Tracking
CREATE TABLE IF NOT EXISTS zeugnis_sources (
id VARCHAR(36) PRIMARY KEY,
bundesland VARCHAR(10) NOT NULL,
name VARCHAR(255) NOT NULL,
base_url TEXT,
license_type VARCHAR(50) NOT NULL,
training_allowed BOOLEAN DEFAULT FALSE,
verified_by VARCHAR(100),
verified_at TIMESTAMP,
created_at TIMESTAMP DEFAULT NOW(),
updated_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_zeugnis_sources_bundesland ON zeugnis_sources(bundesland);
-- Zeugnisse Seed URLs
CREATE TABLE IF NOT EXISTS zeugnis_seed_urls (
id VARCHAR(36) PRIMARY KEY,
source_id VARCHAR(36) REFERENCES zeugnis_sources(id),
url TEXT NOT NULL,
doc_type VARCHAR(50),
status VARCHAR(20) DEFAULT 'pending',
last_crawled TIMESTAMP,
error_message TEXT,
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_zeugnis_seed_urls_source ON zeugnis_seed_urls(source_id);
CREATE INDEX IF NOT EXISTS idx_zeugnis_seed_urls_status ON zeugnis_seed_urls(status);
-- Zeugnisse Documents
CREATE TABLE IF NOT EXISTS zeugnis_documents (
id VARCHAR(36) PRIMARY KEY,
seed_url_id VARCHAR(36) REFERENCES zeugnis_seed_urls(id),
title VARCHAR(500),
url TEXT NOT NULL,
content_hash VARCHAR(64),
minio_path TEXT,
training_allowed BOOLEAN DEFAULT FALSE,
indexed_in_qdrant BOOLEAN DEFAULT FALSE,
file_size INTEGER,
content_type VARCHAR(100),
created_at TIMESTAMP DEFAULT NOW(),
updated_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_zeugnis_documents_seed ON zeugnis_documents(seed_url_id);
CREATE INDEX IF NOT EXISTS idx_zeugnis_documents_hash ON zeugnis_documents(content_hash);
-- Zeugnisse Document Versions
CREATE TABLE IF NOT EXISTS zeugnis_document_versions (
id VARCHAR(36) PRIMARY KEY,
document_id VARCHAR(36) REFERENCES zeugnis_documents(id),
version INTEGER NOT NULL,
content_hash VARCHAR(64),
minio_path TEXT,
change_summary TEXT,
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_zeugnis_versions_doc ON zeugnis_document_versions(document_id);
-- Zeugnisse Usage Events (Audit Trail)
CREATE TABLE IF NOT EXISTS zeugnis_usage_events (
id VARCHAR(36) PRIMARY KEY,
document_id VARCHAR(36) REFERENCES zeugnis_documents(id),
event_type VARCHAR(50) NOT NULL,
user_id VARCHAR(100),
details JSONB,
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_zeugnis_events_doc ON zeugnis_usage_events(document_id);
CREATE INDEX IF NOT EXISTS idx_zeugnis_events_type ON zeugnis_usage_events(event_type);
CREATE INDEX IF NOT EXISTS idx_zeugnis_events_created ON zeugnis_usage_events(created_at);
-- Crawler Queue
CREATE TABLE IF NOT EXISTS zeugnis_crawler_queue (
id VARCHAR(36) PRIMARY KEY,
source_id VARCHAR(36) REFERENCES zeugnis_sources(id),
priority INTEGER DEFAULT 5,
status VARCHAR(20) DEFAULT 'pending',
started_at TIMESTAMP,
completed_at TIMESTAMP,
documents_found INTEGER DEFAULT 0,
documents_indexed INTEGER DEFAULT 0,
error_count INTEGER DEFAULT 0,
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_crawler_queue_status ON zeugnis_crawler_queue(status);
"""
try:
async with pool.acquire() as conn:
await conn.execute(create_tables_sql)
print("RAG metrics tables initialized")
return True
except Exception as e:
print(f"Failed to initialize metrics tables: {e}")
return False
# =============================================================================
# Feedback Storage
# =============================================================================
async def store_feedback(
result_id: str,
rating: int,
query_text: Optional[str] = None,
collection_name: Optional[str] = None,
score: Optional[float] = None,
notes: Optional[str] = None,
user_id: Optional[str] = None,
) -> bool:
"""Store search result feedback."""
pool = await get_pool()
if pool is None:
return False
try:
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO rag_search_feedback
(result_id, query_text, collection_name, score, rating, notes, user_id)
VALUES ($1, $2, $3, $4, $5, $6, $7)
""",
result_id, query_text, collection_name, score, rating, notes, user_id
)
return True
except Exception as e:
print(f"Failed to store feedback: {e}")
return False
async def log_search(
query_text: str,
collection_name: str,
result_count: int,
latency_ms: int,
top_score: Optional[float] = None,
filters: Optional[Dict] = None,
) -> bool:
"""Log a search for metrics tracking."""
pool = await get_pool()
if pool is None:
return False
try:
import json
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO rag_search_logs
(query_text, collection_name, result_count, latency_ms, top_score, filters)
VALUES ($1, $2, $3, $4, $5, $6)
""",
query_text, collection_name, result_count, latency_ms, top_score,
json.dumps(filters) if filters else None
)
return True
except Exception as e:
print(f"Failed to log search: {e}")
return False
async def log_upload(
filename: str,
collection_name: str,
year: int,
pdfs_extracted: int,
minio_path: Optional[str] = None,
uploaded_by: Optional[str] = None,
) -> bool:
"""Log an upload for history tracking."""
pool = await get_pool()
if pool is None:
return False
try:
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO rag_upload_history
(filename, collection_name, year, pdfs_extracted, minio_path, uploaded_by)
VALUES ($1, $2, $3, $4, $5, $6)
""",
filename, collection_name, year, pdfs_extracted, minio_path, uploaded_by
)
return True
except Exception as e:
print(f"Failed to log upload: {e}")
return False
# =============================================================================
# Metrics Calculation
# =============================================================================
async def calculate_metrics(
collection_name: Optional[str] = None,
days: int = 7,
) -> Dict:
"""
Calculate RAG quality metrics from stored feedback.
Returns:
Dict with precision, recall, MRR, latency, etc.
"""
pool = await get_pool()
if pool is None:
return {"error": "Database not available", "connected": False}
try:
async with pool.acquire() as conn:
# Date filter
since = datetime.now() - timedelta(days=days)
# Collection filter
collection_filter = ""
params = [since]
if collection_name:
collection_filter = "AND collection_name = $2"
params.append(collection_name)
# Total feedback count
total_feedback = await conn.fetchval(
f"""
SELECT COUNT(*) FROM rag_search_feedback
WHERE created_at >= $1 {collection_filter}
""",
*params
)
# Rating distribution
rating_dist = await conn.fetch(
f"""
SELECT rating, COUNT(*) as count
FROM rag_search_feedback
WHERE created_at >= $1 {collection_filter}
GROUP BY rating
ORDER BY rating DESC
""",
*params
)
# Average rating (proxy for precision)
avg_rating = await conn.fetchval(
f"""
SELECT AVG(rating) FROM rag_search_feedback
WHERE created_at >= $1 {collection_filter}
""",
*params
)
# Score distribution
score_dist = await conn.fetch(
f"""
SELECT
CASE
WHEN score >= 0.9 THEN '0.9+'
WHEN score >= 0.7 THEN '0.7-0.9'
WHEN score >= 0.5 THEN '0.5-0.7'
ELSE '<0.5'
END as range,
COUNT(*) as count
FROM rag_search_feedback
WHERE created_at >= $1 AND score IS NOT NULL {collection_filter}
GROUP BY range
ORDER BY range DESC
""",
*params
)
# Search logs for latency
latency_stats = await conn.fetchrow(
f"""
SELECT
AVG(latency_ms) as avg_latency,
COUNT(*) as total_searches,
AVG(result_count) as avg_results
FROM rag_search_logs
WHERE created_at >= $1 {collection_filter.replace('collection_name', 'collection_name')}
""",
*params
)
# Calculate precision@5 (% of top 5 rated 4+)
precision_at_5 = await conn.fetchval(
f"""
SELECT
CASE WHEN COUNT(*) > 0
THEN CAST(SUM(CASE WHEN rating >= 4 THEN 1 ELSE 0 END) AS FLOAT) / COUNT(*)
ELSE 0 END
FROM rag_search_feedback
WHERE created_at >= $1 {collection_filter}
""",
*params
) or 0
# Calculate MRR (Mean Reciprocal Rank) - simplified
# Using average rating as proxy for relevance
mrr = (avg_rating or 0) / 5.0
# Error rate (ratings of 1 or 2)
error_count = sum(
r['count'] for r in rating_dist if r['rating'] and r['rating'] <= 2
)
error_rate = (error_count / total_feedback * 100) if total_feedback > 0 else 0
# Score distribution as percentages
total_scored = sum(s['count'] for s in score_dist)
score_distribution = {}
for s in score_dist:
if total_scored > 0:
score_distribution[s['range']] = round(s['count'] / total_scored * 100)
else:
score_distribution[s['range']] = 0
return {
"connected": True,
"period_days": days,
"precision_at_5": round(precision_at_5, 2),
"recall_at_10": round(precision_at_5 * 1.1, 2), # Estimated
"mrr": round(mrr, 2),
"avg_latency_ms": round(latency_stats['avg_latency'] or 0),
"total_ratings": total_feedback,
"total_searches": latency_stats['total_searches'] or 0,
"error_rate": round(error_rate, 1),
"score_distribution": score_distribution,
"rating_distribution": {
str(r['rating']): r['count'] for r in rating_dist if r['rating']
},
}
except Exception as e:
print(f"Failed to calculate metrics: {e}")
return {"error": str(e), "connected": False}
async def get_recent_feedback(limit: int = 20) -> List[Dict]:
"""Get recent feedback entries."""
pool = await get_pool()
if pool is None:
return []
try:
async with pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT result_id, rating, query_text, collection_name, score, notes, created_at
FROM rag_search_feedback
ORDER BY created_at DESC
LIMIT $1
""",
limit
)
return [
{
"result_id": r['result_id'],
"rating": r['rating'],
"query_text": r['query_text'],
"collection_name": r['collection_name'],
"score": r['score'],
"notes": r['notes'],
"created_at": r['created_at'].isoformat() if r['created_at'] else None,
}
for r in rows
]
except Exception as e:
print(f"Failed to get recent feedback: {e}")
return []
async def get_upload_history(limit: int = 20) -> List[Dict]:
"""Get recent upload history."""
pool = await get_pool()
if pool is None:
return []
try:
async with pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT filename, collection_name, year, pdfs_extracted, minio_path, uploaded_by, created_at
FROM rag_upload_history
ORDER BY created_at DESC
LIMIT $1
""",
limit
)
return [
{
"filename": r['filename'],
"collection_name": r['collection_name'],
"year": r['year'],
"pdfs_extracted": r['pdfs_extracted'],
"minio_path": r['minio_path'],
"uploaded_by": r['uploaded_by'],
"created_at": r['created_at'].isoformat() if r['created_at'] else None,
}
for r in rows
]
except Exception as e:
print(f"Failed to get upload history: {e}")
return []
# =============================================================================
# Relevance Judgments (Binary Precision/Recall)
# =============================================================================
async def store_relevance_judgment(
query_id: str,
query_text: str,
result_id: str,
is_relevant: bool,
result_rank: Optional[int] = None,
collection_name: Optional[str] = None,
user_id: Optional[str] = None,
) -> bool:
"""Store binary relevance judgment for Precision/Recall calculation."""
pool = await get_pool()
if pool is None:
return False
try:
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO rag_relevance_judgments
(query_id, query_text, result_id, result_rank, is_relevant, collection_name, user_id)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT DO NOTHING
""",
query_id, query_text, result_id, result_rank, is_relevant, collection_name, user_id
)
return True
except Exception as e:
print(f"Failed to store relevance judgment: {e}")
return False
async def calculate_precision_recall(
collection_name: Optional[str] = None,
days: int = 7,
k: int = 10,
) -> Dict:
"""
Calculate true Precision@k and Recall@k from binary relevance judgments.
Precision@k = (Relevant docs in top k) / k
Recall@k = (Relevant docs in top k) / (Total relevant docs for query)
"""
pool = await get_pool()
if pool is None:
return {"error": "Database not available", "connected": False}
try:
async with pool.acquire() as conn:
since = datetime.now() - timedelta(days=days)
collection_filter = ""
params = [since, k]
if collection_name:
collection_filter = "AND collection_name = $3"
params.append(collection_name)
# Get precision@k per query, then average
precision_result = await conn.fetchval(
f"""
WITH query_precision AS (
SELECT
query_id,
COUNT(CASE WHEN is_relevant THEN 1 END)::FLOAT /
GREATEST(COUNT(*), 1) as precision
FROM rag_relevance_judgments
WHERE created_at >= $1
AND (result_rank IS NULL OR result_rank <= $2)
{collection_filter}
GROUP BY query_id
)
SELECT AVG(precision) FROM query_precision
""",
*params
) or 0
# Get recall@k per query, then average
recall_result = await conn.fetchval(
f"""
WITH query_recall AS (
SELECT
query_id,
COUNT(CASE WHEN is_relevant AND (result_rank IS NULL OR result_rank <= $2) THEN 1 END)::FLOAT /
GREATEST(COUNT(CASE WHEN is_relevant THEN 1 END), 1) as recall
FROM rag_relevance_judgments
WHERE created_at >= $1
{collection_filter}
GROUP BY query_id
)
SELECT AVG(recall) FROM query_recall
""",
*params
) or 0
# Total judgments
total_judgments = await conn.fetchval(
f"""
SELECT COUNT(*) FROM rag_relevance_judgments
WHERE created_at >= $1 {collection_filter}
""",
since, *([collection_name] if collection_name else [])
)
# Unique queries
unique_queries = await conn.fetchval(
f"""
SELECT COUNT(DISTINCT query_id) FROM rag_relevance_judgments
WHERE created_at >= $1 {collection_filter}
""",
since, *([collection_name] if collection_name else [])
)
return {
"connected": True,
"period_days": days,
"k": k,
"precision_at_k": round(precision_result, 3),
"recall_at_k": round(recall_result, 3),
"f1_score": round(
2 * precision_result * recall_result / max(precision_result + recall_result, 0.001), 3
),
"total_judgments": total_judgments or 0,
"unique_queries": unique_queries or 0,
}
except Exception as e:
print(f"Failed to calculate precision/recall: {e}")
return {"error": str(e), "connected": False}
# =============================================================================
# Zeugnis Database Operations
# =============================================================================
async def get_zeugnis_sources() -> List[Dict]:
"""Get all zeugnis sources (Bundesländer)."""
pool = await get_pool()
if pool is None:
return []
try:
async with pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT id, bundesland, name, base_url, license_type, training_allowed,
verified_by, verified_at, created_at, updated_at
FROM zeugnis_sources
ORDER BY bundesland
"""
)
return [dict(r) for r in rows]
except Exception as e:
print(f"Failed to get zeugnis sources: {e}")
return []
async def upsert_zeugnis_source(
id: str,
bundesland: str,
name: str,
license_type: str,
training_allowed: bool,
base_url: Optional[str] = None,
verified_by: Optional[str] = None,
) -> bool:
"""Insert or update a zeugnis source."""
pool = await get_pool()
if pool is None:
return False
try:
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO zeugnis_sources (id, bundesland, name, base_url, license_type, training_allowed, verified_by, verified_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW())
ON CONFLICT (id) DO UPDATE SET
name = EXCLUDED.name,
base_url = EXCLUDED.base_url,
license_type = EXCLUDED.license_type,
training_allowed = EXCLUDED.training_allowed,
verified_by = EXCLUDED.verified_by,
verified_at = NOW(),
updated_at = NOW()
""",
id, bundesland, name, base_url, license_type, training_allowed, verified_by
)
return True
except Exception as e:
print(f"Failed to upsert zeugnis source: {e}")
return False
async def get_zeugnis_documents(
bundesland: Optional[str] = None,
limit: int = 100,
offset: int = 0,
) -> List[Dict]:
"""Get zeugnis documents with optional filtering."""
pool = await get_pool()
if pool is None:
return []
try:
async with pool.acquire() as conn:
if bundesland:
rows = await conn.fetch(
"""
SELECT d.*, s.bundesland, s.name as source_name
FROM zeugnis_documents d
JOIN zeugnis_seed_urls u ON d.seed_url_id = u.id
JOIN zeugnis_sources s ON u.source_id = s.id
WHERE s.bundesland = $1
ORDER BY d.created_at DESC
LIMIT $2 OFFSET $3
""",
bundesland, limit, offset
)
else:
rows = await conn.fetch(
"""
SELECT d.*, s.bundesland, s.name as source_name
FROM zeugnis_documents d
JOIN zeugnis_seed_urls u ON d.seed_url_id = u.id
JOIN zeugnis_sources s ON u.source_id = s.id
ORDER BY d.created_at DESC
LIMIT $1 OFFSET $2
""",
limit, offset
)
return [dict(r) for r in rows]
except Exception as e:
print(f"Failed to get zeugnis documents: {e}")
return []
async def get_zeugnis_stats() -> Dict:
"""Get zeugnis crawler statistics."""
pool = await get_pool()
if pool is None:
return {"error": "Database not available"}
try:
async with pool.acquire() as conn:
# Total sources
sources = await conn.fetchval("SELECT COUNT(*) FROM zeugnis_sources")
# Total documents
documents = await conn.fetchval("SELECT COUNT(*) FROM zeugnis_documents")
# Indexed documents
indexed = await conn.fetchval(
"SELECT COUNT(*) FROM zeugnis_documents WHERE indexed_in_qdrant = true"
)
# Training allowed
training_allowed = await conn.fetchval(
"SELECT COUNT(*) FROM zeugnis_documents WHERE training_allowed = true"
)
# Per Bundesland stats
per_bundesland = await conn.fetch(
"""
SELECT s.bundesland, s.name, s.training_allowed, COUNT(d.id) as doc_count
FROM zeugnis_sources s
LEFT JOIN zeugnis_seed_urls u ON s.id = u.source_id
LEFT JOIN zeugnis_documents d ON u.id = d.seed_url_id
GROUP BY s.bundesland, s.name, s.training_allowed
ORDER BY s.bundesland
"""
)
# Active crawls
active_crawls = await conn.fetchval(
"SELECT COUNT(*) FROM zeugnis_crawler_queue WHERE status = 'running'"
)
return {
"total_sources": sources or 0,
"total_documents": documents or 0,
"indexed_documents": indexed or 0,
"training_allowed_documents": training_allowed or 0,
"active_crawls": active_crawls or 0,
"per_bundesland": [dict(r) for r in per_bundesland],
}
except Exception as e:
print(f"Failed to get zeugnis stats: {e}")
return {"error": str(e)}
async def log_zeugnis_event(
document_id: str,
event_type: str,
user_id: Optional[str] = None,
details: Optional[Dict] = None,
) -> bool:
"""Log a zeugnis usage event for audit trail."""
pool = await get_pool()
if pool is None:
return False
try:
import json
import uuid
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO zeugnis_usage_events (id, document_id, event_type, user_id, details)
VALUES ($1, $2, $3, $4, $5)
""",
str(uuid.uuid4()), document_id, event_type, user_id,
json.dumps(details) if details else None
)
return True
except Exception as e:
print(f"Failed to log zeugnis event: {e}")
return False
# Schema: table initialization
from metrics_db_schema import init_metrics_tables # noqa: F401
# Core: pool, feedback, search logs, metrics, relevance
from metrics_db_core import ( # noqa: F401
DATABASE_URL,
get_pool,
store_feedback,
log_search,
log_upload,
calculate_metrics,
get_recent_feedback,
get_upload_history,
store_relevance_judgment,
calculate_precision_recall,
)
# Zeugnis operations
from metrics_db_zeugnis import ( # noqa: F401
get_zeugnis_sources,
upsert_zeugnis_source,
get_zeugnis_documents,
get_zeugnis_stats,
log_zeugnis_event,
)

View File

@@ -0,0 +1,459 @@
"""
PostgreSQL Metrics Database - Core Operations
Connection pool, table initialization, feedback storage, search logging,
upload history, metrics calculation, and relevance judgments.
Extracted from metrics_db.py to keep files under 500 LOC.
"""
import os
from typing import Optional, List, Dict
from datetime import datetime, timedelta
# Database Configuration - uses test default if not configured (for CI)
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://test:test@localhost:5432/test_metrics")
# Connection pool
_pool = None
async def get_pool():
"""Get or create database connection pool."""
global _pool
if _pool is None:
try:
import asyncpg
_pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10)
except ImportError:
print("Warning: asyncpg not installed. Metrics storage disabled.")
return None
except Exception as e:
print(f"Warning: Failed to connect to PostgreSQL: {e}")
return None
return _pool
# =============================================================================
# Feedback Storage
# =============================================================================
async def store_feedback(
result_id: str,
rating: int,
query_text: Optional[str] = None,
collection_name: Optional[str] = None,
score: Optional[float] = None,
notes: Optional[str] = None,
user_id: Optional[str] = None,
) -> bool:
"""Store search result feedback."""
pool = await get_pool()
if pool is None:
return False
try:
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO rag_search_feedback
(result_id, query_text, collection_name, score, rating, notes, user_id)
VALUES ($1, $2, $3, $4, $5, $6, $7)
""",
result_id, query_text, collection_name, score, rating, notes, user_id
)
return True
except Exception as e:
print(f"Failed to store feedback: {e}")
return False
async def log_search(
query_text: str,
collection_name: str,
result_count: int,
latency_ms: int,
top_score: Optional[float] = None,
filters: Optional[Dict] = None,
) -> bool:
"""Log a search for metrics tracking."""
pool = await get_pool()
if pool is None:
return False
try:
import json
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO rag_search_logs
(query_text, collection_name, result_count, latency_ms, top_score, filters)
VALUES ($1, $2, $3, $4, $5, $6)
""",
query_text, collection_name, result_count, latency_ms, top_score,
json.dumps(filters) if filters else None
)
return True
except Exception as e:
print(f"Failed to log search: {e}")
return False
async def log_upload(
filename: str,
collection_name: str,
year: int,
pdfs_extracted: int,
minio_path: Optional[str] = None,
uploaded_by: Optional[str] = None,
) -> bool:
"""Log an upload for history tracking."""
pool = await get_pool()
if pool is None:
return False
try:
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO rag_upload_history
(filename, collection_name, year, pdfs_extracted, minio_path, uploaded_by)
VALUES ($1, $2, $3, $4, $5, $6)
""",
filename, collection_name, year, pdfs_extracted, minio_path, uploaded_by
)
return True
except Exception as e:
print(f"Failed to log upload: {e}")
return False
# =============================================================================
# Metrics Calculation
# =============================================================================
async def calculate_metrics(
collection_name: Optional[str] = None,
days: int = 7,
) -> Dict:
"""
Calculate RAG quality metrics from stored feedback.
Returns:
Dict with precision, recall, MRR, latency, etc.
"""
pool = await get_pool()
if pool is None:
return {"error": "Database not available", "connected": False}
try:
async with pool.acquire() as conn:
since = datetime.now() - timedelta(days=days)
collection_filter = ""
params = [since]
if collection_name:
collection_filter = "AND collection_name = $2"
params.append(collection_name)
total_feedback = await conn.fetchval(
f"""
SELECT COUNT(*) FROM rag_search_feedback
WHERE created_at >= $1 {collection_filter}
""",
*params
)
rating_dist = await conn.fetch(
f"""
SELECT rating, COUNT(*) as count
FROM rag_search_feedback
WHERE created_at >= $1 {collection_filter}
GROUP BY rating
ORDER BY rating DESC
""",
*params
)
avg_rating = await conn.fetchval(
f"""
SELECT AVG(rating) FROM rag_search_feedback
WHERE created_at >= $1 {collection_filter}
""",
*params
)
score_dist = await conn.fetch(
f"""
SELECT
CASE
WHEN score >= 0.9 THEN '0.9+'
WHEN score >= 0.7 THEN '0.7-0.9'
WHEN score >= 0.5 THEN '0.5-0.7'
ELSE '<0.5'
END as range,
COUNT(*) as count
FROM rag_search_feedback
WHERE created_at >= $1 AND score IS NOT NULL {collection_filter}
GROUP BY range
ORDER BY range DESC
""",
*params
)
latency_stats = await conn.fetchrow(
f"""
SELECT
AVG(latency_ms) as avg_latency,
COUNT(*) as total_searches,
AVG(result_count) as avg_results
FROM rag_search_logs
WHERE created_at >= $1 {collection_filter.replace('collection_name', 'collection_name')}
""",
*params
)
precision_at_5 = await conn.fetchval(
f"""
SELECT
CASE WHEN COUNT(*) > 0
THEN CAST(SUM(CASE WHEN rating >= 4 THEN 1 ELSE 0 END) AS FLOAT) / COUNT(*)
ELSE 0 END
FROM rag_search_feedback
WHERE created_at >= $1 {collection_filter}
""",
*params
) or 0
mrr = (avg_rating or 0) / 5.0
error_count = sum(
r['count'] for r in rating_dist if r['rating'] and r['rating'] <= 2
)
error_rate = (error_count / total_feedback * 100) if total_feedback > 0 else 0
total_scored = sum(s['count'] for s in score_dist)
score_distribution = {}
for s in score_dist:
if total_scored > 0:
score_distribution[s['range']] = round(s['count'] / total_scored * 100)
else:
score_distribution[s['range']] = 0
return {
"connected": True,
"period_days": days,
"precision_at_5": round(precision_at_5, 2),
"recall_at_10": round(precision_at_5 * 1.1, 2),
"mrr": round(mrr, 2),
"avg_latency_ms": round(latency_stats['avg_latency'] or 0),
"total_ratings": total_feedback,
"total_searches": latency_stats['total_searches'] or 0,
"error_rate": round(error_rate, 1),
"score_distribution": score_distribution,
"rating_distribution": {
str(r['rating']): r['count'] for r in rating_dist if r['rating']
},
}
except Exception as e:
print(f"Failed to calculate metrics: {e}")
return {"error": str(e), "connected": False}
async def get_recent_feedback(limit: int = 20) -> List[Dict]:
"""Get recent feedback entries."""
pool = await get_pool()
if pool is None:
return []
try:
async with pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT result_id, rating, query_text, collection_name, score, notes, created_at
FROM rag_search_feedback
ORDER BY created_at DESC
LIMIT $1
""",
limit
)
return [
{
"result_id": r['result_id'],
"rating": r['rating'],
"query_text": r['query_text'],
"collection_name": r['collection_name'],
"score": r['score'],
"notes": r['notes'],
"created_at": r['created_at'].isoformat() if r['created_at'] else None,
}
for r in rows
]
except Exception as e:
print(f"Failed to get recent feedback: {e}")
return []
async def get_upload_history(limit: int = 20) -> List[Dict]:
"""Get recent upload history."""
pool = await get_pool()
if pool is None:
return []
try:
async with pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT filename, collection_name, year, pdfs_extracted, minio_path, uploaded_by, created_at
FROM rag_upload_history
ORDER BY created_at DESC
LIMIT $1
""",
limit
)
return [
{
"filename": r['filename'],
"collection_name": r['collection_name'],
"year": r['year'],
"pdfs_extracted": r['pdfs_extracted'],
"minio_path": r['minio_path'],
"uploaded_by": r['uploaded_by'],
"created_at": r['created_at'].isoformat() if r['created_at'] else None,
}
for r in rows
]
except Exception as e:
print(f"Failed to get upload history: {e}")
return []
# =============================================================================
# Relevance Judgments (Binary Precision/Recall)
# =============================================================================
async def store_relevance_judgment(
query_id: str,
query_text: str,
result_id: str,
is_relevant: bool,
result_rank: Optional[int] = None,
collection_name: Optional[str] = None,
user_id: Optional[str] = None,
) -> bool:
"""Store binary relevance judgment for Precision/Recall calculation."""
pool = await get_pool()
if pool is None:
return False
try:
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO rag_relevance_judgments
(query_id, query_text, result_id, result_rank, is_relevant, collection_name, user_id)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT DO NOTHING
""",
query_id, query_text, result_id, result_rank, is_relevant, collection_name, user_id
)
return True
except Exception as e:
print(f"Failed to store relevance judgment: {e}")
return False
async def calculate_precision_recall(
collection_name: Optional[str] = None,
days: int = 7,
k: int = 10,
) -> Dict:
"""
Calculate true Precision@k and Recall@k from binary relevance judgments.
Precision@k = (Relevant docs in top k) / k
Recall@k = (Relevant docs in top k) / (Total relevant docs for query)
"""
pool = await get_pool()
if pool is None:
return {"error": "Database not available", "connected": False}
try:
async with pool.acquire() as conn:
since = datetime.now() - timedelta(days=days)
collection_filter = ""
params = [since, k]
if collection_name:
collection_filter = "AND collection_name = $3"
params.append(collection_name)
precision_result = await conn.fetchval(
f"""
WITH query_precision AS (
SELECT
query_id,
COUNT(CASE WHEN is_relevant THEN 1 END)::FLOAT /
GREATEST(COUNT(*), 1) as precision
FROM rag_relevance_judgments
WHERE created_at >= $1
AND (result_rank IS NULL OR result_rank <= $2)
{collection_filter}
GROUP BY query_id
)
SELECT AVG(precision) FROM query_precision
""",
*params
) or 0
recall_result = await conn.fetchval(
f"""
WITH query_recall AS (
SELECT
query_id,
COUNT(CASE WHEN is_relevant AND (result_rank IS NULL OR result_rank <= $2) THEN 1 END)::FLOAT /
GREATEST(COUNT(CASE WHEN is_relevant THEN 1 END), 1) as recall
FROM rag_relevance_judgments
WHERE created_at >= $1
{collection_filter}
GROUP BY query_id
)
SELECT AVG(recall) FROM query_recall
""",
*params
) or 0
total_judgments = await conn.fetchval(
f"""
SELECT COUNT(*) FROM rag_relevance_judgments
WHERE created_at >= $1 {collection_filter}
""",
since, *([collection_name] if collection_name else [])
)
unique_queries = await conn.fetchval(
f"""
SELECT COUNT(DISTINCT query_id) FROM rag_relevance_judgments
WHERE created_at >= $1 {collection_filter}
""",
since, *([collection_name] if collection_name else [])
)
return {
"connected": True,
"period_days": days,
"k": k,
"precision_at_k": round(precision_result, 3),
"recall_at_k": round(recall_result, 3),
"f1_score": round(
2 * precision_result * recall_result / max(precision_result + recall_result, 0.001), 3
),
"total_judgments": total_judgments or 0,
"unique_queries": unique_queries or 0,
}
except Exception as e:
print(f"Failed to calculate precision/recall: {e}")
return {"error": str(e), "connected": False}

View File

@@ -0,0 +1,182 @@
"""
PostgreSQL Metrics Database - Schema Initialization
Table creation DDL for all metrics, feedback, and zeugnis tables.
Extracted from metrics_db_core.py to keep files under 500 LOC.
"""
from metrics_db_core import get_pool
async def init_metrics_tables() -> bool:
"""Initialize metrics tables in PostgreSQL."""
pool = await get_pool()
if pool is None:
return False
create_tables_sql = """
-- RAG Search Feedback Table
CREATE TABLE IF NOT EXISTS rag_search_feedback (
id SERIAL PRIMARY KEY,
result_id VARCHAR(255) NOT NULL,
query_text TEXT,
collection_name VARCHAR(100),
score FLOAT,
rating INTEGER CHECK (rating >= 1 AND rating <= 5),
notes TEXT,
user_id VARCHAR(100),
created_at TIMESTAMP DEFAULT NOW()
);
-- Index for efficient querying
CREATE INDEX IF NOT EXISTS idx_feedback_created_at ON rag_search_feedback(created_at);
CREATE INDEX IF NOT EXISTS idx_feedback_collection ON rag_search_feedback(collection_name);
CREATE INDEX IF NOT EXISTS idx_feedback_rating ON rag_search_feedback(rating);
-- RAG Search Logs Table (for latency tracking)
CREATE TABLE IF NOT EXISTS rag_search_logs (
id SERIAL PRIMARY KEY,
query_text TEXT NOT NULL,
collection_name VARCHAR(100),
result_count INTEGER,
latency_ms INTEGER,
top_score FLOAT,
filters JSONB,
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_search_logs_created_at ON rag_search_logs(created_at);
-- RAG Upload History Table
CREATE TABLE IF NOT EXISTS rag_upload_history (
id SERIAL PRIMARY KEY,
filename VARCHAR(500) NOT NULL,
collection_name VARCHAR(100),
year INTEGER,
pdfs_extracted INTEGER,
minio_path VARCHAR(1000),
uploaded_by VARCHAR(100),
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_upload_history_created_at ON rag_upload_history(created_at);
-- Binaere Relevanz-Judgments fuer echte Precision/Recall
CREATE TABLE IF NOT EXISTS rag_relevance_judgments (
id SERIAL PRIMARY KEY,
query_id VARCHAR(255) NOT NULL,
query_text TEXT NOT NULL,
result_id VARCHAR(255) NOT NULL,
result_rank INTEGER,
is_relevant BOOLEAN NOT NULL,
collection_name VARCHAR(100),
user_id VARCHAR(100),
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_relevance_query ON rag_relevance_judgments(query_id);
CREATE INDEX IF NOT EXISTS idx_relevance_created_at ON rag_relevance_judgments(created_at);
-- Zeugnisse Source Tracking
CREATE TABLE IF NOT EXISTS zeugnis_sources (
id VARCHAR(36) PRIMARY KEY,
bundesland VARCHAR(10) NOT NULL,
name VARCHAR(255) NOT NULL,
base_url TEXT,
license_type VARCHAR(50) NOT NULL,
training_allowed BOOLEAN DEFAULT FALSE,
verified_by VARCHAR(100),
verified_at TIMESTAMP,
created_at TIMESTAMP DEFAULT NOW(),
updated_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_zeugnis_sources_bundesland ON zeugnis_sources(bundesland);
-- Zeugnisse Seed URLs
CREATE TABLE IF NOT EXISTS zeugnis_seed_urls (
id VARCHAR(36) PRIMARY KEY,
source_id VARCHAR(36) REFERENCES zeugnis_sources(id),
url TEXT NOT NULL,
doc_type VARCHAR(50),
status VARCHAR(20) DEFAULT 'pending',
last_crawled TIMESTAMP,
error_message TEXT,
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_zeugnis_seed_urls_source ON zeugnis_seed_urls(source_id);
CREATE INDEX IF NOT EXISTS idx_zeugnis_seed_urls_status ON zeugnis_seed_urls(status);
-- Zeugnisse Documents
CREATE TABLE IF NOT EXISTS zeugnis_documents (
id VARCHAR(36) PRIMARY KEY,
seed_url_id VARCHAR(36) REFERENCES zeugnis_seed_urls(id),
title VARCHAR(500),
url TEXT NOT NULL,
content_hash VARCHAR(64),
minio_path TEXT,
training_allowed BOOLEAN DEFAULT FALSE,
indexed_in_qdrant BOOLEAN DEFAULT FALSE,
file_size INTEGER,
content_type VARCHAR(100),
created_at TIMESTAMP DEFAULT NOW(),
updated_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_zeugnis_documents_seed ON zeugnis_documents(seed_url_id);
CREATE INDEX IF NOT EXISTS idx_zeugnis_documents_hash ON zeugnis_documents(content_hash);
-- Zeugnisse Document Versions
CREATE TABLE IF NOT EXISTS zeugnis_document_versions (
id VARCHAR(36) PRIMARY KEY,
document_id VARCHAR(36) REFERENCES zeugnis_documents(id),
version INTEGER NOT NULL,
content_hash VARCHAR(64),
minio_path TEXT,
change_summary TEXT,
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_zeugnis_versions_doc ON zeugnis_document_versions(document_id);
-- Zeugnisse Usage Events (Audit Trail)
CREATE TABLE IF NOT EXISTS zeugnis_usage_events (
id VARCHAR(36) PRIMARY KEY,
document_id VARCHAR(36) REFERENCES zeugnis_documents(id),
event_type VARCHAR(50) NOT NULL,
user_id VARCHAR(100),
details JSONB,
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_zeugnis_events_doc ON zeugnis_usage_events(document_id);
CREATE INDEX IF NOT EXISTS idx_zeugnis_events_type ON zeugnis_usage_events(event_type);
CREATE INDEX IF NOT EXISTS idx_zeugnis_events_created ON zeugnis_usage_events(created_at);
-- Crawler Queue
CREATE TABLE IF NOT EXISTS zeugnis_crawler_queue (
id VARCHAR(36) PRIMARY KEY,
source_id VARCHAR(36) REFERENCES zeugnis_sources(id),
priority INTEGER DEFAULT 5,
status VARCHAR(20) DEFAULT 'pending',
started_at TIMESTAMP,
completed_at TIMESTAMP,
documents_found INTEGER DEFAULT 0,
documents_indexed INTEGER DEFAULT 0,
error_count INTEGER DEFAULT 0,
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_crawler_queue_status ON zeugnis_crawler_queue(status);
"""
try:
async with pool.acquire() as conn:
await conn.execute(create_tables_sql)
print("RAG metrics tables initialized")
return True
except Exception as e:
print(f"Failed to initialize metrics tables: {e}")
return False

View File

@@ -0,0 +1,193 @@
"""
PostgreSQL Metrics Database - Zeugnis Operations
Zeugnis source management, document queries, statistics, and event logging.
Extracted from metrics_db.py to keep files under 500 LOC.
"""
from typing import Optional, List, Dict
from metrics_db_core import get_pool
# =============================================================================
# Zeugnis Database Operations
# =============================================================================
async def get_zeugnis_sources() -> List[Dict]:
"""Get all zeugnis sources (Bundeslaender)."""
pool = await get_pool()
if pool is None:
return []
try:
async with pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT id, bundesland, name, base_url, license_type, training_allowed,
verified_by, verified_at, created_at, updated_at
FROM zeugnis_sources
ORDER BY bundesland
"""
)
return [dict(r) for r in rows]
except Exception as e:
print(f"Failed to get zeugnis sources: {e}")
return []
async def upsert_zeugnis_source(
id: str,
bundesland: str,
name: str,
license_type: str,
training_allowed: bool,
base_url: Optional[str] = None,
verified_by: Optional[str] = None,
) -> bool:
"""Insert or update a zeugnis source."""
pool = await get_pool()
if pool is None:
return False
try:
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO zeugnis_sources (id, bundesland, name, base_url, license_type, training_allowed, verified_by, verified_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW())
ON CONFLICT (id) DO UPDATE SET
name = EXCLUDED.name,
base_url = EXCLUDED.base_url,
license_type = EXCLUDED.license_type,
training_allowed = EXCLUDED.training_allowed,
verified_by = EXCLUDED.verified_by,
verified_at = NOW(),
updated_at = NOW()
""",
id, bundesland, name, base_url, license_type, training_allowed, verified_by
)
return True
except Exception as e:
print(f"Failed to upsert zeugnis source: {e}")
return False
async def get_zeugnis_documents(
bundesland: Optional[str] = None,
limit: int = 100,
offset: int = 0,
) -> List[Dict]:
"""Get zeugnis documents with optional filtering."""
pool = await get_pool()
if pool is None:
return []
try:
async with pool.acquire() as conn:
if bundesland:
rows = await conn.fetch(
"""
SELECT d.*, s.bundesland, s.name as source_name
FROM zeugnis_documents d
JOIN zeugnis_seed_urls u ON d.seed_url_id = u.id
JOIN zeugnis_sources s ON u.source_id = s.id
WHERE s.bundesland = $1
ORDER BY d.created_at DESC
LIMIT $2 OFFSET $3
""",
bundesland, limit, offset
)
else:
rows = await conn.fetch(
"""
SELECT d.*, s.bundesland, s.name as source_name
FROM zeugnis_documents d
JOIN zeugnis_seed_urls u ON d.seed_url_id = u.id
JOIN zeugnis_sources s ON u.source_id = s.id
ORDER BY d.created_at DESC
LIMIT $1 OFFSET $2
""",
limit, offset
)
return [dict(r) for r in rows]
except Exception as e:
print(f"Failed to get zeugnis documents: {e}")
return []
async def get_zeugnis_stats() -> Dict:
"""Get zeugnis crawler statistics."""
pool = await get_pool()
if pool is None:
return {"error": "Database not available"}
try:
async with pool.acquire() as conn:
sources = await conn.fetchval("SELECT COUNT(*) FROM zeugnis_sources")
documents = await conn.fetchval("SELECT COUNT(*) FROM zeugnis_documents")
indexed = await conn.fetchval(
"SELECT COUNT(*) FROM zeugnis_documents WHERE indexed_in_qdrant = true"
)
training_allowed = await conn.fetchval(
"SELECT COUNT(*) FROM zeugnis_documents WHERE training_allowed = true"
)
per_bundesland = await conn.fetch(
"""
SELECT s.bundesland, s.name, s.training_allowed, COUNT(d.id) as doc_count
FROM zeugnis_sources s
LEFT JOIN zeugnis_seed_urls u ON s.id = u.source_id
LEFT JOIN zeugnis_documents d ON u.id = d.seed_url_id
GROUP BY s.bundesland, s.name, s.training_allowed
ORDER BY s.bundesland
"""
)
active_crawls = await conn.fetchval(
"SELECT COUNT(*) FROM zeugnis_crawler_queue WHERE status = 'running'"
)
return {
"total_sources": sources or 0,
"total_documents": documents or 0,
"indexed_documents": indexed or 0,
"training_allowed_documents": training_allowed or 0,
"active_crawls": active_crawls or 0,
"per_bundesland": [dict(r) for r in per_bundesland],
}
except Exception as e:
print(f"Failed to get zeugnis stats: {e}")
return {"error": str(e)}
async def log_zeugnis_event(
document_id: str,
event_type: str,
user_id: Optional[str] = None,
details: Optional[Dict] = None,
) -> bool:
"""Log a zeugnis usage event for audit trail."""
pool = await get_pool()
if pool is None:
return False
try:
import json
import uuid
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO zeugnis_usage_events (id, document_id, event_type, user_id, details)
VALUES ($1, $2, $3, $4, $5)
""",
str(uuid.uuid4()), document_id, event_type, user_id,
json.dumps(details) if details else None
)
return True
except Exception as e:
print(f"Failed to log zeugnis event: {e}")
return False

View File

@@ -1,845 +1,81 @@
"""
OCR Labeling API for Handwriting Training Data Collection
OCR Labeling API — Barrel Re-export
DATENSCHUTZ/PRIVACY:
- Alle Verarbeitung erfolgt lokal (Mac Mini mit Ollama)
- Keine Daten werden an externe Server gesendet
- Bilder werden mit SHA256-Hash dedupliziert
- Export nur für lokales Fine-Tuning (TrOCR, llama3.2-vision)
Split into:
- ocr_labeling_models.py — Pydantic models and constants
- ocr_labeling_helpers.py — OCR wrappers, image storage, hashing
- ocr_labeling_routes.py — Session/queue/labeling route handlers
- ocr_labeling_upload_routes.py — Upload, run-OCR, export route handlers
Endpoints:
- POST /sessions - Create labeling session
- POST /sessions/{id}/upload - Upload images for labeling
- GET /queue - Get next items to label
- POST /confirm - Confirm OCR as correct
- POST /correct - Save corrected ground truth
- POST /skip - Skip unusable item
- GET /stats - Get labeling statistics
- POST /export - Export training data
All public names are re-exported here for backward compatibility.
"""
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query, BackgroundTasks
from pydantic import BaseModel
from typing import Optional, List, Dict, Any
from datetime import datetime
import uuid
import hashlib
import os
import base64
# Import database functions
from metrics_db import (
create_ocr_labeling_session,
get_ocr_labeling_sessions,
get_ocr_labeling_session,
add_ocr_labeling_item,
get_ocr_labeling_queue,
get_ocr_labeling_item,
confirm_ocr_label,
correct_ocr_label,
skip_ocr_item,
get_ocr_labeling_stats,
export_training_samples,
get_training_samples,
# Models
from ocr_labeling_models import ( # noqa: F401
LOCAL_STORAGE_PATH,
SessionCreate,
SessionResponse,
ItemResponse,
ConfirmRequest,
CorrectRequest,
SkipRequest,
ExportRequest,
StatsResponse,
)
# Try to import Vision OCR service
try:
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'backend', 'klausur', 'services'))
from vision_ocr_service import get_vision_ocr_service, VisionOCRService
VISION_OCR_AVAILABLE = True
except ImportError:
VISION_OCR_AVAILABLE = False
print("Warning: Vision OCR service not available")
# Helpers
from ocr_labeling_helpers import ( # noqa: F401
VISION_OCR_AVAILABLE,
PADDLEOCR_AVAILABLE,
TROCR_AVAILABLE,
DONUT_AVAILABLE,
MINIO_AVAILABLE,
TRAINING_EXPORT_AVAILABLE,
compute_image_hash,
run_ocr_on_image,
run_vision_ocr_wrapper,
run_paddleocr_wrapper,
run_trocr_wrapper,
run_donut_wrapper,
save_image_locally,
get_image_url,
)
# Try to import PaddleOCR from hybrid_vocab_extractor
# Conditional re-exports from helpers' optional imports
try:
from hybrid_vocab_extractor import run_paddle_ocr
PADDLEOCR_AVAILABLE = True
from minio_storage import upload_ocr_image, get_ocr_image, MINIO_BUCKET # noqa: F401
except ImportError:
PADDLEOCR_AVAILABLE = False
print("Warning: PaddleOCR not available")
pass
# Try to import TrOCR service
try:
from services.trocr_service import run_trocr_ocr
TROCR_AVAILABLE = True
except ImportError:
TROCR_AVAILABLE = False
print("Warning: TrOCR service not available")
# Try to import Donut service
try:
from services.donut_ocr_service import run_donut_ocr
DONUT_AVAILABLE = True
except ImportError:
DONUT_AVAILABLE = False
print("Warning: Donut OCR service not available")
# Try to import MinIO storage
try:
from minio_storage import upload_ocr_image, get_ocr_image, MINIO_BUCKET
MINIO_AVAILABLE = True
except ImportError:
MINIO_AVAILABLE = False
print("Warning: MinIO storage not available, using local storage")
# Try to import Training Export Service
try:
from training_export_service import (
from training_export_service import ( # noqa: F401
TrainingExportService,
TrainingSample,
get_training_export_service,
)
TRAINING_EXPORT_AVAILABLE = True
except ImportError:
TRAINING_EXPORT_AVAILABLE = False
print("Warning: Training export service not available")
router = APIRouter(prefix="/api/v1/ocr-label", tags=["OCR Labeling"])
# Local storage path (fallback if MinIO not available)
LOCAL_STORAGE_PATH = os.getenv("OCR_STORAGE_PATH", "/app/ocr-labeling")
# =============================================================================
# Pydantic Models
# =============================================================================
class SessionCreate(BaseModel):
name: str
source_type: str = "klausur" # klausur, handwriting_sample, scan
description: Optional[str] = None
ocr_model: Optional[str] = "llama3.2-vision:11b"
class SessionResponse(BaseModel):
id: str
name: str
source_type: str
description: Optional[str]
ocr_model: Optional[str]
total_items: int
labeled_items: int
confirmed_items: int
corrected_items: int
skipped_items: int
created_at: datetime
class ItemResponse(BaseModel):
id: str
session_id: str
session_name: str
image_path: str
image_url: Optional[str]
ocr_text: Optional[str]
ocr_confidence: Optional[float]
ground_truth: Optional[str]
status: str
metadata: Optional[Dict]
created_at: datetime
class ConfirmRequest(BaseModel):
item_id: str
label_time_seconds: Optional[int] = None
class CorrectRequest(BaseModel):
item_id: str
ground_truth: str
label_time_seconds: Optional[int] = None
class SkipRequest(BaseModel):
item_id: str
class ExportRequest(BaseModel):
export_format: str = "generic" # generic, trocr, llama_vision
session_id: Optional[str] = None
batch_id: Optional[str] = None
class StatsResponse(BaseModel):
total_sessions: Optional[int] = None
total_items: int
labeled_items: int
confirmed_items: int
corrected_items: int
pending_items: int
exportable_items: Optional[int] = None
accuracy_rate: float
avg_label_time_seconds: Optional[float] = None
# =============================================================================
# Helper Functions
# =============================================================================
def compute_image_hash(image_data: bytes) -> str:
"""Compute SHA256 hash of image data."""
return hashlib.sha256(image_data).hexdigest()
async def run_ocr_on_image(image_data: bytes, filename: str, model: str = "llama3.2-vision:11b") -> tuple:
"""
Run OCR on an image using the specified model.
Models:
- llama3.2-vision:11b: Vision LLM (default, best for handwriting)
- trocr: Microsoft TrOCR (fast for printed text)
- paddleocr: PaddleOCR + LLM hybrid (4x faster)
- donut: Document Understanding Transformer (structured documents)
Returns:
Tuple of (ocr_text, confidence)
"""
print(f"Running OCR with model: {model}")
# Route to appropriate OCR service based on model
if model == "paddleocr":
return await run_paddleocr_wrapper(image_data, filename)
elif model == "donut":
return await run_donut_wrapper(image_data, filename)
elif model == "trocr":
return await run_trocr_wrapper(image_data, filename)
else:
# Default: Vision LLM (llama3.2-vision or similar)
return await run_vision_ocr_wrapper(image_data, filename)
async def run_vision_ocr_wrapper(image_data: bytes, filename: str) -> tuple:
"""Vision LLM OCR wrapper."""
if not VISION_OCR_AVAILABLE:
print("Vision OCR service not available")
return None, 0.0
try:
service = get_vision_ocr_service()
if not await service.is_available():
print("Vision OCR service not available (is_available check failed)")
return None, 0.0
result = await service.extract_text(
image_data,
filename=filename,
is_handwriting=True
)
return result.text, result.confidence
except Exception as e:
print(f"Vision OCR failed: {e}")
return None, 0.0
async def run_paddleocr_wrapper(image_data: bytes, filename: str) -> tuple:
"""PaddleOCR wrapper - uses hybrid_vocab_extractor."""
if not PADDLEOCR_AVAILABLE:
print("PaddleOCR not available, falling back to Vision OCR")
return await run_vision_ocr_wrapper(image_data, filename)
try:
# run_paddle_ocr returns (regions, raw_text)
regions, raw_text = run_paddle_ocr(image_data)
if not raw_text:
print("PaddleOCR returned empty text")
return None, 0.0
# Calculate average confidence from regions
if regions:
avg_confidence = sum(r.confidence for r in regions) / len(regions)
else:
avg_confidence = 0.5
return raw_text, avg_confidence
except Exception as e:
print(f"PaddleOCR failed: {e}, falling back to Vision OCR")
return await run_vision_ocr_wrapper(image_data, filename)
async def run_trocr_wrapper(image_data: bytes, filename: str) -> tuple:
"""TrOCR wrapper."""
if not TROCR_AVAILABLE:
print("TrOCR not available, falling back to Vision OCR")
return await run_vision_ocr_wrapper(image_data, filename)
try:
text, confidence = await run_trocr_ocr(image_data)
return text, confidence
except Exception as e:
print(f"TrOCR failed: {e}, falling back to Vision OCR")
return await run_vision_ocr_wrapper(image_data, filename)
async def run_donut_wrapper(image_data: bytes, filename: str) -> tuple:
"""Donut OCR wrapper."""
if not DONUT_AVAILABLE:
print("Donut not available, falling back to Vision OCR")
return await run_vision_ocr_wrapper(image_data, filename)
try:
text, confidence = await run_donut_ocr(image_data)
return text, confidence
except Exception as e:
print(f"Donut OCR failed: {e}, falling back to Vision OCR")
return await run_vision_ocr_wrapper(image_data, filename)
def save_image_locally(session_id: str, item_id: str, image_data: bytes, extension: str = "png") -> str:
"""Save image to local storage."""
session_dir = os.path.join(LOCAL_STORAGE_PATH, session_id)
os.makedirs(session_dir, exist_ok=True)
filename = f"{item_id}.{extension}"
filepath = os.path.join(session_dir, filename)
with open(filepath, 'wb') as f:
f.write(image_data)
return filepath
def get_image_url(image_path: str) -> str:
"""Get URL for an image."""
# For local images, return a relative path that the frontend can use
if image_path.startswith(LOCAL_STORAGE_PATH):
relative_path = image_path[len(LOCAL_STORAGE_PATH):].lstrip('/')
return f"/api/v1/ocr-label/images/{relative_path}"
# For MinIO images, the path is already a URL or key
return image_path
# =============================================================================
# API Endpoints
# =============================================================================
@router.post("/sessions", response_model=SessionResponse)
async def create_session(session: SessionCreate):
"""
Create a new OCR labeling session.
A session groups related images for labeling (e.g., all scans from one class).
"""
session_id = str(uuid.uuid4())
success = await create_ocr_labeling_session(
session_id=session_id,
name=session.name,
source_type=session.source_type,
description=session.description,
ocr_model=session.ocr_model,
)
if not success:
raise HTTPException(status_code=500, detail="Failed to create session")
return SessionResponse(
id=session_id,
name=session.name,
source_type=session.source_type,
description=session.description,
ocr_model=session.ocr_model,
total_items=0,
labeled_items=0,
confirmed_items=0,
corrected_items=0,
skipped_items=0,
created_at=datetime.utcnow(),
)
@router.get("/sessions", response_model=List[SessionResponse])
async def list_sessions(limit: int = Query(50, ge=1, le=100)):
"""List all OCR labeling sessions."""
sessions = await get_ocr_labeling_sessions(limit=limit)
return [
SessionResponse(
id=s['id'],
name=s['name'],
source_type=s['source_type'],
description=s.get('description'),
ocr_model=s.get('ocr_model'),
total_items=s.get('total_items', 0),
labeled_items=s.get('labeled_items', 0),
confirmed_items=s.get('confirmed_items', 0),
corrected_items=s.get('corrected_items', 0),
skipped_items=s.get('skipped_items', 0),
created_at=s.get('created_at', datetime.utcnow()),
)
for s in sessions
]
@router.get("/sessions/{session_id}", response_model=SessionResponse)
async def get_session(session_id: str):
"""Get a specific OCR labeling session."""
session = await get_ocr_labeling_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
return SessionResponse(
id=session['id'],
name=session['name'],
source_type=session['source_type'],
description=session.get('description'),
ocr_model=session.get('ocr_model'),
total_items=session.get('total_items', 0),
labeled_items=session.get('labeled_items', 0),
confirmed_items=session.get('confirmed_items', 0),
corrected_items=session.get('corrected_items', 0),
skipped_items=session.get('skipped_items', 0),
created_at=session.get('created_at', datetime.utcnow()),
)
@router.post("/sessions/{session_id}/upload")
async def upload_images(
session_id: str,
background_tasks: BackgroundTasks,
files: List[UploadFile] = File(...),
run_ocr: bool = Form(True),
metadata: Optional[str] = Form(None), # JSON string
):
"""
Upload images to a labeling session.
Args:
session_id: Session to add images to
files: Image files to upload (PNG, JPG, PDF)
run_ocr: Whether to run OCR immediately (default: True)
metadata: Optional JSON metadata (subject, year, etc.)
"""
import json
# Verify session exists
session = await get_ocr_labeling_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
# Parse metadata
meta_dict = None
if metadata:
try:
meta_dict = json.loads(metadata)
except json.JSONDecodeError:
meta_dict = {"raw": metadata}
results = []
ocr_model = session.get('ocr_model', 'llama3.2-vision:11b')
for file in files:
# Read file content
content = await file.read()
# Compute hash for deduplication
image_hash = compute_image_hash(content)
# Generate item ID
item_id = str(uuid.uuid4())
# Determine file extension
extension = file.filename.split('.')[-1].lower() if file.filename else 'png'
if extension not in ['png', 'jpg', 'jpeg', 'pdf']:
extension = 'png'
# Save image
if MINIO_AVAILABLE:
# Upload to MinIO
try:
image_path = upload_ocr_image(session_id, item_id, content, extension)
except Exception as e:
print(f"MinIO upload failed, using local storage: {e}")
image_path = save_image_locally(session_id, item_id, content, extension)
else:
# Save locally
image_path = save_image_locally(session_id, item_id, content, extension)
# Run OCR if requested
ocr_text = None
ocr_confidence = None
if run_ocr and extension != 'pdf': # Skip OCR for PDFs for now
ocr_text, ocr_confidence = await run_ocr_on_image(
content,
file.filename or f"{item_id}.{extension}",
model=ocr_model
)
# Add to database
success = await add_ocr_labeling_item(
item_id=item_id,
session_id=session_id,
image_path=image_path,
image_hash=image_hash,
ocr_text=ocr_text,
ocr_confidence=ocr_confidence,
ocr_model=ocr_model if ocr_text else None,
metadata=meta_dict,
)
if success:
results.append({
"id": item_id,
"filename": file.filename,
"image_path": image_path,
"image_hash": image_hash,
"ocr_text": ocr_text,
"ocr_confidence": ocr_confidence,
"status": "pending",
})
return {
"session_id": session_id,
"uploaded_count": len(results),
"items": results,
}
@router.get("/queue", response_model=List[ItemResponse])
async def get_labeling_queue(
session_id: Optional[str] = Query(None),
status: str = Query("pending"),
limit: int = Query(10, ge=1, le=50),
):
"""
Get items from the labeling queue.
Args:
session_id: Optional filter by session
status: Filter by status (pending, confirmed, corrected, skipped)
limit: Number of items to return
"""
items = await get_ocr_labeling_queue(
session_id=session_id,
status=status,
limit=limit,
)
return [
ItemResponse(
id=item['id'],
session_id=item['session_id'],
session_name=item.get('session_name', ''),
image_path=item['image_path'],
image_url=get_image_url(item['image_path']),
ocr_text=item.get('ocr_text'),
ocr_confidence=item.get('ocr_confidence'),
ground_truth=item.get('ground_truth'),
status=item.get('status', 'pending'),
metadata=item.get('metadata'),
created_at=item.get('created_at', datetime.utcnow()),
)
for item in items
]
@router.get("/items/{item_id}", response_model=ItemResponse)
async def get_item(item_id: str):
"""Get a specific labeling item."""
item = await get_ocr_labeling_item(item_id)
if not item:
raise HTTPException(status_code=404, detail="Item not found")
return ItemResponse(
id=item['id'],
session_id=item['session_id'],
session_name=item.get('session_name', ''),
image_path=item['image_path'],
image_url=get_image_url(item['image_path']),
ocr_text=item.get('ocr_text'),
ocr_confidence=item.get('ocr_confidence'),
ground_truth=item.get('ground_truth'),
status=item.get('status', 'pending'),
metadata=item.get('metadata'),
created_at=item.get('created_at', datetime.utcnow()),
)
@router.post("/confirm")
async def confirm_item(request: ConfirmRequest):
"""
Confirm that OCR text is correct.
Sets ground_truth = ocr_text and marks item as confirmed.
"""
success = await confirm_ocr_label(
item_id=request.item_id,
labeled_by="admin", # TODO: Get from auth
label_time_seconds=request.label_time_seconds,
)
if not success:
raise HTTPException(status_code=400, detail="Failed to confirm item")
return {"status": "confirmed", "item_id": request.item_id}
@router.post("/correct")
async def correct_item(request: CorrectRequest):
"""
Save corrected ground truth for an item.
Use this when OCR text is wrong and needs manual correction.
"""
success = await correct_ocr_label(
item_id=request.item_id,
ground_truth=request.ground_truth,
labeled_by="admin", # TODO: Get from auth
label_time_seconds=request.label_time_seconds,
)
if not success:
raise HTTPException(status_code=400, detail="Failed to correct item")
return {"status": "corrected", "item_id": request.item_id}
@router.post("/skip")
async def skip_item(request: SkipRequest):
"""
Skip an item (unusable image, etc.).
Skipped items are not included in training exports.
"""
success = await skip_ocr_item(
item_id=request.item_id,
labeled_by="admin", # TODO: Get from auth
)
if not success:
raise HTTPException(status_code=400, detail="Failed to skip item")
return {"status": "skipped", "item_id": request.item_id}
@router.get("/stats")
async def get_stats(session_id: Optional[str] = Query(None)):
"""
Get labeling statistics.
Args:
session_id: Optional session ID for session-specific stats
"""
stats = await get_ocr_labeling_stats(session_id=session_id)
if "error" in stats:
raise HTTPException(status_code=500, detail=stats["error"])
return stats
@router.post("/export")
async def export_data(request: ExportRequest):
"""
Export labeled data for training.
Formats:
- generic: JSONL with image_path and ground_truth
- trocr: Format for TrOCR/Microsoft Transformer fine-tuning
- llama_vision: Format for llama3.2-vision fine-tuning
Exports are saved to disk at /app/ocr-exports/{format}/{batch_id}/
"""
# First, get samples from database
db_samples = await export_training_samples(
export_format=request.export_format,
session_id=request.session_id,
batch_id=request.batch_id,
exported_by="admin", # TODO: Get from auth
)
if not db_samples:
return {
"export_format": request.export_format,
"batch_id": request.batch_id,
"exported_count": 0,
"samples": [],
"message": "No labeled samples found to export",
}
# If training export service is available, also write to disk
export_result = None
if TRAINING_EXPORT_AVAILABLE:
try:
export_service = get_training_export_service()
# Convert DB samples to TrainingSample objects
training_samples = []
for s in db_samples:
training_samples.append(TrainingSample(
id=s.get('id', s.get('item_id', '')),
image_path=s.get('image_path', ''),
ground_truth=s.get('ground_truth', ''),
ocr_text=s.get('ocr_text'),
ocr_confidence=s.get('ocr_confidence'),
metadata=s.get('metadata'),
))
# Export to files
export_result = export_service.export(
samples=training_samples,
export_format=request.export_format,
batch_id=request.batch_id,
)
except Exception as e:
print(f"Training export failed: {e}")
# Continue without file export
response = {
"export_format": request.export_format,
"batch_id": request.batch_id or (export_result.batch_id if export_result else None),
"exported_count": len(db_samples),
"samples": db_samples,
}
if export_result:
response["export_path"] = export_result.export_path
response["manifest_path"] = export_result.manifest_path
return response
@router.get("/training-samples")
async def list_training_samples(
export_format: Optional[str] = Query(None),
batch_id: Optional[str] = Query(None),
limit: int = Query(100, ge=1, le=1000),
):
"""Get exported training samples."""
samples = await get_training_samples(
export_format=export_format,
batch_id=batch_id,
limit=limit,
)
return {
"count": len(samples),
"samples": samples,
}
@router.get("/images/{path:path}")
async def get_image(path: str):
"""
Serve an image from local storage.
This endpoint is used when images are stored locally (not in MinIO).
"""
from fastapi.responses import FileResponse
filepath = os.path.join(LOCAL_STORAGE_PATH, path)
if not os.path.exists(filepath):
raise HTTPException(status_code=404, detail="Image not found")
# Determine content type
extension = filepath.split('.')[-1].lower()
content_type = {
'png': 'image/png',
'jpg': 'image/jpeg',
'jpeg': 'image/jpeg',
'pdf': 'application/pdf',
}.get(extension, 'application/octet-stream')
return FileResponse(filepath, media_type=content_type)
@router.post("/run-ocr/{item_id}")
async def run_ocr_for_item(item_id: str):
"""
Run OCR on an existing item.
Use this to re-run OCR or run it if it was skipped during upload.
"""
item = await get_ocr_labeling_item(item_id)
if not item:
raise HTTPException(status_code=404, detail="Item not found")
# Load image
image_path = item['image_path']
if image_path.startswith(LOCAL_STORAGE_PATH):
# Load from local storage
if not os.path.exists(image_path):
raise HTTPException(status_code=404, detail="Image file not found")
with open(image_path, 'rb') as f:
image_data = f.read()
elif MINIO_AVAILABLE:
# Load from MinIO
try:
image_data = get_ocr_image(image_path)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to load image: {e}")
else:
raise HTTPException(status_code=500, detail="Cannot load image")
# Get OCR model from session
session = await get_ocr_labeling_session(item['session_id'])
ocr_model = session.get('ocr_model', 'llama3.2-vision:11b') if session else 'llama3.2-vision:11b'
# Run OCR
ocr_text, ocr_confidence = await run_ocr_on_image(
image_data,
os.path.basename(image_path),
model=ocr_model
)
if ocr_text is None:
raise HTTPException(status_code=500, detail="OCR failed")
# Update item in database
from metrics_db import get_pool
pool = await get_pool()
if pool:
async with pool.acquire() as conn:
await conn.execute(
"""
UPDATE ocr_labeling_items
SET ocr_text = $2, ocr_confidence = $3, ocr_model = $4
WHERE id = $1
""",
item_id, ocr_text, ocr_confidence, ocr_model
)
return {
"item_id": item_id,
"ocr_text": ocr_text,
"ocr_confidence": ocr_confidence,
"ocr_model": ocr_model,
}
@router.get("/exports")
async def list_exports(export_format: Optional[str] = Query(None)):
"""
List all available training data exports.
Args:
export_format: Optional filter by format (generic, trocr, llama_vision)
Returns:
List of export manifests with paths and metadata
"""
if not TRAINING_EXPORT_AVAILABLE:
return {
"exports": [],
"message": "Training export service not available",
}
try:
export_service = get_training_export_service()
exports = export_service.list_exports(export_format=export_format)
return {
"count": len(exports),
"exports": exports,
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to list exports: {e}")
pass
try:
from hybrid_vocab_extractor import run_paddle_ocr # noqa: F401
except ImportError:
pass
try:
from services.trocr_service import run_trocr_ocr # noqa: F401
except ImportError:
pass
try:
from services.donut_ocr_service import run_donut_ocr # noqa: F401
except ImportError:
pass
try:
from vision_ocr_service import get_vision_ocr_service, VisionOCRService # noqa: F401
except ImportError:
pass
# Routes (router is the main export for app.include_router)
from ocr_labeling_routes import router # noqa: F401
from ocr_labeling_upload_routes import router as upload_router # noqa: F401

View File

@@ -0,0 +1,205 @@
"""
OCR Labeling - Helper Functions and OCR Wrappers
Extracted from ocr_labeling_api.py to keep files under 500 LOC.
DATENSCHUTZ/PRIVACY:
- Alle Verarbeitung erfolgt lokal (Mac Mini mit Ollama)
- Keine Daten werden an externe Server gesendet
"""
import os
import hashlib
from ocr_labeling_models import LOCAL_STORAGE_PATH
# Try to import Vision OCR service
try:
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'backend', 'klausur', 'services'))
from vision_ocr_service import get_vision_ocr_service, VisionOCRService
VISION_OCR_AVAILABLE = True
except ImportError:
VISION_OCR_AVAILABLE = False
print("Warning: Vision OCR service not available")
# Try to import PaddleOCR from hybrid_vocab_extractor
try:
from hybrid_vocab_extractor import run_paddle_ocr
PADDLEOCR_AVAILABLE = True
except ImportError:
PADDLEOCR_AVAILABLE = False
print("Warning: PaddleOCR not available")
# Try to import TrOCR service
try:
from services.trocr_service import run_trocr_ocr
TROCR_AVAILABLE = True
except ImportError:
TROCR_AVAILABLE = False
print("Warning: TrOCR service not available")
# Try to import Donut service
try:
from services.donut_ocr_service import run_donut_ocr
DONUT_AVAILABLE = True
except ImportError:
DONUT_AVAILABLE = False
print("Warning: Donut OCR service not available")
# Try to import MinIO storage
try:
from minio_storage import upload_ocr_image, get_ocr_image, MINIO_BUCKET
MINIO_AVAILABLE = True
except ImportError:
MINIO_AVAILABLE = False
print("Warning: MinIO storage not available, using local storage")
# Try to import Training Export Service
try:
from training_export_service import (
TrainingExportService,
TrainingSample,
get_training_export_service,
)
TRAINING_EXPORT_AVAILABLE = True
except ImportError:
TRAINING_EXPORT_AVAILABLE = False
print("Warning: Training export service not available")
# =============================================================================
# Helper Functions
# =============================================================================
def compute_image_hash(image_data: bytes) -> str:
"""Compute SHA256 hash of image data."""
return hashlib.sha256(image_data).hexdigest()
async def run_ocr_on_image(image_data: bytes, filename: str, model: str = "llama3.2-vision:11b") -> tuple:
"""
Run OCR on an image using the specified model.
Models:
- llama3.2-vision:11b: Vision LLM (default, best for handwriting)
- trocr: Microsoft TrOCR (fast for printed text)
- paddleocr: PaddleOCR + LLM hybrid (4x faster)
- donut: Document Understanding Transformer (structured documents)
Returns:
Tuple of (ocr_text, confidence)
"""
print(f"Running OCR with model: {model}")
# Route to appropriate OCR service based on model
if model == "paddleocr":
return await run_paddleocr_wrapper(image_data, filename)
elif model == "donut":
return await run_donut_wrapper(image_data, filename)
elif model == "trocr":
return await run_trocr_wrapper(image_data, filename)
else:
# Default: Vision LLM (llama3.2-vision or similar)
return await run_vision_ocr_wrapper(image_data, filename)
async def run_vision_ocr_wrapper(image_data: bytes, filename: str) -> tuple:
"""Vision LLM OCR wrapper."""
if not VISION_OCR_AVAILABLE:
print("Vision OCR service not available")
return None, 0.0
try:
service = get_vision_ocr_service()
if not await service.is_available():
print("Vision OCR service not available (is_available check failed)")
return None, 0.0
result = await service.extract_text(
image_data,
filename=filename,
is_handwriting=True
)
return result.text, result.confidence
except Exception as e:
print(f"Vision OCR failed: {e}")
return None, 0.0
async def run_paddleocr_wrapper(image_data: bytes, filename: str) -> tuple:
"""PaddleOCR wrapper - uses hybrid_vocab_extractor."""
if not PADDLEOCR_AVAILABLE:
print("PaddleOCR not available, falling back to Vision OCR")
return await run_vision_ocr_wrapper(image_data, filename)
try:
# run_paddle_ocr returns (regions, raw_text)
regions, raw_text = run_paddle_ocr(image_data)
if not raw_text:
print("PaddleOCR returned empty text")
return None, 0.0
# Calculate average confidence from regions
if regions:
avg_confidence = sum(r.confidence for r in regions) / len(regions)
else:
avg_confidence = 0.5
return raw_text, avg_confidence
except Exception as e:
print(f"PaddleOCR failed: {e}, falling back to Vision OCR")
return await run_vision_ocr_wrapper(image_data, filename)
async def run_trocr_wrapper(image_data: bytes, filename: str) -> tuple:
"""TrOCR wrapper."""
if not TROCR_AVAILABLE:
print("TrOCR not available, falling back to Vision OCR")
return await run_vision_ocr_wrapper(image_data, filename)
try:
text, confidence = await run_trocr_ocr(image_data)
return text, confidence
except Exception as e:
print(f"TrOCR failed: {e}, falling back to Vision OCR")
return await run_vision_ocr_wrapper(image_data, filename)
async def run_donut_wrapper(image_data: bytes, filename: str) -> tuple:
"""Donut OCR wrapper."""
if not DONUT_AVAILABLE:
print("Donut not available, falling back to Vision OCR")
return await run_vision_ocr_wrapper(image_data, filename)
try:
text, confidence = await run_donut_ocr(image_data)
return text, confidence
except Exception as e:
print(f"Donut OCR failed: {e}, falling back to Vision OCR")
return await run_vision_ocr_wrapper(image_data, filename)
def save_image_locally(session_id: str, item_id: str, image_data: bytes, extension: str = "png") -> str:
"""Save image to local storage."""
session_dir = os.path.join(LOCAL_STORAGE_PATH, session_id)
os.makedirs(session_dir, exist_ok=True)
filename = f"{item_id}.{extension}"
filepath = os.path.join(session_dir, filename)
with open(filepath, 'wb') as f:
f.write(image_data)
return filepath
def get_image_url(image_path: str) -> str:
"""Get URL for an image."""
# For local images, return a relative path that the frontend can use
if image_path.startswith(LOCAL_STORAGE_PATH):
relative_path = image_path[len(LOCAL_STORAGE_PATH):].lstrip('/')
return f"/api/v1/ocr-label/images/{relative_path}"
# For MinIO images, the path is already a URL or key
return image_path

View File

@@ -0,0 +1,86 @@
"""
OCR Labeling - Pydantic Models and Constants
Extracted from ocr_labeling_api.py to keep files under 500 LOC.
"""
import os
from pydantic import BaseModel
from typing import Optional, Dict
from datetime import datetime
# Local storage path (fallback if MinIO not available)
LOCAL_STORAGE_PATH = os.getenv("OCR_STORAGE_PATH", "/app/ocr-labeling")
# =============================================================================
# Pydantic Models
# =============================================================================
class SessionCreate(BaseModel):
name: str
source_type: str = "klausur" # klausur, handwriting_sample, scan
description: Optional[str] = None
ocr_model: Optional[str] = "llama3.2-vision:11b"
class SessionResponse(BaseModel):
id: str
name: str
source_type: str
description: Optional[str]
ocr_model: Optional[str]
total_items: int
labeled_items: int
confirmed_items: int
corrected_items: int
skipped_items: int
created_at: datetime
class ItemResponse(BaseModel):
id: str
session_id: str
session_name: str
image_path: str
image_url: Optional[str]
ocr_text: Optional[str]
ocr_confidence: Optional[float]
ground_truth: Optional[str]
status: str
metadata: Optional[Dict]
created_at: datetime
class ConfirmRequest(BaseModel):
item_id: str
label_time_seconds: Optional[int] = None
class CorrectRequest(BaseModel):
item_id: str
ground_truth: str
label_time_seconds: Optional[int] = None
class SkipRequest(BaseModel):
item_id: str
class ExportRequest(BaseModel):
export_format: str = "generic" # generic, trocr, llama_vision
session_id: Optional[str] = None
batch_id: Optional[str] = None
class StatsResponse(BaseModel):
total_sessions: Optional[int] = None
total_items: int
labeled_items: int
confirmed_items: int
corrected_items: int
pending_items: int
exportable_items: Optional[int] = None
accuracy_rate: float
avg_label_time_seconds: Optional[float] = None

View File

@@ -0,0 +1,241 @@
"""
OCR Labeling - Session and Labeling Route Handlers
Extracted from ocr_labeling_api.py to keep files under 500 LOC.
Endpoints:
- POST /sessions - Create labeling session
- GET /sessions - List sessions
- GET /sessions/{id} - Get session
- GET /queue - Get labeling queue
- GET /items/{id} - Get item
- POST /confirm - Confirm OCR
- POST /correct - Correct ground truth
- POST /skip - Skip item
- GET /stats - Get statistics
"""
from fastapi import APIRouter, HTTPException, Query
from typing import Optional, List
from datetime import datetime
import uuid
from metrics_db import (
create_ocr_labeling_session,
get_ocr_labeling_sessions,
get_ocr_labeling_session,
get_ocr_labeling_queue,
get_ocr_labeling_item,
confirm_ocr_label,
correct_ocr_label,
skip_ocr_item,
get_ocr_labeling_stats,
)
from ocr_labeling_models import (
SessionCreate, SessionResponse, ItemResponse,
ConfirmRequest, CorrectRequest, SkipRequest,
)
from ocr_labeling_helpers import get_image_url
router = APIRouter(prefix="/api/v1/ocr-label", tags=["OCR Labeling"])
# =============================================================================
# Session Endpoints
# =============================================================================
@router.post("/sessions", response_model=SessionResponse)
async def create_session(session: SessionCreate):
"""Create a new OCR labeling session."""
session_id = str(uuid.uuid4())
success = await create_ocr_labeling_session(
session_id=session_id,
name=session.name,
source_type=session.source_type,
description=session.description,
ocr_model=session.ocr_model,
)
if not success:
raise HTTPException(status_code=500, detail="Failed to create session")
return SessionResponse(
id=session_id,
name=session.name,
source_type=session.source_type,
description=session.description,
ocr_model=session.ocr_model,
total_items=0,
labeled_items=0,
confirmed_items=0,
corrected_items=0,
skipped_items=0,
created_at=datetime.utcnow(),
)
@router.get("/sessions", response_model=List[SessionResponse])
async def list_sessions(limit: int = Query(50, ge=1, le=100)):
"""List all OCR labeling sessions."""
sessions = await get_ocr_labeling_sessions(limit=limit)
return [
SessionResponse(
id=s['id'],
name=s['name'],
source_type=s['source_type'],
description=s.get('description'),
ocr_model=s.get('ocr_model'),
total_items=s.get('total_items', 0),
labeled_items=s.get('labeled_items', 0),
confirmed_items=s.get('confirmed_items', 0),
corrected_items=s.get('corrected_items', 0),
skipped_items=s.get('skipped_items', 0),
created_at=s.get('created_at', datetime.utcnow()),
)
for s in sessions
]
@router.get("/sessions/{session_id}", response_model=SessionResponse)
async def get_session(session_id: str):
"""Get a specific OCR labeling session."""
session = await get_ocr_labeling_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
return SessionResponse(
id=session['id'],
name=session['name'],
source_type=session['source_type'],
description=session.get('description'),
ocr_model=session.get('ocr_model'),
total_items=session.get('total_items', 0),
labeled_items=session.get('labeled_items', 0),
confirmed_items=session.get('confirmed_items', 0),
corrected_items=session.get('corrected_items', 0),
skipped_items=session.get('skipped_items', 0),
created_at=session.get('created_at', datetime.utcnow()),
)
# =============================================================================
# Queue and Item Endpoints
# =============================================================================
@router.get("/queue", response_model=List[ItemResponse])
async def get_labeling_queue(
session_id: Optional[str] = Query(None),
status: str = Query("pending"),
limit: int = Query(10, ge=1, le=50),
):
"""Get items from the labeling queue."""
items = await get_ocr_labeling_queue(
session_id=session_id,
status=status,
limit=limit,
)
return [
ItemResponse(
id=item['id'],
session_id=item['session_id'],
session_name=item.get('session_name', ''),
image_path=item['image_path'],
image_url=get_image_url(item['image_path']),
ocr_text=item.get('ocr_text'),
ocr_confidence=item.get('ocr_confidence'),
ground_truth=item.get('ground_truth'),
status=item.get('status', 'pending'),
metadata=item.get('metadata'),
created_at=item.get('created_at', datetime.utcnow()),
)
for item in items
]
@router.get("/items/{item_id}", response_model=ItemResponse)
async def get_item(item_id: str):
"""Get a specific labeling item."""
item = await get_ocr_labeling_item(item_id)
if not item:
raise HTTPException(status_code=404, detail="Item not found")
return ItemResponse(
id=item['id'],
session_id=item['session_id'],
session_name=item.get('session_name', ''),
image_path=item['image_path'],
image_url=get_image_url(item['image_path']),
ocr_text=item.get('ocr_text'),
ocr_confidence=item.get('ocr_confidence'),
ground_truth=item.get('ground_truth'),
status=item.get('status', 'pending'),
metadata=item.get('metadata'),
created_at=item.get('created_at', datetime.utcnow()),
)
# =============================================================================
# Labeling Action Endpoints
# =============================================================================
@router.post("/confirm")
async def confirm_item(request: ConfirmRequest):
"""Confirm that OCR text is correct."""
success = await confirm_ocr_label(
item_id=request.item_id,
labeled_by="admin",
label_time_seconds=request.label_time_seconds,
)
if not success:
raise HTTPException(status_code=400, detail="Failed to confirm item")
return {"status": "confirmed", "item_id": request.item_id}
@router.post("/correct")
async def correct_item(request: CorrectRequest):
"""Save corrected ground truth for an item."""
success = await correct_ocr_label(
item_id=request.item_id,
ground_truth=request.ground_truth,
labeled_by="admin",
label_time_seconds=request.label_time_seconds,
)
if not success:
raise HTTPException(status_code=400, detail="Failed to correct item")
return {"status": "corrected", "item_id": request.item_id}
@router.post("/skip")
async def skip_item(request: SkipRequest):
"""Skip an item (unusable image, etc.)."""
success = await skip_ocr_item(
item_id=request.item_id,
labeled_by="admin",
)
if not success:
raise HTTPException(status_code=400, detail="Failed to skip item")
return {"status": "skipped", "item_id": request.item_id}
@router.get("/stats")
async def get_stats(session_id: Optional[str] = Query(None)):
"""Get labeling statistics."""
stats = await get_ocr_labeling_stats(session_id=session_id)
if "error" in stats:
raise HTTPException(status_code=500, detail=stats["error"])
return stats

View File

@@ -0,0 +1,313 @@
"""
OCR Labeling - Upload, Run-OCR, and Export Route Handlers
Extracted from ocr_labeling_routes.py to keep files under 500 LOC.
Endpoints:
- POST /sessions/{id}/upload - Upload images for labeling
- POST /run-ocr/{item_id} - Run OCR on existing item
- POST /export - Export training data
- GET /training-samples - List training samples
- GET /images/{path} - Serve images from local storage
- GET /exports - List exports
"""
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query
from typing import Optional, List
import uuid
import os
from metrics_db import (
get_ocr_labeling_session,
add_ocr_labeling_item,
get_ocr_labeling_item,
export_training_samples,
get_training_samples,
)
from ocr_labeling_models import (
ExportRequest,
LOCAL_STORAGE_PATH,
)
from ocr_labeling_helpers import (
compute_image_hash, run_ocr_on_image,
save_image_locally,
MINIO_AVAILABLE, TRAINING_EXPORT_AVAILABLE,
)
# Conditional imports
try:
from minio_storage import upload_ocr_image, get_ocr_image
except ImportError:
pass
try:
from training_export_service import TrainingSample, get_training_export_service
except ImportError:
pass
router = APIRouter(prefix="/api/v1/ocr-label", tags=["OCR Labeling"])
@router.post("/sessions/{session_id}/upload")
async def upload_images(
session_id: str,
files: List[UploadFile] = File(...),
run_ocr: bool = Form(True),
metadata: Optional[str] = Form(None),
):
"""
Upload images to a labeling session.
Args:
session_id: Session to add images to
files: Image files to upload (PNG, JPG, PDF)
run_ocr: Whether to run OCR immediately (default: True)
metadata: Optional JSON metadata (subject, year, etc.)
"""
import json
session = await get_ocr_labeling_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
meta_dict = None
if metadata:
try:
meta_dict = json.loads(metadata)
except json.JSONDecodeError:
meta_dict = {"raw": metadata}
results = []
ocr_model = session.get('ocr_model', 'llama3.2-vision:11b')
for file in files:
content = await file.read()
image_hash = compute_image_hash(content)
item_id = str(uuid.uuid4())
extension = file.filename.split('.')[-1].lower() if file.filename else 'png'
if extension not in ['png', 'jpg', 'jpeg', 'pdf']:
extension = 'png'
if MINIO_AVAILABLE:
try:
image_path = upload_ocr_image(session_id, item_id, content, extension)
except Exception as e:
print(f"MinIO upload failed, using local storage: {e}")
image_path = save_image_locally(session_id, item_id, content, extension)
else:
image_path = save_image_locally(session_id, item_id, content, extension)
ocr_text = None
ocr_confidence = None
if run_ocr and extension != 'pdf':
ocr_text, ocr_confidence = await run_ocr_on_image(
content,
file.filename or f"{item_id}.{extension}",
model=ocr_model
)
success = await add_ocr_labeling_item(
item_id=item_id,
session_id=session_id,
image_path=image_path,
image_hash=image_hash,
ocr_text=ocr_text,
ocr_confidence=ocr_confidence,
ocr_model=ocr_model if ocr_text else None,
metadata=meta_dict,
)
if success:
results.append({
"id": item_id,
"filename": file.filename,
"image_path": image_path,
"image_hash": image_hash,
"ocr_text": ocr_text,
"ocr_confidence": ocr_confidence,
"status": "pending",
})
return {
"session_id": session_id,
"uploaded_count": len(results),
"items": results,
}
@router.post("/export")
async def export_data(request: ExportRequest):
"""Export labeled data for training."""
db_samples = await export_training_samples(
export_format=request.export_format,
session_id=request.session_id,
batch_id=request.batch_id,
exported_by="admin",
)
if not db_samples:
return {
"export_format": request.export_format,
"batch_id": request.batch_id,
"exported_count": 0,
"samples": [],
"message": "No labeled samples found to export",
}
export_result = None
if TRAINING_EXPORT_AVAILABLE:
try:
export_service = get_training_export_service()
training_samples = []
for s in db_samples:
training_samples.append(TrainingSample(
id=s.get('id', s.get('item_id', '')),
image_path=s.get('image_path', ''),
ground_truth=s.get('ground_truth', ''),
ocr_text=s.get('ocr_text'),
ocr_confidence=s.get('ocr_confidence'),
metadata=s.get('metadata'),
))
export_result = export_service.export(
samples=training_samples,
export_format=request.export_format,
batch_id=request.batch_id,
)
except Exception as e:
print(f"Training export failed: {e}")
response = {
"export_format": request.export_format,
"batch_id": request.batch_id or (export_result.batch_id if export_result else None),
"exported_count": len(db_samples),
"samples": db_samples,
}
if export_result:
response["export_path"] = export_result.export_path
response["manifest_path"] = export_result.manifest_path
return response
@router.get("/training-samples")
async def list_training_samples(
export_format: Optional[str] = Query(None),
batch_id: Optional[str] = Query(None),
limit: int = Query(100, ge=1, le=1000),
):
"""Get exported training samples."""
samples = await get_training_samples(
export_format=export_format,
batch_id=batch_id,
limit=limit,
)
return {
"count": len(samples),
"samples": samples,
}
@router.get("/images/{path:path}")
async def get_image(path: str):
"""Serve an image from local storage."""
from fastapi.responses import FileResponse
filepath = os.path.join(LOCAL_STORAGE_PATH, path)
if not os.path.exists(filepath):
raise HTTPException(status_code=404, detail="Image not found")
extension = filepath.split('.')[-1].lower()
content_type = {
'png': 'image/png',
'jpg': 'image/jpeg',
'jpeg': 'image/jpeg',
'pdf': 'application/pdf',
}.get(extension, 'application/octet-stream')
return FileResponse(filepath, media_type=content_type)
@router.post("/run-ocr/{item_id}")
async def run_ocr_for_item(item_id: str):
"""Run OCR on an existing item."""
item = await get_ocr_labeling_item(item_id)
if not item:
raise HTTPException(status_code=404, detail="Item not found")
image_path = item['image_path']
if image_path.startswith(LOCAL_STORAGE_PATH):
if not os.path.exists(image_path):
raise HTTPException(status_code=404, detail="Image file not found")
with open(image_path, 'rb') as f:
image_data = f.read()
elif MINIO_AVAILABLE:
try:
image_data = get_ocr_image(image_path)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to load image: {e}")
else:
raise HTTPException(status_code=500, detail="Cannot load image")
session = await get_ocr_labeling_session(item['session_id'])
ocr_model = session.get('ocr_model', 'llama3.2-vision:11b') if session else 'llama3.2-vision:11b'
ocr_text, ocr_confidence = await run_ocr_on_image(
image_data,
os.path.basename(image_path),
model=ocr_model
)
if ocr_text is None:
raise HTTPException(status_code=500, detail="OCR failed")
from metrics_db import get_pool
pool = await get_pool()
if pool:
async with pool.acquire() as conn:
await conn.execute(
"""
UPDATE ocr_labeling_items
SET ocr_text = $2, ocr_confidence = $3, ocr_model = $4
WHERE id = $1
""",
item_id, ocr_text, ocr_confidence, ocr_model
)
return {
"item_id": item_id,
"ocr_text": ocr_text,
"ocr_confidence": ocr_confidence,
"ocr_model": ocr_model,
}
@router.get("/exports")
async def list_exports(export_format: Optional[str] = Query(None)):
"""List all available training data exports."""
if not TRAINING_EXPORT_AVAILABLE:
return {
"exports": [],
"message": "Training export service not available",
}
try:
export_service = get_training_export_service()
exports = export_service.list_exports(export_format=export_format)
return {
"count": len(exports),
"exports": exports,
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to list exports: {e}")

View File

@@ -1,705 +1,23 @@
"""
OCR Pipeline Auto-Mode Orchestrator and Reprocess Endpoints.
OCR Pipeline Auto-Mode Orchestrator and Reprocess Endpoints — Barrel Re-export.
Extracted from ocr_pipeline_api.py — contains:
- POST /sessions/{session_id}/reprocess (clear downstream + restart from step)
- POST /sessions/{session_id}/run-auto (full auto-mode with SSE streaming)
Split into submodules:
- ocr_pipeline_reprocess.py — POST /sessions/{id}/reprocess
- ocr_pipeline_auto_steps.py — POST /sessions/{id}/run-auto + VLM helper
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import json
import logging
import os
import re
import time
from dataclasses import asdict
from typing import Any, Dict, List, Optional
from fastapi import APIRouter
import cv2
import numpy as np
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from cv_vocab_pipeline import (
OLLAMA_REVIEW_MODEL,
PageRegion,
RowGeometry,
_cells_to_vocab_entries,
_detect_header_footer_gaps,
_detect_sub_columns,
_fix_character_confusion,
_fix_phonetic_brackets,
fix_cell_phonetics,
analyze_layout,
build_cell_grid,
classify_column_types,
create_layout_image,
create_ocr_image,
deskew_image,
deskew_image_by_word_alignment,
detect_column_geometry,
detect_row_geometry,
_apply_shear,
dewarp_image,
llm_review_entries,
)
from ocr_pipeline_common import (
_cache,
_load_session_to_cache,
_get_cached,
_get_base_image_png,
_append_pipeline_log,
)
from ocr_pipeline_session_store import (
get_session_db,
update_session_db,
)
logger = logging.getLogger(__name__)
from ocr_pipeline_reprocess import router as _reprocess_router
from ocr_pipeline_auto_steps import router as _steps_router
# Combine both sub-routers into a single router for backwards compatibility.
# The consumer imports `from ocr_pipeline_auto import router as _auto_router`.
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
router.include_router(_reprocess_router)
router.include_router(_steps_router)
# ---------------------------------------------------------------------------
# Reprocess endpoint
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/reprocess")
async def reprocess_session(session_id: str, request: Request):
"""Re-run pipeline from a specific step, clearing downstream data.
Body: {"from_step": 5} (1-indexed step number)
Pipeline order: Orientation(1) → Deskew(2) → Dewarp(3) → Crop(4) → Columns(5) →
Rows(6) → Words(7) → LLM-Review(8) → Reconstruction(9) → Validation(10)
Clears downstream results:
- from_step <= 1: orientation_result + all downstream
- from_step <= 2: deskew_result + all downstream
- from_step <= 3: dewarp_result + all downstream
- from_step <= 4: crop_result + all downstream
- from_step <= 5: column_result, row_result, word_result
- from_step <= 6: row_result, word_result
- from_step <= 7: word_result (cells, vocab_entries)
- from_step <= 8: word_result.llm_review only
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
body = await request.json()
from_step = body.get("from_step", 1)
if not isinstance(from_step, int) or from_step < 1 or from_step > 10:
raise HTTPException(status_code=400, detail="from_step must be between 1 and 10")
update_kwargs: Dict[str, Any] = {"current_step": from_step}
# Clear downstream data based on from_step
# New pipeline order: Orient(2) → Deskew(3) → Dewarp(4) → Crop(5) →
# Columns(6) → Rows(7) → Words(8) → LLM(9) → Recon(10) → GT(11)
if from_step <= 8:
update_kwargs["word_result"] = None
elif from_step == 9:
# Only clear LLM review from word_result
word_result = session.get("word_result")
if word_result:
word_result.pop("llm_review", None)
word_result.pop("llm_corrections", None)
update_kwargs["word_result"] = word_result
if from_step <= 7:
update_kwargs["row_result"] = None
if from_step <= 6:
update_kwargs["column_result"] = None
if from_step <= 4:
update_kwargs["crop_result"] = None
if from_step <= 3:
update_kwargs["dewarp_result"] = None
if from_step <= 2:
update_kwargs["deskew_result"] = None
if from_step <= 1:
update_kwargs["orientation_result"] = None
await update_session_db(session_id, **update_kwargs)
# Also clear cache
if session_id in _cache:
for key in list(update_kwargs.keys()):
if key != "current_step":
_cache[session_id][key] = update_kwargs[key]
_cache[session_id]["current_step"] = from_step
logger.info(f"Session {session_id} reprocessing from step {from_step}")
return {
"session_id": session_id,
"from_step": from_step,
"cleared": [k for k in update_kwargs if k != "current_step"],
}
# ---------------------------------------------------------------------------
# VLM shear detection helper (used by dewarp step in auto-mode)
# ---------------------------------------------------------------------------
async def _detect_shear_with_vlm(image_bytes: bytes) -> Dict[str, Any]:
"""Ask qwen2.5vl:32b to estimate the vertical shear angle of a scanned page.
The VLM is shown the image and asked: are the column/table borders tilted?
If yes, by how many degrees? Returns a dict with shear_degrees and confidence.
Confidence is 0.0 if Ollama is unavailable or parsing fails.
"""
import httpx
import base64
import re
ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b")
prompt = (
"This is a scanned vocabulary worksheet. Look at the vertical borders of the table columns. "
"Are they perfectly vertical, or do they tilt slightly? "
"If they tilt, estimate the tilt angle in degrees (positive = top tilts right, negative = top tilts left). "
"Reply with ONLY a JSON object like: {\"shear_degrees\": 1.2, \"confidence\": 0.8} "
"Use confidence 0.0-1.0 based on how clearly you can see the tilt. "
"If the columns look straight, return {\"shear_degrees\": 0.0, \"confidence\": 0.9}"
)
img_b64 = base64.b64encode(image_bytes).decode("utf-8")
payload = {
"model": model,
"prompt": prompt,
"images": [img_b64],
"stream": False,
}
try:
async with httpx.AsyncClient(timeout=60.0) as client:
resp = await client.post(f"{ollama_base}/api/generate", json=payload)
resp.raise_for_status()
text = resp.json().get("response", "")
# Parse JSON from response (may have surrounding text)
match = re.search(r'\{[^}]+\}', text)
if match:
import json
data = json.loads(match.group(0))
shear = float(data.get("shear_degrees", 0.0))
conf = float(data.get("confidence", 0.0))
# Clamp to reasonable range
shear = max(-3.0, min(3.0, shear))
conf = max(0.0, min(1.0, conf))
return {"method": "vlm_qwen2.5vl", "shear_degrees": round(shear, 3), "confidence": round(conf, 2)}
except Exception as e:
logger.warning(f"VLM dewarp failed: {e}")
return {"method": "vlm_qwen2.5vl", "shear_degrees": 0.0, "confidence": 0.0}
# ---------------------------------------------------------------------------
# Auto-mode orchestrator
# ---------------------------------------------------------------------------
class RunAutoRequest(BaseModel):
from_step: int = 1 # 1=deskew, 2=dewarp, 3=columns, 4=rows, 5=words, 6=llm-review
ocr_engine: str = "auto" # "auto" | "rapid" | "tesseract"
pronunciation: str = "british"
skip_llm_review: bool = False
dewarp_method: str = "ensemble" # "ensemble" | "vlm" | "cv"
async def _auto_sse_event(step: str, status: str, data: Dict[str, Any]) -> str:
"""Format a single SSE event line."""
import json as _json
payload = {"step": step, "status": status, **data}
return f"data: {_json.dumps(payload)}\n\n"
@router.post("/sessions/{session_id}/run-auto")
async def run_auto(session_id: str, req: RunAutoRequest, request: Request):
"""Run the full OCR pipeline automatically from a given step, streaming SSE progress.
Steps:
1. Deskew — straighten the scan
2. Dewarp — correct vertical shear (ensemble CV or VLM)
3. Columns — detect column layout
4. Rows — detect row layout
5. Words — OCR each cell
6. LLM review — correct OCR errors (optional)
Already-completed steps are skipped unless `from_step` forces a rerun.
Yields SSE events of the form:
data: {"step": "deskew", "status": "start"|"done"|"skipped"|"error", ...}
Final event:
data: {"step": "complete", "status": "done", "steps_run": [...], "steps_skipped": [...]}
"""
if req.from_step < 1 or req.from_step > 6:
raise HTTPException(status_code=400, detail="from_step must be 1-6")
if req.dewarp_method not in ("ensemble", "vlm", "cv"):
raise HTTPException(status_code=400, detail="dewarp_method must be: ensemble, vlm, cv")
if session_id not in _cache:
await _load_session_to_cache(session_id)
async def _generate():
steps_run: List[str] = []
steps_skipped: List[str] = []
error_step: Optional[str] = None
session = await get_session_db(session_id)
if not session:
yield await _auto_sse_event("error", "error", {"message": f"Session {session_id} not found"})
return
cached = _get_cached(session_id)
# -----------------------------------------------------------------
# Step 1: Deskew
# -----------------------------------------------------------------
if req.from_step <= 1:
yield await _auto_sse_event("deskew", "start", {})
try:
t0 = time.time()
orig_bgr = cached.get("original_bgr")
if orig_bgr is None:
raise ValueError("Original image not loaded")
# Method 1: Hough lines
try:
deskewed_hough, angle_hough = deskew_image(orig_bgr.copy())
except Exception:
deskewed_hough, angle_hough = orig_bgr, 0.0
# Method 2: Word alignment
success_enc, png_orig = cv2.imencode(".png", orig_bgr)
orig_bytes = png_orig.tobytes() if success_enc else b""
try:
deskewed_wa_bytes, angle_wa = deskew_image_by_word_alignment(orig_bytes)
except Exception:
deskewed_wa_bytes, angle_wa = orig_bytes, 0.0
# Pick best method
if abs(angle_wa) >= abs(angle_hough) or abs(angle_hough) < 0.1:
method_used = "word_alignment"
angle_applied = angle_wa
wa_arr = np.frombuffer(deskewed_wa_bytes, dtype=np.uint8)
deskewed_bgr = cv2.imdecode(wa_arr, cv2.IMREAD_COLOR)
if deskewed_bgr is None:
deskewed_bgr = deskewed_hough
method_used = "hough"
angle_applied = angle_hough
else:
method_used = "hough"
angle_applied = angle_hough
deskewed_bgr = deskewed_hough
success, png_buf = cv2.imencode(".png", deskewed_bgr)
deskewed_png = png_buf.tobytes() if success else b""
deskew_result = {
"method_used": method_used,
"rotation_degrees": round(float(angle_applied), 3),
"duration_seconds": round(time.time() - t0, 2),
}
cached["deskewed_bgr"] = deskewed_bgr
cached["deskew_result"] = deskew_result
await update_session_db(
session_id,
deskewed_png=deskewed_png,
deskew_result=deskew_result,
auto_rotation_degrees=float(angle_applied),
current_step=3,
)
session = await get_session_db(session_id)
steps_run.append("deskew")
yield await _auto_sse_event("deskew", "done", deskew_result)
except Exception as e:
logger.error(f"Auto-mode deskew failed for {session_id}: {e}")
error_step = "deskew"
yield await _auto_sse_event("deskew", "error", {"message": str(e)})
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
return
else:
steps_skipped.append("deskew")
yield await _auto_sse_event("deskew", "skipped", {"reason": "from_step > 1"})
# -----------------------------------------------------------------
# Step 2: Dewarp
# -----------------------------------------------------------------
if req.from_step <= 2:
yield await _auto_sse_event("dewarp", "start", {"method": req.dewarp_method})
try:
t0 = time.time()
deskewed_bgr = cached.get("deskewed_bgr")
if deskewed_bgr is None:
raise ValueError("Deskewed image not available")
if req.dewarp_method == "vlm":
success_enc, png_buf = cv2.imencode(".png", deskewed_bgr)
img_bytes = png_buf.tobytes() if success_enc else b""
vlm_det = await _detect_shear_with_vlm(img_bytes)
shear_deg = vlm_det["shear_degrees"]
if abs(shear_deg) >= 0.05 and vlm_det["confidence"] >= 0.3:
dewarped_bgr = _apply_shear(deskewed_bgr, -shear_deg)
else:
dewarped_bgr = deskewed_bgr
dewarp_info = {
"method": vlm_det["method"],
"shear_degrees": shear_deg,
"confidence": vlm_det["confidence"],
"detections": [vlm_det],
}
else:
dewarped_bgr, dewarp_info = dewarp_image(deskewed_bgr)
success_enc, png_buf = cv2.imencode(".png", dewarped_bgr)
dewarped_png = png_buf.tobytes() if success_enc else b""
dewarp_result = {
"method_used": dewarp_info["method"],
"shear_degrees": dewarp_info["shear_degrees"],
"confidence": dewarp_info["confidence"],
"duration_seconds": round(time.time() - t0, 2),
"detections": dewarp_info.get("detections", []),
}
cached["dewarped_bgr"] = dewarped_bgr
cached["dewarp_result"] = dewarp_result
await update_session_db(
session_id,
dewarped_png=dewarped_png,
dewarp_result=dewarp_result,
auto_shear_degrees=dewarp_info.get("shear_degrees", 0.0),
current_step=4,
)
session = await get_session_db(session_id)
steps_run.append("dewarp")
yield await _auto_sse_event("dewarp", "done", dewarp_result)
except Exception as e:
logger.error(f"Auto-mode dewarp failed for {session_id}: {e}")
error_step = "dewarp"
yield await _auto_sse_event("dewarp", "error", {"message": str(e)})
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
return
else:
steps_skipped.append("dewarp")
yield await _auto_sse_event("dewarp", "skipped", {"reason": "from_step > 2"})
# -----------------------------------------------------------------
# Step 3: Columns
# -----------------------------------------------------------------
if req.from_step <= 3:
yield await _auto_sse_event("columns", "start", {})
try:
t0 = time.time()
col_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
if col_img is None:
raise ValueError("Cropped/dewarped image not available")
ocr_img = create_ocr_image(col_img)
h, w = ocr_img.shape[:2]
geo_result = detect_column_geometry(ocr_img, col_img)
if geo_result is None:
layout_img = create_layout_image(col_img)
regions = analyze_layout(layout_img, ocr_img)
cached["_word_dicts"] = None
cached["_inv"] = None
cached["_content_bounds"] = None
else:
geometries, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
content_w = right_x - left_x
cached["_word_dicts"] = word_dicts
cached["_inv"] = inv
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
header_y, footer_y = _detect_header_footer_gaps(inv, w, h) if inv is not None else (None, None)
geometries = _detect_sub_columns(geometries, content_w, left_x=left_x,
top_y=top_y, header_y=header_y, footer_y=footer_y)
regions = classify_column_types(geometries, content_w, top_y, w, h, bottom_y,
left_x=left_x, right_x=right_x, inv=inv)
columns = [asdict(r) for r in regions]
column_result = {
"columns": columns,
"classification_methods": list({c.get("classification_method", "") for c in columns if c.get("classification_method")}),
"duration_seconds": round(time.time() - t0, 2),
}
cached["column_result"] = column_result
await update_session_db(session_id, column_result=column_result,
row_result=None, word_result=None, current_step=6)
session = await get_session_db(session_id)
steps_run.append("columns")
yield await _auto_sse_event("columns", "done", {
"column_count": len(columns),
"duration_seconds": column_result["duration_seconds"],
})
except Exception as e:
logger.error(f"Auto-mode columns failed for {session_id}: {e}")
error_step = "columns"
yield await _auto_sse_event("columns", "error", {"message": str(e)})
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
return
else:
steps_skipped.append("columns")
yield await _auto_sse_event("columns", "skipped", {"reason": "from_step > 3"})
# -----------------------------------------------------------------
# Step 4: Rows
# -----------------------------------------------------------------
if req.from_step <= 4:
yield await _auto_sse_event("rows", "start", {})
try:
t0 = time.time()
row_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
session = await get_session_db(session_id)
column_result = session.get("column_result") or cached.get("column_result")
if not column_result or not column_result.get("columns"):
raise ValueError("Column detection must complete first")
col_regions = [
PageRegion(
type=c["type"], x=c["x"], y=c["y"],
width=c["width"], height=c["height"],
classification_confidence=c.get("classification_confidence", 1.0),
classification_method=c.get("classification_method", ""),
)
for c in column_result["columns"]
]
word_dicts = cached.get("_word_dicts")
inv = cached.get("_inv")
content_bounds = cached.get("_content_bounds")
if word_dicts is None or inv is None or content_bounds is None:
ocr_img_tmp = create_ocr_image(row_img)
geo_result = detect_column_geometry(ocr_img_tmp, row_img)
if geo_result is None:
raise ValueError("Column geometry detection failed — cannot detect rows")
_g, lx, rx, ty, by, word_dicts, inv = geo_result
cached["_word_dicts"] = word_dicts
cached["_inv"] = inv
cached["_content_bounds"] = (lx, rx, ty, by)
content_bounds = (lx, rx, ty, by)
left_x, right_x, top_y, bottom_y = content_bounds
row_geoms = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y)
row_list = [
{
"index": r.index, "x": r.x, "y": r.y,
"width": r.width, "height": r.height,
"word_count": r.word_count,
"row_type": r.row_type,
"gap_before": r.gap_before,
}
for r in row_geoms
]
row_result = {
"rows": row_list,
"row_count": len(row_list),
"content_rows": len([r for r in row_geoms if r.row_type == "content"]),
"duration_seconds": round(time.time() - t0, 2),
}
cached["row_result"] = row_result
await update_session_db(session_id, row_result=row_result, current_step=7)
session = await get_session_db(session_id)
steps_run.append("rows")
yield await _auto_sse_event("rows", "done", {
"row_count": len(row_list),
"content_rows": row_result["content_rows"],
"duration_seconds": row_result["duration_seconds"],
})
except Exception as e:
logger.error(f"Auto-mode rows failed for {session_id}: {e}")
error_step = "rows"
yield await _auto_sse_event("rows", "error", {"message": str(e)})
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
return
else:
steps_skipped.append("rows")
yield await _auto_sse_event("rows", "skipped", {"reason": "from_step > 4"})
# -----------------------------------------------------------------
# Step 5: Words (OCR)
# -----------------------------------------------------------------
if req.from_step <= 5:
yield await _auto_sse_event("words", "start", {"engine": req.ocr_engine})
try:
t0 = time.time()
word_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
session = await get_session_db(session_id)
column_result = session.get("column_result") or cached.get("column_result")
row_result = session.get("row_result") or cached.get("row_result")
col_regions = [
PageRegion(
type=c["type"], x=c["x"], y=c["y"],
width=c["width"], height=c["height"],
classification_confidence=c.get("classification_confidence", 1.0),
classification_method=c.get("classification_method", ""),
)
for c in column_result["columns"]
]
row_geoms = [
RowGeometry(
index=r["index"], x=r["x"], y=r["y"],
width=r["width"], height=r["height"],
word_count=r.get("word_count", 0), words=[],
row_type=r.get("row_type", "content"),
gap_before=r.get("gap_before", 0),
)
for r in row_result["rows"]
]
word_dicts = cached.get("_word_dicts")
if word_dicts is not None:
content_bounds = cached.get("_content_bounds")
top_y = content_bounds[2] if content_bounds else min(r.y for r in row_geoms)
for row in row_geoms:
row_y_rel = row.y - top_y
row_bottom_rel = row_y_rel + row.height
row.words = [
w for w in word_dicts
if row_y_rel <= w['top'] + w['height'] / 2 < row_bottom_rel
]
row.word_count = len(row.words)
ocr_img = create_ocr_image(word_img)
img_h, img_w = word_img.shape[:2]
cells, columns_meta = build_cell_grid(
ocr_img, col_regions, row_geoms, img_w, img_h,
ocr_engine=req.ocr_engine, img_bgr=word_img,
)
duration = time.time() - t0
col_types = {c['type'] for c in columns_meta}
is_vocab = bool(col_types & {'column_en', 'column_de'})
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
used_engine = cells[0].get("ocr_engine", "tesseract") if cells else req.ocr_engine
# Apply IPA phonetic fixes directly to cell texts
fix_cell_phonetics(cells, pronunciation=req.pronunciation)
word_result_data = {
"cells": cells,
"grid_shape": {
"rows": n_content_rows,
"cols": len(columns_meta),
"total_cells": len(cells),
},
"columns_used": columns_meta,
"layout": "vocab" if is_vocab else "generic",
"image_width": img_w,
"image_height": img_h,
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
"summary": {
"total_cells": len(cells),
"non_empty_cells": sum(1 for c in cells if c.get("text")),
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
},
}
has_text_col = 'column_text' in col_types
if is_vocab or has_text_col:
entries = _cells_to_vocab_entries(cells, columns_meta)
entries = _fix_character_confusion(entries)
entries = _fix_phonetic_brackets(entries, pronunciation=req.pronunciation)
word_result_data["vocab_entries"] = entries
word_result_data["entries"] = entries
word_result_data["entry_count"] = len(entries)
word_result_data["summary"]["total_entries"] = len(entries)
await update_session_db(session_id, word_result=word_result_data, current_step=8)
cached["word_result"] = word_result_data
session = await get_session_db(session_id)
steps_run.append("words")
yield await _auto_sse_event("words", "done", {
"total_cells": len(cells),
"layout": word_result_data["layout"],
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
"summary": word_result_data["summary"],
})
except Exception as e:
logger.error(f"Auto-mode words failed for {session_id}: {e}")
error_step = "words"
yield await _auto_sse_event("words", "error", {"message": str(e)})
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
return
else:
steps_skipped.append("words")
yield await _auto_sse_event("words", "skipped", {"reason": "from_step > 5"})
# -----------------------------------------------------------------
# Step 6: LLM Review (optional)
# -----------------------------------------------------------------
if req.from_step <= 6 and not req.skip_llm_review:
yield await _auto_sse_event("llm_review", "start", {"model": OLLAMA_REVIEW_MODEL})
try:
session = await get_session_db(session_id)
word_result = session.get("word_result") or cached.get("word_result")
entries = word_result.get("entries") or word_result.get("vocab_entries") or []
if not entries:
yield await _auto_sse_event("llm_review", "skipped", {"reason": "no entries"})
steps_skipped.append("llm_review")
else:
reviewed = await llm_review_entries(entries)
session = await get_session_db(session_id)
word_result_updated = dict(session.get("word_result") or {})
word_result_updated["entries"] = reviewed
word_result_updated["vocab_entries"] = reviewed
word_result_updated["llm_reviewed"] = True
word_result_updated["llm_model"] = OLLAMA_REVIEW_MODEL
await update_session_db(session_id, word_result=word_result_updated, current_step=9)
cached["word_result"] = word_result_updated
steps_run.append("llm_review")
yield await _auto_sse_event("llm_review", "done", {
"entries_reviewed": len(reviewed),
"model": OLLAMA_REVIEW_MODEL,
})
except Exception as e:
logger.warning(f"Auto-mode llm_review failed for {session_id} (non-fatal): {e}")
yield await _auto_sse_event("llm_review", "error", {"message": str(e), "fatal": False})
steps_skipped.append("llm_review")
else:
steps_skipped.append("llm_review")
reason = "skipped by request" if req.skip_llm_review else "from_step > 6"
yield await _auto_sse_event("llm_review", "skipped", {"reason": reason})
# -----------------------------------------------------------------
# Final event
# -----------------------------------------------------------------
yield await _auto_sse_event("complete", "done", {
"steps_run": steps_run,
"steps_skipped": steps_skipped,
})
return StreamingResponse(
_generate(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
__all__ = ["router"]

View File

@@ -0,0 +1,84 @@
"""
OCR Pipeline Auto-Mode Helpers.
VLM shear detection, SSE event formatting, and request models.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import json
import logging
import os
import re
from typing import Any, Dict
from pydantic import BaseModel
logger = logging.getLogger(__name__)
class RunAutoRequest(BaseModel):
from_step: int = 1 # 1=deskew, 2=dewarp, 3=columns, 4=rows, 5=words, 6=llm-review
ocr_engine: str = "auto" # "auto" | "rapid" | "tesseract"
pronunciation: str = "british"
skip_llm_review: bool = False
dewarp_method: str = "ensemble" # "ensemble" | "vlm" | "cv"
async def auto_sse_event(step: str, status: str, data: Dict[str, Any]) -> str:
"""Format a single SSE event line."""
payload = {"step": step, "status": status, **data}
return f"data: {json.dumps(payload)}\n\n"
async def detect_shear_with_vlm(image_bytes: bytes) -> Dict[str, Any]:
"""Ask qwen2.5vl:32b to estimate the vertical shear angle of a scanned page.
The VLM is shown the image and asked: are the column/table borders tilted?
If yes, by how many degrees? Returns a dict with shear_degrees and confidence.
Confidence is 0.0 if Ollama is unavailable or parsing fails.
"""
import httpx
import base64
ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b")
prompt = (
"This is a scanned vocabulary worksheet. Look at the vertical borders of the table columns. "
"Are they perfectly vertical, or do they tilt slightly? "
"If they tilt, estimate the tilt angle in degrees (positive = top tilts right, negative = top tilts left). "
"Reply with ONLY a JSON object like: {\"shear_degrees\": 1.2, \"confidence\": 0.8} "
"Use confidence 0.0-1.0 based on how clearly you can see the tilt. "
"If the columns look straight, return {\"shear_degrees\": 0.0, \"confidence\": 0.9}"
)
img_b64 = base64.b64encode(image_bytes).decode("utf-8")
payload = {
"model": model,
"prompt": prompt,
"images": [img_b64],
"stream": False,
}
try:
async with httpx.AsyncClient(timeout=60.0) as client:
resp = await client.post(f"{ollama_base}/api/generate", json=payload)
resp.raise_for_status()
text = resp.json().get("response", "")
# Parse JSON from response (may have surrounding text)
match = re.search(r'\{[^}]+\}', text)
if match:
data = json.loads(match.group(0))
shear = float(data.get("shear_degrees", 0.0))
conf = float(data.get("confidence", 0.0))
# Clamp to reasonable range
shear = max(-3.0, min(3.0, shear))
conf = max(0.0, min(1.0, conf))
return {"method": "vlm_qwen2.5vl", "shear_degrees": round(shear, 3), "confidence": round(conf, 2)}
except Exception as e:
logger.warning(f"VLM dewarp failed: {e}")
return {"method": "vlm_qwen2.5vl", "shear_degrees": 0.0, "confidence": 0.0}

View File

@@ -0,0 +1,528 @@
"""
OCR Pipeline Auto-Mode Orchestrator.
POST /sessions/{session_id}/run-auto -- full auto-mode with SSE streaming.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
import time
from dataclasses import asdict
from typing import Any, Dict, List, Optional
import cv2
import numpy as np
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import StreamingResponse
from cv_vocab_pipeline import (
OLLAMA_REVIEW_MODEL,
PageRegion,
RowGeometry,
_cells_to_vocab_entries,
_detect_header_footer_gaps,
_detect_sub_columns,
_fix_character_confusion,
_fix_phonetic_brackets,
fix_cell_phonetics,
analyze_layout,
build_cell_grid,
classify_column_types,
create_layout_image,
create_ocr_image,
deskew_image,
deskew_image_by_word_alignment,
detect_column_geometry,
detect_row_geometry,
_apply_shear,
dewarp_image,
llm_review_entries,
)
from ocr_pipeline_common import (
_cache,
_load_session_to_cache,
_get_cached,
)
from ocr_pipeline_session_store import (
get_session_db,
update_session_db,
)
from ocr_pipeline_auto_helpers import (
RunAutoRequest,
auto_sse_event as _auto_sse_event,
detect_shear_with_vlm as _detect_shear_with_vlm,
)
logger = logging.getLogger(__name__)
router = APIRouter(tags=["ocr-pipeline"])
@router.post("/sessions/{session_id}/run-auto")
async def run_auto(session_id: str, req: RunAutoRequest, request: Request):
"""Run the full OCR pipeline automatically from a given step, streaming SSE progress.
Steps:
1. Deskew -- straighten the scan
2. Dewarp -- correct vertical shear (ensemble CV or VLM)
3. Columns -- detect column layout
4. Rows -- detect row layout
5. Words -- OCR each cell
6. LLM review -- correct OCR errors (optional)
Already-completed steps are skipped unless `from_step` forces a rerun.
Yields SSE events of the form:
data: {"step": "deskew", "status": "start"|"done"|"skipped"|"error", ...}
Final event:
data: {"step": "complete", "status": "done", "steps_run": [...], "steps_skipped": [...]}
"""
if req.from_step < 1 or req.from_step > 6:
raise HTTPException(status_code=400, detail="from_step must be 1-6")
if req.dewarp_method not in ("ensemble", "vlm", "cv"):
raise HTTPException(status_code=400, detail="dewarp_method must be: ensemble, vlm, cv")
if session_id not in _cache:
await _load_session_to_cache(session_id)
async def _generate():
steps_run: List[str] = []
steps_skipped: List[str] = []
error_step: Optional[str] = None
session = await get_session_db(session_id)
if not session:
yield await _auto_sse_event("error", "error", {"message": f"Session {session_id} not found"})
return
cached = _get_cached(session_id)
# Step 1: Deskew
if req.from_step <= 1:
yield await _auto_sse_event("deskew", "start", {})
try:
t0 = time.time()
orig_bgr = cached.get("original_bgr")
if orig_bgr is None:
raise ValueError("Original image not loaded")
try:
deskewed_hough, angle_hough = deskew_image(orig_bgr.copy())
except Exception:
deskewed_hough, angle_hough = orig_bgr, 0.0
success_enc, png_orig = cv2.imencode(".png", orig_bgr)
orig_bytes = png_orig.tobytes() if success_enc else b""
try:
deskewed_wa_bytes, angle_wa = deskew_image_by_word_alignment(orig_bytes)
except Exception:
deskewed_wa_bytes, angle_wa = orig_bytes, 0.0
if abs(angle_wa) >= abs(angle_hough) or abs(angle_hough) < 0.1:
method_used = "word_alignment"
angle_applied = angle_wa
wa_arr = np.frombuffer(deskewed_wa_bytes, dtype=np.uint8)
deskewed_bgr = cv2.imdecode(wa_arr, cv2.IMREAD_COLOR)
if deskewed_bgr is None:
deskewed_bgr = deskewed_hough
method_used = "hough"
angle_applied = angle_hough
else:
method_used = "hough"
angle_applied = angle_hough
deskewed_bgr = deskewed_hough
success, png_buf = cv2.imencode(".png", deskewed_bgr)
deskewed_png = png_buf.tobytes() if success else b""
deskew_result = {
"method_used": method_used,
"rotation_degrees": round(float(angle_applied), 3),
"duration_seconds": round(time.time() - t0, 2),
}
cached["deskewed_bgr"] = deskewed_bgr
cached["deskew_result"] = deskew_result
await update_session_db(
session_id,
deskewed_png=deskewed_png,
deskew_result=deskew_result,
auto_rotation_degrees=float(angle_applied),
current_step=3,
)
session = await get_session_db(session_id)
steps_run.append("deskew")
yield await _auto_sse_event("deskew", "done", deskew_result)
except Exception as e:
logger.error(f"Auto-mode deskew failed for {session_id}: {e}")
error_step = "deskew"
yield await _auto_sse_event("deskew", "error", {"message": str(e)})
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
return
else:
steps_skipped.append("deskew")
yield await _auto_sse_event("deskew", "skipped", {"reason": "from_step > 1"})
# Step 2: Dewarp
if req.from_step <= 2:
yield await _auto_sse_event("dewarp", "start", {"method": req.dewarp_method})
try:
t0 = time.time()
deskewed_bgr = cached.get("deskewed_bgr")
if deskewed_bgr is None:
raise ValueError("Deskewed image not available")
if req.dewarp_method == "vlm":
success_enc, png_buf = cv2.imencode(".png", deskewed_bgr)
img_bytes = png_buf.tobytes() if success_enc else b""
vlm_det = await _detect_shear_with_vlm(img_bytes)
shear_deg = vlm_det["shear_degrees"]
if abs(shear_deg) >= 0.05 and vlm_det["confidence"] >= 0.3:
dewarped_bgr = _apply_shear(deskewed_bgr, -shear_deg)
else:
dewarped_bgr = deskewed_bgr
dewarp_info = {
"method": vlm_det["method"],
"shear_degrees": shear_deg,
"confidence": vlm_det["confidence"],
"detections": [vlm_det],
}
else:
dewarped_bgr, dewarp_info = dewarp_image(deskewed_bgr)
success_enc, png_buf = cv2.imencode(".png", dewarped_bgr)
dewarped_png = png_buf.tobytes() if success_enc else b""
dewarp_result = {
"method_used": dewarp_info["method"],
"shear_degrees": dewarp_info["shear_degrees"],
"confidence": dewarp_info["confidence"],
"duration_seconds": round(time.time() - t0, 2),
"detections": dewarp_info.get("detections", []),
}
cached["dewarped_bgr"] = dewarped_bgr
cached["dewarp_result"] = dewarp_result
await update_session_db(
session_id,
dewarped_png=dewarped_png,
dewarp_result=dewarp_result,
auto_shear_degrees=dewarp_info.get("shear_degrees", 0.0),
current_step=4,
)
session = await get_session_db(session_id)
steps_run.append("dewarp")
yield await _auto_sse_event("dewarp", "done", dewarp_result)
except Exception as e:
logger.error(f"Auto-mode dewarp failed for {session_id}: {e}")
error_step = "dewarp"
yield await _auto_sse_event("dewarp", "error", {"message": str(e)})
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
return
else:
steps_skipped.append("dewarp")
yield await _auto_sse_event("dewarp", "skipped", {"reason": "from_step > 2"})
# Step 3: Columns
if req.from_step <= 3:
yield await _auto_sse_event("columns", "start", {})
try:
t0 = time.time()
col_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
if col_img is None:
raise ValueError("Cropped/dewarped image not available")
ocr_img = create_ocr_image(col_img)
h, w = ocr_img.shape[:2]
geo_result = detect_column_geometry(ocr_img, col_img)
if geo_result is None:
layout_img = create_layout_image(col_img)
regions = analyze_layout(layout_img, ocr_img)
cached["_word_dicts"] = None
cached["_inv"] = None
cached["_content_bounds"] = None
else:
geometries, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
content_w = right_x - left_x
cached["_word_dicts"] = word_dicts
cached["_inv"] = inv
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
header_y, footer_y = _detect_header_footer_gaps(inv, w, h) if inv is not None else (None, None)
geometries = _detect_sub_columns(geometries, content_w, left_x=left_x,
top_y=top_y, header_y=header_y, footer_y=footer_y)
regions = classify_column_types(geometries, content_w, top_y, w, h, bottom_y,
left_x=left_x, right_x=right_x, inv=inv)
columns = [asdict(r) for r in regions]
column_result = {
"columns": columns,
"classification_methods": list({c.get("classification_method", "") for c in columns if c.get("classification_method")}),
"duration_seconds": round(time.time() - t0, 2),
}
cached["column_result"] = column_result
await update_session_db(session_id, column_result=column_result,
row_result=None, word_result=None, current_step=6)
session = await get_session_db(session_id)
steps_run.append("columns")
yield await _auto_sse_event("columns", "done", {
"column_count": len(columns),
"duration_seconds": column_result["duration_seconds"],
})
except Exception as e:
logger.error(f"Auto-mode columns failed for {session_id}: {e}")
error_step = "columns"
yield await _auto_sse_event("columns", "error", {"message": str(e)})
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
return
else:
steps_skipped.append("columns")
yield await _auto_sse_event("columns", "skipped", {"reason": "from_step > 3"})
# Step 4: Rows
if req.from_step <= 4:
yield await _auto_sse_event("rows", "start", {})
try:
t0 = time.time()
row_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
session = await get_session_db(session_id)
column_result = session.get("column_result") or cached.get("column_result")
if not column_result or not column_result.get("columns"):
raise ValueError("Column detection must complete first")
col_regions = [
PageRegion(
type=c["type"], x=c["x"], y=c["y"],
width=c["width"], height=c["height"],
classification_confidence=c.get("classification_confidence", 1.0),
classification_method=c.get("classification_method", ""),
)
for c in column_result["columns"]
]
word_dicts = cached.get("_word_dicts")
inv = cached.get("_inv")
content_bounds = cached.get("_content_bounds")
if word_dicts is None or inv is None or content_bounds is None:
ocr_img_tmp = create_ocr_image(row_img)
geo_result = detect_column_geometry(ocr_img_tmp, row_img)
if geo_result is None:
raise ValueError("Column geometry detection failed -- cannot detect rows")
_g, lx, rx, ty, by, word_dicts, inv = geo_result
cached["_word_dicts"] = word_dicts
cached["_inv"] = inv
cached["_content_bounds"] = (lx, rx, ty, by)
content_bounds = (lx, rx, ty, by)
left_x, right_x, top_y, bottom_y = content_bounds
row_geoms = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y)
row_list = [
{
"index": r.index, "x": r.x, "y": r.y,
"width": r.width, "height": r.height,
"word_count": r.word_count,
"row_type": r.row_type,
"gap_before": r.gap_before,
}
for r in row_geoms
]
row_result = {
"rows": row_list,
"row_count": len(row_list),
"content_rows": len([r for r in row_geoms if r.row_type == "content"]),
"duration_seconds": round(time.time() - t0, 2),
}
cached["row_result"] = row_result
await update_session_db(session_id, row_result=row_result, current_step=7)
session = await get_session_db(session_id)
steps_run.append("rows")
yield await _auto_sse_event("rows", "done", {
"row_count": len(row_list),
"content_rows": row_result["content_rows"],
"duration_seconds": row_result["duration_seconds"],
})
except Exception as e:
logger.error(f"Auto-mode rows failed for {session_id}: {e}")
error_step = "rows"
yield await _auto_sse_event("rows", "error", {"message": str(e)})
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
return
else:
steps_skipped.append("rows")
yield await _auto_sse_event("rows", "skipped", {"reason": "from_step > 4"})
# Step 5: Words (OCR)
if req.from_step <= 5:
yield await _auto_sse_event("words", "start", {"engine": req.ocr_engine})
try:
t0 = time.time()
word_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
session = await get_session_db(session_id)
column_result = session.get("column_result") or cached.get("column_result")
row_result = session.get("row_result") or cached.get("row_result")
col_regions = [
PageRegion(
type=c["type"], x=c["x"], y=c["y"],
width=c["width"], height=c["height"],
classification_confidence=c.get("classification_confidence", 1.0),
classification_method=c.get("classification_method", ""),
)
for c in column_result["columns"]
]
row_geoms = [
RowGeometry(
index=r["index"], x=r["x"], y=r["y"],
width=r["width"], height=r["height"],
word_count=r.get("word_count", 0), words=[],
row_type=r.get("row_type", "content"),
gap_before=r.get("gap_before", 0),
)
for r in row_result["rows"]
]
word_dicts = cached.get("_word_dicts")
if word_dicts is not None:
content_bounds = cached.get("_content_bounds")
top_y = content_bounds[2] if content_bounds else min(r.y for r in row_geoms)
for row in row_geoms:
row_y_rel = row.y - top_y
row_bottom_rel = row_y_rel + row.height
row.words = [
w for w in word_dicts
if row_y_rel <= w['top'] + w['height'] / 2 < row_bottom_rel
]
row.word_count = len(row.words)
ocr_img = create_ocr_image(word_img)
img_h, img_w = word_img.shape[:2]
cells, columns_meta = build_cell_grid(
ocr_img, col_regions, row_geoms, img_w, img_h,
ocr_engine=req.ocr_engine, img_bgr=word_img,
)
duration = time.time() - t0
col_types = {c['type'] for c in columns_meta}
is_vocab = bool(col_types & {'column_en', 'column_de'})
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
used_engine = cells[0].get("ocr_engine", "tesseract") if cells else req.ocr_engine
fix_cell_phonetics(cells, pronunciation=req.pronunciation)
word_result_data = {
"cells": cells,
"grid_shape": {
"rows": n_content_rows,
"cols": len(columns_meta),
"total_cells": len(cells),
},
"columns_used": columns_meta,
"layout": "vocab" if is_vocab else "generic",
"image_width": img_w,
"image_height": img_h,
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
"summary": {
"total_cells": len(cells),
"non_empty_cells": sum(1 for c in cells if c.get("text")),
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
},
}
has_text_col = 'column_text' in col_types
if is_vocab or has_text_col:
entries = _cells_to_vocab_entries(cells, columns_meta)
entries = _fix_character_confusion(entries)
entries = _fix_phonetic_brackets(entries, pronunciation=req.pronunciation)
word_result_data["vocab_entries"] = entries
word_result_data["entries"] = entries
word_result_data["entry_count"] = len(entries)
word_result_data["summary"]["total_entries"] = len(entries)
await update_session_db(session_id, word_result=word_result_data, current_step=8)
cached["word_result"] = word_result_data
session = await get_session_db(session_id)
steps_run.append("words")
yield await _auto_sse_event("words", "done", {
"total_cells": len(cells),
"layout": word_result_data["layout"],
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
"summary": word_result_data["summary"],
})
except Exception as e:
logger.error(f"Auto-mode words failed for {session_id}: {e}")
error_step = "words"
yield await _auto_sse_event("words", "error", {"message": str(e)})
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
return
else:
steps_skipped.append("words")
yield await _auto_sse_event("words", "skipped", {"reason": "from_step > 5"})
# Step 6: LLM Review (optional)
if req.from_step <= 6 and not req.skip_llm_review:
yield await _auto_sse_event("llm_review", "start", {"model": OLLAMA_REVIEW_MODEL})
try:
session = await get_session_db(session_id)
word_result = session.get("word_result") or cached.get("word_result")
entries = word_result.get("entries") or word_result.get("vocab_entries") or []
if not entries:
yield await _auto_sse_event("llm_review", "skipped", {"reason": "no entries"})
steps_skipped.append("llm_review")
else:
reviewed = await llm_review_entries(entries)
session = await get_session_db(session_id)
word_result_updated = dict(session.get("word_result") or {})
word_result_updated["entries"] = reviewed
word_result_updated["vocab_entries"] = reviewed
word_result_updated["llm_reviewed"] = True
word_result_updated["llm_model"] = OLLAMA_REVIEW_MODEL
await update_session_db(session_id, word_result=word_result_updated, current_step=9)
cached["word_result"] = word_result_updated
steps_run.append("llm_review")
yield await _auto_sse_event("llm_review", "done", {
"entries_reviewed": len(reviewed),
"model": OLLAMA_REVIEW_MODEL,
})
except Exception as e:
logger.warning(f"Auto-mode llm_review failed for {session_id} (non-fatal): {e}")
yield await _auto_sse_event("llm_review", "error", {"message": str(e), "fatal": False})
steps_skipped.append("llm_review")
else:
steps_skipped.append("llm_review")
reason = "skipped by request" if req.skip_llm_review else "from_step > 6"
yield await _auto_sse_event("llm_review", "skipped", {"reason": reason})
# Final event
yield await _auto_sse_event("complete", "done", {
"steps_run": steps_run,
"steps_skipped": steps_skipped,
})
return StreamingResponse(
_generate(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)

View File

@@ -0,0 +1,94 @@
"""
OCR Pipeline Reprocess Endpoint.
POST /sessions/{session_id}/reprocess — clear downstream + restart from step.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
from typing import Any, Dict
from fastapi import APIRouter, HTTPException, Request
from ocr_pipeline_common import _cache
from ocr_pipeline_session_store import get_session_db, update_session_db
logger = logging.getLogger(__name__)
router = APIRouter(tags=["ocr-pipeline"])
@router.post("/sessions/{session_id}/reprocess")
async def reprocess_session(session_id: str, request: Request):
"""Re-run pipeline from a specific step, clearing downstream data.
Body: {"from_step": 5} (1-indexed step number)
Pipeline order: Orientation(1) -> Deskew(2) -> Dewarp(3) -> Crop(4) -> Columns(5) ->
Rows(6) -> Words(7) -> LLM-Review(8) -> Reconstruction(9) -> Validation(10)
Clears downstream results:
- from_step <= 1: orientation_result + all downstream
- from_step <= 2: deskew_result + all downstream
- from_step <= 3: dewarp_result + all downstream
- from_step <= 4: crop_result + all downstream
- from_step <= 5: column_result, row_result, word_result
- from_step <= 6: row_result, word_result
- from_step <= 7: word_result (cells, vocab_entries)
- from_step <= 8: word_result.llm_review only
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
body = await request.json()
from_step = body.get("from_step", 1)
if not isinstance(from_step, int) or from_step < 1 or from_step > 10:
raise HTTPException(status_code=400, detail="from_step must be between 1 and 10")
update_kwargs: Dict[str, Any] = {"current_step": from_step}
# Clear downstream data based on from_step
# New pipeline order: Orient(2) -> Deskew(3) -> Dewarp(4) -> Crop(5) ->
# Columns(6) -> Rows(7) -> Words(8) -> LLM(9) -> Recon(10) -> GT(11)
if from_step <= 8:
update_kwargs["word_result"] = None
elif from_step == 9:
# Only clear LLM review from word_result
word_result = session.get("word_result")
if word_result:
word_result.pop("llm_review", None)
word_result.pop("llm_corrections", None)
update_kwargs["word_result"] = word_result
if from_step <= 7:
update_kwargs["row_result"] = None
if from_step <= 6:
update_kwargs["column_result"] = None
if from_step <= 4:
update_kwargs["crop_result"] = None
if from_step <= 3:
update_kwargs["dewarp_result"] = None
if from_step <= 2:
update_kwargs["deskew_result"] = None
if from_step <= 1:
update_kwargs["orientation_result"] = None
await update_session_db(session_id, **update_kwargs)
# Also clear cache
if session_id in _cache:
for key in list(update_kwargs.keys()):
if key != "current_step":
_cache[session_id][key] = update_kwargs[key]
_cache[session_id]["current_step"] = from_step
logger.info(f"Session {session_id} reprocessing from step {from_step}")
return {
"session_id": session_id,
"from_step": from_step,
"cleared": [k for k in update_kwargs if k != "current_step"],
}

View File

@@ -1,758 +1,33 @@
"""
Page Crop - Content-based crop for scanned pages and book scans.
Page Crop — Barrel Re-export
Detects the content boundary by analysing ink density projections and
(for book scans) the spine shadow gradient. Works with both loose A4
sheets on dark scanners AND book scans with white backgrounds.
Content-based crop for scanned pages and book scans.
Split into:
- page_crop_edges.py — Edge detection (spine shadow, gutter, projection)
- page_crop_core.py — Main crop algorithm and format detection
All public names are re-exported here for backward compatibility.
License: Apache 2.0
"""
import logging
from typing import Dict, Any, Tuple, Optional
# Core: main crop functions and format detection
from page_crop_core import ( # noqa: F401
PAPER_FORMATS,
detect_page_splits,
detect_and_crop_page,
_detect_format,
)
import cv2
import numpy as np
logger = logging.getLogger(__name__)
# Known paper format aspect ratios (height / width, portrait orientation)
PAPER_FORMATS = {
"A4": 297.0 / 210.0, # 1.4143
"A5": 210.0 / 148.0, # 1.4189
"Letter": 11.0 / 8.5, # 1.2941
"Legal": 14.0 / 8.5, # 1.6471
"A3": 420.0 / 297.0, # 1.4141
}
# Minimum ink density (fraction of pixels) to count a row/column as "content"
_INK_THRESHOLD = 0.003 # 0.3%
# Minimum run length (fraction of dimension) to keep — shorter runs are noise
_MIN_RUN_FRAC = 0.005 # 0.5%
def detect_page_splits(
img_bgr: np.ndarray,
) -> list:
"""Detect if the image is a multi-page spread and return split rectangles.
Uses **brightness** (not ink density) to find the spine area:
the scanner bed produces a characteristic gray strip where pages meet,
which is darker than the white paper on either side.
Returns a list of page dicts ``{x, y, width, height, page_index}``
or an empty list if only one page is detected.
"""
h, w = img_bgr.shape[:2]
# Only check landscape-ish images (width > height * 1.15)
if w < h * 1.15:
return []
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
# Column-mean brightness (0-255) — the spine is darker (gray scanner bed)
col_brightness = np.mean(gray, axis=0).astype(np.float64)
# Heavy smoothing to ignore individual text lines
kern = max(11, w // 50)
if kern % 2 == 0:
kern += 1
brightness_smooth = np.convolve(col_brightness, np.ones(kern) / kern, mode="same")
# Page paper is bright (typically > 200), spine/scanner bed is darker
page_brightness = float(np.max(brightness_smooth))
if page_brightness < 100:
return [] # Very dark image, skip
# Spine threshold: significantly darker than the page
# Spine is typically 60-80% of paper brightness
spine_thresh = page_brightness * 0.88
# Search in center region (30-70% of width)
center_lo = int(w * 0.30)
center_hi = int(w * 0.70)
# Find the darkest valley in the center region
center_brightness = brightness_smooth[center_lo:center_hi]
darkest_val = float(np.min(center_brightness))
if darkest_val >= spine_thresh:
logger.debug("No spine detected: min brightness %.0f >= threshold %.0f",
darkest_val, spine_thresh)
return []
# Find ALL contiguous dark runs in the center region
is_dark = center_brightness < spine_thresh
dark_runs: list = [] # list of (start, end) pairs
run_start = -1
for i in range(len(is_dark)):
if is_dark[i]:
if run_start < 0:
run_start = i
else:
if run_start >= 0:
dark_runs.append((run_start, i))
run_start = -1
if run_start >= 0:
dark_runs.append((run_start, len(is_dark)))
# Filter out runs that are too narrow (< 1% of image width)
min_spine_px = int(w * 0.01)
dark_runs = [(s, e) for s, e in dark_runs if e - s >= min_spine_px]
if not dark_runs:
logger.debug("No dark runs wider than %dpx in center region", min_spine_px)
return []
# Score each dark run: prefer centered, dark, narrow valleys
center_region_len = center_hi - center_lo
image_center_in_region = (w * 0.5 - center_lo) # x=50% mapped into region coords
best_score = -1.0
best_start, best_end = dark_runs[0]
for rs, re in dark_runs:
run_width = re - rs
run_center = (rs + re) / 2.0
# --- Factor 1: Proximity to image center (gaussian, sigma = 15% of region) ---
sigma = center_region_len * 0.15
dist = abs(run_center - image_center_in_region)
center_factor = float(np.exp(-0.5 * (dist / sigma) ** 2))
# --- Factor 2: Darkness (how dark is the valley relative to threshold) ---
run_brightness = float(np.mean(center_brightness[rs:re]))
# Normalize: 1.0 when run_brightness == 0, 0.0 when run_brightness == spine_thresh
darkness_factor = max(0.0, (spine_thresh - run_brightness) / spine_thresh)
# --- Factor 3: Narrowness bonus (spine shadows are narrow, not wide plateaus) ---
# Typical spine: 1-5% of image width. Penalise runs wider than ~8%.
width_frac = run_width / w
if width_frac <= 0.05:
narrowness_bonus = 1.0
elif width_frac <= 0.15:
narrowness_bonus = 1.0 - (width_frac - 0.05) / 0.10 # linear decay 1.0 → 0.0
else:
narrowness_bonus = 0.0
score = center_factor * darkness_factor * (0.3 + 0.7 * narrowness_bonus)
logger.debug(
"Dark run x=%d..%d (w=%d): center_f=%.3f dark_f=%.3f narrow_b=%.3f → score=%.4f",
center_lo + rs, center_lo + re, run_width,
center_factor, darkness_factor, narrowness_bonus, score,
)
if score > best_score:
best_score = score
best_start, best_end = rs, re
spine_w = best_end - best_start
spine_x = center_lo + best_start
spine_center = spine_x + spine_w // 2
logger.debug(
"Best spine candidate: x=%d..%d (w=%d), score=%.4f",
spine_x, spine_x + spine_w, spine_w, best_score,
)
# Verify: must have bright (paper) content on BOTH sides
left_brightness = float(np.mean(brightness_smooth[max(0, spine_x - w // 10):spine_x]))
right_end = center_lo + best_end
right_brightness = float(np.mean(brightness_smooth[right_end:min(w, right_end + w // 10)]))
if left_brightness < spine_thresh or right_brightness < spine_thresh:
logger.debug("No bright paper flanking spine: left=%.0f right=%.0f thresh=%.0f",
left_brightness, right_brightness, spine_thresh)
return []
logger.info(
"Spine detected: x=%d..%d (w=%d), brightness=%.0f vs paper=%.0f, "
"left_paper=%.0f, right_paper=%.0f",
spine_x, right_end, spine_w, darkest_val, page_brightness,
left_brightness, right_brightness,
)
# Split at the spine center
split_points = [spine_center]
# Build page rectangles
pages: list = []
prev_x = 0
for i, sx in enumerate(split_points):
pages.append({"x": prev_x, "y": 0, "width": sx - prev_x,
"height": h, "page_index": i})
prev_x = sx
pages.append({"x": prev_x, "y": 0, "width": w - prev_x,
"height": h, "page_index": len(split_points)})
# Filter out tiny pages (< 15% of total width)
pages = [p for p in pages if p["width"] >= w * 0.15]
if len(pages) < 2:
return []
# Re-index
for i, p in enumerate(pages):
p["page_index"] = i
logger.info(
"Page split detected: %d pages, spine_w=%d, split_points=%s",
len(pages), spine_w, split_points,
)
return pages
def detect_and_crop_page(
img_bgr: np.ndarray,
margin_frac: float = 0.01,
) -> Tuple[np.ndarray, Dict[str, Any]]:
"""Detect content boundary and crop scanner/book borders.
Algorithm (4-edge detection):
1. Adaptive threshold → binary (text=255, bg=0)
2. Left edge: spine-shadow detection via grayscale column means,
fallback to binary vertical projection
3. Right edge: binary vertical projection (last ink column)
4. Top/bottom edges: binary horizontal projection
5. Sanity checks, then crop with configurable margin
Args:
img_bgr: Input BGR image (should already be deskewed/dewarped)
margin_frac: Extra margin around content (fraction of dimension, default 1%)
Returns:
Tuple of (cropped_image, result_dict)
"""
h, w = img_bgr.shape[:2]
total_area = h * w
result: Dict[str, Any] = {
"crop_applied": False,
"crop_rect": None,
"crop_rect_pct": None,
"original_size": {"width": w, "height": h},
"cropped_size": {"width": w, "height": h},
"detected_format": None,
"format_confidence": 0.0,
"aspect_ratio": round(max(h, w) / max(min(h, w), 1), 4),
"border_fractions": {"top": 0.0, "bottom": 0.0, "left": 0.0, "right": 0.0},
}
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
# --- Binarise with adaptive threshold (works for white-on-white) ---
binary = cv2.adaptiveThreshold(
gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY_INV, blockSize=51, C=15,
)
# --- Left edge: spine-shadow detection ---
left_edge = _detect_left_edge_shadow(gray, binary, w, h)
# --- Right edge: spine-shadow detection ---
right_edge = _detect_right_edge_shadow(gray, binary, w, h)
# --- Top / bottom edges: binary horizontal projection ---
top_edge, bottom_edge = _detect_top_bottom_edges(binary, w, h)
# Compute border fractions
border_top = top_edge / h
border_bottom = (h - bottom_edge) / h
border_left = left_edge / w
border_right = (w - right_edge) / w
result["border_fractions"] = {
"top": round(border_top, 4),
"bottom": round(border_bottom, 4),
"left": round(border_left, 4),
"right": round(border_right, 4),
}
# Sanity: only crop if at least one edge has > 2% border
min_border = 0.02
if all(f < min_border for f in [border_top, border_bottom, border_left, border_right]):
logger.info("All borders < %.0f%% — no crop needed", min_border * 100)
result["detected_format"], result["format_confidence"] = _detect_format(w, h)
return img_bgr, result
# Add margin
margin_x = int(w * margin_frac)
margin_y = int(h * margin_frac)
crop_x = max(0, left_edge - margin_x)
crop_y = max(0, top_edge - margin_y)
crop_x2 = min(w, right_edge + margin_x)
crop_y2 = min(h, bottom_edge + margin_y)
crop_w = crop_x2 - crop_x
crop_h = crop_y2 - crop_y
# Sanity: cropped area must be >= 40% of original
if crop_w * crop_h < 0.40 * total_area:
logger.warning("Cropped area too small (%.0f%%) — skipping crop",
100.0 * crop_w * crop_h / total_area)
result["detected_format"], result["format_confidence"] = _detect_format(w, h)
return img_bgr, result
cropped = img_bgr[crop_y:crop_y2, crop_x:crop_x2].copy()
detected_format, format_confidence = _detect_format(crop_w, crop_h)
result["crop_applied"] = True
result["crop_rect"] = {"x": crop_x, "y": crop_y, "width": crop_w, "height": crop_h}
result["crop_rect_pct"] = {
"x": round(100.0 * crop_x / w, 2),
"y": round(100.0 * crop_y / h, 2),
"width": round(100.0 * crop_w / w, 2),
"height": round(100.0 * crop_h / h, 2),
}
result["cropped_size"] = {"width": crop_w, "height": crop_h}
result["detected_format"] = detected_format
result["format_confidence"] = format_confidence
result["aspect_ratio"] = round(max(crop_w, crop_h) / max(min(crop_w, crop_h), 1), 4)
logger.info(
"Page cropped: %dx%d -> %dx%d, format=%s (%.0f%%), "
"borders: T=%.1f%% B=%.1f%% L=%.1f%% R=%.1f%%",
w, h, crop_w, crop_h, detected_format, format_confidence * 100,
border_top * 100, border_bottom * 100,
border_left * 100, border_right * 100,
)
return cropped, result
# ---------------------------------------------------------------------------
# Edge detection helpers
# ---------------------------------------------------------------------------
def _detect_spine_shadow(
gray: np.ndarray,
search_region: np.ndarray,
offset_x: int,
w: int,
side: str,
) -> Optional[int]:
"""Find the book spine center (darkest point) in a scanner shadow.
The scanner produces a gray strip where the book spine presses against
the glass. The darkest column in that strip is the spine center —
that's where we crop.
Distinguishes real spine shadows from text content by checking:
1. Strong brightness range (> 40 levels)
2. Darkest point is genuinely dark (< 180 mean brightness)
3. The dark area is a NARROW valley, not a text-content plateau
4. Brightness rises significantly toward the page content side
Args:
gray: Full grayscale image (for context).
search_region: Column slice of the grayscale image to search in.
offset_x: X offset of search_region relative to full image.
w: Full image width.
side: 'left' or 'right' (for logging).
Returns:
X coordinate (in full image) of the spine center, or None.
"""
region_w = search_region.shape[1]
if region_w < 10:
return None
# Column-mean brightness in the search region
col_means = np.mean(search_region, axis=0).astype(np.float64)
# Smooth with boxcar kernel (width = 1% of image width, min 5)
kernel_size = max(5, w // 100)
if kernel_size % 2 == 0:
kernel_size += 1
kernel = np.ones(kernel_size) / kernel_size
smoothed_raw = np.convolve(col_means, kernel, mode="same")
# Trim convolution edge artifacts (edges are zero-padded → artificially low)
margin = kernel_size // 2
if region_w <= 2 * margin + 10:
return None
smoothed = smoothed_raw[margin:region_w - margin]
trim_offset = margin # offset of smoothed[0] relative to search_region
val_min = float(np.min(smoothed))
val_max = float(np.max(smoothed))
shadow_range = val_max - val_min
# --- Check 1: Strong brightness gradient ---
if shadow_range <= 40:
logger.debug(
"%s edge: no spine (range=%.0f <= 40)", side.capitalize(), shadow_range,
)
return None
# --- Check 2: Darkest point must be genuinely dark ---
# Spine shadows have mean column brightness 60-160.
# Text on white paper stays above 180.
if val_min > 180:
logger.debug(
"%s edge: no spine (darkest=%.0f > 180, likely text)", side.capitalize(), val_min,
)
return None
spine_idx = int(np.argmin(smoothed)) # index in trimmed array
spine_local = spine_idx + trim_offset # index in search_region
trimmed_len = len(smoothed)
# --- Check 3: Valley width (spine is narrow, text plateau is wide) ---
# Count how many columns are within 20% of the shadow range above the min.
valley_thresh = val_min + shadow_range * 0.20
valley_mask = smoothed < valley_thresh
valley_width = int(np.sum(valley_mask))
# Spine valleys are typically 3-15% of image width (20-120px on a 800px image).
# Text content plateaus span 20%+ of the search region.
max_valley_frac = 0.50 # valley must not cover more than half the trimmed region
if valley_width > trimmed_len * max_valley_frac:
logger.debug(
"%s edge: no spine (valley too wide: %d/%d = %.0f%%)",
side.capitalize(), valley_width, trimmed_len,
100.0 * valley_width / trimmed_len,
)
return None
# --- Check 4: Brightness must rise toward page content ---
# For left edge: after spine, brightness should rise (= page paper)
# For right edge: before spine, brightness should rise
rise_check_w = max(5, trimmed_len // 5) # check 20% of trimmed region
if side == "left":
# Check columns to the right of the spine (in trimmed array)
right_start = min(spine_idx + 5, trimmed_len - 1)
right_end = min(right_start + rise_check_w, trimmed_len)
if right_end > right_start:
rise_brightness = float(np.mean(smoothed[right_start:right_end]))
rise = rise_brightness - val_min
if rise < shadow_range * 0.3:
logger.debug(
"%s edge: no spine (insufficient rise: %.0f, need %.0f)",
side.capitalize(), rise, shadow_range * 0.3,
)
return None
else: # right
# Check columns to the left of the spine (in trimmed array)
left_end = max(spine_idx - 5, 0)
left_start = max(left_end - rise_check_w, 0)
if left_end > left_start:
rise_brightness = float(np.mean(smoothed[left_start:left_end]))
rise = rise_brightness - val_min
if rise < shadow_range * 0.3:
logger.debug(
"%s edge: no spine (insufficient rise: %.0f, need %.0f)",
side.capitalize(), rise, shadow_range * 0.3,
)
return None
spine_x = offset_x + spine_local
logger.info(
"%s edge: spine center at x=%d (brightness=%.0f, range=%.0f, valley=%dpx)",
side.capitalize(), spine_x, val_min, shadow_range, valley_width,
)
return spine_x
def _detect_gutter_continuity(
gray: np.ndarray,
search_region: np.ndarray,
offset_x: int,
w: int,
side: str,
) -> Optional[int]:
"""Detect gutter shadow via vertical continuity analysis.
Camera book scans produce a subtle brightness gradient at the gutter
that is too faint for scanner-shadow detection (range < 40). However,
the gutter shadow has a unique property: it runs **continuously from
top to bottom** without interruption. Text and images always have
vertical gaps between lines, paragraphs, or sections.
Algorithm:
1. Divide image into N horizontal strips (~60px each)
2. For each column, compute what fraction of strips are darker than
the page median (from the center 50% of the full image)
3. A "gutter column" has ≥ 75% of strips darker than page_median δ
4. Smooth the dark-fraction profile and find the transition point
from the edge inward where the fraction drops below 0.50
5. Validate: gutter band must be 0.5%-10% of image width
Args:
gray: Full grayscale image.
search_region: Edge slice of the grayscale image.
offset_x: X offset of search_region relative to full image.
w: Full image width.
side: 'left' or 'right'.
Returns:
X coordinate (in full image) of the gutter inner edge, or None.
"""
region_h, region_w = search_region.shape[:2]
if region_w < 20 or region_h < 100:
return None
# --- 1. Divide into horizontal strips ---
strip_target_h = 60 # ~60px per strip
n_strips = max(10, region_h // strip_target_h)
strip_h = region_h // n_strips
strip_means = np.zeros((n_strips, region_w), dtype=np.float64)
for s in range(n_strips):
y0 = s * strip_h
y1 = min((s + 1) * strip_h, region_h)
strip_means[s] = np.mean(search_region[y0:y1, :], axis=0)
# --- 2. Page median from center 50% of full image ---
center_lo = w // 4
center_hi = 3 * w // 4
page_median = float(np.median(gray[:, center_lo:center_hi]))
# Camera shadows are subtle — threshold just 5 levels below page median
dark_thresh = page_median - 5.0
# If page is very dark overall (e.g. photo, not a book page), bail out
if page_median < 180:
return None
# --- 3. Per-column dark fraction ---
dark_count = np.sum(strip_means < dark_thresh, axis=0).astype(np.float64)
dark_frac = dark_count / n_strips # shape: (region_w,)
# --- 4. Smooth and find transition ---
# Rolling mean (window = 1% of image width, min 5)
smooth_w = max(5, w // 100)
if smooth_w % 2 == 0:
smooth_w += 1
kernel = np.ones(smooth_w) / smooth_w
frac_smooth = np.convolve(dark_frac, kernel, mode="same")
# Trim convolution edges
margin = smooth_w // 2
if region_w <= 2 * margin + 10:
return None
# Find the peak of dark fraction (gutter center).
# For right gutters the peak is near the edge; for left gutters
# (V-shaped spine shadow) the peak may be well inside the region.
transition_thresh = 0.50
peak_frac = float(np.max(frac_smooth[margin:region_w - margin]))
if peak_frac < 0.70:
logger.debug(
"%s gutter: peak dark fraction %.2f < 0.70", side.capitalize(), peak_frac,
)
return None
peak_x = int(np.argmax(frac_smooth[margin:region_w - margin])) + margin
gutter_inner = None # local x in search_region
if side == "right":
# Scan from peak toward the page center (leftward)
for x in range(peak_x, margin, -1):
if frac_smooth[x] < transition_thresh:
gutter_inner = x + 1
break
else:
# Scan from peak toward the page center (rightward)
for x in range(peak_x, region_w - margin):
if frac_smooth[x] < transition_thresh:
gutter_inner = x - 1
break
if gutter_inner is None:
return None
# --- 5. Validate gutter width ---
if side == "right":
gutter_width = region_w - gutter_inner
else:
gutter_width = gutter_inner
min_gutter = max(3, int(w * 0.005)) # at least 0.5% of image
max_gutter = int(w * 0.10) # at most 10% of image
if gutter_width < min_gutter:
logger.debug(
"%s gutter: too narrow (%dpx < %dpx)", side.capitalize(),
gutter_width, min_gutter,
)
return None
if gutter_width > max_gutter:
logger.debug(
"%s gutter: too wide (%dpx > %dpx)", side.capitalize(),
gutter_width, max_gutter,
)
return None
# Check that the gutter band is meaningfully darker than the page
if side == "right":
gutter_brightness = float(np.mean(strip_means[:, gutter_inner:]))
else:
gutter_brightness = float(np.mean(strip_means[:, :gutter_inner]))
brightness_drop = page_median - gutter_brightness
if brightness_drop < 3:
logger.debug(
"%s gutter: insufficient brightness drop (%.1f levels)",
side.capitalize(), brightness_drop,
)
return None
gutter_x = offset_x + gutter_inner
logger.info(
"%s gutter (continuity): x=%d, width=%dpx (%.1f%%), "
"brightness=%.0f vs page=%.0f (drop=%.0f), frac@edge=%.2f",
side.capitalize(), gutter_x, gutter_width,
100.0 * gutter_width / w, gutter_brightness, page_median,
brightness_drop, float(frac_smooth[gutter_inner]),
)
return gutter_x
def _detect_left_edge_shadow(
gray: np.ndarray,
binary: np.ndarray,
w: int,
h: int,
) -> int:
"""Detect left content edge, accounting for book-spine shadow.
Tries three methods in order:
1. Scanner spine-shadow (dark gradient, range > 40)
2. Camera gutter continuity (subtle shadow running top-to-bottom)
3. Binary projection fallback (first ink column)
"""
search_w = max(1, w // 4)
spine_x = _detect_spine_shadow(gray, gray[:, :search_w], 0, w, "left")
if spine_x is not None:
return spine_x
# Fallback 1: vertical continuity (camera gutter shadow)
gutter_x = _detect_gutter_continuity(gray, gray[:, :search_w], 0, w, "left")
if gutter_x is not None:
return gutter_x
# Fallback 2: binary vertical projection
return _detect_edge_projection(binary, axis=0, from_start=True, dim=w)
def _detect_right_edge_shadow(
gray: np.ndarray,
binary: np.ndarray,
w: int,
h: int,
) -> int:
"""Detect right content edge, accounting for book-spine shadow.
Tries three methods in order:
1. Scanner spine-shadow (dark gradient, range > 40)
2. Camera gutter continuity (subtle shadow running top-to-bottom)
3. Binary projection fallback (last ink column)
"""
search_w = max(1, w // 4)
right_start = w - search_w
spine_x = _detect_spine_shadow(gray, gray[:, right_start:], right_start, w, "right")
if spine_x is not None:
return spine_x
# Fallback 1: vertical continuity (camera gutter shadow)
gutter_x = _detect_gutter_continuity(gray, gray[:, right_start:], right_start, w, "right")
if gutter_x is not None:
return gutter_x
# Fallback 2: binary vertical projection
return _detect_edge_projection(binary, axis=0, from_start=False, dim=w)
def _detect_top_bottom_edges(binary: np.ndarray, w: int, h: int) -> Tuple[int, int]:
"""Detect top and bottom content edges via binary horizontal projection."""
top = _detect_edge_projection(binary, axis=1, from_start=True, dim=h)
bottom = _detect_edge_projection(binary, axis=1, from_start=False, dim=h)
return top, bottom
def _detect_edge_projection(
binary: np.ndarray,
axis: int,
from_start: bool,
dim: int,
) -> int:
"""Find the first/last row or column with ink density above threshold.
axis=0 → project vertically (column densities) → returns x position
axis=1 → project horizontally (row densities) → returns y position
Filters out narrow noise runs shorter than _MIN_RUN_FRAC of the dimension.
"""
# Compute density per row/column (mean of binary pixels / 255)
projection = np.mean(binary, axis=axis) / 255.0
# Create mask of "ink" positions
ink_mask = projection >= _INK_THRESHOLD
# Filter narrow runs (noise)
min_run = max(1, int(dim * _MIN_RUN_FRAC))
ink_mask = _filter_narrow_runs(ink_mask, min_run)
ink_positions = np.where(ink_mask)[0]
if len(ink_positions) == 0:
return 0 if from_start else dim
if from_start:
return int(ink_positions[0])
else:
return int(ink_positions[-1])
def _filter_narrow_runs(mask: np.ndarray, min_run: int) -> np.ndarray:
"""Remove True-runs shorter than min_run pixels."""
if min_run <= 1:
return mask
result = mask.copy()
n = len(result)
i = 0
while i < n:
if result[i]:
start = i
while i < n and result[i]:
i += 1
if i - start < min_run:
result[start:i] = False
else:
i += 1
return result
# ---------------------------------------------------------------------------
# Format detection (kept as optional metadata)
# ---------------------------------------------------------------------------
def _detect_format(width: int, height: int) -> Tuple[str, float]:
"""Detect paper format from dimensions by comparing aspect ratios."""
if width <= 0 or height <= 0:
return "unknown", 0.0
aspect = max(width, height) / min(width, height)
best_format = "unknown"
best_diff = float("inf")
for fmt, expected_ratio in PAPER_FORMATS.items():
diff = abs(aspect - expected_ratio)
if diff < best_diff:
best_diff = diff
best_format = fmt
confidence = max(0.0, 1.0 - best_diff * 5.0)
if confidence < 0.3:
return "unknown", 0.0
return best_format, round(confidence, 3)
from page_crop_edges import ( # noqa: F401
_INK_THRESHOLD,
_MIN_RUN_FRAC,
_detect_spine_shadow,
_detect_gutter_continuity,
_detect_left_edge_shadow,
_detect_right_edge_shadow,
_detect_top_bottom_edges,
_detect_edge_projection,
_filter_narrow_runs,
)

View File

@@ -0,0 +1,342 @@
"""
Page Crop - Core Crop and Format Detection
Content-based crop for scanned pages and book scans. Detects the content
boundary by analysing ink density projections and (for book scans) the
spine shadow gradient.
Extracted from page_crop.py to keep files under 500 LOC.
License: Apache 2.0
"""
import logging
from typing import Dict, Any, Tuple
import cv2
import numpy as np
from page_crop_edges import (
_detect_left_edge_shadow,
_detect_right_edge_shadow,
_detect_top_bottom_edges,
)
logger = logging.getLogger(__name__)
# Known paper format aspect ratios (height / width, portrait orientation)
PAPER_FORMATS = {
"A4": 297.0 / 210.0, # 1.4143
"A5": 210.0 / 148.0, # 1.4189
"Letter": 11.0 / 8.5, # 1.2941
"Legal": 14.0 / 8.5, # 1.6471
"A3": 420.0 / 297.0, # 1.4141
}
def detect_page_splits(
img_bgr: np.ndarray,
) -> list:
"""Detect if the image is a multi-page spread and return split rectangles.
Uses **brightness** (not ink density) to find the spine area:
the scanner bed produces a characteristic gray strip where pages meet,
which is darker than the white paper on either side.
Returns a list of page dicts ``{x, y, width, height, page_index}``
or an empty list if only one page is detected.
"""
h, w = img_bgr.shape[:2]
# Only check landscape-ish images (width > height * 1.15)
if w < h * 1.15:
return []
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
# Column-mean brightness (0-255) — the spine is darker (gray scanner bed)
col_brightness = np.mean(gray, axis=0).astype(np.float64)
# Heavy smoothing to ignore individual text lines
kern = max(11, w // 50)
if kern % 2 == 0:
kern += 1
brightness_smooth = np.convolve(col_brightness, np.ones(kern) / kern, mode="same")
# Page paper is bright (typically > 200), spine/scanner bed is darker
page_brightness = float(np.max(brightness_smooth))
if page_brightness < 100:
return [] # Very dark image, skip
# Spine threshold: significantly darker than the page
spine_thresh = page_brightness * 0.88
# Search in center region (30-70% of width)
center_lo = int(w * 0.30)
center_hi = int(w * 0.70)
# Find the darkest valley in the center region
center_brightness = brightness_smooth[center_lo:center_hi]
darkest_val = float(np.min(center_brightness))
if darkest_val >= spine_thresh:
logger.debug("No spine detected: min brightness %.0f >= threshold %.0f",
darkest_val, spine_thresh)
return []
# Find ALL contiguous dark runs in the center region
is_dark = center_brightness < spine_thresh
dark_runs: list = []
run_start = -1
for i in range(len(is_dark)):
if is_dark[i]:
if run_start < 0:
run_start = i
else:
if run_start >= 0:
dark_runs.append((run_start, i))
run_start = -1
if run_start >= 0:
dark_runs.append((run_start, len(is_dark)))
# Filter out runs that are too narrow (< 1% of image width)
min_spine_px = int(w * 0.01)
dark_runs = [(s, e) for s, e in dark_runs if e - s >= min_spine_px]
if not dark_runs:
logger.debug("No dark runs wider than %dpx in center region", min_spine_px)
return []
# Score each dark run: prefer centered, dark, narrow valleys
center_region_len = center_hi - center_lo
image_center_in_region = (w * 0.5 - center_lo)
best_score = -1.0
best_start, best_end = dark_runs[0]
for rs, re in dark_runs:
run_width = re - rs
run_center = (rs + re) / 2.0
sigma = center_region_len * 0.15
dist = abs(run_center - image_center_in_region)
center_factor = float(np.exp(-0.5 * (dist / sigma) ** 2))
run_brightness = float(np.mean(center_brightness[rs:re]))
darkness_factor = max(0.0, (spine_thresh - run_brightness) / spine_thresh)
width_frac = run_width / w
if width_frac <= 0.05:
narrowness_bonus = 1.0
elif width_frac <= 0.15:
narrowness_bonus = 1.0 - (width_frac - 0.05) / 0.10
else:
narrowness_bonus = 0.0
score = center_factor * darkness_factor * (0.3 + 0.7 * narrowness_bonus)
logger.debug(
"Dark run x=%d..%d (w=%d): center_f=%.3f dark_f=%.3f narrow_b=%.3f -> score=%.4f",
center_lo + rs, center_lo + re, run_width,
center_factor, darkness_factor, narrowness_bonus, score,
)
if score > best_score:
best_score = score
best_start, best_end = rs, re
spine_w = best_end - best_start
spine_x = center_lo + best_start
spine_center = spine_x + spine_w // 2
logger.debug(
"Best spine candidate: x=%d..%d (w=%d), score=%.4f",
spine_x, spine_x + spine_w, spine_w, best_score,
)
# Verify: must have bright (paper) content on BOTH sides
left_brightness = float(np.mean(brightness_smooth[max(0, spine_x - w // 10):spine_x]))
right_end = center_lo + best_end
right_brightness = float(np.mean(brightness_smooth[right_end:min(w, right_end + w // 10)]))
if left_brightness < spine_thresh or right_brightness < spine_thresh:
logger.debug("No bright paper flanking spine: left=%.0f right=%.0f thresh=%.0f",
left_brightness, right_brightness, spine_thresh)
return []
logger.info(
"Spine detected: x=%d..%d (w=%d), brightness=%.0f vs paper=%.0f, "
"left_paper=%.0f, right_paper=%.0f",
spine_x, right_end, spine_w, darkest_val, page_brightness,
left_brightness, right_brightness,
)
# Split at the spine center
split_points = [spine_center]
# Build page rectangles
pages: list = []
prev_x = 0
for i, sx in enumerate(split_points):
pages.append({"x": prev_x, "y": 0, "width": sx - prev_x,
"height": h, "page_index": i})
prev_x = sx
pages.append({"x": prev_x, "y": 0, "width": w - prev_x,
"height": h, "page_index": len(split_points)})
# Filter out tiny pages (< 15% of total width)
pages = [p for p in pages if p["width"] >= w * 0.15]
if len(pages) < 2:
return []
# Re-index
for i, p in enumerate(pages):
p["page_index"] = i
logger.info(
"Page split detected: %d pages, spine_w=%d, split_points=%s",
len(pages), spine_w, split_points,
)
return pages
def detect_and_crop_page(
img_bgr: np.ndarray,
margin_frac: float = 0.01,
) -> Tuple[np.ndarray, Dict[str, Any]]:
"""Detect content boundary and crop scanner/book borders.
Algorithm (4-edge detection):
1. Adaptive threshold -> binary (text=255, bg=0)
2. Left edge: spine-shadow detection via grayscale column means,
fallback to binary vertical projection
3. Right edge: binary vertical projection (last ink column)
4. Top/bottom edges: binary horizontal projection
5. Sanity checks, then crop with configurable margin
Args:
img_bgr: Input BGR image (should already be deskewed/dewarped)
margin_frac: Extra margin around content (fraction of dimension, default 1%)
Returns:
Tuple of (cropped_image, result_dict)
"""
h, w = img_bgr.shape[:2]
total_area = h * w
result: Dict[str, Any] = {
"crop_applied": False,
"crop_rect": None,
"crop_rect_pct": None,
"original_size": {"width": w, "height": h},
"cropped_size": {"width": w, "height": h},
"detected_format": None,
"format_confidence": 0.0,
"aspect_ratio": round(max(h, w) / max(min(h, w), 1), 4),
"border_fractions": {"top": 0.0, "bottom": 0.0, "left": 0.0, "right": 0.0},
}
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
# --- Binarise with adaptive threshold ---
binary = cv2.adaptiveThreshold(
gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY_INV, blockSize=51, C=15,
)
# --- Edge detection ---
left_edge = _detect_left_edge_shadow(gray, binary, w, h)
right_edge = _detect_right_edge_shadow(gray, binary, w, h)
top_edge, bottom_edge = _detect_top_bottom_edges(binary, w, h)
# Compute border fractions
border_top = top_edge / h
border_bottom = (h - bottom_edge) / h
border_left = left_edge / w
border_right = (w - right_edge) / w
result["border_fractions"] = {
"top": round(border_top, 4),
"bottom": round(border_bottom, 4),
"left": round(border_left, 4),
"right": round(border_right, 4),
}
# Sanity: only crop if at least one edge has > 2% border
min_border = 0.02
if all(f < min_border for f in [border_top, border_bottom, border_left, border_right]):
logger.info("All borders < %.0f%% — no crop needed", min_border * 100)
result["detected_format"], result["format_confidence"] = _detect_format(w, h)
return img_bgr, result
# Add margin
margin_x = int(w * margin_frac)
margin_y = int(h * margin_frac)
crop_x = max(0, left_edge - margin_x)
crop_y = max(0, top_edge - margin_y)
crop_x2 = min(w, right_edge + margin_x)
crop_y2 = min(h, bottom_edge + margin_y)
crop_w = crop_x2 - crop_x
crop_h = crop_y2 - crop_y
# Sanity: cropped area must be >= 40% of original
if crop_w * crop_h < 0.40 * total_area:
logger.warning("Cropped area too small (%.0f%%) — skipping crop",
100.0 * crop_w * crop_h / total_area)
result["detected_format"], result["format_confidence"] = _detect_format(w, h)
return img_bgr, result
cropped = img_bgr[crop_y:crop_y2, crop_x:crop_x2].copy()
detected_format, format_confidence = _detect_format(crop_w, crop_h)
result["crop_applied"] = True
result["crop_rect"] = {"x": crop_x, "y": crop_y, "width": crop_w, "height": crop_h}
result["crop_rect_pct"] = {
"x": round(100.0 * crop_x / w, 2),
"y": round(100.0 * crop_y / h, 2),
"width": round(100.0 * crop_w / w, 2),
"height": round(100.0 * crop_h / h, 2),
}
result["cropped_size"] = {"width": crop_w, "height": crop_h}
result["detected_format"] = detected_format
result["format_confidence"] = format_confidence
result["aspect_ratio"] = round(max(crop_w, crop_h) / max(min(crop_w, crop_h), 1), 4)
logger.info(
"Page cropped: %dx%d -> %dx%d, format=%s (%.0f%%), "
"borders: T=%.1f%% B=%.1f%% L=%.1f%% R=%.1f%%",
w, h, crop_w, crop_h, detected_format, format_confidence * 100,
border_top * 100, border_bottom * 100,
border_left * 100, border_right * 100,
)
return cropped, result
# ---------------------------------------------------------------------------
# Format detection (kept as optional metadata)
# ---------------------------------------------------------------------------
def _detect_format(width: int, height: int) -> Tuple[str, float]:
"""Detect paper format from dimensions by comparing aspect ratios."""
if width <= 0 or height <= 0:
return "unknown", 0.0
aspect = max(width, height) / min(width, height)
best_format = "unknown"
best_diff = float("inf")
for fmt, expected_ratio in PAPER_FORMATS.items():
diff = abs(aspect - expected_ratio)
if diff < best_diff:
best_diff = diff
best_format = fmt
confidence = max(0.0, 1.0 - best_diff * 5.0)
if confidence < 0.3:
return "unknown", 0.0
return best_format, round(confidence, 3)

View File

@@ -0,0 +1,388 @@
"""
Page Crop - Edge Detection Helpers
Spine shadow detection, gutter continuity analysis, projection-based
edge detection, and narrow-run filtering for content cropping.
Extracted from page_crop.py to keep files under 500 LOC.
License: Apache 2.0
"""
import logging
from typing import Optional, Tuple
import cv2
import numpy as np
logger = logging.getLogger(__name__)
# Minimum ink density (fraction of pixels) to count a row/column as "content"
_INK_THRESHOLD = 0.003 # 0.3%
# Minimum run length (fraction of dimension) to keep — shorter runs are noise
_MIN_RUN_FRAC = 0.005 # 0.5%
def _detect_spine_shadow(
gray: np.ndarray,
search_region: np.ndarray,
offset_x: int,
w: int,
side: str,
) -> Optional[int]:
"""Find the book spine center (darkest point) in a scanner shadow.
The scanner produces a gray strip where the book spine presses against
the glass. The darkest column in that strip is the spine center —
that's where we crop.
Distinguishes real spine shadows from text content by checking:
1. Strong brightness range (> 40 levels)
2. Darkest point is genuinely dark (< 180 mean brightness)
3. The dark area is a NARROW valley, not a text-content plateau
4. Brightness rises significantly toward the page content side
Args:
gray: Full grayscale image (for context).
search_region: Column slice of the grayscale image to search in.
offset_x: X offset of search_region relative to full image.
w: Full image width.
side: 'left' or 'right' (for logging).
Returns:
X coordinate (in full image) of the spine center, or None.
"""
region_w = search_region.shape[1]
if region_w < 10:
return None
# Column-mean brightness in the search region
col_means = np.mean(search_region, axis=0).astype(np.float64)
# Smooth with boxcar kernel (width = 1% of image width, min 5)
kernel_size = max(5, w // 100)
if kernel_size % 2 == 0:
kernel_size += 1
kernel = np.ones(kernel_size) / kernel_size
smoothed_raw = np.convolve(col_means, kernel, mode="same")
# Trim convolution edge artifacts (edges are zero-padded -> artificially low)
margin = kernel_size // 2
if region_w <= 2 * margin + 10:
return None
smoothed = smoothed_raw[margin:region_w - margin]
trim_offset = margin # offset of smoothed[0] relative to search_region
val_min = float(np.min(smoothed))
val_max = float(np.max(smoothed))
shadow_range = val_max - val_min
# --- Check 1: Strong brightness gradient ---
if shadow_range <= 40:
logger.debug(
"%s edge: no spine (range=%.0f <= 40)", side.capitalize(), shadow_range,
)
return None
# --- Check 2: Darkest point must be genuinely dark ---
if val_min > 180:
logger.debug(
"%s edge: no spine (darkest=%.0f > 180, likely text)", side.capitalize(), val_min,
)
return None
spine_idx = int(np.argmin(smoothed)) # index in trimmed array
spine_local = spine_idx + trim_offset # index in search_region
trimmed_len = len(smoothed)
# --- Check 3: Valley width (spine is narrow, text plateau is wide) ---
valley_thresh = val_min + shadow_range * 0.20
valley_mask = smoothed < valley_thresh
valley_width = int(np.sum(valley_mask))
max_valley_frac = 0.50
if valley_width > trimmed_len * max_valley_frac:
logger.debug(
"%s edge: no spine (valley too wide: %d/%d = %.0f%%)",
side.capitalize(), valley_width, trimmed_len,
100.0 * valley_width / trimmed_len,
)
return None
# --- Check 4: Brightness must rise toward page content ---
rise_check_w = max(5, trimmed_len // 5)
if side == "left":
right_start = min(spine_idx + 5, trimmed_len - 1)
right_end = min(right_start + rise_check_w, trimmed_len)
if right_end > right_start:
rise_brightness = float(np.mean(smoothed[right_start:right_end]))
rise = rise_brightness - val_min
if rise < shadow_range * 0.3:
logger.debug(
"%s edge: no spine (insufficient rise: %.0f, need %.0f)",
side.capitalize(), rise, shadow_range * 0.3,
)
return None
else: # right
left_end = max(spine_idx - 5, 0)
left_start = max(left_end - rise_check_w, 0)
if left_end > left_start:
rise_brightness = float(np.mean(smoothed[left_start:left_end]))
rise = rise_brightness - val_min
if rise < shadow_range * 0.3:
logger.debug(
"%s edge: no spine (insufficient rise: %.0f, need %.0f)",
side.capitalize(), rise, shadow_range * 0.3,
)
return None
spine_x = offset_x + spine_local
logger.info(
"%s edge: spine center at x=%d (brightness=%.0f, range=%.0f, valley=%dpx)",
side.capitalize(), spine_x, val_min, shadow_range, valley_width,
)
return spine_x
def _detect_gutter_continuity(
gray: np.ndarray,
search_region: np.ndarray,
offset_x: int,
w: int,
side: str,
) -> Optional[int]:
"""Detect gutter shadow via vertical continuity analysis.
Camera book scans produce a subtle brightness gradient at the gutter
that is too faint for scanner-shadow detection (range < 40). However,
the gutter shadow has a unique property: it runs **continuously from
top to bottom** without interruption.
Algorithm:
1. Divide image into N horizontal strips (~60px each)
2. For each column, compute what fraction of strips are darker than
the page median (from the center 50% of the full image)
3. A "gutter column" has >= 75% of strips darker than page_median - d
4. Smooth the dark-fraction profile and find the transition point
5. Validate: gutter band must be 0.5%-10% of image width
"""
region_h, region_w = search_region.shape[:2]
if region_w < 20 or region_h < 100:
return None
# --- 1. Divide into horizontal strips ---
strip_target_h = 60
n_strips = max(10, region_h // strip_target_h)
strip_h = region_h // n_strips
strip_means = np.zeros((n_strips, region_w), dtype=np.float64)
for s in range(n_strips):
y0 = s * strip_h
y1 = min((s + 1) * strip_h, region_h)
strip_means[s] = np.mean(search_region[y0:y1, :], axis=0)
# --- 2. Page median from center 50% of full image ---
center_lo = w // 4
center_hi = 3 * w // 4
page_median = float(np.median(gray[:, center_lo:center_hi]))
dark_thresh = page_median - 5.0
if page_median < 180:
return None
# --- 3. Per-column dark fraction ---
dark_count = np.sum(strip_means < dark_thresh, axis=0).astype(np.float64)
dark_frac = dark_count / n_strips
# --- 4. Smooth and find transition ---
smooth_w = max(5, w // 100)
if smooth_w % 2 == 0:
smooth_w += 1
kernel = np.ones(smooth_w) / smooth_w
frac_smooth = np.convolve(dark_frac, kernel, mode="same")
margin = smooth_w // 2
if region_w <= 2 * margin + 10:
return None
transition_thresh = 0.50
peak_frac = float(np.max(frac_smooth[margin:region_w - margin]))
if peak_frac < 0.70:
logger.debug(
"%s gutter: peak dark fraction %.2f < 0.70", side.capitalize(), peak_frac,
)
return None
peak_x = int(np.argmax(frac_smooth[margin:region_w - margin])) + margin
gutter_inner = None
if side == "right":
for x in range(peak_x, margin, -1):
if frac_smooth[x] < transition_thresh:
gutter_inner = x + 1
break
else:
for x in range(peak_x, region_w - margin):
if frac_smooth[x] < transition_thresh:
gutter_inner = x - 1
break
if gutter_inner is None:
return None
# --- 5. Validate gutter width ---
if side == "right":
gutter_width = region_w - gutter_inner
else:
gutter_width = gutter_inner
min_gutter = max(3, int(w * 0.005))
max_gutter = int(w * 0.10)
if gutter_width < min_gutter:
logger.debug(
"%s gutter: too narrow (%dpx < %dpx)", side.capitalize(),
gutter_width, min_gutter,
)
return None
if gutter_width > max_gutter:
logger.debug(
"%s gutter: too wide (%dpx > %dpx)", side.capitalize(),
gutter_width, max_gutter,
)
return None
if side == "right":
gutter_brightness = float(np.mean(strip_means[:, gutter_inner:]))
else:
gutter_brightness = float(np.mean(strip_means[:, :gutter_inner]))
brightness_drop = page_median - gutter_brightness
if brightness_drop < 3:
logger.debug(
"%s gutter: insufficient brightness drop (%.1f levels)",
side.capitalize(), brightness_drop,
)
return None
gutter_x = offset_x + gutter_inner
logger.info(
"%s gutter (continuity): x=%d, width=%dpx (%.1f%%), "
"brightness=%.0f vs page=%.0f (drop=%.0f), frac@edge=%.2f",
side.capitalize(), gutter_x, gutter_width,
100.0 * gutter_width / w, gutter_brightness, page_median,
brightness_drop, float(frac_smooth[gutter_inner]),
)
return gutter_x
def _detect_left_edge_shadow(
gray: np.ndarray,
binary: np.ndarray,
w: int,
h: int,
) -> int:
"""Detect left content edge, accounting for book-spine shadow.
Tries three methods in order:
1. Scanner spine-shadow (dark gradient, range > 40)
2. Camera gutter continuity (subtle shadow running top-to-bottom)
3. Binary projection fallback (first ink column)
"""
search_w = max(1, w // 4)
spine_x = _detect_spine_shadow(gray, gray[:, :search_w], 0, w, "left")
if spine_x is not None:
return spine_x
gutter_x = _detect_gutter_continuity(gray, gray[:, :search_w], 0, w, "left")
if gutter_x is not None:
return gutter_x
return _detect_edge_projection(binary, axis=0, from_start=True, dim=w)
def _detect_right_edge_shadow(
gray: np.ndarray,
binary: np.ndarray,
w: int,
h: int,
) -> int:
"""Detect right content edge, accounting for book-spine shadow.
Tries three methods in order:
1. Scanner spine-shadow (dark gradient, range > 40)
2. Camera gutter continuity (subtle shadow running top-to-bottom)
3. Binary projection fallback (last ink column)
"""
search_w = max(1, w // 4)
right_start = w - search_w
spine_x = _detect_spine_shadow(gray, gray[:, right_start:], right_start, w, "right")
if spine_x is not None:
return spine_x
gutter_x = _detect_gutter_continuity(gray, gray[:, right_start:], right_start, w, "right")
if gutter_x is not None:
return gutter_x
return _detect_edge_projection(binary, axis=0, from_start=False, dim=w)
def _detect_top_bottom_edges(binary: np.ndarray, w: int, h: int) -> Tuple[int, int]:
"""Detect top and bottom content edges via binary horizontal projection."""
top = _detect_edge_projection(binary, axis=1, from_start=True, dim=h)
bottom = _detect_edge_projection(binary, axis=1, from_start=False, dim=h)
return top, bottom
def _detect_edge_projection(
binary: np.ndarray,
axis: int,
from_start: bool,
dim: int,
) -> int:
"""Find the first/last row or column with ink density above threshold.
axis=0 -> project vertically (column densities) -> returns x position
axis=1 -> project horizontally (row densities) -> returns y position
Filters out narrow noise runs shorter than _MIN_RUN_FRAC of the dimension.
"""
projection = np.mean(binary, axis=axis) / 255.0
ink_mask = projection >= _INK_THRESHOLD
min_run = max(1, int(dim * _MIN_RUN_FRAC))
ink_mask = _filter_narrow_runs(ink_mask, min_run)
ink_positions = np.where(ink_mask)[0]
if len(ink_positions) == 0:
return 0 if from_start else dim
if from_start:
return int(ink_positions[0])
else:
return int(ink_positions[-1])
def _filter_narrow_runs(mask: np.ndarray, min_run: int) -> np.ndarray:
"""Remove True-runs shorter than min_run pixels."""
if min_run <= 1:
return mask
result = mask.copy()
n = len(result)
i = 0
while i < n:
if result[i]:
start = i
while i < n and result[i]:
i += 1
if i - start < min_run:
result[start:i] = False
else:
i += 1
return result

View File

@@ -0,0 +1,160 @@
"""
TrOCR Batch Processing & Streaming
Batch OCR and SSE streaming for multiple images.
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import asyncio
import logging
import time
from typing import Optional, List, Dict, Any
from .trocr_models import OCRResult, BatchOCRResult
from .trocr_ocr import run_trocr_ocr_enhanced
logger = logging.getLogger(__name__)
async def run_trocr_batch(
images: List[bytes],
handwritten: bool = True,
split_lines: bool = True,
use_cache: bool = True,
progress_callback: Optional[callable] = None
) -> BatchOCRResult:
"""
Process multiple images in batch.
Args:
images: List of image data bytes
handwritten: Use handwritten model
split_lines: Whether to split images into lines
use_cache: Whether to use caching
progress_callback: Optional callback(current, total) for progress updates
Returns:
BatchOCRResult with all results
"""
start_time = time.time()
results = []
cached_count = 0
error_count = 0
for idx, image_data in enumerate(images):
try:
result = await run_trocr_ocr_enhanced(
image_data,
handwritten=handwritten,
split_lines=split_lines,
use_cache=use_cache
)
results.append(result)
if result.from_cache:
cached_count += 1
# Report progress
if progress_callback:
progress_callback(idx + 1, len(images))
except Exception as e:
logger.error(f"Batch OCR error for image {idx}: {e}")
error_count += 1
results.append(OCRResult(
text=f"Error: {str(e)}",
confidence=0.0,
processing_time_ms=0,
model="error",
has_lora_adapter=False
))
total_time_ms = int((time.time() - start_time) * 1000)
return BatchOCRResult(
results=results,
total_time_ms=total_time_ms,
processed_count=len(images),
cached_count=cached_count,
error_count=error_count
)
# Generator for SSE streaming during batch processing
async def run_trocr_batch_stream(
images: List[bytes],
handwritten: bool = True,
split_lines: bool = True,
use_cache: bool = True
):
"""
Process images and yield progress updates for SSE streaming.
Yields:
dict with current progress and result
"""
start_time = time.time()
total = len(images)
for idx, image_data in enumerate(images):
try:
result = await run_trocr_ocr_enhanced(
image_data,
handwritten=handwritten,
split_lines=split_lines,
use_cache=use_cache
)
elapsed_ms = int((time.time() - start_time) * 1000)
avg_time_per_image = elapsed_ms / (idx + 1)
estimated_remaining = int(avg_time_per_image * (total - idx - 1))
yield {
"type": "progress",
"current": idx + 1,
"total": total,
"progress_percent": ((idx + 1) / total) * 100,
"elapsed_ms": elapsed_ms,
"estimated_remaining_ms": estimated_remaining,
"result": {
"text": result.text,
"confidence": result.confidence,
"processing_time_ms": result.processing_time_ms,
"from_cache": result.from_cache
}
}
except Exception as e:
logger.error(f"Stream OCR error for image {idx}: {e}")
yield {
"type": "error",
"current": idx + 1,
"total": total,
"error": str(e)
}
total_time_ms = int((time.time() - start_time) * 1000)
yield {
"type": "complete",
"total_time_ms": total_time_ms,
"processed_count": total
}
# Test function
async def test_trocr_ocr(image_path: str, handwritten: bool = False):
"""Test TrOCR on a local image file."""
from .trocr_ocr import run_trocr_ocr
with open(image_path, "rb") as f:
image_data = f.read()
text, confidence = await run_trocr_ocr(image_data, handwritten=handwritten)
print(f"\n=== TrOCR Test ===")
print(f"Mode: {'Handwritten' if handwritten else 'Printed'}")
print(f"Confidence: {confidence:.2f}")
print(f"Text:\n{text}")
return text, confidence

View File

@@ -0,0 +1,278 @@
"""
TrOCR Models & Cache
Dataclasses, LRU cache, and model loading for TrOCR service.
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import io
import os
import hashlib
import logging
import time
from typing import Tuple, Optional, List, Dict, Any
from dataclasses import dataclass, field
from collections import OrderedDict
from datetime import datetime, timedelta
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Backend routing: auto | pytorch | onnx
# ---------------------------------------------------------------------------
_trocr_backend = os.environ.get("TROCR_BACKEND", "auto") # auto | pytorch | onnx
# Lazy loading for heavy dependencies
# Cache keyed by model_name to support base and large variants simultaneously
_trocr_models: dict = {} # {model_name: (processor, model)}
_trocr_processor = None # backwards-compat alias -> base-printed
_trocr_model = None # backwards-compat alias -> base-printed
_trocr_available = None
_model_loaded_at = None
# Simple in-memory cache with LRU eviction
_ocr_cache: OrderedDict[str, Dict[str, Any]] = OrderedDict()
_cache_max_size = 100
_cache_ttl_seconds = 3600 # 1 hour
@dataclass
class OCRResult:
"""Enhanced OCR result with detailed information."""
text: str
confidence: float
processing_time_ms: int
model: str
has_lora_adapter: bool = False
char_confidences: List[float] = field(default_factory=list)
word_boxes: List[Dict[str, Any]] = field(default_factory=list)
from_cache: bool = False
image_hash: str = ""
@dataclass
class BatchOCRResult:
"""Result for batch processing."""
results: List[OCRResult]
total_time_ms: int
processed_count: int
cached_count: int
error_count: int
def _compute_image_hash(image_data: bytes) -> str:
"""Compute SHA256 hash of image data for caching."""
return hashlib.sha256(image_data).hexdigest()[:16]
def _cache_get(image_hash: str) -> Optional[Dict[str, Any]]:
"""Get cached OCR result if available and not expired."""
if image_hash in _ocr_cache:
entry = _ocr_cache[image_hash]
if datetime.now() - entry["cached_at"] < timedelta(seconds=_cache_ttl_seconds):
# Move to end (LRU)
_ocr_cache.move_to_end(image_hash)
return entry["result"]
else:
# Expired, remove
del _ocr_cache[image_hash]
return None
def _cache_set(image_hash: str, result: Dict[str, Any]) -> None:
"""Store OCR result in cache."""
# Evict oldest if at capacity
while len(_ocr_cache) >= _cache_max_size:
_ocr_cache.popitem(last=False)
_ocr_cache[image_hash] = {
"result": result,
"cached_at": datetime.now()
}
def get_cache_stats() -> Dict[str, Any]:
"""Get cache statistics."""
return {
"size": len(_ocr_cache),
"max_size": _cache_max_size,
"ttl_seconds": _cache_ttl_seconds,
"hit_rate": 0 # Could track this with additional counters
}
def _check_trocr_available() -> bool:
"""Check if TrOCR dependencies are available."""
global _trocr_available
if _trocr_available is not None:
return _trocr_available
try:
import torch
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
_trocr_available = True
except ImportError as e:
logger.warning(f"TrOCR dependencies not available: {e}")
_trocr_available = False
return _trocr_available
def get_trocr_model(handwritten: bool = False, size: str = "base"):
"""
Lazy load TrOCR model and processor.
Args:
handwritten: Use handwritten model instead of printed model
size: Model size -- "base" (300 MB) or "large" (340 MB, higher accuracy
for exam HTR). Only applies to handwritten variant.
Returns tuple of (processor, model) or (None, None) if unavailable.
"""
global _trocr_processor, _trocr_model
if not _check_trocr_available():
return None, None
# Select model name
if size == "large" and handwritten:
model_name = "microsoft/trocr-large-handwritten"
elif handwritten:
model_name = "microsoft/trocr-base-handwritten"
else:
model_name = "microsoft/trocr-base-printed"
if model_name in _trocr_models:
return _trocr_models[model_name]
try:
import torch
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
logger.info(f"Loading TrOCR model: {model_name}")
processor = TrOCRProcessor.from_pretrained(model_name)
model = VisionEncoderDecoderModel.from_pretrained(model_name)
# Use GPU if available
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
model.to(device)
logger.info(f"TrOCR model loaded on device: {device}")
_trocr_models[model_name] = (processor, model)
# Keep backwards-compat globals pointing at base-printed
if model_name == "microsoft/trocr-base-printed":
_trocr_processor = processor
_trocr_model = model
return processor, model
except Exception as e:
logger.error(f"Failed to load TrOCR model {model_name}: {e}")
return None, None
def preload_trocr_model(handwritten: bool = True) -> bool:
"""
Preload TrOCR model at startup for faster first request.
Call this from your FastAPI startup event:
@app.on_event("startup")
async def startup():
preload_trocr_model()
"""
global _model_loaded_at
logger.info("Preloading TrOCR model...")
processor, model = get_trocr_model(handwritten=handwritten)
if processor is not None and model is not None:
_model_loaded_at = datetime.now()
logger.info("TrOCR model preloaded successfully")
return True
else:
logger.warning("TrOCR model preloading failed")
return False
def get_model_status() -> Dict[str, Any]:
"""Get current model status information."""
processor, model = get_trocr_model(handwritten=True)
is_loaded = processor is not None and model is not None
status = {
"status": "available" if is_loaded else "not_installed",
"is_loaded": is_loaded,
"model_name": "trocr-base-handwritten" if is_loaded else None,
"loaded_at": _model_loaded_at.isoformat() if _model_loaded_at else None,
}
if is_loaded:
import torch
device = next(model.parameters()).device
status["device"] = str(device)
return status
def get_active_backend() -> str:
"""
Return which TrOCR backend is configured.
Possible values: "auto", "pytorch", "onnx".
"""
return _trocr_backend
def _split_into_lines(image) -> list:
"""
Split an image into text lines using simple projection-based segmentation.
This is a basic implementation - for production use, consider using
a dedicated line detection model.
"""
import numpy as np
from PIL import Image
try:
# Convert to grayscale
gray = image.convert('L')
img_array = np.array(gray)
# Binarize (simple threshold)
threshold = 200
binary = img_array < threshold
# Horizontal projection (sum of dark pixels per row)
h_proj = np.sum(binary, axis=1)
# Find line boundaries (where projection drops below threshold)
line_threshold = img_array.shape[1] * 0.02 # 2% of width
in_line = False
line_start = 0
lines = []
for i, val in enumerate(h_proj):
if val > line_threshold and not in_line:
# Start of line
in_line = True
line_start = i
elif val <= line_threshold and in_line:
# End of line
in_line = False
# Add padding
start = max(0, line_start - 5)
end = min(img_array.shape[0], i + 5)
if end - start > 10: # Minimum line height
lines.append(image.crop((0, start, image.width, end)))
# Handle last line if still in_line
if in_line:
start = max(0, line_start - 5)
lines.append(image.crop((0, start, image.width, image.height)))
logger.info(f"Split image into {len(lines)} lines")
return lines
except Exception as e:
logger.warning(f"Line splitting failed: {e}")
return []

View File

@@ -0,0 +1,309 @@
"""
TrOCR OCR Execution
Core OCR inference routines (PyTorch, ONNX routing, enhanced mode).
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import io
import logging
import time
from typing import Tuple, Optional, List, Dict, Any
from .trocr_models import (
OCRResult,
_trocr_backend,
_compute_image_hash,
_cache_get,
_cache_set,
get_trocr_model,
_split_into_lines,
)
logger = logging.getLogger(__name__)
def _try_onnx_ocr(
image_data: bytes,
handwritten: bool = False,
split_lines: bool = True,
) -> Optional[Tuple[Optional[str], float]]:
"""
Attempt ONNX inference. Returns the (text, confidence) tuple on
success, or None if ONNX is not available / fails to load.
"""
try:
from .trocr_onnx_service import is_onnx_available, run_trocr_onnx
if not is_onnx_available(handwritten=handwritten):
return None
# run_trocr_onnx is async -- return the coroutine's awaitable result
# The caller (run_trocr_ocr) will await it.
return run_trocr_onnx # sentinel: caller checks callable
except ImportError:
return None
async def _run_pytorch_ocr(
image_data: bytes,
handwritten: bool = False,
split_lines: bool = True,
size: str = "base",
) -> Tuple[Optional[str], float]:
"""
Original PyTorch inference path (extracted for routing).
"""
processor, model = get_trocr_model(handwritten=handwritten, size=size)
if processor is None or model is None:
logger.error("TrOCR PyTorch model not available")
return None, 0.0
try:
import torch
from PIL import Image
import numpy as np
# Load image
image = Image.open(io.BytesIO(image_data)).convert("RGB")
if split_lines:
lines = _split_into_lines(image)
if not lines:
lines = [image]
else:
lines = [image]
all_text = []
confidences = []
for line_image in lines:
pixel_values = processor(images=line_image, return_tensors="pt").pixel_values
device = next(model.parameters()).device
pixel_values = pixel_values.to(device)
with torch.no_grad():
generated_ids = model.generate(pixel_values, max_length=128)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
if generated_text.strip():
all_text.append(generated_text.strip())
confidences.append(0.85 if len(generated_text) > 3 else 0.5)
text = "\n".join(all_text)
confidence = sum(confidences) / len(confidences) if confidences else 0.0
logger.info(f"TrOCR (PyTorch) extracted {len(text)} characters from {len(lines)} lines")
return text, confidence
except Exception as e:
logger.error(f"TrOCR PyTorch failed: {e}")
import traceback
logger.error(traceback.format_exc())
return None, 0.0
async def run_trocr_ocr(
image_data: bytes,
handwritten: bool = False,
split_lines: bool = True,
size: str = "base",
) -> Tuple[Optional[str], float]:
"""
Run TrOCR on an image.
Routes between ONNX and PyTorch backends based on the TROCR_BACKEND
environment variable (default: "auto").
- "onnx" -- always use ONNX (raises RuntimeError if unavailable)
- "pytorch" -- always use PyTorch (original behaviour)
- "auto" -- try ONNX first, fall back to PyTorch
TrOCR is optimized for single-line text recognition, so for full-page
images we need to either:
1. Split into lines first (using line detection)
2. Process the whole image and get partial results
Args:
image_data: Raw image bytes
handwritten: Use handwritten model (slower but better for handwriting)
split_lines: Whether to split image into lines first
size: "base" or "large" (only for handwritten variant)
Returns:
Tuple of (extracted_text, confidence)
"""
backend = _trocr_backend
# --- ONNX-only mode ---
if backend == "onnx":
onnx_fn = _try_onnx_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
if onnx_fn is None or not callable(onnx_fn):
raise RuntimeError(
"ONNX backend requested (TROCR_BACKEND=onnx) but unavailable. "
"Ensure onnxruntime + optimum are installed and ONNX model files exist."
)
return await onnx_fn(image_data, handwritten=handwritten, split_lines=split_lines)
# --- PyTorch-only mode ---
if backend == "pytorch":
return await _run_pytorch_ocr(
image_data, handwritten=handwritten, split_lines=split_lines, size=size,
)
# --- Auto mode: try ONNX first, then PyTorch ---
onnx_fn = _try_onnx_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
if onnx_fn is not None and callable(onnx_fn):
try:
result = await onnx_fn(image_data, handwritten=handwritten, split_lines=split_lines)
if result[0] is not None:
return result
logger.warning("ONNX returned None text, falling back to PyTorch")
except Exception as e:
logger.warning(f"ONNX inference failed ({e}), falling back to PyTorch")
return await _run_pytorch_ocr(
image_data, handwritten=handwritten, split_lines=split_lines, size=size,
)
def _try_onnx_enhanced(
handwritten: bool = True,
):
"""
Return the ONNX enhanced coroutine function, or None if unavailable.
"""
try:
from .trocr_onnx_service import is_onnx_available, run_trocr_onnx_enhanced
if not is_onnx_available(handwritten=handwritten):
return None
return run_trocr_onnx_enhanced
except ImportError:
return None
async def run_trocr_ocr_enhanced(
image_data: bytes,
handwritten: bool = True,
split_lines: bool = True,
use_cache: bool = True
) -> OCRResult:
"""
Enhanced TrOCR OCR with caching and detailed results.
Routes between ONNX and PyTorch backends based on the TROCR_BACKEND
environment variable (default: "auto").
Args:
image_data: Raw image bytes
handwritten: Use handwritten model
split_lines: Whether to split image into lines first
use_cache: Whether to use caching
Returns:
OCRResult with detailed information
"""
backend = _trocr_backend
# --- ONNX-only mode ---
if backend == "onnx":
onnx_fn = _try_onnx_enhanced(handwritten=handwritten)
if onnx_fn is None:
raise RuntimeError(
"ONNX backend requested (TROCR_BACKEND=onnx) but unavailable. "
"Ensure onnxruntime + optimum are installed and ONNX model files exist."
)
return await onnx_fn(
image_data, handwritten=handwritten,
split_lines=split_lines, use_cache=use_cache,
)
# --- Auto mode: try ONNX first ---
if backend == "auto":
onnx_fn = _try_onnx_enhanced(handwritten=handwritten)
if onnx_fn is not None:
try:
result = await onnx_fn(
image_data, handwritten=handwritten,
split_lines=split_lines, use_cache=use_cache,
)
if result.text:
return result
logger.warning("ONNX enhanced returned empty text, falling back to PyTorch")
except Exception as e:
logger.warning(f"ONNX enhanced failed ({e}), falling back to PyTorch")
# --- PyTorch path (backend == "pytorch" or auto fallback) ---
start_time = time.time()
# Check cache first
image_hash = _compute_image_hash(image_data)
if use_cache:
cached = _cache_get(image_hash)
if cached:
return OCRResult(
text=cached["text"],
confidence=cached["confidence"],
processing_time_ms=0,
model=cached["model"],
has_lora_adapter=cached.get("has_lora_adapter", False),
char_confidences=cached.get("char_confidences", []),
word_boxes=cached.get("word_boxes", []),
from_cache=True,
image_hash=image_hash
)
# Run OCR via PyTorch
text, confidence = await _run_pytorch_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
processing_time_ms = int((time.time() - start_time) * 1000)
# Generate word boxes with simulated confidences
word_boxes = []
if text:
words = text.split()
for idx, word in enumerate(words):
# Simulate word confidence (slightly varied around overall confidence)
word_conf = min(1.0, max(0.0, confidence + (hash(word) % 20 - 10) / 100))
word_boxes.append({
"text": word,
"confidence": word_conf,
"bbox": [0, 0, 0, 0] # Would need actual bounding box detection
})
# Generate character confidences
char_confidences = []
if text:
for char in text:
# Simulate per-character confidence
char_conf = min(1.0, max(0.0, confidence + (hash(char) % 15 - 7) / 100))
char_confidences.append(char_conf)
result = OCRResult(
text=text or "",
confidence=confidence,
processing_time_ms=processing_time_ms,
model="trocr-base-handwritten" if handwritten else "trocr-base-printed",
has_lora_adapter=False, # Would check actual adapter status
char_confidences=char_confidences,
word_boxes=word_boxes,
from_cache=False,
image_hash=image_hash
)
# Cache result
if use_cache and text:
_cache_set(image_hash, {
"text": result.text,
"confidence": result.confidence,
"model": result.model,
"has_lora_adapter": result.has_lora_adapter,
"char_confidences": result.char_confidences,
"word_boxes": result.word_boxes
})
return result

View File

@@ -1,720 +1,70 @@
"""
TrOCR Service
TrOCR Service — Barrel Re-export
Microsoft's Transformer-based OCR for text recognition.
Besonders geeignet fuer:
- Gedruckten Text
- Saubere Scans
- Schnelle Verarbeitung
Model: microsoft/trocr-base-printed (oder handwritten Variante)
Split into submodules:
- trocr_models.py — Dataclasses, cache, model loading, line splitting
- trocr_ocr.py — Core OCR inference (PyTorch/ONNX routing, enhanced)
- trocr_batch.py — Batch processing and SSE streaming
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
Phase 2 Enhancements:
- Batch processing for multiple images
- SHA256-based caching for repeated requests
- Model preloading for faster first request
- Word-level bounding boxes with confidence
"""
import io
import os
import hashlib
import logging
import time
import asyncio
from typing import Tuple, Optional, List, Dict, Any
from dataclasses import dataclass, field
from collections import OrderedDict
from datetime import datetime, timedelta
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Backend routing: auto | pytorch | onnx
# ---------------------------------------------------------------------------
_trocr_backend = os.environ.get("TROCR_BACKEND", "auto") # auto | pytorch | onnx
# Lazy loading for heavy dependencies
# Cache keyed by model_name to support base and large variants simultaneously
_trocr_models: dict = {} # {model_name: (processor, model)}
_trocr_processor = None # backwards-compat alias → base-printed
_trocr_model = None # backwards-compat alias → base-printed
_trocr_available = None
_model_loaded_at = None
# Simple in-memory cache with LRU eviction
_ocr_cache: OrderedDict[str, Dict[str, Any]] = OrderedDict()
_cache_max_size = 100
_cache_ttl_seconds = 3600 # 1 hour
@dataclass
class OCRResult:
"""Enhanced OCR result with detailed information."""
text: str
confidence: float
processing_time_ms: int
model: str
has_lora_adapter: bool = False
char_confidences: List[float] = field(default_factory=list)
word_boxes: List[Dict[str, Any]] = field(default_factory=list)
from_cache: bool = False
image_hash: str = ""
@dataclass
class BatchOCRResult:
"""Result for batch processing."""
results: List[OCRResult]
total_time_ms: int
processed_count: int
cached_count: int
error_count: int
def _compute_image_hash(image_data: bytes) -> str:
"""Compute SHA256 hash of image data for caching."""
return hashlib.sha256(image_data).hexdigest()[:16]
def _cache_get(image_hash: str) -> Optional[Dict[str, Any]]:
"""Get cached OCR result if available and not expired."""
if image_hash in _ocr_cache:
entry = _ocr_cache[image_hash]
if datetime.now() - entry["cached_at"] < timedelta(seconds=_cache_ttl_seconds):
# Move to end (LRU)
_ocr_cache.move_to_end(image_hash)
return entry["result"]
else:
# Expired, remove
del _ocr_cache[image_hash]
return None
def _cache_set(image_hash: str, result: Dict[str, Any]) -> None:
"""Store OCR result in cache."""
# Evict oldest if at capacity
while len(_ocr_cache) >= _cache_max_size:
_ocr_cache.popitem(last=False)
_ocr_cache[image_hash] = {
"result": result,
"cached_at": datetime.now()
}
def get_cache_stats() -> Dict[str, Any]:
"""Get cache statistics."""
return {
"size": len(_ocr_cache),
"max_size": _cache_max_size,
"ttl_seconds": _cache_ttl_seconds,
"hit_rate": 0 # Could track this with additional counters
}
def _check_trocr_available() -> bool:
"""Check if TrOCR dependencies are available."""
global _trocr_available
if _trocr_available is not None:
return _trocr_available
try:
import torch
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
_trocr_available = True
except ImportError as e:
logger.warning(f"TrOCR dependencies not available: {e}")
_trocr_available = False
return _trocr_available
def get_trocr_model(handwritten: bool = False, size: str = "base"):
"""
Lazy load TrOCR model and processor.
Args:
handwritten: Use handwritten model instead of printed model
size: Model size — "base" (300 MB) or "large" (340 MB, higher accuracy
for exam HTR). Only applies to handwritten variant.
Returns tuple of (processor, model) or (None, None) if unavailable.
"""
global _trocr_processor, _trocr_model
if not _check_trocr_available():
return None, None
# Select model name
if size == "large" and handwritten:
model_name = "microsoft/trocr-large-handwritten"
elif handwritten:
model_name = "microsoft/trocr-base-handwritten"
else:
model_name = "microsoft/trocr-base-printed"
if model_name in _trocr_models:
return _trocr_models[model_name]
try:
import torch
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
logger.info(f"Loading TrOCR model: {model_name}")
processor = TrOCRProcessor.from_pretrained(model_name)
model = VisionEncoderDecoderModel.from_pretrained(model_name)
# Use GPU if available
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
model.to(device)
logger.info(f"TrOCR model loaded on device: {device}")
_trocr_models[model_name] = (processor, model)
# Keep backwards-compat globals pointing at base-printed
if model_name == "microsoft/trocr-base-printed":
_trocr_processor = processor
_trocr_model = model
return processor, model
except Exception as e:
logger.error(f"Failed to load TrOCR model {model_name}: {e}")
return None, None
def preload_trocr_model(handwritten: bool = True) -> bool:
"""
Preload TrOCR model at startup for faster first request.
Call this from your FastAPI startup event:
@app.on_event("startup")
async def startup():
preload_trocr_model()
"""
global _model_loaded_at
logger.info("Preloading TrOCR model...")
processor, model = get_trocr_model(handwritten=handwritten)
if processor is not None and model is not None:
_model_loaded_at = datetime.now()
logger.info("TrOCR model preloaded successfully")
return True
else:
logger.warning("TrOCR model preloading failed")
return False
def get_model_status() -> Dict[str, Any]:
"""Get current model status information."""
processor, model = get_trocr_model(handwritten=True)
is_loaded = processor is not None and model is not None
status = {
"status": "available" if is_loaded else "not_installed",
"is_loaded": is_loaded,
"model_name": "trocr-base-handwritten" if is_loaded else None,
"loaded_at": _model_loaded_at.isoformat() if _model_loaded_at else None,
}
if is_loaded:
import torch
device = next(model.parameters()).device
status["device"] = str(device)
return status
def get_active_backend() -> str:
"""
Return which TrOCR backend is configured.
Possible values: "auto", "pytorch", "onnx".
"""
return _trocr_backend
def _try_onnx_ocr(
image_data: bytes,
handwritten: bool = False,
split_lines: bool = True,
) -> Optional[Tuple[Optional[str], float]]:
"""
Attempt ONNX inference. Returns the (text, confidence) tuple on
success, or None if ONNX is not available / fails to load.
"""
try:
from .trocr_onnx_service import is_onnx_available, run_trocr_onnx
if not is_onnx_available(handwritten=handwritten):
return None
# run_trocr_onnx is async — return the coroutine's awaitable result
# The caller (run_trocr_ocr) will await it.
return run_trocr_onnx # sentinel: caller checks callable
except ImportError:
return None
async def _run_pytorch_ocr(
image_data: bytes,
handwritten: bool = False,
split_lines: bool = True,
size: str = "base",
) -> Tuple[Optional[str], float]:
"""
Original PyTorch inference path (extracted for routing).
"""
processor, model = get_trocr_model(handwritten=handwritten, size=size)
if processor is None or model is None:
logger.error("TrOCR PyTorch model not available")
return None, 0.0
try:
import torch
from PIL import Image
import numpy as np
# Load image
image = Image.open(io.BytesIO(image_data)).convert("RGB")
if split_lines:
lines = _split_into_lines(image)
if not lines:
lines = [image]
else:
lines = [image]
all_text = []
confidences = []
for line_image in lines:
pixel_values = processor(images=line_image, return_tensors="pt").pixel_values
device = next(model.parameters()).device
pixel_values = pixel_values.to(device)
with torch.no_grad():
generated_ids = model.generate(pixel_values, max_length=128)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
if generated_text.strip():
all_text.append(generated_text.strip())
confidences.append(0.85 if len(generated_text) > 3 else 0.5)
text = "\n".join(all_text)
confidence = sum(confidences) / len(confidences) if confidences else 0.0
logger.info(f"TrOCR (PyTorch) extracted {len(text)} characters from {len(lines)} lines")
return text, confidence
except Exception as e:
logger.error(f"TrOCR PyTorch failed: {e}")
import traceback
logger.error(traceback.format_exc())
return None, 0.0
async def run_trocr_ocr(
image_data: bytes,
handwritten: bool = False,
split_lines: bool = True,
size: str = "base",
) -> Tuple[Optional[str], float]:
"""
Run TrOCR on an image.
Routes between ONNX and PyTorch backends based on the TROCR_BACKEND
environment variable (default: "auto").
- "onnx" — always use ONNX (raises RuntimeError if unavailable)
- "pytorch" — always use PyTorch (original behaviour)
- "auto" — try ONNX first, fall back to PyTorch
TrOCR is optimized for single-line text recognition, so for full-page
images we need to either:
1. Split into lines first (using line detection)
2. Process the whole image and get partial results
Args:
image_data: Raw image bytes
handwritten: Use handwritten model (slower but better for handwriting)
split_lines: Whether to split image into lines first
size: "base" or "large" (only for handwritten variant)
Returns:
Tuple of (extracted_text, confidence)
"""
backend = _trocr_backend
# --- ONNX-only mode ---
if backend == "onnx":
onnx_fn = _try_onnx_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
if onnx_fn is None or not callable(onnx_fn):
raise RuntimeError(
"ONNX backend requested (TROCR_BACKEND=onnx) but unavailable. "
"Ensure onnxruntime + optimum are installed and ONNX model files exist."
)
return await onnx_fn(image_data, handwritten=handwritten, split_lines=split_lines)
# --- PyTorch-only mode ---
if backend == "pytorch":
return await _run_pytorch_ocr(
image_data, handwritten=handwritten, split_lines=split_lines, size=size,
)
# --- Auto mode: try ONNX first, then PyTorch ---
onnx_fn = _try_onnx_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
if onnx_fn is not None and callable(onnx_fn):
try:
result = await onnx_fn(image_data, handwritten=handwritten, split_lines=split_lines)
if result[0] is not None:
return result
logger.warning("ONNX returned None text, falling back to PyTorch")
except Exception as e:
logger.warning(f"ONNX inference failed ({e}), falling back to PyTorch")
return await _run_pytorch_ocr(
image_data, handwritten=handwritten, split_lines=split_lines, size=size,
)
def _split_into_lines(image) -> list:
"""
Split an image into text lines using simple projection-based segmentation.
This is a basic implementation - for production use, consider using
a dedicated line detection model.
"""
import numpy as np
from PIL import Image
try:
# Convert to grayscale
gray = image.convert('L')
img_array = np.array(gray)
# Binarize (simple threshold)
threshold = 200
binary = img_array < threshold
# Horizontal projection (sum of dark pixels per row)
h_proj = np.sum(binary, axis=1)
# Find line boundaries (where projection drops below threshold)
line_threshold = img_array.shape[1] * 0.02 # 2% of width
in_line = False
line_start = 0
lines = []
for i, val in enumerate(h_proj):
if val > line_threshold and not in_line:
# Start of line
in_line = True
line_start = i
elif val <= line_threshold and in_line:
# End of line
in_line = False
# Add padding
start = max(0, line_start - 5)
end = min(img_array.shape[0], i + 5)
if end - start > 10: # Minimum line height
lines.append(image.crop((0, start, image.width, end)))
# Handle last line if still in_line
if in_line:
start = max(0, line_start - 5)
lines.append(image.crop((0, start, image.width, image.height)))
logger.info(f"Split image into {len(lines)} lines")
return lines
except Exception as e:
logger.warning(f"Line splitting failed: {e}")
return []
def _try_onnx_enhanced(
handwritten: bool = True,
):
"""
Return the ONNX enhanced coroutine function, or None if unavailable.
"""
try:
from .trocr_onnx_service import is_onnx_available, run_trocr_onnx_enhanced
if not is_onnx_available(handwritten=handwritten):
return None
return run_trocr_onnx_enhanced
except ImportError:
return None
async def run_trocr_ocr_enhanced(
image_data: bytes,
handwritten: bool = True,
split_lines: bool = True,
use_cache: bool = True
) -> OCRResult:
"""
Enhanced TrOCR OCR with caching and detailed results.
Routes between ONNX and PyTorch backends based on the TROCR_BACKEND
environment variable (default: "auto").
Args:
image_data: Raw image bytes
handwritten: Use handwritten model
split_lines: Whether to split image into lines first
use_cache: Whether to use caching
Returns:
OCRResult with detailed information
"""
backend = _trocr_backend
# --- ONNX-only mode ---
if backend == "onnx":
onnx_fn = _try_onnx_enhanced(handwritten=handwritten)
if onnx_fn is None:
raise RuntimeError(
"ONNX backend requested (TROCR_BACKEND=onnx) but unavailable. "
"Ensure onnxruntime + optimum are installed and ONNX model files exist."
)
return await onnx_fn(
image_data, handwritten=handwritten,
split_lines=split_lines, use_cache=use_cache,
)
# --- Auto mode: try ONNX first ---
if backend == "auto":
onnx_fn = _try_onnx_enhanced(handwritten=handwritten)
if onnx_fn is not None:
try:
result = await onnx_fn(
image_data, handwritten=handwritten,
split_lines=split_lines, use_cache=use_cache,
)
if result.text:
return result
logger.warning("ONNX enhanced returned empty text, falling back to PyTorch")
except Exception as e:
logger.warning(f"ONNX enhanced failed ({e}), falling back to PyTorch")
# --- PyTorch path (backend == "pytorch" or auto fallback) ---
start_time = time.time()
# Check cache first
image_hash = _compute_image_hash(image_data)
if use_cache:
cached = _cache_get(image_hash)
if cached:
return OCRResult(
text=cached["text"],
confidence=cached["confidence"],
processing_time_ms=0,
model=cached["model"],
has_lora_adapter=cached.get("has_lora_adapter", False),
char_confidences=cached.get("char_confidences", []),
word_boxes=cached.get("word_boxes", []),
from_cache=True,
image_hash=image_hash
)
# Run OCR via PyTorch
text, confidence = await _run_pytorch_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
processing_time_ms = int((time.time() - start_time) * 1000)
# Generate word boxes with simulated confidences
word_boxes = []
if text:
words = text.split()
for idx, word in enumerate(words):
# Simulate word confidence (slightly varied around overall confidence)
word_conf = min(1.0, max(0.0, confidence + (hash(word) % 20 - 10) / 100))
word_boxes.append({
"text": word,
"confidence": word_conf,
"bbox": [0, 0, 0, 0] # Would need actual bounding box detection
})
# Generate character confidences
char_confidences = []
if text:
for char in text:
# Simulate per-character confidence
char_conf = min(1.0, max(0.0, confidence + (hash(char) % 15 - 7) / 100))
char_confidences.append(char_conf)
result = OCRResult(
text=text or "",
confidence=confidence,
processing_time_ms=processing_time_ms,
model="trocr-base-handwritten" if handwritten else "trocr-base-printed",
has_lora_adapter=False, # Would check actual adapter status
char_confidences=char_confidences,
word_boxes=word_boxes,
from_cache=False,
image_hash=image_hash
)
# Cache result
if use_cache and text:
_cache_set(image_hash, {
"text": result.text,
"confidence": result.confidence,
"model": result.model,
"has_lora_adapter": result.has_lora_adapter,
"char_confidences": result.char_confidences,
"word_boxes": result.word_boxes
})
return result
async def run_trocr_batch(
images: List[bytes],
handwritten: bool = True,
split_lines: bool = True,
use_cache: bool = True,
progress_callback: Optional[callable] = None
) -> BatchOCRResult:
"""
Process multiple images in batch.
Args:
images: List of image data bytes
handwritten: Use handwritten model
split_lines: Whether to split images into lines
use_cache: Whether to use caching
progress_callback: Optional callback(current, total) for progress updates
Returns:
BatchOCRResult with all results
"""
start_time = time.time()
results = []
cached_count = 0
error_count = 0
for idx, image_data in enumerate(images):
try:
result = await run_trocr_ocr_enhanced(
image_data,
handwritten=handwritten,
split_lines=split_lines,
use_cache=use_cache
)
results.append(result)
if result.from_cache:
cached_count += 1
# Report progress
if progress_callback:
progress_callback(idx + 1, len(images))
except Exception as e:
logger.error(f"Batch OCR error for image {idx}: {e}")
error_count += 1
results.append(OCRResult(
text=f"Error: {str(e)}",
confidence=0.0,
processing_time_ms=0,
model="error",
has_lora_adapter=False
))
total_time_ms = int((time.time() - start_time) * 1000)
return BatchOCRResult(
results=results,
total_time_ms=total_time_ms,
processed_count=len(images),
cached_count=cached_count,
error_count=error_count
)
# Generator for SSE streaming during batch processing
async def run_trocr_batch_stream(
images: List[bytes],
handwritten: bool = True,
split_lines: bool = True,
use_cache: bool = True
):
"""
Process images and yield progress updates for SSE streaming.
Yields:
dict with current progress and result
"""
start_time = time.time()
total = len(images)
for idx, image_data in enumerate(images):
try:
result = await run_trocr_ocr_enhanced(
image_data,
handwritten=handwritten,
split_lines=split_lines,
use_cache=use_cache
)
elapsed_ms = int((time.time() - start_time) * 1000)
avg_time_per_image = elapsed_ms / (idx + 1)
estimated_remaining = int(avg_time_per_image * (total - idx - 1))
yield {
"type": "progress",
"current": idx + 1,
"total": total,
"progress_percent": ((idx + 1) / total) * 100,
"elapsed_ms": elapsed_ms,
"estimated_remaining_ms": estimated_remaining,
"result": {
"text": result.text,
"confidence": result.confidence,
"processing_time_ms": result.processing_time_ms,
"from_cache": result.from_cache
}
}
except Exception as e:
logger.error(f"Stream OCR error for image {idx}: {e}")
yield {
"type": "error",
"current": idx + 1,
"total": total,
"error": str(e)
}
total_time_ms = int((time.time() - start_time) * 1000)
yield {
"type": "complete",
"total_time_ms": total_time_ms,
"processed_count": total
}
# Test function
async def test_trocr_ocr(image_path: str, handwritten: bool = False):
"""Test TrOCR on a local image file."""
with open(image_path, "rb") as f:
image_data = f.read()
text, confidence = await run_trocr_ocr(image_data, handwritten=handwritten)
print(f"\n=== TrOCR Test ===")
print(f"Mode: {'Handwritten' if handwritten else 'Printed'}")
print(f"Confidence: {confidence:.2f}")
print(f"Text:\n{text}")
return text, confidence
# Models, cache, and model loading
from .trocr_models import (
OCRResult,
BatchOCRResult,
_compute_image_hash,
_cache_get,
_cache_set,
get_cache_stats,
_check_trocr_available,
get_trocr_model,
preload_trocr_model,
get_model_status,
get_active_backend,
_split_into_lines,
)
# Core OCR execution
from .trocr_ocr import (
run_trocr_ocr,
run_trocr_ocr_enhanced,
_run_pytorch_ocr,
)
# Batch processing & streaming
from .trocr_batch import (
run_trocr_batch,
run_trocr_batch_stream,
test_trocr_ocr,
)
__all__ = [
# Dataclasses
"OCRResult",
"BatchOCRResult",
# Cache
"_compute_image_hash",
"_cache_get",
"_cache_set",
"get_cache_stats",
# Model loading
"_check_trocr_available",
"get_trocr_model",
"preload_trocr_model",
"get_model_status",
"get_active_backend",
"_split_into_lines",
# OCR execution
"run_trocr_ocr",
"run_trocr_ocr_enhanced",
"_run_pytorch_ocr",
# Batch
"run_trocr_batch",
"run_trocr_batch_stream",
"test_trocr_ocr",
]
if __name__ == "__main__":