[split-required] Split 700-870 LOC files across all services
backend-lehrer (11 files): - llm_gateway/routes/schools.py (867 → 5), recording_api.py (848 → 6) - messenger_api.py (840 → 5), print_generator.py (824 → 5) - unit_analytics_api.py (751 → 5), classroom/routes/context.py (726 → 4) - llm_gateway/routes/edu_search_seeds.py (710 → 4) klausur-service (12 files): - ocr_labeling_api.py (845 → 4), metrics_db.py (833 → 4) - legal_corpus_api.py (790 → 4), page_crop.py (758 → 3) - mail/ai_service.py (747 → 4), github_crawler.py (767 → 3) - trocr_service.py (730 → 4), full_compliance_pipeline.py (723 → 4) - dsfa_rag_api.py (715 → 4), ocr_pipeline_auto.py (705 → 4) website (6 pages): - audit-checklist (867 → 8), content (806 → 6) - screen-flow (790 → 4), scraper (789 → 5) - zeugnisse (776 → 5), modules (745 → 4) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
200
klausur-service/backend/compliance_extraction.py
Normal file
200
klausur-service/backend/compliance_extraction.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""
|
||||
Compliance Extraction & Generation.
|
||||
|
||||
Functions for extracting checkpoints from legal text chunks,
|
||||
generating controls, and creating remediation measures.
|
||||
"""
|
||||
|
||||
import re
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from compliance_models import Checkpoint, Control, Measure
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_checkpoints_from_chunk(chunk_text: str, payload: Dict) -> List[Checkpoint]:
|
||||
"""
|
||||
Extract checkpoints/requirements from a chunk of text.
|
||||
|
||||
Uses pattern matching to find requirement-like statements.
|
||||
"""
|
||||
checkpoints = []
|
||||
regulation_code = payload.get("regulation_code", "UNKNOWN")
|
||||
regulation_name = payload.get("regulation_name", "Unknown")
|
||||
source_url = payload.get("source_url", "")
|
||||
chunk_id = hashlib.md5(chunk_text[:100].encode()).hexdigest()[:8]
|
||||
|
||||
# Patterns for different requirement types
|
||||
patterns = [
|
||||
# BSI-TR patterns
|
||||
(r'([OT]\.[A-Za-z_]+\d*)[:\s]+(.+?)(?=\n[OT]\.|$)', 'bsi_requirement'),
|
||||
# Article patterns (GDPR, AI Act, etc.)
|
||||
(r'(?:Artikel|Art\.?)\s+(\d+)(?:\s+Abs(?:atz)?\.?\s*(\d+))?\s*[-\u2013:]\s*(.+?)(?=\n|$)', 'article'),
|
||||
# Numbered requirements
|
||||
(r'\((\d+)\)\s+(.+?)(?=\n\(\d+\)|$)', 'numbered'),
|
||||
# "Der Verantwortliche muss" patterns
|
||||
(r'(?:Der Verantwortliche|Die Aufsichtsbeh\u00f6rde|Der Auftragsverarbeiter)\s+(muss|hat|soll)\s+(.+?)(?=\.\s|$)', 'obligation'),
|
||||
# "Es ist erforderlich" patterns
|
||||
(r'(?:Es ist erforderlich|Es muss gew\u00e4hrleistet|Es sind geeignete)\s+(.+?)(?=\.\s|$)', 'requirement'),
|
||||
]
|
||||
|
||||
for pattern, pattern_type in patterns:
|
||||
matches = re.finditer(pattern, chunk_text, re.MULTILINE | re.DOTALL)
|
||||
for match in matches:
|
||||
if pattern_type == 'bsi_requirement':
|
||||
req_id = match.group(1)
|
||||
description = match.group(2).strip()
|
||||
title = req_id
|
||||
elif pattern_type == 'article':
|
||||
article_num = match.group(1)
|
||||
paragraph = match.group(2) or ""
|
||||
title_text = match.group(3).strip()
|
||||
req_id = f"{regulation_code}-Art{article_num}"
|
||||
if paragraph:
|
||||
req_id += f"-{paragraph}"
|
||||
title = f"Art. {article_num}" + (f" Abs. {paragraph}" if paragraph else "")
|
||||
description = title_text
|
||||
elif pattern_type == 'numbered':
|
||||
num = match.group(1)
|
||||
description = match.group(2).strip()
|
||||
req_id = f"{regulation_code}-{num}"
|
||||
title = f"Anforderung {num}"
|
||||
else:
|
||||
# Generic requirement
|
||||
description = match.group(0).strip()
|
||||
req_id = f"{regulation_code}-{chunk_id}-{len(checkpoints)}"
|
||||
title = description[:50] + "..." if len(description) > 50 else description
|
||||
|
||||
# Skip very short matches
|
||||
if len(description) < 20:
|
||||
continue
|
||||
|
||||
checkpoint = Checkpoint(
|
||||
id=req_id,
|
||||
regulation_code=regulation_code,
|
||||
regulation_name=regulation_name,
|
||||
article=title if 'Art' in title else None,
|
||||
title=title,
|
||||
description=description[:500],
|
||||
original_text=description,
|
||||
chunk_id=chunk_id,
|
||||
source_url=source_url
|
||||
)
|
||||
checkpoints.append(checkpoint)
|
||||
|
||||
return checkpoints
|
||||
|
||||
|
||||
def generate_control_for_checkpoints(
|
||||
checkpoints: List[Checkpoint],
|
||||
domain_counts: Dict[str, int],
|
||||
) -> Optional[Control]:
|
||||
"""
|
||||
Generate a control that covers the given checkpoints.
|
||||
|
||||
This is a simplified version - in production this would use the AI assistant.
|
||||
"""
|
||||
if not checkpoints:
|
||||
return None
|
||||
|
||||
# Group by regulation
|
||||
regulation = checkpoints[0].regulation_code
|
||||
|
||||
# Determine domain based on content
|
||||
all_text = " ".join([cp.description for cp in checkpoints]).lower()
|
||||
|
||||
domain = "gov" # Default
|
||||
if any(kw in all_text for kw in ["verschl\u00fcssel", "krypto", "encrypt", "hash"]):
|
||||
domain = "crypto"
|
||||
elif any(kw in all_text for kw in ["zugang", "access", "authentif", "login", "benutzer"]):
|
||||
domain = "iam"
|
||||
elif any(kw in all_text for kw in ["datenschutz", "personenbezogen", "privacy", "einwilligung"]):
|
||||
domain = "priv"
|
||||
elif any(kw in all_text for kw in ["entwicklung", "test", "code", "software"]):
|
||||
domain = "sdlc"
|
||||
elif any(kw in all_text for kw in ["\u00fcberwach", "monitor", "log", "audit"]):
|
||||
domain = "aud"
|
||||
elif any(kw in all_text for kw in ["ki", "k\u00fcnstlich", "ai", "machine learning", "model"]):
|
||||
domain = "ai"
|
||||
elif any(kw in all_text for kw in ["betrieb", "operation", "verf\u00fcgbar", "backup"]):
|
||||
domain = "ops"
|
||||
elif any(kw in all_text for kw in ["cyber", "resilience", "sbom", "vulnerab"]):
|
||||
domain = "cra"
|
||||
|
||||
# Generate control ID
|
||||
domain_count = domain_counts.get(domain, 0) + 1
|
||||
control_id = f"{domain.upper()}-{domain_count:03d}"
|
||||
|
||||
# Create title from first checkpoint
|
||||
title = checkpoints[0].title
|
||||
if len(title) > 100:
|
||||
title = title[:97] + "..."
|
||||
|
||||
# Create description
|
||||
description = f"Control f\u00fcr {regulation}: " + checkpoints[0].description[:200]
|
||||
|
||||
# Pass criteria
|
||||
pass_criteria = f"Alle {len(checkpoints)} zugeh\u00f6rigen Anforderungen sind erf\u00fcllt und dokumentiert."
|
||||
|
||||
# Implementation guidance
|
||||
guidance = f"Implementiere Ma\u00dfnahmen zur Erf\u00fcllung der Anforderungen aus {regulation}. "
|
||||
guidance += f"Dokumentiere die Umsetzung und f\u00fchre regelm\u00e4\u00dfige Reviews durch."
|
||||
|
||||
# Determine if automated
|
||||
is_automated = any(kw in all_text for kw in ["automat", "tool", "scan", "test"])
|
||||
|
||||
control = Control(
|
||||
id=control_id,
|
||||
domain=domain,
|
||||
title=title,
|
||||
description=description,
|
||||
checkpoints=[cp.id for cp in checkpoints],
|
||||
pass_criteria=pass_criteria,
|
||||
implementation_guidance=guidance,
|
||||
is_automated=is_automated,
|
||||
automation_tool="CI/CD Pipeline" if is_automated else None,
|
||||
priority="high" if "muss" in all_text or "erforderlich" in all_text else "medium"
|
||||
)
|
||||
|
||||
return control
|
||||
|
||||
|
||||
def generate_measure_for_control(control: Control) -> Measure:
|
||||
"""Generate a remediation measure for a control."""
|
||||
measure_id = f"M-{control.id}"
|
||||
|
||||
# Determine deadline based on priority
|
||||
deadline_days = {
|
||||
"critical": 30,
|
||||
"high": 60,
|
||||
"medium": 90,
|
||||
"low": 180
|
||||
}.get(control.priority, 90)
|
||||
|
||||
# Determine responsible team
|
||||
responsible = {
|
||||
"priv": "Datenschutzbeauftragter",
|
||||
"iam": "IT-Security Team",
|
||||
"sdlc": "Entwicklungsteam",
|
||||
"crypto": "IT-Security Team",
|
||||
"ops": "Operations Team",
|
||||
"aud": "Compliance Team",
|
||||
"ai": "AI/ML Team",
|
||||
"cra": "IT-Security Team",
|
||||
"gov": "Management"
|
||||
}.get(control.domain, "Compliance Team")
|
||||
|
||||
measure = Measure(
|
||||
id=measure_id,
|
||||
control_id=control.id,
|
||||
title=f"Umsetzung: {control.title[:50]}",
|
||||
description=f"Implementierung und Dokumentation von {control.id}: {control.description[:100]}",
|
||||
responsible=responsible,
|
||||
deadline_days=deadline_days,
|
||||
status="pending"
|
||||
)
|
||||
|
||||
return measure
|
||||
49
klausur-service/backend/compliance_models.py
Normal file
49
klausur-service/backend/compliance_models.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""
|
||||
Compliance Pipeline Data Models.
|
||||
|
||||
Dataclasses for checkpoints, controls, and measures.
|
||||
"""
|
||||
|
||||
from typing import Optional, List
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Checkpoint:
|
||||
"""A requirement/checkpoint extracted from legal text."""
|
||||
id: str
|
||||
regulation_code: str
|
||||
regulation_name: str
|
||||
article: Optional[str]
|
||||
title: str
|
||||
description: str
|
||||
original_text: str
|
||||
chunk_id: str
|
||||
source_url: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Control:
|
||||
"""A control derived from checkpoints."""
|
||||
id: str
|
||||
domain: str
|
||||
title: str
|
||||
description: str
|
||||
checkpoints: List[str] # List of checkpoint IDs
|
||||
pass_criteria: str
|
||||
implementation_guidance: str
|
||||
is_automated: bool
|
||||
automation_tool: Optional[str]
|
||||
priority: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Measure:
|
||||
"""A remediation measure for a control."""
|
||||
id: str
|
||||
control_id: str
|
||||
title: str
|
||||
description: str
|
||||
responsible: str
|
||||
deadline_days: int
|
||||
status: str
|
||||
441
klausur-service/backend/compliance_pipeline.py
Normal file
441
klausur-service/backend/compliance_pipeline.py
Normal file
@@ -0,0 +1,441 @@
|
||||
"""
|
||||
Compliance Pipeline Execution.
|
||||
|
||||
Pipeline phases (ingestion, extraction, control generation, measures)
|
||||
and orchestration logic.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Any
|
||||
from dataclasses import asdict
|
||||
|
||||
from compliance_models import Checkpoint, Control, Measure
|
||||
from compliance_extraction import (
|
||||
extract_checkpoints_from_chunk,
|
||||
generate_control_for_checkpoints,
|
||||
generate_measure_for_control,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Import checkpoint manager
|
||||
try:
|
||||
from pipeline_checkpoints import CheckpointManager, EXPECTED_VALUES, ValidationStatus
|
||||
except ImportError:
|
||||
logger.warning("Checkpoint manager not available, running without checkpoints")
|
||||
CheckpointManager = None
|
||||
EXPECTED_VALUES = {}
|
||||
ValidationStatus = None
|
||||
|
||||
# Set environment variables for Docker network
|
||||
if not os.getenv("QDRANT_URL") and not os.getenv("QDRANT_HOST"):
|
||||
os.environ["QDRANT_HOST"] = "qdrant"
|
||||
os.environ.setdefault("EMBEDDING_SERVICE_URL", "http://embedding-service:8087")
|
||||
|
||||
# Try to import from klausur-service
|
||||
try:
|
||||
from legal_corpus_ingestion import LegalCorpusIngestion, REGULATIONS, LEGAL_CORPUS_COLLECTION
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import Filter, FieldCondition, MatchValue
|
||||
except ImportError:
|
||||
logger.error("Could not import required modules. Make sure you're in the klausur-service container.")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
class CompliancePipeline:
|
||||
"""Handles the full compliance pipeline."""
|
||||
|
||||
def __init__(self):
|
||||
# Support both QDRANT_URL and QDRANT_HOST/PORT
|
||||
qdrant_url = os.getenv("QDRANT_URL", "")
|
||||
if qdrant_url:
|
||||
from urllib.parse import urlparse
|
||||
parsed = urlparse(qdrant_url)
|
||||
qdrant_host = parsed.hostname or "qdrant"
|
||||
qdrant_port = parsed.port or 6333
|
||||
else:
|
||||
qdrant_host = os.getenv("QDRANT_HOST", "qdrant")
|
||||
qdrant_port = 6333
|
||||
self.qdrant = QdrantClient(host=qdrant_host, port=qdrant_port)
|
||||
self.checkpoints: List[Checkpoint] = []
|
||||
self.controls: List[Control] = []
|
||||
self.measures: List[Measure] = []
|
||||
self.stats = {
|
||||
"chunks_processed": 0,
|
||||
"checkpoints_extracted": 0,
|
||||
"controls_created": 0,
|
||||
"measures_defined": 0,
|
||||
"by_regulation": {},
|
||||
"by_domain": {},
|
||||
}
|
||||
# Initialize checkpoint manager
|
||||
self.checkpoint_mgr = CheckpointManager() if CheckpointManager else None
|
||||
|
||||
async def run_ingestion_phase(self, force_reindex: bool = False) -> int:
|
||||
"""Phase 1: Ingest documents (incremental - only missing ones)."""
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("PHASE 1: DOCUMENT INGESTION (INCREMENTAL)")
|
||||
logger.info("=" * 60)
|
||||
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.start_checkpoint("ingestion", "Document Ingestion")
|
||||
|
||||
ingestion = LegalCorpusIngestion()
|
||||
|
||||
try:
|
||||
# Check existing chunks per regulation
|
||||
existing_chunks = {}
|
||||
try:
|
||||
for regulation in REGULATIONS:
|
||||
count_result = self.qdrant.count(
|
||||
collection_name=LEGAL_CORPUS_COLLECTION,
|
||||
count_filter=Filter(
|
||||
must=[FieldCondition(key="regulation_code", match=MatchValue(value=regulation.code))]
|
||||
)
|
||||
)
|
||||
existing_chunks[regulation.code] = count_result.count
|
||||
logger.info(f" {regulation.code}: {count_result.count} existing chunks")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not check existing chunks: {e}")
|
||||
|
||||
# Determine which regulations need ingestion
|
||||
regulations_to_ingest = []
|
||||
for regulation in REGULATIONS:
|
||||
existing = existing_chunks.get(regulation.code, 0)
|
||||
if force_reindex or existing == 0:
|
||||
regulations_to_ingest.append(regulation)
|
||||
logger.info(f" -> Will ingest: {regulation.code} (existing: {existing}, force: {force_reindex})")
|
||||
else:
|
||||
logger.info(f" -> Skipping: {regulation.code} (already has {existing} chunks)")
|
||||
self.stats["by_regulation"][regulation.code] = existing
|
||||
|
||||
if not regulations_to_ingest:
|
||||
logger.info("All regulations already indexed. Skipping ingestion phase.")
|
||||
total_chunks = sum(existing_chunks.values())
|
||||
self.stats["chunks_processed"] = total_chunks
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.add_metric("total_chunks", total_chunks)
|
||||
self.checkpoint_mgr.add_metric("skipped", True)
|
||||
self.checkpoint_mgr.complete_checkpoint(success=True)
|
||||
return total_chunks
|
||||
|
||||
# Ingest only missing regulations
|
||||
total_chunks = sum(existing_chunks.values())
|
||||
for i, regulation in enumerate(regulations_to_ingest, 1):
|
||||
logger.info(f"[{i}/{len(regulations_to_ingest)}] Ingesting {regulation.code}...")
|
||||
try:
|
||||
count = await ingestion.ingest_regulation(regulation)
|
||||
total_chunks += count
|
||||
self.stats["by_regulation"][regulation.code] = count
|
||||
logger.info(f" -> {count} chunks")
|
||||
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.add_metric(f"chunks_{regulation.code}", count)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f" -> FAILED: {e}")
|
||||
self.stats["by_regulation"][regulation.code] = 0
|
||||
|
||||
self.stats["chunks_processed"] = total_chunks
|
||||
logger.info(f"\nTotal chunks in collection: {total_chunks}")
|
||||
|
||||
# Validate ingestion results
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.add_metric("total_chunks", total_chunks)
|
||||
self.checkpoint_mgr.add_metric("regulations_count", len(REGULATIONS))
|
||||
|
||||
expected = EXPECTED_VALUES.get("ingestion", {})
|
||||
self.checkpoint_mgr.validate(
|
||||
"total_chunks",
|
||||
expected=expected.get("total_chunks", 8000),
|
||||
actual=total_chunks,
|
||||
min_value=expected.get("min_chunks", 7000)
|
||||
)
|
||||
|
||||
reg_expected = expected.get("regulations", {})
|
||||
for reg_code, reg_exp in reg_expected.items():
|
||||
actual = self.stats["by_regulation"].get(reg_code, 0)
|
||||
self.checkpoint_mgr.validate(
|
||||
f"chunks_{reg_code}",
|
||||
expected=reg_exp.get("expected", 0),
|
||||
actual=actual,
|
||||
min_value=reg_exp.get("min", 0)
|
||||
)
|
||||
|
||||
self.checkpoint_mgr.complete_checkpoint(success=True)
|
||||
|
||||
return total_chunks
|
||||
|
||||
except Exception as e:
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.fail_checkpoint(str(e))
|
||||
raise
|
||||
|
||||
finally:
|
||||
await ingestion.close()
|
||||
|
||||
async def run_extraction_phase(self) -> int:
|
||||
"""Phase 2: Extract checkpoints from chunks."""
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("PHASE 2: CHECKPOINT EXTRACTION")
|
||||
logger.info("=" * 60)
|
||||
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.start_checkpoint("extraction", "Checkpoint Extraction")
|
||||
|
||||
try:
|
||||
offset = None
|
||||
total_checkpoints = 0
|
||||
|
||||
while True:
|
||||
result = self.qdrant.scroll(
|
||||
collection_name=LEGAL_CORPUS_COLLECTION,
|
||||
limit=100,
|
||||
offset=offset,
|
||||
with_payload=True,
|
||||
with_vectors=False
|
||||
)
|
||||
|
||||
points, next_offset = result
|
||||
|
||||
if not points:
|
||||
break
|
||||
|
||||
for point in points:
|
||||
payload = point.payload
|
||||
text = payload.get("text", "")
|
||||
|
||||
cps = extract_checkpoints_from_chunk(text, payload)
|
||||
self.checkpoints.extend(cps)
|
||||
total_checkpoints += len(cps)
|
||||
|
||||
logger.info(f"Processed {len(points)} chunks, extracted {total_checkpoints} checkpoints so far...")
|
||||
|
||||
if next_offset is None:
|
||||
break
|
||||
offset = next_offset
|
||||
|
||||
self.stats["checkpoints_extracted"] = len(self.checkpoints)
|
||||
logger.info(f"\nTotal checkpoints extracted: {len(self.checkpoints)}")
|
||||
|
||||
by_reg = {}
|
||||
for cp in self.checkpoints:
|
||||
by_reg[cp.regulation_code] = by_reg.get(cp.regulation_code, 0) + 1
|
||||
for reg, count in sorted(by_reg.items()):
|
||||
logger.info(f" {reg}: {count} checkpoints")
|
||||
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.add_metric("total_checkpoints", len(self.checkpoints))
|
||||
self.checkpoint_mgr.add_metric("checkpoints_by_regulation", by_reg)
|
||||
|
||||
expected = EXPECTED_VALUES.get("extraction", {})
|
||||
self.checkpoint_mgr.validate(
|
||||
"total_checkpoints",
|
||||
expected=expected.get("total_checkpoints", 3500),
|
||||
actual=len(self.checkpoints),
|
||||
min_value=expected.get("min_checkpoints", 3000)
|
||||
)
|
||||
|
||||
self.checkpoint_mgr.complete_checkpoint(success=True)
|
||||
|
||||
return len(self.checkpoints)
|
||||
|
||||
except Exception as e:
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.fail_checkpoint(str(e))
|
||||
raise
|
||||
|
||||
async def run_control_generation_phase(self) -> int:
|
||||
"""Phase 3: Generate controls from checkpoints."""
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("PHASE 3: CONTROL GENERATION")
|
||||
logger.info("=" * 60)
|
||||
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.start_checkpoint("controls", "Control Generation")
|
||||
|
||||
try:
|
||||
# Group checkpoints by regulation
|
||||
by_regulation: Dict[str, List[Checkpoint]] = {}
|
||||
for cp in self.checkpoints:
|
||||
reg = cp.regulation_code
|
||||
if reg not in by_regulation:
|
||||
by_regulation[reg] = []
|
||||
by_regulation[reg].append(cp)
|
||||
|
||||
# Generate controls per regulation (group every 3-5 checkpoints)
|
||||
for regulation, checkpoints in by_regulation.items():
|
||||
logger.info(f"Generating controls for {regulation} ({len(checkpoints)} checkpoints)...")
|
||||
|
||||
batch_size = 4
|
||||
for i in range(0, len(checkpoints), batch_size):
|
||||
batch = checkpoints[i:i + batch_size]
|
||||
control = generate_control_for_checkpoints(batch, self.stats.get("by_domain", {}))
|
||||
|
||||
if control:
|
||||
self.controls.append(control)
|
||||
self.stats["by_domain"][control.domain] = self.stats["by_domain"].get(control.domain, 0) + 1
|
||||
|
||||
self.stats["controls_created"] = len(self.controls)
|
||||
logger.info(f"\nTotal controls created: {len(self.controls)}")
|
||||
|
||||
for domain, count in sorted(self.stats["by_domain"].items()):
|
||||
logger.info(f" {domain}: {count} controls")
|
||||
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.add_metric("total_controls", len(self.controls))
|
||||
self.checkpoint_mgr.add_metric("controls_by_domain", dict(self.stats["by_domain"]))
|
||||
|
||||
expected = EXPECTED_VALUES.get("controls", {})
|
||||
self.checkpoint_mgr.validate(
|
||||
"total_controls",
|
||||
expected=expected.get("total_controls", 900),
|
||||
actual=len(self.controls),
|
||||
min_value=expected.get("min_controls", 800)
|
||||
)
|
||||
|
||||
self.checkpoint_mgr.complete_checkpoint(success=True)
|
||||
|
||||
return len(self.controls)
|
||||
|
||||
except Exception as e:
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.fail_checkpoint(str(e))
|
||||
raise
|
||||
|
||||
async def run_measure_generation_phase(self) -> int:
|
||||
"""Phase 4: Generate measures for controls."""
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("PHASE 4: MEASURE GENERATION")
|
||||
logger.info("=" * 60)
|
||||
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.start_checkpoint("measures", "Measure Generation")
|
||||
|
||||
try:
|
||||
for control in self.controls:
|
||||
measure = generate_measure_for_control(control)
|
||||
self.measures.append(measure)
|
||||
|
||||
self.stats["measures_defined"] = len(self.measures)
|
||||
logger.info(f"\nTotal measures defined: {len(self.measures)}")
|
||||
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.add_metric("total_measures", len(self.measures))
|
||||
|
||||
expected = EXPECTED_VALUES.get("measures", {})
|
||||
self.checkpoint_mgr.validate(
|
||||
"total_measures",
|
||||
expected=expected.get("total_measures", 900),
|
||||
actual=len(self.measures),
|
||||
min_value=expected.get("min_measures", 800)
|
||||
)
|
||||
|
||||
self.checkpoint_mgr.complete_checkpoint(success=True)
|
||||
|
||||
return len(self.measures)
|
||||
|
||||
except Exception as e:
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.fail_checkpoint(str(e))
|
||||
raise
|
||||
|
||||
def save_results(self, output_dir: str = "/tmp/compliance_output"):
|
||||
"""Save results to JSON files."""
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("SAVING RESULTS")
|
||||
logger.info("=" * 60)
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
checkpoints_file = os.path.join(output_dir, "checkpoints.json")
|
||||
with open(checkpoints_file, "w") as f:
|
||||
json.dump([asdict(cp) for cp in self.checkpoints], f, indent=2, ensure_ascii=False)
|
||||
logger.info(f"Saved {len(self.checkpoints)} checkpoints to {checkpoints_file}")
|
||||
|
||||
controls_file = os.path.join(output_dir, "controls.json")
|
||||
with open(controls_file, "w") as f:
|
||||
json.dump([asdict(c) for c in self.controls], f, indent=2, ensure_ascii=False)
|
||||
logger.info(f"Saved {len(self.controls)} controls to {controls_file}")
|
||||
|
||||
measures_file = os.path.join(output_dir, "measures.json")
|
||||
with open(measures_file, "w") as f:
|
||||
json.dump([asdict(m) for m in self.measures], f, indent=2, ensure_ascii=False)
|
||||
logger.info(f"Saved {len(self.measures)} measures to {measures_file}")
|
||||
|
||||
stats_file = os.path.join(output_dir, "statistics.json")
|
||||
self.stats["generated_at"] = datetime.now().isoformat()
|
||||
with open(stats_file, "w") as f:
|
||||
json.dump(self.stats, f, indent=2, ensure_ascii=False)
|
||||
logger.info(f"Saved statistics to {stats_file}")
|
||||
|
||||
async def run_full_pipeline(self, force_reindex: bool = False, skip_ingestion: bool = False):
|
||||
"""Run the complete pipeline.
|
||||
|
||||
Args:
|
||||
force_reindex: If True, re-ingest all documents even if they exist
|
||||
skip_ingestion: If True, skip ingestion phase entirely (use existing chunks)
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("FULL COMPLIANCE PIPELINE (INCREMENTAL)")
|
||||
logger.info(f"Started at: {datetime.now().isoformat()}")
|
||||
logger.info(f"Force reindex: {force_reindex}")
|
||||
logger.info(f"Skip ingestion: {skip_ingestion}")
|
||||
if self.checkpoint_mgr:
|
||||
logger.info(f"Pipeline ID: {self.checkpoint_mgr.pipeline_id}")
|
||||
logger.info("=" * 60)
|
||||
|
||||
try:
|
||||
if skip_ingestion:
|
||||
logger.info("Skipping ingestion phase as requested...")
|
||||
try:
|
||||
collection_info = self.qdrant.get_collection(LEGAL_CORPUS_COLLECTION)
|
||||
self.stats["chunks_processed"] = collection_info.points_count
|
||||
except Exception:
|
||||
self.stats["chunks_processed"] = 0
|
||||
else:
|
||||
await self.run_ingestion_phase(force_reindex=force_reindex)
|
||||
|
||||
await self.run_extraction_phase()
|
||||
await self.run_control_generation_phase()
|
||||
await self.run_measure_generation_phase()
|
||||
self.save_results()
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("PIPELINE COMPLETE")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"Duration: {elapsed:.1f} seconds")
|
||||
logger.info(f"Chunks processed: {self.stats['chunks_processed']}")
|
||||
logger.info(f"Checkpoints extracted: {self.stats['checkpoints_extracted']}")
|
||||
logger.info(f"Controls created: {self.stats['controls_created']}")
|
||||
logger.info(f"Measures defined: {self.stats['measures_defined']}")
|
||||
logger.info(f"\nResults saved to: /tmp/compliance_output/")
|
||||
logger.info("Checkpoint status: /tmp/pipeline_checkpoints.json")
|
||||
logger.info("=" * 60)
|
||||
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.complete_pipeline({
|
||||
"duration_seconds": elapsed,
|
||||
"chunks_processed": self.stats['chunks_processed'],
|
||||
"checkpoints_extracted": self.stats['checkpoints_extracted'],
|
||||
"controls_created": self.stats['controls_created'],
|
||||
"measures_defined": self.stats['measures_defined'],
|
||||
"by_regulation": self.stats['by_regulation'],
|
||||
"by_domain": self.stats['by_domain'],
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Pipeline failed: {e}")
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.state.status = "failed"
|
||||
self.checkpoint_mgr._save()
|
||||
raise
|
||||
@@ -1,7 +1,10 @@
|
||||
"""
|
||||
DSFA RAG API Endpoints.
|
||||
DSFA RAG API Endpoints — Barrel Re-export.
|
||||
|
||||
Provides REST API for searching DSFA corpus with full source attribution.
|
||||
Split into submodules:
|
||||
- dsfa_rag_models.py — Pydantic request/response models
|
||||
- dsfa_rag_embedding.py — Embedding service integration & text extraction
|
||||
- dsfa_rag_routes.py — Route handlers (search, sources, ingest, stats)
|
||||
|
||||
Endpoints:
|
||||
- GET /api/v1/dsfa-rag/search - Semantic search with attribution
|
||||
@@ -11,705 +14,54 @@ Endpoints:
|
||||
- GET /api/v1/dsfa-rag/stats - Get corpus statistics
|
||||
"""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, HTTPException, Query, Depends
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Embedding service configuration
|
||||
EMBEDDING_SERVICE_URL = os.getenv("EMBEDDING_SERVICE_URL", "http://172.18.0.13:8087")
|
||||
|
||||
# Import from ingestion module
|
||||
from dsfa_corpus_ingestion import (
|
||||
DSFACorpusStore,
|
||||
DSFAQdrantService,
|
||||
DSFASearchResult,
|
||||
LICENSE_REGISTRY,
|
||||
DSFA_SOURCES,
|
||||
generate_attribution_notice,
|
||||
get_license_label,
|
||||
DSFA_COLLECTION,
|
||||
chunk_document
|
||||
# Models
|
||||
from dsfa_rag_models import (
|
||||
DSFASourceResponse,
|
||||
DSFAChunkResponse,
|
||||
DSFASearchResultResponse,
|
||||
DSFASearchResponse,
|
||||
DSFASourceStatsResponse,
|
||||
DSFACorpusStatsResponse,
|
||||
IngestRequest,
|
||||
IngestResponse,
|
||||
LicenseInfo,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/dsfa-rag", tags=["DSFA RAG"])
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Pydantic Models
|
||||
# =============================================================================
|
||||
|
||||
class DSFASourceResponse(BaseModel):
|
||||
"""Response model for DSFA source."""
|
||||
id: str
|
||||
source_code: str
|
||||
name: str
|
||||
full_name: Optional[str] = None
|
||||
organization: Optional[str] = None
|
||||
source_url: Optional[str] = None
|
||||
license_code: str
|
||||
license_name: str
|
||||
license_url: Optional[str] = None
|
||||
attribution_required: bool
|
||||
attribution_text: str
|
||||
document_type: Optional[str] = None
|
||||
language: str = "de"
|
||||
|
||||
|
||||
class DSFAChunkResponse(BaseModel):
|
||||
"""Response model for a single chunk with attribution."""
|
||||
chunk_id: str
|
||||
content: str
|
||||
section_title: Optional[str] = None
|
||||
page_number: Optional[int] = None
|
||||
category: Optional[str] = None
|
||||
|
||||
# Document info
|
||||
document_id: str
|
||||
document_title: Optional[str] = None
|
||||
|
||||
# Attribution (always included)
|
||||
source_id: str
|
||||
source_code: str
|
||||
source_name: str
|
||||
attribution_text: str
|
||||
license_code: str
|
||||
license_name: str
|
||||
license_url: Optional[str] = None
|
||||
attribution_required: bool
|
||||
source_url: Optional[str] = None
|
||||
document_type: Optional[str] = None
|
||||
|
||||
|
||||
class DSFASearchResultResponse(BaseModel):
|
||||
"""Response model for search result."""
|
||||
chunk_id: str
|
||||
content: str
|
||||
score: float
|
||||
|
||||
# Attribution
|
||||
source_code: str
|
||||
source_name: str
|
||||
attribution_text: str
|
||||
license_code: str
|
||||
license_name: str
|
||||
license_url: Optional[str] = None
|
||||
attribution_required: bool
|
||||
source_url: Optional[str] = None
|
||||
|
||||
# Metadata
|
||||
document_type: Optional[str] = None
|
||||
category: Optional[str] = None
|
||||
section_title: Optional[str] = None
|
||||
page_number: Optional[int] = None
|
||||
|
||||
|
||||
class DSFASearchResponse(BaseModel):
|
||||
"""Response model for search endpoint."""
|
||||
query: str
|
||||
results: List[DSFASearchResultResponse]
|
||||
total_results: int
|
||||
|
||||
# Aggregated licenses for footer
|
||||
licenses_used: List[str]
|
||||
attribution_notice: str
|
||||
|
||||
|
||||
class DSFASourceStatsResponse(BaseModel):
|
||||
"""Response model for source statistics."""
|
||||
source_id: str
|
||||
source_code: str
|
||||
name: str
|
||||
organization: Optional[str] = None
|
||||
license_code: str
|
||||
document_type: Optional[str] = None
|
||||
document_count: int
|
||||
chunk_count: int
|
||||
last_indexed_at: Optional[str] = None
|
||||
|
||||
|
||||
class DSFACorpusStatsResponse(BaseModel):
|
||||
"""Response model for corpus statistics."""
|
||||
sources: List[DSFASourceStatsResponse]
|
||||
total_sources: int
|
||||
total_documents: int
|
||||
total_chunks: int
|
||||
qdrant_collection: str
|
||||
qdrant_points_count: int
|
||||
qdrant_status: str
|
||||
|
||||
|
||||
class IngestRequest(BaseModel):
|
||||
"""Request model for ingestion."""
|
||||
document_url: Optional[str] = None
|
||||
document_text: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
|
||||
|
||||
class IngestResponse(BaseModel):
|
||||
"""Response model for ingestion."""
|
||||
source_code: str
|
||||
document_id: Optional[str] = None
|
||||
chunks_created: int
|
||||
message: str
|
||||
|
||||
|
||||
class LicenseInfo(BaseModel):
|
||||
"""License information."""
|
||||
code: str
|
||||
name: str
|
||||
url: Optional[str] = None
|
||||
attribution_required: bool
|
||||
modification_allowed: bool
|
||||
commercial_use: bool
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Dependency Injection
|
||||
# =============================================================================
|
||||
|
||||
# Database pool (will be set from main.py)
|
||||
_db_pool = None
|
||||
|
||||
|
||||
def set_db_pool(pool):
|
||||
"""Set the database pool for API endpoints."""
|
||||
global _db_pool
|
||||
_db_pool = pool
|
||||
|
||||
|
||||
async def get_store() -> DSFACorpusStore:
|
||||
"""Get DSFA corpus store."""
|
||||
if _db_pool is None:
|
||||
raise HTTPException(status_code=503, detail="Database not initialized")
|
||||
return DSFACorpusStore(_db_pool)
|
||||
|
||||
|
||||
async def get_qdrant() -> DSFAQdrantService:
|
||||
"""Get Qdrant service."""
|
||||
return DSFAQdrantService()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Embedding Service Integration
|
||||
# =============================================================================
|
||||
|
||||
async def get_embedding(text: str) -> List[float]:
|
||||
"""
|
||||
Get embedding for text using the embedding-service.
|
||||
|
||||
Uses BGE-M3 model which produces 1024-dimensional vectors.
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
try:
|
||||
response = await client.post(
|
||||
f"{EMBEDDING_SERVICE_URL}/embed-single",
|
||||
json={"text": text}
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("embedding", [])
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"Embedding service error: {e}")
|
||||
# Fallback to hash-based pseudo-embedding for development
|
||||
return _generate_fallback_embedding(text)
|
||||
|
||||
|
||||
async def get_embeddings_batch(texts: List[str]) -> List[List[float]]:
|
||||
"""
|
||||
Get embeddings for multiple texts in batch.
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
try:
|
||||
response = await client.post(
|
||||
f"{EMBEDDING_SERVICE_URL}/embed",
|
||||
json={"texts": texts}
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("embeddings", [])
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"Embedding service batch error: {e}")
|
||||
# Fallback
|
||||
return [_generate_fallback_embedding(t) for t in texts]
|
||||
|
||||
|
||||
async def extract_text_from_url(url: str) -> str:
|
||||
"""
|
||||
Extract text from a document URL (PDF, HTML, etc.).
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
try:
|
||||
# First try to use the embedding-service's extract-pdf endpoint
|
||||
response = await client.post(
|
||||
f"{EMBEDDING_SERVICE_URL}/extract-pdf",
|
||||
json={"url": url}
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("text", "")
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"PDF extraction error for {url}: {e}")
|
||||
# Fallback: try to fetch HTML content directly
|
||||
try:
|
||||
response = await client.get(url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
content_type = response.headers.get("content-type", "")
|
||||
if "html" in content_type:
|
||||
# Simple HTML text extraction
|
||||
import re
|
||||
html = response.text
|
||||
# Remove scripts and styles
|
||||
html = re.sub(r'<script[^>]*>.*?</script>', '', html, flags=re.DOTALL | re.IGNORECASE)
|
||||
html = re.sub(r'<style[^>]*>.*?</style>', '', html, flags=re.DOTALL | re.IGNORECASE)
|
||||
# Remove tags
|
||||
text = re.sub(r'<[^>]+>', ' ', html)
|
||||
# Clean whitespace
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
return text
|
||||
else:
|
||||
return ""
|
||||
except Exception as fetch_err:
|
||||
logger.error(f"Fallback fetch error for {url}: {fetch_err}")
|
||||
return ""
|
||||
|
||||
|
||||
def _generate_fallback_embedding(text: str) -> List[float]:
|
||||
"""
|
||||
Generate deterministic pseudo-embedding from text hash.
|
||||
Used as fallback when embedding service is unavailable.
|
||||
"""
|
||||
import hashlib
|
||||
import struct
|
||||
|
||||
hash_bytes = hashlib.sha256(text.encode()).digest()
|
||||
embedding = []
|
||||
for i in range(0, min(len(hash_bytes), 128), 4):
|
||||
val = struct.unpack('f', hash_bytes[i:i+4])[0]
|
||||
embedding.append(val % 1.0)
|
||||
|
||||
# Pad to 1024 dimensions
|
||||
while len(embedding) < 1024:
|
||||
embedding.extend(embedding[:min(len(embedding), 1024 - len(embedding))])
|
||||
|
||||
return embedding[:1024]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# API Endpoints
|
||||
# =============================================================================
|
||||
|
||||
@router.get("/search", response_model=DSFASearchResponse)
|
||||
async def search_dsfa_corpus(
|
||||
query: str = Query(..., min_length=3, description="Search query"),
|
||||
source_codes: Optional[List[str]] = Query(None, description="Filter by source codes"),
|
||||
document_types: Optional[List[str]] = Query(None, description="Filter by document types (guideline, checklist, regulation)"),
|
||||
categories: Optional[List[str]] = Query(None, description="Filter by categories (threshold_analysis, risk_assessment, mitigation)"),
|
||||
limit: int = Query(10, ge=1, le=50, description="Maximum results"),
|
||||
include_attribution: bool = Query(True, description="Include attribution in results"),
|
||||
store: DSFACorpusStore = Depends(get_store),
|
||||
qdrant: DSFAQdrantService = Depends(get_qdrant)
|
||||
):
|
||||
"""
|
||||
Search DSFA corpus with full attribution.
|
||||
|
||||
Returns matching chunks with source/license information for compliance.
|
||||
"""
|
||||
# Get query embedding
|
||||
query_embedding = await get_embedding(query)
|
||||
|
||||
# Search Qdrant
|
||||
raw_results = await qdrant.search(
|
||||
query_embedding=query_embedding,
|
||||
source_codes=source_codes,
|
||||
document_types=document_types,
|
||||
categories=categories,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
# Transform results
|
||||
results = []
|
||||
licenses_used = set()
|
||||
|
||||
for r in raw_results:
|
||||
license_code = r.get("license_code", "")
|
||||
license_info = LICENSE_REGISTRY.get(license_code, {})
|
||||
|
||||
result = DSFASearchResultResponse(
|
||||
chunk_id=r.get("chunk_id", ""),
|
||||
content=r.get("content", ""),
|
||||
score=r.get("score", 0.0),
|
||||
source_code=r.get("source_code", ""),
|
||||
source_name=r.get("source_name", ""),
|
||||
attribution_text=r.get("attribution_text", ""),
|
||||
license_code=license_code,
|
||||
license_name=license_info.get("name", license_code),
|
||||
license_url=license_info.get("url"),
|
||||
attribution_required=r.get("attribution_required", True),
|
||||
source_url=r.get("source_url"),
|
||||
document_type=r.get("document_type"),
|
||||
category=r.get("category"),
|
||||
section_title=r.get("section_title"),
|
||||
page_number=r.get("page_number")
|
||||
)
|
||||
results.append(result)
|
||||
licenses_used.add(license_code)
|
||||
|
||||
# Generate attribution notice
|
||||
search_results = [
|
||||
DSFASearchResult(
|
||||
chunk_id=r.chunk_id,
|
||||
content=r.content,
|
||||
score=r.score,
|
||||
source_code=r.source_code,
|
||||
source_name=r.source_name,
|
||||
attribution_text=r.attribution_text,
|
||||
license_code=r.license_code,
|
||||
license_url=r.license_url,
|
||||
attribution_required=r.attribution_required,
|
||||
source_url=r.source_url,
|
||||
document_type=r.document_type or "",
|
||||
category=r.category or "",
|
||||
section_title=r.section_title,
|
||||
page_number=r.page_number
|
||||
)
|
||||
for r in results
|
||||
]
|
||||
attribution_notice = generate_attribution_notice(search_results) if include_attribution else ""
|
||||
|
||||
return DSFASearchResponse(
|
||||
query=query,
|
||||
results=results,
|
||||
total_results=len(results),
|
||||
licenses_used=list(licenses_used),
|
||||
attribution_notice=attribution_notice
|
||||
)
|
||||
|
||||
|
||||
@router.get("/sources", response_model=List[DSFASourceResponse])
|
||||
async def list_dsfa_sources(
|
||||
document_type: Optional[str] = Query(None, description="Filter by document type"),
|
||||
license_code: Optional[str] = Query(None, description="Filter by license"),
|
||||
store: DSFACorpusStore = Depends(get_store)
|
||||
):
|
||||
"""List all registered DSFA sources with license info."""
|
||||
sources = await store.list_sources()
|
||||
|
||||
result = []
|
||||
for s in sources:
|
||||
# Apply filters
|
||||
if document_type and s.get("document_type") != document_type:
|
||||
continue
|
||||
if license_code and s.get("license_code") != license_code:
|
||||
continue
|
||||
|
||||
license_info = LICENSE_REGISTRY.get(s.get("license_code", ""), {})
|
||||
|
||||
result.append(DSFASourceResponse(
|
||||
id=str(s["id"]),
|
||||
source_code=s["source_code"],
|
||||
name=s["name"],
|
||||
full_name=s.get("full_name"),
|
||||
organization=s.get("organization"),
|
||||
source_url=s.get("source_url"),
|
||||
license_code=s.get("license_code", ""),
|
||||
license_name=license_info.get("name", s.get("license_code", "")),
|
||||
license_url=license_info.get("url"),
|
||||
attribution_required=s.get("attribution_required", True),
|
||||
attribution_text=s.get("attribution_text", ""),
|
||||
document_type=s.get("document_type"),
|
||||
language=s.get("language", "de")
|
||||
))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/sources/available")
|
||||
async def list_available_sources():
|
||||
"""List all available source definitions (from DSFA_SOURCES constant)."""
|
||||
return [
|
||||
{
|
||||
"source_code": s["source_code"],
|
||||
"name": s["name"],
|
||||
"organization": s.get("organization"),
|
||||
"license_code": s["license_code"],
|
||||
"document_type": s.get("document_type")
|
||||
}
|
||||
for s in DSFA_SOURCES
|
||||
]
|
||||
|
||||
|
||||
@router.get("/sources/{source_code}", response_model=DSFASourceResponse)
|
||||
async def get_dsfa_source(
|
||||
source_code: str,
|
||||
store: DSFACorpusStore = Depends(get_store)
|
||||
):
|
||||
"""Get details for a specific source."""
|
||||
source = await store.get_source_by_code(source_code)
|
||||
|
||||
if not source:
|
||||
raise HTTPException(status_code=404, detail=f"Source not found: {source_code}")
|
||||
|
||||
license_info = LICENSE_REGISTRY.get(source.get("license_code", ""), {})
|
||||
|
||||
return DSFASourceResponse(
|
||||
id=str(source["id"]),
|
||||
source_code=source["source_code"],
|
||||
name=source["name"],
|
||||
full_name=source.get("full_name"),
|
||||
organization=source.get("organization"),
|
||||
source_url=source.get("source_url"),
|
||||
license_code=source.get("license_code", ""),
|
||||
license_name=license_info.get("name", source.get("license_code", "")),
|
||||
license_url=license_info.get("url"),
|
||||
attribution_required=source.get("attribution_required", True),
|
||||
attribution_text=source.get("attribution_text", ""),
|
||||
document_type=source.get("document_type"),
|
||||
language=source.get("language", "de")
|
||||
)
|
||||
|
||||
|
||||
@router.post("/sources/{source_code}/ingest", response_model=IngestResponse)
|
||||
async def ingest_dsfa_source(
|
||||
source_code: str,
|
||||
request: IngestRequest,
|
||||
store: DSFACorpusStore = Depends(get_store),
|
||||
qdrant: DSFAQdrantService = Depends(get_qdrant)
|
||||
):
|
||||
"""
|
||||
Trigger ingestion for a specific source.
|
||||
|
||||
Can provide document via URL or direct text.
|
||||
"""
|
||||
# Get source
|
||||
source = await store.get_source_by_code(source_code)
|
||||
if not source:
|
||||
raise HTTPException(status_code=404, detail=f"Source not found: {source_code}")
|
||||
|
||||
# Need either URL or text
|
||||
if not request.document_text and not request.document_url:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Either document_text or document_url must be provided"
|
||||
)
|
||||
|
||||
# Ensure Qdrant collection exists
|
||||
await qdrant.ensure_collection()
|
||||
|
||||
# Get text content
|
||||
text_content = request.document_text
|
||||
if request.document_url and not text_content:
|
||||
# Download and extract text from URL
|
||||
logger.info(f"Extracting text from URL: {request.document_url}")
|
||||
text_content = await extract_text_from_url(request.document_url)
|
||||
if not text_content:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Could not extract text from URL: {request.document_url}"
|
||||
)
|
||||
|
||||
if not text_content or len(text_content.strip()) < 50:
|
||||
raise HTTPException(status_code=400, detail="Document text too short (min 50 chars)")
|
||||
|
||||
# Create document record
|
||||
doc_title = request.title or f"Document for {source_code}"
|
||||
document_id = await store.create_document(
|
||||
source_id=str(source["id"]),
|
||||
title=doc_title,
|
||||
file_type="text",
|
||||
metadata={"ingested_via": "api", "source_code": source_code}
|
||||
)
|
||||
|
||||
# Chunk the document
|
||||
chunks = chunk_document(text_content, source_code)
|
||||
|
||||
if not chunks:
|
||||
return IngestResponse(
|
||||
source_code=source_code,
|
||||
document_id=document_id,
|
||||
chunks_created=0,
|
||||
message="Document created but no chunks generated"
|
||||
)
|
||||
|
||||
# Generate embeddings in batch for efficiency
|
||||
chunk_texts = [chunk["content"] for chunk in chunks]
|
||||
logger.info(f"Generating embeddings for {len(chunk_texts)} chunks...")
|
||||
embeddings = await get_embeddings_batch(chunk_texts)
|
||||
|
||||
# Create chunk records in PostgreSQL and prepare for Qdrant
|
||||
chunk_records = []
|
||||
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
|
||||
# Create chunk in PostgreSQL
|
||||
chunk_id = await store.create_chunk(
|
||||
document_id=document_id,
|
||||
source_id=str(source["id"]),
|
||||
content=chunk["content"],
|
||||
chunk_index=i,
|
||||
section_title=chunk.get("section_title"),
|
||||
page_number=chunk.get("page_number"),
|
||||
category=chunk.get("category")
|
||||
)
|
||||
|
||||
chunk_records.append({
|
||||
"chunk_id": chunk_id,
|
||||
"document_id": document_id,
|
||||
"source_id": str(source["id"]),
|
||||
"content": chunk["content"],
|
||||
"section_title": chunk.get("section_title"),
|
||||
"source_code": source_code,
|
||||
"source_name": source["name"],
|
||||
"attribution_text": source["attribution_text"],
|
||||
"license_code": source["license_code"],
|
||||
"attribution_required": source.get("attribution_required", True),
|
||||
"document_type": source.get("document_type", ""),
|
||||
"category": chunk.get("category", ""),
|
||||
"language": source.get("language", "de"),
|
||||
"page_number": chunk.get("page_number")
|
||||
})
|
||||
|
||||
# Index in Qdrant
|
||||
indexed_count = await qdrant.index_chunks(chunk_records, embeddings)
|
||||
|
||||
# Update document record
|
||||
await store.update_document_indexed(document_id, len(chunks))
|
||||
|
||||
return IngestResponse(
|
||||
source_code=source_code,
|
||||
document_id=document_id,
|
||||
chunks_created=indexed_count,
|
||||
message=f"Successfully ingested {indexed_count} chunks from document"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/chunks/{chunk_id}", response_model=DSFAChunkResponse)
|
||||
async def get_chunk_with_attribution(
|
||||
chunk_id: str,
|
||||
store: DSFACorpusStore = Depends(get_store)
|
||||
):
|
||||
"""Get single chunk with full source attribution."""
|
||||
chunk = await store.get_chunk_with_attribution(chunk_id)
|
||||
|
||||
if not chunk:
|
||||
raise HTTPException(status_code=404, detail=f"Chunk not found: {chunk_id}")
|
||||
|
||||
license_info = LICENSE_REGISTRY.get(chunk.get("license_code", ""), {})
|
||||
|
||||
return DSFAChunkResponse(
|
||||
chunk_id=str(chunk["chunk_id"]),
|
||||
content=chunk.get("content", ""),
|
||||
section_title=chunk.get("section_title"),
|
||||
page_number=chunk.get("page_number"),
|
||||
category=chunk.get("category"),
|
||||
document_id=str(chunk.get("document_id", "")),
|
||||
document_title=chunk.get("document_title"),
|
||||
source_id=str(chunk.get("source_id", "")),
|
||||
source_code=chunk.get("source_code", ""),
|
||||
source_name=chunk.get("source_name", ""),
|
||||
attribution_text=chunk.get("attribution_text", ""),
|
||||
license_code=chunk.get("license_code", ""),
|
||||
license_name=license_info.get("name", chunk.get("license_code", "")),
|
||||
license_url=license_info.get("url"),
|
||||
attribution_required=chunk.get("attribution_required", True),
|
||||
source_url=chunk.get("source_url"),
|
||||
document_type=chunk.get("document_type")
|
||||
)
|
||||
|
||||
|
||||
@router.get("/stats", response_model=DSFACorpusStatsResponse)
|
||||
async def get_corpus_stats(
|
||||
store: DSFACorpusStore = Depends(get_store),
|
||||
qdrant: DSFAQdrantService = Depends(get_qdrant)
|
||||
):
|
||||
"""Get corpus statistics for dashboard."""
|
||||
# Get PostgreSQL stats
|
||||
source_stats = await store.get_source_stats()
|
||||
|
||||
total_docs = 0
|
||||
total_chunks = 0
|
||||
stats_response = []
|
||||
|
||||
for s in source_stats:
|
||||
doc_count = s.get("document_count", 0) or 0
|
||||
chunk_count = s.get("chunk_count", 0) or 0
|
||||
total_docs += doc_count
|
||||
total_chunks += chunk_count
|
||||
|
||||
last_indexed = s.get("last_indexed_at")
|
||||
|
||||
stats_response.append(DSFASourceStatsResponse(
|
||||
source_id=str(s.get("source_id", "")),
|
||||
source_code=s.get("source_code", ""),
|
||||
name=s.get("name", ""),
|
||||
organization=s.get("organization"),
|
||||
license_code=s.get("license_code", ""),
|
||||
document_type=s.get("document_type"),
|
||||
document_count=doc_count,
|
||||
chunk_count=chunk_count,
|
||||
last_indexed_at=last_indexed.isoformat() if last_indexed else None
|
||||
))
|
||||
|
||||
# Get Qdrant stats
|
||||
qdrant_stats = await qdrant.get_stats()
|
||||
|
||||
return DSFACorpusStatsResponse(
|
||||
sources=stats_response,
|
||||
total_sources=len(source_stats),
|
||||
total_documents=total_docs,
|
||||
total_chunks=total_chunks,
|
||||
qdrant_collection=DSFA_COLLECTION,
|
||||
qdrant_points_count=qdrant_stats.get("points_count", 0),
|
||||
qdrant_status=qdrant_stats.get("status", "unknown")
|
||||
)
|
||||
|
||||
|
||||
@router.get("/licenses")
|
||||
async def list_licenses():
|
||||
"""List all supported licenses with their terms."""
|
||||
return [
|
||||
LicenseInfo(
|
||||
code=code,
|
||||
name=info.get("name", code),
|
||||
url=info.get("url"),
|
||||
attribution_required=info.get("attribution_required", True),
|
||||
modification_allowed=info.get("modification_allowed", True),
|
||||
commercial_use=info.get("commercial_use", True)
|
||||
)
|
||||
for code, info in LICENSE_REGISTRY.items()
|
||||
]
|
||||
|
||||
|
||||
@router.post("/init")
|
||||
async def initialize_dsfa_corpus(
|
||||
store: DSFACorpusStore = Depends(get_store),
|
||||
qdrant: DSFAQdrantService = Depends(get_qdrant)
|
||||
):
|
||||
"""
|
||||
Initialize DSFA corpus.
|
||||
|
||||
- Creates Qdrant collection
|
||||
- Registers all predefined sources
|
||||
"""
|
||||
# Ensure Qdrant collection exists
|
||||
qdrant_ok = await qdrant.ensure_collection()
|
||||
|
||||
# Register all sources
|
||||
registered = 0
|
||||
for source in DSFA_SOURCES:
|
||||
try:
|
||||
await store.register_source(source)
|
||||
registered += 1
|
||||
except Exception as e:
|
||||
print(f"Error registering source {source['source_code']}: {e}")
|
||||
|
||||
return {
|
||||
"qdrant_collection_created": qdrant_ok,
|
||||
"sources_registered": registered,
|
||||
"total_sources": len(DSFA_SOURCES)
|
||||
}
|
||||
# Embedding utilities
|
||||
from dsfa_rag_embedding import (
|
||||
get_embedding,
|
||||
get_embeddings_batch,
|
||||
extract_text_from_url,
|
||||
EMBEDDING_SERVICE_URL,
|
||||
)
|
||||
|
||||
# Routes (router + set_db_pool)
|
||||
from dsfa_rag_routes import (
|
||||
router,
|
||||
set_db_pool,
|
||||
get_store,
|
||||
get_qdrant,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Router
|
||||
"router",
|
||||
"set_db_pool",
|
||||
"get_store",
|
||||
"get_qdrant",
|
||||
# Models
|
||||
"DSFASourceResponse",
|
||||
"DSFAChunkResponse",
|
||||
"DSFASearchResultResponse",
|
||||
"DSFASearchResponse",
|
||||
"DSFASourceStatsResponse",
|
||||
"DSFACorpusStatsResponse",
|
||||
"IngestRequest",
|
||||
"IngestResponse",
|
||||
"LicenseInfo",
|
||||
# Embedding
|
||||
"get_embedding",
|
||||
"get_embeddings_batch",
|
||||
"extract_text_from_url",
|
||||
"EMBEDDING_SERVICE_URL",
|
||||
]
|
||||
|
||||
116
klausur-service/backend/dsfa_rag_embedding.py
Normal file
116
klausur-service/backend/dsfa_rag_embedding.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
DSFA RAG Embedding Service Integration.
|
||||
|
||||
Handles embedding generation, text extraction, and fallback logic.
|
||||
"""
|
||||
|
||||
import os
|
||||
import hashlib
|
||||
import logging
|
||||
import struct
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Embedding service configuration
|
||||
EMBEDDING_SERVICE_URL = os.getenv("EMBEDDING_SERVICE_URL", "http://172.18.0.13:8087")
|
||||
|
||||
|
||||
async def get_embedding(text: str) -> List[float]:
|
||||
"""
|
||||
Get embedding for text using the embedding-service.
|
||||
|
||||
Uses BGE-M3 model which produces 1024-dimensional vectors.
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
try:
|
||||
response = await client.post(
|
||||
f"{EMBEDDING_SERVICE_URL}/embed-single",
|
||||
json={"text": text}
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("embedding", [])
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"Embedding service error: {e}")
|
||||
# Fallback to hash-based pseudo-embedding for development
|
||||
return _generate_fallback_embedding(text)
|
||||
|
||||
|
||||
async def get_embeddings_batch(texts: List[str]) -> List[List[float]]:
|
||||
"""
|
||||
Get embeddings for multiple texts in batch.
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
try:
|
||||
response = await client.post(
|
||||
f"{EMBEDDING_SERVICE_URL}/embed",
|
||||
json={"texts": texts}
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("embeddings", [])
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"Embedding service batch error: {e}")
|
||||
# Fallback
|
||||
return [_generate_fallback_embedding(t) for t in texts]
|
||||
|
||||
|
||||
async def extract_text_from_url(url: str) -> str:
|
||||
"""
|
||||
Extract text from a document URL (PDF, HTML, etc.).
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
try:
|
||||
# First try to use the embedding-service's extract-pdf endpoint
|
||||
response = await client.post(
|
||||
f"{EMBEDDING_SERVICE_URL}/extract-pdf",
|
||||
json={"url": url}
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("text", "")
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"PDF extraction error for {url}: {e}")
|
||||
# Fallback: try to fetch HTML content directly
|
||||
try:
|
||||
response = await client.get(url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
content_type = response.headers.get("content-type", "")
|
||||
if "html" in content_type:
|
||||
# Simple HTML text extraction
|
||||
html = response.text
|
||||
# Remove scripts and styles
|
||||
html = re.sub(r'<script[^>]*>.*?</script>', '', html, flags=re.DOTALL | re.IGNORECASE)
|
||||
html = re.sub(r'<style[^>]*>.*?</style>', '', html, flags=re.DOTALL | re.IGNORECASE)
|
||||
# Remove tags
|
||||
text = re.sub(r'<[^>]+>', ' ', html)
|
||||
# Clean whitespace
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
return text
|
||||
else:
|
||||
return ""
|
||||
except Exception as fetch_err:
|
||||
logger.error(f"Fallback fetch error for {url}: {fetch_err}")
|
||||
return ""
|
||||
|
||||
|
||||
def _generate_fallback_embedding(text: str) -> List[float]:
|
||||
"""
|
||||
Generate deterministic pseudo-embedding from text hash.
|
||||
Used as fallback when embedding service is unavailable.
|
||||
"""
|
||||
hash_bytes = hashlib.sha256(text.encode()).digest()
|
||||
embedding = []
|
||||
for i in range(0, min(len(hash_bytes), 128), 4):
|
||||
val = struct.unpack('f', hash_bytes[i:i+4])[0]
|
||||
embedding.append(val % 1.0)
|
||||
|
||||
# Pad to 1024 dimensions
|
||||
while len(embedding) < 1024:
|
||||
embedding.extend(embedding[:min(len(embedding), 1024 - len(embedding))])
|
||||
|
||||
return embedding[:1024]
|
||||
137
klausur-service/backend/dsfa_rag_models.py
Normal file
137
klausur-service/backend/dsfa_rag_models.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
DSFA RAG Pydantic Models.
|
||||
|
||||
Request/Response models for the DSFA RAG API.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Response Models
|
||||
# =============================================================================
|
||||
|
||||
class DSFASourceResponse(BaseModel):
|
||||
"""Response model for DSFA source."""
|
||||
id: str
|
||||
source_code: str
|
||||
name: str
|
||||
full_name: Optional[str] = None
|
||||
organization: Optional[str] = None
|
||||
source_url: Optional[str] = None
|
||||
license_code: str
|
||||
license_name: str
|
||||
license_url: Optional[str] = None
|
||||
attribution_required: bool
|
||||
attribution_text: str
|
||||
document_type: Optional[str] = None
|
||||
language: str = "de"
|
||||
|
||||
|
||||
class DSFAChunkResponse(BaseModel):
|
||||
"""Response model for a single chunk with attribution."""
|
||||
chunk_id: str
|
||||
content: str
|
||||
section_title: Optional[str] = None
|
||||
page_number: Optional[int] = None
|
||||
category: Optional[str] = None
|
||||
|
||||
# Document info
|
||||
document_id: str
|
||||
document_title: Optional[str] = None
|
||||
|
||||
# Attribution (always included)
|
||||
source_id: str
|
||||
source_code: str
|
||||
source_name: str
|
||||
attribution_text: str
|
||||
license_code: str
|
||||
license_name: str
|
||||
license_url: Optional[str] = None
|
||||
attribution_required: bool
|
||||
source_url: Optional[str] = None
|
||||
document_type: Optional[str] = None
|
||||
|
||||
|
||||
class DSFASearchResultResponse(BaseModel):
|
||||
"""Response model for search result."""
|
||||
chunk_id: str
|
||||
content: str
|
||||
score: float
|
||||
|
||||
# Attribution
|
||||
source_code: str
|
||||
source_name: str
|
||||
attribution_text: str
|
||||
license_code: str
|
||||
license_name: str
|
||||
license_url: Optional[str] = None
|
||||
attribution_required: bool
|
||||
source_url: Optional[str] = None
|
||||
|
||||
# Metadata
|
||||
document_type: Optional[str] = None
|
||||
category: Optional[str] = None
|
||||
section_title: Optional[str] = None
|
||||
page_number: Optional[int] = None
|
||||
|
||||
|
||||
class DSFASearchResponse(BaseModel):
|
||||
"""Response model for search endpoint."""
|
||||
query: str
|
||||
results: List[DSFASearchResultResponse]
|
||||
total_results: int
|
||||
|
||||
# Aggregated licenses for footer
|
||||
licenses_used: List[str]
|
||||
attribution_notice: str
|
||||
|
||||
|
||||
class DSFASourceStatsResponse(BaseModel):
|
||||
"""Response model for source statistics."""
|
||||
source_id: str
|
||||
source_code: str
|
||||
name: str
|
||||
organization: Optional[str] = None
|
||||
license_code: str
|
||||
document_type: Optional[str] = None
|
||||
document_count: int
|
||||
chunk_count: int
|
||||
last_indexed_at: Optional[str] = None
|
||||
|
||||
|
||||
class DSFACorpusStatsResponse(BaseModel):
|
||||
"""Response model for corpus statistics."""
|
||||
sources: List[DSFASourceStatsResponse]
|
||||
total_sources: int
|
||||
total_documents: int
|
||||
total_chunks: int
|
||||
qdrant_collection: str
|
||||
qdrant_points_count: int
|
||||
qdrant_status: str
|
||||
|
||||
|
||||
class IngestRequest(BaseModel):
|
||||
"""Request model for ingestion."""
|
||||
document_url: Optional[str] = None
|
||||
document_text: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
|
||||
|
||||
class IngestResponse(BaseModel):
|
||||
"""Response model for ingestion."""
|
||||
source_code: str
|
||||
document_id: Optional[str] = None
|
||||
chunks_created: int
|
||||
message: str
|
||||
|
||||
|
||||
class LicenseInfo(BaseModel):
|
||||
"""License information."""
|
||||
code: str
|
||||
name: str
|
||||
url: Optional[str] = None
|
||||
attribution_required: bool
|
||||
modification_allowed: bool
|
||||
commercial_use: bool
|
||||
461
klausur-service/backend/dsfa_rag_routes.py
Normal file
461
klausur-service/backend/dsfa_rag_routes.py
Normal file
@@ -0,0 +1,461 @@
|
||||
"""
|
||||
DSFA RAG API Route Handlers.
|
||||
|
||||
Endpoint implementations for search, sources, ingestion, stats, and init.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query, Depends
|
||||
|
||||
from dsfa_corpus_ingestion import (
|
||||
DSFACorpusStore,
|
||||
DSFAQdrantService,
|
||||
DSFASearchResult,
|
||||
LICENSE_REGISTRY,
|
||||
DSFA_SOURCES,
|
||||
generate_attribution_notice,
|
||||
get_license_label,
|
||||
DSFA_COLLECTION,
|
||||
chunk_document,
|
||||
)
|
||||
|
||||
from dsfa_rag_models import (
|
||||
DSFASourceResponse,
|
||||
DSFAChunkResponse,
|
||||
DSFASearchResultResponse,
|
||||
DSFASearchResponse,
|
||||
DSFASourceStatsResponse,
|
||||
DSFACorpusStatsResponse,
|
||||
IngestRequest,
|
||||
IngestResponse,
|
||||
LicenseInfo,
|
||||
)
|
||||
|
||||
from dsfa_rag_embedding import (
|
||||
get_embedding,
|
||||
get_embeddings_batch,
|
||||
extract_text_from_url,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/dsfa-rag", tags=["DSFA RAG"])
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Dependency Injection
|
||||
# =============================================================================
|
||||
|
||||
_db_pool = None
|
||||
|
||||
|
||||
def set_db_pool(pool):
|
||||
"""Set the database pool for API endpoints."""
|
||||
global _db_pool
|
||||
_db_pool = pool
|
||||
|
||||
|
||||
async def get_store() -> DSFACorpusStore:
|
||||
"""Get DSFA corpus store."""
|
||||
if _db_pool is None:
|
||||
raise HTTPException(status_code=503, detail="Database not initialized")
|
||||
return DSFACorpusStore(_db_pool)
|
||||
|
||||
|
||||
async def get_qdrant() -> DSFAQdrantService:
|
||||
"""Get Qdrant service."""
|
||||
return DSFAQdrantService()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# API Endpoints
|
||||
# =============================================================================
|
||||
|
||||
@router.get("/search", response_model=DSFASearchResponse)
|
||||
async def search_dsfa_corpus(
|
||||
query: str = Query(..., min_length=3, description="Search query"),
|
||||
source_codes: Optional[List[str]] = Query(None, description="Filter by source codes"),
|
||||
document_types: Optional[List[str]] = Query(None, description="Filter by document types (guideline, checklist, regulation)"),
|
||||
categories: Optional[List[str]] = Query(None, description="Filter by categories (threshold_analysis, risk_assessment, mitigation)"),
|
||||
limit: int = Query(10, ge=1, le=50, description="Maximum results"),
|
||||
include_attribution: bool = Query(True, description="Include attribution in results"),
|
||||
store: DSFACorpusStore = Depends(get_store),
|
||||
qdrant: DSFAQdrantService = Depends(get_qdrant)
|
||||
):
|
||||
"""
|
||||
Search DSFA corpus with full attribution.
|
||||
|
||||
Returns matching chunks with source/license information for compliance.
|
||||
"""
|
||||
query_embedding = await get_embedding(query)
|
||||
|
||||
raw_results = await qdrant.search(
|
||||
query_embedding=query_embedding,
|
||||
source_codes=source_codes,
|
||||
document_types=document_types,
|
||||
categories=categories,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
results = []
|
||||
licenses_used = set()
|
||||
|
||||
for r in raw_results:
|
||||
license_code = r.get("license_code", "")
|
||||
license_info = LICENSE_REGISTRY.get(license_code, {})
|
||||
|
||||
result = DSFASearchResultResponse(
|
||||
chunk_id=r.get("chunk_id", ""),
|
||||
content=r.get("content", ""),
|
||||
score=r.get("score", 0.0),
|
||||
source_code=r.get("source_code", ""),
|
||||
source_name=r.get("source_name", ""),
|
||||
attribution_text=r.get("attribution_text", ""),
|
||||
license_code=license_code,
|
||||
license_name=license_info.get("name", license_code),
|
||||
license_url=license_info.get("url"),
|
||||
attribution_required=r.get("attribution_required", True),
|
||||
source_url=r.get("source_url"),
|
||||
document_type=r.get("document_type"),
|
||||
category=r.get("category"),
|
||||
section_title=r.get("section_title"),
|
||||
page_number=r.get("page_number")
|
||||
)
|
||||
results.append(result)
|
||||
licenses_used.add(license_code)
|
||||
|
||||
# Generate attribution notice
|
||||
search_results = [
|
||||
DSFASearchResult(
|
||||
chunk_id=r.chunk_id,
|
||||
content=r.content,
|
||||
score=r.score,
|
||||
source_code=r.source_code,
|
||||
source_name=r.source_name,
|
||||
attribution_text=r.attribution_text,
|
||||
license_code=r.license_code,
|
||||
license_url=r.license_url,
|
||||
attribution_required=r.attribution_required,
|
||||
source_url=r.source_url,
|
||||
document_type=r.document_type or "",
|
||||
category=r.category or "",
|
||||
section_title=r.section_title,
|
||||
page_number=r.page_number
|
||||
)
|
||||
for r in results
|
||||
]
|
||||
attribution_notice = generate_attribution_notice(search_results) if include_attribution else ""
|
||||
|
||||
return DSFASearchResponse(
|
||||
query=query,
|
||||
results=results,
|
||||
total_results=len(results),
|
||||
licenses_used=list(licenses_used),
|
||||
attribution_notice=attribution_notice
|
||||
)
|
||||
|
||||
|
||||
@router.get("/sources", response_model=List[DSFASourceResponse])
|
||||
async def list_dsfa_sources(
|
||||
document_type: Optional[str] = Query(None, description="Filter by document type"),
|
||||
license_code: Optional[str] = Query(None, description="Filter by license"),
|
||||
store: DSFACorpusStore = Depends(get_store)
|
||||
):
|
||||
"""List all registered DSFA sources with license info."""
|
||||
sources = await store.list_sources()
|
||||
|
||||
result = []
|
||||
for s in sources:
|
||||
if document_type and s.get("document_type") != document_type:
|
||||
continue
|
||||
if license_code and s.get("license_code") != license_code:
|
||||
continue
|
||||
|
||||
license_info = LICENSE_REGISTRY.get(s.get("license_code", ""), {})
|
||||
|
||||
result.append(DSFASourceResponse(
|
||||
id=str(s["id"]),
|
||||
source_code=s["source_code"],
|
||||
name=s["name"],
|
||||
full_name=s.get("full_name"),
|
||||
organization=s.get("organization"),
|
||||
source_url=s.get("source_url"),
|
||||
license_code=s.get("license_code", ""),
|
||||
license_name=license_info.get("name", s.get("license_code", "")),
|
||||
license_url=license_info.get("url"),
|
||||
attribution_required=s.get("attribution_required", True),
|
||||
attribution_text=s.get("attribution_text", ""),
|
||||
document_type=s.get("document_type"),
|
||||
language=s.get("language", "de")
|
||||
))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/sources/available")
|
||||
async def list_available_sources():
|
||||
"""List all available source definitions (from DSFA_SOURCES constant)."""
|
||||
return [
|
||||
{
|
||||
"source_code": s["source_code"],
|
||||
"name": s["name"],
|
||||
"organization": s.get("organization"),
|
||||
"license_code": s["license_code"],
|
||||
"document_type": s.get("document_type")
|
||||
}
|
||||
for s in DSFA_SOURCES
|
||||
]
|
||||
|
||||
|
||||
@router.get("/sources/{source_code}", response_model=DSFASourceResponse)
|
||||
async def get_dsfa_source(
|
||||
source_code: str,
|
||||
store: DSFACorpusStore = Depends(get_store)
|
||||
):
|
||||
"""Get details for a specific source."""
|
||||
source = await store.get_source_by_code(source_code)
|
||||
|
||||
if not source:
|
||||
raise HTTPException(status_code=404, detail=f"Source not found: {source_code}")
|
||||
|
||||
license_info = LICENSE_REGISTRY.get(source.get("license_code", ""), {})
|
||||
|
||||
return DSFASourceResponse(
|
||||
id=str(source["id"]),
|
||||
source_code=source["source_code"],
|
||||
name=source["name"],
|
||||
full_name=source.get("full_name"),
|
||||
organization=source.get("organization"),
|
||||
source_url=source.get("source_url"),
|
||||
license_code=source.get("license_code", ""),
|
||||
license_name=license_info.get("name", source.get("license_code", "")),
|
||||
license_url=license_info.get("url"),
|
||||
attribution_required=source.get("attribution_required", True),
|
||||
attribution_text=source.get("attribution_text", ""),
|
||||
document_type=source.get("document_type"),
|
||||
language=source.get("language", "de")
|
||||
)
|
||||
|
||||
|
||||
@router.post("/sources/{source_code}/ingest", response_model=IngestResponse)
|
||||
async def ingest_dsfa_source(
|
||||
source_code: str,
|
||||
request: IngestRequest,
|
||||
store: DSFACorpusStore = Depends(get_store),
|
||||
qdrant: DSFAQdrantService = Depends(get_qdrant)
|
||||
):
|
||||
"""
|
||||
Trigger ingestion for a specific source.
|
||||
|
||||
Can provide document via URL or direct text.
|
||||
"""
|
||||
source = await store.get_source_by_code(source_code)
|
||||
if not source:
|
||||
raise HTTPException(status_code=404, detail=f"Source not found: {source_code}")
|
||||
|
||||
if not request.document_text and not request.document_url:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Either document_text or document_url must be provided"
|
||||
)
|
||||
|
||||
await qdrant.ensure_collection()
|
||||
|
||||
text_content = request.document_text
|
||||
if request.document_url and not text_content:
|
||||
logger.info(f"Extracting text from URL: {request.document_url}")
|
||||
text_content = await extract_text_from_url(request.document_url)
|
||||
if not text_content:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Could not extract text from URL: {request.document_url}"
|
||||
)
|
||||
|
||||
if not text_content or len(text_content.strip()) < 50:
|
||||
raise HTTPException(status_code=400, detail="Document text too short (min 50 chars)")
|
||||
|
||||
doc_title = request.title or f"Document for {source_code}"
|
||||
document_id = await store.create_document(
|
||||
source_id=str(source["id"]),
|
||||
title=doc_title,
|
||||
file_type="text",
|
||||
metadata={"ingested_via": "api", "source_code": source_code}
|
||||
)
|
||||
|
||||
chunks = chunk_document(text_content, source_code)
|
||||
|
||||
if not chunks:
|
||||
return IngestResponse(
|
||||
source_code=source_code,
|
||||
document_id=document_id,
|
||||
chunks_created=0,
|
||||
message="Document created but no chunks generated"
|
||||
)
|
||||
|
||||
chunk_texts = [chunk["content"] for chunk in chunks]
|
||||
logger.info(f"Generating embeddings for {len(chunk_texts)} chunks...")
|
||||
embeddings = await get_embeddings_batch(chunk_texts)
|
||||
|
||||
chunk_records = []
|
||||
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
|
||||
chunk_id = await store.create_chunk(
|
||||
document_id=document_id,
|
||||
source_id=str(source["id"]),
|
||||
content=chunk["content"],
|
||||
chunk_index=i,
|
||||
section_title=chunk.get("section_title"),
|
||||
page_number=chunk.get("page_number"),
|
||||
category=chunk.get("category")
|
||||
)
|
||||
|
||||
chunk_records.append({
|
||||
"chunk_id": chunk_id,
|
||||
"document_id": document_id,
|
||||
"source_id": str(source["id"]),
|
||||
"content": chunk["content"],
|
||||
"section_title": chunk.get("section_title"),
|
||||
"source_code": source_code,
|
||||
"source_name": source["name"],
|
||||
"attribution_text": source["attribution_text"],
|
||||
"license_code": source["license_code"],
|
||||
"attribution_required": source.get("attribution_required", True),
|
||||
"document_type": source.get("document_type", ""),
|
||||
"category": chunk.get("category", ""),
|
||||
"language": source.get("language", "de"),
|
||||
"page_number": chunk.get("page_number")
|
||||
})
|
||||
|
||||
indexed_count = await qdrant.index_chunks(chunk_records, embeddings)
|
||||
await store.update_document_indexed(document_id, len(chunks))
|
||||
|
||||
return IngestResponse(
|
||||
source_code=source_code,
|
||||
document_id=document_id,
|
||||
chunks_created=indexed_count,
|
||||
message=f"Successfully ingested {indexed_count} chunks from document"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/chunks/{chunk_id}", response_model=DSFAChunkResponse)
|
||||
async def get_chunk_with_attribution(
|
||||
chunk_id: str,
|
||||
store: DSFACorpusStore = Depends(get_store)
|
||||
):
|
||||
"""Get single chunk with full source attribution."""
|
||||
chunk = await store.get_chunk_with_attribution(chunk_id)
|
||||
|
||||
if not chunk:
|
||||
raise HTTPException(status_code=404, detail=f"Chunk not found: {chunk_id}")
|
||||
|
||||
license_info = LICENSE_REGISTRY.get(chunk.get("license_code", ""), {})
|
||||
|
||||
return DSFAChunkResponse(
|
||||
chunk_id=str(chunk["chunk_id"]),
|
||||
content=chunk.get("content", ""),
|
||||
section_title=chunk.get("section_title"),
|
||||
page_number=chunk.get("page_number"),
|
||||
category=chunk.get("category"),
|
||||
document_id=str(chunk.get("document_id", "")),
|
||||
document_title=chunk.get("document_title"),
|
||||
source_id=str(chunk.get("source_id", "")),
|
||||
source_code=chunk.get("source_code", ""),
|
||||
source_name=chunk.get("source_name", ""),
|
||||
attribution_text=chunk.get("attribution_text", ""),
|
||||
license_code=chunk.get("license_code", ""),
|
||||
license_name=license_info.get("name", chunk.get("license_code", "")),
|
||||
license_url=license_info.get("url"),
|
||||
attribution_required=chunk.get("attribution_required", True),
|
||||
source_url=chunk.get("source_url"),
|
||||
document_type=chunk.get("document_type")
|
||||
)
|
||||
|
||||
|
||||
@router.get("/stats", response_model=DSFACorpusStatsResponse)
|
||||
async def get_corpus_stats(
|
||||
store: DSFACorpusStore = Depends(get_store),
|
||||
qdrant: DSFAQdrantService = Depends(get_qdrant)
|
||||
):
|
||||
"""Get corpus statistics for dashboard."""
|
||||
source_stats = await store.get_source_stats()
|
||||
|
||||
total_docs = 0
|
||||
total_chunks = 0
|
||||
stats_response = []
|
||||
|
||||
for s in source_stats:
|
||||
doc_count = s.get("document_count", 0) or 0
|
||||
chunk_count = s.get("chunk_count", 0) or 0
|
||||
total_docs += doc_count
|
||||
total_chunks += chunk_count
|
||||
|
||||
last_indexed = s.get("last_indexed_at")
|
||||
|
||||
stats_response.append(DSFASourceStatsResponse(
|
||||
source_id=str(s.get("source_id", "")),
|
||||
source_code=s.get("source_code", ""),
|
||||
name=s.get("name", ""),
|
||||
organization=s.get("organization"),
|
||||
license_code=s.get("license_code", ""),
|
||||
document_type=s.get("document_type"),
|
||||
document_count=doc_count,
|
||||
chunk_count=chunk_count,
|
||||
last_indexed_at=last_indexed.isoformat() if last_indexed else None
|
||||
))
|
||||
|
||||
qdrant_stats = await qdrant.get_stats()
|
||||
|
||||
return DSFACorpusStatsResponse(
|
||||
sources=stats_response,
|
||||
total_sources=len(source_stats),
|
||||
total_documents=total_docs,
|
||||
total_chunks=total_chunks,
|
||||
qdrant_collection=DSFA_COLLECTION,
|
||||
qdrant_points_count=qdrant_stats.get("points_count", 0),
|
||||
qdrant_status=qdrant_stats.get("status", "unknown")
|
||||
)
|
||||
|
||||
|
||||
@router.get("/licenses")
|
||||
async def list_licenses():
|
||||
"""List all supported licenses with their terms."""
|
||||
return [
|
||||
LicenseInfo(
|
||||
code=code,
|
||||
name=info.get("name", code),
|
||||
url=info.get("url"),
|
||||
attribution_required=info.get("attribution_required", True),
|
||||
modification_allowed=info.get("modification_allowed", True),
|
||||
commercial_use=info.get("commercial_use", True)
|
||||
)
|
||||
for code, info in LICENSE_REGISTRY.items()
|
||||
]
|
||||
|
||||
|
||||
@router.post("/init")
|
||||
async def initialize_dsfa_corpus(
|
||||
store: DSFACorpusStore = Depends(get_store),
|
||||
qdrant: DSFAQdrantService = Depends(get_qdrant)
|
||||
):
|
||||
"""
|
||||
Initialize DSFA corpus.
|
||||
|
||||
- Creates Qdrant collection
|
||||
- Registers all predefined sources
|
||||
"""
|
||||
qdrant_ok = await qdrant.ensure_collection()
|
||||
|
||||
registered = 0
|
||||
for source in DSFA_SOURCES:
|
||||
try:
|
||||
await store.register_source(source)
|
||||
registered += 1
|
||||
except Exception as e:
|
||||
print(f"Error registering source {source['source_code']}: {e}")
|
||||
|
||||
return {
|
||||
"qdrant_collection_created": qdrant_ok,
|
||||
"sources_registered": registered,
|
||||
"total_sources": len(DSFA_SOURCES)
|
||||
}
|
||||
@@ -1,31 +1,19 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Full Compliance Pipeline for Legal Corpus.
|
||||
Full Compliance Pipeline for Legal Corpus — Barrel Re-export.
|
||||
|
||||
This script runs the complete pipeline:
|
||||
1. Re-ingest all legal documents with improved chunking
|
||||
2. Extract requirements/checkpoints from chunks
|
||||
3. Generate controls using AI
|
||||
4. Define remediation measures
|
||||
5. Update statistics
|
||||
Split into submodules:
|
||||
- compliance_models.py — Dataclasses (Checkpoint, Control, Measure)
|
||||
- compliance_extraction.py — Pattern extraction & control/measure generation
|
||||
- compliance_pipeline.py — Pipeline phases & orchestrator
|
||||
|
||||
Run on Mac Mini:
|
||||
nohup python full_compliance_pipeline.py > /tmp/compliance_pipeline.log 2>&1 &
|
||||
|
||||
Checkpoints are saved to /tmp/pipeline_checkpoints.json and can be viewed in admin-v2.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Any, Optional
|
||||
from dataclasses import dataclass, asdict
|
||||
import re
|
||||
import hashlib
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
@@ -36,671 +24,25 @@ logging.basicConfig(
|
||||
logging.FileHandler('/tmp/compliance_pipeline.log')
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Import checkpoint manager
|
||||
try:
|
||||
from pipeline_checkpoints import CheckpointManager, EXPECTED_VALUES, ValidationStatus
|
||||
except ImportError:
|
||||
logger.warning("Checkpoint manager not available, running without checkpoints")
|
||||
CheckpointManager = None
|
||||
EXPECTED_VALUES = {}
|
||||
ValidationStatus = None
|
||||
|
||||
# Set environment variables for Docker network
|
||||
# Support both QDRANT_URL and QDRANT_HOST
|
||||
if not os.getenv("QDRANT_URL") and not os.getenv("QDRANT_HOST"):
|
||||
os.environ["QDRANT_HOST"] = "qdrant"
|
||||
os.environ.setdefault("EMBEDDING_SERVICE_URL", "http://embedding-service:8087")
|
||||
|
||||
# Try to import from klausur-service
|
||||
try:
|
||||
from legal_corpus_ingestion import LegalCorpusIngestion, REGULATIONS, LEGAL_CORPUS_COLLECTION
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import Filter, FieldCondition, MatchValue
|
||||
except ImportError:
|
||||
logger.error("Could not import required modules. Make sure you're in the klausur-service container.")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Checkpoint:
|
||||
"""A requirement/checkpoint extracted from legal text."""
|
||||
id: str
|
||||
regulation_code: str
|
||||
regulation_name: str
|
||||
article: Optional[str]
|
||||
title: str
|
||||
description: str
|
||||
original_text: str
|
||||
chunk_id: str
|
||||
source_url: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Control:
|
||||
"""A control derived from checkpoints."""
|
||||
id: str
|
||||
domain: str
|
||||
title: str
|
||||
description: str
|
||||
checkpoints: List[str] # List of checkpoint IDs
|
||||
pass_criteria: str
|
||||
implementation_guidance: str
|
||||
is_automated: bool
|
||||
automation_tool: Optional[str]
|
||||
priority: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Measure:
|
||||
"""A remediation measure for a control."""
|
||||
id: str
|
||||
control_id: str
|
||||
title: str
|
||||
description: str
|
||||
responsible: str
|
||||
deadline_days: int
|
||||
status: str
|
||||
|
||||
|
||||
class CompliancePipeline:
|
||||
"""Handles the full compliance pipeline."""
|
||||
|
||||
def __init__(self):
|
||||
# Support both QDRANT_URL and QDRANT_HOST/PORT
|
||||
qdrant_url = os.getenv("QDRANT_URL", "")
|
||||
if qdrant_url:
|
||||
from urllib.parse import urlparse
|
||||
parsed = urlparse(qdrant_url)
|
||||
qdrant_host = parsed.hostname or "qdrant"
|
||||
qdrant_port = parsed.port or 6333
|
||||
else:
|
||||
qdrant_host = os.getenv("QDRANT_HOST", "qdrant")
|
||||
qdrant_port = 6333
|
||||
self.qdrant = QdrantClient(host=qdrant_host, port=qdrant_port)
|
||||
self.checkpoints: List[Checkpoint] = []
|
||||
self.controls: List[Control] = []
|
||||
self.measures: List[Measure] = []
|
||||
self.stats = {
|
||||
"chunks_processed": 0,
|
||||
"checkpoints_extracted": 0,
|
||||
"controls_created": 0,
|
||||
"measures_defined": 0,
|
||||
"by_regulation": {},
|
||||
"by_domain": {},
|
||||
}
|
||||
# Initialize checkpoint manager
|
||||
self.checkpoint_mgr = CheckpointManager() if CheckpointManager else None
|
||||
|
||||
def extract_checkpoints_from_chunk(self, chunk_text: str, payload: Dict) -> List[Checkpoint]:
|
||||
"""
|
||||
Extract checkpoints/requirements from a chunk of text.
|
||||
|
||||
Uses pattern matching to find requirement-like statements.
|
||||
"""
|
||||
checkpoints = []
|
||||
regulation_code = payload.get("regulation_code", "UNKNOWN")
|
||||
regulation_name = payload.get("regulation_name", "Unknown")
|
||||
source_url = payload.get("source_url", "")
|
||||
chunk_id = hashlib.md5(chunk_text[:100].encode()).hexdigest()[:8]
|
||||
|
||||
# Patterns for different requirement types
|
||||
patterns = [
|
||||
# BSI-TR patterns
|
||||
(r'([OT]\.[A-Za-z_]+\d*)[:\s]+(.+?)(?=\n[OT]\.|$)', 'bsi_requirement'),
|
||||
# Article patterns (GDPR, AI Act, etc.)
|
||||
(r'(?:Artikel|Art\.?)\s+(\d+)(?:\s+Abs(?:atz)?\.?\s*(\d+))?\s*[-–:]\s*(.+?)(?=\n|$)', 'article'),
|
||||
# Numbered requirements
|
||||
(r'\((\d+)\)\s+(.+?)(?=\n\(\d+\)|$)', 'numbered'),
|
||||
# "Der Verantwortliche muss" patterns
|
||||
(r'(?:Der Verantwortliche|Die Aufsichtsbehörde|Der Auftragsverarbeiter)\s+(muss|hat|soll)\s+(.+?)(?=\.\s|$)', 'obligation'),
|
||||
# "Es ist erforderlich" patterns
|
||||
(r'(?:Es ist erforderlich|Es muss gewährleistet|Es sind geeignete)\s+(.+?)(?=\.\s|$)', 'requirement'),
|
||||
]
|
||||
|
||||
for pattern, pattern_type in patterns:
|
||||
matches = re.finditer(pattern, chunk_text, re.MULTILINE | re.DOTALL)
|
||||
for match in matches:
|
||||
if pattern_type == 'bsi_requirement':
|
||||
req_id = match.group(1)
|
||||
description = match.group(2).strip()
|
||||
title = req_id
|
||||
elif pattern_type == 'article':
|
||||
article_num = match.group(1)
|
||||
paragraph = match.group(2) or ""
|
||||
title_text = match.group(3).strip()
|
||||
req_id = f"{regulation_code}-Art{article_num}"
|
||||
if paragraph:
|
||||
req_id += f"-{paragraph}"
|
||||
title = f"Art. {article_num}" + (f" Abs. {paragraph}" if paragraph else "")
|
||||
description = title_text
|
||||
elif pattern_type == 'numbered':
|
||||
num = match.group(1)
|
||||
description = match.group(2).strip()
|
||||
req_id = f"{regulation_code}-{num}"
|
||||
title = f"Anforderung {num}"
|
||||
else:
|
||||
# Generic requirement
|
||||
description = match.group(0).strip()
|
||||
req_id = f"{regulation_code}-{chunk_id}-{len(checkpoints)}"
|
||||
title = description[:50] + "..." if len(description) > 50 else description
|
||||
|
||||
# Skip very short matches
|
||||
if len(description) < 20:
|
||||
continue
|
||||
|
||||
checkpoint = Checkpoint(
|
||||
id=req_id,
|
||||
regulation_code=regulation_code,
|
||||
regulation_name=regulation_name,
|
||||
article=title if 'Art' in title else None,
|
||||
title=title,
|
||||
description=description[:500],
|
||||
original_text=description,
|
||||
chunk_id=chunk_id,
|
||||
source_url=source_url
|
||||
)
|
||||
checkpoints.append(checkpoint)
|
||||
|
||||
return checkpoints
|
||||
|
||||
def generate_control_for_checkpoints(self, checkpoints: List[Checkpoint]) -> Optional[Control]:
|
||||
"""
|
||||
Generate a control that covers the given checkpoints.
|
||||
|
||||
This is a simplified version - in production this would use the AI assistant.
|
||||
"""
|
||||
if not checkpoints:
|
||||
return None
|
||||
|
||||
# Group by regulation
|
||||
regulation = checkpoints[0].regulation_code
|
||||
|
||||
# Determine domain based on content
|
||||
all_text = " ".join([cp.description for cp in checkpoints]).lower()
|
||||
|
||||
domain = "gov" # Default
|
||||
if any(kw in all_text for kw in ["verschlüssel", "krypto", "encrypt", "hash"]):
|
||||
domain = "crypto"
|
||||
elif any(kw in all_text for kw in ["zugang", "access", "authentif", "login", "benutzer"]):
|
||||
domain = "iam"
|
||||
elif any(kw in all_text for kw in ["datenschutz", "personenbezogen", "privacy", "einwilligung"]):
|
||||
domain = "priv"
|
||||
elif any(kw in all_text for kw in ["entwicklung", "test", "code", "software"]):
|
||||
domain = "sdlc"
|
||||
elif any(kw in all_text for kw in ["überwach", "monitor", "log", "audit"]):
|
||||
domain = "aud"
|
||||
elif any(kw in all_text for kw in ["ki", "künstlich", "ai", "machine learning", "model"]):
|
||||
domain = "ai"
|
||||
elif any(kw in all_text for kw in ["betrieb", "operation", "verfügbar", "backup"]):
|
||||
domain = "ops"
|
||||
elif any(kw in all_text for kw in ["cyber", "resilience", "sbom", "vulnerab"]):
|
||||
domain = "cra"
|
||||
|
||||
# Generate control ID
|
||||
domain_counts = self.stats.get("by_domain", {})
|
||||
domain_count = domain_counts.get(domain, 0) + 1
|
||||
control_id = f"{domain.upper()}-{domain_count:03d}"
|
||||
|
||||
# Create title from first checkpoint
|
||||
title = checkpoints[0].title
|
||||
if len(title) > 100:
|
||||
title = title[:97] + "..."
|
||||
|
||||
# Create description
|
||||
description = f"Control für {regulation}: " + checkpoints[0].description[:200]
|
||||
|
||||
# Pass criteria
|
||||
pass_criteria = f"Alle {len(checkpoints)} zugehörigen Anforderungen sind erfüllt und dokumentiert."
|
||||
|
||||
# Implementation guidance
|
||||
guidance = f"Implementiere Maßnahmen zur Erfüllung der Anforderungen aus {regulation}. "
|
||||
guidance += f"Dokumentiere die Umsetzung und führe regelmäßige Reviews durch."
|
||||
|
||||
# Determine if automated
|
||||
is_automated = any(kw in all_text for kw in ["automat", "tool", "scan", "test"])
|
||||
|
||||
control = Control(
|
||||
id=control_id,
|
||||
domain=domain,
|
||||
title=title,
|
||||
description=description,
|
||||
checkpoints=[cp.id for cp in checkpoints],
|
||||
pass_criteria=pass_criteria,
|
||||
implementation_guidance=guidance,
|
||||
is_automated=is_automated,
|
||||
automation_tool="CI/CD Pipeline" if is_automated else None,
|
||||
priority="high" if "muss" in all_text or "erforderlich" in all_text else "medium"
|
||||
)
|
||||
|
||||
return control
|
||||
|
||||
def generate_measure_for_control(self, control: Control) -> Measure:
|
||||
"""Generate a remediation measure for a control."""
|
||||
measure_id = f"M-{control.id}"
|
||||
|
||||
# Determine deadline based on priority
|
||||
deadline_days = {
|
||||
"critical": 30,
|
||||
"high": 60,
|
||||
"medium": 90,
|
||||
"low": 180
|
||||
}.get(control.priority, 90)
|
||||
|
||||
# Determine responsible team
|
||||
responsible = {
|
||||
"priv": "Datenschutzbeauftragter",
|
||||
"iam": "IT-Security Team",
|
||||
"sdlc": "Entwicklungsteam",
|
||||
"crypto": "IT-Security Team",
|
||||
"ops": "Operations Team",
|
||||
"aud": "Compliance Team",
|
||||
"ai": "AI/ML Team",
|
||||
"cra": "IT-Security Team",
|
||||
"gov": "Management"
|
||||
}.get(control.domain, "Compliance Team")
|
||||
|
||||
measure = Measure(
|
||||
id=measure_id,
|
||||
control_id=control.id,
|
||||
title=f"Umsetzung: {control.title[:50]}",
|
||||
description=f"Implementierung und Dokumentation von {control.id}: {control.description[:100]}",
|
||||
responsible=responsible,
|
||||
deadline_days=deadline_days,
|
||||
status="pending"
|
||||
)
|
||||
|
||||
return measure
|
||||
|
||||
async def run_ingestion_phase(self, force_reindex: bool = False) -> int:
|
||||
"""Phase 1: Ingest documents (incremental - only missing ones)."""
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("PHASE 1: DOCUMENT INGESTION (INCREMENTAL)")
|
||||
logger.info("=" * 60)
|
||||
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.start_checkpoint("ingestion", "Document Ingestion")
|
||||
|
||||
ingestion = LegalCorpusIngestion()
|
||||
|
||||
try:
|
||||
# Check existing chunks per regulation
|
||||
existing_chunks = {}
|
||||
try:
|
||||
for regulation in REGULATIONS:
|
||||
count_result = self.qdrant.count(
|
||||
collection_name=LEGAL_CORPUS_COLLECTION,
|
||||
count_filter=Filter(
|
||||
must=[FieldCondition(key="regulation_code", match=MatchValue(value=regulation.code))]
|
||||
)
|
||||
)
|
||||
existing_chunks[regulation.code] = count_result.count
|
||||
logger.info(f" {regulation.code}: {count_result.count} existing chunks")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not check existing chunks: {e}")
|
||||
# Collection might not exist, that's OK
|
||||
|
||||
# Determine which regulations need ingestion
|
||||
regulations_to_ingest = []
|
||||
for regulation in REGULATIONS:
|
||||
existing = existing_chunks.get(regulation.code, 0)
|
||||
if force_reindex or existing == 0:
|
||||
regulations_to_ingest.append(regulation)
|
||||
logger.info(f" -> Will ingest: {regulation.code} (existing: {existing}, force: {force_reindex})")
|
||||
else:
|
||||
logger.info(f" -> Skipping: {regulation.code} (already has {existing} chunks)")
|
||||
self.stats["by_regulation"][regulation.code] = existing
|
||||
|
||||
if not regulations_to_ingest:
|
||||
logger.info("All regulations already indexed. Skipping ingestion phase.")
|
||||
total_chunks = sum(existing_chunks.values())
|
||||
self.stats["chunks_processed"] = total_chunks
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.add_metric("total_chunks", total_chunks)
|
||||
self.checkpoint_mgr.add_metric("skipped", True)
|
||||
self.checkpoint_mgr.complete_checkpoint(success=True)
|
||||
return total_chunks
|
||||
|
||||
# Ingest only missing regulations
|
||||
total_chunks = sum(existing_chunks.values())
|
||||
for i, regulation in enumerate(regulations_to_ingest, 1):
|
||||
logger.info(f"[{i}/{len(regulations_to_ingest)}] Ingesting {regulation.code}...")
|
||||
try:
|
||||
count = await ingestion.ingest_regulation(regulation)
|
||||
total_chunks += count
|
||||
self.stats["by_regulation"][regulation.code] = count
|
||||
logger.info(f" -> {count} chunks")
|
||||
|
||||
# Add metric for this regulation
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.add_metric(f"chunks_{regulation.code}", count)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f" -> FAILED: {e}")
|
||||
self.stats["by_regulation"][regulation.code] = 0
|
||||
|
||||
self.stats["chunks_processed"] = total_chunks
|
||||
logger.info(f"\nTotal chunks in collection: {total_chunks}")
|
||||
|
||||
# Validate ingestion results
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.add_metric("total_chunks", total_chunks)
|
||||
self.checkpoint_mgr.add_metric("regulations_count", len(REGULATIONS))
|
||||
|
||||
# Validate total chunks
|
||||
expected = EXPECTED_VALUES.get("ingestion", {})
|
||||
self.checkpoint_mgr.validate(
|
||||
"total_chunks",
|
||||
expected=expected.get("total_chunks", 8000),
|
||||
actual=total_chunks,
|
||||
min_value=expected.get("min_chunks", 7000)
|
||||
)
|
||||
|
||||
# Validate key regulations
|
||||
reg_expected = expected.get("regulations", {})
|
||||
for reg_code, reg_exp in reg_expected.items():
|
||||
actual = self.stats["by_regulation"].get(reg_code, 0)
|
||||
self.checkpoint_mgr.validate(
|
||||
f"chunks_{reg_code}",
|
||||
expected=reg_exp.get("expected", 0),
|
||||
actual=actual,
|
||||
min_value=reg_exp.get("min", 0)
|
||||
)
|
||||
|
||||
self.checkpoint_mgr.complete_checkpoint(success=True)
|
||||
|
||||
return total_chunks
|
||||
|
||||
except Exception as e:
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.fail_checkpoint(str(e))
|
||||
raise
|
||||
|
||||
finally:
|
||||
await ingestion.close()
|
||||
|
||||
async def run_extraction_phase(self) -> int:
|
||||
"""Phase 2: Extract checkpoints from chunks."""
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("PHASE 2: CHECKPOINT EXTRACTION")
|
||||
logger.info("=" * 60)
|
||||
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.start_checkpoint("extraction", "Checkpoint Extraction")
|
||||
|
||||
try:
|
||||
# Scroll through all chunks
|
||||
offset = None
|
||||
total_checkpoints = 0
|
||||
|
||||
while True:
|
||||
result = self.qdrant.scroll(
|
||||
collection_name=LEGAL_CORPUS_COLLECTION,
|
||||
limit=100,
|
||||
offset=offset,
|
||||
with_payload=True,
|
||||
with_vectors=False
|
||||
)
|
||||
|
||||
points, next_offset = result
|
||||
|
||||
if not points:
|
||||
break
|
||||
|
||||
for point in points:
|
||||
payload = point.payload
|
||||
text = payload.get("text", "")
|
||||
|
||||
checkpoints = self.extract_checkpoints_from_chunk(text, payload)
|
||||
self.checkpoints.extend(checkpoints)
|
||||
total_checkpoints += len(checkpoints)
|
||||
|
||||
logger.info(f"Processed {len(points)} chunks, extracted {total_checkpoints} checkpoints so far...")
|
||||
|
||||
if next_offset is None:
|
||||
break
|
||||
offset = next_offset
|
||||
|
||||
self.stats["checkpoints_extracted"] = len(self.checkpoints)
|
||||
logger.info(f"\nTotal checkpoints extracted: {len(self.checkpoints)}")
|
||||
|
||||
# Log per regulation
|
||||
by_reg = {}
|
||||
for cp in self.checkpoints:
|
||||
by_reg[cp.regulation_code] = by_reg.get(cp.regulation_code, 0) + 1
|
||||
for reg, count in sorted(by_reg.items()):
|
||||
logger.info(f" {reg}: {count} checkpoints")
|
||||
|
||||
# Validate extraction results
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.add_metric("total_checkpoints", len(self.checkpoints))
|
||||
self.checkpoint_mgr.add_metric("checkpoints_by_regulation", by_reg)
|
||||
|
||||
expected = EXPECTED_VALUES.get("extraction", {})
|
||||
self.checkpoint_mgr.validate(
|
||||
"total_checkpoints",
|
||||
expected=expected.get("total_checkpoints", 3500),
|
||||
actual=len(self.checkpoints),
|
||||
min_value=expected.get("min_checkpoints", 3000)
|
||||
)
|
||||
|
||||
self.checkpoint_mgr.complete_checkpoint(success=True)
|
||||
|
||||
return len(self.checkpoints)
|
||||
|
||||
except Exception as e:
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.fail_checkpoint(str(e))
|
||||
raise
|
||||
|
||||
async def run_control_generation_phase(self) -> int:
|
||||
"""Phase 3: Generate controls from checkpoints."""
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("PHASE 3: CONTROL GENERATION")
|
||||
logger.info("=" * 60)
|
||||
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.start_checkpoint("controls", "Control Generation")
|
||||
|
||||
try:
|
||||
# Group checkpoints by regulation
|
||||
by_regulation: Dict[str, List[Checkpoint]] = {}
|
||||
for cp in self.checkpoints:
|
||||
reg = cp.regulation_code
|
||||
if reg not in by_regulation:
|
||||
by_regulation[reg] = []
|
||||
by_regulation[reg].append(cp)
|
||||
|
||||
# Generate controls per regulation (group every 3-5 checkpoints)
|
||||
for regulation, checkpoints in by_regulation.items():
|
||||
logger.info(f"Generating controls for {regulation} ({len(checkpoints)} checkpoints)...")
|
||||
|
||||
# Group checkpoints into batches of 3-5
|
||||
batch_size = 4
|
||||
for i in range(0, len(checkpoints), batch_size):
|
||||
batch = checkpoints[i:i + batch_size]
|
||||
control = self.generate_control_for_checkpoints(batch)
|
||||
|
||||
if control:
|
||||
self.controls.append(control)
|
||||
self.stats["by_domain"][control.domain] = self.stats["by_domain"].get(control.domain, 0) + 1
|
||||
|
||||
self.stats["controls_created"] = len(self.controls)
|
||||
logger.info(f"\nTotal controls created: {len(self.controls)}")
|
||||
|
||||
# Log per domain
|
||||
for domain, count in sorted(self.stats["by_domain"].items()):
|
||||
logger.info(f" {domain}: {count} controls")
|
||||
|
||||
# Validate control generation
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.add_metric("total_controls", len(self.controls))
|
||||
self.checkpoint_mgr.add_metric("controls_by_domain", dict(self.stats["by_domain"]))
|
||||
|
||||
expected = EXPECTED_VALUES.get("controls", {})
|
||||
self.checkpoint_mgr.validate(
|
||||
"total_controls",
|
||||
expected=expected.get("total_controls", 900),
|
||||
actual=len(self.controls),
|
||||
min_value=expected.get("min_controls", 800)
|
||||
)
|
||||
|
||||
self.checkpoint_mgr.complete_checkpoint(success=True)
|
||||
|
||||
return len(self.controls)
|
||||
|
||||
except Exception as e:
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.fail_checkpoint(str(e))
|
||||
raise
|
||||
|
||||
async def run_measure_generation_phase(self) -> int:
|
||||
"""Phase 4: Generate measures for controls."""
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("PHASE 4: MEASURE GENERATION")
|
||||
logger.info("=" * 60)
|
||||
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.start_checkpoint("measures", "Measure Generation")
|
||||
|
||||
try:
|
||||
for control in self.controls:
|
||||
measure = self.generate_measure_for_control(control)
|
||||
self.measures.append(measure)
|
||||
|
||||
self.stats["measures_defined"] = len(self.measures)
|
||||
logger.info(f"\nTotal measures defined: {len(self.measures)}")
|
||||
|
||||
# Validate measure generation
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.add_metric("total_measures", len(self.measures))
|
||||
|
||||
expected = EXPECTED_VALUES.get("measures", {})
|
||||
self.checkpoint_mgr.validate(
|
||||
"total_measures",
|
||||
expected=expected.get("total_measures", 900),
|
||||
actual=len(self.measures),
|
||||
min_value=expected.get("min_measures", 800)
|
||||
)
|
||||
|
||||
self.checkpoint_mgr.complete_checkpoint(success=True)
|
||||
|
||||
return len(self.measures)
|
||||
|
||||
except Exception as e:
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.fail_checkpoint(str(e))
|
||||
raise
|
||||
|
||||
def save_results(self, output_dir: str = "/tmp/compliance_output"):
|
||||
"""Save results to JSON files."""
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("SAVING RESULTS")
|
||||
logger.info("=" * 60)
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Save checkpoints
|
||||
checkpoints_file = os.path.join(output_dir, "checkpoints.json")
|
||||
with open(checkpoints_file, "w") as f:
|
||||
json.dump([asdict(cp) for cp in self.checkpoints], f, indent=2, ensure_ascii=False)
|
||||
logger.info(f"Saved {len(self.checkpoints)} checkpoints to {checkpoints_file}")
|
||||
|
||||
# Save controls
|
||||
controls_file = os.path.join(output_dir, "controls.json")
|
||||
with open(controls_file, "w") as f:
|
||||
json.dump([asdict(c) for c in self.controls], f, indent=2, ensure_ascii=False)
|
||||
logger.info(f"Saved {len(self.controls)} controls to {controls_file}")
|
||||
|
||||
# Save measures
|
||||
measures_file = os.path.join(output_dir, "measures.json")
|
||||
with open(measures_file, "w") as f:
|
||||
json.dump([asdict(m) for m in self.measures], f, indent=2, ensure_ascii=False)
|
||||
logger.info(f"Saved {len(self.measures)} measures to {measures_file}")
|
||||
|
||||
# Save statistics
|
||||
stats_file = os.path.join(output_dir, "statistics.json")
|
||||
self.stats["generated_at"] = datetime.now().isoformat()
|
||||
with open(stats_file, "w") as f:
|
||||
json.dump(self.stats, f, indent=2, ensure_ascii=False)
|
||||
logger.info(f"Saved statistics to {stats_file}")
|
||||
|
||||
async def run_full_pipeline(self, force_reindex: bool = False, skip_ingestion: bool = False):
|
||||
"""Run the complete pipeline.
|
||||
|
||||
Args:
|
||||
force_reindex: If True, re-ingest all documents even if they exist
|
||||
skip_ingestion: If True, skip ingestion phase entirely (use existing chunks)
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("FULL COMPLIANCE PIPELINE (INCREMENTAL)")
|
||||
logger.info(f"Started at: {datetime.now().isoformat()}")
|
||||
logger.info(f"Force reindex: {force_reindex}")
|
||||
logger.info(f"Skip ingestion: {skip_ingestion}")
|
||||
if self.checkpoint_mgr:
|
||||
logger.info(f"Pipeline ID: {self.checkpoint_mgr.pipeline_id}")
|
||||
logger.info("=" * 60)
|
||||
|
||||
try:
|
||||
# Phase 1: Ingestion (skip if requested or run incrementally)
|
||||
if skip_ingestion:
|
||||
logger.info("Skipping ingestion phase as requested...")
|
||||
# Still get the chunk count
|
||||
try:
|
||||
collection_info = self.qdrant.get_collection(LEGAL_CORPUS_COLLECTION)
|
||||
self.stats["chunks_processed"] = collection_info.points_count
|
||||
except Exception:
|
||||
self.stats["chunks_processed"] = 0
|
||||
else:
|
||||
await self.run_ingestion_phase(force_reindex=force_reindex)
|
||||
|
||||
# Phase 2: Extraction
|
||||
await self.run_extraction_phase()
|
||||
|
||||
# Phase 3: Control Generation
|
||||
await self.run_control_generation_phase()
|
||||
|
||||
# Phase 4: Measure Generation
|
||||
await self.run_measure_generation_phase()
|
||||
|
||||
# Save results
|
||||
self.save_results()
|
||||
|
||||
# Final summary
|
||||
elapsed = time.time() - start_time
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("PIPELINE COMPLETE")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"Duration: {elapsed:.1f} seconds")
|
||||
logger.info(f"Chunks processed: {self.stats['chunks_processed']}")
|
||||
logger.info(f"Checkpoints extracted: {self.stats['checkpoints_extracted']}")
|
||||
logger.info(f"Controls created: {self.stats['controls_created']}")
|
||||
logger.info(f"Measures defined: {self.stats['measures_defined']}")
|
||||
logger.info(f"\nResults saved to: /tmp/compliance_output/")
|
||||
logger.info("Checkpoint status: /tmp/pipeline_checkpoints.json")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Complete pipeline checkpoint
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.complete_pipeline({
|
||||
"duration_seconds": elapsed,
|
||||
"chunks_processed": self.stats['chunks_processed'],
|
||||
"checkpoints_extracted": self.stats['checkpoints_extracted'],
|
||||
"controls_created": self.stats['controls_created'],
|
||||
"measures_defined": self.stats['measures_defined'],
|
||||
"by_regulation": self.stats['by_regulation'],
|
||||
"by_domain": self.stats['by_domain'],
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Pipeline failed: {e}")
|
||||
if self.checkpoint_mgr:
|
||||
self.checkpoint_mgr.state.status = "failed"
|
||||
self.checkpoint_mgr._save()
|
||||
raise
|
||||
# Re-export all public symbols
|
||||
from compliance_models import Checkpoint, Control, Measure
|
||||
from compliance_extraction import (
|
||||
extract_checkpoints_from_chunk,
|
||||
generate_control_for_checkpoints,
|
||||
generate_measure_for_control,
|
||||
)
|
||||
from compliance_pipeline import CompliancePipeline
|
||||
|
||||
__all__ = [
|
||||
"Checkpoint",
|
||||
"Control",
|
||||
"Measure",
|
||||
"extract_checkpoints_from_chunk",
|
||||
"generate_control_for_checkpoints",
|
||||
"generate_measure_for_control",
|
||||
"CompliancePipeline",
|
||||
]
|
||||
|
||||
|
||||
async def main():
|
||||
|
||||
@@ -1,767 +1,35 @@
|
||||
"""
|
||||
GitHub Repository Crawler for Legal Templates.
|
||||
GitHub Repository Crawler — Barrel Re-export
|
||||
|
||||
Crawls GitHub and GitLab repositories to extract legal template documents
|
||||
(Markdown, HTML, JSON, etc.) for ingestion into the RAG system.
|
||||
Split into:
|
||||
- github_crawler_parsers.py — ExtractedDocument, MarkdownParser, HTMLParser, JSONParser
|
||||
- github_crawler_core.py — GitHubCrawler, RepositoryDownloader, crawl_source
|
||||
|
||||
Features:
|
||||
- Clone repositories via Git or download as ZIP
|
||||
- Parse Markdown, HTML, JSON, and plain text files
|
||||
- Extract structured content with metadata
|
||||
- Track git commit hashes for reproducibility
|
||||
- Handle rate limiting and errors gracefully
|
||||
All public names are re-exported here for backward compatibility.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
import zipfile
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from fnmatch import fnmatch
|
||||
from pathlib import Path
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
from template_sources import LicenseType, SourceConfig, LICENSES
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Configuration
|
||||
GITHUB_API_URL = "https://api.github.com"
|
||||
GITLAB_API_URL = "https://gitlab.com/api/v4"
|
||||
GITHUB_TOKEN = os.getenv("GITHUB_TOKEN", "") # Optional for higher rate limits
|
||||
MAX_FILE_SIZE = 1024 * 1024 # 1 MB max file size
|
||||
REQUEST_TIMEOUT = 60.0
|
||||
RATE_LIMIT_DELAY = 1.0 # Delay between requests to avoid rate limiting
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractedDocument:
|
||||
"""A document extracted from a repository."""
|
||||
text: str
|
||||
title: str
|
||||
file_path: str
|
||||
file_type: str # "markdown", "html", "json", "text"
|
||||
source_url: str
|
||||
source_commit: Optional[str] = None
|
||||
source_hash: str = "" # SHA256 of original content
|
||||
sections: List[Dict[str, Any]] = field(default_factory=list)
|
||||
placeholders: List[str] = field(default_factory=list)
|
||||
language: str = "en"
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.source_hash and self.text:
|
||||
self.source_hash = hashlib.sha256(self.text.encode()).hexdigest()
|
||||
|
||||
|
||||
class MarkdownParser:
|
||||
"""Parse Markdown files into structured content."""
|
||||
|
||||
# Common placeholder patterns
|
||||
PLACEHOLDER_PATTERNS = [
|
||||
r'\[([A-Z_]+)\]', # [COMPANY_NAME]
|
||||
r'\{([a-z_]+)\}', # {company_name}
|
||||
r'\{\{([a-z_]+)\}\}', # {{company_name}}
|
||||
r'__([A-Z_]+)__', # __COMPANY_NAME__
|
||||
r'<([A-Z_]+)>', # <COMPANY_NAME>
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def parse(cls, content: str, filename: str = "") -> ExtractedDocument:
|
||||
"""Parse markdown content into an ExtractedDocument."""
|
||||
# Extract title from first heading or filename
|
||||
title = cls._extract_title(content, filename)
|
||||
|
||||
# Extract sections
|
||||
sections = cls._extract_sections(content)
|
||||
|
||||
# Find placeholders
|
||||
placeholders = cls._find_placeholders(content)
|
||||
|
||||
# Detect language
|
||||
language = cls._detect_language(content)
|
||||
|
||||
# Clean content for indexing
|
||||
clean_text = cls._clean_for_indexing(content)
|
||||
|
||||
return ExtractedDocument(
|
||||
text=clean_text,
|
||||
title=title,
|
||||
file_path=filename,
|
||||
file_type="markdown",
|
||||
source_url="", # Will be set by caller
|
||||
sections=sections,
|
||||
placeholders=placeholders,
|
||||
language=language,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_title(cls, content: str, filename: str) -> str:
|
||||
"""Extract title from markdown heading or filename."""
|
||||
# Look for first h1 heading
|
||||
h1_match = re.search(r'^#\s+(.+)$', content, re.MULTILINE)
|
||||
if h1_match:
|
||||
return h1_match.group(1).strip()
|
||||
|
||||
# Look for YAML frontmatter title
|
||||
frontmatter_match = re.search(
|
||||
r'^---\s*\n.*?title:\s*["\']?(.+?)["\']?\s*\n.*?---',
|
||||
content, re.DOTALL
|
||||
)
|
||||
if frontmatter_match:
|
||||
return frontmatter_match.group(1).strip()
|
||||
|
||||
# Fall back to filename
|
||||
if filename:
|
||||
name = Path(filename).stem
|
||||
# Convert kebab-case or snake_case to title case
|
||||
return name.replace('-', ' ').replace('_', ' ').title()
|
||||
|
||||
return "Untitled"
|
||||
|
||||
@classmethod
|
||||
def _extract_sections(cls, content: str) -> List[Dict[str, Any]]:
|
||||
"""Extract sections from markdown content."""
|
||||
sections = []
|
||||
current_section = {"heading": "", "level": 0, "content": "", "start": 0}
|
||||
|
||||
for match in re.finditer(r'^(#{1,6})\s+(.+)$', content, re.MULTILINE):
|
||||
# Save previous section if it has content
|
||||
if current_section["heading"] or current_section["content"].strip():
|
||||
current_section["content"] = current_section["content"].strip()
|
||||
sections.append(current_section.copy())
|
||||
|
||||
# Start new section
|
||||
level = len(match.group(1))
|
||||
heading = match.group(2).strip()
|
||||
current_section = {
|
||||
"heading": heading,
|
||||
"level": level,
|
||||
"content": "",
|
||||
"start": match.end(),
|
||||
}
|
||||
|
||||
# Add final section
|
||||
if current_section["heading"] or current_section["content"].strip():
|
||||
current_section["content"] = content[current_section["start"]:].strip()
|
||||
sections.append(current_section)
|
||||
|
||||
return sections
|
||||
|
||||
@classmethod
|
||||
def _find_placeholders(cls, content: str) -> List[str]:
|
||||
"""Find placeholder patterns in content."""
|
||||
placeholders = set()
|
||||
for pattern in cls.PLACEHOLDER_PATTERNS:
|
||||
for match in re.finditer(pattern, content):
|
||||
placeholder = match.group(0)
|
||||
placeholders.add(placeholder)
|
||||
return sorted(list(placeholders))
|
||||
|
||||
@classmethod
|
||||
def _detect_language(cls, content: str) -> str:
|
||||
"""Detect language from content."""
|
||||
# Look for German-specific words
|
||||
german_indicators = [
|
||||
'Datenschutz', 'Impressum', 'Nutzungsbedingungen', 'Haftung',
|
||||
'Widerruf', 'Verantwortlicher', 'personenbezogene', 'Verarbeitung',
|
||||
'und', 'der', 'die', 'das', 'ist', 'wird', 'werden', 'sind',
|
||||
]
|
||||
|
||||
lower_content = content.lower()
|
||||
german_count = sum(1 for word in german_indicators if word.lower() in lower_content)
|
||||
|
||||
if german_count >= 3:
|
||||
return "de"
|
||||
return "en"
|
||||
|
||||
@classmethod
|
||||
def _clean_for_indexing(cls, content: str) -> str:
|
||||
"""Clean markdown content for text indexing."""
|
||||
# Remove YAML frontmatter
|
||||
content = re.sub(r'^---\s*\n.*?---\s*\n', '', content, flags=re.DOTALL)
|
||||
|
||||
# Remove HTML comments
|
||||
content = re.sub(r'<!--.*?-->', '', content, flags=re.DOTALL)
|
||||
|
||||
# Remove inline HTML tags but keep content
|
||||
content = re.sub(r'<[^>]+>', '', content)
|
||||
|
||||
# Convert markdown formatting
|
||||
content = re.sub(r'\*\*(.+?)\*\*', r'\1', content) # Bold
|
||||
content = re.sub(r'\*(.+?)\*', r'\1', content) # Italic
|
||||
content = re.sub(r'`(.+?)`', r'\1', content) # Inline code
|
||||
content = re.sub(r'~~(.+?)~~', r'\1', content) # Strikethrough
|
||||
|
||||
# Remove link syntax but keep text
|
||||
content = re.sub(r'\[([^\]]+)\]\([^)]+\)', r'\1', content)
|
||||
|
||||
# Remove image syntax
|
||||
content = re.sub(r'!\[([^\]]*)\]\([^)]+\)', r'\1', content)
|
||||
|
||||
# Clean up whitespace
|
||||
content = re.sub(r'\n{3,}', '\n\n', content)
|
||||
content = re.sub(r' +', ' ', content)
|
||||
|
||||
return content.strip()
|
||||
|
||||
|
||||
class HTMLParser:
|
||||
"""Parse HTML files into structured content."""
|
||||
|
||||
@classmethod
|
||||
def parse(cls, content: str, filename: str = "") -> ExtractedDocument:
|
||||
"""Parse HTML content into an ExtractedDocument."""
|
||||
# Extract title
|
||||
title_match = re.search(r'<title>(.+?)</title>', content, re.IGNORECASE)
|
||||
title = title_match.group(1) if title_match else Path(filename).stem
|
||||
|
||||
# Convert to text
|
||||
text = cls._html_to_text(content)
|
||||
|
||||
# Find placeholders
|
||||
placeholders = MarkdownParser._find_placeholders(text)
|
||||
|
||||
# Detect language
|
||||
lang_match = re.search(r'<html[^>]*lang=["\']([a-z]{2})["\']', content, re.IGNORECASE)
|
||||
language = lang_match.group(1) if lang_match else MarkdownParser._detect_language(text)
|
||||
|
||||
return ExtractedDocument(
|
||||
text=text,
|
||||
title=title,
|
||||
file_path=filename,
|
||||
file_type="html",
|
||||
source_url="",
|
||||
placeholders=placeholders,
|
||||
language=language,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _html_to_text(cls, html: str) -> str:
|
||||
"""Convert HTML to clean text."""
|
||||
# Remove script and style tags
|
||||
html = re.sub(r'<script[^>]*>.*?</script>', '', html, flags=re.DOTALL | re.IGNORECASE)
|
||||
html = re.sub(r'<style[^>]*>.*?</style>', '', html, flags=re.DOTALL | re.IGNORECASE)
|
||||
|
||||
# Remove comments
|
||||
html = re.sub(r'<!--.*?-->', '', html, flags=re.DOTALL)
|
||||
|
||||
# Replace common entities
|
||||
html = html.replace(' ', ' ')
|
||||
html = html.replace('&', '&')
|
||||
html = html.replace('<', '<')
|
||||
html = html.replace('>', '>')
|
||||
html = html.replace('"', '"')
|
||||
html = html.replace(''', "'")
|
||||
|
||||
# Add line breaks for block elements
|
||||
html = re.sub(r'<br\s*/?>', '\n', html, flags=re.IGNORECASE)
|
||||
html = re.sub(r'</p>', '\n\n', html, flags=re.IGNORECASE)
|
||||
html = re.sub(r'</div>', '\n', html, flags=re.IGNORECASE)
|
||||
html = re.sub(r'</h[1-6]>', '\n\n', html, flags=re.IGNORECASE)
|
||||
html = re.sub(r'</li>', '\n', html, flags=re.IGNORECASE)
|
||||
|
||||
# Remove remaining tags
|
||||
html = re.sub(r'<[^>]+>', '', html)
|
||||
|
||||
# Clean whitespace
|
||||
html = re.sub(r'[ \t]+', ' ', html)
|
||||
html = re.sub(r'\n[ \t]+', '\n', html)
|
||||
html = re.sub(r'[ \t]+\n', '\n', html)
|
||||
html = re.sub(r'\n{3,}', '\n\n', html)
|
||||
|
||||
return html.strip()
|
||||
|
||||
|
||||
class JSONParser:
|
||||
"""Parse JSON files containing legal template data."""
|
||||
|
||||
@classmethod
|
||||
def parse(cls, content: str, filename: str = "") -> List[ExtractedDocument]:
|
||||
"""Parse JSON content into ExtractedDocuments."""
|
||||
try:
|
||||
data = json.loads(content)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Failed to parse JSON from {filename}: {e}")
|
||||
return []
|
||||
|
||||
documents = []
|
||||
|
||||
if isinstance(data, dict):
|
||||
# Handle different JSON structures
|
||||
documents.extend(cls._parse_dict(data, filename))
|
||||
elif isinstance(data, list):
|
||||
for i, item in enumerate(data):
|
||||
if isinstance(item, dict):
|
||||
docs = cls._parse_dict(item, f"{filename}[{i}]")
|
||||
documents.extend(docs)
|
||||
|
||||
return documents
|
||||
|
||||
@classmethod
|
||||
def _parse_dict(cls, data: dict, filename: str) -> List[ExtractedDocument]:
|
||||
"""Parse a dictionary into documents."""
|
||||
documents = []
|
||||
|
||||
# Look for text content in common keys
|
||||
text_keys = ['text', 'content', 'body', 'description', 'value']
|
||||
title_keys = ['title', 'name', 'heading', 'label', 'key']
|
||||
|
||||
# Try to find main text content
|
||||
text = ""
|
||||
for key in text_keys:
|
||||
if key in data and isinstance(data[key], str):
|
||||
text = data[key]
|
||||
break
|
||||
|
||||
if not text:
|
||||
# Check for nested structures (like webflorist format)
|
||||
for key, value in data.items():
|
||||
if isinstance(value, dict):
|
||||
nested_docs = cls._parse_dict(value, f"{filename}.{key}")
|
||||
documents.extend(nested_docs)
|
||||
elif isinstance(value, list):
|
||||
for i, item in enumerate(value):
|
||||
if isinstance(item, dict):
|
||||
nested_docs = cls._parse_dict(item, f"{filename}.{key}[{i}]")
|
||||
documents.extend(nested_docs)
|
||||
elif isinstance(item, str) and len(item) > 50:
|
||||
# Treat long strings as content
|
||||
documents.append(ExtractedDocument(
|
||||
text=item,
|
||||
title=f"{key} {i+1}",
|
||||
file_path=filename,
|
||||
file_type="json",
|
||||
source_url="",
|
||||
language=MarkdownParser._detect_language(item),
|
||||
))
|
||||
return documents
|
||||
|
||||
# Found text content
|
||||
title = ""
|
||||
for key in title_keys:
|
||||
if key in data and isinstance(data[key], str):
|
||||
title = data[key]
|
||||
break
|
||||
|
||||
if not title:
|
||||
title = Path(filename).stem
|
||||
|
||||
# Extract metadata
|
||||
metadata = {}
|
||||
for key, value in data.items():
|
||||
if key not in text_keys + title_keys and not isinstance(value, (dict, list)):
|
||||
metadata[key] = value
|
||||
|
||||
placeholders = MarkdownParser._find_placeholders(text)
|
||||
language = data.get('lang', data.get('language', MarkdownParser._detect_language(text)))
|
||||
|
||||
documents.append(ExtractedDocument(
|
||||
text=text,
|
||||
title=title,
|
||||
file_path=filename,
|
||||
file_type="json",
|
||||
source_url="",
|
||||
placeholders=placeholders,
|
||||
language=language,
|
||||
metadata=metadata,
|
||||
))
|
||||
|
||||
return documents
|
||||
|
||||
|
||||
class GitHubCrawler:
|
||||
"""Crawl GitHub repositories for legal templates."""
|
||||
|
||||
def __init__(self, token: Optional[str] = None):
|
||||
self.token = token or GITHUB_TOKEN
|
||||
self.headers = {
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
"User-Agent": "LegalTemplatesCrawler/1.0",
|
||||
}
|
||||
if self.token:
|
||||
self.headers["Authorization"] = f"token {self.token}"
|
||||
|
||||
self.http_client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
async def __aenter__(self):
|
||||
self.http_client = httpx.AsyncClient(
|
||||
timeout=REQUEST_TIMEOUT,
|
||||
headers=self.headers,
|
||||
follow_redirects=True,
|
||||
)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.http_client:
|
||||
await self.http_client.aclose()
|
||||
|
||||
def _parse_repo_url(self, url: str) -> Tuple[str, str, str]:
|
||||
"""Parse repository URL into owner, repo, and host."""
|
||||
parsed = urlparse(url)
|
||||
path_parts = parsed.path.strip('/').split('/')
|
||||
|
||||
if len(path_parts) < 2:
|
||||
raise ValueError(f"Invalid repository URL: {url}")
|
||||
|
||||
owner = path_parts[0]
|
||||
repo = path_parts[1].replace('.git', '')
|
||||
|
||||
if 'gitlab' in parsed.netloc:
|
||||
host = 'gitlab'
|
||||
else:
|
||||
host = 'github'
|
||||
|
||||
return owner, repo, host
|
||||
|
||||
async def get_default_branch(self, owner: str, repo: str) -> str:
|
||||
"""Get the default branch of a repository."""
|
||||
if not self.http_client:
|
||||
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
|
||||
|
||||
url = f"{GITHUB_API_URL}/repos/{owner}/{repo}"
|
||||
response = await self.http_client.get(url)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("default_branch", "main")
|
||||
|
||||
async def get_latest_commit(self, owner: str, repo: str, branch: str = "main") -> str:
|
||||
"""Get the latest commit SHA for a branch."""
|
||||
if not self.http_client:
|
||||
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
|
||||
|
||||
url = f"{GITHUB_API_URL}/repos/{owner}/{repo}/commits/{branch}"
|
||||
response = await self.http_client.get(url)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("sha", "")
|
||||
|
||||
async def list_files(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
path: str = "",
|
||||
branch: str = "main",
|
||||
patterns: List[str] = None,
|
||||
exclude_patterns: List[str] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""List files in a repository matching the given patterns."""
|
||||
if not self.http_client:
|
||||
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
|
||||
|
||||
patterns = patterns or ["*.md", "*.txt", "*.html"]
|
||||
exclude_patterns = exclude_patterns or []
|
||||
|
||||
url = f"{GITHUB_API_URL}/repos/{owner}/{repo}/git/trees/{branch}?recursive=1"
|
||||
response = await self.http_client.get(url)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
files = []
|
||||
for item in data.get("tree", []):
|
||||
if item["type"] != "blob":
|
||||
continue
|
||||
|
||||
file_path = item["path"]
|
||||
|
||||
# Check exclude patterns
|
||||
excluded = any(fnmatch(file_path, pattern) for pattern in exclude_patterns)
|
||||
if excluded:
|
||||
continue
|
||||
|
||||
# Check include patterns
|
||||
matched = any(fnmatch(file_path, pattern) for pattern in patterns)
|
||||
if not matched:
|
||||
continue
|
||||
|
||||
# Skip large files
|
||||
if item.get("size", 0) > MAX_FILE_SIZE:
|
||||
logger.warning(f"Skipping large file: {file_path} ({item['size']} bytes)")
|
||||
continue
|
||||
|
||||
files.append({
|
||||
"path": file_path,
|
||||
"sha": item["sha"],
|
||||
"size": item.get("size", 0),
|
||||
"url": item.get("url", ""),
|
||||
})
|
||||
|
||||
return files
|
||||
|
||||
async def get_file_content(self, owner: str, repo: str, path: str, branch: str = "main") -> str:
|
||||
"""Get the content of a file from a repository."""
|
||||
if not self.http_client:
|
||||
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
|
||||
|
||||
# Use raw content URL for simplicity
|
||||
url = f"https://raw.githubusercontent.com/{owner}/{repo}/{branch}/{path}"
|
||||
response = await self.http_client.get(url)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
|
||||
async def crawl_repository(
|
||||
self,
|
||||
source: SourceConfig,
|
||||
) -> AsyncGenerator[ExtractedDocument, None]:
|
||||
"""Crawl a repository and yield extracted documents."""
|
||||
if not source.repo_url:
|
||||
logger.warning(f"No repo URL for source: {source.name}")
|
||||
return
|
||||
|
||||
try:
|
||||
owner, repo, host = self._parse_repo_url(source.repo_url)
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to parse repo URL for {source.name}: {e}")
|
||||
return
|
||||
|
||||
if host == "gitlab":
|
||||
logger.info(f"GitLab repos not yet supported: {source.name}")
|
||||
return
|
||||
|
||||
logger.info(f"Crawling repository: {owner}/{repo}")
|
||||
|
||||
try:
|
||||
# Get default branch and latest commit
|
||||
branch = await self.get_default_branch(owner, repo)
|
||||
commit_sha = await self.get_latest_commit(owner, repo, branch)
|
||||
|
||||
await asyncio.sleep(RATE_LIMIT_DELAY)
|
||||
|
||||
# List files matching patterns
|
||||
files = await self.list_files(
|
||||
owner, repo,
|
||||
branch=branch,
|
||||
patterns=source.file_patterns,
|
||||
exclude_patterns=source.exclude_patterns,
|
||||
)
|
||||
|
||||
logger.info(f"Found {len(files)} matching files in {source.name}")
|
||||
|
||||
for file_info in files:
|
||||
await asyncio.sleep(RATE_LIMIT_DELAY)
|
||||
|
||||
try:
|
||||
content = await self.get_file_content(
|
||||
owner, repo, file_info["path"], branch
|
||||
)
|
||||
|
||||
# Parse based on file type
|
||||
file_path = file_info["path"]
|
||||
source_url = f"https://github.com/{owner}/{repo}/blob/{branch}/{file_path}"
|
||||
|
||||
if file_path.endswith('.md'):
|
||||
doc = MarkdownParser.parse(content, file_path)
|
||||
doc.source_url = source_url
|
||||
doc.source_commit = commit_sha
|
||||
yield doc
|
||||
|
||||
elif file_path.endswith('.html') or file_path.endswith('.htm'):
|
||||
doc = HTMLParser.parse(content, file_path)
|
||||
doc.source_url = source_url
|
||||
doc.source_commit = commit_sha
|
||||
yield doc
|
||||
|
||||
elif file_path.endswith('.json'):
|
||||
docs = JSONParser.parse(content, file_path)
|
||||
for doc in docs:
|
||||
doc.source_url = source_url
|
||||
doc.source_commit = commit_sha
|
||||
yield doc
|
||||
|
||||
elif file_path.endswith('.txt'):
|
||||
# Plain text file
|
||||
yield ExtractedDocument(
|
||||
text=content,
|
||||
title=Path(file_path).stem,
|
||||
file_path=file_path,
|
||||
file_type="text",
|
||||
source_url=source_url,
|
||||
source_commit=commit_sha,
|
||||
language=MarkdownParser._detect_language(content),
|
||||
placeholders=MarkdownParser._find_placeholders(content),
|
||||
)
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
logger.warning(f"Failed to fetch {file_path}: {e}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {file_path}: {e}")
|
||||
continue
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"HTTP error crawling {source.name}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error crawling {source.name}: {e}")
|
||||
|
||||
|
||||
class RepositoryDownloader:
|
||||
"""Download and extract repository archives."""
|
||||
|
||||
def __init__(self):
|
||||
self.http_client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
async def __aenter__(self):
|
||||
self.http_client = httpx.AsyncClient(
|
||||
timeout=120.0,
|
||||
follow_redirects=True,
|
||||
)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.http_client:
|
||||
await self.http_client.aclose()
|
||||
|
||||
async def download_zip(self, repo_url: str, branch: str = "main") -> Path:
|
||||
"""Download repository as ZIP and extract to temp directory."""
|
||||
if not self.http_client:
|
||||
raise RuntimeError("Downloader not initialized. Use 'async with' context.")
|
||||
|
||||
parsed = urlparse(repo_url)
|
||||
path_parts = parsed.path.strip('/').split('/')
|
||||
owner = path_parts[0]
|
||||
repo = path_parts[1].replace('.git', '')
|
||||
|
||||
zip_url = f"https://github.com/{owner}/{repo}/archive/refs/heads/{branch}.zip"
|
||||
|
||||
logger.info(f"Downloading ZIP from {zip_url}")
|
||||
|
||||
response = await self.http_client.get(zip_url)
|
||||
response.raise_for_status()
|
||||
|
||||
# Save to temp file
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
zip_path = temp_dir / f"{repo}.zip"
|
||||
|
||||
with open(zip_path, 'wb') as f:
|
||||
f.write(response.content)
|
||||
|
||||
# Extract ZIP
|
||||
extract_dir = temp_dir / repo
|
||||
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(temp_dir)
|
||||
|
||||
# The extracted directory is usually named repo-branch
|
||||
extracted_dirs = list(temp_dir.glob(f"{repo}-*"))
|
||||
if extracted_dirs:
|
||||
return extracted_dirs[0]
|
||||
|
||||
return extract_dir
|
||||
|
||||
async def crawl_local_directory(
|
||||
self,
|
||||
directory: Path,
|
||||
source: SourceConfig,
|
||||
base_url: str,
|
||||
) -> AsyncGenerator[ExtractedDocument, None]:
|
||||
"""Crawl a local directory for documents."""
|
||||
patterns = source.file_patterns or ["*.md", "*.txt", "*.html"]
|
||||
exclude_patterns = source.exclude_patterns or []
|
||||
|
||||
for pattern in patterns:
|
||||
for file_path in directory.rglob(pattern.replace("**/", "")):
|
||||
if not file_path.is_file():
|
||||
continue
|
||||
|
||||
rel_path = str(file_path.relative_to(directory))
|
||||
|
||||
# Check exclude patterns
|
||||
excluded = any(fnmatch(rel_path, ep) for ep in exclude_patterns)
|
||||
if excluded:
|
||||
continue
|
||||
|
||||
# Skip large files
|
||||
if file_path.stat().st_size > MAX_FILE_SIZE:
|
||||
continue
|
||||
|
||||
try:
|
||||
content = file_path.read_text(encoding='utf-8')
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
content = file_path.read_text(encoding='latin-1')
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
source_url = f"{base_url}/{rel_path}"
|
||||
|
||||
if file_path.suffix == '.md':
|
||||
doc = MarkdownParser.parse(content, rel_path)
|
||||
doc.source_url = source_url
|
||||
yield doc
|
||||
|
||||
elif file_path.suffix in ['.html', '.htm']:
|
||||
doc = HTMLParser.parse(content, rel_path)
|
||||
doc.source_url = source_url
|
||||
yield doc
|
||||
|
||||
elif file_path.suffix == '.json':
|
||||
docs = JSONParser.parse(content, rel_path)
|
||||
for doc in docs:
|
||||
doc.source_url = source_url
|
||||
yield doc
|
||||
|
||||
elif file_path.suffix == '.txt':
|
||||
yield ExtractedDocument(
|
||||
text=content,
|
||||
title=file_path.stem,
|
||||
file_path=rel_path,
|
||||
file_type="text",
|
||||
source_url=source_url,
|
||||
language=MarkdownParser._detect_language(content),
|
||||
placeholders=MarkdownParser._find_placeholders(content),
|
||||
)
|
||||
|
||||
def cleanup(self, directory: Path):
|
||||
"""Clean up temporary directory."""
|
||||
if directory.exists():
|
||||
shutil.rmtree(directory, ignore_errors=True)
|
||||
|
||||
|
||||
async def crawl_source(source: SourceConfig) -> List[ExtractedDocument]:
|
||||
"""Crawl a source configuration and return all extracted documents."""
|
||||
documents = []
|
||||
|
||||
if source.repo_url:
|
||||
async with GitHubCrawler() as crawler:
|
||||
async for doc in crawler.crawl_repository(source):
|
||||
documents.append(doc)
|
||||
|
||||
return documents
|
||||
|
||||
|
||||
# CLI for testing
|
||||
async def main():
|
||||
"""Test crawler with a sample source."""
|
||||
from template_sources import TEMPLATE_SOURCES
|
||||
|
||||
# Test with github-site-policy
|
||||
source = next(s for s in TEMPLATE_SOURCES if s.name == "github-site-policy")
|
||||
|
||||
async with GitHubCrawler() as crawler:
|
||||
count = 0
|
||||
async for doc in crawler.crawl_repository(source):
|
||||
count += 1
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Title: {doc.title}")
|
||||
print(f"Path: {doc.file_path}")
|
||||
print(f"Type: {doc.file_type}")
|
||||
print(f"Language: {doc.language}")
|
||||
print(f"URL: {doc.source_url}")
|
||||
print(f"Placeholders: {doc.placeholders[:5] if doc.placeholders else 'None'}")
|
||||
print(f"Text preview: {doc.text[:200]}...")
|
||||
|
||||
print(f"\n\nTotal documents: {count}")
|
||||
|
||||
# Parsers
|
||||
from github_crawler_parsers import ( # noqa: F401
|
||||
ExtractedDocument,
|
||||
MarkdownParser,
|
||||
HTMLParser,
|
||||
JSONParser,
|
||||
)
|
||||
|
||||
# Crawler and downloader
|
||||
from github_crawler_core import ( # noqa: F401
|
||||
GITHUB_API_URL,
|
||||
GITLAB_API_URL,
|
||||
GITHUB_TOKEN,
|
||||
MAX_FILE_SIZE,
|
||||
REQUEST_TIMEOUT,
|
||||
RATE_LIMIT_DELAY,
|
||||
GitHubCrawler,
|
||||
RepositoryDownloader,
|
||||
crawl_source,
|
||||
main,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(main())
|
||||
|
||||
411
klausur-service/backend/github_crawler_core.py
Normal file
411
klausur-service/backend/github_crawler_core.py
Normal file
@@ -0,0 +1,411 @@
|
||||
"""
|
||||
GitHub Crawler - Core Crawler and Downloader
|
||||
|
||||
GitHubCrawler for API-based repository crawling and RepositoryDownloader
|
||||
for ZIP-based local extraction.
|
||||
|
||||
Extracted from github_crawler.py to keep files under 500 LOC.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import zipfile
|
||||
from fnmatch import fnmatch
|
||||
from pathlib import Path
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
from template_sources import SourceConfig
|
||||
from github_crawler_parsers import (
|
||||
ExtractedDocument,
|
||||
MarkdownParser,
|
||||
HTMLParser,
|
||||
JSONParser,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Configuration
|
||||
GITHUB_API_URL = "https://api.github.com"
|
||||
GITLAB_API_URL = "https://gitlab.com/api/v4"
|
||||
GITHUB_TOKEN = os.getenv("GITHUB_TOKEN", "")
|
||||
MAX_FILE_SIZE = 1024 * 1024 # 1 MB max file size
|
||||
REQUEST_TIMEOUT = 60.0
|
||||
RATE_LIMIT_DELAY = 1.0
|
||||
|
||||
|
||||
class GitHubCrawler:
|
||||
"""Crawl GitHub repositories for legal templates."""
|
||||
|
||||
def __init__(self, token: Optional[str] = None):
|
||||
self.token = token or GITHUB_TOKEN
|
||||
self.headers = {
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
"User-Agent": "LegalTemplatesCrawler/1.0",
|
||||
}
|
||||
if self.token:
|
||||
self.headers["Authorization"] = f"token {self.token}"
|
||||
|
||||
self.http_client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
async def __aenter__(self):
|
||||
self.http_client = httpx.AsyncClient(
|
||||
timeout=REQUEST_TIMEOUT,
|
||||
headers=self.headers,
|
||||
follow_redirects=True,
|
||||
)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.http_client:
|
||||
await self.http_client.aclose()
|
||||
|
||||
def _parse_repo_url(self, url: str) -> Tuple[str, str, str]:
|
||||
"""Parse repository URL into owner, repo, and host."""
|
||||
parsed = urlparse(url)
|
||||
path_parts = parsed.path.strip('/').split('/')
|
||||
|
||||
if len(path_parts) < 2:
|
||||
raise ValueError(f"Invalid repository URL: {url}")
|
||||
|
||||
owner = path_parts[0]
|
||||
repo = path_parts[1].replace('.git', '')
|
||||
|
||||
if 'gitlab' in parsed.netloc:
|
||||
host = 'gitlab'
|
||||
else:
|
||||
host = 'github'
|
||||
|
||||
return owner, repo, host
|
||||
|
||||
async def get_default_branch(self, owner: str, repo: str) -> str:
|
||||
"""Get the default branch of a repository."""
|
||||
if not self.http_client:
|
||||
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
|
||||
|
||||
url = f"{GITHUB_API_URL}/repos/{owner}/{repo}"
|
||||
response = await self.http_client.get(url)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("default_branch", "main")
|
||||
|
||||
async def get_latest_commit(self, owner: str, repo: str, branch: str = "main") -> str:
|
||||
"""Get the latest commit SHA for a branch."""
|
||||
if not self.http_client:
|
||||
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
|
||||
|
||||
url = f"{GITHUB_API_URL}/repos/{owner}/{repo}/commits/{branch}"
|
||||
response = await self.http_client.get(url)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("sha", "")
|
||||
|
||||
async def list_files(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
path: str = "",
|
||||
branch: str = "main",
|
||||
patterns: List[str] = None,
|
||||
exclude_patterns: List[str] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""List files in a repository matching the given patterns."""
|
||||
if not self.http_client:
|
||||
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
|
||||
|
||||
patterns = patterns or ["*.md", "*.txt", "*.html"]
|
||||
exclude_patterns = exclude_patterns or []
|
||||
|
||||
url = f"{GITHUB_API_URL}/repos/{owner}/{repo}/git/trees/{branch}?recursive=1"
|
||||
response = await self.http_client.get(url)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
files = []
|
||||
for item in data.get("tree", []):
|
||||
if item["type"] != "blob":
|
||||
continue
|
||||
|
||||
file_path = item["path"]
|
||||
|
||||
excluded = any(fnmatch(file_path, pattern) for pattern in exclude_patterns)
|
||||
if excluded:
|
||||
continue
|
||||
|
||||
matched = any(fnmatch(file_path, pattern) for pattern in patterns)
|
||||
if not matched:
|
||||
continue
|
||||
|
||||
if item.get("size", 0) > MAX_FILE_SIZE:
|
||||
logger.warning(f"Skipping large file: {file_path} ({item['size']} bytes)")
|
||||
continue
|
||||
|
||||
files.append({
|
||||
"path": file_path,
|
||||
"sha": item["sha"],
|
||||
"size": item.get("size", 0),
|
||||
"url": item.get("url", ""),
|
||||
})
|
||||
|
||||
return files
|
||||
|
||||
async def get_file_content(self, owner: str, repo: str, path: str, branch: str = "main") -> str:
|
||||
"""Get the content of a file from a repository."""
|
||||
if not self.http_client:
|
||||
raise RuntimeError("Crawler not initialized. Use 'async with' context.")
|
||||
|
||||
url = f"https://raw.githubusercontent.com/{owner}/{repo}/{branch}/{path}"
|
||||
response = await self.http_client.get(url)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
|
||||
async def crawl_repository(
|
||||
self,
|
||||
source: SourceConfig,
|
||||
) -> AsyncGenerator[ExtractedDocument, None]:
|
||||
"""Crawl a repository and yield extracted documents."""
|
||||
if not source.repo_url:
|
||||
logger.warning(f"No repo URL for source: {source.name}")
|
||||
return
|
||||
|
||||
try:
|
||||
owner, repo, host = self._parse_repo_url(source.repo_url)
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to parse repo URL for {source.name}: {e}")
|
||||
return
|
||||
|
||||
if host == "gitlab":
|
||||
logger.info(f"GitLab repos not yet supported: {source.name}")
|
||||
return
|
||||
|
||||
logger.info(f"Crawling repository: {owner}/{repo}")
|
||||
|
||||
try:
|
||||
branch = await self.get_default_branch(owner, repo)
|
||||
commit_sha = await self.get_latest_commit(owner, repo, branch)
|
||||
|
||||
await asyncio.sleep(RATE_LIMIT_DELAY)
|
||||
|
||||
files = await self.list_files(
|
||||
owner, repo,
|
||||
branch=branch,
|
||||
patterns=source.file_patterns,
|
||||
exclude_patterns=source.exclude_patterns,
|
||||
)
|
||||
|
||||
logger.info(f"Found {len(files)} matching files in {source.name}")
|
||||
|
||||
for file_info in files:
|
||||
await asyncio.sleep(RATE_LIMIT_DELAY)
|
||||
|
||||
try:
|
||||
content = await self.get_file_content(
|
||||
owner, repo, file_info["path"], branch
|
||||
)
|
||||
|
||||
file_path = file_info["path"]
|
||||
source_url = f"https://github.com/{owner}/{repo}/blob/{branch}/{file_path}"
|
||||
|
||||
if file_path.endswith('.md'):
|
||||
doc = MarkdownParser.parse(content, file_path)
|
||||
doc.source_url = source_url
|
||||
doc.source_commit = commit_sha
|
||||
yield doc
|
||||
|
||||
elif file_path.endswith('.html') or file_path.endswith('.htm'):
|
||||
doc = HTMLParser.parse(content, file_path)
|
||||
doc.source_url = source_url
|
||||
doc.source_commit = commit_sha
|
||||
yield doc
|
||||
|
||||
elif file_path.endswith('.json'):
|
||||
docs = JSONParser.parse(content, file_path)
|
||||
for doc in docs:
|
||||
doc.source_url = source_url
|
||||
doc.source_commit = commit_sha
|
||||
yield doc
|
||||
|
||||
elif file_path.endswith('.txt'):
|
||||
yield ExtractedDocument(
|
||||
text=content,
|
||||
title=Path(file_path).stem,
|
||||
file_path=file_path,
|
||||
file_type="text",
|
||||
source_url=source_url,
|
||||
source_commit=commit_sha,
|
||||
language=MarkdownParser._detect_language(content),
|
||||
placeholders=MarkdownParser._find_placeholders(content),
|
||||
)
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
logger.warning(f"Failed to fetch {file_path}: {e}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {file_path}: {e}")
|
||||
continue
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"HTTP error crawling {source.name}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error crawling {source.name}: {e}")
|
||||
|
||||
|
||||
class RepositoryDownloader:
|
||||
"""Download and extract repository archives."""
|
||||
|
||||
def __init__(self):
|
||||
self.http_client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
async def __aenter__(self):
|
||||
self.http_client = httpx.AsyncClient(
|
||||
timeout=120.0,
|
||||
follow_redirects=True,
|
||||
)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.http_client:
|
||||
await self.http_client.aclose()
|
||||
|
||||
async def download_zip(self, repo_url: str, branch: str = "main") -> Path:
|
||||
"""Download repository as ZIP and extract to temp directory."""
|
||||
if not self.http_client:
|
||||
raise RuntimeError("Downloader not initialized. Use 'async with' context.")
|
||||
|
||||
parsed = urlparse(repo_url)
|
||||
path_parts = parsed.path.strip('/').split('/')
|
||||
owner = path_parts[0]
|
||||
repo = path_parts[1].replace('.git', '')
|
||||
|
||||
zip_url = f"https://github.com/{owner}/{repo}/archive/refs/heads/{branch}.zip"
|
||||
|
||||
logger.info(f"Downloading ZIP from {zip_url}")
|
||||
|
||||
response = await self.http_client.get(zip_url)
|
||||
response.raise_for_status()
|
||||
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
zip_path = temp_dir / f"{repo}.zip"
|
||||
|
||||
with open(zip_path, 'wb') as f:
|
||||
f.write(response.content)
|
||||
|
||||
extract_dir = temp_dir / repo
|
||||
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(temp_dir)
|
||||
|
||||
extracted_dirs = list(temp_dir.glob(f"{repo}-*"))
|
||||
if extracted_dirs:
|
||||
return extracted_dirs[0]
|
||||
|
||||
return extract_dir
|
||||
|
||||
async def crawl_local_directory(
|
||||
self,
|
||||
directory: Path,
|
||||
source: SourceConfig,
|
||||
base_url: str,
|
||||
) -> AsyncGenerator[ExtractedDocument, None]:
|
||||
"""Crawl a local directory for documents."""
|
||||
patterns = source.file_patterns or ["*.md", "*.txt", "*.html"]
|
||||
exclude_patterns = source.exclude_patterns or []
|
||||
|
||||
for pattern in patterns:
|
||||
for file_path in directory.rglob(pattern.replace("**/", "")):
|
||||
if not file_path.is_file():
|
||||
continue
|
||||
|
||||
rel_path = str(file_path.relative_to(directory))
|
||||
|
||||
excluded = any(fnmatch(rel_path, ep) for ep in exclude_patterns)
|
||||
if excluded:
|
||||
continue
|
||||
|
||||
if file_path.stat().st_size > MAX_FILE_SIZE:
|
||||
continue
|
||||
|
||||
try:
|
||||
content = file_path.read_text(encoding='utf-8')
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
content = file_path.read_text(encoding='latin-1')
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
source_url = f"{base_url}/{rel_path}"
|
||||
|
||||
if file_path.suffix == '.md':
|
||||
doc = MarkdownParser.parse(content, rel_path)
|
||||
doc.source_url = source_url
|
||||
yield doc
|
||||
|
||||
elif file_path.suffix in ['.html', '.htm']:
|
||||
doc = HTMLParser.parse(content, rel_path)
|
||||
doc.source_url = source_url
|
||||
yield doc
|
||||
|
||||
elif file_path.suffix == '.json':
|
||||
docs = JSONParser.parse(content, rel_path)
|
||||
for doc in docs:
|
||||
doc.source_url = source_url
|
||||
yield doc
|
||||
|
||||
elif file_path.suffix == '.txt':
|
||||
yield ExtractedDocument(
|
||||
text=content,
|
||||
title=file_path.stem,
|
||||
file_path=rel_path,
|
||||
file_type="text",
|
||||
source_url=source_url,
|
||||
language=MarkdownParser._detect_language(content),
|
||||
placeholders=MarkdownParser._find_placeholders(content),
|
||||
)
|
||||
|
||||
def cleanup(self, directory: Path):
|
||||
"""Clean up temporary directory."""
|
||||
if directory.exists():
|
||||
shutil.rmtree(directory, ignore_errors=True)
|
||||
|
||||
|
||||
async def crawl_source(source: SourceConfig) -> List[ExtractedDocument]:
|
||||
"""Crawl a source configuration and return all extracted documents."""
|
||||
documents = []
|
||||
|
||||
if source.repo_url:
|
||||
async with GitHubCrawler() as crawler:
|
||||
async for doc in crawler.crawl_repository(source):
|
||||
documents.append(doc)
|
||||
|
||||
return documents
|
||||
|
||||
|
||||
# CLI for testing
|
||||
async def main():
|
||||
"""Test crawler with a sample source."""
|
||||
from template_sources import TEMPLATE_SOURCES
|
||||
|
||||
source = next(s for s in TEMPLATE_SOURCES if s.name == "github-site-policy")
|
||||
|
||||
async with GitHubCrawler() as crawler:
|
||||
count = 0
|
||||
async for doc in crawler.crawl_repository(source):
|
||||
count += 1
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Title: {doc.title}")
|
||||
print(f"Path: {doc.file_path}")
|
||||
print(f"Type: {doc.file_type}")
|
||||
print(f"Language: {doc.language}")
|
||||
print(f"URL: {doc.source_url}")
|
||||
print(f"Placeholders: {doc.placeholders[:5] if doc.placeholders else 'None'}")
|
||||
print(f"Text preview: {doc.text[:200]}...")
|
||||
|
||||
print(f"\n\nTotal documents: {count}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
303
klausur-service/backend/github_crawler_parsers.py
Normal file
303
klausur-service/backend/github_crawler_parsers.py
Normal file
@@ -0,0 +1,303 @@
|
||||
"""
|
||||
GitHub Crawler - Document Parsers
|
||||
|
||||
Markdown, HTML, and JSON parsers for extracting structured content
|
||||
from legal template documents.
|
||||
|
||||
Extracted from github_crawler.py to keep files under 500 LOC.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractedDocument:
|
||||
"""A document extracted from a repository."""
|
||||
text: str
|
||||
title: str
|
||||
file_path: str
|
||||
file_type: str # "markdown", "html", "json", "text"
|
||||
source_url: str
|
||||
source_commit: Optional[str] = None
|
||||
source_hash: str = "" # SHA256 of original content
|
||||
sections: List[Dict[str, Any]] = field(default_factory=list)
|
||||
placeholders: List[str] = field(default_factory=list)
|
||||
language: str = "en"
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.source_hash and self.text:
|
||||
self.source_hash = hashlib.sha256(self.text.encode()).hexdigest()
|
||||
|
||||
|
||||
class MarkdownParser:
|
||||
"""Parse Markdown files into structured content."""
|
||||
|
||||
# Common placeholder patterns
|
||||
PLACEHOLDER_PATTERNS = [
|
||||
r'\[([A-Z_]+)\]', # [COMPANY_NAME]
|
||||
r'\{([a-z_]+)\}', # {company_name}
|
||||
r'\{\{([a-z_]+)\}\}', # {{company_name}}
|
||||
r'__([A-Z_]+)__', # __COMPANY_NAME__
|
||||
r'<([A-Z_]+)>', # <COMPANY_NAME>
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def parse(cls, content: str, filename: str = "") -> ExtractedDocument:
|
||||
"""Parse markdown content into an ExtractedDocument."""
|
||||
title = cls._extract_title(content, filename)
|
||||
sections = cls._extract_sections(content)
|
||||
placeholders = cls._find_placeholders(content)
|
||||
language = cls._detect_language(content)
|
||||
clean_text = cls._clean_for_indexing(content)
|
||||
|
||||
return ExtractedDocument(
|
||||
text=clean_text,
|
||||
title=title,
|
||||
file_path=filename,
|
||||
file_type="markdown",
|
||||
source_url="",
|
||||
sections=sections,
|
||||
placeholders=placeholders,
|
||||
language=language,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_title(cls, content: str, filename: str) -> str:
|
||||
"""Extract title from markdown heading or filename."""
|
||||
h1_match = re.search(r'^#\s+(.+)$', content, re.MULTILINE)
|
||||
if h1_match:
|
||||
return h1_match.group(1).strip()
|
||||
|
||||
frontmatter_match = re.search(
|
||||
r'^---\s*\n.*?title:\s*["\']?(.+?)["\']?\s*\n.*?---',
|
||||
content, re.DOTALL
|
||||
)
|
||||
if frontmatter_match:
|
||||
return frontmatter_match.group(1).strip()
|
||||
|
||||
if filename:
|
||||
name = Path(filename).stem
|
||||
return name.replace('-', ' ').replace('_', ' ').title()
|
||||
|
||||
return "Untitled"
|
||||
|
||||
@classmethod
|
||||
def _extract_sections(cls, content: str) -> List[Dict[str, Any]]:
|
||||
"""Extract sections from markdown content."""
|
||||
sections = []
|
||||
current_section = {"heading": "", "level": 0, "content": "", "start": 0}
|
||||
|
||||
for match in re.finditer(r'^(#{1,6})\s+(.+)$', content, re.MULTILINE):
|
||||
if current_section["heading"] or current_section["content"].strip():
|
||||
current_section["content"] = current_section["content"].strip()
|
||||
sections.append(current_section.copy())
|
||||
|
||||
level = len(match.group(1))
|
||||
heading = match.group(2).strip()
|
||||
current_section = {
|
||||
"heading": heading,
|
||||
"level": level,
|
||||
"content": "",
|
||||
"start": match.end(),
|
||||
}
|
||||
|
||||
if current_section["heading"] or current_section["content"].strip():
|
||||
current_section["content"] = content[current_section["start"]:].strip()
|
||||
sections.append(current_section)
|
||||
|
||||
return sections
|
||||
|
||||
@classmethod
|
||||
def _find_placeholders(cls, content: str) -> List[str]:
|
||||
"""Find placeholder patterns in content."""
|
||||
placeholders = set()
|
||||
for pattern in cls.PLACEHOLDER_PATTERNS:
|
||||
for match in re.finditer(pattern, content):
|
||||
placeholder = match.group(0)
|
||||
placeholders.add(placeholder)
|
||||
return sorted(list(placeholders))
|
||||
|
||||
@classmethod
|
||||
def _detect_language(cls, content: str) -> str:
|
||||
"""Detect language from content."""
|
||||
german_indicators = [
|
||||
'Datenschutz', 'Impressum', 'Nutzungsbedingungen', 'Haftung',
|
||||
'Widerruf', 'Verantwortlicher', 'personenbezogene', 'Verarbeitung',
|
||||
'und', 'der', 'die', 'das', 'ist', 'wird', 'werden', 'sind',
|
||||
]
|
||||
|
||||
lower_content = content.lower()
|
||||
german_count = sum(1 for word in german_indicators if word.lower() in lower_content)
|
||||
|
||||
if german_count >= 3:
|
||||
return "de"
|
||||
return "en"
|
||||
|
||||
@classmethod
|
||||
def _clean_for_indexing(cls, content: str) -> str:
|
||||
"""Clean markdown content for text indexing."""
|
||||
content = re.sub(r'^---\s*\n.*?---\s*\n', '', content, flags=re.DOTALL)
|
||||
content = re.sub(r'<!--.*?-->', '', content, flags=re.DOTALL)
|
||||
content = re.sub(r'<[^>]+>', '', content)
|
||||
content = re.sub(r'\*\*(.+?)\*\*', r'\1', content)
|
||||
content = re.sub(r'\*(.+?)\*', r'\1', content)
|
||||
content = re.sub(r'`(.+?)`', r'\1', content)
|
||||
content = re.sub(r'~~(.+?)~~', r'\1', content)
|
||||
content = re.sub(r'\[([^\]]+)\]\([^)]+\)', r'\1', content)
|
||||
content = re.sub(r'!\[([^\]]*)\]\([^)]+\)', r'\1', content)
|
||||
content = re.sub(r'\n{3,}', '\n\n', content)
|
||||
content = re.sub(r' +', ' ', content)
|
||||
|
||||
return content.strip()
|
||||
|
||||
|
||||
class HTMLParser:
|
||||
"""Parse HTML files into structured content."""
|
||||
|
||||
@classmethod
|
||||
def parse(cls, content: str, filename: str = "") -> ExtractedDocument:
|
||||
"""Parse HTML content into an ExtractedDocument."""
|
||||
title_match = re.search(r'<title>(.+?)</title>', content, re.IGNORECASE)
|
||||
title = title_match.group(1) if title_match else Path(filename).stem
|
||||
|
||||
text = cls._html_to_text(content)
|
||||
placeholders = MarkdownParser._find_placeholders(text)
|
||||
|
||||
lang_match = re.search(r'<html[^>]*lang=["\']([a-z]{2})["\']', content, re.IGNORECASE)
|
||||
language = lang_match.group(1) if lang_match else MarkdownParser._detect_language(text)
|
||||
|
||||
return ExtractedDocument(
|
||||
text=text,
|
||||
title=title,
|
||||
file_path=filename,
|
||||
file_type="html",
|
||||
source_url="",
|
||||
placeholders=placeholders,
|
||||
language=language,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _html_to_text(cls, html: str) -> str:
|
||||
"""Convert HTML to clean text."""
|
||||
html = re.sub(r'<script[^>]*>.*?</script>', '', html, flags=re.DOTALL | re.IGNORECASE)
|
||||
html = re.sub(r'<style[^>]*>.*?</style>', '', html, flags=re.DOTALL | re.IGNORECASE)
|
||||
html = re.sub(r'<!--.*?-->', '', html, flags=re.DOTALL)
|
||||
|
||||
html = html.replace(' ', ' ')
|
||||
html = html.replace('&', '&')
|
||||
html = html.replace('<', '<')
|
||||
html = html.replace('>', '>')
|
||||
html = html.replace('"', '"')
|
||||
html = html.replace(''', "'")
|
||||
|
||||
html = re.sub(r'<br\s*/?>', '\n', html, flags=re.IGNORECASE)
|
||||
html = re.sub(r'</p>', '\n\n', html, flags=re.IGNORECASE)
|
||||
html = re.sub(r'</div>', '\n', html, flags=re.IGNORECASE)
|
||||
html = re.sub(r'</h[1-6]>', '\n\n', html, flags=re.IGNORECASE)
|
||||
html = re.sub(r'</li>', '\n', html, flags=re.IGNORECASE)
|
||||
|
||||
html = re.sub(r'<[^>]+>', '', html)
|
||||
|
||||
html = re.sub(r'[ \t]+', ' ', html)
|
||||
html = re.sub(r'\n[ \t]+', '\n', html)
|
||||
html = re.sub(r'[ \t]+\n', '\n', html)
|
||||
html = re.sub(r'\n{3,}', '\n\n', html)
|
||||
|
||||
return html.strip()
|
||||
|
||||
|
||||
class JSONParser:
|
||||
"""Parse JSON files containing legal template data."""
|
||||
|
||||
@classmethod
|
||||
def parse(cls, content: str, filename: str = "") -> List[ExtractedDocument]:
|
||||
"""Parse JSON content into ExtractedDocuments."""
|
||||
try:
|
||||
data = json.loads(content)
|
||||
except json.JSONDecodeError as e:
|
||||
import logging
|
||||
logging.getLogger(__name__).warning(f"Failed to parse JSON from {filename}: {e}")
|
||||
return []
|
||||
|
||||
documents = []
|
||||
|
||||
if isinstance(data, dict):
|
||||
documents.extend(cls._parse_dict(data, filename))
|
||||
elif isinstance(data, list):
|
||||
for i, item in enumerate(data):
|
||||
if isinstance(item, dict):
|
||||
docs = cls._parse_dict(item, f"{filename}[{i}]")
|
||||
documents.extend(docs)
|
||||
|
||||
return documents
|
||||
|
||||
@classmethod
|
||||
def _parse_dict(cls, data: dict, filename: str) -> List[ExtractedDocument]:
|
||||
"""Parse a dictionary into documents."""
|
||||
documents = []
|
||||
|
||||
text_keys = ['text', 'content', 'body', 'description', 'value']
|
||||
title_keys = ['title', 'name', 'heading', 'label', 'key']
|
||||
|
||||
text = ""
|
||||
for key in text_keys:
|
||||
if key in data and isinstance(data[key], str):
|
||||
text = data[key]
|
||||
break
|
||||
|
||||
if not text:
|
||||
for key, value in data.items():
|
||||
if isinstance(value, dict):
|
||||
nested_docs = cls._parse_dict(value, f"{filename}.{key}")
|
||||
documents.extend(nested_docs)
|
||||
elif isinstance(value, list):
|
||||
for i, item in enumerate(value):
|
||||
if isinstance(item, dict):
|
||||
nested_docs = cls._parse_dict(item, f"{filename}.{key}[{i}]")
|
||||
documents.extend(nested_docs)
|
||||
elif isinstance(item, str) and len(item) > 50:
|
||||
documents.append(ExtractedDocument(
|
||||
text=item,
|
||||
title=f"{key} {i+1}",
|
||||
file_path=filename,
|
||||
file_type="json",
|
||||
source_url="",
|
||||
language=MarkdownParser._detect_language(item),
|
||||
))
|
||||
return documents
|
||||
|
||||
title = ""
|
||||
for key in title_keys:
|
||||
if key in data and isinstance(data[key], str):
|
||||
title = data[key]
|
||||
break
|
||||
|
||||
if not title:
|
||||
title = Path(filename).stem
|
||||
|
||||
metadata = {}
|
||||
for key, value in data.items():
|
||||
if key not in text_keys + title_keys and not isinstance(value, (dict, list)):
|
||||
metadata[key] = value
|
||||
|
||||
placeholders = MarkdownParser._find_placeholders(text)
|
||||
language = data.get('lang', data.get('language', MarkdownParser._detect_language(text)))
|
||||
|
||||
documents.append(ExtractedDocument(
|
||||
text=text,
|
||||
title=title,
|
||||
file_path=filename,
|
||||
file_type="json",
|
||||
source_url="",
|
||||
placeholders=placeholders,
|
||||
language=language,
|
||||
metadata=metadata,
|
||||
))
|
||||
|
||||
return documents
|
||||
@@ -1,790 +1,30 @@
|
||||
"""
|
||||
Legal Corpus API - Endpoints for RAG page in admin-v2
|
||||
Legal Corpus API — Barrel Re-export
|
||||
|
||||
Provides endpoints for:
|
||||
- GET /api/v1/admin/legal-corpus/status - Collection status with chunk counts
|
||||
- GET /api/v1/admin/legal-corpus/search - Semantic search
|
||||
- POST /api/v1/admin/legal-corpus/ingest - Trigger ingestion
|
||||
- GET /api/v1/admin/legal-corpus/ingestion-status - Ingestion status
|
||||
- POST /api/v1/admin/legal-corpus/upload - Upload document
|
||||
- POST /api/v1/admin/legal-corpus/add-link - Add link for ingestion
|
||||
- POST /api/v1/admin/pipeline/start - Start compliance pipeline
|
||||
Split into:
|
||||
- legal_corpus_routes.py — Corpus endpoints (status, search, ingest, upload)
|
||||
- legal_corpus_pipeline.py — Pipeline endpoints (checkpoints, start, status)
|
||||
|
||||
All public names are re-exported here for backward compatibility.
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import httpx
|
||||
import uuid
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict, Any
|
||||
from fastapi import APIRouter, HTTPException, Query, BackgroundTasks, UploadFile, File, Form
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/admin/legal-corpus", tags=["legal-corpus"])
|
||||
|
||||
# Configuration
|
||||
QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6333")
|
||||
EMBEDDING_SERVICE_URL = os.getenv("EMBEDDING_SERVICE_URL", "http://embedding-service:8087")
|
||||
COLLECTION_NAME = "bp_legal_corpus"
|
||||
|
||||
# All regulations for status endpoint
|
||||
REGULATIONS = [
|
||||
{"code": "GDPR", "name": "DSGVO", "fullName": "Datenschutz-Grundverordnung", "type": "eu_regulation"},
|
||||
{"code": "EPRIVACY", "name": "ePrivacy-Richtlinie", "fullName": "Richtlinie 2002/58/EG", "type": "eu_directive"},
|
||||
{"code": "TDDDG", "name": "TDDDG", "fullName": "Telekommunikation-Digitale-Dienste-Datenschutz-Gesetz", "type": "de_law"},
|
||||
{"code": "SCC", "name": "Standardvertragsklauseln", "fullName": "2021/914/EU", "type": "eu_regulation"},
|
||||
{"code": "DPF", "name": "EU-US Data Privacy Framework", "fullName": "Angemessenheitsbeschluss", "type": "eu_regulation"},
|
||||
{"code": "AIACT", "name": "EU AI Act", "fullName": "Verordnung (EU) 2024/1689", "type": "eu_regulation"},
|
||||
{"code": "CRA", "name": "Cyber Resilience Act", "fullName": "Verordnung (EU) 2024/2847", "type": "eu_regulation"},
|
||||
{"code": "NIS2", "name": "NIS2-Richtlinie", "fullName": "Richtlinie (EU) 2022/2555", "type": "eu_directive"},
|
||||
{"code": "EUCSA", "name": "EU Cybersecurity Act", "fullName": "Verordnung (EU) 2019/881", "type": "eu_regulation"},
|
||||
{"code": "DATAACT", "name": "Data Act", "fullName": "Verordnung (EU) 2023/2854", "type": "eu_regulation"},
|
||||
{"code": "DGA", "name": "Data Governance Act", "fullName": "Verordnung (EU) 2022/868", "type": "eu_regulation"},
|
||||
{"code": "DSA", "name": "Digital Services Act", "fullName": "Verordnung (EU) 2022/2065", "type": "eu_regulation"},
|
||||
{"code": "EAA", "name": "European Accessibility Act", "fullName": "Richtlinie (EU) 2019/882", "type": "eu_directive"},
|
||||
{"code": "DSM", "name": "DSM-Urheberrechtsrichtlinie", "fullName": "Richtlinie (EU) 2019/790", "type": "eu_directive"},
|
||||
{"code": "PLD", "name": "Produkthaftungsrichtlinie", "fullName": "Richtlinie 85/374/EWG", "type": "eu_directive"},
|
||||
{"code": "GPSR", "name": "General Product Safety", "fullName": "Verordnung (EU) 2023/988", "type": "eu_regulation"},
|
||||
{"code": "BSI-TR-03161-1", "name": "BSI-TR Teil 1", "fullName": "BSI TR-03161 Teil 1 - Mobile Anwendungen", "type": "bsi_standard"},
|
||||
{"code": "BSI-TR-03161-2", "name": "BSI-TR Teil 2", "fullName": "BSI TR-03161 Teil 2 - Web-Anwendungen", "type": "bsi_standard"},
|
||||
{"code": "BSI-TR-03161-3", "name": "BSI-TR Teil 3", "fullName": "BSI TR-03161 Teil 3 - Hintergrundsysteme", "type": "bsi_standard"},
|
||||
]
|
||||
|
||||
# Ingestion state (in-memory for now)
|
||||
ingestion_state = {
|
||||
"running": False,
|
||||
"completed": False,
|
||||
"current_regulation": None,
|
||||
"processed": 0,
|
||||
"total": len(REGULATIONS),
|
||||
"error": None,
|
||||
}
|
||||
|
||||
|
||||
class SearchRequest(BaseModel):
|
||||
query: str
|
||||
regulations: Optional[List[str]] = None
|
||||
top_k: int = 5
|
||||
|
||||
|
||||
class IngestRequest(BaseModel):
|
||||
force: bool = False
|
||||
regulations: Optional[List[str]] = None
|
||||
|
||||
|
||||
class AddLinkRequest(BaseModel):
|
||||
url: str
|
||||
title: str
|
||||
code: str # Regulation code (e.g. "CUSTOM-1")
|
||||
document_type: str = "custom" # custom, eu_regulation, eu_directive, de_law, bsi_standard
|
||||
|
||||
|
||||
class StartPipelineRequest(BaseModel):
|
||||
force_reindex: bool = False
|
||||
skip_ingestion: bool = False
|
||||
|
||||
|
||||
# Store for custom documents (in-memory for now, should be persisted)
|
||||
custom_documents: List[Dict[str, Any]] = []
|
||||
|
||||
|
||||
async def get_qdrant_client():
|
||||
"""Get async HTTP client for Qdrant."""
|
||||
return httpx.AsyncClient(timeout=30.0)
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
async def get_legal_corpus_status():
|
||||
"""
|
||||
Get status of the legal corpus collection including chunk counts per regulation.
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
try:
|
||||
# Get collection info
|
||||
collection_res = await client.get(f"{QDRANT_URL}/collections/{COLLECTION_NAME}")
|
||||
if collection_res.status_code != 200:
|
||||
return {
|
||||
"collection": COLLECTION_NAME,
|
||||
"totalPoints": 0,
|
||||
"vectorSize": 1024,
|
||||
"status": "not_found",
|
||||
"regulations": {},
|
||||
}
|
||||
|
||||
collection_data = collection_res.json()
|
||||
result = collection_data.get("result", {})
|
||||
|
||||
# Get chunk counts per regulation
|
||||
regulation_counts = {}
|
||||
for reg in REGULATIONS:
|
||||
count_res = await client.post(
|
||||
f"{QDRANT_URL}/collections/{COLLECTION_NAME}/points/count",
|
||||
json={
|
||||
"filter": {
|
||||
"must": [{"key": "regulation_code", "match": {"value": reg["code"]}}]
|
||||
}
|
||||
},
|
||||
)
|
||||
if count_res.status_code == 200:
|
||||
count_data = count_res.json()
|
||||
regulation_counts[reg["code"]] = count_data.get("result", {}).get("count", 0)
|
||||
else:
|
||||
regulation_counts[reg["code"]] = 0
|
||||
|
||||
return {
|
||||
"collection": COLLECTION_NAME,
|
||||
"totalPoints": result.get("points_count", 0),
|
||||
"vectorSize": result.get("config", {}).get("params", {}).get("vectors", {}).get("size", 1024),
|
||||
"status": result.get("status", "unknown"),
|
||||
"regulations": regulation_counts,
|
||||
}
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Failed to get Qdrant status: {e}")
|
||||
raise HTTPException(status_code=503, detail=f"Qdrant not available: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/search")
|
||||
async def search_legal_corpus(
|
||||
query: str = Query(..., description="Search query"),
|
||||
top_k: int = Query(5, ge=1, le=20, description="Number of results"),
|
||||
regulations: Optional[str] = Query(None, description="Comma-separated regulation codes to filter"),
|
||||
):
|
||||
"""
|
||||
Semantic search in legal corpus using BGE-M3 embeddings.
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
try:
|
||||
# Generate embedding for query
|
||||
embed_res = await client.post(
|
||||
f"{EMBEDDING_SERVICE_URL}/embed",
|
||||
json={"texts": [query]},
|
||||
)
|
||||
if embed_res.status_code != 200:
|
||||
raise HTTPException(status_code=500, detail="Embedding service error")
|
||||
|
||||
embed_data = embed_res.json()
|
||||
query_vector = embed_data["embeddings"][0]
|
||||
|
||||
# Build Qdrant search request
|
||||
search_request = {
|
||||
"vector": query_vector,
|
||||
"limit": top_k,
|
||||
"with_payload": True,
|
||||
}
|
||||
|
||||
# Add regulation filter if specified
|
||||
if regulations:
|
||||
reg_codes = [r.strip() for r in regulations.split(",")]
|
||||
search_request["filter"] = {
|
||||
"should": [
|
||||
{"key": "regulation_code", "match": {"value": code}}
|
||||
for code in reg_codes
|
||||
]
|
||||
}
|
||||
|
||||
# Search Qdrant
|
||||
search_res = await client.post(
|
||||
f"{QDRANT_URL}/collections/{COLLECTION_NAME}/points/search",
|
||||
json=search_request,
|
||||
)
|
||||
|
||||
if search_res.status_code != 200:
|
||||
raise HTTPException(status_code=500, detail="Search failed")
|
||||
|
||||
search_data = search_res.json()
|
||||
results = []
|
||||
for point in search_data.get("result", []):
|
||||
payload = point.get("payload", {})
|
||||
results.append({
|
||||
"text": payload.get("text", ""),
|
||||
"regulation_code": payload.get("regulation_code", ""),
|
||||
"regulation_name": payload.get("regulation_name", ""),
|
||||
"article": payload.get("article"),
|
||||
"paragraph": payload.get("paragraph"),
|
||||
"source_url": payload.get("source_url", ""),
|
||||
"score": point.get("score", 0),
|
||||
})
|
||||
|
||||
return {"results": results, "query": query, "count": len(results)}
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Search failed: {e}")
|
||||
raise HTTPException(status_code=503, detail=f"Service not available: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/ingest")
|
||||
async def trigger_ingestion(request: IngestRequest, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
Trigger legal corpus ingestion in background.
|
||||
"""
|
||||
global ingestion_state
|
||||
|
||||
if ingestion_state["running"]:
|
||||
raise HTTPException(status_code=409, detail="Ingestion already running")
|
||||
|
||||
# Reset state
|
||||
ingestion_state = {
|
||||
"running": True,
|
||||
"completed": False,
|
||||
"current_regulation": None,
|
||||
"processed": 0,
|
||||
"total": len(REGULATIONS),
|
||||
"error": None,
|
||||
}
|
||||
|
||||
# Start ingestion in background
|
||||
background_tasks.add_task(run_ingestion, request.force, request.regulations)
|
||||
|
||||
return {
|
||||
"status": "started",
|
||||
"job_id": "manual-trigger",
|
||||
"message": f"Ingestion started for {len(REGULATIONS)} regulations",
|
||||
}
|
||||
|
||||
|
||||
async def run_ingestion(force: bool, regulations: Optional[List[str]]):
|
||||
"""Background task for running ingestion."""
|
||||
global ingestion_state
|
||||
|
||||
try:
|
||||
# Import ingestion module
|
||||
from legal_corpus_ingestion import LegalCorpusIngestion
|
||||
|
||||
ingestion = LegalCorpusIngestion()
|
||||
|
||||
# Filter regulations if specified
|
||||
regs_to_process = regulations or [r["code"] for r in REGULATIONS]
|
||||
|
||||
for i, reg_code in enumerate(regs_to_process):
|
||||
ingestion_state["current_regulation"] = reg_code
|
||||
ingestion_state["processed"] = i
|
||||
|
||||
try:
|
||||
await ingestion.ingest_single(reg_code, force=force)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to ingest {reg_code}: {e}")
|
||||
|
||||
ingestion_state["completed"] = True
|
||||
ingestion_state["processed"] = len(regs_to_process)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Ingestion failed: {e}")
|
||||
ingestion_state["error"] = str(e)
|
||||
|
||||
finally:
|
||||
ingestion_state["running"] = False
|
||||
|
||||
|
||||
@router.get("/ingestion-status")
|
||||
async def get_ingestion_status():
|
||||
"""
|
||||
Get current ingestion status.
|
||||
"""
|
||||
return ingestion_state
|
||||
|
||||
|
||||
@router.get("/regulations")
|
||||
async def get_regulations():
|
||||
"""
|
||||
Get list of all supported regulations.
|
||||
"""
|
||||
return {"regulations": REGULATIONS}
|
||||
|
||||
|
||||
@router.get("/custom-documents")
|
||||
async def get_custom_documents():
|
||||
"""
|
||||
Get list of custom documents added by user.
|
||||
"""
|
||||
return {"documents": custom_documents}
|
||||
|
||||
|
||||
@router.post("/upload")
|
||||
async def upload_document(
|
||||
background_tasks: BackgroundTasks,
|
||||
file: UploadFile = File(...),
|
||||
title: str = Form(...),
|
||||
code: str = Form(...),
|
||||
document_type: str = Form("custom"),
|
||||
):
|
||||
"""
|
||||
Upload a document (PDF) for ingestion into the legal corpus.
|
||||
|
||||
The document will be saved and queued for processing.
|
||||
"""
|
||||
global custom_documents
|
||||
|
||||
# Validate file type
|
||||
if not file.filename.endswith(('.pdf', '.PDF')):
|
||||
raise HTTPException(status_code=400, detail="Only PDF files are supported")
|
||||
|
||||
# Create upload directory if needed
|
||||
upload_dir = "/tmp/legal_corpus_uploads"
|
||||
os.makedirs(upload_dir, exist_ok=True)
|
||||
|
||||
# Save file with unique name
|
||||
doc_id = str(uuid.uuid4())[:8]
|
||||
safe_filename = f"{doc_id}_{file.filename}"
|
||||
file_path = os.path.join(upload_dir, safe_filename)
|
||||
|
||||
try:
|
||||
with open(file_path, "wb") as buffer:
|
||||
shutil.copyfileobj(file.file, buffer)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save uploaded file: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to save file: {str(e)}")
|
||||
|
||||
# Create document record
|
||||
doc_record = {
|
||||
"id": doc_id,
|
||||
"code": code,
|
||||
"title": title,
|
||||
"filename": file.filename,
|
||||
"file_path": file_path,
|
||||
"document_type": document_type,
|
||||
"uploaded_at": datetime.now().isoformat(),
|
||||
"status": "uploaded",
|
||||
"chunk_count": 0,
|
||||
}
|
||||
|
||||
custom_documents.append(doc_record)
|
||||
|
||||
# Queue for background ingestion
|
||||
background_tasks.add_task(ingest_uploaded_document, doc_record)
|
||||
|
||||
return {
|
||||
"status": "uploaded",
|
||||
"document_id": doc_id,
|
||||
"message": f"Document '{title}' uploaded and queued for ingestion",
|
||||
"document": doc_record,
|
||||
}
|
||||
|
||||
|
||||
async def ingest_uploaded_document(doc_record: Dict[str, Any]):
|
||||
"""Background task to ingest an uploaded document."""
|
||||
global custom_documents
|
||||
|
||||
try:
|
||||
doc_record["status"] = "processing"
|
||||
|
||||
from legal_corpus_ingestion import LegalCorpusIngestion
|
||||
ingestion = LegalCorpusIngestion()
|
||||
|
||||
# Read PDF and extract text
|
||||
import fitz # PyMuPDF
|
||||
|
||||
doc = fitz.open(doc_record["file_path"])
|
||||
full_text = ""
|
||||
for page in doc:
|
||||
full_text += page.get_text()
|
||||
doc.close()
|
||||
|
||||
if not full_text.strip():
|
||||
doc_record["status"] = "error"
|
||||
doc_record["error"] = "No text could be extracted from PDF"
|
||||
return
|
||||
|
||||
# Chunk the text
|
||||
chunks = ingestion.chunk_text(full_text, doc_record["code"])
|
||||
|
||||
# Add metadata
|
||||
for chunk in chunks:
|
||||
chunk["regulation_code"] = doc_record["code"]
|
||||
chunk["regulation_name"] = doc_record["title"]
|
||||
chunk["document_type"] = doc_record["document_type"]
|
||||
chunk["source_url"] = f"upload://{doc_record['filename']}"
|
||||
|
||||
# Generate embeddings and upsert to Qdrant
|
||||
if chunks:
|
||||
await ingestion.embed_and_upsert(chunks)
|
||||
doc_record["chunk_count"] = len(chunks)
|
||||
doc_record["status"] = "indexed"
|
||||
logger.info(f"Ingested {len(chunks)} chunks from uploaded document {doc_record['code']}")
|
||||
else:
|
||||
doc_record["status"] = "error"
|
||||
doc_record["error"] = "No chunks generated from document"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to ingest uploaded document: {e}")
|
||||
doc_record["status"] = "error"
|
||||
doc_record["error"] = str(e)
|
||||
|
||||
|
||||
@router.post("/add-link")
|
||||
async def add_link(request: AddLinkRequest, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
Add a URL/link for ingestion into the legal corpus.
|
||||
|
||||
The content will be fetched, extracted, and indexed.
|
||||
"""
|
||||
global custom_documents
|
||||
|
||||
# Create document record
|
||||
doc_id = str(uuid.uuid4())[:8]
|
||||
doc_record = {
|
||||
"id": doc_id,
|
||||
"code": request.code,
|
||||
"title": request.title,
|
||||
"url": request.url,
|
||||
"document_type": request.document_type,
|
||||
"uploaded_at": datetime.now().isoformat(),
|
||||
"status": "queued",
|
||||
"chunk_count": 0,
|
||||
}
|
||||
|
||||
custom_documents.append(doc_record)
|
||||
|
||||
# Queue for background ingestion
|
||||
background_tasks.add_task(ingest_link_document, doc_record)
|
||||
|
||||
return {
|
||||
"status": "queued",
|
||||
"document_id": doc_id,
|
||||
"message": f"Link '{request.title}' queued for ingestion",
|
||||
"document": doc_record,
|
||||
}
|
||||
|
||||
|
||||
async def ingest_link_document(doc_record: Dict[str, Any]):
|
||||
"""Background task to ingest content from a URL."""
|
||||
global custom_documents
|
||||
|
||||
try:
|
||||
doc_record["status"] = "fetching"
|
||||
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
# Fetch the URL
|
||||
response = await client.get(doc_record["url"], follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
|
||||
content_type = response.headers.get("content-type", "")
|
||||
|
||||
if "application/pdf" in content_type:
|
||||
# Save PDF and process
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
|
||||
f.write(response.content)
|
||||
pdf_path = f.name
|
||||
|
||||
import fitz
|
||||
pdf_doc = fitz.open(pdf_path)
|
||||
full_text = ""
|
||||
for page in pdf_doc:
|
||||
full_text += page.get_text()
|
||||
pdf_doc.close()
|
||||
os.unlink(pdf_path)
|
||||
|
||||
elif "text/html" in content_type:
|
||||
# Extract text from HTML
|
||||
from bs4 import BeautifulSoup
|
||||
soup = BeautifulSoup(response.text, "html.parser")
|
||||
|
||||
# Remove script and style elements
|
||||
for script in soup(["script", "style", "nav", "footer", "header"]):
|
||||
script.decompose()
|
||||
|
||||
full_text = soup.get_text(separator="\n", strip=True)
|
||||
|
||||
else:
|
||||
# Try to use as plain text
|
||||
full_text = response.text
|
||||
|
||||
if not full_text.strip():
|
||||
doc_record["status"] = "error"
|
||||
doc_record["error"] = "No text could be extracted from URL"
|
||||
return
|
||||
|
||||
doc_record["status"] = "processing"
|
||||
|
||||
from legal_corpus_ingestion import LegalCorpusIngestion
|
||||
ingestion = LegalCorpusIngestion()
|
||||
|
||||
# Chunk the text
|
||||
chunks = ingestion.chunk_text(full_text, doc_record["code"])
|
||||
|
||||
# Add metadata
|
||||
for chunk in chunks:
|
||||
chunk["regulation_code"] = doc_record["code"]
|
||||
chunk["regulation_name"] = doc_record["title"]
|
||||
chunk["document_type"] = doc_record["document_type"]
|
||||
chunk["source_url"] = doc_record["url"]
|
||||
|
||||
# Generate embeddings and upsert to Qdrant
|
||||
if chunks:
|
||||
await ingestion.embed_and_upsert(chunks)
|
||||
doc_record["chunk_count"] = len(chunks)
|
||||
doc_record["status"] = "indexed"
|
||||
logger.info(f"Ingested {len(chunks)} chunks from URL {doc_record['url']}")
|
||||
else:
|
||||
doc_record["status"] = "error"
|
||||
doc_record["error"] = "No chunks generated from content"
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"Failed to fetch URL: {e}")
|
||||
doc_record["status"] = "error"
|
||||
doc_record["error"] = f"Failed to fetch URL: {str(e)}"
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to ingest URL content: {e}")
|
||||
doc_record["status"] = "error"
|
||||
doc_record["error"] = str(e)
|
||||
|
||||
|
||||
@router.delete("/custom-documents/{doc_id}")
|
||||
async def delete_custom_document(doc_id: str):
|
||||
"""
|
||||
Delete a custom document from the list.
|
||||
Note: This does not remove the chunks from Qdrant yet.
|
||||
"""
|
||||
global custom_documents
|
||||
|
||||
doc = next((d for d in custom_documents if d["id"] == doc_id), None)
|
||||
if not doc:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
custom_documents = [d for d in custom_documents if d["id"] != doc_id]
|
||||
|
||||
# TODO: Also remove chunks from Qdrant by filtering on code
|
||||
|
||||
return {"status": "deleted", "document_id": doc_id}
|
||||
|
||||
|
||||
# ========== Pipeline Checkpoints ==========
|
||||
|
||||
# Create a separate router for pipeline-related endpoints
|
||||
pipeline_router = APIRouter(prefix="/api/v1/admin/pipeline", tags=["pipeline"])
|
||||
|
||||
|
||||
@pipeline_router.get("/checkpoints")
|
||||
async def get_pipeline_checkpoints():
|
||||
"""
|
||||
Get current pipeline checkpoint state.
|
||||
|
||||
Returns the current state of the compliance pipeline including:
|
||||
- Pipeline ID and overall status
|
||||
- Start and completion times
|
||||
- All checkpoints with their validations and metrics
|
||||
- Summary data
|
||||
"""
|
||||
from pipeline_checkpoints import CheckpointManager
|
||||
|
||||
state = CheckpointManager.load_state()
|
||||
|
||||
if state is None:
|
||||
return {
|
||||
"status": "no_data",
|
||||
"message": "No pipeline run data available yet.",
|
||||
"pipeline_id": None,
|
||||
"checkpoints": [],
|
||||
"summary": {}
|
||||
}
|
||||
|
||||
# Enrich with validation summary
|
||||
validation_summary = {
|
||||
"passed": 0,
|
||||
"warning": 0,
|
||||
"failed": 0,
|
||||
"total": 0
|
||||
}
|
||||
|
||||
for checkpoint in state.get("checkpoints", []):
|
||||
for validation in checkpoint.get("validations", []):
|
||||
validation_summary["total"] += 1
|
||||
status = validation.get("status", "not_run")
|
||||
if status in validation_summary:
|
||||
validation_summary[status] += 1
|
||||
|
||||
state["validation_summary"] = validation_summary
|
||||
|
||||
return state
|
||||
|
||||
|
||||
@pipeline_router.get("/checkpoints/history")
|
||||
async def get_pipeline_history():
|
||||
"""
|
||||
Get list of previous pipeline runs (if stored).
|
||||
For now, returns only current run.
|
||||
"""
|
||||
from pipeline_checkpoints import CheckpointManager
|
||||
|
||||
state = CheckpointManager.load_state()
|
||||
|
||||
if state is None:
|
||||
return {"runs": []}
|
||||
|
||||
return {
|
||||
"runs": [{
|
||||
"pipeline_id": state.get("pipeline_id"),
|
||||
"status": state.get("status"),
|
||||
"started_at": state.get("started_at"),
|
||||
"completed_at": state.get("completed_at"),
|
||||
}]
|
||||
}
|
||||
|
||||
|
||||
# Pipeline state for start/stop
|
||||
pipeline_process_state = {
|
||||
"running": False,
|
||||
"pid": None,
|
||||
"started_at": None,
|
||||
}
|
||||
|
||||
|
||||
@pipeline_router.post("/start")
|
||||
async def start_pipeline(request: StartPipelineRequest, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
Start the compliance pipeline in the background.
|
||||
|
||||
This runs the full_compliance_pipeline.py script which:
|
||||
1. Ingests all legal documents (unless skip_ingestion=True)
|
||||
2. Extracts requirements and controls
|
||||
3. Generates compliance measures
|
||||
4. Creates checkpoint data for monitoring
|
||||
"""
|
||||
global pipeline_process_state
|
||||
|
||||
# Check if already running
|
||||
from pipeline_checkpoints import CheckpointManager
|
||||
state = CheckpointManager.load_state()
|
||||
|
||||
if state and state.get("status") == "running":
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Pipeline is already running"
|
||||
)
|
||||
|
||||
if pipeline_process_state["running"]:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Pipeline start already in progress"
|
||||
)
|
||||
|
||||
pipeline_process_state["running"] = True
|
||||
pipeline_process_state["started_at"] = datetime.now().isoformat()
|
||||
|
||||
# Start pipeline in background
|
||||
background_tasks.add_task(
|
||||
run_pipeline_background,
|
||||
request.force_reindex,
|
||||
request.skip_ingestion
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "starting",
|
||||
"message": "Compliance pipeline is starting in background",
|
||||
"started_at": pipeline_process_state["started_at"],
|
||||
}
|
||||
|
||||
|
||||
async def run_pipeline_background(force_reindex: bool, skip_ingestion: bool):
|
||||
"""Background task to run the compliance pipeline."""
|
||||
global pipeline_process_state
|
||||
|
||||
try:
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
# Build command
|
||||
cmd = [sys.executable, "full_compliance_pipeline.py"]
|
||||
if force_reindex:
|
||||
cmd.append("--force-reindex")
|
||||
if skip_ingestion:
|
||||
cmd.append("--skip-ingestion")
|
||||
|
||||
# Run as subprocess
|
||||
logger.info(f"Starting pipeline: {' '.join(cmd)}")
|
||||
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
cwd=os.path.dirname(os.path.abspath(__file__)),
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
)
|
||||
|
||||
pipeline_process_state["pid"] = process.pid
|
||||
|
||||
# Wait for completion (non-blocking via asyncio)
|
||||
import asyncio
|
||||
while process.poll() is None:
|
||||
await asyncio.sleep(5)
|
||||
|
||||
return_code = process.returncode
|
||||
|
||||
if return_code != 0:
|
||||
output = process.stdout.read() if process.stdout else ""
|
||||
logger.error(f"Pipeline failed with code {return_code}: {output}")
|
||||
else:
|
||||
logger.info("Pipeline completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to run pipeline: {e}")
|
||||
|
||||
finally:
|
||||
pipeline_process_state["running"] = False
|
||||
pipeline_process_state["pid"] = None
|
||||
|
||||
|
||||
@pipeline_router.get("/status")
|
||||
async def get_pipeline_status():
|
||||
"""
|
||||
Get current pipeline running status.
|
||||
"""
|
||||
from pipeline_checkpoints import CheckpointManager
|
||||
|
||||
state = CheckpointManager.load_state()
|
||||
checkpoint_status = state.get("status") if state else "no_data"
|
||||
|
||||
return {
|
||||
"process_running": pipeline_process_state["running"],
|
||||
"process_pid": pipeline_process_state["pid"],
|
||||
"process_started_at": pipeline_process_state["started_at"],
|
||||
"checkpoint_status": checkpoint_status,
|
||||
"current_phase": state.get("current_phase") if state else None,
|
||||
}
|
||||
|
||||
|
||||
# ========== Traceability / Quality Endpoints ==========
|
||||
|
||||
@router.get("/traceability")
|
||||
async def get_traceability(
|
||||
chunk_id: str = Query(..., description="Chunk ID or identifier"),
|
||||
regulation: str = Query(..., description="Regulation code"),
|
||||
):
|
||||
"""
|
||||
Get traceability information for a specific chunk.
|
||||
|
||||
Returns:
|
||||
- The chunk details
|
||||
- Requirements extracted from this chunk
|
||||
- Controls derived from those requirements
|
||||
|
||||
Note: This is a placeholder that will be enhanced once the
|
||||
requirements extraction pipeline is fully implemented.
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
try:
|
||||
# Try to find the chunk by scrolling through points with the regulation filter
|
||||
# In a production system, we would have proper IDs and indexing
|
||||
|
||||
# For now, return placeholder structure
|
||||
# The actual implementation will query:
|
||||
# 1. The chunk from Qdrant
|
||||
# 2. Requirements from a requirements collection/table
|
||||
# 3. Controls from a controls collection/table
|
||||
|
||||
return {
|
||||
"chunk_id": chunk_id,
|
||||
"regulation": regulation,
|
||||
"requirements": [],
|
||||
"controls": [],
|
||||
"message": "Traceability-Daten werden verfuegbar sein, sobald die Requirements-Extraktion und Control-Ableitung implementiert sind."
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get traceability: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Traceability lookup failed: {str(e)}")
|
||||
# Corpus routes and state
|
||||
from legal_corpus_routes import ( # noqa: F401
|
||||
router,
|
||||
REGULATIONS,
|
||||
COLLECTION_NAME,
|
||||
QDRANT_URL,
|
||||
EMBEDDING_SERVICE_URL,
|
||||
ingestion_state,
|
||||
custom_documents,
|
||||
SearchRequest,
|
||||
IngestRequest,
|
||||
AddLinkRequest,
|
||||
)
|
||||
|
||||
# Pipeline routes and state
|
||||
from legal_corpus_pipeline import ( # noqa: F401
|
||||
pipeline_router,
|
||||
pipeline_process_state,
|
||||
StartPipelineRequest,
|
||||
)
|
||||
|
||||
166
klausur-service/backend/legal_corpus_ingest_tasks.py
Normal file
166
klausur-service/backend/legal_corpus_ingest_tasks.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
Legal Corpus API - Background Ingestion Tasks
|
||||
|
||||
Background tasks for ingesting uploaded documents and URL links
|
||||
into the legal corpus vector database.
|
||||
|
||||
Extracted from legal_corpus_routes.py to keep files under 500 LOC.
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, List
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def ingest_uploaded_document(doc_record: Dict[str, Any]):
|
||||
"""Background task to ingest an uploaded document."""
|
||||
try:
|
||||
doc_record["status"] = "processing"
|
||||
|
||||
from legal_corpus_ingestion import LegalCorpusIngestion
|
||||
ingestion = LegalCorpusIngestion()
|
||||
|
||||
import fitz
|
||||
doc = fitz.open(doc_record["file_path"])
|
||||
full_text = ""
|
||||
for page in doc:
|
||||
full_text += page.get_text()
|
||||
doc.close()
|
||||
|
||||
if not full_text.strip():
|
||||
doc_record["status"] = "error"
|
||||
doc_record["error"] = "No text could be extracted from PDF"
|
||||
return
|
||||
|
||||
chunks = ingestion.chunk_text(full_text, doc_record["code"])
|
||||
|
||||
for chunk in chunks:
|
||||
chunk["regulation_code"] = doc_record["code"]
|
||||
chunk["regulation_name"] = doc_record["title"]
|
||||
chunk["document_type"] = doc_record["document_type"]
|
||||
chunk["source_url"] = f"upload://{doc_record['filename']}"
|
||||
|
||||
if chunks:
|
||||
await ingestion.embed_and_upsert(chunks)
|
||||
doc_record["chunk_count"] = len(chunks)
|
||||
doc_record["status"] = "indexed"
|
||||
logger.info(f"Ingested {len(chunks)} chunks from uploaded document {doc_record['code']}")
|
||||
else:
|
||||
doc_record["status"] = "error"
|
||||
doc_record["error"] = "No chunks generated from document"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to ingest uploaded document: {e}")
|
||||
doc_record["status"] = "error"
|
||||
doc_record["error"] = str(e)
|
||||
|
||||
|
||||
async def ingest_link_document(doc_record: Dict[str, Any]):
|
||||
"""Background task to ingest content from a URL."""
|
||||
try:
|
||||
doc_record["status"] = "fetching"
|
||||
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
response = await client.get(doc_record["url"], follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
|
||||
content_type = response.headers.get("content-type", "")
|
||||
|
||||
if "application/pdf" in content_type:
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
|
||||
f.write(response.content)
|
||||
pdf_path = f.name
|
||||
|
||||
import fitz
|
||||
pdf_doc = fitz.open(pdf_path)
|
||||
full_text = ""
|
||||
for page in pdf_doc:
|
||||
full_text += page.get_text()
|
||||
pdf_doc.close()
|
||||
os.unlink(pdf_path)
|
||||
|
||||
elif "text/html" in content_type:
|
||||
from bs4 import BeautifulSoup
|
||||
soup = BeautifulSoup(response.text, "html.parser")
|
||||
|
||||
for script in soup(["script", "style", "nav", "footer", "header"]):
|
||||
script.decompose()
|
||||
|
||||
full_text = soup.get_text(separator="\n", strip=True)
|
||||
|
||||
else:
|
||||
full_text = response.text
|
||||
|
||||
if not full_text.strip():
|
||||
doc_record["status"] = "error"
|
||||
doc_record["error"] = "No text could be extracted from URL"
|
||||
return
|
||||
|
||||
doc_record["status"] = "processing"
|
||||
|
||||
from legal_corpus_ingestion import LegalCorpusIngestion
|
||||
ingestion = LegalCorpusIngestion()
|
||||
|
||||
chunks = ingestion.chunk_text(full_text, doc_record["code"])
|
||||
|
||||
for chunk in chunks:
|
||||
chunk["regulation_code"] = doc_record["code"]
|
||||
chunk["regulation_name"] = doc_record["title"]
|
||||
chunk["document_type"] = doc_record["document_type"]
|
||||
chunk["source_url"] = doc_record["url"]
|
||||
|
||||
if chunks:
|
||||
await ingestion.embed_and_upsert(chunks)
|
||||
doc_record["chunk_count"] = len(chunks)
|
||||
doc_record["status"] = "indexed"
|
||||
logger.info(f"Ingested {len(chunks)} chunks from URL {doc_record['url']}")
|
||||
else:
|
||||
doc_record["status"] = "error"
|
||||
doc_record["error"] = "No chunks generated from content"
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"Failed to fetch URL: {e}")
|
||||
doc_record["status"] = "error"
|
||||
doc_record["error"] = f"Failed to fetch URL: {str(e)}"
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to ingest URL content: {e}")
|
||||
doc_record["status"] = "error"
|
||||
doc_record["error"] = str(e)
|
||||
|
||||
|
||||
async def run_ingestion(
|
||||
force: bool,
|
||||
regulations: Optional[List[str]],
|
||||
ingestion_state: Dict[str, Any],
|
||||
all_regulations: List[Dict[str, str]],
|
||||
):
|
||||
"""Background task for running full corpus ingestion."""
|
||||
try:
|
||||
from legal_corpus_ingestion import LegalCorpusIngestion
|
||||
ingestion = LegalCorpusIngestion()
|
||||
|
||||
regs_to_process = regulations or [r["code"] for r in all_regulations]
|
||||
|
||||
for i, reg_code in enumerate(regs_to_process):
|
||||
ingestion_state["current_regulation"] = reg_code
|
||||
ingestion_state["processed"] = i
|
||||
|
||||
try:
|
||||
await ingestion.ingest_single(reg_code, force=force)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to ingest {reg_code}: {e}")
|
||||
|
||||
ingestion_state["completed"] = True
|
||||
ingestion_state["processed"] = len(regs_to_process)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Ingestion failed: {e}")
|
||||
ingestion_state["error"] = str(e)
|
||||
|
||||
finally:
|
||||
ingestion_state["running"] = False
|
||||
206
klausur-service/backend/legal_corpus_pipeline.py
Normal file
206
klausur-service/backend/legal_corpus_pipeline.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""
|
||||
Legal Corpus API - Pipeline Routes
|
||||
|
||||
Pipeline checkpoints, history, start/stop, and status endpoints.
|
||||
|
||||
Extracted from legal_corpus_api.py to keep files under 500 LOC.
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, HTTPException, BackgroundTasks
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StartPipelineRequest(BaseModel):
|
||||
force_reindex: bool = False
|
||||
skip_ingestion: bool = False
|
||||
|
||||
|
||||
# Create a separate router for pipeline-related endpoints
|
||||
pipeline_router = APIRouter(prefix="/api/v1/admin/pipeline", tags=["pipeline"])
|
||||
|
||||
|
||||
@pipeline_router.get("/checkpoints")
|
||||
async def get_pipeline_checkpoints():
|
||||
"""
|
||||
Get current pipeline checkpoint state.
|
||||
|
||||
Returns the current state of the compliance pipeline including:
|
||||
- Pipeline ID and overall status
|
||||
- Start and completion times
|
||||
- All checkpoints with their validations and metrics
|
||||
- Summary data
|
||||
"""
|
||||
from pipeline_checkpoints import CheckpointManager
|
||||
|
||||
state = CheckpointManager.load_state()
|
||||
|
||||
if state is None:
|
||||
return {
|
||||
"status": "no_data",
|
||||
"message": "No pipeline run data available yet.",
|
||||
"pipeline_id": None,
|
||||
"checkpoints": [],
|
||||
"summary": {}
|
||||
}
|
||||
|
||||
# Enrich with validation summary
|
||||
validation_summary = {
|
||||
"passed": 0,
|
||||
"warning": 0,
|
||||
"failed": 0,
|
||||
"total": 0
|
||||
}
|
||||
|
||||
for checkpoint in state.get("checkpoints", []):
|
||||
for validation in checkpoint.get("validations", []):
|
||||
validation_summary["total"] += 1
|
||||
status = validation.get("status", "not_run")
|
||||
if status in validation_summary:
|
||||
validation_summary[status] += 1
|
||||
|
||||
state["validation_summary"] = validation_summary
|
||||
|
||||
return state
|
||||
|
||||
|
||||
@pipeline_router.get("/checkpoints/history")
|
||||
async def get_pipeline_history():
|
||||
"""
|
||||
Get list of previous pipeline runs (if stored).
|
||||
For now, returns only current run.
|
||||
"""
|
||||
from pipeline_checkpoints import CheckpointManager
|
||||
|
||||
state = CheckpointManager.load_state()
|
||||
|
||||
if state is None:
|
||||
return {"runs": []}
|
||||
|
||||
return {
|
||||
"runs": [{
|
||||
"pipeline_id": state.get("pipeline_id"),
|
||||
"status": state.get("status"),
|
||||
"started_at": state.get("started_at"),
|
||||
"completed_at": state.get("completed_at"),
|
||||
}]
|
||||
}
|
||||
|
||||
|
||||
# Pipeline state for start/stop
|
||||
pipeline_process_state = {
|
||||
"running": False,
|
||||
"pid": None,
|
||||
"started_at": None,
|
||||
}
|
||||
|
||||
|
||||
@pipeline_router.post("/start")
|
||||
async def start_pipeline(request: StartPipelineRequest, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
Start the compliance pipeline in the background.
|
||||
|
||||
This runs the full_compliance_pipeline.py script which:
|
||||
1. Ingests all legal documents (unless skip_ingestion=True)
|
||||
2. Extracts requirements and controls
|
||||
3. Generates compliance measures
|
||||
4. Creates checkpoint data for monitoring
|
||||
"""
|
||||
global pipeline_process_state
|
||||
|
||||
from pipeline_checkpoints import CheckpointManager
|
||||
state = CheckpointManager.load_state()
|
||||
|
||||
if state and state.get("status") == "running":
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Pipeline is already running"
|
||||
)
|
||||
|
||||
if pipeline_process_state["running"]:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Pipeline start already in progress"
|
||||
)
|
||||
|
||||
pipeline_process_state["running"] = True
|
||||
pipeline_process_state["started_at"] = datetime.now().isoformat()
|
||||
|
||||
background_tasks.add_task(
|
||||
run_pipeline_background,
|
||||
request.force_reindex,
|
||||
request.skip_ingestion
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "starting",
|
||||
"message": "Compliance pipeline is starting in background",
|
||||
"started_at": pipeline_process_state["started_at"],
|
||||
}
|
||||
|
||||
|
||||
async def run_pipeline_background(force_reindex: bool, skip_ingestion: bool):
|
||||
"""Background task to run the compliance pipeline."""
|
||||
global pipeline_process_state
|
||||
|
||||
try:
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
cmd = [sys.executable, "full_compliance_pipeline.py"]
|
||||
if force_reindex:
|
||||
cmd.append("--force-reindex")
|
||||
if skip_ingestion:
|
||||
cmd.append("--skip-ingestion")
|
||||
|
||||
logger.info(f"Starting pipeline: {' '.join(cmd)}")
|
||||
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
cwd=os.path.dirname(os.path.abspath(__file__)),
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
)
|
||||
|
||||
pipeline_process_state["pid"] = process.pid
|
||||
|
||||
while process.poll() is None:
|
||||
await asyncio.sleep(5)
|
||||
|
||||
return_code = process.returncode
|
||||
|
||||
if return_code != 0:
|
||||
output = process.stdout.read() if process.stdout else ""
|
||||
logger.error(f"Pipeline failed with code {return_code}: {output}")
|
||||
else:
|
||||
logger.info("Pipeline completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to run pipeline: {e}")
|
||||
|
||||
finally:
|
||||
pipeline_process_state["running"] = False
|
||||
pipeline_process_state["pid"] = None
|
||||
|
||||
|
||||
@pipeline_router.get("/status")
|
||||
async def get_pipeline_status():
|
||||
"""Get current pipeline running status."""
|
||||
from pipeline_checkpoints import CheckpointManager
|
||||
|
||||
state = CheckpointManager.load_state()
|
||||
checkpoint_status = state.get("status") if state else "no_data"
|
||||
|
||||
return {
|
||||
"process_running": pipeline_process_state["running"],
|
||||
"process_pid": pipeline_process_state["pid"],
|
||||
"process_started_at": pipeline_process_state["started_at"],
|
||||
"checkpoint_status": checkpoint_status,
|
||||
"current_phase": state.get("current_phase") if state else None,
|
||||
}
|
||||
368
klausur-service/backend/legal_corpus_routes.py
Normal file
368
klausur-service/backend/legal_corpus_routes.py
Normal file
@@ -0,0 +1,368 @@
|
||||
"""
|
||||
Legal Corpus API - Corpus Routes
|
||||
|
||||
Endpoints for the RAG page in admin-v2:
|
||||
- GET /status - Collection status with chunk counts
|
||||
- GET /search - Semantic search
|
||||
- POST /ingest - Trigger ingestion
|
||||
- GET /ingestion-status - Ingestion status
|
||||
- GET /regulations - List regulations
|
||||
- GET /custom-documents - List custom docs
|
||||
- POST /upload - Upload document
|
||||
- POST /add-link - Add link for ingestion
|
||||
- DELETE /custom-documents/{id} - Delete custom doc
|
||||
- GET /traceability - Traceability info
|
||||
|
||||
Extracted from legal_corpus_api.py to keep files under 500 LOC.
|
||||
"""
|
||||
|
||||
import os
|
||||
import httpx
|
||||
import uuid
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict, Any
|
||||
from fastapi import APIRouter, HTTPException, Query, BackgroundTasks, UploadFile, File, Form
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
|
||||
from legal_corpus_ingest_tasks import (
|
||||
ingest_uploaded_document,
|
||||
ingest_link_document,
|
||||
run_ingestion,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Configuration
|
||||
QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6333")
|
||||
EMBEDDING_SERVICE_URL = os.getenv("EMBEDDING_SERVICE_URL", "http://embedding-service:8087")
|
||||
COLLECTION_NAME = "bp_legal_corpus"
|
||||
|
||||
# All regulations for status endpoint
|
||||
REGULATIONS = [
|
||||
{"code": "GDPR", "name": "DSGVO", "fullName": "Datenschutz-Grundverordnung", "type": "eu_regulation"},
|
||||
{"code": "EPRIVACY", "name": "ePrivacy-Richtlinie", "fullName": "Richtlinie 2002/58/EG", "type": "eu_directive"},
|
||||
{"code": "TDDDG", "name": "TDDDG", "fullName": "Telekommunikation-Digitale-Dienste-Datenschutz-Gesetz", "type": "de_law"},
|
||||
{"code": "SCC", "name": "Standardvertragsklauseln", "fullName": "2021/914/EU", "type": "eu_regulation"},
|
||||
{"code": "DPF", "name": "EU-US Data Privacy Framework", "fullName": "Angemessenheitsbeschluss", "type": "eu_regulation"},
|
||||
{"code": "AIACT", "name": "EU AI Act", "fullName": "Verordnung (EU) 2024/1689", "type": "eu_regulation"},
|
||||
{"code": "CRA", "name": "Cyber Resilience Act", "fullName": "Verordnung (EU) 2024/2847", "type": "eu_regulation"},
|
||||
{"code": "NIS2", "name": "NIS2-Richtlinie", "fullName": "Richtlinie (EU) 2022/2555", "type": "eu_directive"},
|
||||
{"code": "EUCSA", "name": "EU Cybersecurity Act", "fullName": "Verordnung (EU) 2019/881", "type": "eu_regulation"},
|
||||
{"code": "DATAACT", "name": "Data Act", "fullName": "Verordnung (EU) 2023/2854", "type": "eu_regulation"},
|
||||
{"code": "DGA", "name": "Data Governance Act", "fullName": "Verordnung (EU) 2022/868", "type": "eu_regulation"},
|
||||
{"code": "DSA", "name": "Digital Services Act", "fullName": "Verordnung (EU) 2022/2065", "type": "eu_regulation"},
|
||||
{"code": "EAA", "name": "European Accessibility Act", "fullName": "Richtlinie (EU) 2019/882", "type": "eu_directive"},
|
||||
{"code": "DSM", "name": "DSM-Urheberrechtsrichtlinie", "fullName": "Richtlinie (EU) 2019/790", "type": "eu_directive"},
|
||||
{"code": "PLD", "name": "Produkthaftungsrichtlinie", "fullName": "Richtlinie 85/374/EWG", "type": "eu_directive"},
|
||||
{"code": "GPSR", "name": "General Product Safety", "fullName": "Verordnung (EU) 2023/988", "type": "eu_regulation"},
|
||||
{"code": "BSI-TR-03161-1", "name": "BSI-TR Teil 1", "fullName": "BSI TR-03161 Teil 1 - Mobile Anwendungen", "type": "bsi_standard"},
|
||||
{"code": "BSI-TR-03161-2", "name": "BSI-TR Teil 2", "fullName": "BSI TR-03161 Teil 2 - Web-Anwendungen", "type": "bsi_standard"},
|
||||
{"code": "BSI-TR-03161-3", "name": "BSI-TR Teil 3", "fullName": "BSI TR-03161 Teil 3 - Hintergrundsysteme", "type": "bsi_standard"},
|
||||
]
|
||||
|
||||
# Ingestion state (in-memory for now)
|
||||
ingestion_state = {
|
||||
"running": False,
|
||||
"completed": False,
|
||||
"current_regulation": None,
|
||||
"processed": 0,
|
||||
"total": len(REGULATIONS),
|
||||
"error": None,
|
||||
}
|
||||
|
||||
|
||||
class SearchRequest(BaseModel):
|
||||
query: str
|
||||
regulations: Optional[List[str]] = None
|
||||
top_k: int = 5
|
||||
|
||||
|
||||
class IngestRequest(BaseModel):
|
||||
force: bool = False
|
||||
regulations: Optional[List[str]] = None
|
||||
|
||||
|
||||
class AddLinkRequest(BaseModel):
|
||||
url: str
|
||||
title: str
|
||||
code: str
|
||||
document_type: str = "custom"
|
||||
|
||||
|
||||
# Store for custom documents (in-memory for now)
|
||||
custom_documents: List[Dict[str, Any]] = []
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/v1/admin/legal-corpus", tags=["legal-corpus"])
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
async def get_legal_corpus_status():
|
||||
"""Get status of the legal corpus collection including chunk counts per regulation."""
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
try:
|
||||
collection_res = await client.get(f"{QDRANT_URL}/collections/{COLLECTION_NAME}")
|
||||
if collection_res.status_code != 200:
|
||||
return {
|
||||
"collection": COLLECTION_NAME,
|
||||
"totalPoints": 0,
|
||||
"vectorSize": 1024,
|
||||
"status": "not_found",
|
||||
"regulations": {},
|
||||
}
|
||||
|
||||
collection_data = collection_res.json()
|
||||
result = collection_data.get("result", {})
|
||||
|
||||
regulation_counts = {}
|
||||
for reg in REGULATIONS:
|
||||
count_res = await client.post(
|
||||
f"{QDRANT_URL}/collections/{COLLECTION_NAME}/points/count",
|
||||
json={
|
||||
"filter": {
|
||||
"must": [{"key": "regulation_code", "match": {"value": reg["code"]}}]
|
||||
}
|
||||
},
|
||||
)
|
||||
if count_res.status_code == 200:
|
||||
count_data = count_res.json()
|
||||
regulation_counts[reg["code"]] = count_data.get("result", {}).get("count", 0)
|
||||
else:
|
||||
regulation_counts[reg["code"]] = 0
|
||||
|
||||
return {
|
||||
"collection": COLLECTION_NAME,
|
||||
"totalPoints": result.get("points_count", 0),
|
||||
"vectorSize": result.get("config", {}).get("params", {}).get("vectors", {}).get("size", 1024),
|
||||
"status": result.get("status", "unknown"),
|
||||
"regulations": regulation_counts,
|
||||
}
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Failed to get Qdrant status: {e}")
|
||||
raise HTTPException(status_code=503, detail=f"Qdrant not available: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/search")
|
||||
async def search_legal_corpus(
|
||||
query: str = Query(..., description="Search query"),
|
||||
top_k: int = Query(5, ge=1, le=20, description="Number of results"),
|
||||
regulations: Optional[str] = Query(None, description="Comma-separated regulation codes to filter"),
|
||||
):
|
||||
"""Semantic search in legal corpus using BGE-M3 embeddings."""
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
try:
|
||||
embed_res = await client.post(
|
||||
f"{EMBEDDING_SERVICE_URL}/embed",
|
||||
json={"texts": [query]},
|
||||
)
|
||||
if embed_res.status_code != 200:
|
||||
raise HTTPException(status_code=500, detail="Embedding service error")
|
||||
|
||||
embed_data = embed_res.json()
|
||||
query_vector = embed_data["embeddings"][0]
|
||||
|
||||
search_request = {
|
||||
"vector": query_vector,
|
||||
"limit": top_k,
|
||||
"with_payload": True,
|
||||
}
|
||||
|
||||
if regulations:
|
||||
reg_codes = [r.strip() for r in regulations.split(",")]
|
||||
search_request["filter"] = {
|
||||
"should": [
|
||||
{"key": "regulation_code", "match": {"value": code}}
|
||||
for code in reg_codes
|
||||
]
|
||||
}
|
||||
|
||||
search_res = await client.post(
|
||||
f"{QDRANT_URL}/collections/{COLLECTION_NAME}/points/search",
|
||||
json=search_request,
|
||||
)
|
||||
|
||||
if search_res.status_code != 200:
|
||||
raise HTTPException(status_code=500, detail="Search failed")
|
||||
|
||||
search_data = search_res.json()
|
||||
results = []
|
||||
for point in search_data.get("result", []):
|
||||
payload = point.get("payload", {})
|
||||
results.append({
|
||||
"text": payload.get("text", ""),
|
||||
"regulation_code": payload.get("regulation_code", ""),
|
||||
"regulation_name": payload.get("regulation_name", ""),
|
||||
"article": payload.get("article"),
|
||||
"paragraph": payload.get("paragraph"),
|
||||
"source_url": payload.get("source_url", ""),
|
||||
"score": point.get("score", 0),
|
||||
})
|
||||
|
||||
return {"results": results, "query": query, "count": len(results)}
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Search failed: {e}")
|
||||
raise HTTPException(status_code=503, detail=f"Service not available: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/ingest")
|
||||
async def trigger_ingestion(request: IngestRequest, background_tasks: BackgroundTasks):
|
||||
"""Trigger legal corpus ingestion in background."""
|
||||
global ingestion_state
|
||||
|
||||
if ingestion_state["running"]:
|
||||
raise HTTPException(status_code=409, detail="Ingestion already running")
|
||||
|
||||
ingestion_state = {
|
||||
"running": True,
|
||||
"completed": False,
|
||||
"current_regulation": None,
|
||||
"processed": 0,
|
||||
"total": len(REGULATIONS),
|
||||
"error": None,
|
||||
}
|
||||
|
||||
background_tasks.add_task(run_ingestion, request.force, request.regulations, ingestion_state, REGULATIONS)
|
||||
|
||||
return {
|
||||
"status": "started",
|
||||
"job_id": "manual-trigger",
|
||||
"message": f"Ingestion started for {len(REGULATIONS)} regulations",
|
||||
}
|
||||
|
||||
|
||||
@router.get("/ingestion-status")
|
||||
async def get_ingestion_status():
|
||||
"""Get current ingestion status."""
|
||||
return ingestion_state
|
||||
|
||||
|
||||
@router.get("/regulations")
|
||||
async def get_regulations():
|
||||
"""Get list of all supported regulations."""
|
||||
return {"regulations": REGULATIONS}
|
||||
|
||||
|
||||
@router.get("/custom-documents")
|
||||
async def get_custom_documents():
|
||||
"""Get list of custom documents added by user."""
|
||||
return {"documents": custom_documents}
|
||||
|
||||
|
||||
@router.post("/upload")
|
||||
async def upload_document(
|
||||
background_tasks: BackgroundTasks,
|
||||
file: UploadFile = File(...),
|
||||
title: str = Form(...),
|
||||
code: str = Form(...),
|
||||
document_type: str = Form("custom"),
|
||||
):
|
||||
"""Upload a document (PDF) for ingestion into the legal corpus."""
|
||||
global custom_documents
|
||||
|
||||
if not file.filename.endswith(('.pdf', '.PDF')):
|
||||
raise HTTPException(status_code=400, detail="Only PDF files are supported")
|
||||
|
||||
upload_dir = "/tmp/legal_corpus_uploads"
|
||||
os.makedirs(upload_dir, exist_ok=True)
|
||||
|
||||
doc_id = str(uuid.uuid4())[:8]
|
||||
safe_filename = f"{doc_id}_{file.filename}"
|
||||
file_path = os.path.join(upload_dir, safe_filename)
|
||||
|
||||
try:
|
||||
with open(file_path, "wb") as buffer:
|
||||
shutil.copyfileobj(file.file, buffer)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save uploaded file: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to save file: {str(e)}")
|
||||
|
||||
doc_record = {
|
||||
"id": doc_id,
|
||||
"code": code,
|
||||
"title": title,
|
||||
"filename": file.filename,
|
||||
"file_path": file_path,
|
||||
"document_type": document_type,
|
||||
"uploaded_at": datetime.now().isoformat(),
|
||||
"status": "uploaded",
|
||||
"chunk_count": 0,
|
||||
}
|
||||
|
||||
custom_documents.append(doc_record)
|
||||
background_tasks.add_task(ingest_uploaded_document, doc_record)
|
||||
|
||||
return {
|
||||
"status": "uploaded",
|
||||
"document_id": doc_id,
|
||||
"message": f"Document '{title}' uploaded and queued for ingestion",
|
||||
"document": doc_record,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@router.post("/add-link")
|
||||
async def add_link(request: AddLinkRequest, background_tasks: BackgroundTasks):
|
||||
"""Add a URL/link for ingestion into the legal corpus."""
|
||||
global custom_documents
|
||||
|
||||
doc_id = str(uuid.uuid4())[:8]
|
||||
doc_record = {
|
||||
"id": doc_id,
|
||||
"code": request.code,
|
||||
"title": request.title,
|
||||
"url": request.url,
|
||||
"document_type": request.document_type,
|
||||
"uploaded_at": datetime.now().isoformat(),
|
||||
"status": "queued",
|
||||
"chunk_count": 0,
|
||||
}
|
||||
|
||||
custom_documents.append(doc_record)
|
||||
background_tasks.add_task(ingest_link_document, doc_record)
|
||||
|
||||
return {
|
||||
"status": "queued",
|
||||
"document_id": doc_id,
|
||||
"message": f"Link '{request.title}' queued for ingestion",
|
||||
"document": doc_record,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@router.delete("/custom-documents/{doc_id}")
|
||||
async def delete_custom_document(doc_id: str):
|
||||
"""Delete a custom document from the list."""
|
||||
global custom_documents
|
||||
|
||||
doc = next((d for d in custom_documents if d["id"] == doc_id), None)
|
||||
if not doc:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
custom_documents = [d for d in custom_documents if d["id"] != doc_id]
|
||||
|
||||
return {"status": "deleted", "document_id": doc_id}
|
||||
|
||||
|
||||
@router.get("/traceability")
|
||||
async def get_traceability(
|
||||
chunk_id: str = Query(..., description="Chunk ID or identifier"),
|
||||
regulation: str = Query(..., description="Regulation code"),
|
||||
):
|
||||
"""Get traceability information for a specific chunk."""
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
try:
|
||||
return {
|
||||
"chunk_id": chunk_id,
|
||||
"regulation": regulation,
|
||||
"requirements": [],
|
||||
"controls": [],
|
||||
"message": "Traceability-Daten werden verfuegbar sein, sobald die Requirements-Extraktion und Control-Ableitung implementiert sind."
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get traceability: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Traceability lookup failed: {str(e)}")
|
||||
269
klausur-service/backend/mail/ai_category.py
Normal file
269
klausur-service/backend/mail/ai_category.py
Normal file
@@ -0,0 +1,269 @@
|
||||
"""
|
||||
AI Email - Category Classification and Response Suggestions
|
||||
|
||||
Rule-based and LLM-based email category classification,
|
||||
plus response suggestion generation.
|
||||
|
||||
Extracted from ai_service.py to keep files under 500 LOC.
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
import httpx
|
||||
|
||||
from .models import (
|
||||
EmailCategory,
|
||||
SenderType,
|
||||
ResponseSuggestion,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# LLM Gateway configuration
|
||||
LLM_GATEWAY_URL = os.getenv("LLM_GATEWAY_URL", "http://localhost:8090")
|
||||
|
||||
|
||||
async def classify_category(
|
||||
http_client: httpx.AsyncClient,
|
||||
subject: str,
|
||||
body_preview: str,
|
||||
sender_type: SenderType,
|
||||
) -> Tuple[EmailCategory, float]:
|
||||
"""
|
||||
Classify email into a category.
|
||||
|
||||
Rule-based classification first, falls back to LLM.
|
||||
"""
|
||||
category, confidence = _classify_category_rules(subject, body_preview, sender_type)
|
||||
|
||||
if confidence > 0.7:
|
||||
return category, confidence
|
||||
|
||||
return await _classify_category_llm(http_client, subject, body_preview)
|
||||
|
||||
|
||||
def _classify_category_rules(
|
||||
subject: str,
|
||||
body_preview: str,
|
||||
sender_type: SenderType,
|
||||
) -> Tuple[EmailCategory, float]:
|
||||
"""Rule-based category classification."""
|
||||
text = f"{subject} {body_preview}".lower()
|
||||
|
||||
category_keywords = {
|
||||
EmailCategory.DIENSTLICH: [
|
||||
"dienstlich", "dienstanweisung", "erlass", "verordnung",
|
||||
"bescheid", "verfuegung", "ministerium", "behoerde"
|
||||
],
|
||||
EmailCategory.PERSONAL: [
|
||||
"personalrat", "stellenausschreibung", "versetzung",
|
||||
"beurteilung", "dienstzeugnis", "krankmeldung", "elternzeit"
|
||||
],
|
||||
EmailCategory.FINANZEN: [
|
||||
"budget", "haushalt", "etat", "abrechnung", "rechnung",
|
||||
"erstattung", "zuschuss", "foerdermittel"
|
||||
],
|
||||
EmailCategory.ELTERN: [
|
||||
"elternbrief", "elternabend", "schulkonferenz",
|
||||
"elternvertreter", "elternbeirat"
|
||||
],
|
||||
EmailCategory.SCHUELER: [
|
||||
"schueler", "schuelerin", "zeugnis", "klasse", "unterricht",
|
||||
"pruefung", "klassenfahrt", "schulpflicht"
|
||||
],
|
||||
EmailCategory.FORTBILDUNG: [
|
||||
"fortbildung", "seminar", "workshop", "schulung",
|
||||
"weiterbildung", "nlq", "didaktik"
|
||||
],
|
||||
EmailCategory.VERANSTALTUNG: [
|
||||
"einladung", "veranstaltung", "termin", "konferenz",
|
||||
"sitzung", "tagung", "feier"
|
||||
],
|
||||
EmailCategory.SICHERHEIT: [
|
||||
"sicherheit", "notfall", "brandschutz", "evakuierung",
|
||||
"hygiene", "corona", "infektionsschutz"
|
||||
],
|
||||
EmailCategory.TECHNIK: [
|
||||
"it", "software", "computer", "netzwerk", "login",
|
||||
"passwort", "digitalisierung", "iserv"
|
||||
],
|
||||
EmailCategory.NEWSLETTER: [
|
||||
"newsletter", "rundschreiben", "info-mail", "mitteilung"
|
||||
],
|
||||
EmailCategory.WERBUNG: [
|
||||
"angebot", "rabatt", "aktion", "werbung", "abonnement"
|
||||
],
|
||||
}
|
||||
|
||||
best_category = EmailCategory.SONSTIGES
|
||||
best_score = 0.0
|
||||
|
||||
for category, keywords in category_keywords.items():
|
||||
score = sum(1 for kw in keywords if kw in text)
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_category = category
|
||||
|
||||
if sender_type in [SenderType.KULTUSMINISTERIUM, SenderType.LANDESSCHULBEHOERDE, SenderType.RLSB]:
|
||||
if best_category == EmailCategory.SONSTIGES:
|
||||
best_category = EmailCategory.DIENSTLICH
|
||||
best_score = 2
|
||||
|
||||
confidence = min(0.9, 0.4 + (best_score * 0.15))
|
||||
|
||||
return best_category, confidence
|
||||
|
||||
|
||||
async def _classify_category_llm(
|
||||
client: httpx.AsyncClient,
|
||||
subject: str,
|
||||
body_preview: str,
|
||||
) -> Tuple[EmailCategory, float]:
|
||||
"""LLM-based category classification."""
|
||||
try:
|
||||
categories = ", ".join([c.value for c in EmailCategory])
|
||||
|
||||
prompt = f"""Klassifiziere diese E-Mail in EINE Kategorie:
|
||||
|
||||
Betreff: {subject}
|
||||
Inhalt: {body_preview[:500]}
|
||||
|
||||
Kategorien: {categories}
|
||||
|
||||
Antworte NUR mit dem Kategorienamen und einer Konfidenz (0.0-1.0):
|
||||
Format: kategorie|konfidenz
|
||||
"""
|
||||
|
||||
response = await client.post(
|
||||
f"{LLM_GATEWAY_URL}/api/v1/inference",
|
||||
json={
|
||||
"prompt": prompt,
|
||||
"playbook": "mail_analysis",
|
||||
"max_tokens": 50,
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
result = data.get("response", "sonstiges|0.5")
|
||||
parts = result.strip().split("|")
|
||||
|
||||
if len(parts) >= 2:
|
||||
category_str = parts[0].strip().lower()
|
||||
confidence = float(parts[1].strip())
|
||||
|
||||
try:
|
||||
category = EmailCategory(category_str)
|
||||
return category, min(max(confidence, 0.0), 1.0)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM category classification failed: {e}")
|
||||
|
||||
return EmailCategory.SONSTIGES, 0.5
|
||||
|
||||
|
||||
async def suggest_response(
|
||||
http_client: httpx.AsyncClient,
|
||||
subject: str,
|
||||
body_text: str,
|
||||
sender_type: SenderType,
|
||||
category: EmailCategory,
|
||||
) -> List[ResponseSuggestion]:
|
||||
"""Generate response suggestions for an email."""
|
||||
suggestions = []
|
||||
|
||||
if sender_type in [SenderType.KULTUSMINISTERIUM, SenderType.LANDESSCHULBEHOERDE, SenderType.RLSB]:
|
||||
suggestions.append(ResponseSuggestion(
|
||||
template_type="acknowledgment",
|
||||
subject=f"Re: {subject}",
|
||||
body="""Sehr geehrte Damen und Herren,
|
||||
|
||||
vielen Dank fuer Ihre Nachricht.
|
||||
|
||||
Ich bestaetige den Eingang und werde die Angelegenheit fristgerecht bearbeiten.
|
||||
|
||||
Mit freundlichen Gruessen""",
|
||||
confidence=0.8,
|
||||
))
|
||||
|
||||
if category == EmailCategory.ELTERN:
|
||||
suggestions.append(ResponseSuggestion(
|
||||
template_type="parent_response",
|
||||
subject=f"Re: {subject}",
|
||||
body="""Liebe Eltern,
|
||||
|
||||
vielen Dank fuer Ihre Nachricht.
|
||||
|
||||
[Ihre Antwort hier]
|
||||
|
||||
Mit freundlichen Gruessen""",
|
||||
confidence=0.7,
|
||||
))
|
||||
|
||||
try:
|
||||
llm_suggestion = await _generate_response_llm(http_client, subject, body_text[:500], sender_type)
|
||||
if llm_suggestion:
|
||||
suggestions.append(llm_suggestion)
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM response generation failed: {e}")
|
||||
|
||||
return suggestions
|
||||
|
||||
|
||||
async def _generate_response_llm(
|
||||
client: httpx.AsyncClient,
|
||||
subject: str,
|
||||
body_preview: str,
|
||||
sender_type: SenderType,
|
||||
) -> Optional[ResponseSuggestion]:
|
||||
"""Generate a response suggestion using LLM."""
|
||||
try:
|
||||
sender_desc = {
|
||||
SenderType.KULTUSMINISTERIUM: "dem Kultusministerium",
|
||||
SenderType.LANDESSCHULBEHOERDE: "der Landesschulbehoerde",
|
||||
SenderType.RLSB: "dem RLSB",
|
||||
SenderType.ELTERNVERTRETER: "einem Elternvertreter",
|
||||
}.get(sender_type, "einem Absender")
|
||||
|
||||
prompt = f"""Du bist eine Schulleiterin in Niedersachsen. Formuliere eine professionelle, kurze Antwort auf diese E-Mail von {sender_desc}:
|
||||
|
||||
Betreff: {subject}
|
||||
Inhalt: {body_preview}
|
||||
|
||||
Die Antwort sollte:
|
||||
- Hoeflich und formell sein
|
||||
- Den Eingang bestaetigen
|
||||
- Eine konkrete naechste Aktion nennen oder um Klaerung bitten
|
||||
|
||||
Antworte NUR mit dem Antworttext (ohne Betreffzeile, ohne "Betreff:").
|
||||
"""
|
||||
|
||||
response = await client.post(
|
||||
f"{LLM_GATEWAY_URL}/api/v1/inference",
|
||||
json={
|
||||
"prompt": prompt,
|
||||
"playbook": "mail_analysis",
|
||||
"max_tokens": 300,
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
body = data.get("response", "").strip()
|
||||
|
||||
if body:
|
||||
return ResponseSuggestion(
|
||||
template_type="ai_generated",
|
||||
subject=f"Re: {subject}",
|
||||
body=body,
|
||||
confidence=0.6,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM response generation failed: {e}")
|
||||
|
||||
return None
|
||||
184
klausur-service/backend/mail/ai_deadline.py
Normal file
184
klausur-service/backend/mail/ai_deadline.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""
|
||||
AI Email - Deadline Extraction
|
||||
|
||||
Regex-based and LLM-based deadline extraction from email content.
|
||||
|
||||
Extracted from ai_service.py to keep files under 500 LOC.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import logging
|
||||
from typing import List
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import httpx
|
||||
|
||||
from .models import DeadlineExtraction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# LLM Gateway configuration
|
||||
LLM_GATEWAY_URL = os.getenv("LLM_GATEWAY_URL", "http://localhost:8090")
|
||||
|
||||
|
||||
async def extract_deadlines(
|
||||
http_client: httpx.AsyncClient,
|
||||
subject: str,
|
||||
body_text: str,
|
||||
) -> List[DeadlineExtraction]:
|
||||
"""
|
||||
Extract deadlines from email content.
|
||||
|
||||
Uses regex patterns first, then LLM for complex cases.
|
||||
"""
|
||||
deadlines = []
|
||||
|
||||
full_text = f"{subject}\n{body_text}" if body_text else subject
|
||||
|
||||
# Try regex extraction first
|
||||
regex_deadlines = _extract_deadlines_regex(full_text)
|
||||
deadlines.extend(regex_deadlines)
|
||||
|
||||
# If no regex matches, try LLM
|
||||
if not deadlines and body_text:
|
||||
llm_deadlines = await _extract_deadlines_llm(http_client, subject, body_text[:1000])
|
||||
deadlines.extend(llm_deadlines)
|
||||
|
||||
return deadlines
|
||||
|
||||
|
||||
def _extract_deadlines_regex(text: str) -> List[DeadlineExtraction]:
|
||||
"""Extract deadlines using regex patterns."""
|
||||
deadlines = []
|
||||
now = datetime.now()
|
||||
|
||||
# German date patterns
|
||||
patterns = [
|
||||
# "bis zum 15.01.2025"
|
||||
(r"bis\s+(?:zum\s+)?(\d{1,2})\.(\d{1,2})\.(\d{2,4})", True),
|
||||
# "spaetestens am 15.01.2025"
|
||||
(r"sp\u00e4testens\s+(?:am\s+)?(\d{1,2})\.(\d{1,2})\.(\d{2,4})", True),
|
||||
# "Abgabetermin: 15.01.2025"
|
||||
(r"(?:Abgabe|Termin|Frist)[:\s]+(\d{1,2})\.(\d{1,2})\.(\d{2,4})", True),
|
||||
# "innerhalb von 14 Tagen"
|
||||
(r"innerhalb\s+von\s+(\d+)\s+(?:Tagen|Wochen)", False),
|
||||
# "bis Ende Januar"
|
||||
(r"bis\s+(?:Ende\s+)?(Januar|Februar|M\u00e4rz|April|Mai|Juni|Juli|August|September|Oktober|November|Dezember)", False),
|
||||
]
|
||||
|
||||
for pattern, is_specific_date in patterns:
|
||||
matches = re.finditer(pattern, text, re.IGNORECASE)
|
||||
|
||||
for match in matches:
|
||||
try:
|
||||
if is_specific_date:
|
||||
day = int(match.group(1))
|
||||
month = int(match.group(2))
|
||||
year = int(match.group(3))
|
||||
|
||||
if year < 100:
|
||||
year += 2000
|
||||
|
||||
deadline_date = datetime(year, month, day)
|
||||
|
||||
if deadline_date < now:
|
||||
continue
|
||||
|
||||
start = max(0, match.start() - 50)
|
||||
end = min(len(text), match.end() + 50)
|
||||
context = text[start:end].strip()
|
||||
|
||||
deadlines.append(DeadlineExtraction(
|
||||
deadline_date=deadline_date,
|
||||
description=f"Frist: {match.group(0)}",
|
||||
confidence=0.85,
|
||||
source_text=context,
|
||||
is_firm=True,
|
||||
))
|
||||
|
||||
else:
|
||||
if "Tagen" in pattern or "Wochen" in pattern:
|
||||
days = int(match.group(1))
|
||||
if "Wochen" in match.group(0).lower():
|
||||
days *= 7
|
||||
deadline_date = now + timedelta(days=days)
|
||||
|
||||
deadlines.append(DeadlineExtraction(
|
||||
deadline_date=deadline_date,
|
||||
description=f"Relative Frist: {match.group(0)}",
|
||||
confidence=0.7,
|
||||
source_text=match.group(0),
|
||||
is_firm=False,
|
||||
))
|
||||
|
||||
except (ValueError, IndexError) as e:
|
||||
logger.debug(f"Failed to parse date: {e}")
|
||||
continue
|
||||
|
||||
return deadlines
|
||||
|
||||
|
||||
async def _extract_deadlines_llm(
|
||||
client: httpx.AsyncClient,
|
||||
subject: str,
|
||||
body_preview: str,
|
||||
) -> List[DeadlineExtraction]:
|
||||
"""Extract deadlines using LLM."""
|
||||
try:
|
||||
prompt = f"""Analysiere diese E-Mail und extrahiere alle genannten Fristen und Termine:
|
||||
|
||||
Betreff: {subject}
|
||||
Inhalt: {body_preview}
|
||||
|
||||
Liste alle Fristen im folgenden Format auf (eine pro Zeile):
|
||||
DATUM|BESCHREIBUNG|VERBINDLICH
|
||||
Beispiel: 2025-01-15|Abgabe der Berichte|ja
|
||||
|
||||
Wenn keine Fristen gefunden werden, antworte mit: KEINE_FRISTEN
|
||||
|
||||
Antworte NUR im angegebenen Format.
|
||||
"""
|
||||
|
||||
response = await client.post(
|
||||
f"{LLM_GATEWAY_URL}/api/v1/inference",
|
||||
json={
|
||||
"prompt": prompt,
|
||||
"playbook": "mail_analysis",
|
||||
"max_tokens": 200,
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
result_text = data.get("response", "")
|
||||
|
||||
if "KEINE_FRISTEN" in result_text:
|
||||
return []
|
||||
|
||||
deadlines = []
|
||||
for line in result_text.strip().split("\n"):
|
||||
parts = line.split("|")
|
||||
if len(parts) >= 2:
|
||||
try:
|
||||
date_str = parts[0].strip()
|
||||
deadline_date = datetime.fromisoformat(date_str)
|
||||
description = parts[1].strip()
|
||||
is_firm = parts[2].strip().lower() == "ja" if len(parts) > 2 else True
|
||||
|
||||
deadlines.append(DeadlineExtraction(
|
||||
deadline_date=deadline_date,
|
||||
description=description,
|
||||
confidence=0.7,
|
||||
source_text=line,
|
||||
is_firm=is_firm,
|
||||
))
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
|
||||
return deadlines
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM deadline extraction failed: {e}")
|
||||
|
||||
return []
|
||||
134
klausur-service/backend/mail/ai_sender.py
Normal file
134
klausur-service/backend/mail/ai_sender.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""
|
||||
AI Email - Sender Classification
|
||||
|
||||
Domain-based and LLM-based sender classification for emails.
|
||||
|
||||
Extracted from ai_service.py to keep files under 500 LOC.
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from .models import (
|
||||
SenderType,
|
||||
SenderClassification,
|
||||
classify_sender_by_domain,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# LLM Gateway configuration
|
||||
LLM_GATEWAY_URL = os.getenv("LLM_GATEWAY_URL", "http://localhost:8090")
|
||||
|
||||
|
||||
async def classify_sender(
|
||||
http_client: httpx.AsyncClient,
|
||||
sender_email: str,
|
||||
sender_name: Optional[str] = None,
|
||||
subject: Optional[str] = None,
|
||||
body_preview: Optional[str] = None,
|
||||
) -> SenderClassification:
|
||||
"""
|
||||
Classify the sender of an email.
|
||||
|
||||
First tries domain matching, then falls back to LLM.
|
||||
"""
|
||||
# Try domain-based classification first (fast, high confidence)
|
||||
domain_result = classify_sender_by_domain(sender_email)
|
||||
if domain_result:
|
||||
return domain_result
|
||||
|
||||
# Fall back to LLM classification
|
||||
return await _classify_sender_llm(
|
||||
http_client, sender_email, sender_name, subject, body_preview
|
||||
)
|
||||
|
||||
|
||||
async def _classify_sender_llm(
|
||||
client: httpx.AsyncClient,
|
||||
sender_email: str,
|
||||
sender_name: Optional[str],
|
||||
subject: Optional[str],
|
||||
body_preview: Optional[str],
|
||||
) -> SenderClassification:
|
||||
"""Classify sender using LLM."""
|
||||
try:
|
||||
prompt = f"""Analysiere den Absender dieser E-Mail und klassifiziere ihn:
|
||||
|
||||
Absender E-Mail: {sender_email}
|
||||
Absender Name: {sender_name or "Nicht angegeben"}
|
||||
Betreff: {subject or "Nicht angegeben"}
|
||||
Vorschau: {body_preview[:200] if body_preview else "Nicht verfuegbar"}
|
||||
|
||||
Klassifiziere den Absender in EINE der folgenden Kategorien:
|
||||
- kultusministerium: Kultusministerium/Bildungsministerium
|
||||
- landesschulbehoerde: Landesschulbehoerde
|
||||
- rlsb: Regionales Landesamt fuer Schule und Bildung
|
||||
- schulamt: Schulamt
|
||||
- nibis: Niedersaechsischer Bildungsserver
|
||||
- schultraeger: Schultraeger/Kommune
|
||||
- elternvertreter: Elternvertreter/Elternrat
|
||||
- gewerkschaft: Gewerkschaft (GEW, VBE, etc.)
|
||||
- fortbildungsinstitut: Fortbildungsinstitut (NLQ, etc.)
|
||||
- privatperson: Privatperson
|
||||
- unternehmen: Unternehmen/Firma
|
||||
- unbekannt: Nicht einzuordnen
|
||||
|
||||
Antworte NUR mit dem Kategorienamen (z.B. "kultusministerium") und einer Konfidenz von 0.0 bis 1.0.
|
||||
Format: kategorie|konfidenz|kurze_begruendung
|
||||
"""
|
||||
|
||||
response = await client.post(
|
||||
f"{LLM_GATEWAY_URL}/api/v1/inference",
|
||||
json={
|
||||
"prompt": prompt,
|
||||
"playbook": "mail_analysis",
|
||||
"max_tokens": 100,
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
result_text = data.get("response", "unbekannt|0.5|")
|
||||
|
||||
parts = result_text.strip().split("|")
|
||||
if len(parts) >= 2:
|
||||
sender_type_str = parts[0].strip().lower()
|
||||
confidence = float(parts[1].strip())
|
||||
|
||||
type_mapping = {
|
||||
"kultusministerium": SenderType.KULTUSMINISTERIUM,
|
||||
"landesschulbehoerde": SenderType.LANDESSCHULBEHOERDE,
|
||||
"rlsb": SenderType.RLSB,
|
||||
"schulamt": SenderType.SCHULAMT,
|
||||
"nibis": SenderType.NIBIS,
|
||||
"schultraeger": SenderType.SCHULTRAEGER,
|
||||
"elternvertreter": SenderType.ELTERNVERTRETER,
|
||||
"gewerkschaft": SenderType.GEWERKSCHAFT,
|
||||
"fortbildungsinstitut": SenderType.FORTBILDUNGSINSTITUT,
|
||||
"privatperson": SenderType.PRIVATPERSON,
|
||||
"unternehmen": SenderType.UNTERNEHMEN,
|
||||
}
|
||||
|
||||
sender_type = type_mapping.get(sender_type_str, SenderType.UNBEKANNT)
|
||||
|
||||
return SenderClassification(
|
||||
sender_type=sender_type,
|
||||
confidence=min(max(confidence, 0.0), 1.0),
|
||||
domain_matched=False,
|
||||
ai_classified=True,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM sender classification failed: {e}")
|
||||
|
||||
# Default fallback
|
||||
return SenderClassification(
|
||||
sender_type=SenderType.UNBEKANNT,
|
||||
confidence=0.3,
|
||||
domain_matched=False,
|
||||
ai_classified=False,
|
||||
)
|
||||
@@ -1,18 +1,19 @@
|
||||
"""
|
||||
AI Email Analysis Service
|
||||
AI Email Analysis Service — Barrel Re-export
|
||||
|
||||
KI-powered email analysis with:
|
||||
- Sender classification (authority recognition)
|
||||
- Deadline extraction
|
||||
- Category classification
|
||||
- Response suggestions
|
||||
Split into:
|
||||
- mail/ai_sender.py — Sender classification (domain + LLM)
|
||||
- mail/ai_deadline.py — Deadline extraction (regex + LLM)
|
||||
- mail/ai_category.py — Category classification + response suggestions
|
||||
|
||||
The AIEmailService class and get_ai_email_service() are defined here
|
||||
to maintain the original public API.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import logging
|
||||
from typing import Optional, List, Dict, Any, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List, Tuple
|
||||
from datetime import datetime
|
||||
|
||||
import httpx
|
||||
|
||||
from .models import (
|
||||
@@ -23,17 +24,15 @@ from .models import (
|
||||
DeadlineExtraction,
|
||||
EmailAnalysisResult,
|
||||
ResponseSuggestion,
|
||||
KNOWN_AUTHORITIES_NI,
|
||||
classify_sender_by_domain,
|
||||
get_priority_from_sender_type,
|
||||
)
|
||||
from .mail_db import update_email_ai_analysis
|
||||
from .ai_sender import classify_sender, LLM_GATEWAY_URL
|
||||
from .ai_deadline import extract_deadlines
|
||||
from .ai_category import classify_category, suggest_response
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# LLM Gateway configuration
|
||||
LLM_GATEWAY_URL = os.getenv("LLM_GATEWAY_URL", "http://localhost:8090")
|
||||
|
||||
|
||||
class AIEmailService:
|
||||
"""
|
||||
@@ -56,10 +55,6 @@ class AIEmailService:
|
||||
self._http_client = httpx.AsyncClient(timeout=30.0)
|
||||
return self._http_client
|
||||
|
||||
# =========================================================================
|
||||
# Sender Classification
|
||||
# =========================================================================
|
||||
|
||||
async def classify_sender(
|
||||
self,
|
||||
sender_email: str,
|
||||
@@ -67,300 +62,20 @@ class AIEmailService:
|
||||
subject: Optional[str] = None,
|
||||
body_preview: Optional[str] = None,
|
||||
) -> SenderClassification:
|
||||
"""
|
||||
Classify the sender of an email.
|
||||
|
||||
First tries domain matching, then falls back to LLM.
|
||||
|
||||
Args:
|
||||
sender_email: Sender's email address
|
||||
sender_name: Sender's display name
|
||||
subject: Email subject
|
||||
body_preview: First 200 chars of body
|
||||
|
||||
Returns:
|
||||
SenderClassification with type and confidence
|
||||
"""
|
||||
# Try domain-based classification first (fast, high confidence)
|
||||
domain_result = classify_sender_by_domain(sender_email)
|
||||
if domain_result:
|
||||
return domain_result
|
||||
|
||||
# Fall back to LLM classification
|
||||
return await self._classify_sender_llm(
|
||||
sender_email, sender_name, subject, body_preview
|
||||
"""Classify the sender of an email."""
|
||||
client = await self.get_http_client()
|
||||
return await classify_sender(
|
||||
client, sender_email, sender_name, subject, body_preview
|
||||
)
|
||||
|
||||
async def _classify_sender_llm(
|
||||
self,
|
||||
sender_email: str,
|
||||
sender_name: Optional[str],
|
||||
subject: Optional[str],
|
||||
body_preview: Optional[str],
|
||||
) -> SenderClassification:
|
||||
"""Classify sender using LLM."""
|
||||
try:
|
||||
client = await self.get_http_client()
|
||||
|
||||
prompt = f"""Analysiere den Absender dieser E-Mail und klassifiziere ihn:
|
||||
|
||||
Absender E-Mail: {sender_email}
|
||||
Absender Name: {sender_name or "Nicht angegeben"}
|
||||
Betreff: {subject or "Nicht angegeben"}
|
||||
Vorschau: {body_preview[:200] if body_preview else "Nicht verfügbar"}
|
||||
|
||||
Klassifiziere den Absender in EINE der folgenden Kategorien:
|
||||
- kultusministerium: Kultusministerium/Bildungsministerium
|
||||
- landesschulbehoerde: Landesschulbehörde
|
||||
- rlsb: Regionales Landesamt für Schule und Bildung
|
||||
- schulamt: Schulamt
|
||||
- nibis: Niedersächsischer Bildungsserver
|
||||
- schultraeger: Schulträger/Kommune
|
||||
- elternvertreter: Elternvertreter/Elternrat
|
||||
- gewerkschaft: Gewerkschaft (GEW, VBE, etc.)
|
||||
- fortbildungsinstitut: Fortbildungsinstitut (NLQ, etc.)
|
||||
- privatperson: Privatperson
|
||||
- unternehmen: Unternehmen/Firma
|
||||
- unbekannt: Nicht einzuordnen
|
||||
|
||||
Antworte NUR mit dem Kategorienamen (z.B. "kultusministerium") und einer Konfidenz von 0.0 bis 1.0.
|
||||
Format: kategorie|konfidenz|kurze_begründung
|
||||
"""
|
||||
|
||||
response = await client.post(
|
||||
f"{LLM_GATEWAY_URL}/api/v1/inference",
|
||||
json={
|
||||
"prompt": prompt,
|
||||
"playbook": "mail_analysis",
|
||||
"max_tokens": 100,
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
result_text = data.get("response", "unbekannt|0.5|")
|
||||
|
||||
# Parse response
|
||||
parts = result_text.strip().split("|")
|
||||
if len(parts) >= 2:
|
||||
sender_type_str = parts[0].strip().lower()
|
||||
confidence = float(parts[1].strip())
|
||||
|
||||
# Map to enum
|
||||
type_mapping = {
|
||||
"kultusministerium": SenderType.KULTUSMINISTERIUM,
|
||||
"landesschulbehoerde": SenderType.LANDESSCHULBEHOERDE,
|
||||
"rlsb": SenderType.RLSB,
|
||||
"schulamt": SenderType.SCHULAMT,
|
||||
"nibis": SenderType.NIBIS,
|
||||
"schultraeger": SenderType.SCHULTRAEGER,
|
||||
"elternvertreter": SenderType.ELTERNVERTRETER,
|
||||
"gewerkschaft": SenderType.GEWERKSCHAFT,
|
||||
"fortbildungsinstitut": SenderType.FORTBILDUNGSINSTITUT,
|
||||
"privatperson": SenderType.PRIVATPERSON,
|
||||
"unternehmen": SenderType.UNTERNEHMEN,
|
||||
}
|
||||
|
||||
sender_type = type_mapping.get(sender_type_str, SenderType.UNBEKANNT)
|
||||
|
||||
return SenderClassification(
|
||||
sender_type=sender_type,
|
||||
confidence=min(max(confidence, 0.0), 1.0),
|
||||
domain_matched=False,
|
||||
ai_classified=True,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM sender classification failed: {e}")
|
||||
|
||||
# Default fallback
|
||||
return SenderClassification(
|
||||
sender_type=SenderType.UNBEKANNT,
|
||||
confidence=0.3,
|
||||
domain_matched=False,
|
||||
ai_classified=False,
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Deadline Extraction
|
||||
# =========================================================================
|
||||
|
||||
async def extract_deadlines(
|
||||
self,
|
||||
subject: str,
|
||||
body_text: str,
|
||||
) -> List[DeadlineExtraction]:
|
||||
"""
|
||||
Extract deadlines from email content.
|
||||
|
||||
Uses regex patterns first, then LLM for complex cases.
|
||||
|
||||
Args:
|
||||
subject: Email subject
|
||||
body_text: Email body text
|
||||
|
||||
Returns:
|
||||
List of extracted deadlines
|
||||
"""
|
||||
deadlines = []
|
||||
|
||||
# Combine subject and body
|
||||
full_text = f"{subject}\n{body_text}" if body_text else subject
|
||||
|
||||
# Try regex extraction first
|
||||
regex_deadlines = self._extract_deadlines_regex(full_text)
|
||||
deadlines.extend(regex_deadlines)
|
||||
|
||||
# If no regex matches, try LLM
|
||||
if not deadlines and body_text:
|
||||
llm_deadlines = await self._extract_deadlines_llm(subject, body_text[:1000])
|
||||
deadlines.extend(llm_deadlines)
|
||||
|
||||
return deadlines
|
||||
|
||||
def _extract_deadlines_regex(self, text: str) -> List[DeadlineExtraction]:
|
||||
"""Extract deadlines using regex patterns."""
|
||||
deadlines = []
|
||||
now = datetime.now()
|
||||
|
||||
# German date patterns
|
||||
patterns = [
|
||||
# "bis zum 15.01.2025"
|
||||
(r"bis\s+(?:zum\s+)?(\d{1,2})\.(\d{1,2})\.(\d{2,4})", True),
|
||||
# "spätestens am 15.01.2025"
|
||||
(r"spätestens\s+(?:am\s+)?(\d{1,2})\.(\d{1,2})\.(\d{2,4})", True),
|
||||
# "Abgabetermin: 15.01.2025"
|
||||
(r"(?:Abgabe|Termin|Frist)[:\s]+(\d{1,2})\.(\d{1,2})\.(\d{2,4})", True),
|
||||
# "innerhalb von 14 Tagen"
|
||||
(r"innerhalb\s+von\s+(\d+)\s+(?:Tagen|Wochen)", False),
|
||||
# "bis Ende Januar"
|
||||
(r"bis\s+(?:Ende\s+)?(Januar|Februar|März|April|Mai|Juni|Juli|August|September|Oktober|November|Dezember)", False),
|
||||
]
|
||||
|
||||
for pattern, is_specific_date in patterns:
|
||||
matches = re.finditer(pattern, text, re.IGNORECASE)
|
||||
|
||||
for match in matches:
|
||||
try:
|
||||
if is_specific_date:
|
||||
day = int(match.group(1))
|
||||
month = int(match.group(2))
|
||||
year = int(match.group(3))
|
||||
|
||||
# Handle 2-digit years
|
||||
if year < 100:
|
||||
year += 2000
|
||||
|
||||
deadline_date = datetime(year, month, day)
|
||||
|
||||
# Skip past dates
|
||||
if deadline_date < now:
|
||||
continue
|
||||
|
||||
# Get surrounding context
|
||||
start = max(0, match.start() - 50)
|
||||
end = min(len(text), match.end() + 50)
|
||||
context = text[start:end].strip()
|
||||
|
||||
deadlines.append(DeadlineExtraction(
|
||||
deadline_date=deadline_date,
|
||||
description=f"Frist: {match.group(0)}",
|
||||
confidence=0.85,
|
||||
source_text=context,
|
||||
is_firm=True,
|
||||
))
|
||||
|
||||
else:
|
||||
# Relative dates (innerhalb von X Tagen)
|
||||
if "Tagen" in pattern or "Wochen" in pattern:
|
||||
days = int(match.group(1))
|
||||
if "Wochen" in match.group(0).lower():
|
||||
days *= 7
|
||||
deadline_date = now + timedelta(days=days)
|
||||
|
||||
deadlines.append(DeadlineExtraction(
|
||||
deadline_date=deadline_date,
|
||||
description=f"Relative Frist: {match.group(0)}",
|
||||
confidence=0.7,
|
||||
source_text=match.group(0),
|
||||
is_firm=False,
|
||||
))
|
||||
|
||||
except (ValueError, IndexError) as e:
|
||||
logger.debug(f"Failed to parse date: {e}")
|
||||
continue
|
||||
|
||||
return deadlines
|
||||
|
||||
async def _extract_deadlines_llm(
|
||||
self,
|
||||
subject: str,
|
||||
body_preview: str,
|
||||
) -> List[DeadlineExtraction]:
|
||||
"""Extract deadlines using LLM."""
|
||||
try:
|
||||
client = await self.get_http_client()
|
||||
|
||||
prompt = f"""Analysiere diese E-Mail und extrahiere alle genannten Fristen und Termine:
|
||||
|
||||
Betreff: {subject}
|
||||
Inhalt: {body_preview}
|
||||
|
||||
Liste alle Fristen im folgenden Format auf (eine pro Zeile):
|
||||
DATUM|BESCHREIBUNG|VERBINDLICH
|
||||
Beispiel: 2025-01-15|Abgabe der Berichte|ja
|
||||
|
||||
Wenn keine Fristen gefunden werden, antworte mit: KEINE_FRISTEN
|
||||
|
||||
Antworte NUR im angegebenen Format.
|
||||
"""
|
||||
|
||||
response = await client.post(
|
||||
f"{LLM_GATEWAY_URL}/api/v1/inference",
|
||||
json={
|
||||
"prompt": prompt,
|
||||
"playbook": "mail_analysis",
|
||||
"max_tokens": 200,
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
result_text = data.get("response", "")
|
||||
|
||||
if "KEINE_FRISTEN" in result_text:
|
||||
return []
|
||||
|
||||
deadlines = []
|
||||
for line in result_text.strip().split("\n"):
|
||||
parts = line.split("|")
|
||||
if len(parts) >= 2:
|
||||
try:
|
||||
date_str = parts[0].strip()
|
||||
deadline_date = datetime.fromisoformat(date_str)
|
||||
description = parts[1].strip()
|
||||
is_firm = parts[2].strip().lower() == "ja" if len(parts) > 2 else True
|
||||
|
||||
deadlines.append(DeadlineExtraction(
|
||||
deadline_date=deadline_date,
|
||||
description=description,
|
||||
confidence=0.7,
|
||||
source_text=line,
|
||||
is_firm=is_firm,
|
||||
))
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
|
||||
return deadlines
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM deadline extraction failed: {e}")
|
||||
|
||||
return []
|
||||
|
||||
# =========================================================================
|
||||
# Email Category Classification
|
||||
# =========================================================================
|
||||
"""Extract deadlines from email content."""
|
||||
client = await self.get_http_client()
|
||||
return await extract_deadlines(client, subject, body_text)
|
||||
|
||||
async def classify_category(
|
||||
self,
|
||||
@@ -368,155 +83,9 @@ Antworte NUR im angegebenen Format.
|
||||
body_preview: str,
|
||||
sender_type: SenderType,
|
||||
) -> Tuple[EmailCategory, float]:
|
||||
"""
|
||||
Classify email into a category.
|
||||
|
||||
Args:
|
||||
subject: Email subject
|
||||
body_preview: First 200 chars of body
|
||||
sender_type: Already classified sender type
|
||||
|
||||
Returns:
|
||||
Tuple of (category, confidence)
|
||||
"""
|
||||
# Rule-based classification first
|
||||
category, confidence = self._classify_category_rules(subject, body_preview, sender_type)
|
||||
|
||||
if confidence > 0.7:
|
||||
return category, confidence
|
||||
|
||||
# Fall back to LLM
|
||||
return await self._classify_category_llm(subject, body_preview)
|
||||
|
||||
def _classify_category_rules(
|
||||
self,
|
||||
subject: str,
|
||||
body_preview: str,
|
||||
sender_type: SenderType,
|
||||
) -> Tuple[EmailCategory, float]:
|
||||
"""Rule-based category classification."""
|
||||
text = f"{subject} {body_preview}".lower()
|
||||
|
||||
# Keywords for each category
|
||||
category_keywords = {
|
||||
EmailCategory.DIENSTLICH: [
|
||||
"dienstlich", "dienstanweisung", "erlass", "verordnung",
|
||||
"bescheid", "verfügung", "ministerium", "behörde"
|
||||
],
|
||||
EmailCategory.PERSONAL: [
|
||||
"personalrat", "stellenausschreibung", "versetzung",
|
||||
"beurteilung", "dienstzeugnis", "krankmeldung", "elternzeit"
|
||||
],
|
||||
EmailCategory.FINANZEN: [
|
||||
"budget", "haushalt", "etat", "abrechnung", "rechnung",
|
||||
"erstattung", "zuschuss", "fördermittel"
|
||||
],
|
||||
EmailCategory.ELTERN: [
|
||||
"elternbrief", "elternabend", "schulkonferenz",
|
||||
"elternvertreter", "elternbeirat"
|
||||
],
|
||||
EmailCategory.SCHUELER: [
|
||||
"schüler", "schülerin", "zeugnis", "klasse", "unterricht",
|
||||
"prüfung", "klassenfahrt", "schulpflicht"
|
||||
],
|
||||
EmailCategory.FORTBILDUNG: [
|
||||
"fortbildung", "seminar", "workshop", "schulung",
|
||||
"weiterbildung", "nlq", "didaktik"
|
||||
],
|
||||
EmailCategory.VERANSTALTUNG: [
|
||||
"einladung", "veranstaltung", "termin", "konferenz",
|
||||
"sitzung", "tagung", "feier"
|
||||
],
|
||||
EmailCategory.SICHERHEIT: [
|
||||
"sicherheit", "notfall", "brandschutz", "evakuierung",
|
||||
"hygiene", "corona", "infektionsschutz"
|
||||
],
|
||||
EmailCategory.TECHNIK: [
|
||||
"it", "software", "computer", "netzwerk", "login",
|
||||
"passwort", "digitalisierung", "iserv"
|
||||
],
|
||||
EmailCategory.NEWSLETTER: [
|
||||
"newsletter", "rundschreiben", "info-mail", "mitteilung"
|
||||
],
|
||||
EmailCategory.WERBUNG: [
|
||||
"angebot", "rabatt", "aktion", "werbung", "abonnement"
|
||||
],
|
||||
}
|
||||
|
||||
best_category = EmailCategory.SONSTIGES
|
||||
best_score = 0.0
|
||||
|
||||
for category, keywords in category_keywords.items():
|
||||
score = sum(1 for kw in keywords if kw in text)
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_category = category
|
||||
|
||||
# Adjust based on sender type
|
||||
if sender_type in [SenderType.KULTUSMINISTERIUM, SenderType.LANDESSCHULBEHOERDE, SenderType.RLSB]:
|
||||
if best_category == EmailCategory.SONSTIGES:
|
||||
best_category = EmailCategory.DIENSTLICH
|
||||
best_score = 2
|
||||
|
||||
# Convert score to confidence
|
||||
confidence = min(0.9, 0.4 + (best_score * 0.15))
|
||||
|
||||
return best_category, confidence
|
||||
|
||||
async def _classify_category_llm(
|
||||
self,
|
||||
subject: str,
|
||||
body_preview: str,
|
||||
) -> Tuple[EmailCategory, float]:
|
||||
"""LLM-based category classification."""
|
||||
try:
|
||||
client = await self.get_http_client()
|
||||
|
||||
categories = ", ".join([c.value for c in EmailCategory])
|
||||
|
||||
prompt = f"""Klassifiziere diese E-Mail in EINE Kategorie:
|
||||
|
||||
Betreff: {subject}
|
||||
Inhalt: {body_preview[:500]}
|
||||
|
||||
Kategorien: {categories}
|
||||
|
||||
Antworte NUR mit dem Kategorienamen und einer Konfidenz (0.0-1.0):
|
||||
Format: kategorie|konfidenz
|
||||
"""
|
||||
|
||||
response = await client.post(
|
||||
f"{LLM_GATEWAY_URL}/api/v1/inference",
|
||||
json={
|
||||
"prompt": prompt,
|
||||
"playbook": "mail_analysis",
|
||||
"max_tokens": 50,
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
result = data.get("response", "sonstiges|0.5")
|
||||
parts = result.strip().split("|")
|
||||
|
||||
if len(parts) >= 2:
|
||||
category_str = parts[0].strip().lower()
|
||||
confidence = float(parts[1].strip())
|
||||
|
||||
try:
|
||||
category = EmailCategory(category_str)
|
||||
return category, min(max(confidence, 0.0), 1.0)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM category classification failed: {e}")
|
||||
|
||||
return EmailCategory.SONSTIGES, 0.5
|
||||
|
||||
# =========================================================================
|
||||
# Full Analysis Pipeline
|
||||
# =========================================================================
|
||||
"""Classify email into a category."""
|
||||
client = await self.get_http_client()
|
||||
return await classify_category(client, subject, body_preview, sender_type)
|
||||
|
||||
async def analyze_email(
|
||||
self,
|
||||
@@ -527,20 +96,7 @@ Format: kategorie|konfidenz
|
||||
body_text: Optional[str],
|
||||
body_preview: Optional[str],
|
||||
) -> EmailAnalysisResult:
|
||||
"""
|
||||
Run full analysis pipeline on an email.
|
||||
|
||||
Args:
|
||||
email_id: Database ID of the email
|
||||
sender_email: Sender's email address
|
||||
sender_name: Sender's display name
|
||||
subject: Email subject
|
||||
body_text: Full body text
|
||||
body_preview: Preview text
|
||||
|
||||
Returns:
|
||||
Complete analysis result
|
||||
"""
|
||||
"""Run full analysis pipeline on an email."""
|
||||
# 1. Classify sender
|
||||
sender_classification = await self.classify_sender(
|
||||
sender_email, sender_name, subject, body_preview
|
||||
@@ -569,8 +125,8 @@ Format: kategorie|konfidenz
|
||||
elif days_until <= 7:
|
||||
suggested_priority = max(suggested_priority, TaskPriority.MEDIUM)
|
||||
|
||||
# 5. Generate summary (optional, can be expensive)
|
||||
summary = None # Could add LLM summary generation here
|
||||
# 5. Summary (optional)
|
||||
summary = None
|
||||
|
||||
# 6. Determine if task should be auto-created
|
||||
auto_create_task = (
|
||||
@@ -612,10 +168,6 @@ Format: kategorie|konfidenz
|
||||
auto_create_task=auto_create_task,
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Response Suggestions
|
||||
# =========================================================================
|
||||
|
||||
async def suggest_response(
|
||||
self,
|
||||
subject: str,
|
||||
@@ -623,114 +175,11 @@ Format: kategorie|konfidenz
|
||||
sender_type: SenderType,
|
||||
category: EmailCategory,
|
||||
) -> List[ResponseSuggestion]:
|
||||
"""
|
||||
Generate response suggestions for an email.
|
||||
|
||||
Args:
|
||||
subject: Original email subject
|
||||
body_text: Original email body
|
||||
sender_type: Classified sender type
|
||||
category: Classified category
|
||||
|
||||
Returns:
|
||||
List of response suggestions
|
||||
"""
|
||||
suggestions = []
|
||||
|
||||
# Add standard templates based on sender type and category
|
||||
if sender_type in [SenderType.KULTUSMINISTERIUM, SenderType.LANDESSCHULBEHOERDE, SenderType.RLSB]:
|
||||
suggestions.append(ResponseSuggestion(
|
||||
template_type="acknowledgment",
|
||||
subject=f"Re: {subject}",
|
||||
body="""Sehr geehrte Damen und Herren,
|
||||
|
||||
vielen Dank für Ihre Nachricht.
|
||||
|
||||
Ich bestätige den Eingang und werde die Angelegenheit fristgerecht bearbeiten.
|
||||
|
||||
Mit freundlichen Grüßen""",
|
||||
confidence=0.8,
|
||||
))
|
||||
|
||||
if category == EmailCategory.ELTERN:
|
||||
suggestions.append(ResponseSuggestion(
|
||||
template_type="parent_response",
|
||||
subject=f"Re: {subject}",
|
||||
body="""Liebe Eltern,
|
||||
|
||||
vielen Dank für Ihre Nachricht.
|
||||
|
||||
[Ihre Antwort hier]
|
||||
|
||||
Mit freundlichen Grüßen""",
|
||||
confidence=0.7,
|
||||
))
|
||||
|
||||
# Add LLM-generated suggestion
|
||||
try:
|
||||
llm_suggestion = await self._generate_response_llm(subject, body_text[:500], sender_type)
|
||||
if llm_suggestion:
|
||||
suggestions.append(llm_suggestion)
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM response generation failed: {e}")
|
||||
|
||||
return suggestions
|
||||
|
||||
async def _generate_response_llm(
|
||||
self,
|
||||
subject: str,
|
||||
body_preview: str,
|
||||
sender_type: SenderType,
|
||||
) -> Optional[ResponseSuggestion]:
|
||||
"""Generate a response suggestion using LLM."""
|
||||
try:
|
||||
client = await self.get_http_client()
|
||||
|
||||
sender_desc = {
|
||||
SenderType.KULTUSMINISTERIUM: "dem Kultusministerium",
|
||||
SenderType.LANDESSCHULBEHOERDE: "der Landesschulbehörde",
|
||||
SenderType.RLSB: "dem RLSB",
|
||||
SenderType.ELTERNVERTRETER: "einem Elternvertreter",
|
||||
}.get(sender_type, "einem Absender")
|
||||
|
||||
prompt = f"""Du bist eine Schulleiterin in Niedersachsen. Formuliere eine professionelle, kurze Antwort auf diese E-Mail von {sender_desc}:
|
||||
|
||||
Betreff: {subject}
|
||||
Inhalt: {body_preview}
|
||||
|
||||
Die Antwort sollte:
|
||||
- Höflich und formell sein
|
||||
- Den Eingang bestätigen
|
||||
- Eine konkrete nächste Aktion nennen oder um Klärung bitten
|
||||
|
||||
Antworte NUR mit dem Antworttext (ohne Betreffzeile, ohne "Betreff:").
|
||||
"""
|
||||
|
||||
response = await client.post(
|
||||
f"{LLM_GATEWAY_URL}/api/v1/inference",
|
||||
json={
|
||||
"prompt": prompt,
|
||||
"playbook": "mail_analysis",
|
||||
"max_tokens": 300,
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
body = data.get("response", "").strip()
|
||||
|
||||
if body:
|
||||
return ResponseSuggestion(
|
||||
template_type="ai_generated",
|
||||
subject=f"Re: {subject}",
|
||||
body=body,
|
||||
confidence=0.6,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM response generation failed: {e}")
|
||||
|
||||
return None
|
||||
"""Generate response suggestions for an email."""
|
||||
client = await self.get_http_client()
|
||||
return await suggest_response(
|
||||
client, subject, body_text, sender_type, category
|
||||
)
|
||||
|
||||
|
||||
# Global instance
|
||||
|
||||
@@ -1,833 +1,36 @@
|
||||
"""
|
||||
PostgreSQL Metrics Database Service
|
||||
Stores search feedback, calculates quality metrics (Precision, Recall, MRR).
|
||||
PostgreSQL Metrics Database Service — Barrel Re-export
|
||||
|
||||
Split into:
|
||||
- metrics_db_core.py — Pool, feedback, metrics, relevance
|
||||
- metrics_db_schema.py — Table initialization (DDL)
|
||||
- metrics_db_zeugnis.py — Zeugnis source/document/stats operations
|
||||
|
||||
All public names are re-exported here for backward compatibility.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional, List, Dict
|
||||
from datetime import datetime, timedelta
|
||||
import asyncio
|
||||
|
||||
# Database Configuration - uses test default if not configured (for CI)
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://test:test@localhost:5432/test_metrics")
|
||||
|
||||
# Connection pool
|
||||
_pool = None
|
||||
|
||||
|
||||
async def get_pool():
|
||||
"""Get or create database connection pool."""
|
||||
global _pool
|
||||
if _pool is None:
|
||||
try:
|
||||
import asyncpg
|
||||
_pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10)
|
||||
except ImportError:
|
||||
print("Warning: asyncpg not installed. Metrics storage disabled.")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to connect to PostgreSQL: {e}")
|
||||
return None
|
||||
return _pool
|
||||
|
||||
|
||||
async def init_metrics_tables() -> bool:
|
||||
"""Initialize metrics tables in PostgreSQL."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return False
|
||||
|
||||
create_tables_sql = """
|
||||
-- RAG Search Feedback Table
|
||||
CREATE TABLE IF NOT EXISTS rag_search_feedback (
|
||||
id SERIAL PRIMARY KEY,
|
||||
result_id VARCHAR(255) NOT NULL,
|
||||
query_text TEXT,
|
||||
collection_name VARCHAR(100),
|
||||
score FLOAT,
|
||||
rating INTEGER CHECK (rating >= 1 AND rating <= 5),
|
||||
notes TEXT,
|
||||
user_id VARCHAR(100),
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
-- Index for efficient querying
|
||||
CREATE INDEX IF NOT EXISTS idx_feedback_created_at ON rag_search_feedback(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_feedback_collection ON rag_search_feedback(collection_name);
|
||||
CREATE INDEX IF NOT EXISTS idx_feedback_rating ON rag_search_feedback(rating);
|
||||
|
||||
-- RAG Search Logs Table (for latency tracking)
|
||||
CREATE TABLE IF NOT EXISTS rag_search_logs (
|
||||
id SERIAL PRIMARY KEY,
|
||||
query_text TEXT NOT NULL,
|
||||
collection_name VARCHAR(100),
|
||||
result_count INTEGER,
|
||||
latency_ms INTEGER,
|
||||
top_score FLOAT,
|
||||
filters JSONB,
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_search_logs_created_at ON rag_search_logs(created_at);
|
||||
|
||||
-- RAG Upload History Table
|
||||
CREATE TABLE IF NOT EXISTS rag_upload_history (
|
||||
id SERIAL PRIMARY KEY,
|
||||
filename VARCHAR(500) NOT NULL,
|
||||
collection_name VARCHAR(100),
|
||||
year INTEGER,
|
||||
pdfs_extracted INTEGER,
|
||||
minio_path VARCHAR(1000),
|
||||
uploaded_by VARCHAR(100),
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_upload_history_created_at ON rag_upload_history(created_at);
|
||||
|
||||
-- Binäre Relevanz-Judgments für echte Precision/Recall
|
||||
CREATE TABLE IF NOT EXISTS rag_relevance_judgments (
|
||||
id SERIAL PRIMARY KEY,
|
||||
query_id VARCHAR(255) NOT NULL,
|
||||
query_text TEXT NOT NULL,
|
||||
result_id VARCHAR(255) NOT NULL,
|
||||
result_rank INTEGER,
|
||||
is_relevant BOOLEAN NOT NULL,
|
||||
collection_name VARCHAR(100),
|
||||
user_id VARCHAR(100),
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_relevance_query ON rag_relevance_judgments(query_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_relevance_created_at ON rag_relevance_judgments(created_at);
|
||||
|
||||
-- Zeugnisse Source Tracking
|
||||
CREATE TABLE IF NOT EXISTS zeugnis_sources (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
bundesland VARCHAR(10) NOT NULL,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
base_url TEXT,
|
||||
license_type VARCHAR(50) NOT NULL,
|
||||
training_allowed BOOLEAN DEFAULT FALSE,
|
||||
verified_by VARCHAR(100),
|
||||
verified_at TIMESTAMP,
|
||||
created_at TIMESTAMP DEFAULT NOW(),
|
||||
updated_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_zeugnis_sources_bundesland ON zeugnis_sources(bundesland);
|
||||
|
||||
-- Zeugnisse Seed URLs
|
||||
CREATE TABLE IF NOT EXISTS zeugnis_seed_urls (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
source_id VARCHAR(36) REFERENCES zeugnis_sources(id),
|
||||
url TEXT NOT NULL,
|
||||
doc_type VARCHAR(50),
|
||||
status VARCHAR(20) DEFAULT 'pending',
|
||||
last_crawled TIMESTAMP,
|
||||
error_message TEXT,
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_zeugnis_seed_urls_source ON zeugnis_seed_urls(source_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_zeugnis_seed_urls_status ON zeugnis_seed_urls(status);
|
||||
|
||||
-- Zeugnisse Documents
|
||||
CREATE TABLE IF NOT EXISTS zeugnis_documents (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
seed_url_id VARCHAR(36) REFERENCES zeugnis_seed_urls(id),
|
||||
title VARCHAR(500),
|
||||
url TEXT NOT NULL,
|
||||
content_hash VARCHAR(64),
|
||||
minio_path TEXT,
|
||||
training_allowed BOOLEAN DEFAULT FALSE,
|
||||
indexed_in_qdrant BOOLEAN DEFAULT FALSE,
|
||||
file_size INTEGER,
|
||||
content_type VARCHAR(100),
|
||||
created_at TIMESTAMP DEFAULT NOW(),
|
||||
updated_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_zeugnis_documents_seed ON zeugnis_documents(seed_url_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_zeugnis_documents_hash ON zeugnis_documents(content_hash);
|
||||
|
||||
-- Zeugnisse Document Versions
|
||||
CREATE TABLE IF NOT EXISTS zeugnis_document_versions (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
document_id VARCHAR(36) REFERENCES zeugnis_documents(id),
|
||||
version INTEGER NOT NULL,
|
||||
content_hash VARCHAR(64),
|
||||
minio_path TEXT,
|
||||
change_summary TEXT,
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_zeugnis_versions_doc ON zeugnis_document_versions(document_id);
|
||||
|
||||
-- Zeugnisse Usage Events (Audit Trail)
|
||||
CREATE TABLE IF NOT EXISTS zeugnis_usage_events (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
document_id VARCHAR(36) REFERENCES zeugnis_documents(id),
|
||||
event_type VARCHAR(50) NOT NULL,
|
||||
user_id VARCHAR(100),
|
||||
details JSONB,
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_zeugnis_events_doc ON zeugnis_usage_events(document_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_zeugnis_events_type ON zeugnis_usage_events(event_type);
|
||||
CREATE INDEX IF NOT EXISTS idx_zeugnis_events_created ON zeugnis_usage_events(created_at);
|
||||
|
||||
-- Crawler Queue
|
||||
CREATE TABLE IF NOT EXISTS zeugnis_crawler_queue (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
source_id VARCHAR(36) REFERENCES zeugnis_sources(id),
|
||||
priority INTEGER DEFAULT 5,
|
||||
status VARCHAR(20) DEFAULT 'pending',
|
||||
started_at TIMESTAMP,
|
||||
completed_at TIMESTAMP,
|
||||
documents_found INTEGER DEFAULT 0,
|
||||
documents_indexed INTEGER DEFAULT 0,
|
||||
error_count INTEGER DEFAULT 0,
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_crawler_queue_status ON zeugnis_crawler_queue(status);
|
||||
"""
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(create_tables_sql)
|
||||
print("RAG metrics tables initialized")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to initialize metrics tables: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Feedback Storage
|
||||
# =============================================================================
|
||||
|
||||
async def store_feedback(
|
||||
result_id: str,
|
||||
rating: int,
|
||||
query_text: Optional[str] = None,
|
||||
collection_name: Optional[str] = None,
|
||||
score: Optional[float] = None,
|
||||
notes: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Store search result feedback."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO rag_search_feedback
|
||||
(result_id, query_text, collection_name, score, rating, notes, user_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
""",
|
||||
result_id, query_text, collection_name, score, rating, notes, user_id
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to store feedback: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def log_search(
|
||||
query_text: str,
|
||||
collection_name: str,
|
||||
result_count: int,
|
||||
latency_ms: int,
|
||||
top_score: Optional[float] = None,
|
||||
filters: Optional[Dict] = None,
|
||||
) -> bool:
|
||||
"""Log a search for metrics tracking."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
import json
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO rag_search_logs
|
||||
(query_text, collection_name, result_count, latency_ms, top_score, filters)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
""",
|
||||
query_text, collection_name, result_count, latency_ms, top_score,
|
||||
json.dumps(filters) if filters else None
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to log search: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def log_upload(
|
||||
filename: str,
|
||||
collection_name: str,
|
||||
year: int,
|
||||
pdfs_extracted: int,
|
||||
minio_path: Optional[str] = None,
|
||||
uploaded_by: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Log an upload for history tracking."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO rag_upload_history
|
||||
(filename, collection_name, year, pdfs_extracted, minio_path, uploaded_by)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
""",
|
||||
filename, collection_name, year, pdfs_extracted, minio_path, uploaded_by
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to log upload: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Metrics Calculation
|
||||
# =============================================================================
|
||||
|
||||
async def calculate_metrics(
|
||||
collection_name: Optional[str] = None,
|
||||
days: int = 7,
|
||||
) -> Dict:
|
||||
"""
|
||||
Calculate RAG quality metrics from stored feedback.
|
||||
|
||||
Returns:
|
||||
Dict with precision, recall, MRR, latency, etc.
|
||||
"""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return {"error": "Database not available", "connected": False}
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
# Date filter
|
||||
since = datetime.now() - timedelta(days=days)
|
||||
|
||||
# Collection filter
|
||||
collection_filter = ""
|
||||
params = [since]
|
||||
if collection_name:
|
||||
collection_filter = "AND collection_name = $2"
|
||||
params.append(collection_name)
|
||||
|
||||
# Total feedback count
|
||||
total_feedback = await conn.fetchval(
|
||||
f"""
|
||||
SELECT COUNT(*) FROM rag_search_feedback
|
||||
WHERE created_at >= $1 {collection_filter}
|
||||
""",
|
||||
*params
|
||||
)
|
||||
|
||||
# Rating distribution
|
||||
rating_dist = await conn.fetch(
|
||||
f"""
|
||||
SELECT rating, COUNT(*) as count
|
||||
FROM rag_search_feedback
|
||||
WHERE created_at >= $1 {collection_filter}
|
||||
GROUP BY rating
|
||||
ORDER BY rating DESC
|
||||
""",
|
||||
*params
|
||||
)
|
||||
|
||||
# Average rating (proxy for precision)
|
||||
avg_rating = await conn.fetchval(
|
||||
f"""
|
||||
SELECT AVG(rating) FROM rag_search_feedback
|
||||
WHERE created_at >= $1 {collection_filter}
|
||||
""",
|
||||
*params
|
||||
)
|
||||
|
||||
# Score distribution
|
||||
score_dist = await conn.fetch(
|
||||
f"""
|
||||
SELECT
|
||||
CASE
|
||||
WHEN score >= 0.9 THEN '0.9+'
|
||||
WHEN score >= 0.7 THEN '0.7-0.9'
|
||||
WHEN score >= 0.5 THEN '0.5-0.7'
|
||||
ELSE '<0.5'
|
||||
END as range,
|
||||
COUNT(*) as count
|
||||
FROM rag_search_feedback
|
||||
WHERE created_at >= $1 AND score IS NOT NULL {collection_filter}
|
||||
GROUP BY range
|
||||
ORDER BY range DESC
|
||||
""",
|
||||
*params
|
||||
)
|
||||
|
||||
# Search logs for latency
|
||||
latency_stats = await conn.fetchrow(
|
||||
f"""
|
||||
SELECT
|
||||
AVG(latency_ms) as avg_latency,
|
||||
COUNT(*) as total_searches,
|
||||
AVG(result_count) as avg_results
|
||||
FROM rag_search_logs
|
||||
WHERE created_at >= $1 {collection_filter.replace('collection_name', 'collection_name')}
|
||||
""",
|
||||
*params
|
||||
)
|
||||
|
||||
# Calculate precision@5 (% of top 5 rated 4+)
|
||||
precision_at_5 = await conn.fetchval(
|
||||
f"""
|
||||
SELECT
|
||||
CASE WHEN COUNT(*) > 0
|
||||
THEN CAST(SUM(CASE WHEN rating >= 4 THEN 1 ELSE 0 END) AS FLOAT) / COUNT(*)
|
||||
ELSE 0 END
|
||||
FROM rag_search_feedback
|
||||
WHERE created_at >= $1 {collection_filter}
|
||||
""",
|
||||
*params
|
||||
) or 0
|
||||
|
||||
# Calculate MRR (Mean Reciprocal Rank) - simplified
|
||||
# Using average rating as proxy for relevance
|
||||
mrr = (avg_rating or 0) / 5.0
|
||||
|
||||
# Error rate (ratings of 1 or 2)
|
||||
error_count = sum(
|
||||
r['count'] for r in rating_dist if r['rating'] and r['rating'] <= 2
|
||||
)
|
||||
error_rate = (error_count / total_feedback * 100) if total_feedback > 0 else 0
|
||||
|
||||
# Score distribution as percentages
|
||||
total_scored = sum(s['count'] for s in score_dist)
|
||||
score_distribution = {}
|
||||
for s in score_dist:
|
||||
if total_scored > 0:
|
||||
score_distribution[s['range']] = round(s['count'] / total_scored * 100)
|
||||
else:
|
||||
score_distribution[s['range']] = 0
|
||||
|
||||
return {
|
||||
"connected": True,
|
||||
"period_days": days,
|
||||
"precision_at_5": round(precision_at_5, 2),
|
||||
"recall_at_10": round(precision_at_5 * 1.1, 2), # Estimated
|
||||
"mrr": round(mrr, 2),
|
||||
"avg_latency_ms": round(latency_stats['avg_latency'] or 0),
|
||||
"total_ratings": total_feedback,
|
||||
"total_searches": latency_stats['total_searches'] or 0,
|
||||
"error_rate": round(error_rate, 1),
|
||||
"score_distribution": score_distribution,
|
||||
"rating_distribution": {
|
||||
str(r['rating']): r['count'] for r in rating_dist if r['rating']
|
||||
},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to calculate metrics: {e}")
|
||||
return {"error": str(e), "connected": False}
|
||||
|
||||
|
||||
async def get_recent_feedback(limit: int = 20) -> List[Dict]:
|
||||
"""Get recent feedback entries."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT result_id, rating, query_text, collection_name, score, notes, created_at
|
||||
FROM rag_search_feedback
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $1
|
||||
""",
|
||||
limit
|
||||
)
|
||||
return [
|
||||
{
|
||||
"result_id": r['result_id'],
|
||||
"rating": r['rating'],
|
||||
"query_text": r['query_text'],
|
||||
"collection_name": r['collection_name'],
|
||||
"score": r['score'],
|
||||
"notes": r['notes'],
|
||||
"created_at": r['created_at'].isoformat() if r['created_at'] else None,
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
except Exception as e:
|
||||
print(f"Failed to get recent feedback: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def get_upload_history(limit: int = 20) -> List[Dict]:
|
||||
"""Get recent upload history."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT filename, collection_name, year, pdfs_extracted, minio_path, uploaded_by, created_at
|
||||
FROM rag_upload_history
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $1
|
||||
""",
|
||||
limit
|
||||
)
|
||||
return [
|
||||
{
|
||||
"filename": r['filename'],
|
||||
"collection_name": r['collection_name'],
|
||||
"year": r['year'],
|
||||
"pdfs_extracted": r['pdfs_extracted'],
|
||||
"minio_path": r['minio_path'],
|
||||
"uploaded_by": r['uploaded_by'],
|
||||
"created_at": r['created_at'].isoformat() if r['created_at'] else None,
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
except Exception as e:
|
||||
print(f"Failed to get upload history: {e}")
|
||||
return []
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Relevance Judgments (Binary Precision/Recall)
|
||||
# =============================================================================
|
||||
|
||||
async def store_relevance_judgment(
|
||||
query_id: str,
|
||||
query_text: str,
|
||||
result_id: str,
|
||||
is_relevant: bool,
|
||||
result_rank: Optional[int] = None,
|
||||
collection_name: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Store binary relevance judgment for Precision/Recall calculation."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO rag_relevance_judgments
|
||||
(query_id, query_text, result_id, result_rank, is_relevant, collection_name, user_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
ON CONFLICT DO NOTHING
|
||||
""",
|
||||
query_id, query_text, result_id, result_rank, is_relevant, collection_name, user_id
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to store relevance judgment: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def calculate_precision_recall(
|
||||
collection_name: Optional[str] = None,
|
||||
days: int = 7,
|
||||
k: int = 10,
|
||||
) -> Dict:
|
||||
"""
|
||||
Calculate true Precision@k and Recall@k from binary relevance judgments.
|
||||
|
||||
Precision@k = (Relevant docs in top k) / k
|
||||
Recall@k = (Relevant docs in top k) / (Total relevant docs for query)
|
||||
"""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return {"error": "Database not available", "connected": False}
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
since = datetime.now() - timedelta(days=days)
|
||||
|
||||
collection_filter = ""
|
||||
params = [since, k]
|
||||
if collection_name:
|
||||
collection_filter = "AND collection_name = $3"
|
||||
params.append(collection_name)
|
||||
|
||||
# Get precision@k per query, then average
|
||||
precision_result = await conn.fetchval(
|
||||
f"""
|
||||
WITH query_precision AS (
|
||||
SELECT
|
||||
query_id,
|
||||
COUNT(CASE WHEN is_relevant THEN 1 END)::FLOAT /
|
||||
GREATEST(COUNT(*), 1) as precision
|
||||
FROM rag_relevance_judgments
|
||||
WHERE created_at >= $1
|
||||
AND (result_rank IS NULL OR result_rank <= $2)
|
||||
{collection_filter}
|
||||
GROUP BY query_id
|
||||
)
|
||||
SELECT AVG(precision) FROM query_precision
|
||||
""",
|
||||
*params
|
||||
) or 0
|
||||
|
||||
# Get recall@k per query, then average
|
||||
recall_result = await conn.fetchval(
|
||||
f"""
|
||||
WITH query_recall AS (
|
||||
SELECT
|
||||
query_id,
|
||||
COUNT(CASE WHEN is_relevant AND (result_rank IS NULL OR result_rank <= $2) THEN 1 END)::FLOAT /
|
||||
GREATEST(COUNT(CASE WHEN is_relevant THEN 1 END), 1) as recall
|
||||
FROM rag_relevance_judgments
|
||||
WHERE created_at >= $1
|
||||
{collection_filter}
|
||||
GROUP BY query_id
|
||||
)
|
||||
SELECT AVG(recall) FROM query_recall
|
||||
""",
|
||||
*params
|
||||
) or 0
|
||||
|
||||
# Total judgments
|
||||
total_judgments = await conn.fetchval(
|
||||
f"""
|
||||
SELECT COUNT(*) FROM rag_relevance_judgments
|
||||
WHERE created_at >= $1 {collection_filter}
|
||||
""",
|
||||
since, *([collection_name] if collection_name else [])
|
||||
)
|
||||
|
||||
# Unique queries
|
||||
unique_queries = await conn.fetchval(
|
||||
f"""
|
||||
SELECT COUNT(DISTINCT query_id) FROM rag_relevance_judgments
|
||||
WHERE created_at >= $1 {collection_filter}
|
||||
""",
|
||||
since, *([collection_name] if collection_name else [])
|
||||
)
|
||||
|
||||
return {
|
||||
"connected": True,
|
||||
"period_days": days,
|
||||
"k": k,
|
||||
"precision_at_k": round(precision_result, 3),
|
||||
"recall_at_k": round(recall_result, 3),
|
||||
"f1_score": round(
|
||||
2 * precision_result * recall_result / max(precision_result + recall_result, 0.001), 3
|
||||
),
|
||||
"total_judgments": total_judgments or 0,
|
||||
"unique_queries": unique_queries or 0,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to calculate precision/recall: {e}")
|
||||
return {"error": str(e), "connected": False}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Zeugnis Database Operations
|
||||
# =============================================================================
|
||||
|
||||
async def get_zeugnis_sources() -> List[Dict]:
|
||||
"""Get all zeugnis sources (Bundesländer)."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT id, bundesland, name, base_url, license_type, training_allowed,
|
||||
verified_by, verified_at, created_at, updated_at
|
||||
FROM zeugnis_sources
|
||||
ORDER BY bundesland
|
||||
"""
|
||||
)
|
||||
return [dict(r) for r in rows]
|
||||
except Exception as e:
|
||||
print(f"Failed to get zeugnis sources: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def upsert_zeugnis_source(
|
||||
id: str,
|
||||
bundesland: str,
|
||||
name: str,
|
||||
license_type: str,
|
||||
training_allowed: bool,
|
||||
base_url: Optional[str] = None,
|
||||
verified_by: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Insert or update a zeugnis source."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO zeugnis_sources (id, bundesland, name, base_url, license_type, training_allowed, verified_by, verified_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW())
|
||||
ON CONFLICT (id) DO UPDATE SET
|
||||
name = EXCLUDED.name,
|
||||
base_url = EXCLUDED.base_url,
|
||||
license_type = EXCLUDED.license_type,
|
||||
training_allowed = EXCLUDED.training_allowed,
|
||||
verified_by = EXCLUDED.verified_by,
|
||||
verified_at = NOW(),
|
||||
updated_at = NOW()
|
||||
""",
|
||||
id, bundesland, name, base_url, license_type, training_allowed, verified_by
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to upsert zeugnis source: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_zeugnis_documents(
|
||||
bundesland: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> List[Dict]:
|
||||
"""Get zeugnis documents with optional filtering."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
if bundesland:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT d.*, s.bundesland, s.name as source_name
|
||||
FROM zeugnis_documents d
|
||||
JOIN zeugnis_seed_urls u ON d.seed_url_id = u.id
|
||||
JOIN zeugnis_sources s ON u.source_id = s.id
|
||||
WHERE s.bundesland = $1
|
||||
ORDER BY d.created_at DESC
|
||||
LIMIT $2 OFFSET $3
|
||||
""",
|
||||
bundesland, limit, offset
|
||||
)
|
||||
else:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT d.*, s.bundesland, s.name as source_name
|
||||
FROM zeugnis_documents d
|
||||
JOIN zeugnis_seed_urls u ON d.seed_url_id = u.id
|
||||
JOIN zeugnis_sources s ON u.source_id = s.id
|
||||
ORDER BY d.created_at DESC
|
||||
LIMIT $1 OFFSET $2
|
||||
""",
|
||||
limit, offset
|
||||
)
|
||||
return [dict(r) for r in rows]
|
||||
except Exception as e:
|
||||
print(f"Failed to get zeugnis documents: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def get_zeugnis_stats() -> Dict:
|
||||
"""Get zeugnis crawler statistics."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return {"error": "Database not available"}
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
# Total sources
|
||||
sources = await conn.fetchval("SELECT COUNT(*) FROM zeugnis_sources")
|
||||
|
||||
# Total documents
|
||||
documents = await conn.fetchval("SELECT COUNT(*) FROM zeugnis_documents")
|
||||
|
||||
# Indexed documents
|
||||
indexed = await conn.fetchval(
|
||||
"SELECT COUNT(*) FROM zeugnis_documents WHERE indexed_in_qdrant = true"
|
||||
)
|
||||
|
||||
# Training allowed
|
||||
training_allowed = await conn.fetchval(
|
||||
"SELECT COUNT(*) FROM zeugnis_documents WHERE training_allowed = true"
|
||||
)
|
||||
|
||||
# Per Bundesland stats
|
||||
per_bundesland = await conn.fetch(
|
||||
"""
|
||||
SELECT s.bundesland, s.name, s.training_allowed, COUNT(d.id) as doc_count
|
||||
FROM zeugnis_sources s
|
||||
LEFT JOIN zeugnis_seed_urls u ON s.id = u.source_id
|
||||
LEFT JOIN zeugnis_documents d ON u.id = d.seed_url_id
|
||||
GROUP BY s.bundesland, s.name, s.training_allowed
|
||||
ORDER BY s.bundesland
|
||||
"""
|
||||
)
|
||||
|
||||
# Active crawls
|
||||
active_crawls = await conn.fetchval(
|
||||
"SELECT COUNT(*) FROM zeugnis_crawler_queue WHERE status = 'running'"
|
||||
)
|
||||
|
||||
return {
|
||||
"total_sources": sources or 0,
|
||||
"total_documents": documents or 0,
|
||||
"indexed_documents": indexed or 0,
|
||||
"training_allowed_documents": training_allowed or 0,
|
||||
"active_crawls": active_crawls or 0,
|
||||
"per_bundesland": [dict(r) for r in per_bundesland],
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"Failed to get zeugnis stats: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
async def log_zeugnis_event(
|
||||
document_id: str,
|
||||
event_type: str,
|
||||
user_id: Optional[str] = None,
|
||||
details: Optional[Dict] = None,
|
||||
) -> bool:
|
||||
"""Log a zeugnis usage event for audit trail."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
import json
|
||||
import uuid
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO zeugnis_usage_events (id, document_id, event_type, user_id, details)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
""",
|
||||
str(uuid.uuid4()), document_id, event_type, user_id,
|
||||
json.dumps(details) if details else None
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to log zeugnis event: {e}")
|
||||
return False
|
||||
# Schema: table initialization
|
||||
from metrics_db_schema import init_metrics_tables # noqa: F401
|
||||
|
||||
# Core: pool, feedback, search logs, metrics, relevance
|
||||
from metrics_db_core import ( # noqa: F401
|
||||
DATABASE_URL,
|
||||
get_pool,
|
||||
store_feedback,
|
||||
log_search,
|
||||
log_upload,
|
||||
calculate_metrics,
|
||||
get_recent_feedback,
|
||||
get_upload_history,
|
||||
store_relevance_judgment,
|
||||
calculate_precision_recall,
|
||||
)
|
||||
|
||||
# Zeugnis operations
|
||||
from metrics_db_zeugnis import ( # noqa: F401
|
||||
get_zeugnis_sources,
|
||||
upsert_zeugnis_source,
|
||||
get_zeugnis_documents,
|
||||
get_zeugnis_stats,
|
||||
log_zeugnis_event,
|
||||
)
|
||||
|
||||
459
klausur-service/backend/metrics_db_core.py
Normal file
459
klausur-service/backend/metrics_db_core.py
Normal file
@@ -0,0 +1,459 @@
|
||||
"""
|
||||
PostgreSQL Metrics Database - Core Operations
|
||||
|
||||
Connection pool, table initialization, feedback storage, search logging,
|
||||
upload history, metrics calculation, and relevance judgments.
|
||||
|
||||
Extracted from metrics_db.py to keep files under 500 LOC.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional, List, Dict
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# Database Configuration - uses test default if not configured (for CI)
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://test:test@localhost:5432/test_metrics")
|
||||
|
||||
# Connection pool
|
||||
_pool = None
|
||||
|
||||
|
||||
async def get_pool():
|
||||
"""Get or create database connection pool."""
|
||||
global _pool
|
||||
if _pool is None:
|
||||
try:
|
||||
import asyncpg
|
||||
_pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10)
|
||||
except ImportError:
|
||||
print("Warning: asyncpg not installed. Metrics storage disabled.")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to connect to PostgreSQL: {e}")
|
||||
return None
|
||||
return _pool
|
||||
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Feedback Storage
|
||||
# =============================================================================
|
||||
|
||||
async def store_feedback(
|
||||
result_id: str,
|
||||
rating: int,
|
||||
query_text: Optional[str] = None,
|
||||
collection_name: Optional[str] = None,
|
||||
score: Optional[float] = None,
|
||||
notes: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Store search result feedback."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO rag_search_feedback
|
||||
(result_id, query_text, collection_name, score, rating, notes, user_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
""",
|
||||
result_id, query_text, collection_name, score, rating, notes, user_id
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to store feedback: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def log_search(
|
||||
query_text: str,
|
||||
collection_name: str,
|
||||
result_count: int,
|
||||
latency_ms: int,
|
||||
top_score: Optional[float] = None,
|
||||
filters: Optional[Dict] = None,
|
||||
) -> bool:
|
||||
"""Log a search for metrics tracking."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
import json
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO rag_search_logs
|
||||
(query_text, collection_name, result_count, latency_ms, top_score, filters)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
""",
|
||||
query_text, collection_name, result_count, latency_ms, top_score,
|
||||
json.dumps(filters) if filters else None
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to log search: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def log_upload(
|
||||
filename: str,
|
||||
collection_name: str,
|
||||
year: int,
|
||||
pdfs_extracted: int,
|
||||
minio_path: Optional[str] = None,
|
||||
uploaded_by: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Log an upload for history tracking."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO rag_upload_history
|
||||
(filename, collection_name, year, pdfs_extracted, minio_path, uploaded_by)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
""",
|
||||
filename, collection_name, year, pdfs_extracted, minio_path, uploaded_by
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to log upload: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Metrics Calculation
|
||||
# =============================================================================
|
||||
|
||||
async def calculate_metrics(
|
||||
collection_name: Optional[str] = None,
|
||||
days: int = 7,
|
||||
) -> Dict:
|
||||
"""
|
||||
Calculate RAG quality metrics from stored feedback.
|
||||
|
||||
Returns:
|
||||
Dict with precision, recall, MRR, latency, etc.
|
||||
"""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return {"error": "Database not available", "connected": False}
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
since = datetime.now() - timedelta(days=days)
|
||||
|
||||
collection_filter = ""
|
||||
params = [since]
|
||||
if collection_name:
|
||||
collection_filter = "AND collection_name = $2"
|
||||
params.append(collection_name)
|
||||
|
||||
total_feedback = await conn.fetchval(
|
||||
f"""
|
||||
SELECT COUNT(*) FROM rag_search_feedback
|
||||
WHERE created_at >= $1 {collection_filter}
|
||||
""",
|
||||
*params
|
||||
)
|
||||
|
||||
rating_dist = await conn.fetch(
|
||||
f"""
|
||||
SELECT rating, COUNT(*) as count
|
||||
FROM rag_search_feedback
|
||||
WHERE created_at >= $1 {collection_filter}
|
||||
GROUP BY rating
|
||||
ORDER BY rating DESC
|
||||
""",
|
||||
*params
|
||||
)
|
||||
|
||||
avg_rating = await conn.fetchval(
|
||||
f"""
|
||||
SELECT AVG(rating) FROM rag_search_feedback
|
||||
WHERE created_at >= $1 {collection_filter}
|
||||
""",
|
||||
*params
|
||||
)
|
||||
|
||||
score_dist = await conn.fetch(
|
||||
f"""
|
||||
SELECT
|
||||
CASE
|
||||
WHEN score >= 0.9 THEN '0.9+'
|
||||
WHEN score >= 0.7 THEN '0.7-0.9'
|
||||
WHEN score >= 0.5 THEN '0.5-0.7'
|
||||
ELSE '<0.5'
|
||||
END as range,
|
||||
COUNT(*) as count
|
||||
FROM rag_search_feedback
|
||||
WHERE created_at >= $1 AND score IS NOT NULL {collection_filter}
|
||||
GROUP BY range
|
||||
ORDER BY range DESC
|
||||
""",
|
||||
*params
|
||||
)
|
||||
|
||||
latency_stats = await conn.fetchrow(
|
||||
f"""
|
||||
SELECT
|
||||
AVG(latency_ms) as avg_latency,
|
||||
COUNT(*) as total_searches,
|
||||
AVG(result_count) as avg_results
|
||||
FROM rag_search_logs
|
||||
WHERE created_at >= $1 {collection_filter.replace('collection_name', 'collection_name')}
|
||||
""",
|
||||
*params
|
||||
)
|
||||
|
||||
precision_at_5 = await conn.fetchval(
|
||||
f"""
|
||||
SELECT
|
||||
CASE WHEN COUNT(*) > 0
|
||||
THEN CAST(SUM(CASE WHEN rating >= 4 THEN 1 ELSE 0 END) AS FLOAT) / COUNT(*)
|
||||
ELSE 0 END
|
||||
FROM rag_search_feedback
|
||||
WHERE created_at >= $1 {collection_filter}
|
||||
""",
|
||||
*params
|
||||
) or 0
|
||||
|
||||
mrr = (avg_rating or 0) / 5.0
|
||||
|
||||
error_count = sum(
|
||||
r['count'] for r in rating_dist if r['rating'] and r['rating'] <= 2
|
||||
)
|
||||
error_rate = (error_count / total_feedback * 100) if total_feedback > 0 else 0
|
||||
|
||||
total_scored = sum(s['count'] for s in score_dist)
|
||||
score_distribution = {}
|
||||
for s in score_dist:
|
||||
if total_scored > 0:
|
||||
score_distribution[s['range']] = round(s['count'] / total_scored * 100)
|
||||
else:
|
||||
score_distribution[s['range']] = 0
|
||||
|
||||
return {
|
||||
"connected": True,
|
||||
"period_days": days,
|
||||
"precision_at_5": round(precision_at_5, 2),
|
||||
"recall_at_10": round(precision_at_5 * 1.1, 2),
|
||||
"mrr": round(mrr, 2),
|
||||
"avg_latency_ms": round(latency_stats['avg_latency'] or 0),
|
||||
"total_ratings": total_feedback,
|
||||
"total_searches": latency_stats['total_searches'] or 0,
|
||||
"error_rate": round(error_rate, 1),
|
||||
"score_distribution": score_distribution,
|
||||
"rating_distribution": {
|
||||
str(r['rating']): r['count'] for r in rating_dist if r['rating']
|
||||
},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to calculate metrics: {e}")
|
||||
return {"error": str(e), "connected": False}
|
||||
|
||||
|
||||
async def get_recent_feedback(limit: int = 20) -> List[Dict]:
|
||||
"""Get recent feedback entries."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT result_id, rating, query_text, collection_name, score, notes, created_at
|
||||
FROM rag_search_feedback
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $1
|
||||
""",
|
||||
limit
|
||||
)
|
||||
return [
|
||||
{
|
||||
"result_id": r['result_id'],
|
||||
"rating": r['rating'],
|
||||
"query_text": r['query_text'],
|
||||
"collection_name": r['collection_name'],
|
||||
"score": r['score'],
|
||||
"notes": r['notes'],
|
||||
"created_at": r['created_at'].isoformat() if r['created_at'] else None,
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
except Exception as e:
|
||||
print(f"Failed to get recent feedback: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def get_upload_history(limit: int = 20) -> List[Dict]:
|
||||
"""Get recent upload history."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT filename, collection_name, year, pdfs_extracted, minio_path, uploaded_by, created_at
|
||||
FROM rag_upload_history
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $1
|
||||
""",
|
||||
limit
|
||||
)
|
||||
return [
|
||||
{
|
||||
"filename": r['filename'],
|
||||
"collection_name": r['collection_name'],
|
||||
"year": r['year'],
|
||||
"pdfs_extracted": r['pdfs_extracted'],
|
||||
"minio_path": r['minio_path'],
|
||||
"uploaded_by": r['uploaded_by'],
|
||||
"created_at": r['created_at'].isoformat() if r['created_at'] else None,
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
except Exception as e:
|
||||
print(f"Failed to get upload history: {e}")
|
||||
return []
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Relevance Judgments (Binary Precision/Recall)
|
||||
# =============================================================================
|
||||
|
||||
async def store_relevance_judgment(
|
||||
query_id: str,
|
||||
query_text: str,
|
||||
result_id: str,
|
||||
is_relevant: bool,
|
||||
result_rank: Optional[int] = None,
|
||||
collection_name: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Store binary relevance judgment for Precision/Recall calculation."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO rag_relevance_judgments
|
||||
(query_id, query_text, result_id, result_rank, is_relevant, collection_name, user_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
ON CONFLICT DO NOTHING
|
||||
""",
|
||||
query_id, query_text, result_id, result_rank, is_relevant, collection_name, user_id
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to store relevance judgment: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def calculate_precision_recall(
|
||||
collection_name: Optional[str] = None,
|
||||
days: int = 7,
|
||||
k: int = 10,
|
||||
) -> Dict:
|
||||
"""
|
||||
Calculate true Precision@k and Recall@k from binary relevance judgments.
|
||||
|
||||
Precision@k = (Relevant docs in top k) / k
|
||||
Recall@k = (Relevant docs in top k) / (Total relevant docs for query)
|
||||
"""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return {"error": "Database not available", "connected": False}
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
since = datetime.now() - timedelta(days=days)
|
||||
|
||||
collection_filter = ""
|
||||
params = [since, k]
|
||||
if collection_name:
|
||||
collection_filter = "AND collection_name = $3"
|
||||
params.append(collection_name)
|
||||
|
||||
precision_result = await conn.fetchval(
|
||||
f"""
|
||||
WITH query_precision AS (
|
||||
SELECT
|
||||
query_id,
|
||||
COUNT(CASE WHEN is_relevant THEN 1 END)::FLOAT /
|
||||
GREATEST(COUNT(*), 1) as precision
|
||||
FROM rag_relevance_judgments
|
||||
WHERE created_at >= $1
|
||||
AND (result_rank IS NULL OR result_rank <= $2)
|
||||
{collection_filter}
|
||||
GROUP BY query_id
|
||||
)
|
||||
SELECT AVG(precision) FROM query_precision
|
||||
""",
|
||||
*params
|
||||
) or 0
|
||||
|
||||
recall_result = await conn.fetchval(
|
||||
f"""
|
||||
WITH query_recall AS (
|
||||
SELECT
|
||||
query_id,
|
||||
COUNT(CASE WHEN is_relevant AND (result_rank IS NULL OR result_rank <= $2) THEN 1 END)::FLOAT /
|
||||
GREATEST(COUNT(CASE WHEN is_relevant THEN 1 END), 1) as recall
|
||||
FROM rag_relevance_judgments
|
||||
WHERE created_at >= $1
|
||||
{collection_filter}
|
||||
GROUP BY query_id
|
||||
)
|
||||
SELECT AVG(recall) FROM query_recall
|
||||
""",
|
||||
*params
|
||||
) or 0
|
||||
|
||||
total_judgments = await conn.fetchval(
|
||||
f"""
|
||||
SELECT COUNT(*) FROM rag_relevance_judgments
|
||||
WHERE created_at >= $1 {collection_filter}
|
||||
""",
|
||||
since, *([collection_name] if collection_name else [])
|
||||
)
|
||||
|
||||
unique_queries = await conn.fetchval(
|
||||
f"""
|
||||
SELECT COUNT(DISTINCT query_id) FROM rag_relevance_judgments
|
||||
WHERE created_at >= $1 {collection_filter}
|
||||
""",
|
||||
since, *([collection_name] if collection_name else [])
|
||||
)
|
||||
|
||||
return {
|
||||
"connected": True,
|
||||
"period_days": days,
|
||||
"k": k,
|
||||
"precision_at_k": round(precision_result, 3),
|
||||
"recall_at_k": round(recall_result, 3),
|
||||
"f1_score": round(
|
||||
2 * precision_result * recall_result / max(precision_result + recall_result, 0.001), 3
|
||||
),
|
||||
"total_judgments": total_judgments or 0,
|
||||
"unique_queries": unique_queries or 0,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to calculate precision/recall: {e}")
|
||||
return {"error": str(e), "connected": False}
|
||||
182
klausur-service/backend/metrics_db_schema.py
Normal file
182
klausur-service/backend/metrics_db_schema.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""
|
||||
PostgreSQL Metrics Database - Schema Initialization
|
||||
|
||||
Table creation DDL for all metrics, feedback, and zeugnis tables.
|
||||
|
||||
Extracted from metrics_db_core.py to keep files under 500 LOC.
|
||||
"""
|
||||
|
||||
from metrics_db_core import get_pool
|
||||
|
||||
|
||||
async def init_metrics_tables() -> bool:
|
||||
"""Initialize metrics tables in PostgreSQL."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return False
|
||||
|
||||
create_tables_sql = """
|
||||
-- RAG Search Feedback Table
|
||||
CREATE TABLE IF NOT EXISTS rag_search_feedback (
|
||||
id SERIAL PRIMARY KEY,
|
||||
result_id VARCHAR(255) NOT NULL,
|
||||
query_text TEXT,
|
||||
collection_name VARCHAR(100),
|
||||
score FLOAT,
|
||||
rating INTEGER CHECK (rating >= 1 AND rating <= 5),
|
||||
notes TEXT,
|
||||
user_id VARCHAR(100),
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
-- Index for efficient querying
|
||||
CREATE INDEX IF NOT EXISTS idx_feedback_created_at ON rag_search_feedback(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_feedback_collection ON rag_search_feedback(collection_name);
|
||||
CREATE INDEX IF NOT EXISTS idx_feedback_rating ON rag_search_feedback(rating);
|
||||
|
||||
-- RAG Search Logs Table (for latency tracking)
|
||||
CREATE TABLE IF NOT EXISTS rag_search_logs (
|
||||
id SERIAL PRIMARY KEY,
|
||||
query_text TEXT NOT NULL,
|
||||
collection_name VARCHAR(100),
|
||||
result_count INTEGER,
|
||||
latency_ms INTEGER,
|
||||
top_score FLOAT,
|
||||
filters JSONB,
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_search_logs_created_at ON rag_search_logs(created_at);
|
||||
|
||||
-- RAG Upload History Table
|
||||
CREATE TABLE IF NOT EXISTS rag_upload_history (
|
||||
id SERIAL PRIMARY KEY,
|
||||
filename VARCHAR(500) NOT NULL,
|
||||
collection_name VARCHAR(100),
|
||||
year INTEGER,
|
||||
pdfs_extracted INTEGER,
|
||||
minio_path VARCHAR(1000),
|
||||
uploaded_by VARCHAR(100),
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_upload_history_created_at ON rag_upload_history(created_at);
|
||||
|
||||
-- Binaere Relevanz-Judgments fuer echte Precision/Recall
|
||||
CREATE TABLE IF NOT EXISTS rag_relevance_judgments (
|
||||
id SERIAL PRIMARY KEY,
|
||||
query_id VARCHAR(255) NOT NULL,
|
||||
query_text TEXT NOT NULL,
|
||||
result_id VARCHAR(255) NOT NULL,
|
||||
result_rank INTEGER,
|
||||
is_relevant BOOLEAN NOT NULL,
|
||||
collection_name VARCHAR(100),
|
||||
user_id VARCHAR(100),
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_relevance_query ON rag_relevance_judgments(query_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_relevance_created_at ON rag_relevance_judgments(created_at);
|
||||
|
||||
-- Zeugnisse Source Tracking
|
||||
CREATE TABLE IF NOT EXISTS zeugnis_sources (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
bundesland VARCHAR(10) NOT NULL,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
base_url TEXT,
|
||||
license_type VARCHAR(50) NOT NULL,
|
||||
training_allowed BOOLEAN DEFAULT FALSE,
|
||||
verified_by VARCHAR(100),
|
||||
verified_at TIMESTAMP,
|
||||
created_at TIMESTAMP DEFAULT NOW(),
|
||||
updated_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_zeugnis_sources_bundesland ON zeugnis_sources(bundesland);
|
||||
|
||||
-- Zeugnisse Seed URLs
|
||||
CREATE TABLE IF NOT EXISTS zeugnis_seed_urls (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
source_id VARCHAR(36) REFERENCES zeugnis_sources(id),
|
||||
url TEXT NOT NULL,
|
||||
doc_type VARCHAR(50),
|
||||
status VARCHAR(20) DEFAULT 'pending',
|
||||
last_crawled TIMESTAMP,
|
||||
error_message TEXT,
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_zeugnis_seed_urls_source ON zeugnis_seed_urls(source_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_zeugnis_seed_urls_status ON zeugnis_seed_urls(status);
|
||||
|
||||
-- Zeugnisse Documents
|
||||
CREATE TABLE IF NOT EXISTS zeugnis_documents (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
seed_url_id VARCHAR(36) REFERENCES zeugnis_seed_urls(id),
|
||||
title VARCHAR(500),
|
||||
url TEXT NOT NULL,
|
||||
content_hash VARCHAR(64),
|
||||
minio_path TEXT,
|
||||
training_allowed BOOLEAN DEFAULT FALSE,
|
||||
indexed_in_qdrant BOOLEAN DEFAULT FALSE,
|
||||
file_size INTEGER,
|
||||
content_type VARCHAR(100),
|
||||
created_at TIMESTAMP DEFAULT NOW(),
|
||||
updated_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_zeugnis_documents_seed ON zeugnis_documents(seed_url_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_zeugnis_documents_hash ON zeugnis_documents(content_hash);
|
||||
|
||||
-- Zeugnisse Document Versions
|
||||
CREATE TABLE IF NOT EXISTS zeugnis_document_versions (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
document_id VARCHAR(36) REFERENCES zeugnis_documents(id),
|
||||
version INTEGER NOT NULL,
|
||||
content_hash VARCHAR(64),
|
||||
minio_path TEXT,
|
||||
change_summary TEXT,
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_zeugnis_versions_doc ON zeugnis_document_versions(document_id);
|
||||
|
||||
-- Zeugnisse Usage Events (Audit Trail)
|
||||
CREATE TABLE IF NOT EXISTS zeugnis_usage_events (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
document_id VARCHAR(36) REFERENCES zeugnis_documents(id),
|
||||
event_type VARCHAR(50) NOT NULL,
|
||||
user_id VARCHAR(100),
|
||||
details JSONB,
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_zeugnis_events_doc ON zeugnis_usage_events(document_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_zeugnis_events_type ON zeugnis_usage_events(event_type);
|
||||
CREATE INDEX IF NOT EXISTS idx_zeugnis_events_created ON zeugnis_usage_events(created_at);
|
||||
|
||||
-- Crawler Queue
|
||||
CREATE TABLE IF NOT EXISTS zeugnis_crawler_queue (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
source_id VARCHAR(36) REFERENCES zeugnis_sources(id),
|
||||
priority INTEGER DEFAULT 5,
|
||||
status VARCHAR(20) DEFAULT 'pending',
|
||||
started_at TIMESTAMP,
|
||||
completed_at TIMESTAMP,
|
||||
documents_found INTEGER DEFAULT 0,
|
||||
documents_indexed INTEGER DEFAULT 0,
|
||||
error_count INTEGER DEFAULT 0,
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_crawler_queue_status ON zeugnis_crawler_queue(status);
|
||||
"""
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(create_tables_sql)
|
||||
print("RAG metrics tables initialized")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to initialize metrics tables: {e}")
|
||||
return False
|
||||
193
klausur-service/backend/metrics_db_zeugnis.py
Normal file
193
klausur-service/backend/metrics_db_zeugnis.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""
|
||||
PostgreSQL Metrics Database - Zeugnis Operations
|
||||
|
||||
Zeugnis source management, document queries, statistics, and event logging.
|
||||
|
||||
Extracted from metrics_db.py to keep files under 500 LOC.
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict
|
||||
|
||||
from metrics_db_core import get_pool
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Zeugnis Database Operations
|
||||
# =============================================================================
|
||||
|
||||
async def get_zeugnis_sources() -> List[Dict]:
|
||||
"""Get all zeugnis sources (Bundeslaender)."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT id, bundesland, name, base_url, license_type, training_allowed,
|
||||
verified_by, verified_at, created_at, updated_at
|
||||
FROM zeugnis_sources
|
||||
ORDER BY bundesland
|
||||
"""
|
||||
)
|
||||
return [dict(r) for r in rows]
|
||||
except Exception as e:
|
||||
print(f"Failed to get zeugnis sources: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def upsert_zeugnis_source(
|
||||
id: str,
|
||||
bundesland: str,
|
||||
name: str,
|
||||
license_type: str,
|
||||
training_allowed: bool,
|
||||
base_url: Optional[str] = None,
|
||||
verified_by: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Insert or update a zeugnis source."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO zeugnis_sources (id, bundesland, name, base_url, license_type, training_allowed, verified_by, verified_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW())
|
||||
ON CONFLICT (id) DO UPDATE SET
|
||||
name = EXCLUDED.name,
|
||||
base_url = EXCLUDED.base_url,
|
||||
license_type = EXCLUDED.license_type,
|
||||
training_allowed = EXCLUDED.training_allowed,
|
||||
verified_by = EXCLUDED.verified_by,
|
||||
verified_at = NOW(),
|
||||
updated_at = NOW()
|
||||
""",
|
||||
id, bundesland, name, base_url, license_type, training_allowed, verified_by
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to upsert zeugnis source: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_zeugnis_documents(
|
||||
bundesland: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> List[Dict]:
|
||||
"""Get zeugnis documents with optional filtering."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
if bundesland:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT d.*, s.bundesland, s.name as source_name
|
||||
FROM zeugnis_documents d
|
||||
JOIN zeugnis_seed_urls u ON d.seed_url_id = u.id
|
||||
JOIN zeugnis_sources s ON u.source_id = s.id
|
||||
WHERE s.bundesland = $1
|
||||
ORDER BY d.created_at DESC
|
||||
LIMIT $2 OFFSET $3
|
||||
""",
|
||||
bundesland, limit, offset
|
||||
)
|
||||
else:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT d.*, s.bundesland, s.name as source_name
|
||||
FROM zeugnis_documents d
|
||||
JOIN zeugnis_seed_urls u ON d.seed_url_id = u.id
|
||||
JOIN zeugnis_sources s ON u.source_id = s.id
|
||||
ORDER BY d.created_at DESC
|
||||
LIMIT $1 OFFSET $2
|
||||
""",
|
||||
limit, offset
|
||||
)
|
||||
return [dict(r) for r in rows]
|
||||
except Exception as e:
|
||||
print(f"Failed to get zeugnis documents: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def get_zeugnis_stats() -> Dict:
|
||||
"""Get zeugnis crawler statistics."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return {"error": "Database not available"}
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
sources = await conn.fetchval("SELECT COUNT(*) FROM zeugnis_sources")
|
||||
documents = await conn.fetchval("SELECT COUNT(*) FROM zeugnis_documents")
|
||||
|
||||
indexed = await conn.fetchval(
|
||||
"SELECT COUNT(*) FROM zeugnis_documents WHERE indexed_in_qdrant = true"
|
||||
)
|
||||
|
||||
training_allowed = await conn.fetchval(
|
||||
"SELECT COUNT(*) FROM zeugnis_documents WHERE training_allowed = true"
|
||||
)
|
||||
|
||||
per_bundesland = await conn.fetch(
|
||||
"""
|
||||
SELECT s.bundesland, s.name, s.training_allowed, COUNT(d.id) as doc_count
|
||||
FROM zeugnis_sources s
|
||||
LEFT JOIN zeugnis_seed_urls u ON s.id = u.source_id
|
||||
LEFT JOIN zeugnis_documents d ON u.id = d.seed_url_id
|
||||
GROUP BY s.bundesland, s.name, s.training_allowed
|
||||
ORDER BY s.bundesland
|
||||
"""
|
||||
)
|
||||
|
||||
active_crawls = await conn.fetchval(
|
||||
"SELECT COUNT(*) FROM zeugnis_crawler_queue WHERE status = 'running'"
|
||||
)
|
||||
|
||||
return {
|
||||
"total_sources": sources or 0,
|
||||
"total_documents": documents or 0,
|
||||
"indexed_documents": indexed or 0,
|
||||
"training_allowed_documents": training_allowed or 0,
|
||||
"active_crawls": active_crawls or 0,
|
||||
"per_bundesland": [dict(r) for r in per_bundesland],
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"Failed to get zeugnis stats: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
async def log_zeugnis_event(
|
||||
document_id: str,
|
||||
event_type: str,
|
||||
user_id: Optional[str] = None,
|
||||
details: Optional[Dict] = None,
|
||||
) -> bool:
|
||||
"""Log a zeugnis usage event for audit trail."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
import json
|
||||
import uuid
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO zeugnis_usage_events (id, document_id, event_type, user_id, details)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
""",
|
||||
str(uuid.uuid4()), document_id, event_type, user_id,
|
||||
json.dumps(details) if details else None
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to log zeugnis event: {e}")
|
||||
return False
|
||||
@@ -1,845 +1,81 @@
|
||||
"""
|
||||
OCR Labeling API for Handwriting Training Data Collection
|
||||
OCR Labeling API — Barrel Re-export
|
||||
|
||||
DATENSCHUTZ/PRIVACY:
|
||||
- Alle Verarbeitung erfolgt lokal (Mac Mini mit Ollama)
|
||||
- Keine Daten werden an externe Server gesendet
|
||||
- Bilder werden mit SHA256-Hash dedupliziert
|
||||
- Export nur für lokales Fine-Tuning (TrOCR, llama3.2-vision)
|
||||
Split into:
|
||||
- ocr_labeling_models.py — Pydantic models and constants
|
||||
- ocr_labeling_helpers.py — OCR wrappers, image storage, hashing
|
||||
- ocr_labeling_routes.py — Session/queue/labeling route handlers
|
||||
- ocr_labeling_upload_routes.py — Upload, run-OCR, export route handlers
|
||||
|
||||
Endpoints:
|
||||
- POST /sessions - Create labeling session
|
||||
- POST /sessions/{id}/upload - Upload images for labeling
|
||||
- GET /queue - Get next items to label
|
||||
- POST /confirm - Confirm OCR as correct
|
||||
- POST /correct - Save corrected ground truth
|
||||
- POST /skip - Skip unusable item
|
||||
- GET /stats - Get labeling statistics
|
||||
- POST /export - Export training data
|
||||
All public names are re-exported here for backward compatibility.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query, BackgroundTasks
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
import hashlib
|
||||
import os
|
||||
import base64
|
||||
|
||||
# Import database functions
|
||||
from metrics_db import (
|
||||
create_ocr_labeling_session,
|
||||
get_ocr_labeling_sessions,
|
||||
get_ocr_labeling_session,
|
||||
add_ocr_labeling_item,
|
||||
get_ocr_labeling_queue,
|
||||
get_ocr_labeling_item,
|
||||
confirm_ocr_label,
|
||||
correct_ocr_label,
|
||||
skip_ocr_item,
|
||||
get_ocr_labeling_stats,
|
||||
export_training_samples,
|
||||
get_training_samples,
|
||||
# Models
|
||||
from ocr_labeling_models import ( # noqa: F401
|
||||
LOCAL_STORAGE_PATH,
|
||||
SessionCreate,
|
||||
SessionResponse,
|
||||
ItemResponse,
|
||||
ConfirmRequest,
|
||||
CorrectRequest,
|
||||
SkipRequest,
|
||||
ExportRequest,
|
||||
StatsResponse,
|
||||
)
|
||||
|
||||
# Try to import Vision OCR service
|
||||
try:
|
||||
import sys
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'backend', 'klausur', 'services'))
|
||||
from vision_ocr_service import get_vision_ocr_service, VisionOCRService
|
||||
VISION_OCR_AVAILABLE = True
|
||||
except ImportError:
|
||||
VISION_OCR_AVAILABLE = False
|
||||
print("Warning: Vision OCR service not available")
|
||||
# Helpers
|
||||
from ocr_labeling_helpers import ( # noqa: F401
|
||||
VISION_OCR_AVAILABLE,
|
||||
PADDLEOCR_AVAILABLE,
|
||||
TROCR_AVAILABLE,
|
||||
DONUT_AVAILABLE,
|
||||
MINIO_AVAILABLE,
|
||||
TRAINING_EXPORT_AVAILABLE,
|
||||
compute_image_hash,
|
||||
run_ocr_on_image,
|
||||
run_vision_ocr_wrapper,
|
||||
run_paddleocr_wrapper,
|
||||
run_trocr_wrapper,
|
||||
run_donut_wrapper,
|
||||
save_image_locally,
|
||||
get_image_url,
|
||||
)
|
||||
|
||||
# Try to import PaddleOCR from hybrid_vocab_extractor
|
||||
# Conditional re-exports from helpers' optional imports
|
||||
try:
|
||||
from hybrid_vocab_extractor import run_paddle_ocr
|
||||
PADDLEOCR_AVAILABLE = True
|
||||
from minio_storage import upload_ocr_image, get_ocr_image, MINIO_BUCKET # noqa: F401
|
||||
except ImportError:
|
||||
PADDLEOCR_AVAILABLE = False
|
||||
print("Warning: PaddleOCR not available")
|
||||
pass
|
||||
|
||||
# Try to import TrOCR service
|
||||
try:
|
||||
from services.trocr_service import run_trocr_ocr
|
||||
TROCR_AVAILABLE = True
|
||||
except ImportError:
|
||||
TROCR_AVAILABLE = False
|
||||
print("Warning: TrOCR service not available")
|
||||
|
||||
# Try to import Donut service
|
||||
try:
|
||||
from services.donut_ocr_service import run_donut_ocr
|
||||
DONUT_AVAILABLE = True
|
||||
except ImportError:
|
||||
DONUT_AVAILABLE = False
|
||||
print("Warning: Donut OCR service not available")
|
||||
|
||||
# Try to import MinIO storage
|
||||
try:
|
||||
from minio_storage import upload_ocr_image, get_ocr_image, MINIO_BUCKET
|
||||
MINIO_AVAILABLE = True
|
||||
except ImportError:
|
||||
MINIO_AVAILABLE = False
|
||||
print("Warning: MinIO storage not available, using local storage")
|
||||
|
||||
# Try to import Training Export Service
|
||||
try:
|
||||
from training_export_service import (
|
||||
from training_export_service import ( # noqa: F401
|
||||
TrainingExportService,
|
||||
TrainingSample,
|
||||
get_training_export_service,
|
||||
)
|
||||
TRAINING_EXPORT_AVAILABLE = True
|
||||
except ImportError:
|
||||
TRAINING_EXPORT_AVAILABLE = False
|
||||
print("Warning: Training export service not available")
|
||||
|
||||
router = APIRouter(prefix="/api/v1/ocr-label", tags=["OCR Labeling"])
|
||||
|
||||
# Local storage path (fallback if MinIO not available)
|
||||
LOCAL_STORAGE_PATH = os.getenv("OCR_STORAGE_PATH", "/app/ocr-labeling")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Pydantic Models
|
||||
# =============================================================================
|
||||
|
||||
class SessionCreate(BaseModel):
|
||||
name: str
|
||||
source_type: str = "klausur" # klausur, handwriting_sample, scan
|
||||
description: Optional[str] = None
|
||||
ocr_model: Optional[str] = "llama3.2-vision:11b"
|
||||
|
||||
|
||||
class SessionResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
source_type: str
|
||||
description: Optional[str]
|
||||
ocr_model: Optional[str]
|
||||
total_items: int
|
||||
labeled_items: int
|
||||
confirmed_items: int
|
||||
corrected_items: int
|
||||
skipped_items: int
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class ItemResponse(BaseModel):
|
||||
id: str
|
||||
session_id: str
|
||||
session_name: str
|
||||
image_path: str
|
||||
image_url: Optional[str]
|
||||
ocr_text: Optional[str]
|
||||
ocr_confidence: Optional[float]
|
||||
ground_truth: Optional[str]
|
||||
status: str
|
||||
metadata: Optional[Dict]
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class ConfirmRequest(BaseModel):
|
||||
item_id: str
|
||||
label_time_seconds: Optional[int] = None
|
||||
|
||||
|
||||
class CorrectRequest(BaseModel):
|
||||
item_id: str
|
||||
ground_truth: str
|
||||
label_time_seconds: Optional[int] = None
|
||||
|
||||
|
||||
class SkipRequest(BaseModel):
|
||||
item_id: str
|
||||
|
||||
|
||||
class ExportRequest(BaseModel):
|
||||
export_format: str = "generic" # generic, trocr, llama_vision
|
||||
session_id: Optional[str] = None
|
||||
batch_id: Optional[str] = None
|
||||
|
||||
|
||||
class StatsResponse(BaseModel):
|
||||
total_sessions: Optional[int] = None
|
||||
total_items: int
|
||||
labeled_items: int
|
||||
confirmed_items: int
|
||||
corrected_items: int
|
||||
pending_items: int
|
||||
exportable_items: Optional[int] = None
|
||||
accuracy_rate: float
|
||||
avg_label_time_seconds: Optional[float] = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Helper Functions
|
||||
# =============================================================================
|
||||
|
||||
def compute_image_hash(image_data: bytes) -> str:
|
||||
"""Compute SHA256 hash of image data."""
|
||||
return hashlib.sha256(image_data).hexdigest()
|
||||
|
||||
|
||||
async def run_ocr_on_image(image_data: bytes, filename: str, model: str = "llama3.2-vision:11b") -> tuple:
|
||||
"""
|
||||
Run OCR on an image using the specified model.
|
||||
|
||||
Models:
|
||||
- llama3.2-vision:11b: Vision LLM (default, best for handwriting)
|
||||
- trocr: Microsoft TrOCR (fast for printed text)
|
||||
- paddleocr: PaddleOCR + LLM hybrid (4x faster)
|
||||
- donut: Document Understanding Transformer (structured documents)
|
||||
|
||||
Returns:
|
||||
Tuple of (ocr_text, confidence)
|
||||
"""
|
||||
print(f"Running OCR with model: {model}")
|
||||
|
||||
# Route to appropriate OCR service based on model
|
||||
if model == "paddleocr":
|
||||
return await run_paddleocr_wrapper(image_data, filename)
|
||||
elif model == "donut":
|
||||
return await run_donut_wrapper(image_data, filename)
|
||||
elif model == "trocr":
|
||||
return await run_trocr_wrapper(image_data, filename)
|
||||
else:
|
||||
# Default: Vision LLM (llama3.2-vision or similar)
|
||||
return await run_vision_ocr_wrapper(image_data, filename)
|
||||
|
||||
|
||||
async def run_vision_ocr_wrapper(image_data: bytes, filename: str) -> tuple:
|
||||
"""Vision LLM OCR wrapper."""
|
||||
if not VISION_OCR_AVAILABLE:
|
||||
print("Vision OCR service not available")
|
||||
return None, 0.0
|
||||
|
||||
try:
|
||||
service = get_vision_ocr_service()
|
||||
if not await service.is_available():
|
||||
print("Vision OCR service not available (is_available check failed)")
|
||||
return None, 0.0
|
||||
|
||||
result = await service.extract_text(
|
||||
image_data,
|
||||
filename=filename,
|
||||
is_handwriting=True
|
||||
)
|
||||
return result.text, result.confidence
|
||||
except Exception as e:
|
||||
print(f"Vision OCR failed: {e}")
|
||||
return None, 0.0
|
||||
|
||||
|
||||
async def run_paddleocr_wrapper(image_data: bytes, filename: str) -> tuple:
|
||||
"""PaddleOCR wrapper - uses hybrid_vocab_extractor."""
|
||||
if not PADDLEOCR_AVAILABLE:
|
||||
print("PaddleOCR not available, falling back to Vision OCR")
|
||||
return await run_vision_ocr_wrapper(image_data, filename)
|
||||
|
||||
try:
|
||||
# run_paddle_ocr returns (regions, raw_text)
|
||||
regions, raw_text = run_paddle_ocr(image_data)
|
||||
|
||||
if not raw_text:
|
||||
print("PaddleOCR returned empty text")
|
||||
return None, 0.0
|
||||
|
||||
# Calculate average confidence from regions
|
||||
if regions:
|
||||
avg_confidence = sum(r.confidence for r in regions) / len(regions)
|
||||
else:
|
||||
avg_confidence = 0.5
|
||||
|
||||
return raw_text, avg_confidence
|
||||
except Exception as e:
|
||||
print(f"PaddleOCR failed: {e}, falling back to Vision OCR")
|
||||
return await run_vision_ocr_wrapper(image_data, filename)
|
||||
|
||||
|
||||
async def run_trocr_wrapper(image_data: bytes, filename: str) -> tuple:
|
||||
"""TrOCR wrapper."""
|
||||
if not TROCR_AVAILABLE:
|
||||
print("TrOCR not available, falling back to Vision OCR")
|
||||
return await run_vision_ocr_wrapper(image_data, filename)
|
||||
|
||||
try:
|
||||
text, confidence = await run_trocr_ocr(image_data)
|
||||
return text, confidence
|
||||
except Exception as e:
|
||||
print(f"TrOCR failed: {e}, falling back to Vision OCR")
|
||||
return await run_vision_ocr_wrapper(image_data, filename)
|
||||
|
||||
|
||||
async def run_donut_wrapper(image_data: bytes, filename: str) -> tuple:
|
||||
"""Donut OCR wrapper."""
|
||||
if not DONUT_AVAILABLE:
|
||||
print("Donut not available, falling back to Vision OCR")
|
||||
return await run_vision_ocr_wrapper(image_data, filename)
|
||||
|
||||
try:
|
||||
text, confidence = await run_donut_ocr(image_data)
|
||||
return text, confidence
|
||||
except Exception as e:
|
||||
print(f"Donut OCR failed: {e}, falling back to Vision OCR")
|
||||
return await run_vision_ocr_wrapper(image_data, filename)
|
||||
|
||||
|
||||
def save_image_locally(session_id: str, item_id: str, image_data: bytes, extension: str = "png") -> str:
|
||||
"""Save image to local storage."""
|
||||
session_dir = os.path.join(LOCAL_STORAGE_PATH, session_id)
|
||||
os.makedirs(session_dir, exist_ok=True)
|
||||
|
||||
filename = f"{item_id}.{extension}"
|
||||
filepath = os.path.join(session_dir, filename)
|
||||
|
||||
with open(filepath, 'wb') as f:
|
||||
f.write(image_data)
|
||||
|
||||
return filepath
|
||||
|
||||
|
||||
def get_image_url(image_path: str) -> str:
|
||||
"""Get URL for an image."""
|
||||
# For local images, return a relative path that the frontend can use
|
||||
if image_path.startswith(LOCAL_STORAGE_PATH):
|
||||
relative_path = image_path[len(LOCAL_STORAGE_PATH):].lstrip('/')
|
||||
return f"/api/v1/ocr-label/images/{relative_path}"
|
||||
# For MinIO images, the path is already a URL or key
|
||||
return image_path
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# API Endpoints
|
||||
# =============================================================================
|
||||
|
||||
@router.post("/sessions", response_model=SessionResponse)
|
||||
async def create_session(session: SessionCreate):
|
||||
"""
|
||||
Create a new OCR labeling session.
|
||||
|
||||
A session groups related images for labeling (e.g., all scans from one class).
|
||||
"""
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
success = await create_ocr_labeling_session(
|
||||
session_id=session_id,
|
||||
name=session.name,
|
||||
source_type=session.source_type,
|
||||
description=session.description,
|
||||
ocr_model=session.ocr_model,
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to create session")
|
||||
|
||||
return SessionResponse(
|
||||
id=session_id,
|
||||
name=session.name,
|
||||
source_type=session.source_type,
|
||||
description=session.description,
|
||||
ocr_model=session.ocr_model,
|
||||
total_items=0,
|
||||
labeled_items=0,
|
||||
confirmed_items=0,
|
||||
corrected_items=0,
|
||||
skipped_items=0,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/sessions", response_model=List[SessionResponse])
|
||||
async def list_sessions(limit: int = Query(50, ge=1, le=100)):
|
||||
"""List all OCR labeling sessions."""
|
||||
sessions = await get_ocr_labeling_sessions(limit=limit)
|
||||
|
||||
return [
|
||||
SessionResponse(
|
||||
id=s['id'],
|
||||
name=s['name'],
|
||||
source_type=s['source_type'],
|
||||
description=s.get('description'),
|
||||
ocr_model=s.get('ocr_model'),
|
||||
total_items=s.get('total_items', 0),
|
||||
labeled_items=s.get('labeled_items', 0),
|
||||
confirmed_items=s.get('confirmed_items', 0),
|
||||
corrected_items=s.get('corrected_items', 0),
|
||||
skipped_items=s.get('skipped_items', 0),
|
||||
created_at=s.get('created_at', datetime.utcnow()),
|
||||
)
|
||||
for s in sessions
|
||||
]
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}", response_model=SessionResponse)
|
||||
async def get_session(session_id: str):
|
||||
"""Get a specific OCR labeling session."""
|
||||
session = await get_ocr_labeling_session(session_id)
|
||||
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
return SessionResponse(
|
||||
id=session['id'],
|
||||
name=session['name'],
|
||||
source_type=session['source_type'],
|
||||
description=session.get('description'),
|
||||
ocr_model=session.get('ocr_model'),
|
||||
total_items=session.get('total_items', 0),
|
||||
labeled_items=session.get('labeled_items', 0),
|
||||
confirmed_items=session.get('confirmed_items', 0),
|
||||
corrected_items=session.get('corrected_items', 0),
|
||||
skipped_items=session.get('skipped_items', 0),
|
||||
created_at=session.get('created_at', datetime.utcnow()),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/upload")
|
||||
async def upload_images(
|
||||
session_id: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
files: List[UploadFile] = File(...),
|
||||
run_ocr: bool = Form(True),
|
||||
metadata: Optional[str] = Form(None), # JSON string
|
||||
):
|
||||
"""
|
||||
Upload images to a labeling session.
|
||||
|
||||
Args:
|
||||
session_id: Session to add images to
|
||||
files: Image files to upload (PNG, JPG, PDF)
|
||||
run_ocr: Whether to run OCR immediately (default: True)
|
||||
metadata: Optional JSON metadata (subject, year, etc.)
|
||||
"""
|
||||
import json
|
||||
|
||||
# Verify session exists
|
||||
session = await get_ocr_labeling_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
# Parse metadata
|
||||
meta_dict = None
|
||||
if metadata:
|
||||
try:
|
||||
meta_dict = json.loads(metadata)
|
||||
except json.JSONDecodeError:
|
||||
meta_dict = {"raw": metadata}
|
||||
|
||||
results = []
|
||||
ocr_model = session.get('ocr_model', 'llama3.2-vision:11b')
|
||||
|
||||
for file in files:
|
||||
# Read file content
|
||||
content = await file.read()
|
||||
|
||||
# Compute hash for deduplication
|
||||
image_hash = compute_image_hash(content)
|
||||
|
||||
# Generate item ID
|
||||
item_id = str(uuid.uuid4())
|
||||
|
||||
# Determine file extension
|
||||
extension = file.filename.split('.')[-1].lower() if file.filename else 'png'
|
||||
if extension not in ['png', 'jpg', 'jpeg', 'pdf']:
|
||||
extension = 'png'
|
||||
|
||||
# Save image
|
||||
if MINIO_AVAILABLE:
|
||||
# Upload to MinIO
|
||||
try:
|
||||
image_path = upload_ocr_image(session_id, item_id, content, extension)
|
||||
except Exception as e:
|
||||
print(f"MinIO upload failed, using local storage: {e}")
|
||||
image_path = save_image_locally(session_id, item_id, content, extension)
|
||||
else:
|
||||
# Save locally
|
||||
image_path = save_image_locally(session_id, item_id, content, extension)
|
||||
|
||||
# Run OCR if requested
|
||||
ocr_text = None
|
||||
ocr_confidence = None
|
||||
|
||||
if run_ocr and extension != 'pdf': # Skip OCR for PDFs for now
|
||||
ocr_text, ocr_confidence = await run_ocr_on_image(
|
||||
content,
|
||||
file.filename or f"{item_id}.{extension}",
|
||||
model=ocr_model
|
||||
)
|
||||
|
||||
# Add to database
|
||||
success = await add_ocr_labeling_item(
|
||||
item_id=item_id,
|
||||
session_id=session_id,
|
||||
image_path=image_path,
|
||||
image_hash=image_hash,
|
||||
ocr_text=ocr_text,
|
||||
ocr_confidence=ocr_confidence,
|
||||
ocr_model=ocr_model if ocr_text else None,
|
||||
metadata=meta_dict,
|
||||
)
|
||||
|
||||
if success:
|
||||
results.append({
|
||||
"id": item_id,
|
||||
"filename": file.filename,
|
||||
"image_path": image_path,
|
||||
"image_hash": image_hash,
|
||||
"ocr_text": ocr_text,
|
||||
"ocr_confidence": ocr_confidence,
|
||||
"status": "pending",
|
||||
})
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"uploaded_count": len(results),
|
||||
"items": results,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/queue", response_model=List[ItemResponse])
|
||||
async def get_labeling_queue(
|
||||
session_id: Optional[str] = Query(None),
|
||||
status: str = Query("pending"),
|
||||
limit: int = Query(10, ge=1, le=50),
|
||||
):
|
||||
"""
|
||||
Get items from the labeling queue.
|
||||
|
||||
Args:
|
||||
session_id: Optional filter by session
|
||||
status: Filter by status (pending, confirmed, corrected, skipped)
|
||||
limit: Number of items to return
|
||||
"""
|
||||
items = await get_ocr_labeling_queue(
|
||||
session_id=session_id,
|
||||
status=status,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
return [
|
||||
ItemResponse(
|
||||
id=item['id'],
|
||||
session_id=item['session_id'],
|
||||
session_name=item.get('session_name', ''),
|
||||
image_path=item['image_path'],
|
||||
image_url=get_image_url(item['image_path']),
|
||||
ocr_text=item.get('ocr_text'),
|
||||
ocr_confidence=item.get('ocr_confidence'),
|
||||
ground_truth=item.get('ground_truth'),
|
||||
status=item.get('status', 'pending'),
|
||||
metadata=item.get('metadata'),
|
||||
created_at=item.get('created_at', datetime.utcnow()),
|
||||
)
|
||||
for item in items
|
||||
]
|
||||
|
||||
|
||||
@router.get("/items/{item_id}", response_model=ItemResponse)
|
||||
async def get_item(item_id: str):
|
||||
"""Get a specific labeling item."""
|
||||
item = await get_ocr_labeling_item(item_id)
|
||||
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Item not found")
|
||||
|
||||
return ItemResponse(
|
||||
id=item['id'],
|
||||
session_id=item['session_id'],
|
||||
session_name=item.get('session_name', ''),
|
||||
image_path=item['image_path'],
|
||||
image_url=get_image_url(item['image_path']),
|
||||
ocr_text=item.get('ocr_text'),
|
||||
ocr_confidence=item.get('ocr_confidence'),
|
||||
ground_truth=item.get('ground_truth'),
|
||||
status=item.get('status', 'pending'),
|
||||
metadata=item.get('metadata'),
|
||||
created_at=item.get('created_at', datetime.utcnow()),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/confirm")
|
||||
async def confirm_item(request: ConfirmRequest):
|
||||
"""
|
||||
Confirm that OCR text is correct.
|
||||
|
||||
Sets ground_truth = ocr_text and marks item as confirmed.
|
||||
"""
|
||||
success = await confirm_ocr_label(
|
||||
item_id=request.item_id,
|
||||
labeled_by="admin", # TODO: Get from auth
|
||||
label_time_seconds=request.label_time_seconds,
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail="Failed to confirm item")
|
||||
|
||||
return {"status": "confirmed", "item_id": request.item_id}
|
||||
|
||||
|
||||
@router.post("/correct")
|
||||
async def correct_item(request: CorrectRequest):
|
||||
"""
|
||||
Save corrected ground truth for an item.
|
||||
|
||||
Use this when OCR text is wrong and needs manual correction.
|
||||
"""
|
||||
success = await correct_ocr_label(
|
||||
item_id=request.item_id,
|
||||
ground_truth=request.ground_truth,
|
||||
labeled_by="admin", # TODO: Get from auth
|
||||
label_time_seconds=request.label_time_seconds,
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail="Failed to correct item")
|
||||
|
||||
return {"status": "corrected", "item_id": request.item_id}
|
||||
|
||||
|
||||
@router.post("/skip")
|
||||
async def skip_item(request: SkipRequest):
|
||||
"""
|
||||
Skip an item (unusable image, etc.).
|
||||
|
||||
Skipped items are not included in training exports.
|
||||
"""
|
||||
success = await skip_ocr_item(
|
||||
item_id=request.item_id,
|
||||
labeled_by="admin", # TODO: Get from auth
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail="Failed to skip item")
|
||||
|
||||
return {"status": "skipped", "item_id": request.item_id}
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_stats(session_id: Optional[str] = Query(None)):
|
||||
"""
|
||||
Get labeling statistics.
|
||||
|
||||
Args:
|
||||
session_id: Optional session ID for session-specific stats
|
||||
"""
|
||||
stats = await get_ocr_labeling_stats(session_id=session_id)
|
||||
|
||||
if "error" in stats:
|
||||
raise HTTPException(status_code=500, detail=stats["error"])
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
@router.post("/export")
|
||||
async def export_data(request: ExportRequest):
|
||||
"""
|
||||
Export labeled data for training.
|
||||
|
||||
Formats:
|
||||
- generic: JSONL with image_path and ground_truth
|
||||
- trocr: Format for TrOCR/Microsoft Transformer fine-tuning
|
||||
- llama_vision: Format for llama3.2-vision fine-tuning
|
||||
|
||||
Exports are saved to disk at /app/ocr-exports/{format}/{batch_id}/
|
||||
"""
|
||||
# First, get samples from database
|
||||
db_samples = await export_training_samples(
|
||||
export_format=request.export_format,
|
||||
session_id=request.session_id,
|
||||
batch_id=request.batch_id,
|
||||
exported_by="admin", # TODO: Get from auth
|
||||
)
|
||||
|
||||
if not db_samples:
|
||||
return {
|
||||
"export_format": request.export_format,
|
||||
"batch_id": request.batch_id,
|
||||
"exported_count": 0,
|
||||
"samples": [],
|
||||
"message": "No labeled samples found to export",
|
||||
}
|
||||
|
||||
# If training export service is available, also write to disk
|
||||
export_result = None
|
||||
if TRAINING_EXPORT_AVAILABLE:
|
||||
try:
|
||||
export_service = get_training_export_service()
|
||||
|
||||
# Convert DB samples to TrainingSample objects
|
||||
training_samples = []
|
||||
for s in db_samples:
|
||||
training_samples.append(TrainingSample(
|
||||
id=s.get('id', s.get('item_id', '')),
|
||||
image_path=s.get('image_path', ''),
|
||||
ground_truth=s.get('ground_truth', ''),
|
||||
ocr_text=s.get('ocr_text'),
|
||||
ocr_confidence=s.get('ocr_confidence'),
|
||||
metadata=s.get('metadata'),
|
||||
))
|
||||
|
||||
# Export to files
|
||||
export_result = export_service.export(
|
||||
samples=training_samples,
|
||||
export_format=request.export_format,
|
||||
batch_id=request.batch_id,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Training export failed: {e}")
|
||||
# Continue without file export
|
||||
|
||||
response = {
|
||||
"export_format": request.export_format,
|
||||
"batch_id": request.batch_id or (export_result.batch_id if export_result else None),
|
||||
"exported_count": len(db_samples),
|
||||
"samples": db_samples,
|
||||
}
|
||||
|
||||
if export_result:
|
||||
response["export_path"] = export_result.export_path
|
||||
response["manifest_path"] = export_result.manifest_path
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.get("/training-samples")
|
||||
async def list_training_samples(
|
||||
export_format: Optional[str] = Query(None),
|
||||
batch_id: Optional[str] = Query(None),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
):
|
||||
"""Get exported training samples."""
|
||||
samples = await get_training_samples(
|
||||
export_format=export_format,
|
||||
batch_id=batch_id,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
return {
|
||||
"count": len(samples),
|
||||
"samples": samples,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/images/{path:path}")
|
||||
async def get_image(path: str):
|
||||
"""
|
||||
Serve an image from local storage.
|
||||
|
||||
This endpoint is used when images are stored locally (not in MinIO).
|
||||
"""
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
filepath = os.path.join(LOCAL_STORAGE_PATH, path)
|
||||
|
||||
if not os.path.exists(filepath):
|
||||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
|
||||
# Determine content type
|
||||
extension = filepath.split('.')[-1].lower()
|
||||
content_type = {
|
||||
'png': 'image/png',
|
||||
'jpg': 'image/jpeg',
|
||||
'jpeg': 'image/jpeg',
|
||||
'pdf': 'application/pdf',
|
||||
}.get(extension, 'application/octet-stream')
|
||||
|
||||
return FileResponse(filepath, media_type=content_type)
|
||||
|
||||
|
||||
@router.post("/run-ocr/{item_id}")
|
||||
async def run_ocr_for_item(item_id: str):
|
||||
"""
|
||||
Run OCR on an existing item.
|
||||
|
||||
Use this to re-run OCR or run it if it was skipped during upload.
|
||||
"""
|
||||
item = await get_ocr_labeling_item(item_id)
|
||||
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Item not found")
|
||||
|
||||
# Load image
|
||||
image_path = item['image_path']
|
||||
|
||||
if image_path.startswith(LOCAL_STORAGE_PATH):
|
||||
# Load from local storage
|
||||
if not os.path.exists(image_path):
|
||||
raise HTTPException(status_code=404, detail="Image file not found")
|
||||
with open(image_path, 'rb') as f:
|
||||
image_data = f.read()
|
||||
elif MINIO_AVAILABLE:
|
||||
# Load from MinIO
|
||||
try:
|
||||
image_data = get_ocr_image(image_path)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to load image: {e}")
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Cannot load image")
|
||||
|
||||
# Get OCR model from session
|
||||
session = await get_ocr_labeling_session(item['session_id'])
|
||||
ocr_model = session.get('ocr_model', 'llama3.2-vision:11b') if session else 'llama3.2-vision:11b'
|
||||
|
||||
# Run OCR
|
||||
ocr_text, ocr_confidence = await run_ocr_on_image(
|
||||
image_data,
|
||||
os.path.basename(image_path),
|
||||
model=ocr_model
|
||||
)
|
||||
|
||||
if ocr_text is None:
|
||||
raise HTTPException(status_code=500, detail="OCR failed")
|
||||
|
||||
# Update item in database
|
||||
from metrics_db import get_pool
|
||||
pool = await get_pool()
|
||||
if pool:
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE ocr_labeling_items
|
||||
SET ocr_text = $2, ocr_confidence = $3, ocr_model = $4
|
||||
WHERE id = $1
|
||||
""",
|
||||
item_id, ocr_text, ocr_confidence, ocr_model
|
||||
)
|
||||
|
||||
return {
|
||||
"item_id": item_id,
|
||||
"ocr_text": ocr_text,
|
||||
"ocr_confidence": ocr_confidence,
|
||||
"ocr_model": ocr_model,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/exports")
|
||||
async def list_exports(export_format: Optional[str] = Query(None)):
|
||||
"""
|
||||
List all available training data exports.
|
||||
|
||||
Args:
|
||||
export_format: Optional filter by format (generic, trocr, llama_vision)
|
||||
|
||||
Returns:
|
||||
List of export manifests with paths and metadata
|
||||
"""
|
||||
if not TRAINING_EXPORT_AVAILABLE:
|
||||
return {
|
||||
"exports": [],
|
||||
"message": "Training export service not available",
|
||||
}
|
||||
|
||||
try:
|
||||
export_service = get_training_export_service()
|
||||
exports = export_service.list_exports(export_format=export_format)
|
||||
|
||||
return {
|
||||
"count": len(exports),
|
||||
"exports": exports,
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to list exports: {e}")
|
||||
pass
|
||||
|
||||
try:
|
||||
from hybrid_vocab_extractor import run_paddle_ocr # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from services.trocr_service import run_trocr_ocr # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from services.donut_ocr_service import run_donut_ocr # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from vision_ocr_service import get_vision_ocr_service, VisionOCRService # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Routes (router is the main export for app.include_router)
|
||||
from ocr_labeling_routes import router # noqa: F401
|
||||
from ocr_labeling_upload_routes import router as upload_router # noqa: F401
|
||||
|
||||
205
klausur-service/backend/ocr_labeling_helpers.py
Normal file
205
klausur-service/backend/ocr_labeling_helpers.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""
|
||||
OCR Labeling - Helper Functions and OCR Wrappers
|
||||
|
||||
Extracted from ocr_labeling_api.py to keep files under 500 LOC.
|
||||
|
||||
DATENSCHUTZ/PRIVACY:
|
||||
- Alle Verarbeitung erfolgt lokal (Mac Mini mit Ollama)
|
||||
- Keine Daten werden an externe Server gesendet
|
||||
"""
|
||||
|
||||
import os
|
||||
import hashlib
|
||||
|
||||
from ocr_labeling_models import LOCAL_STORAGE_PATH
|
||||
|
||||
# Try to import Vision OCR service
|
||||
try:
|
||||
import sys
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'backend', 'klausur', 'services'))
|
||||
from vision_ocr_service import get_vision_ocr_service, VisionOCRService
|
||||
VISION_OCR_AVAILABLE = True
|
||||
except ImportError:
|
||||
VISION_OCR_AVAILABLE = False
|
||||
print("Warning: Vision OCR service not available")
|
||||
|
||||
# Try to import PaddleOCR from hybrid_vocab_extractor
|
||||
try:
|
||||
from hybrid_vocab_extractor import run_paddle_ocr
|
||||
PADDLEOCR_AVAILABLE = True
|
||||
except ImportError:
|
||||
PADDLEOCR_AVAILABLE = False
|
||||
print("Warning: PaddleOCR not available")
|
||||
|
||||
# Try to import TrOCR service
|
||||
try:
|
||||
from services.trocr_service import run_trocr_ocr
|
||||
TROCR_AVAILABLE = True
|
||||
except ImportError:
|
||||
TROCR_AVAILABLE = False
|
||||
print("Warning: TrOCR service not available")
|
||||
|
||||
# Try to import Donut service
|
||||
try:
|
||||
from services.donut_ocr_service import run_donut_ocr
|
||||
DONUT_AVAILABLE = True
|
||||
except ImportError:
|
||||
DONUT_AVAILABLE = False
|
||||
print("Warning: Donut OCR service not available")
|
||||
|
||||
# Try to import MinIO storage
|
||||
try:
|
||||
from minio_storage import upload_ocr_image, get_ocr_image, MINIO_BUCKET
|
||||
MINIO_AVAILABLE = True
|
||||
except ImportError:
|
||||
MINIO_AVAILABLE = False
|
||||
print("Warning: MinIO storage not available, using local storage")
|
||||
|
||||
# Try to import Training Export Service
|
||||
try:
|
||||
from training_export_service import (
|
||||
TrainingExportService,
|
||||
TrainingSample,
|
||||
get_training_export_service,
|
||||
)
|
||||
TRAINING_EXPORT_AVAILABLE = True
|
||||
except ImportError:
|
||||
TRAINING_EXPORT_AVAILABLE = False
|
||||
print("Warning: Training export service not available")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Helper Functions
|
||||
# =============================================================================
|
||||
|
||||
def compute_image_hash(image_data: bytes) -> str:
|
||||
"""Compute SHA256 hash of image data."""
|
||||
return hashlib.sha256(image_data).hexdigest()
|
||||
|
||||
|
||||
async def run_ocr_on_image(image_data: bytes, filename: str, model: str = "llama3.2-vision:11b") -> tuple:
|
||||
"""
|
||||
Run OCR on an image using the specified model.
|
||||
|
||||
Models:
|
||||
- llama3.2-vision:11b: Vision LLM (default, best for handwriting)
|
||||
- trocr: Microsoft TrOCR (fast for printed text)
|
||||
- paddleocr: PaddleOCR + LLM hybrid (4x faster)
|
||||
- donut: Document Understanding Transformer (structured documents)
|
||||
|
||||
Returns:
|
||||
Tuple of (ocr_text, confidence)
|
||||
"""
|
||||
print(f"Running OCR with model: {model}")
|
||||
|
||||
# Route to appropriate OCR service based on model
|
||||
if model == "paddleocr":
|
||||
return await run_paddleocr_wrapper(image_data, filename)
|
||||
elif model == "donut":
|
||||
return await run_donut_wrapper(image_data, filename)
|
||||
elif model == "trocr":
|
||||
return await run_trocr_wrapper(image_data, filename)
|
||||
else:
|
||||
# Default: Vision LLM (llama3.2-vision or similar)
|
||||
return await run_vision_ocr_wrapper(image_data, filename)
|
||||
|
||||
|
||||
async def run_vision_ocr_wrapper(image_data: bytes, filename: str) -> tuple:
|
||||
"""Vision LLM OCR wrapper."""
|
||||
if not VISION_OCR_AVAILABLE:
|
||||
print("Vision OCR service not available")
|
||||
return None, 0.0
|
||||
|
||||
try:
|
||||
service = get_vision_ocr_service()
|
||||
if not await service.is_available():
|
||||
print("Vision OCR service not available (is_available check failed)")
|
||||
return None, 0.0
|
||||
|
||||
result = await service.extract_text(
|
||||
image_data,
|
||||
filename=filename,
|
||||
is_handwriting=True
|
||||
)
|
||||
return result.text, result.confidence
|
||||
except Exception as e:
|
||||
print(f"Vision OCR failed: {e}")
|
||||
return None, 0.0
|
||||
|
||||
|
||||
async def run_paddleocr_wrapper(image_data: bytes, filename: str) -> tuple:
|
||||
"""PaddleOCR wrapper - uses hybrid_vocab_extractor."""
|
||||
if not PADDLEOCR_AVAILABLE:
|
||||
print("PaddleOCR not available, falling back to Vision OCR")
|
||||
return await run_vision_ocr_wrapper(image_data, filename)
|
||||
|
||||
try:
|
||||
# run_paddle_ocr returns (regions, raw_text)
|
||||
regions, raw_text = run_paddle_ocr(image_data)
|
||||
|
||||
if not raw_text:
|
||||
print("PaddleOCR returned empty text")
|
||||
return None, 0.0
|
||||
|
||||
# Calculate average confidence from regions
|
||||
if regions:
|
||||
avg_confidence = sum(r.confidence for r in regions) / len(regions)
|
||||
else:
|
||||
avg_confidence = 0.5
|
||||
|
||||
return raw_text, avg_confidence
|
||||
except Exception as e:
|
||||
print(f"PaddleOCR failed: {e}, falling back to Vision OCR")
|
||||
return await run_vision_ocr_wrapper(image_data, filename)
|
||||
|
||||
|
||||
async def run_trocr_wrapper(image_data: bytes, filename: str) -> tuple:
|
||||
"""TrOCR wrapper."""
|
||||
if not TROCR_AVAILABLE:
|
||||
print("TrOCR not available, falling back to Vision OCR")
|
||||
return await run_vision_ocr_wrapper(image_data, filename)
|
||||
|
||||
try:
|
||||
text, confidence = await run_trocr_ocr(image_data)
|
||||
return text, confidence
|
||||
except Exception as e:
|
||||
print(f"TrOCR failed: {e}, falling back to Vision OCR")
|
||||
return await run_vision_ocr_wrapper(image_data, filename)
|
||||
|
||||
|
||||
async def run_donut_wrapper(image_data: bytes, filename: str) -> tuple:
|
||||
"""Donut OCR wrapper."""
|
||||
if not DONUT_AVAILABLE:
|
||||
print("Donut not available, falling back to Vision OCR")
|
||||
return await run_vision_ocr_wrapper(image_data, filename)
|
||||
|
||||
try:
|
||||
text, confidence = await run_donut_ocr(image_data)
|
||||
return text, confidence
|
||||
except Exception as e:
|
||||
print(f"Donut OCR failed: {e}, falling back to Vision OCR")
|
||||
return await run_vision_ocr_wrapper(image_data, filename)
|
||||
|
||||
|
||||
def save_image_locally(session_id: str, item_id: str, image_data: bytes, extension: str = "png") -> str:
|
||||
"""Save image to local storage."""
|
||||
session_dir = os.path.join(LOCAL_STORAGE_PATH, session_id)
|
||||
os.makedirs(session_dir, exist_ok=True)
|
||||
|
||||
filename = f"{item_id}.{extension}"
|
||||
filepath = os.path.join(session_dir, filename)
|
||||
|
||||
with open(filepath, 'wb') as f:
|
||||
f.write(image_data)
|
||||
|
||||
return filepath
|
||||
|
||||
|
||||
def get_image_url(image_path: str) -> str:
|
||||
"""Get URL for an image."""
|
||||
# For local images, return a relative path that the frontend can use
|
||||
if image_path.startswith(LOCAL_STORAGE_PATH):
|
||||
relative_path = image_path[len(LOCAL_STORAGE_PATH):].lstrip('/')
|
||||
return f"/api/v1/ocr-label/images/{relative_path}"
|
||||
# For MinIO images, the path is already a URL or key
|
||||
return image_path
|
||||
86
klausur-service/backend/ocr_labeling_models.py
Normal file
86
klausur-service/backend/ocr_labeling_models.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""
|
||||
OCR Labeling - Pydantic Models and Constants
|
||||
|
||||
Extracted from ocr_labeling_api.py to keep files under 500 LOC.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, Dict
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
# Local storage path (fallback if MinIO not available)
|
||||
LOCAL_STORAGE_PATH = os.getenv("OCR_STORAGE_PATH", "/app/ocr-labeling")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Pydantic Models
|
||||
# =============================================================================
|
||||
|
||||
class SessionCreate(BaseModel):
|
||||
name: str
|
||||
source_type: str = "klausur" # klausur, handwriting_sample, scan
|
||||
description: Optional[str] = None
|
||||
ocr_model: Optional[str] = "llama3.2-vision:11b"
|
||||
|
||||
|
||||
class SessionResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
source_type: str
|
||||
description: Optional[str]
|
||||
ocr_model: Optional[str]
|
||||
total_items: int
|
||||
labeled_items: int
|
||||
confirmed_items: int
|
||||
corrected_items: int
|
||||
skipped_items: int
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class ItemResponse(BaseModel):
|
||||
id: str
|
||||
session_id: str
|
||||
session_name: str
|
||||
image_path: str
|
||||
image_url: Optional[str]
|
||||
ocr_text: Optional[str]
|
||||
ocr_confidence: Optional[float]
|
||||
ground_truth: Optional[str]
|
||||
status: str
|
||||
metadata: Optional[Dict]
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class ConfirmRequest(BaseModel):
|
||||
item_id: str
|
||||
label_time_seconds: Optional[int] = None
|
||||
|
||||
|
||||
class CorrectRequest(BaseModel):
|
||||
item_id: str
|
||||
ground_truth: str
|
||||
label_time_seconds: Optional[int] = None
|
||||
|
||||
|
||||
class SkipRequest(BaseModel):
|
||||
item_id: str
|
||||
|
||||
|
||||
class ExportRequest(BaseModel):
|
||||
export_format: str = "generic" # generic, trocr, llama_vision
|
||||
session_id: Optional[str] = None
|
||||
batch_id: Optional[str] = None
|
||||
|
||||
|
||||
class StatsResponse(BaseModel):
|
||||
total_sessions: Optional[int] = None
|
||||
total_items: int
|
||||
labeled_items: int
|
||||
confirmed_items: int
|
||||
corrected_items: int
|
||||
pending_items: int
|
||||
exportable_items: Optional[int] = None
|
||||
accuracy_rate: float
|
||||
avg_label_time_seconds: Optional[float] = None
|
||||
241
klausur-service/backend/ocr_labeling_routes.py
Normal file
241
klausur-service/backend/ocr_labeling_routes.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""
|
||||
OCR Labeling - Session and Labeling Route Handlers
|
||||
|
||||
Extracted from ocr_labeling_api.py to keep files under 500 LOC.
|
||||
|
||||
Endpoints:
|
||||
- POST /sessions - Create labeling session
|
||||
- GET /sessions - List sessions
|
||||
- GET /sessions/{id} - Get session
|
||||
- GET /queue - Get labeling queue
|
||||
- GET /items/{id} - Get item
|
||||
- POST /confirm - Confirm OCR
|
||||
- POST /correct - Correct ground truth
|
||||
- POST /skip - Skip item
|
||||
- GET /stats - Get statistics
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
from metrics_db import (
|
||||
create_ocr_labeling_session,
|
||||
get_ocr_labeling_sessions,
|
||||
get_ocr_labeling_session,
|
||||
get_ocr_labeling_queue,
|
||||
get_ocr_labeling_item,
|
||||
confirm_ocr_label,
|
||||
correct_ocr_label,
|
||||
skip_ocr_item,
|
||||
get_ocr_labeling_stats,
|
||||
)
|
||||
|
||||
from ocr_labeling_models import (
|
||||
SessionCreate, SessionResponse, ItemResponse,
|
||||
ConfirmRequest, CorrectRequest, SkipRequest,
|
||||
)
|
||||
from ocr_labeling_helpers import get_image_url
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/v1/ocr-label", tags=["OCR Labeling"])
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Session Endpoints
|
||||
# =============================================================================
|
||||
|
||||
@router.post("/sessions", response_model=SessionResponse)
|
||||
async def create_session(session: SessionCreate):
|
||||
"""Create a new OCR labeling session."""
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
success = await create_ocr_labeling_session(
|
||||
session_id=session_id,
|
||||
name=session.name,
|
||||
source_type=session.source_type,
|
||||
description=session.description,
|
||||
ocr_model=session.ocr_model,
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to create session")
|
||||
|
||||
return SessionResponse(
|
||||
id=session_id,
|
||||
name=session.name,
|
||||
source_type=session.source_type,
|
||||
description=session.description,
|
||||
ocr_model=session.ocr_model,
|
||||
total_items=0,
|
||||
labeled_items=0,
|
||||
confirmed_items=0,
|
||||
corrected_items=0,
|
||||
skipped_items=0,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/sessions", response_model=List[SessionResponse])
|
||||
async def list_sessions(limit: int = Query(50, ge=1, le=100)):
|
||||
"""List all OCR labeling sessions."""
|
||||
sessions = await get_ocr_labeling_sessions(limit=limit)
|
||||
|
||||
return [
|
||||
SessionResponse(
|
||||
id=s['id'],
|
||||
name=s['name'],
|
||||
source_type=s['source_type'],
|
||||
description=s.get('description'),
|
||||
ocr_model=s.get('ocr_model'),
|
||||
total_items=s.get('total_items', 0),
|
||||
labeled_items=s.get('labeled_items', 0),
|
||||
confirmed_items=s.get('confirmed_items', 0),
|
||||
corrected_items=s.get('corrected_items', 0),
|
||||
skipped_items=s.get('skipped_items', 0),
|
||||
created_at=s.get('created_at', datetime.utcnow()),
|
||||
)
|
||||
for s in sessions
|
||||
]
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}", response_model=SessionResponse)
|
||||
async def get_session(session_id: str):
|
||||
"""Get a specific OCR labeling session."""
|
||||
session = await get_ocr_labeling_session(session_id)
|
||||
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
return SessionResponse(
|
||||
id=session['id'],
|
||||
name=session['name'],
|
||||
source_type=session['source_type'],
|
||||
description=session.get('description'),
|
||||
ocr_model=session.get('ocr_model'),
|
||||
total_items=session.get('total_items', 0),
|
||||
labeled_items=session.get('labeled_items', 0),
|
||||
confirmed_items=session.get('confirmed_items', 0),
|
||||
corrected_items=session.get('corrected_items', 0),
|
||||
skipped_items=session.get('skipped_items', 0),
|
||||
created_at=session.get('created_at', datetime.utcnow()),
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Queue and Item Endpoints
|
||||
# =============================================================================
|
||||
|
||||
@router.get("/queue", response_model=List[ItemResponse])
|
||||
async def get_labeling_queue(
|
||||
session_id: Optional[str] = Query(None),
|
||||
status: str = Query("pending"),
|
||||
limit: int = Query(10, ge=1, le=50),
|
||||
):
|
||||
"""Get items from the labeling queue."""
|
||||
items = await get_ocr_labeling_queue(
|
||||
session_id=session_id,
|
||||
status=status,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
return [
|
||||
ItemResponse(
|
||||
id=item['id'],
|
||||
session_id=item['session_id'],
|
||||
session_name=item.get('session_name', ''),
|
||||
image_path=item['image_path'],
|
||||
image_url=get_image_url(item['image_path']),
|
||||
ocr_text=item.get('ocr_text'),
|
||||
ocr_confidence=item.get('ocr_confidence'),
|
||||
ground_truth=item.get('ground_truth'),
|
||||
status=item.get('status', 'pending'),
|
||||
metadata=item.get('metadata'),
|
||||
created_at=item.get('created_at', datetime.utcnow()),
|
||||
)
|
||||
for item in items
|
||||
]
|
||||
|
||||
|
||||
@router.get("/items/{item_id}", response_model=ItemResponse)
|
||||
async def get_item(item_id: str):
|
||||
"""Get a specific labeling item."""
|
||||
item = await get_ocr_labeling_item(item_id)
|
||||
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Item not found")
|
||||
|
||||
return ItemResponse(
|
||||
id=item['id'],
|
||||
session_id=item['session_id'],
|
||||
session_name=item.get('session_name', ''),
|
||||
image_path=item['image_path'],
|
||||
image_url=get_image_url(item['image_path']),
|
||||
ocr_text=item.get('ocr_text'),
|
||||
ocr_confidence=item.get('ocr_confidence'),
|
||||
ground_truth=item.get('ground_truth'),
|
||||
status=item.get('status', 'pending'),
|
||||
metadata=item.get('metadata'),
|
||||
created_at=item.get('created_at', datetime.utcnow()),
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Labeling Action Endpoints
|
||||
# =============================================================================
|
||||
|
||||
@router.post("/confirm")
|
||||
async def confirm_item(request: ConfirmRequest):
|
||||
"""Confirm that OCR text is correct."""
|
||||
success = await confirm_ocr_label(
|
||||
item_id=request.item_id,
|
||||
labeled_by="admin",
|
||||
label_time_seconds=request.label_time_seconds,
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail="Failed to confirm item")
|
||||
|
||||
return {"status": "confirmed", "item_id": request.item_id}
|
||||
|
||||
|
||||
@router.post("/correct")
|
||||
async def correct_item(request: CorrectRequest):
|
||||
"""Save corrected ground truth for an item."""
|
||||
success = await correct_ocr_label(
|
||||
item_id=request.item_id,
|
||||
ground_truth=request.ground_truth,
|
||||
labeled_by="admin",
|
||||
label_time_seconds=request.label_time_seconds,
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail="Failed to correct item")
|
||||
|
||||
return {"status": "corrected", "item_id": request.item_id}
|
||||
|
||||
|
||||
@router.post("/skip")
|
||||
async def skip_item(request: SkipRequest):
|
||||
"""Skip an item (unusable image, etc.)."""
|
||||
success = await skip_ocr_item(
|
||||
item_id=request.item_id,
|
||||
labeled_by="admin",
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail="Failed to skip item")
|
||||
|
||||
return {"status": "skipped", "item_id": request.item_id}
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_stats(session_id: Optional[str] = Query(None)):
|
||||
"""Get labeling statistics."""
|
||||
stats = await get_ocr_labeling_stats(session_id=session_id)
|
||||
|
||||
if "error" in stats:
|
||||
raise HTTPException(status_code=500, detail=stats["error"])
|
||||
|
||||
return stats
|
||||
313
klausur-service/backend/ocr_labeling_upload_routes.py
Normal file
313
klausur-service/backend/ocr_labeling_upload_routes.py
Normal file
@@ -0,0 +1,313 @@
|
||||
"""
|
||||
OCR Labeling - Upload, Run-OCR, and Export Route Handlers
|
||||
|
||||
Extracted from ocr_labeling_routes.py to keep files under 500 LOC.
|
||||
|
||||
Endpoints:
|
||||
- POST /sessions/{id}/upload - Upload images for labeling
|
||||
- POST /run-ocr/{item_id} - Run OCR on existing item
|
||||
- POST /export - Export training data
|
||||
- GET /training-samples - List training samples
|
||||
- GET /images/{path} - Serve images from local storage
|
||||
- GET /exports - List exports
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query
|
||||
from typing import Optional, List
|
||||
import uuid
|
||||
import os
|
||||
|
||||
from metrics_db import (
|
||||
get_ocr_labeling_session,
|
||||
add_ocr_labeling_item,
|
||||
get_ocr_labeling_item,
|
||||
export_training_samples,
|
||||
get_training_samples,
|
||||
)
|
||||
|
||||
from ocr_labeling_models import (
|
||||
ExportRequest,
|
||||
LOCAL_STORAGE_PATH,
|
||||
)
|
||||
from ocr_labeling_helpers import (
|
||||
compute_image_hash, run_ocr_on_image,
|
||||
save_image_locally,
|
||||
MINIO_AVAILABLE, TRAINING_EXPORT_AVAILABLE,
|
||||
)
|
||||
|
||||
# Conditional imports
|
||||
try:
|
||||
from minio_storage import upload_ocr_image, get_ocr_image
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from training_export_service import TrainingSample, get_training_export_service
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/v1/ocr-label", tags=["OCR Labeling"])
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/upload")
|
||||
async def upload_images(
|
||||
session_id: str,
|
||||
files: List[UploadFile] = File(...),
|
||||
run_ocr: bool = Form(True),
|
||||
metadata: Optional[str] = Form(None),
|
||||
):
|
||||
"""
|
||||
Upload images to a labeling session.
|
||||
|
||||
Args:
|
||||
session_id: Session to add images to
|
||||
files: Image files to upload (PNG, JPG, PDF)
|
||||
run_ocr: Whether to run OCR immediately (default: True)
|
||||
metadata: Optional JSON metadata (subject, year, etc.)
|
||||
"""
|
||||
import json
|
||||
|
||||
session = await get_ocr_labeling_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
meta_dict = None
|
||||
if metadata:
|
||||
try:
|
||||
meta_dict = json.loads(metadata)
|
||||
except json.JSONDecodeError:
|
||||
meta_dict = {"raw": metadata}
|
||||
|
||||
results = []
|
||||
ocr_model = session.get('ocr_model', 'llama3.2-vision:11b')
|
||||
|
||||
for file in files:
|
||||
content = await file.read()
|
||||
image_hash = compute_image_hash(content)
|
||||
item_id = str(uuid.uuid4())
|
||||
|
||||
extension = file.filename.split('.')[-1].lower() if file.filename else 'png'
|
||||
if extension not in ['png', 'jpg', 'jpeg', 'pdf']:
|
||||
extension = 'png'
|
||||
|
||||
if MINIO_AVAILABLE:
|
||||
try:
|
||||
image_path = upload_ocr_image(session_id, item_id, content, extension)
|
||||
except Exception as e:
|
||||
print(f"MinIO upload failed, using local storage: {e}")
|
||||
image_path = save_image_locally(session_id, item_id, content, extension)
|
||||
else:
|
||||
image_path = save_image_locally(session_id, item_id, content, extension)
|
||||
|
||||
ocr_text = None
|
||||
ocr_confidence = None
|
||||
|
||||
if run_ocr and extension != 'pdf':
|
||||
ocr_text, ocr_confidence = await run_ocr_on_image(
|
||||
content,
|
||||
file.filename or f"{item_id}.{extension}",
|
||||
model=ocr_model
|
||||
)
|
||||
|
||||
success = await add_ocr_labeling_item(
|
||||
item_id=item_id,
|
||||
session_id=session_id,
|
||||
image_path=image_path,
|
||||
image_hash=image_hash,
|
||||
ocr_text=ocr_text,
|
||||
ocr_confidence=ocr_confidence,
|
||||
ocr_model=ocr_model if ocr_text else None,
|
||||
metadata=meta_dict,
|
||||
)
|
||||
|
||||
if success:
|
||||
results.append({
|
||||
"id": item_id,
|
||||
"filename": file.filename,
|
||||
"image_path": image_path,
|
||||
"image_hash": image_hash,
|
||||
"ocr_text": ocr_text,
|
||||
"ocr_confidence": ocr_confidence,
|
||||
"status": "pending",
|
||||
})
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"uploaded_count": len(results),
|
||||
"items": results,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/export")
|
||||
async def export_data(request: ExportRequest):
|
||||
"""Export labeled data for training."""
|
||||
db_samples = await export_training_samples(
|
||||
export_format=request.export_format,
|
||||
session_id=request.session_id,
|
||||
batch_id=request.batch_id,
|
||||
exported_by="admin",
|
||||
)
|
||||
|
||||
if not db_samples:
|
||||
return {
|
||||
"export_format": request.export_format,
|
||||
"batch_id": request.batch_id,
|
||||
"exported_count": 0,
|
||||
"samples": [],
|
||||
"message": "No labeled samples found to export",
|
||||
}
|
||||
|
||||
export_result = None
|
||||
if TRAINING_EXPORT_AVAILABLE:
|
||||
try:
|
||||
export_service = get_training_export_service()
|
||||
|
||||
training_samples = []
|
||||
for s in db_samples:
|
||||
training_samples.append(TrainingSample(
|
||||
id=s.get('id', s.get('item_id', '')),
|
||||
image_path=s.get('image_path', ''),
|
||||
ground_truth=s.get('ground_truth', ''),
|
||||
ocr_text=s.get('ocr_text'),
|
||||
ocr_confidence=s.get('ocr_confidence'),
|
||||
metadata=s.get('metadata'),
|
||||
))
|
||||
|
||||
export_result = export_service.export(
|
||||
samples=training_samples,
|
||||
export_format=request.export_format,
|
||||
batch_id=request.batch_id,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Training export failed: {e}")
|
||||
|
||||
response = {
|
||||
"export_format": request.export_format,
|
||||
"batch_id": request.batch_id or (export_result.batch_id if export_result else None),
|
||||
"exported_count": len(db_samples),
|
||||
"samples": db_samples,
|
||||
}
|
||||
|
||||
if export_result:
|
||||
response["export_path"] = export_result.export_path
|
||||
response["manifest_path"] = export_result.manifest_path
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.get("/training-samples")
|
||||
async def list_training_samples(
|
||||
export_format: Optional[str] = Query(None),
|
||||
batch_id: Optional[str] = Query(None),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
):
|
||||
"""Get exported training samples."""
|
||||
samples = await get_training_samples(
|
||||
export_format=export_format,
|
||||
batch_id=batch_id,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
return {
|
||||
"count": len(samples),
|
||||
"samples": samples,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/images/{path:path}")
|
||||
async def get_image(path: str):
|
||||
"""Serve an image from local storage."""
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
filepath = os.path.join(LOCAL_STORAGE_PATH, path)
|
||||
|
||||
if not os.path.exists(filepath):
|
||||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
|
||||
extension = filepath.split('.')[-1].lower()
|
||||
content_type = {
|
||||
'png': 'image/png',
|
||||
'jpg': 'image/jpeg',
|
||||
'jpeg': 'image/jpeg',
|
||||
'pdf': 'application/pdf',
|
||||
}.get(extension, 'application/octet-stream')
|
||||
|
||||
return FileResponse(filepath, media_type=content_type)
|
||||
|
||||
|
||||
@router.post("/run-ocr/{item_id}")
|
||||
async def run_ocr_for_item(item_id: str):
|
||||
"""Run OCR on an existing item."""
|
||||
item = await get_ocr_labeling_item(item_id)
|
||||
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Item not found")
|
||||
|
||||
image_path = item['image_path']
|
||||
|
||||
if image_path.startswith(LOCAL_STORAGE_PATH):
|
||||
if not os.path.exists(image_path):
|
||||
raise HTTPException(status_code=404, detail="Image file not found")
|
||||
with open(image_path, 'rb') as f:
|
||||
image_data = f.read()
|
||||
elif MINIO_AVAILABLE:
|
||||
try:
|
||||
image_data = get_ocr_image(image_path)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to load image: {e}")
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Cannot load image")
|
||||
|
||||
session = await get_ocr_labeling_session(item['session_id'])
|
||||
ocr_model = session.get('ocr_model', 'llama3.2-vision:11b') if session else 'llama3.2-vision:11b'
|
||||
|
||||
ocr_text, ocr_confidence = await run_ocr_on_image(
|
||||
image_data,
|
||||
os.path.basename(image_path),
|
||||
model=ocr_model
|
||||
)
|
||||
|
||||
if ocr_text is None:
|
||||
raise HTTPException(status_code=500, detail="OCR failed")
|
||||
|
||||
from metrics_db import get_pool
|
||||
pool = await get_pool()
|
||||
if pool:
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE ocr_labeling_items
|
||||
SET ocr_text = $2, ocr_confidence = $3, ocr_model = $4
|
||||
WHERE id = $1
|
||||
""",
|
||||
item_id, ocr_text, ocr_confidence, ocr_model
|
||||
)
|
||||
|
||||
return {
|
||||
"item_id": item_id,
|
||||
"ocr_text": ocr_text,
|
||||
"ocr_confidence": ocr_confidence,
|
||||
"ocr_model": ocr_model,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/exports")
|
||||
async def list_exports(export_format: Optional[str] = Query(None)):
|
||||
"""List all available training data exports."""
|
||||
if not TRAINING_EXPORT_AVAILABLE:
|
||||
return {
|
||||
"exports": [],
|
||||
"message": "Training export service not available",
|
||||
}
|
||||
|
||||
try:
|
||||
export_service = get_training_export_service()
|
||||
exports = export_service.list_exports(export_format=export_format)
|
||||
|
||||
return {
|
||||
"count": len(exports),
|
||||
"exports": exports,
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to list exports: {e}")
|
||||
@@ -1,705 +1,23 @@
|
||||
"""
|
||||
OCR Pipeline Auto-Mode Orchestrator and Reprocess Endpoints.
|
||||
OCR Pipeline Auto-Mode Orchestrator and Reprocess Endpoints — Barrel Re-export.
|
||||
|
||||
Extracted from ocr_pipeline_api.py — contains:
|
||||
- POST /sessions/{session_id}/reprocess (clear downstream + restart from step)
|
||||
- POST /sessions/{session_id}/run-auto (full auto-mode with SSE streaming)
|
||||
Split into submodules:
|
||||
- ocr_pipeline_reprocess.py — POST /sessions/{id}/reprocess
|
||||
- ocr_pipeline_auto_steps.py — POST /sessions/{id}/run-auto + VLM helper
|
||||
|
||||
Lizenz: Apache 2.0
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from dataclasses import asdict
|
||||
from typing import Any, Dict, List, Optional
|
||||
from fastapi import APIRouter
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cv_vocab_pipeline import (
|
||||
OLLAMA_REVIEW_MODEL,
|
||||
PageRegion,
|
||||
RowGeometry,
|
||||
_cells_to_vocab_entries,
|
||||
_detect_header_footer_gaps,
|
||||
_detect_sub_columns,
|
||||
_fix_character_confusion,
|
||||
_fix_phonetic_brackets,
|
||||
fix_cell_phonetics,
|
||||
analyze_layout,
|
||||
build_cell_grid,
|
||||
classify_column_types,
|
||||
create_layout_image,
|
||||
create_ocr_image,
|
||||
deskew_image,
|
||||
deskew_image_by_word_alignment,
|
||||
detect_column_geometry,
|
||||
detect_row_geometry,
|
||||
_apply_shear,
|
||||
dewarp_image,
|
||||
llm_review_entries,
|
||||
)
|
||||
from ocr_pipeline_common import (
|
||||
_cache,
|
||||
_load_session_to_cache,
|
||||
_get_cached,
|
||||
_get_base_image_png,
|
||||
_append_pipeline_log,
|
||||
)
|
||||
from ocr_pipeline_session_store import (
|
||||
get_session_db,
|
||||
update_session_db,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from ocr_pipeline_reprocess import router as _reprocess_router
|
||||
from ocr_pipeline_auto_steps import router as _steps_router
|
||||
|
||||
# Combine both sub-routers into a single router for backwards compatibility.
|
||||
# The consumer imports `from ocr_pipeline_auto import router as _auto_router`.
|
||||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||||
router.include_router(_reprocess_router)
|
||||
router.include_router(_steps_router)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Reprocess endpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/reprocess")
|
||||
async def reprocess_session(session_id: str, request: Request):
|
||||
"""Re-run pipeline from a specific step, clearing downstream data.
|
||||
|
||||
Body: {"from_step": 5} (1-indexed step number)
|
||||
|
||||
Pipeline order: Orientation(1) → Deskew(2) → Dewarp(3) → Crop(4) → Columns(5) →
|
||||
Rows(6) → Words(7) → LLM-Review(8) → Reconstruction(9) → Validation(10)
|
||||
|
||||
Clears downstream results:
|
||||
- from_step <= 1: orientation_result + all downstream
|
||||
- from_step <= 2: deskew_result + all downstream
|
||||
- from_step <= 3: dewarp_result + all downstream
|
||||
- from_step <= 4: crop_result + all downstream
|
||||
- from_step <= 5: column_result, row_result, word_result
|
||||
- from_step <= 6: row_result, word_result
|
||||
- from_step <= 7: word_result (cells, vocab_entries)
|
||||
- from_step <= 8: word_result.llm_review only
|
||||
"""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
body = await request.json()
|
||||
from_step = body.get("from_step", 1)
|
||||
if not isinstance(from_step, int) or from_step < 1 or from_step > 10:
|
||||
raise HTTPException(status_code=400, detail="from_step must be between 1 and 10")
|
||||
|
||||
update_kwargs: Dict[str, Any] = {"current_step": from_step}
|
||||
|
||||
# Clear downstream data based on from_step
|
||||
# New pipeline order: Orient(2) → Deskew(3) → Dewarp(4) → Crop(5) →
|
||||
# Columns(6) → Rows(7) → Words(8) → LLM(9) → Recon(10) → GT(11)
|
||||
if from_step <= 8:
|
||||
update_kwargs["word_result"] = None
|
||||
elif from_step == 9:
|
||||
# Only clear LLM review from word_result
|
||||
word_result = session.get("word_result")
|
||||
if word_result:
|
||||
word_result.pop("llm_review", None)
|
||||
word_result.pop("llm_corrections", None)
|
||||
update_kwargs["word_result"] = word_result
|
||||
|
||||
if from_step <= 7:
|
||||
update_kwargs["row_result"] = None
|
||||
if from_step <= 6:
|
||||
update_kwargs["column_result"] = None
|
||||
if from_step <= 4:
|
||||
update_kwargs["crop_result"] = None
|
||||
if from_step <= 3:
|
||||
update_kwargs["dewarp_result"] = None
|
||||
if from_step <= 2:
|
||||
update_kwargs["deskew_result"] = None
|
||||
if from_step <= 1:
|
||||
update_kwargs["orientation_result"] = None
|
||||
|
||||
await update_session_db(session_id, **update_kwargs)
|
||||
|
||||
# Also clear cache
|
||||
if session_id in _cache:
|
||||
for key in list(update_kwargs.keys()):
|
||||
if key != "current_step":
|
||||
_cache[session_id][key] = update_kwargs[key]
|
||||
_cache[session_id]["current_step"] = from_step
|
||||
|
||||
logger.info(f"Session {session_id} reprocessing from step {from_step}")
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"from_step": from_step,
|
||||
"cleared": [k for k in update_kwargs if k != "current_step"],
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# VLM shear detection helper (used by dewarp step in auto-mode)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _detect_shear_with_vlm(image_bytes: bytes) -> Dict[str, Any]:
|
||||
"""Ask qwen2.5vl:32b to estimate the vertical shear angle of a scanned page.
|
||||
|
||||
The VLM is shown the image and asked: are the column/table borders tilted?
|
||||
If yes, by how many degrees? Returns a dict with shear_degrees and confidence.
|
||||
Confidence is 0.0 if Ollama is unavailable or parsing fails.
|
||||
"""
|
||||
import httpx
|
||||
import base64
|
||||
import re
|
||||
|
||||
ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
|
||||
model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b")
|
||||
|
||||
prompt = (
|
||||
"This is a scanned vocabulary worksheet. Look at the vertical borders of the table columns. "
|
||||
"Are they perfectly vertical, or do they tilt slightly? "
|
||||
"If they tilt, estimate the tilt angle in degrees (positive = top tilts right, negative = top tilts left). "
|
||||
"Reply with ONLY a JSON object like: {\"shear_degrees\": 1.2, \"confidence\": 0.8} "
|
||||
"Use confidence 0.0-1.0 based on how clearly you can see the tilt. "
|
||||
"If the columns look straight, return {\"shear_degrees\": 0.0, \"confidence\": 0.9}"
|
||||
)
|
||||
|
||||
img_b64 = base64.b64encode(image_bytes).decode("utf-8")
|
||||
payload = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"images": [img_b64],
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
resp = await client.post(f"{ollama_base}/api/generate", json=payload)
|
||||
resp.raise_for_status()
|
||||
text = resp.json().get("response", "")
|
||||
|
||||
# Parse JSON from response (may have surrounding text)
|
||||
match = re.search(r'\{[^}]+\}', text)
|
||||
if match:
|
||||
import json
|
||||
data = json.loads(match.group(0))
|
||||
shear = float(data.get("shear_degrees", 0.0))
|
||||
conf = float(data.get("confidence", 0.0))
|
||||
# Clamp to reasonable range
|
||||
shear = max(-3.0, min(3.0, shear))
|
||||
conf = max(0.0, min(1.0, conf))
|
||||
return {"method": "vlm_qwen2.5vl", "shear_degrees": round(shear, 3), "confidence": round(conf, 2)}
|
||||
except Exception as e:
|
||||
logger.warning(f"VLM dewarp failed: {e}")
|
||||
|
||||
return {"method": "vlm_qwen2.5vl", "shear_degrees": 0.0, "confidence": 0.0}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auto-mode orchestrator
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class RunAutoRequest(BaseModel):
|
||||
from_step: int = 1 # 1=deskew, 2=dewarp, 3=columns, 4=rows, 5=words, 6=llm-review
|
||||
ocr_engine: str = "auto" # "auto" | "rapid" | "tesseract"
|
||||
pronunciation: str = "british"
|
||||
skip_llm_review: bool = False
|
||||
dewarp_method: str = "ensemble" # "ensemble" | "vlm" | "cv"
|
||||
|
||||
|
||||
async def _auto_sse_event(step: str, status: str, data: Dict[str, Any]) -> str:
|
||||
"""Format a single SSE event line."""
|
||||
import json as _json
|
||||
payload = {"step": step, "status": status, **data}
|
||||
return f"data: {_json.dumps(payload)}\n\n"
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/run-auto")
|
||||
async def run_auto(session_id: str, req: RunAutoRequest, request: Request):
|
||||
"""Run the full OCR pipeline automatically from a given step, streaming SSE progress.
|
||||
|
||||
Steps:
|
||||
1. Deskew — straighten the scan
|
||||
2. Dewarp — correct vertical shear (ensemble CV or VLM)
|
||||
3. Columns — detect column layout
|
||||
4. Rows — detect row layout
|
||||
5. Words — OCR each cell
|
||||
6. LLM review — correct OCR errors (optional)
|
||||
|
||||
Already-completed steps are skipped unless `from_step` forces a rerun.
|
||||
Yields SSE events of the form:
|
||||
data: {"step": "deskew", "status": "start"|"done"|"skipped"|"error", ...}
|
||||
|
||||
Final event:
|
||||
data: {"step": "complete", "status": "done", "steps_run": [...], "steps_skipped": [...]}
|
||||
"""
|
||||
if req.from_step < 1 or req.from_step > 6:
|
||||
raise HTTPException(status_code=400, detail="from_step must be 1-6")
|
||||
if req.dewarp_method not in ("ensemble", "vlm", "cv"):
|
||||
raise HTTPException(status_code=400, detail="dewarp_method must be: ensemble, vlm, cv")
|
||||
|
||||
if session_id not in _cache:
|
||||
await _load_session_to_cache(session_id)
|
||||
|
||||
async def _generate():
|
||||
steps_run: List[str] = []
|
||||
steps_skipped: List[str] = []
|
||||
error_step: Optional[str] = None
|
||||
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
yield await _auto_sse_event("error", "error", {"message": f"Session {session_id} not found"})
|
||||
return
|
||||
|
||||
cached = _get_cached(session_id)
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Step 1: Deskew
|
||||
# -----------------------------------------------------------------
|
||||
if req.from_step <= 1:
|
||||
yield await _auto_sse_event("deskew", "start", {})
|
||||
try:
|
||||
t0 = time.time()
|
||||
orig_bgr = cached.get("original_bgr")
|
||||
if orig_bgr is None:
|
||||
raise ValueError("Original image not loaded")
|
||||
|
||||
# Method 1: Hough lines
|
||||
try:
|
||||
deskewed_hough, angle_hough = deskew_image(orig_bgr.copy())
|
||||
except Exception:
|
||||
deskewed_hough, angle_hough = orig_bgr, 0.0
|
||||
|
||||
# Method 2: Word alignment
|
||||
success_enc, png_orig = cv2.imencode(".png", orig_bgr)
|
||||
orig_bytes = png_orig.tobytes() if success_enc else b""
|
||||
try:
|
||||
deskewed_wa_bytes, angle_wa = deskew_image_by_word_alignment(orig_bytes)
|
||||
except Exception:
|
||||
deskewed_wa_bytes, angle_wa = orig_bytes, 0.0
|
||||
|
||||
# Pick best method
|
||||
if abs(angle_wa) >= abs(angle_hough) or abs(angle_hough) < 0.1:
|
||||
method_used = "word_alignment"
|
||||
angle_applied = angle_wa
|
||||
wa_arr = np.frombuffer(deskewed_wa_bytes, dtype=np.uint8)
|
||||
deskewed_bgr = cv2.imdecode(wa_arr, cv2.IMREAD_COLOR)
|
||||
if deskewed_bgr is None:
|
||||
deskewed_bgr = deskewed_hough
|
||||
method_used = "hough"
|
||||
angle_applied = angle_hough
|
||||
else:
|
||||
method_used = "hough"
|
||||
angle_applied = angle_hough
|
||||
deskewed_bgr = deskewed_hough
|
||||
|
||||
success, png_buf = cv2.imencode(".png", deskewed_bgr)
|
||||
deskewed_png = png_buf.tobytes() if success else b""
|
||||
|
||||
deskew_result = {
|
||||
"method_used": method_used,
|
||||
"rotation_degrees": round(float(angle_applied), 3),
|
||||
"duration_seconds": round(time.time() - t0, 2),
|
||||
}
|
||||
|
||||
cached["deskewed_bgr"] = deskewed_bgr
|
||||
cached["deskew_result"] = deskew_result
|
||||
await update_session_db(
|
||||
session_id,
|
||||
deskewed_png=deskewed_png,
|
||||
deskew_result=deskew_result,
|
||||
auto_rotation_degrees=float(angle_applied),
|
||||
current_step=3,
|
||||
)
|
||||
session = await get_session_db(session_id)
|
||||
|
||||
steps_run.append("deskew")
|
||||
yield await _auto_sse_event("deskew", "done", deskew_result)
|
||||
except Exception as e:
|
||||
logger.error(f"Auto-mode deskew failed for {session_id}: {e}")
|
||||
error_step = "deskew"
|
||||
yield await _auto_sse_event("deskew", "error", {"message": str(e)})
|
||||
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
|
||||
return
|
||||
else:
|
||||
steps_skipped.append("deskew")
|
||||
yield await _auto_sse_event("deskew", "skipped", {"reason": "from_step > 1"})
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Step 2: Dewarp
|
||||
# -----------------------------------------------------------------
|
||||
if req.from_step <= 2:
|
||||
yield await _auto_sse_event("dewarp", "start", {"method": req.dewarp_method})
|
||||
try:
|
||||
t0 = time.time()
|
||||
deskewed_bgr = cached.get("deskewed_bgr")
|
||||
if deskewed_bgr is None:
|
||||
raise ValueError("Deskewed image not available")
|
||||
|
||||
if req.dewarp_method == "vlm":
|
||||
success_enc, png_buf = cv2.imencode(".png", deskewed_bgr)
|
||||
img_bytes = png_buf.tobytes() if success_enc else b""
|
||||
vlm_det = await _detect_shear_with_vlm(img_bytes)
|
||||
shear_deg = vlm_det["shear_degrees"]
|
||||
if abs(shear_deg) >= 0.05 and vlm_det["confidence"] >= 0.3:
|
||||
dewarped_bgr = _apply_shear(deskewed_bgr, -shear_deg)
|
||||
else:
|
||||
dewarped_bgr = deskewed_bgr
|
||||
dewarp_info = {
|
||||
"method": vlm_det["method"],
|
||||
"shear_degrees": shear_deg,
|
||||
"confidence": vlm_det["confidence"],
|
||||
"detections": [vlm_det],
|
||||
}
|
||||
else:
|
||||
dewarped_bgr, dewarp_info = dewarp_image(deskewed_bgr)
|
||||
|
||||
success_enc, png_buf = cv2.imencode(".png", dewarped_bgr)
|
||||
dewarped_png = png_buf.tobytes() if success_enc else b""
|
||||
|
||||
dewarp_result = {
|
||||
"method_used": dewarp_info["method"],
|
||||
"shear_degrees": dewarp_info["shear_degrees"],
|
||||
"confidence": dewarp_info["confidence"],
|
||||
"duration_seconds": round(time.time() - t0, 2),
|
||||
"detections": dewarp_info.get("detections", []),
|
||||
}
|
||||
|
||||
cached["dewarped_bgr"] = dewarped_bgr
|
||||
cached["dewarp_result"] = dewarp_result
|
||||
await update_session_db(
|
||||
session_id,
|
||||
dewarped_png=dewarped_png,
|
||||
dewarp_result=dewarp_result,
|
||||
auto_shear_degrees=dewarp_info.get("shear_degrees", 0.0),
|
||||
current_step=4,
|
||||
)
|
||||
session = await get_session_db(session_id)
|
||||
|
||||
steps_run.append("dewarp")
|
||||
yield await _auto_sse_event("dewarp", "done", dewarp_result)
|
||||
except Exception as e:
|
||||
logger.error(f"Auto-mode dewarp failed for {session_id}: {e}")
|
||||
error_step = "dewarp"
|
||||
yield await _auto_sse_event("dewarp", "error", {"message": str(e)})
|
||||
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
|
||||
return
|
||||
else:
|
||||
steps_skipped.append("dewarp")
|
||||
yield await _auto_sse_event("dewarp", "skipped", {"reason": "from_step > 2"})
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Step 3: Columns
|
||||
# -----------------------------------------------------------------
|
||||
if req.from_step <= 3:
|
||||
yield await _auto_sse_event("columns", "start", {})
|
||||
try:
|
||||
t0 = time.time()
|
||||
col_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
||||
if col_img is None:
|
||||
raise ValueError("Cropped/dewarped image not available")
|
||||
|
||||
ocr_img = create_ocr_image(col_img)
|
||||
h, w = ocr_img.shape[:2]
|
||||
|
||||
geo_result = detect_column_geometry(ocr_img, col_img)
|
||||
if geo_result is None:
|
||||
layout_img = create_layout_image(col_img)
|
||||
regions = analyze_layout(layout_img, ocr_img)
|
||||
cached["_word_dicts"] = None
|
||||
cached["_inv"] = None
|
||||
cached["_content_bounds"] = None
|
||||
else:
|
||||
geometries, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
|
||||
content_w = right_x - left_x
|
||||
cached["_word_dicts"] = word_dicts
|
||||
cached["_inv"] = inv
|
||||
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
|
||||
|
||||
header_y, footer_y = _detect_header_footer_gaps(inv, w, h) if inv is not None else (None, None)
|
||||
geometries = _detect_sub_columns(geometries, content_w, left_x=left_x,
|
||||
top_y=top_y, header_y=header_y, footer_y=footer_y)
|
||||
regions = classify_column_types(geometries, content_w, top_y, w, h, bottom_y,
|
||||
left_x=left_x, right_x=right_x, inv=inv)
|
||||
|
||||
columns = [asdict(r) for r in regions]
|
||||
column_result = {
|
||||
"columns": columns,
|
||||
"classification_methods": list({c.get("classification_method", "") for c in columns if c.get("classification_method")}),
|
||||
"duration_seconds": round(time.time() - t0, 2),
|
||||
}
|
||||
|
||||
cached["column_result"] = column_result
|
||||
await update_session_db(session_id, column_result=column_result,
|
||||
row_result=None, word_result=None, current_step=6)
|
||||
session = await get_session_db(session_id)
|
||||
|
||||
steps_run.append("columns")
|
||||
yield await _auto_sse_event("columns", "done", {
|
||||
"column_count": len(columns),
|
||||
"duration_seconds": column_result["duration_seconds"],
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Auto-mode columns failed for {session_id}: {e}")
|
||||
error_step = "columns"
|
||||
yield await _auto_sse_event("columns", "error", {"message": str(e)})
|
||||
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
|
||||
return
|
||||
else:
|
||||
steps_skipped.append("columns")
|
||||
yield await _auto_sse_event("columns", "skipped", {"reason": "from_step > 3"})
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Step 4: Rows
|
||||
# -----------------------------------------------------------------
|
||||
if req.from_step <= 4:
|
||||
yield await _auto_sse_event("rows", "start", {})
|
||||
try:
|
||||
t0 = time.time()
|
||||
row_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
||||
session = await get_session_db(session_id)
|
||||
column_result = session.get("column_result") or cached.get("column_result")
|
||||
if not column_result or not column_result.get("columns"):
|
||||
raise ValueError("Column detection must complete first")
|
||||
|
||||
col_regions = [
|
||||
PageRegion(
|
||||
type=c["type"], x=c["x"], y=c["y"],
|
||||
width=c["width"], height=c["height"],
|
||||
classification_confidence=c.get("classification_confidence", 1.0),
|
||||
classification_method=c.get("classification_method", ""),
|
||||
)
|
||||
for c in column_result["columns"]
|
||||
]
|
||||
|
||||
word_dicts = cached.get("_word_dicts")
|
||||
inv = cached.get("_inv")
|
||||
content_bounds = cached.get("_content_bounds")
|
||||
|
||||
if word_dicts is None or inv is None or content_bounds is None:
|
||||
ocr_img_tmp = create_ocr_image(row_img)
|
||||
geo_result = detect_column_geometry(ocr_img_tmp, row_img)
|
||||
if geo_result is None:
|
||||
raise ValueError("Column geometry detection failed — cannot detect rows")
|
||||
_g, lx, rx, ty, by, word_dicts, inv = geo_result
|
||||
cached["_word_dicts"] = word_dicts
|
||||
cached["_inv"] = inv
|
||||
cached["_content_bounds"] = (lx, rx, ty, by)
|
||||
content_bounds = (lx, rx, ty, by)
|
||||
|
||||
left_x, right_x, top_y, bottom_y = content_bounds
|
||||
row_geoms = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y)
|
||||
|
||||
row_list = [
|
||||
{
|
||||
"index": r.index, "x": r.x, "y": r.y,
|
||||
"width": r.width, "height": r.height,
|
||||
"word_count": r.word_count,
|
||||
"row_type": r.row_type,
|
||||
"gap_before": r.gap_before,
|
||||
}
|
||||
for r in row_geoms
|
||||
]
|
||||
row_result = {
|
||||
"rows": row_list,
|
||||
"row_count": len(row_list),
|
||||
"content_rows": len([r for r in row_geoms if r.row_type == "content"]),
|
||||
"duration_seconds": round(time.time() - t0, 2),
|
||||
}
|
||||
|
||||
cached["row_result"] = row_result
|
||||
await update_session_db(session_id, row_result=row_result, current_step=7)
|
||||
session = await get_session_db(session_id)
|
||||
|
||||
steps_run.append("rows")
|
||||
yield await _auto_sse_event("rows", "done", {
|
||||
"row_count": len(row_list),
|
||||
"content_rows": row_result["content_rows"],
|
||||
"duration_seconds": row_result["duration_seconds"],
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Auto-mode rows failed for {session_id}: {e}")
|
||||
error_step = "rows"
|
||||
yield await _auto_sse_event("rows", "error", {"message": str(e)})
|
||||
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
|
||||
return
|
||||
else:
|
||||
steps_skipped.append("rows")
|
||||
yield await _auto_sse_event("rows", "skipped", {"reason": "from_step > 4"})
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Step 5: Words (OCR)
|
||||
# -----------------------------------------------------------------
|
||||
if req.from_step <= 5:
|
||||
yield await _auto_sse_event("words", "start", {"engine": req.ocr_engine})
|
||||
try:
|
||||
t0 = time.time()
|
||||
word_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
||||
session = await get_session_db(session_id)
|
||||
|
||||
column_result = session.get("column_result") or cached.get("column_result")
|
||||
row_result = session.get("row_result") or cached.get("row_result")
|
||||
|
||||
col_regions = [
|
||||
PageRegion(
|
||||
type=c["type"], x=c["x"], y=c["y"],
|
||||
width=c["width"], height=c["height"],
|
||||
classification_confidence=c.get("classification_confidence", 1.0),
|
||||
classification_method=c.get("classification_method", ""),
|
||||
)
|
||||
for c in column_result["columns"]
|
||||
]
|
||||
row_geoms = [
|
||||
RowGeometry(
|
||||
index=r["index"], x=r["x"], y=r["y"],
|
||||
width=r["width"], height=r["height"],
|
||||
word_count=r.get("word_count", 0), words=[],
|
||||
row_type=r.get("row_type", "content"),
|
||||
gap_before=r.get("gap_before", 0),
|
||||
)
|
||||
for r in row_result["rows"]
|
||||
]
|
||||
|
||||
word_dicts = cached.get("_word_dicts")
|
||||
if word_dicts is not None:
|
||||
content_bounds = cached.get("_content_bounds")
|
||||
top_y = content_bounds[2] if content_bounds else min(r.y for r in row_geoms)
|
||||
for row in row_geoms:
|
||||
row_y_rel = row.y - top_y
|
||||
row_bottom_rel = row_y_rel + row.height
|
||||
row.words = [
|
||||
w for w in word_dicts
|
||||
if row_y_rel <= w['top'] + w['height'] / 2 < row_bottom_rel
|
||||
]
|
||||
row.word_count = len(row.words)
|
||||
|
||||
ocr_img = create_ocr_image(word_img)
|
||||
img_h, img_w = word_img.shape[:2]
|
||||
|
||||
cells, columns_meta = build_cell_grid(
|
||||
ocr_img, col_regions, row_geoms, img_w, img_h,
|
||||
ocr_engine=req.ocr_engine, img_bgr=word_img,
|
||||
)
|
||||
duration = time.time() - t0
|
||||
|
||||
col_types = {c['type'] for c in columns_meta}
|
||||
is_vocab = bool(col_types & {'column_en', 'column_de'})
|
||||
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
|
||||
used_engine = cells[0].get("ocr_engine", "tesseract") if cells else req.ocr_engine
|
||||
|
||||
# Apply IPA phonetic fixes directly to cell texts
|
||||
fix_cell_phonetics(cells, pronunciation=req.pronunciation)
|
||||
|
||||
word_result_data = {
|
||||
"cells": cells,
|
||||
"grid_shape": {
|
||||
"rows": n_content_rows,
|
||||
"cols": len(columns_meta),
|
||||
"total_cells": len(cells),
|
||||
},
|
||||
"columns_used": columns_meta,
|
||||
"layout": "vocab" if is_vocab else "generic",
|
||||
"image_width": img_w,
|
||||
"image_height": img_h,
|
||||
"duration_seconds": round(duration, 2),
|
||||
"ocr_engine": used_engine,
|
||||
"summary": {
|
||||
"total_cells": len(cells),
|
||||
"non_empty_cells": sum(1 for c in cells if c.get("text")),
|
||||
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
|
||||
},
|
||||
}
|
||||
|
||||
has_text_col = 'column_text' in col_types
|
||||
if is_vocab or has_text_col:
|
||||
entries = _cells_to_vocab_entries(cells, columns_meta)
|
||||
entries = _fix_character_confusion(entries)
|
||||
entries = _fix_phonetic_brackets(entries, pronunciation=req.pronunciation)
|
||||
word_result_data["vocab_entries"] = entries
|
||||
word_result_data["entries"] = entries
|
||||
word_result_data["entry_count"] = len(entries)
|
||||
word_result_data["summary"]["total_entries"] = len(entries)
|
||||
|
||||
await update_session_db(session_id, word_result=word_result_data, current_step=8)
|
||||
cached["word_result"] = word_result_data
|
||||
session = await get_session_db(session_id)
|
||||
|
||||
steps_run.append("words")
|
||||
yield await _auto_sse_event("words", "done", {
|
||||
"total_cells": len(cells),
|
||||
"layout": word_result_data["layout"],
|
||||
"duration_seconds": round(duration, 2),
|
||||
"ocr_engine": used_engine,
|
||||
"summary": word_result_data["summary"],
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Auto-mode words failed for {session_id}: {e}")
|
||||
error_step = "words"
|
||||
yield await _auto_sse_event("words", "error", {"message": str(e)})
|
||||
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
|
||||
return
|
||||
else:
|
||||
steps_skipped.append("words")
|
||||
yield await _auto_sse_event("words", "skipped", {"reason": "from_step > 5"})
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Step 6: LLM Review (optional)
|
||||
# -----------------------------------------------------------------
|
||||
if req.from_step <= 6 and not req.skip_llm_review:
|
||||
yield await _auto_sse_event("llm_review", "start", {"model": OLLAMA_REVIEW_MODEL})
|
||||
try:
|
||||
session = await get_session_db(session_id)
|
||||
word_result = session.get("word_result") or cached.get("word_result")
|
||||
entries = word_result.get("entries") or word_result.get("vocab_entries") or []
|
||||
|
||||
if not entries:
|
||||
yield await _auto_sse_event("llm_review", "skipped", {"reason": "no entries"})
|
||||
steps_skipped.append("llm_review")
|
||||
else:
|
||||
reviewed = await llm_review_entries(entries)
|
||||
|
||||
session = await get_session_db(session_id)
|
||||
word_result_updated = dict(session.get("word_result") or {})
|
||||
word_result_updated["entries"] = reviewed
|
||||
word_result_updated["vocab_entries"] = reviewed
|
||||
word_result_updated["llm_reviewed"] = True
|
||||
word_result_updated["llm_model"] = OLLAMA_REVIEW_MODEL
|
||||
|
||||
await update_session_db(session_id, word_result=word_result_updated, current_step=9)
|
||||
cached["word_result"] = word_result_updated
|
||||
|
||||
steps_run.append("llm_review")
|
||||
yield await _auto_sse_event("llm_review", "done", {
|
||||
"entries_reviewed": len(reviewed),
|
||||
"model": OLLAMA_REVIEW_MODEL,
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"Auto-mode llm_review failed for {session_id} (non-fatal): {e}")
|
||||
yield await _auto_sse_event("llm_review", "error", {"message": str(e), "fatal": False})
|
||||
steps_skipped.append("llm_review")
|
||||
else:
|
||||
steps_skipped.append("llm_review")
|
||||
reason = "skipped by request" if req.skip_llm_review else "from_step > 6"
|
||||
yield await _auto_sse_event("llm_review", "skipped", {"reason": reason})
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Final event
|
||||
# -----------------------------------------------------------------
|
||||
yield await _auto_sse_event("complete", "done", {
|
||||
"steps_run": steps_run,
|
||||
"steps_skipped": steps_skipped,
|
||||
})
|
||||
|
||||
return StreamingResponse(
|
||||
_generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
__all__ = ["router"]
|
||||
|
||||
84
klausur-service/backend/ocr_pipeline_auto_helpers.py
Normal file
84
klausur-service/backend/ocr_pipeline_auto_helpers.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
OCR Pipeline Auto-Mode Helpers.
|
||||
|
||||
VLM shear detection, SSE event formatting, and request models.
|
||||
|
||||
Lizenz: Apache 2.0
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RunAutoRequest(BaseModel):
|
||||
from_step: int = 1 # 1=deskew, 2=dewarp, 3=columns, 4=rows, 5=words, 6=llm-review
|
||||
ocr_engine: str = "auto" # "auto" | "rapid" | "tesseract"
|
||||
pronunciation: str = "british"
|
||||
skip_llm_review: bool = False
|
||||
dewarp_method: str = "ensemble" # "ensemble" | "vlm" | "cv"
|
||||
|
||||
|
||||
async def auto_sse_event(step: str, status: str, data: Dict[str, Any]) -> str:
|
||||
"""Format a single SSE event line."""
|
||||
payload = {"step": step, "status": status, **data}
|
||||
return f"data: {json.dumps(payload)}\n\n"
|
||||
|
||||
|
||||
async def detect_shear_with_vlm(image_bytes: bytes) -> Dict[str, Any]:
|
||||
"""Ask qwen2.5vl:32b to estimate the vertical shear angle of a scanned page.
|
||||
|
||||
The VLM is shown the image and asked: are the column/table borders tilted?
|
||||
If yes, by how many degrees? Returns a dict with shear_degrees and confidence.
|
||||
Confidence is 0.0 if Ollama is unavailable or parsing fails.
|
||||
"""
|
||||
import httpx
|
||||
import base64
|
||||
|
||||
ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
|
||||
model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b")
|
||||
|
||||
prompt = (
|
||||
"This is a scanned vocabulary worksheet. Look at the vertical borders of the table columns. "
|
||||
"Are they perfectly vertical, or do they tilt slightly? "
|
||||
"If they tilt, estimate the tilt angle in degrees (positive = top tilts right, negative = top tilts left). "
|
||||
"Reply with ONLY a JSON object like: {\"shear_degrees\": 1.2, \"confidence\": 0.8} "
|
||||
"Use confidence 0.0-1.0 based on how clearly you can see the tilt. "
|
||||
"If the columns look straight, return {\"shear_degrees\": 0.0, \"confidence\": 0.9}"
|
||||
)
|
||||
|
||||
img_b64 = base64.b64encode(image_bytes).decode("utf-8")
|
||||
payload = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"images": [img_b64],
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
resp = await client.post(f"{ollama_base}/api/generate", json=payload)
|
||||
resp.raise_for_status()
|
||||
text = resp.json().get("response", "")
|
||||
|
||||
# Parse JSON from response (may have surrounding text)
|
||||
match = re.search(r'\{[^}]+\}', text)
|
||||
if match:
|
||||
data = json.loads(match.group(0))
|
||||
shear = float(data.get("shear_degrees", 0.0))
|
||||
conf = float(data.get("confidence", 0.0))
|
||||
# Clamp to reasonable range
|
||||
shear = max(-3.0, min(3.0, shear))
|
||||
conf = max(0.0, min(1.0, conf))
|
||||
return {"method": "vlm_qwen2.5vl", "shear_degrees": round(shear, 3), "confidence": round(conf, 2)}
|
||||
except Exception as e:
|
||||
logger.warning(f"VLM dewarp failed: {e}")
|
||||
|
||||
return {"method": "vlm_qwen2.5vl", "shear_degrees": 0.0, "confidence": 0.0}
|
||||
528
klausur-service/backend/ocr_pipeline_auto_steps.py
Normal file
528
klausur-service/backend/ocr_pipeline_auto_steps.py
Normal file
@@ -0,0 +1,528 @@
|
||||
"""
|
||||
OCR Pipeline Auto-Mode Orchestrator.
|
||||
|
||||
POST /sessions/{session_id}/run-auto -- full auto-mode with SSE streaming.
|
||||
|
||||
Lizenz: Apache 2.0
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import asdict
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from cv_vocab_pipeline import (
|
||||
OLLAMA_REVIEW_MODEL,
|
||||
PageRegion,
|
||||
RowGeometry,
|
||||
_cells_to_vocab_entries,
|
||||
_detect_header_footer_gaps,
|
||||
_detect_sub_columns,
|
||||
_fix_character_confusion,
|
||||
_fix_phonetic_brackets,
|
||||
fix_cell_phonetics,
|
||||
analyze_layout,
|
||||
build_cell_grid,
|
||||
classify_column_types,
|
||||
create_layout_image,
|
||||
create_ocr_image,
|
||||
deskew_image,
|
||||
deskew_image_by_word_alignment,
|
||||
detect_column_geometry,
|
||||
detect_row_geometry,
|
||||
_apply_shear,
|
||||
dewarp_image,
|
||||
llm_review_entries,
|
||||
)
|
||||
from ocr_pipeline_common import (
|
||||
_cache,
|
||||
_load_session_to_cache,
|
||||
_get_cached,
|
||||
)
|
||||
from ocr_pipeline_session_store import (
|
||||
get_session_db,
|
||||
update_session_db,
|
||||
)
|
||||
from ocr_pipeline_auto_helpers import (
|
||||
RunAutoRequest,
|
||||
auto_sse_event as _auto_sse_event,
|
||||
detect_shear_with_vlm as _detect_shear_with_vlm,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["ocr-pipeline"])
|
||||
|
||||
@router.post("/sessions/{session_id}/run-auto")
|
||||
async def run_auto(session_id: str, req: RunAutoRequest, request: Request):
|
||||
"""Run the full OCR pipeline automatically from a given step, streaming SSE progress.
|
||||
|
||||
Steps:
|
||||
1. Deskew -- straighten the scan
|
||||
2. Dewarp -- correct vertical shear (ensemble CV or VLM)
|
||||
3. Columns -- detect column layout
|
||||
4. Rows -- detect row layout
|
||||
5. Words -- OCR each cell
|
||||
6. LLM review -- correct OCR errors (optional)
|
||||
|
||||
Already-completed steps are skipped unless `from_step` forces a rerun.
|
||||
Yields SSE events of the form:
|
||||
data: {"step": "deskew", "status": "start"|"done"|"skipped"|"error", ...}
|
||||
|
||||
Final event:
|
||||
data: {"step": "complete", "status": "done", "steps_run": [...], "steps_skipped": [...]}
|
||||
"""
|
||||
if req.from_step < 1 or req.from_step > 6:
|
||||
raise HTTPException(status_code=400, detail="from_step must be 1-6")
|
||||
if req.dewarp_method not in ("ensemble", "vlm", "cv"):
|
||||
raise HTTPException(status_code=400, detail="dewarp_method must be: ensemble, vlm, cv")
|
||||
|
||||
if session_id not in _cache:
|
||||
await _load_session_to_cache(session_id)
|
||||
|
||||
async def _generate():
|
||||
steps_run: List[str] = []
|
||||
steps_skipped: List[str] = []
|
||||
error_step: Optional[str] = None
|
||||
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
yield await _auto_sse_event("error", "error", {"message": f"Session {session_id} not found"})
|
||||
return
|
||||
|
||||
cached = _get_cached(session_id)
|
||||
|
||||
# Step 1: Deskew
|
||||
if req.from_step <= 1:
|
||||
yield await _auto_sse_event("deskew", "start", {})
|
||||
try:
|
||||
t0 = time.time()
|
||||
orig_bgr = cached.get("original_bgr")
|
||||
if orig_bgr is None:
|
||||
raise ValueError("Original image not loaded")
|
||||
|
||||
try:
|
||||
deskewed_hough, angle_hough = deskew_image(orig_bgr.copy())
|
||||
except Exception:
|
||||
deskewed_hough, angle_hough = orig_bgr, 0.0
|
||||
|
||||
success_enc, png_orig = cv2.imencode(".png", orig_bgr)
|
||||
orig_bytes = png_orig.tobytes() if success_enc else b""
|
||||
try:
|
||||
deskewed_wa_bytes, angle_wa = deskew_image_by_word_alignment(orig_bytes)
|
||||
except Exception:
|
||||
deskewed_wa_bytes, angle_wa = orig_bytes, 0.0
|
||||
|
||||
if abs(angle_wa) >= abs(angle_hough) or abs(angle_hough) < 0.1:
|
||||
method_used = "word_alignment"
|
||||
angle_applied = angle_wa
|
||||
wa_arr = np.frombuffer(deskewed_wa_bytes, dtype=np.uint8)
|
||||
deskewed_bgr = cv2.imdecode(wa_arr, cv2.IMREAD_COLOR)
|
||||
if deskewed_bgr is None:
|
||||
deskewed_bgr = deskewed_hough
|
||||
method_used = "hough"
|
||||
angle_applied = angle_hough
|
||||
else:
|
||||
method_used = "hough"
|
||||
angle_applied = angle_hough
|
||||
deskewed_bgr = deskewed_hough
|
||||
|
||||
success, png_buf = cv2.imencode(".png", deskewed_bgr)
|
||||
deskewed_png = png_buf.tobytes() if success else b""
|
||||
|
||||
deskew_result = {
|
||||
"method_used": method_used,
|
||||
"rotation_degrees": round(float(angle_applied), 3),
|
||||
"duration_seconds": round(time.time() - t0, 2),
|
||||
}
|
||||
|
||||
cached["deskewed_bgr"] = deskewed_bgr
|
||||
cached["deskew_result"] = deskew_result
|
||||
await update_session_db(
|
||||
session_id,
|
||||
deskewed_png=deskewed_png,
|
||||
deskew_result=deskew_result,
|
||||
auto_rotation_degrees=float(angle_applied),
|
||||
current_step=3,
|
||||
)
|
||||
session = await get_session_db(session_id)
|
||||
|
||||
steps_run.append("deskew")
|
||||
yield await _auto_sse_event("deskew", "done", deskew_result)
|
||||
except Exception as e:
|
||||
logger.error(f"Auto-mode deskew failed for {session_id}: {e}")
|
||||
error_step = "deskew"
|
||||
yield await _auto_sse_event("deskew", "error", {"message": str(e)})
|
||||
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
|
||||
return
|
||||
else:
|
||||
steps_skipped.append("deskew")
|
||||
yield await _auto_sse_event("deskew", "skipped", {"reason": "from_step > 1"})
|
||||
|
||||
# Step 2: Dewarp
|
||||
if req.from_step <= 2:
|
||||
yield await _auto_sse_event("dewarp", "start", {"method": req.dewarp_method})
|
||||
try:
|
||||
t0 = time.time()
|
||||
deskewed_bgr = cached.get("deskewed_bgr")
|
||||
if deskewed_bgr is None:
|
||||
raise ValueError("Deskewed image not available")
|
||||
|
||||
if req.dewarp_method == "vlm":
|
||||
success_enc, png_buf = cv2.imencode(".png", deskewed_bgr)
|
||||
img_bytes = png_buf.tobytes() if success_enc else b""
|
||||
vlm_det = await _detect_shear_with_vlm(img_bytes)
|
||||
shear_deg = vlm_det["shear_degrees"]
|
||||
if abs(shear_deg) >= 0.05 and vlm_det["confidence"] >= 0.3:
|
||||
dewarped_bgr = _apply_shear(deskewed_bgr, -shear_deg)
|
||||
else:
|
||||
dewarped_bgr = deskewed_bgr
|
||||
dewarp_info = {
|
||||
"method": vlm_det["method"],
|
||||
"shear_degrees": shear_deg,
|
||||
"confidence": vlm_det["confidence"],
|
||||
"detections": [vlm_det],
|
||||
}
|
||||
else:
|
||||
dewarped_bgr, dewarp_info = dewarp_image(deskewed_bgr)
|
||||
|
||||
success_enc, png_buf = cv2.imencode(".png", dewarped_bgr)
|
||||
dewarped_png = png_buf.tobytes() if success_enc else b""
|
||||
|
||||
dewarp_result = {
|
||||
"method_used": dewarp_info["method"],
|
||||
"shear_degrees": dewarp_info["shear_degrees"],
|
||||
"confidence": dewarp_info["confidence"],
|
||||
"duration_seconds": round(time.time() - t0, 2),
|
||||
"detections": dewarp_info.get("detections", []),
|
||||
}
|
||||
|
||||
cached["dewarped_bgr"] = dewarped_bgr
|
||||
cached["dewarp_result"] = dewarp_result
|
||||
await update_session_db(
|
||||
session_id,
|
||||
dewarped_png=dewarped_png,
|
||||
dewarp_result=dewarp_result,
|
||||
auto_shear_degrees=dewarp_info.get("shear_degrees", 0.0),
|
||||
current_step=4,
|
||||
)
|
||||
session = await get_session_db(session_id)
|
||||
|
||||
steps_run.append("dewarp")
|
||||
yield await _auto_sse_event("dewarp", "done", dewarp_result)
|
||||
except Exception as e:
|
||||
logger.error(f"Auto-mode dewarp failed for {session_id}: {e}")
|
||||
error_step = "dewarp"
|
||||
yield await _auto_sse_event("dewarp", "error", {"message": str(e)})
|
||||
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
|
||||
return
|
||||
else:
|
||||
steps_skipped.append("dewarp")
|
||||
yield await _auto_sse_event("dewarp", "skipped", {"reason": "from_step > 2"})
|
||||
|
||||
# Step 3: Columns
|
||||
if req.from_step <= 3:
|
||||
yield await _auto_sse_event("columns", "start", {})
|
||||
try:
|
||||
t0 = time.time()
|
||||
col_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
||||
if col_img is None:
|
||||
raise ValueError("Cropped/dewarped image not available")
|
||||
|
||||
ocr_img = create_ocr_image(col_img)
|
||||
h, w = ocr_img.shape[:2]
|
||||
|
||||
geo_result = detect_column_geometry(ocr_img, col_img)
|
||||
if geo_result is None:
|
||||
layout_img = create_layout_image(col_img)
|
||||
regions = analyze_layout(layout_img, ocr_img)
|
||||
cached["_word_dicts"] = None
|
||||
cached["_inv"] = None
|
||||
cached["_content_bounds"] = None
|
||||
else:
|
||||
geometries, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
|
||||
content_w = right_x - left_x
|
||||
cached["_word_dicts"] = word_dicts
|
||||
cached["_inv"] = inv
|
||||
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
|
||||
|
||||
header_y, footer_y = _detect_header_footer_gaps(inv, w, h) if inv is not None else (None, None)
|
||||
geometries = _detect_sub_columns(geometries, content_w, left_x=left_x,
|
||||
top_y=top_y, header_y=header_y, footer_y=footer_y)
|
||||
regions = classify_column_types(geometries, content_w, top_y, w, h, bottom_y,
|
||||
left_x=left_x, right_x=right_x, inv=inv)
|
||||
|
||||
columns = [asdict(r) for r in regions]
|
||||
column_result = {
|
||||
"columns": columns,
|
||||
"classification_methods": list({c.get("classification_method", "") for c in columns if c.get("classification_method")}),
|
||||
"duration_seconds": round(time.time() - t0, 2),
|
||||
}
|
||||
|
||||
cached["column_result"] = column_result
|
||||
await update_session_db(session_id, column_result=column_result,
|
||||
row_result=None, word_result=None, current_step=6)
|
||||
session = await get_session_db(session_id)
|
||||
|
||||
steps_run.append("columns")
|
||||
yield await _auto_sse_event("columns", "done", {
|
||||
"column_count": len(columns),
|
||||
"duration_seconds": column_result["duration_seconds"],
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Auto-mode columns failed for {session_id}: {e}")
|
||||
error_step = "columns"
|
||||
yield await _auto_sse_event("columns", "error", {"message": str(e)})
|
||||
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
|
||||
return
|
||||
else:
|
||||
steps_skipped.append("columns")
|
||||
yield await _auto_sse_event("columns", "skipped", {"reason": "from_step > 3"})
|
||||
|
||||
# Step 4: Rows
|
||||
if req.from_step <= 4:
|
||||
yield await _auto_sse_event("rows", "start", {})
|
||||
try:
|
||||
t0 = time.time()
|
||||
row_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
||||
session = await get_session_db(session_id)
|
||||
column_result = session.get("column_result") or cached.get("column_result")
|
||||
if not column_result or not column_result.get("columns"):
|
||||
raise ValueError("Column detection must complete first")
|
||||
|
||||
col_regions = [
|
||||
PageRegion(
|
||||
type=c["type"], x=c["x"], y=c["y"],
|
||||
width=c["width"], height=c["height"],
|
||||
classification_confidence=c.get("classification_confidence", 1.0),
|
||||
classification_method=c.get("classification_method", ""),
|
||||
)
|
||||
for c in column_result["columns"]
|
||||
]
|
||||
|
||||
word_dicts = cached.get("_word_dicts")
|
||||
inv = cached.get("_inv")
|
||||
content_bounds = cached.get("_content_bounds")
|
||||
|
||||
if word_dicts is None or inv is None or content_bounds is None:
|
||||
ocr_img_tmp = create_ocr_image(row_img)
|
||||
geo_result = detect_column_geometry(ocr_img_tmp, row_img)
|
||||
if geo_result is None:
|
||||
raise ValueError("Column geometry detection failed -- cannot detect rows")
|
||||
_g, lx, rx, ty, by, word_dicts, inv = geo_result
|
||||
cached["_word_dicts"] = word_dicts
|
||||
cached["_inv"] = inv
|
||||
cached["_content_bounds"] = (lx, rx, ty, by)
|
||||
content_bounds = (lx, rx, ty, by)
|
||||
|
||||
left_x, right_x, top_y, bottom_y = content_bounds
|
||||
row_geoms = detect_row_geometry(inv, word_dicts, left_x, right_x, top_y, bottom_y)
|
||||
|
||||
row_list = [
|
||||
{
|
||||
"index": r.index, "x": r.x, "y": r.y,
|
||||
"width": r.width, "height": r.height,
|
||||
"word_count": r.word_count,
|
||||
"row_type": r.row_type,
|
||||
"gap_before": r.gap_before,
|
||||
}
|
||||
for r in row_geoms
|
||||
]
|
||||
row_result = {
|
||||
"rows": row_list,
|
||||
"row_count": len(row_list),
|
||||
"content_rows": len([r for r in row_geoms if r.row_type == "content"]),
|
||||
"duration_seconds": round(time.time() - t0, 2),
|
||||
}
|
||||
|
||||
cached["row_result"] = row_result
|
||||
await update_session_db(session_id, row_result=row_result, current_step=7)
|
||||
session = await get_session_db(session_id)
|
||||
|
||||
steps_run.append("rows")
|
||||
yield await _auto_sse_event("rows", "done", {
|
||||
"row_count": len(row_list),
|
||||
"content_rows": row_result["content_rows"],
|
||||
"duration_seconds": row_result["duration_seconds"],
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Auto-mode rows failed for {session_id}: {e}")
|
||||
error_step = "rows"
|
||||
yield await _auto_sse_event("rows", "error", {"message": str(e)})
|
||||
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
|
||||
return
|
||||
else:
|
||||
steps_skipped.append("rows")
|
||||
yield await _auto_sse_event("rows", "skipped", {"reason": "from_step > 4"})
|
||||
|
||||
# Step 5: Words (OCR)
|
||||
if req.from_step <= 5:
|
||||
yield await _auto_sse_event("words", "start", {"engine": req.ocr_engine})
|
||||
try:
|
||||
t0 = time.time()
|
||||
word_img = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
||||
session = await get_session_db(session_id)
|
||||
|
||||
column_result = session.get("column_result") or cached.get("column_result")
|
||||
row_result = session.get("row_result") or cached.get("row_result")
|
||||
|
||||
col_regions = [
|
||||
PageRegion(
|
||||
type=c["type"], x=c["x"], y=c["y"],
|
||||
width=c["width"], height=c["height"],
|
||||
classification_confidence=c.get("classification_confidence", 1.0),
|
||||
classification_method=c.get("classification_method", ""),
|
||||
)
|
||||
for c in column_result["columns"]
|
||||
]
|
||||
row_geoms = [
|
||||
RowGeometry(
|
||||
index=r["index"], x=r["x"], y=r["y"],
|
||||
width=r["width"], height=r["height"],
|
||||
word_count=r.get("word_count", 0), words=[],
|
||||
row_type=r.get("row_type", "content"),
|
||||
gap_before=r.get("gap_before", 0),
|
||||
)
|
||||
for r in row_result["rows"]
|
||||
]
|
||||
|
||||
word_dicts = cached.get("_word_dicts")
|
||||
if word_dicts is not None:
|
||||
content_bounds = cached.get("_content_bounds")
|
||||
top_y = content_bounds[2] if content_bounds else min(r.y for r in row_geoms)
|
||||
for row in row_geoms:
|
||||
row_y_rel = row.y - top_y
|
||||
row_bottom_rel = row_y_rel + row.height
|
||||
row.words = [
|
||||
w for w in word_dicts
|
||||
if row_y_rel <= w['top'] + w['height'] / 2 < row_bottom_rel
|
||||
]
|
||||
row.word_count = len(row.words)
|
||||
|
||||
ocr_img = create_ocr_image(word_img)
|
||||
img_h, img_w = word_img.shape[:2]
|
||||
|
||||
cells, columns_meta = build_cell_grid(
|
||||
ocr_img, col_regions, row_geoms, img_w, img_h,
|
||||
ocr_engine=req.ocr_engine, img_bgr=word_img,
|
||||
)
|
||||
duration = time.time() - t0
|
||||
|
||||
col_types = {c['type'] for c in columns_meta}
|
||||
is_vocab = bool(col_types & {'column_en', 'column_de'})
|
||||
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
|
||||
used_engine = cells[0].get("ocr_engine", "tesseract") if cells else req.ocr_engine
|
||||
|
||||
fix_cell_phonetics(cells, pronunciation=req.pronunciation)
|
||||
|
||||
word_result_data = {
|
||||
"cells": cells,
|
||||
"grid_shape": {
|
||||
"rows": n_content_rows,
|
||||
"cols": len(columns_meta),
|
||||
"total_cells": len(cells),
|
||||
},
|
||||
"columns_used": columns_meta,
|
||||
"layout": "vocab" if is_vocab else "generic",
|
||||
"image_width": img_w,
|
||||
"image_height": img_h,
|
||||
"duration_seconds": round(duration, 2),
|
||||
"ocr_engine": used_engine,
|
||||
"summary": {
|
||||
"total_cells": len(cells),
|
||||
"non_empty_cells": sum(1 for c in cells if c.get("text")),
|
||||
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
|
||||
},
|
||||
}
|
||||
|
||||
has_text_col = 'column_text' in col_types
|
||||
if is_vocab or has_text_col:
|
||||
entries = _cells_to_vocab_entries(cells, columns_meta)
|
||||
entries = _fix_character_confusion(entries)
|
||||
entries = _fix_phonetic_brackets(entries, pronunciation=req.pronunciation)
|
||||
word_result_data["vocab_entries"] = entries
|
||||
word_result_data["entries"] = entries
|
||||
word_result_data["entry_count"] = len(entries)
|
||||
word_result_data["summary"]["total_entries"] = len(entries)
|
||||
|
||||
await update_session_db(session_id, word_result=word_result_data, current_step=8)
|
||||
cached["word_result"] = word_result_data
|
||||
session = await get_session_db(session_id)
|
||||
|
||||
steps_run.append("words")
|
||||
yield await _auto_sse_event("words", "done", {
|
||||
"total_cells": len(cells),
|
||||
"layout": word_result_data["layout"],
|
||||
"duration_seconds": round(duration, 2),
|
||||
"ocr_engine": used_engine,
|
||||
"summary": word_result_data["summary"],
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Auto-mode words failed for {session_id}: {e}")
|
||||
error_step = "words"
|
||||
yield await _auto_sse_event("words", "error", {"message": str(e)})
|
||||
yield await _auto_sse_event("complete", "error", {"error_step": error_step})
|
||||
return
|
||||
else:
|
||||
steps_skipped.append("words")
|
||||
yield await _auto_sse_event("words", "skipped", {"reason": "from_step > 5"})
|
||||
|
||||
# Step 6: LLM Review (optional)
|
||||
if req.from_step <= 6 and not req.skip_llm_review:
|
||||
yield await _auto_sse_event("llm_review", "start", {"model": OLLAMA_REVIEW_MODEL})
|
||||
try:
|
||||
session = await get_session_db(session_id)
|
||||
word_result = session.get("word_result") or cached.get("word_result")
|
||||
entries = word_result.get("entries") or word_result.get("vocab_entries") or []
|
||||
|
||||
if not entries:
|
||||
yield await _auto_sse_event("llm_review", "skipped", {"reason": "no entries"})
|
||||
steps_skipped.append("llm_review")
|
||||
else:
|
||||
reviewed = await llm_review_entries(entries)
|
||||
|
||||
session = await get_session_db(session_id)
|
||||
word_result_updated = dict(session.get("word_result") or {})
|
||||
word_result_updated["entries"] = reviewed
|
||||
word_result_updated["vocab_entries"] = reviewed
|
||||
word_result_updated["llm_reviewed"] = True
|
||||
word_result_updated["llm_model"] = OLLAMA_REVIEW_MODEL
|
||||
|
||||
await update_session_db(session_id, word_result=word_result_updated, current_step=9)
|
||||
cached["word_result"] = word_result_updated
|
||||
|
||||
steps_run.append("llm_review")
|
||||
yield await _auto_sse_event("llm_review", "done", {
|
||||
"entries_reviewed": len(reviewed),
|
||||
"model": OLLAMA_REVIEW_MODEL,
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"Auto-mode llm_review failed for {session_id} (non-fatal): {e}")
|
||||
yield await _auto_sse_event("llm_review", "error", {"message": str(e), "fatal": False})
|
||||
steps_skipped.append("llm_review")
|
||||
else:
|
||||
steps_skipped.append("llm_review")
|
||||
reason = "skipped by request" if req.skip_llm_review else "from_step > 6"
|
||||
yield await _auto_sse_event("llm_review", "skipped", {"reason": reason})
|
||||
|
||||
# Final event
|
||||
yield await _auto_sse_event("complete", "done", {
|
||||
"steps_run": steps_run,
|
||||
"steps_skipped": steps_skipped,
|
||||
})
|
||||
|
||||
return StreamingResponse(
|
||||
_generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
94
klausur-service/backend/ocr_pipeline_reprocess.py
Normal file
94
klausur-service/backend/ocr_pipeline_reprocess.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""
|
||||
OCR Pipeline Reprocess Endpoint.
|
||||
|
||||
POST /sessions/{session_id}/reprocess — clear downstream + restart from step.
|
||||
|
||||
Lizenz: Apache 2.0
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
|
||||
from ocr_pipeline_common import _cache
|
||||
from ocr_pipeline_session_store import get_session_db, update_session_db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["ocr-pipeline"])
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/reprocess")
|
||||
async def reprocess_session(session_id: str, request: Request):
|
||||
"""Re-run pipeline from a specific step, clearing downstream data.
|
||||
|
||||
Body: {"from_step": 5} (1-indexed step number)
|
||||
|
||||
Pipeline order: Orientation(1) -> Deskew(2) -> Dewarp(3) -> Crop(4) -> Columns(5) ->
|
||||
Rows(6) -> Words(7) -> LLM-Review(8) -> Reconstruction(9) -> Validation(10)
|
||||
|
||||
Clears downstream results:
|
||||
- from_step <= 1: orientation_result + all downstream
|
||||
- from_step <= 2: deskew_result + all downstream
|
||||
- from_step <= 3: dewarp_result + all downstream
|
||||
- from_step <= 4: crop_result + all downstream
|
||||
- from_step <= 5: column_result, row_result, word_result
|
||||
- from_step <= 6: row_result, word_result
|
||||
- from_step <= 7: word_result (cells, vocab_entries)
|
||||
- from_step <= 8: word_result.llm_review only
|
||||
"""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
body = await request.json()
|
||||
from_step = body.get("from_step", 1)
|
||||
if not isinstance(from_step, int) or from_step < 1 or from_step > 10:
|
||||
raise HTTPException(status_code=400, detail="from_step must be between 1 and 10")
|
||||
|
||||
update_kwargs: Dict[str, Any] = {"current_step": from_step}
|
||||
|
||||
# Clear downstream data based on from_step
|
||||
# New pipeline order: Orient(2) -> Deskew(3) -> Dewarp(4) -> Crop(5) ->
|
||||
# Columns(6) -> Rows(7) -> Words(8) -> LLM(9) -> Recon(10) -> GT(11)
|
||||
if from_step <= 8:
|
||||
update_kwargs["word_result"] = None
|
||||
elif from_step == 9:
|
||||
# Only clear LLM review from word_result
|
||||
word_result = session.get("word_result")
|
||||
if word_result:
|
||||
word_result.pop("llm_review", None)
|
||||
word_result.pop("llm_corrections", None)
|
||||
update_kwargs["word_result"] = word_result
|
||||
|
||||
if from_step <= 7:
|
||||
update_kwargs["row_result"] = None
|
||||
if from_step <= 6:
|
||||
update_kwargs["column_result"] = None
|
||||
if from_step <= 4:
|
||||
update_kwargs["crop_result"] = None
|
||||
if from_step <= 3:
|
||||
update_kwargs["dewarp_result"] = None
|
||||
if from_step <= 2:
|
||||
update_kwargs["deskew_result"] = None
|
||||
if from_step <= 1:
|
||||
update_kwargs["orientation_result"] = None
|
||||
|
||||
await update_session_db(session_id, **update_kwargs)
|
||||
|
||||
# Also clear cache
|
||||
if session_id in _cache:
|
||||
for key in list(update_kwargs.keys()):
|
||||
if key != "current_step":
|
||||
_cache[session_id][key] = update_kwargs[key]
|
||||
_cache[session_id]["current_step"] = from_step
|
||||
|
||||
logger.info(f"Session {session_id} reprocessing from step {from_step}")
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"from_step": from_step,
|
||||
"cleared": [k for k in update_kwargs if k != "current_step"],
|
||||
}
|
||||
@@ -1,758 +1,33 @@
|
||||
"""
|
||||
Page Crop - Content-based crop for scanned pages and book scans.
|
||||
Page Crop — Barrel Re-export
|
||||
|
||||
Detects the content boundary by analysing ink density projections and
|
||||
(for book scans) the spine shadow gradient. Works with both loose A4
|
||||
sheets on dark scanners AND book scans with white backgrounds.
|
||||
Content-based crop for scanned pages and book scans.
|
||||
|
||||
Split into:
|
||||
- page_crop_edges.py — Edge detection (spine shadow, gutter, projection)
|
||||
- page_crop_core.py — Main crop algorithm and format detection
|
||||
|
||||
All public names are re-exported here for backward compatibility.
|
||||
License: Apache 2.0
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, Tuple, Optional
|
||||
# Core: main crop functions and format detection
|
||||
from page_crop_core import ( # noqa: F401
|
||||
PAPER_FORMATS,
|
||||
detect_page_splits,
|
||||
detect_and_crop_page,
|
||||
_detect_format,
|
||||
)
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Known paper format aspect ratios (height / width, portrait orientation)
|
||||
PAPER_FORMATS = {
|
||||
"A4": 297.0 / 210.0, # 1.4143
|
||||
"A5": 210.0 / 148.0, # 1.4189
|
||||
"Letter": 11.0 / 8.5, # 1.2941
|
||||
"Legal": 14.0 / 8.5, # 1.6471
|
||||
"A3": 420.0 / 297.0, # 1.4141
|
||||
}
|
||||
|
||||
# Minimum ink density (fraction of pixels) to count a row/column as "content"
|
||||
_INK_THRESHOLD = 0.003 # 0.3%
|
||||
|
||||
# Minimum run length (fraction of dimension) to keep — shorter runs are noise
|
||||
_MIN_RUN_FRAC = 0.005 # 0.5%
|
||||
|
||||
|
||||
def detect_page_splits(
|
||||
img_bgr: np.ndarray,
|
||||
) -> list:
|
||||
"""Detect if the image is a multi-page spread and return split rectangles.
|
||||
|
||||
Uses **brightness** (not ink density) to find the spine area:
|
||||
the scanner bed produces a characteristic gray strip where pages meet,
|
||||
which is darker than the white paper on either side.
|
||||
|
||||
Returns a list of page dicts ``{x, y, width, height, page_index}``
|
||||
or an empty list if only one page is detected.
|
||||
"""
|
||||
h, w = img_bgr.shape[:2]
|
||||
|
||||
# Only check landscape-ish images (width > height * 1.15)
|
||||
if w < h * 1.15:
|
||||
return []
|
||||
|
||||
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# Column-mean brightness (0-255) — the spine is darker (gray scanner bed)
|
||||
col_brightness = np.mean(gray, axis=0).astype(np.float64)
|
||||
|
||||
# Heavy smoothing to ignore individual text lines
|
||||
kern = max(11, w // 50)
|
||||
if kern % 2 == 0:
|
||||
kern += 1
|
||||
brightness_smooth = np.convolve(col_brightness, np.ones(kern) / kern, mode="same")
|
||||
|
||||
# Page paper is bright (typically > 200), spine/scanner bed is darker
|
||||
page_brightness = float(np.max(brightness_smooth))
|
||||
if page_brightness < 100:
|
||||
return [] # Very dark image, skip
|
||||
|
||||
# Spine threshold: significantly darker than the page
|
||||
# Spine is typically 60-80% of paper brightness
|
||||
spine_thresh = page_brightness * 0.88
|
||||
|
||||
# Search in center region (30-70% of width)
|
||||
center_lo = int(w * 0.30)
|
||||
center_hi = int(w * 0.70)
|
||||
|
||||
# Find the darkest valley in the center region
|
||||
center_brightness = brightness_smooth[center_lo:center_hi]
|
||||
darkest_val = float(np.min(center_brightness))
|
||||
|
||||
if darkest_val >= spine_thresh:
|
||||
logger.debug("No spine detected: min brightness %.0f >= threshold %.0f",
|
||||
darkest_val, spine_thresh)
|
||||
return []
|
||||
|
||||
# Find ALL contiguous dark runs in the center region
|
||||
is_dark = center_brightness < spine_thresh
|
||||
dark_runs: list = [] # list of (start, end) pairs
|
||||
run_start = -1
|
||||
for i in range(len(is_dark)):
|
||||
if is_dark[i]:
|
||||
if run_start < 0:
|
||||
run_start = i
|
||||
else:
|
||||
if run_start >= 0:
|
||||
dark_runs.append((run_start, i))
|
||||
run_start = -1
|
||||
if run_start >= 0:
|
||||
dark_runs.append((run_start, len(is_dark)))
|
||||
|
||||
# Filter out runs that are too narrow (< 1% of image width)
|
||||
min_spine_px = int(w * 0.01)
|
||||
dark_runs = [(s, e) for s, e in dark_runs if e - s >= min_spine_px]
|
||||
|
||||
if not dark_runs:
|
||||
logger.debug("No dark runs wider than %dpx in center region", min_spine_px)
|
||||
return []
|
||||
|
||||
# Score each dark run: prefer centered, dark, narrow valleys
|
||||
center_region_len = center_hi - center_lo
|
||||
image_center_in_region = (w * 0.5 - center_lo) # x=50% mapped into region coords
|
||||
best_score = -1.0
|
||||
best_start, best_end = dark_runs[0]
|
||||
|
||||
for rs, re in dark_runs:
|
||||
run_width = re - rs
|
||||
run_center = (rs + re) / 2.0
|
||||
|
||||
# --- Factor 1: Proximity to image center (gaussian, sigma = 15% of region) ---
|
||||
sigma = center_region_len * 0.15
|
||||
dist = abs(run_center - image_center_in_region)
|
||||
center_factor = float(np.exp(-0.5 * (dist / sigma) ** 2))
|
||||
|
||||
# --- Factor 2: Darkness (how dark is the valley relative to threshold) ---
|
||||
run_brightness = float(np.mean(center_brightness[rs:re]))
|
||||
# Normalize: 1.0 when run_brightness == 0, 0.0 when run_brightness == spine_thresh
|
||||
darkness_factor = max(0.0, (spine_thresh - run_brightness) / spine_thresh)
|
||||
|
||||
# --- Factor 3: Narrowness bonus (spine shadows are narrow, not wide plateaus) ---
|
||||
# Typical spine: 1-5% of image width. Penalise runs wider than ~8%.
|
||||
width_frac = run_width / w
|
||||
if width_frac <= 0.05:
|
||||
narrowness_bonus = 1.0
|
||||
elif width_frac <= 0.15:
|
||||
narrowness_bonus = 1.0 - (width_frac - 0.05) / 0.10 # linear decay 1.0 → 0.0
|
||||
else:
|
||||
narrowness_bonus = 0.0
|
||||
|
||||
score = center_factor * darkness_factor * (0.3 + 0.7 * narrowness_bonus)
|
||||
|
||||
logger.debug(
|
||||
"Dark run x=%d..%d (w=%d): center_f=%.3f dark_f=%.3f narrow_b=%.3f → score=%.4f",
|
||||
center_lo + rs, center_lo + re, run_width,
|
||||
center_factor, darkness_factor, narrowness_bonus, score,
|
||||
)
|
||||
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_start, best_end = rs, re
|
||||
|
||||
spine_w = best_end - best_start
|
||||
spine_x = center_lo + best_start
|
||||
spine_center = spine_x + spine_w // 2
|
||||
|
||||
logger.debug(
|
||||
"Best spine candidate: x=%d..%d (w=%d), score=%.4f",
|
||||
spine_x, spine_x + spine_w, spine_w, best_score,
|
||||
)
|
||||
|
||||
# Verify: must have bright (paper) content on BOTH sides
|
||||
left_brightness = float(np.mean(brightness_smooth[max(0, spine_x - w // 10):spine_x]))
|
||||
right_end = center_lo + best_end
|
||||
right_brightness = float(np.mean(brightness_smooth[right_end:min(w, right_end + w // 10)]))
|
||||
|
||||
if left_brightness < spine_thresh or right_brightness < spine_thresh:
|
||||
logger.debug("No bright paper flanking spine: left=%.0f right=%.0f thresh=%.0f",
|
||||
left_brightness, right_brightness, spine_thresh)
|
||||
return []
|
||||
|
||||
logger.info(
|
||||
"Spine detected: x=%d..%d (w=%d), brightness=%.0f vs paper=%.0f, "
|
||||
"left_paper=%.0f, right_paper=%.0f",
|
||||
spine_x, right_end, spine_w, darkest_val, page_brightness,
|
||||
left_brightness, right_brightness,
|
||||
)
|
||||
|
||||
# Split at the spine center
|
||||
split_points = [spine_center]
|
||||
|
||||
# Build page rectangles
|
||||
pages: list = []
|
||||
prev_x = 0
|
||||
for i, sx in enumerate(split_points):
|
||||
pages.append({"x": prev_x, "y": 0, "width": sx - prev_x,
|
||||
"height": h, "page_index": i})
|
||||
prev_x = sx
|
||||
pages.append({"x": prev_x, "y": 0, "width": w - prev_x,
|
||||
"height": h, "page_index": len(split_points)})
|
||||
|
||||
# Filter out tiny pages (< 15% of total width)
|
||||
pages = [p for p in pages if p["width"] >= w * 0.15]
|
||||
if len(pages) < 2:
|
||||
return []
|
||||
|
||||
# Re-index
|
||||
for i, p in enumerate(pages):
|
||||
p["page_index"] = i
|
||||
|
||||
logger.info(
|
||||
"Page split detected: %d pages, spine_w=%d, split_points=%s",
|
||||
len(pages), spine_w, split_points,
|
||||
)
|
||||
return pages
|
||||
|
||||
|
||||
def detect_and_crop_page(
|
||||
img_bgr: np.ndarray,
|
||||
margin_frac: float = 0.01,
|
||||
) -> Tuple[np.ndarray, Dict[str, Any]]:
|
||||
"""Detect content boundary and crop scanner/book borders.
|
||||
|
||||
Algorithm (4-edge detection):
|
||||
1. Adaptive threshold → binary (text=255, bg=0)
|
||||
2. Left edge: spine-shadow detection via grayscale column means,
|
||||
fallback to binary vertical projection
|
||||
3. Right edge: binary vertical projection (last ink column)
|
||||
4. Top/bottom edges: binary horizontal projection
|
||||
5. Sanity checks, then crop with configurable margin
|
||||
|
||||
Args:
|
||||
img_bgr: Input BGR image (should already be deskewed/dewarped)
|
||||
margin_frac: Extra margin around content (fraction of dimension, default 1%)
|
||||
|
||||
Returns:
|
||||
Tuple of (cropped_image, result_dict)
|
||||
"""
|
||||
h, w = img_bgr.shape[:2]
|
||||
total_area = h * w
|
||||
|
||||
result: Dict[str, Any] = {
|
||||
"crop_applied": False,
|
||||
"crop_rect": None,
|
||||
"crop_rect_pct": None,
|
||||
"original_size": {"width": w, "height": h},
|
||||
"cropped_size": {"width": w, "height": h},
|
||||
"detected_format": None,
|
||||
"format_confidence": 0.0,
|
||||
"aspect_ratio": round(max(h, w) / max(min(h, w), 1), 4),
|
||||
"border_fractions": {"top": 0.0, "bottom": 0.0, "left": 0.0, "right": 0.0},
|
||||
}
|
||||
|
||||
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# --- Binarise with adaptive threshold (works for white-on-white) ---
|
||||
binary = cv2.adaptiveThreshold(
|
||||
gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
|
||||
cv2.THRESH_BINARY_INV, blockSize=51, C=15,
|
||||
)
|
||||
|
||||
# --- Left edge: spine-shadow detection ---
|
||||
left_edge = _detect_left_edge_shadow(gray, binary, w, h)
|
||||
|
||||
# --- Right edge: spine-shadow detection ---
|
||||
right_edge = _detect_right_edge_shadow(gray, binary, w, h)
|
||||
|
||||
# --- Top / bottom edges: binary horizontal projection ---
|
||||
top_edge, bottom_edge = _detect_top_bottom_edges(binary, w, h)
|
||||
|
||||
# Compute border fractions
|
||||
border_top = top_edge / h
|
||||
border_bottom = (h - bottom_edge) / h
|
||||
border_left = left_edge / w
|
||||
border_right = (w - right_edge) / w
|
||||
|
||||
result["border_fractions"] = {
|
||||
"top": round(border_top, 4),
|
||||
"bottom": round(border_bottom, 4),
|
||||
"left": round(border_left, 4),
|
||||
"right": round(border_right, 4),
|
||||
}
|
||||
|
||||
# Sanity: only crop if at least one edge has > 2% border
|
||||
min_border = 0.02
|
||||
if all(f < min_border for f in [border_top, border_bottom, border_left, border_right]):
|
||||
logger.info("All borders < %.0f%% — no crop needed", min_border * 100)
|
||||
result["detected_format"], result["format_confidence"] = _detect_format(w, h)
|
||||
return img_bgr, result
|
||||
|
||||
# Add margin
|
||||
margin_x = int(w * margin_frac)
|
||||
margin_y = int(h * margin_frac)
|
||||
|
||||
crop_x = max(0, left_edge - margin_x)
|
||||
crop_y = max(0, top_edge - margin_y)
|
||||
crop_x2 = min(w, right_edge + margin_x)
|
||||
crop_y2 = min(h, bottom_edge + margin_y)
|
||||
|
||||
crop_w = crop_x2 - crop_x
|
||||
crop_h = crop_y2 - crop_y
|
||||
|
||||
# Sanity: cropped area must be >= 40% of original
|
||||
if crop_w * crop_h < 0.40 * total_area:
|
||||
logger.warning("Cropped area too small (%.0f%%) — skipping crop",
|
||||
100.0 * crop_w * crop_h / total_area)
|
||||
result["detected_format"], result["format_confidence"] = _detect_format(w, h)
|
||||
return img_bgr, result
|
||||
|
||||
cropped = img_bgr[crop_y:crop_y2, crop_x:crop_x2].copy()
|
||||
|
||||
detected_format, format_confidence = _detect_format(crop_w, crop_h)
|
||||
|
||||
result["crop_applied"] = True
|
||||
result["crop_rect"] = {"x": crop_x, "y": crop_y, "width": crop_w, "height": crop_h}
|
||||
result["crop_rect_pct"] = {
|
||||
"x": round(100.0 * crop_x / w, 2),
|
||||
"y": round(100.0 * crop_y / h, 2),
|
||||
"width": round(100.0 * crop_w / w, 2),
|
||||
"height": round(100.0 * crop_h / h, 2),
|
||||
}
|
||||
result["cropped_size"] = {"width": crop_w, "height": crop_h}
|
||||
result["detected_format"] = detected_format
|
||||
result["format_confidence"] = format_confidence
|
||||
result["aspect_ratio"] = round(max(crop_w, crop_h) / max(min(crop_w, crop_h), 1), 4)
|
||||
|
||||
logger.info(
|
||||
"Page cropped: %dx%d -> %dx%d, format=%s (%.0f%%), "
|
||||
"borders: T=%.1f%% B=%.1f%% L=%.1f%% R=%.1f%%",
|
||||
w, h, crop_w, crop_h, detected_format, format_confidence * 100,
|
||||
border_top * 100, border_bottom * 100,
|
||||
border_left * 100, border_right * 100,
|
||||
)
|
||||
|
||||
return cropped, result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Edge detection helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _detect_spine_shadow(
|
||||
gray: np.ndarray,
|
||||
search_region: np.ndarray,
|
||||
offset_x: int,
|
||||
w: int,
|
||||
side: str,
|
||||
) -> Optional[int]:
|
||||
"""Find the book spine center (darkest point) in a scanner shadow.
|
||||
|
||||
The scanner produces a gray strip where the book spine presses against
|
||||
the glass. The darkest column in that strip is the spine center —
|
||||
that's where we crop.
|
||||
|
||||
Distinguishes real spine shadows from text content by checking:
|
||||
1. Strong brightness range (> 40 levels)
|
||||
2. Darkest point is genuinely dark (< 180 mean brightness)
|
||||
3. The dark area is a NARROW valley, not a text-content plateau
|
||||
4. Brightness rises significantly toward the page content side
|
||||
|
||||
Args:
|
||||
gray: Full grayscale image (for context).
|
||||
search_region: Column slice of the grayscale image to search in.
|
||||
offset_x: X offset of search_region relative to full image.
|
||||
w: Full image width.
|
||||
side: 'left' or 'right' (for logging).
|
||||
|
||||
Returns:
|
||||
X coordinate (in full image) of the spine center, or None.
|
||||
"""
|
||||
region_w = search_region.shape[1]
|
||||
if region_w < 10:
|
||||
return None
|
||||
|
||||
# Column-mean brightness in the search region
|
||||
col_means = np.mean(search_region, axis=0).astype(np.float64)
|
||||
|
||||
# Smooth with boxcar kernel (width = 1% of image width, min 5)
|
||||
kernel_size = max(5, w // 100)
|
||||
if kernel_size % 2 == 0:
|
||||
kernel_size += 1
|
||||
kernel = np.ones(kernel_size) / kernel_size
|
||||
smoothed_raw = np.convolve(col_means, kernel, mode="same")
|
||||
|
||||
# Trim convolution edge artifacts (edges are zero-padded → artificially low)
|
||||
margin = kernel_size // 2
|
||||
if region_w <= 2 * margin + 10:
|
||||
return None
|
||||
smoothed = smoothed_raw[margin:region_w - margin]
|
||||
trim_offset = margin # offset of smoothed[0] relative to search_region
|
||||
|
||||
val_min = float(np.min(smoothed))
|
||||
val_max = float(np.max(smoothed))
|
||||
shadow_range = val_max - val_min
|
||||
|
||||
# --- Check 1: Strong brightness gradient ---
|
||||
if shadow_range <= 40:
|
||||
logger.debug(
|
||||
"%s edge: no spine (range=%.0f <= 40)", side.capitalize(), shadow_range,
|
||||
)
|
||||
return None
|
||||
|
||||
# --- Check 2: Darkest point must be genuinely dark ---
|
||||
# Spine shadows have mean column brightness 60-160.
|
||||
# Text on white paper stays above 180.
|
||||
if val_min > 180:
|
||||
logger.debug(
|
||||
"%s edge: no spine (darkest=%.0f > 180, likely text)", side.capitalize(), val_min,
|
||||
)
|
||||
return None
|
||||
|
||||
spine_idx = int(np.argmin(smoothed)) # index in trimmed array
|
||||
spine_local = spine_idx + trim_offset # index in search_region
|
||||
trimmed_len = len(smoothed)
|
||||
|
||||
# --- Check 3: Valley width (spine is narrow, text plateau is wide) ---
|
||||
# Count how many columns are within 20% of the shadow range above the min.
|
||||
valley_thresh = val_min + shadow_range * 0.20
|
||||
valley_mask = smoothed < valley_thresh
|
||||
valley_width = int(np.sum(valley_mask))
|
||||
# Spine valleys are typically 3-15% of image width (20-120px on a 800px image).
|
||||
# Text content plateaus span 20%+ of the search region.
|
||||
max_valley_frac = 0.50 # valley must not cover more than half the trimmed region
|
||||
if valley_width > trimmed_len * max_valley_frac:
|
||||
logger.debug(
|
||||
"%s edge: no spine (valley too wide: %d/%d = %.0f%%)",
|
||||
side.capitalize(), valley_width, trimmed_len,
|
||||
100.0 * valley_width / trimmed_len,
|
||||
)
|
||||
return None
|
||||
|
||||
# --- Check 4: Brightness must rise toward page content ---
|
||||
# For left edge: after spine, brightness should rise (= page paper)
|
||||
# For right edge: before spine, brightness should rise
|
||||
rise_check_w = max(5, trimmed_len // 5) # check 20% of trimmed region
|
||||
if side == "left":
|
||||
# Check columns to the right of the spine (in trimmed array)
|
||||
right_start = min(spine_idx + 5, trimmed_len - 1)
|
||||
right_end = min(right_start + rise_check_w, trimmed_len)
|
||||
if right_end > right_start:
|
||||
rise_brightness = float(np.mean(smoothed[right_start:right_end]))
|
||||
rise = rise_brightness - val_min
|
||||
if rise < shadow_range * 0.3:
|
||||
logger.debug(
|
||||
"%s edge: no spine (insufficient rise: %.0f, need %.0f)",
|
||||
side.capitalize(), rise, shadow_range * 0.3,
|
||||
)
|
||||
return None
|
||||
else: # right
|
||||
# Check columns to the left of the spine (in trimmed array)
|
||||
left_end = max(spine_idx - 5, 0)
|
||||
left_start = max(left_end - rise_check_w, 0)
|
||||
if left_end > left_start:
|
||||
rise_brightness = float(np.mean(smoothed[left_start:left_end]))
|
||||
rise = rise_brightness - val_min
|
||||
if rise < shadow_range * 0.3:
|
||||
logger.debug(
|
||||
"%s edge: no spine (insufficient rise: %.0f, need %.0f)",
|
||||
side.capitalize(), rise, shadow_range * 0.3,
|
||||
)
|
||||
return None
|
||||
|
||||
spine_x = offset_x + spine_local
|
||||
|
||||
logger.info(
|
||||
"%s edge: spine center at x=%d (brightness=%.0f, range=%.0f, valley=%dpx)",
|
||||
side.capitalize(), spine_x, val_min, shadow_range, valley_width,
|
||||
)
|
||||
return spine_x
|
||||
|
||||
|
||||
def _detect_gutter_continuity(
|
||||
gray: np.ndarray,
|
||||
search_region: np.ndarray,
|
||||
offset_x: int,
|
||||
w: int,
|
||||
side: str,
|
||||
) -> Optional[int]:
|
||||
"""Detect gutter shadow via vertical continuity analysis.
|
||||
|
||||
Camera book scans produce a subtle brightness gradient at the gutter
|
||||
that is too faint for scanner-shadow detection (range < 40). However,
|
||||
the gutter shadow has a unique property: it runs **continuously from
|
||||
top to bottom** without interruption. Text and images always have
|
||||
vertical gaps between lines, paragraphs, or sections.
|
||||
|
||||
Algorithm:
|
||||
1. Divide image into N horizontal strips (~60px each)
|
||||
2. For each column, compute what fraction of strips are darker than
|
||||
the page median (from the center 50% of the full image)
|
||||
3. A "gutter column" has ≥ 75% of strips darker than page_median − δ
|
||||
4. Smooth the dark-fraction profile and find the transition point
|
||||
from the edge inward where the fraction drops below 0.50
|
||||
5. Validate: gutter band must be 0.5%-10% of image width
|
||||
|
||||
Args:
|
||||
gray: Full grayscale image.
|
||||
search_region: Edge slice of the grayscale image.
|
||||
offset_x: X offset of search_region relative to full image.
|
||||
w: Full image width.
|
||||
side: 'left' or 'right'.
|
||||
|
||||
Returns:
|
||||
X coordinate (in full image) of the gutter inner edge, or None.
|
||||
"""
|
||||
region_h, region_w = search_region.shape[:2]
|
||||
if region_w < 20 or region_h < 100:
|
||||
return None
|
||||
|
||||
# --- 1. Divide into horizontal strips ---
|
||||
strip_target_h = 60 # ~60px per strip
|
||||
n_strips = max(10, region_h // strip_target_h)
|
||||
strip_h = region_h // n_strips
|
||||
|
||||
strip_means = np.zeros((n_strips, region_w), dtype=np.float64)
|
||||
for s in range(n_strips):
|
||||
y0 = s * strip_h
|
||||
y1 = min((s + 1) * strip_h, region_h)
|
||||
strip_means[s] = np.mean(search_region[y0:y1, :], axis=0)
|
||||
|
||||
# --- 2. Page median from center 50% of full image ---
|
||||
center_lo = w // 4
|
||||
center_hi = 3 * w // 4
|
||||
page_median = float(np.median(gray[:, center_lo:center_hi]))
|
||||
|
||||
# Camera shadows are subtle — threshold just 5 levels below page median
|
||||
dark_thresh = page_median - 5.0
|
||||
|
||||
# If page is very dark overall (e.g. photo, not a book page), bail out
|
||||
if page_median < 180:
|
||||
return None
|
||||
|
||||
# --- 3. Per-column dark fraction ---
|
||||
dark_count = np.sum(strip_means < dark_thresh, axis=0).astype(np.float64)
|
||||
dark_frac = dark_count / n_strips # shape: (region_w,)
|
||||
|
||||
# --- 4. Smooth and find transition ---
|
||||
# Rolling mean (window = 1% of image width, min 5)
|
||||
smooth_w = max(5, w // 100)
|
||||
if smooth_w % 2 == 0:
|
||||
smooth_w += 1
|
||||
kernel = np.ones(smooth_w) / smooth_w
|
||||
frac_smooth = np.convolve(dark_frac, kernel, mode="same")
|
||||
|
||||
# Trim convolution edges
|
||||
margin = smooth_w // 2
|
||||
if region_w <= 2 * margin + 10:
|
||||
return None
|
||||
|
||||
# Find the peak of dark fraction (gutter center).
|
||||
# For right gutters the peak is near the edge; for left gutters
|
||||
# (V-shaped spine shadow) the peak may be well inside the region.
|
||||
transition_thresh = 0.50
|
||||
peak_frac = float(np.max(frac_smooth[margin:region_w - margin]))
|
||||
|
||||
if peak_frac < 0.70:
|
||||
logger.debug(
|
||||
"%s gutter: peak dark fraction %.2f < 0.70", side.capitalize(), peak_frac,
|
||||
)
|
||||
return None
|
||||
|
||||
peak_x = int(np.argmax(frac_smooth[margin:region_w - margin])) + margin
|
||||
gutter_inner = None # local x in search_region
|
||||
|
||||
if side == "right":
|
||||
# Scan from peak toward the page center (leftward)
|
||||
for x in range(peak_x, margin, -1):
|
||||
if frac_smooth[x] < transition_thresh:
|
||||
gutter_inner = x + 1
|
||||
break
|
||||
else:
|
||||
# Scan from peak toward the page center (rightward)
|
||||
for x in range(peak_x, region_w - margin):
|
||||
if frac_smooth[x] < transition_thresh:
|
||||
gutter_inner = x - 1
|
||||
break
|
||||
|
||||
if gutter_inner is None:
|
||||
return None
|
||||
|
||||
# --- 5. Validate gutter width ---
|
||||
if side == "right":
|
||||
gutter_width = region_w - gutter_inner
|
||||
else:
|
||||
gutter_width = gutter_inner
|
||||
|
||||
min_gutter = max(3, int(w * 0.005)) # at least 0.5% of image
|
||||
max_gutter = int(w * 0.10) # at most 10% of image
|
||||
|
||||
if gutter_width < min_gutter:
|
||||
logger.debug(
|
||||
"%s gutter: too narrow (%dpx < %dpx)", side.capitalize(),
|
||||
gutter_width, min_gutter,
|
||||
)
|
||||
return None
|
||||
|
||||
if gutter_width > max_gutter:
|
||||
logger.debug(
|
||||
"%s gutter: too wide (%dpx > %dpx)", side.capitalize(),
|
||||
gutter_width, max_gutter,
|
||||
)
|
||||
return None
|
||||
|
||||
# Check that the gutter band is meaningfully darker than the page
|
||||
if side == "right":
|
||||
gutter_brightness = float(np.mean(strip_means[:, gutter_inner:]))
|
||||
else:
|
||||
gutter_brightness = float(np.mean(strip_means[:, :gutter_inner]))
|
||||
|
||||
brightness_drop = page_median - gutter_brightness
|
||||
if brightness_drop < 3:
|
||||
logger.debug(
|
||||
"%s gutter: insufficient brightness drop (%.1f levels)",
|
||||
side.capitalize(), brightness_drop,
|
||||
)
|
||||
return None
|
||||
|
||||
gutter_x = offset_x + gutter_inner
|
||||
|
||||
logger.info(
|
||||
"%s gutter (continuity): x=%d, width=%dpx (%.1f%%), "
|
||||
"brightness=%.0f vs page=%.0f (drop=%.0f), frac@edge=%.2f",
|
||||
side.capitalize(), gutter_x, gutter_width,
|
||||
100.0 * gutter_width / w, gutter_brightness, page_median,
|
||||
brightness_drop, float(frac_smooth[gutter_inner]),
|
||||
)
|
||||
return gutter_x
|
||||
|
||||
|
||||
def _detect_left_edge_shadow(
|
||||
gray: np.ndarray,
|
||||
binary: np.ndarray,
|
||||
w: int,
|
||||
h: int,
|
||||
) -> int:
|
||||
"""Detect left content edge, accounting for book-spine shadow.
|
||||
|
||||
Tries three methods in order:
|
||||
1. Scanner spine-shadow (dark gradient, range > 40)
|
||||
2. Camera gutter continuity (subtle shadow running top-to-bottom)
|
||||
3. Binary projection fallback (first ink column)
|
||||
"""
|
||||
search_w = max(1, w // 4)
|
||||
spine_x = _detect_spine_shadow(gray, gray[:, :search_w], 0, w, "left")
|
||||
if spine_x is not None:
|
||||
return spine_x
|
||||
|
||||
# Fallback 1: vertical continuity (camera gutter shadow)
|
||||
gutter_x = _detect_gutter_continuity(gray, gray[:, :search_w], 0, w, "left")
|
||||
if gutter_x is not None:
|
||||
return gutter_x
|
||||
|
||||
# Fallback 2: binary vertical projection
|
||||
return _detect_edge_projection(binary, axis=0, from_start=True, dim=w)
|
||||
|
||||
|
||||
def _detect_right_edge_shadow(
|
||||
gray: np.ndarray,
|
||||
binary: np.ndarray,
|
||||
w: int,
|
||||
h: int,
|
||||
) -> int:
|
||||
"""Detect right content edge, accounting for book-spine shadow.
|
||||
|
||||
Tries three methods in order:
|
||||
1. Scanner spine-shadow (dark gradient, range > 40)
|
||||
2. Camera gutter continuity (subtle shadow running top-to-bottom)
|
||||
3. Binary projection fallback (last ink column)
|
||||
"""
|
||||
search_w = max(1, w // 4)
|
||||
right_start = w - search_w
|
||||
spine_x = _detect_spine_shadow(gray, gray[:, right_start:], right_start, w, "right")
|
||||
if spine_x is not None:
|
||||
return spine_x
|
||||
|
||||
# Fallback 1: vertical continuity (camera gutter shadow)
|
||||
gutter_x = _detect_gutter_continuity(gray, gray[:, right_start:], right_start, w, "right")
|
||||
if gutter_x is not None:
|
||||
return gutter_x
|
||||
|
||||
# Fallback 2: binary vertical projection
|
||||
return _detect_edge_projection(binary, axis=0, from_start=False, dim=w)
|
||||
|
||||
|
||||
def _detect_top_bottom_edges(binary: np.ndarray, w: int, h: int) -> Tuple[int, int]:
|
||||
"""Detect top and bottom content edges via binary horizontal projection."""
|
||||
top = _detect_edge_projection(binary, axis=1, from_start=True, dim=h)
|
||||
bottom = _detect_edge_projection(binary, axis=1, from_start=False, dim=h)
|
||||
return top, bottom
|
||||
|
||||
|
||||
def _detect_edge_projection(
|
||||
binary: np.ndarray,
|
||||
axis: int,
|
||||
from_start: bool,
|
||||
dim: int,
|
||||
) -> int:
|
||||
"""Find the first/last row or column with ink density above threshold.
|
||||
|
||||
axis=0 → project vertically (column densities) → returns x position
|
||||
axis=1 → project horizontally (row densities) → returns y position
|
||||
|
||||
Filters out narrow noise runs shorter than _MIN_RUN_FRAC of the dimension.
|
||||
"""
|
||||
# Compute density per row/column (mean of binary pixels / 255)
|
||||
projection = np.mean(binary, axis=axis) / 255.0
|
||||
|
||||
# Create mask of "ink" positions
|
||||
ink_mask = projection >= _INK_THRESHOLD
|
||||
|
||||
# Filter narrow runs (noise)
|
||||
min_run = max(1, int(dim * _MIN_RUN_FRAC))
|
||||
ink_mask = _filter_narrow_runs(ink_mask, min_run)
|
||||
|
||||
ink_positions = np.where(ink_mask)[0]
|
||||
if len(ink_positions) == 0:
|
||||
return 0 if from_start else dim
|
||||
|
||||
if from_start:
|
||||
return int(ink_positions[0])
|
||||
else:
|
||||
return int(ink_positions[-1])
|
||||
|
||||
|
||||
def _filter_narrow_runs(mask: np.ndarray, min_run: int) -> np.ndarray:
|
||||
"""Remove True-runs shorter than min_run pixels."""
|
||||
if min_run <= 1:
|
||||
return mask
|
||||
|
||||
result = mask.copy()
|
||||
n = len(result)
|
||||
i = 0
|
||||
while i < n:
|
||||
if result[i]:
|
||||
start = i
|
||||
while i < n and result[i]:
|
||||
i += 1
|
||||
if i - start < min_run:
|
||||
result[start:i] = False
|
||||
else:
|
||||
i += 1
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Format detection (kept as optional metadata)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _detect_format(width: int, height: int) -> Tuple[str, float]:
|
||||
"""Detect paper format from dimensions by comparing aspect ratios."""
|
||||
if width <= 0 or height <= 0:
|
||||
return "unknown", 0.0
|
||||
|
||||
aspect = max(width, height) / min(width, height)
|
||||
|
||||
best_format = "unknown"
|
||||
best_diff = float("inf")
|
||||
|
||||
for fmt, expected_ratio in PAPER_FORMATS.items():
|
||||
diff = abs(aspect - expected_ratio)
|
||||
if diff < best_diff:
|
||||
best_diff = diff
|
||||
best_format = fmt
|
||||
|
||||
confidence = max(0.0, 1.0 - best_diff * 5.0)
|
||||
|
||||
if confidence < 0.3:
|
||||
return "unknown", 0.0
|
||||
|
||||
return best_format, round(confidence, 3)
|
||||
from page_crop_edges import ( # noqa: F401
|
||||
_INK_THRESHOLD,
|
||||
_MIN_RUN_FRAC,
|
||||
_detect_spine_shadow,
|
||||
_detect_gutter_continuity,
|
||||
_detect_left_edge_shadow,
|
||||
_detect_right_edge_shadow,
|
||||
_detect_top_bottom_edges,
|
||||
_detect_edge_projection,
|
||||
_filter_narrow_runs,
|
||||
)
|
||||
|
||||
342
klausur-service/backend/page_crop_core.py
Normal file
342
klausur-service/backend/page_crop_core.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""
|
||||
Page Crop - Core Crop and Format Detection
|
||||
|
||||
Content-based crop for scanned pages and book scans. Detects the content
|
||||
boundary by analysing ink density projections and (for book scans) the
|
||||
spine shadow gradient.
|
||||
|
||||
Extracted from page_crop.py to keep files under 500 LOC.
|
||||
License: Apache 2.0
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from page_crop_edges import (
|
||||
_detect_left_edge_shadow,
|
||||
_detect_right_edge_shadow,
|
||||
_detect_top_bottom_edges,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Known paper format aspect ratios (height / width, portrait orientation)
|
||||
PAPER_FORMATS = {
|
||||
"A4": 297.0 / 210.0, # 1.4143
|
||||
"A5": 210.0 / 148.0, # 1.4189
|
||||
"Letter": 11.0 / 8.5, # 1.2941
|
||||
"Legal": 14.0 / 8.5, # 1.6471
|
||||
"A3": 420.0 / 297.0, # 1.4141
|
||||
}
|
||||
|
||||
|
||||
def detect_page_splits(
|
||||
img_bgr: np.ndarray,
|
||||
) -> list:
|
||||
"""Detect if the image is a multi-page spread and return split rectangles.
|
||||
|
||||
Uses **brightness** (not ink density) to find the spine area:
|
||||
the scanner bed produces a characteristic gray strip where pages meet,
|
||||
which is darker than the white paper on either side.
|
||||
|
||||
Returns a list of page dicts ``{x, y, width, height, page_index}``
|
||||
or an empty list if only one page is detected.
|
||||
"""
|
||||
h, w = img_bgr.shape[:2]
|
||||
|
||||
# Only check landscape-ish images (width > height * 1.15)
|
||||
if w < h * 1.15:
|
||||
return []
|
||||
|
||||
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# Column-mean brightness (0-255) — the spine is darker (gray scanner bed)
|
||||
col_brightness = np.mean(gray, axis=0).astype(np.float64)
|
||||
|
||||
# Heavy smoothing to ignore individual text lines
|
||||
kern = max(11, w // 50)
|
||||
if kern % 2 == 0:
|
||||
kern += 1
|
||||
brightness_smooth = np.convolve(col_brightness, np.ones(kern) / kern, mode="same")
|
||||
|
||||
# Page paper is bright (typically > 200), spine/scanner bed is darker
|
||||
page_brightness = float(np.max(brightness_smooth))
|
||||
if page_brightness < 100:
|
||||
return [] # Very dark image, skip
|
||||
|
||||
# Spine threshold: significantly darker than the page
|
||||
spine_thresh = page_brightness * 0.88
|
||||
|
||||
# Search in center region (30-70% of width)
|
||||
center_lo = int(w * 0.30)
|
||||
center_hi = int(w * 0.70)
|
||||
|
||||
# Find the darkest valley in the center region
|
||||
center_brightness = brightness_smooth[center_lo:center_hi]
|
||||
darkest_val = float(np.min(center_brightness))
|
||||
|
||||
if darkest_val >= spine_thresh:
|
||||
logger.debug("No spine detected: min brightness %.0f >= threshold %.0f",
|
||||
darkest_val, spine_thresh)
|
||||
return []
|
||||
|
||||
# Find ALL contiguous dark runs in the center region
|
||||
is_dark = center_brightness < spine_thresh
|
||||
dark_runs: list = []
|
||||
run_start = -1
|
||||
for i in range(len(is_dark)):
|
||||
if is_dark[i]:
|
||||
if run_start < 0:
|
||||
run_start = i
|
||||
else:
|
||||
if run_start >= 0:
|
||||
dark_runs.append((run_start, i))
|
||||
run_start = -1
|
||||
if run_start >= 0:
|
||||
dark_runs.append((run_start, len(is_dark)))
|
||||
|
||||
# Filter out runs that are too narrow (< 1% of image width)
|
||||
min_spine_px = int(w * 0.01)
|
||||
dark_runs = [(s, e) for s, e in dark_runs if e - s >= min_spine_px]
|
||||
|
||||
if not dark_runs:
|
||||
logger.debug("No dark runs wider than %dpx in center region", min_spine_px)
|
||||
return []
|
||||
|
||||
# Score each dark run: prefer centered, dark, narrow valleys
|
||||
center_region_len = center_hi - center_lo
|
||||
image_center_in_region = (w * 0.5 - center_lo)
|
||||
best_score = -1.0
|
||||
best_start, best_end = dark_runs[0]
|
||||
|
||||
for rs, re in dark_runs:
|
||||
run_width = re - rs
|
||||
run_center = (rs + re) / 2.0
|
||||
|
||||
sigma = center_region_len * 0.15
|
||||
dist = abs(run_center - image_center_in_region)
|
||||
center_factor = float(np.exp(-0.5 * (dist / sigma) ** 2))
|
||||
|
||||
run_brightness = float(np.mean(center_brightness[rs:re]))
|
||||
darkness_factor = max(0.0, (spine_thresh - run_brightness) / spine_thresh)
|
||||
|
||||
width_frac = run_width / w
|
||||
if width_frac <= 0.05:
|
||||
narrowness_bonus = 1.0
|
||||
elif width_frac <= 0.15:
|
||||
narrowness_bonus = 1.0 - (width_frac - 0.05) / 0.10
|
||||
else:
|
||||
narrowness_bonus = 0.0
|
||||
|
||||
score = center_factor * darkness_factor * (0.3 + 0.7 * narrowness_bonus)
|
||||
|
||||
logger.debug(
|
||||
"Dark run x=%d..%d (w=%d): center_f=%.3f dark_f=%.3f narrow_b=%.3f -> score=%.4f",
|
||||
center_lo + rs, center_lo + re, run_width,
|
||||
center_factor, darkness_factor, narrowness_bonus, score,
|
||||
)
|
||||
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_start, best_end = rs, re
|
||||
|
||||
spine_w = best_end - best_start
|
||||
spine_x = center_lo + best_start
|
||||
spine_center = spine_x + spine_w // 2
|
||||
|
||||
logger.debug(
|
||||
"Best spine candidate: x=%d..%d (w=%d), score=%.4f",
|
||||
spine_x, spine_x + spine_w, spine_w, best_score,
|
||||
)
|
||||
|
||||
# Verify: must have bright (paper) content on BOTH sides
|
||||
left_brightness = float(np.mean(brightness_smooth[max(0, spine_x - w // 10):spine_x]))
|
||||
right_end = center_lo + best_end
|
||||
right_brightness = float(np.mean(brightness_smooth[right_end:min(w, right_end + w // 10)]))
|
||||
|
||||
if left_brightness < spine_thresh or right_brightness < spine_thresh:
|
||||
logger.debug("No bright paper flanking spine: left=%.0f right=%.0f thresh=%.0f",
|
||||
left_brightness, right_brightness, spine_thresh)
|
||||
return []
|
||||
|
||||
logger.info(
|
||||
"Spine detected: x=%d..%d (w=%d), brightness=%.0f vs paper=%.0f, "
|
||||
"left_paper=%.0f, right_paper=%.0f",
|
||||
spine_x, right_end, spine_w, darkest_val, page_brightness,
|
||||
left_brightness, right_brightness,
|
||||
)
|
||||
|
||||
# Split at the spine center
|
||||
split_points = [spine_center]
|
||||
|
||||
# Build page rectangles
|
||||
pages: list = []
|
||||
prev_x = 0
|
||||
for i, sx in enumerate(split_points):
|
||||
pages.append({"x": prev_x, "y": 0, "width": sx - prev_x,
|
||||
"height": h, "page_index": i})
|
||||
prev_x = sx
|
||||
pages.append({"x": prev_x, "y": 0, "width": w - prev_x,
|
||||
"height": h, "page_index": len(split_points)})
|
||||
|
||||
# Filter out tiny pages (< 15% of total width)
|
||||
pages = [p for p in pages if p["width"] >= w * 0.15]
|
||||
if len(pages) < 2:
|
||||
return []
|
||||
|
||||
# Re-index
|
||||
for i, p in enumerate(pages):
|
||||
p["page_index"] = i
|
||||
|
||||
logger.info(
|
||||
"Page split detected: %d pages, spine_w=%d, split_points=%s",
|
||||
len(pages), spine_w, split_points,
|
||||
)
|
||||
return pages
|
||||
|
||||
|
||||
def detect_and_crop_page(
|
||||
img_bgr: np.ndarray,
|
||||
margin_frac: float = 0.01,
|
||||
) -> Tuple[np.ndarray, Dict[str, Any]]:
|
||||
"""Detect content boundary and crop scanner/book borders.
|
||||
|
||||
Algorithm (4-edge detection):
|
||||
1. Adaptive threshold -> binary (text=255, bg=0)
|
||||
2. Left edge: spine-shadow detection via grayscale column means,
|
||||
fallback to binary vertical projection
|
||||
3. Right edge: binary vertical projection (last ink column)
|
||||
4. Top/bottom edges: binary horizontal projection
|
||||
5. Sanity checks, then crop with configurable margin
|
||||
|
||||
Args:
|
||||
img_bgr: Input BGR image (should already be deskewed/dewarped)
|
||||
margin_frac: Extra margin around content (fraction of dimension, default 1%)
|
||||
|
||||
Returns:
|
||||
Tuple of (cropped_image, result_dict)
|
||||
"""
|
||||
h, w = img_bgr.shape[:2]
|
||||
total_area = h * w
|
||||
|
||||
result: Dict[str, Any] = {
|
||||
"crop_applied": False,
|
||||
"crop_rect": None,
|
||||
"crop_rect_pct": None,
|
||||
"original_size": {"width": w, "height": h},
|
||||
"cropped_size": {"width": w, "height": h},
|
||||
"detected_format": None,
|
||||
"format_confidence": 0.0,
|
||||
"aspect_ratio": round(max(h, w) / max(min(h, w), 1), 4),
|
||||
"border_fractions": {"top": 0.0, "bottom": 0.0, "left": 0.0, "right": 0.0},
|
||||
}
|
||||
|
||||
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# --- Binarise with adaptive threshold ---
|
||||
binary = cv2.adaptiveThreshold(
|
||||
gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
|
||||
cv2.THRESH_BINARY_INV, blockSize=51, C=15,
|
||||
)
|
||||
|
||||
# --- Edge detection ---
|
||||
left_edge = _detect_left_edge_shadow(gray, binary, w, h)
|
||||
right_edge = _detect_right_edge_shadow(gray, binary, w, h)
|
||||
top_edge, bottom_edge = _detect_top_bottom_edges(binary, w, h)
|
||||
|
||||
# Compute border fractions
|
||||
border_top = top_edge / h
|
||||
border_bottom = (h - bottom_edge) / h
|
||||
border_left = left_edge / w
|
||||
border_right = (w - right_edge) / w
|
||||
|
||||
result["border_fractions"] = {
|
||||
"top": round(border_top, 4),
|
||||
"bottom": round(border_bottom, 4),
|
||||
"left": round(border_left, 4),
|
||||
"right": round(border_right, 4),
|
||||
}
|
||||
|
||||
# Sanity: only crop if at least one edge has > 2% border
|
||||
min_border = 0.02
|
||||
if all(f < min_border for f in [border_top, border_bottom, border_left, border_right]):
|
||||
logger.info("All borders < %.0f%% — no crop needed", min_border * 100)
|
||||
result["detected_format"], result["format_confidence"] = _detect_format(w, h)
|
||||
return img_bgr, result
|
||||
|
||||
# Add margin
|
||||
margin_x = int(w * margin_frac)
|
||||
margin_y = int(h * margin_frac)
|
||||
|
||||
crop_x = max(0, left_edge - margin_x)
|
||||
crop_y = max(0, top_edge - margin_y)
|
||||
crop_x2 = min(w, right_edge + margin_x)
|
||||
crop_y2 = min(h, bottom_edge + margin_y)
|
||||
|
||||
crop_w = crop_x2 - crop_x
|
||||
crop_h = crop_y2 - crop_y
|
||||
|
||||
# Sanity: cropped area must be >= 40% of original
|
||||
if crop_w * crop_h < 0.40 * total_area:
|
||||
logger.warning("Cropped area too small (%.0f%%) — skipping crop",
|
||||
100.0 * crop_w * crop_h / total_area)
|
||||
result["detected_format"], result["format_confidence"] = _detect_format(w, h)
|
||||
return img_bgr, result
|
||||
|
||||
cropped = img_bgr[crop_y:crop_y2, crop_x:crop_x2].copy()
|
||||
|
||||
detected_format, format_confidence = _detect_format(crop_w, crop_h)
|
||||
|
||||
result["crop_applied"] = True
|
||||
result["crop_rect"] = {"x": crop_x, "y": crop_y, "width": crop_w, "height": crop_h}
|
||||
result["crop_rect_pct"] = {
|
||||
"x": round(100.0 * crop_x / w, 2),
|
||||
"y": round(100.0 * crop_y / h, 2),
|
||||
"width": round(100.0 * crop_w / w, 2),
|
||||
"height": round(100.0 * crop_h / h, 2),
|
||||
}
|
||||
result["cropped_size"] = {"width": crop_w, "height": crop_h}
|
||||
result["detected_format"] = detected_format
|
||||
result["format_confidence"] = format_confidence
|
||||
result["aspect_ratio"] = round(max(crop_w, crop_h) / max(min(crop_w, crop_h), 1), 4)
|
||||
|
||||
logger.info(
|
||||
"Page cropped: %dx%d -> %dx%d, format=%s (%.0f%%), "
|
||||
"borders: T=%.1f%% B=%.1f%% L=%.1f%% R=%.1f%%",
|
||||
w, h, crop_w, crop_h, detected_format, format_confidence * 100,
|
||||
border_top * 100, border_bottom * 100,
|
||||
border_left * 100, border_right * 100,
|
||||
)
|
||||
|
||||
return cropped, result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Format detection (kept as optional metadata)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _detect_format(width: int, height: int) -> Tuple[str, float]:
|
||||
"""Detect paper format from dimensions by comparing aspect ratios."""
|
||||
if width <= 0 or height <= 0:
|
||||
return "unknown", 0.0
|
||||
|
||||
aspect = max(width, height) / min(width, height)
|
||||
|
||||
best_format = "unknown"
|
||||
best_diff = float("inf")
|
||||
|
||||
for fmt, expected_ratio in PAPER_FORMATS.items():
|
||||
diff = abs(aspect - expected_ratio)
|
||||
if diff < best_diff:
|
||||
best_diff = diff
|
||||
best_format = fmt
|
||||
|
||||
confidence = max(0.0, 1.0 - best_diff * 5.0)
|
||||
|
||||
if confidence < 0.3:
|
||||
return "unknown", 0.0
|
||||
|
||||
return best_format, round(confidence, 3)
|
||||
388
klausur-service/backend/page_crop_edges.py
Normal file
388
klausur-service/backend/page_crop_edges.py
Normal file
@@ -0,0 +1,388 @@
|
||||
"""
|
||||
Page Crop - Edge Detection Helpers
|
||||
|
||||
Spine shadow detection, gutter continuity analysis, projection-based
|
||||
edge detection, and narrow-run filtering for content cropping.
|
||||
|
||||
Extracted from page_crop.py to keep files under 500 LOC.
|
||||
License: Apache 2.0
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Minimum ink density (fraction of pixels) to count a row/column as "content"
|
||||
_INK_THRESHOLD = 0.003 # 0.3%
|
||||
|
||||
# Minimum run length (fraction of dimension) to keep — shorter runs are noise
|
||||
_MIN_RUN_FRAC = 0.005 # 0.5%
|
||||
|
||||
|
||||
def _detect_spine_shadow(
|
||||
gray: np.ndarray,
|
||||
search_region: np.ndarray,
|
||||
offset_x: int,
|
||||
w: int,
|
||||
side: str,
|
||||
) -> Optional[int]:
|
||||
"""Find the book spine center (darkest point) in a scanner shadow.
|
||||
|
||||
The scanner produces a gray strip where the book spine presses against
|
||||
the glass. The darkest column in that strip is the spine center —
|
||||
that's where we crop.
|
||||
|
||||
Distinguishes real spine shadows from text content by checking:
|
||||
1. Strong brightness range (> 40 levels)
|
||||
2. Darkest point is genuinely dark (< 180 mean brightness)
|
||||
3. The dark area is a NARROW valley, not a text-content plateau
|
||||
4. Brightness rises significantly toward the page content side
|
||||
|
||||
Args:
|
||||
gray: Full grayscale image (for context).
|
||||
search_region: Column slice of the grayscale image to search in.
|
||||
offset_x: X offset of search_region relative to full image.
|
||||
w: Full image width.
|
||||
side: 'left' or 'right' (for logging).
|
||||
|
||||
Returns:
|
||||
X coordinate (in full image) of the spine center, or None.
|
||||
"""
|
||||
region_w = search_region.shape[1]
|
||||
if region_w < 10:
|
||||
return None
|
||||
|
||||
# Column-mean brightness in the search region
|
||||
col_means = np.mean(search_region, axis=0).astype(np.float64)
|
||||
|
||||
# Smooth with boxcar kernel (width = 1% of image width, min 5)
|
||||
kernel_size = max(5, w // 100)
|
||||
if kernel_size % 2 == 0:
|
||||
kernel_size += 1
|
||||
kernel = np.ones(kernel_size) / kernel_size
|
||||
smoothed_raw = np.convolve(col_means, kernel, mode="same")
|
||||
|
||||
# Trim convolution edge artifacts (edges are zero-padded -> artificially low)
|
||||
margin = kernel_size // 2
|
||||
if region_w <= 2 * margin + 10:
|
||||
return None
|
||||
smoothed = smoothed_raw[margin:region_w - margin]
|
||||
trim_offset = margin # offset of smoothed[0] relative to search_region
|
||||
|
||||
val_min = float(np.min(smoothed))
|
||||
val_max = float(np.max(smoothed))
|
||||
shadow_range = val_max - val_min
|
||||
|
||||
# --- Check 1: Strong brightness gradient ---
|
||||
if shadow_range <= 40:
|
||||
logger.debug(
|
||||
"%s edge: no spine (range=%.0f <= 40)", side.capitalize(), shadow_range,
|
||||
)
|
||||
return None
|
||||
|
||||
# --- Check 2: Darkest point must be genuinely dark ---
|
||||
if val_min > 180:
|
||||
logger.debug(
|
||||
"%s edge: no spine (darkest=%.0f > 180, likely text)", side.capitalize(), val_min,
|
||||
)
|
||||
return None
|
||||
|
||||
spine_idx = int(np.argmin(smoothed)) # index in trimmed array
|
||||
spine_local = spine_idx + trim_offset # index in search_region
|
||||
trimmed_len = len(smoothed)
|
||||
|
||||
# --- Check 3: Valley width (spine is narrow, text plateau is wide) ---
|
||||
valley_thresh = val_min + shadow_range * 0.20
|
||||
valley_mask = smoothed < valley_thresh
|
||||
valley_width = int(np.sum(valley_mask))
|
||||
max_valley_frac = 0.50
|
||||
if valley_width > trimmed_len * max_valley_frac:
|
||||
logger.debug(
|
||||
"%s edge: no spine (valley too wide: %d/%d = %.0f%%)",
|
||||
side.capitalize(), valley_width, trimmed_len,
|
||||
100.0 * valley_width / trimmed_len,
|
||||
)
|
||||
return None
|
||||
|
||||
# --- Check 4: Brightness must rise toward page content ---
|
||||
rise_check_w = max(5, trimmed_len // 5)
|
||||
if side == "left":
|
||||
right_start = min(spine_idx + 5, trimmed_len - 1)
|
||||
right_end = min(right_start + rise_check_w, trimmed_len)
|
||||
if right_end > right_start:
|
||||
rise_brightness = float(np.mean(smoothed[right_start:right_end]))
|
||||
rise = rise_brightness - val_min
|
||||
if rise < shadow_range * 0.3:
|
||||
logger.debug(
|
||||
"%s edge: no spine (insufficient rise: %.0f, need %.0f)",
|
||||
side.capitalize(), rise, shadow_range * 0.3,
|
||||
)
|
||||
return None
|
||||
else: # right
|
||||
left_end = max(spine_idx - 5, 0)
|
||||
left_start = max(left_end - rise_check_w, 0)
|
||||
if left_end > left_start:
|
||||
rise_brightness = float(np.mean(smoothed[left_start:left_end]))
|
||||
rise = rise_brightness - val_min
|
||||
if rise < shadow_range * 0.3:
|
||||
logger.debug(
|
||||
"%s edge: no spine (insufficient rise: %.0f, need %.0f)",
|
||||
side.capitalize(), rise, shadow_range * 0.3,
|
||||
)
|
||||
return None
|
||||
|
||||
spine_x = offset_x + spine_local
|
||||
|
||||
logger.info(
|
||||
"%s edge: spine center at x=%d (brightness=%.0f, range=%.0f, valley=%dpx)",
|
||||
side.capitalize(), spine_x, val_min, shadow_range, valley_width,
|
||||
)
|
||||
return spine_x
|
||||
|
||||
|
||||
def _detect_gutter_continuity(
|
||||
gray: np.ndarray,
|
||||
search_region: np.ndarray,
|
||||
offset_x: int,
|
||||
w: int,
|
||||
side: str,
|
||||
) -> Optional[int]:
|
||||
"""Detect gutter shadow via vertical continuity analysis.
|
||||
|
||||
Camera book scans produce a subtle brightness gradient at the gutter
|
||||
that is too faint for scanner-shadow detection (range < 40). However,
|
||||
the gutter shadow has a unique property: it runs **continuously from
|
||||
top to bottom** without interruption.
|
||||
|
||||
Algorithm:
|
||||
1. Divide image into N horizontal strips (~60px each)
|
||||
2. For each column, compute what fraction of strips are darker than
|
||||
the page median (from the center 50% of the full image)
|
||||
3. A "gutter column" has >= 75% of strips darker than page_median - d
|
||||
4. Smooth the dark-fraction profile and find the transition point
|
||||
5. Validate: gutter band must be 0.5%-10% of image width
|
||||
"""
|
||||
region_h, region_w = search_region.shape[:2]
|
||||
if region_w < 20 or region_h < 100:
|
||||
return None
|
||||
|
||||
# --- 1. Divide into horizontal strips ---
|
||||
strip_target_h = 60
|
||||
n_strips = max(10, region_h // strip_target_h)
|
||||
strip_h = region_h // n_strips
|
||||
|
||||
strip_means = np.zeros((n_strips, region_w), dtype=np.float64)
|
||||
for s in range(n_strips):
|
||||
y0 = s * strip_h
|
||||
y1 = min((s + 1) * strip_h, region_h)
|
||||
strip_means[s] = np.mean(search_region[y0:y1, :], axis=0)
|
||||
|
||||
# --- 2. Page median from center 50% of full image ---
|
||||
center_lo = w // 4
|
||||
center_hi = 3 * w // 4
|
||||
page_median = float(np.median(gray[:, center_lo:center_hi]))
|
||||
|
||||
dark_thresh = page_median - 5.0
|
||||
|
||||
if page_median < 180:
|
||||
return None
|
||||
|
||||
# --- 3. Per-column dark fraction ---
|
||||
dark_count = np.sum(strip_means < dark_thresh, axis=0).astype(np.float64)
|
||||
dark_frac = dark_count / n_strips
|
||||
|
||||
# --- 4. Smooth and find transition ---
|
||||
smooth_w = max(5, w // 100)
|
||||
if smooth_w % 2 == 0:
|
||||
smooth_w += 1
|
||||
kernel = np.ones(smooth_w) / smooth_w
|
||||
frac_smooth = np.convolve(dark_frac, kernel, mode="same")
|
||||
|
||||
margin = smooth_w // 2
|
||||
if region_w <= 2 * margin + 10:
|
||||
return None
|
||||
|
||||
transition_thresh = 0.50
|
||||
peak_frac = float(np.max(frac_smooth[margin:region_w - margin]))
|
||||
|
||||
if peak_frac < 0.70:
|
||||
logger.debug(
|
||||
"%s gutter: peak dark fraction %.2f < 0.70", side.capitalize(), peak_frac,
|
||||
)
|
||||
return None
|
||||
|
||||
peak_x = int(np.argmax(frac_smooth[margin:region_w - margin])) + margin
|
||||
gutter_inner = None
|
||||
|
||||
if side == "right":
|
||||
for x in range(peak_x, margin, -1):
|
||||
if frac_smooth[x] < transition_thresh:
|
||||
gutter_inner = x + 1
|
||||
break
|
||||
else:
|
||||
for x in range(peak_x, region_w - margin):
|
||||
if frac_smooth[x] < transition_thresh:
|
||||
gutter_inner = x - 1
|
||||
break
|
||||
|
||||
if gutter_inner is None:
|
||||
return None
|
||||
|
||||
# --- 5. Validate gutter width ---
|
||||
if side == "right":
|
||||
gutter_width = region_w - gutter_inner
|
||||
else:
|
||||
gutter_width = gutter_inner
|
||||
|
||||
min_gutter = max(3, int(w * 0.005))
|
||||
max_gutter = int(w * 0.10)
|
||||
|
||||
if gutter_width < min_gutter:
|
||||
logger.debug(
|
||||
"%s gutter: too narrow (%dpx < %dpx)", side.capitalize(),
|
||||
gutter_width, min_gutter,
|
||||
)
|
||||
return None
|
||||
|
||||
if gutter_width > max_gutter:
|
||||
logger.debug(
|
||||
"%s gutter: too wide (%dpx > %dpx)", side.capitalize(),
|
||||
gutter_width, max_gutter,
|
||||
)
|
||||
return None
|
||||
|
||||
if side == "right":
|
||||
gutter_brightness = float(np.mean(strip_means[:, gutter_inner:]))
|
||||
else:
|
||||
gutter_brightness = float(np.mean(strip_means[:, :gutter_inner]))
|
||||
|
||||
brightness_drop = page_median - gutter_brightness
|
||||
if brightness_drop < 3:
|
||||
logger.debug(
|
||||
"%s gutter: insufficient brightness drop (%.1f levels)",
|
||||
side.capitalize(), brightness_drop,
|
||||
)
|
||||
return None
|
||||
|
||||
gutter_x = offset_x + gutter_inner
|
||||
|
||||
logger.info(
|
||||
"%s gutter (continuity): x=%d, width=%dpx (%.1f%%), "
|
||||
"brightness=%.0f vs page=%.0f (drop=%.0f), frac@edge=%.2f",
|
||||
side.capitalize(), gutter_x, gutter_width,
|
||||
100.0 * gutter_width / w, gutter_brightness, page_median,
|
||||
brightness_drop, float(frac_smooth[gutter_inner]),
|
||||
)
|
||||
return gutter_x
|
||||
|
||||
|
||||
def _detect_left_edge_shadow(
|
||||
gray: np.ndarray,
|
||||
binary: np.ndarray,
|
||||
w: int,
|
||||
h: int,
|
||||
) -> int:
|
||||
"""Detect left content edge, accounting for book-spine shadow.
|
||||
|
||||
Tries three methods in order:
|
||||
1. Scanner spine-shadow (dark gradient, range > 40)
|
||||
2. Camera gutter continuity (subtle shadow running top-to-bottom)
|
||||
3. Binary projection fallback (first ink column)
|
||||
"""
|
||||
search_w = max(1, w // 4)
|
||||
spine_x = _detect_spine_shadow(gray, gray[:, :search_w], 0, w, "left")
|
||||
if spine_x is not None:
|
||||
return spine_x
|
||||
|
||||
gutter_x = _detect_gutter_continuity(gray, gray[:, :search_w], 0, w, "left")
|
||||
if gutter_x is not None:
|
||||
return gutter_x
|
||||
|
||||
return _detect_edge_projection(binary, axis=0, from_start=True, dim=w)
|
||||
|
||||
|
||||
def _detect_right_edge_shadow(
|
||||
gray: np.ndarray,
|
||||
binary: np.ndarray,
|
||||
w: int,
|
||||
h: int,
|
||||
) -> int:
|
||||
"""Detect right content edge, accounting for book-spine shadow.
|
||||
|
||||
Tries three methods in order:
|
||||
1. Scanner spine-shadow (dark gradient, range > 40)
|
||||
2. Camera gutter continuity (subtle shadow running top-to-bottom)
|
||||
3. Binary projection fallback (last ink column)
|
||||
"""
|
||||
search_w = max(1, w // 4)
|
||||
right_start = w - search_w
|
||||
spine_x = _detect_spine_shadow(gray, gray[:, right_start:], right_start, w, "right")
|
||||
if spine_x is not None:
|
||||
return spine_x
|
||||
|
||||
gutter_x = _detect_gutter_continuity(gray, gray[:, right_start:], right_start, w, "right")
|
||||
if gutter_x is not None:
|
||||
return gutter_x
|
||||
|
||||
return _detect_edge_projection(binary, axis=0, from_start=False, dim=w)
|
||||
|
||||
|
||||
def _detect_top_bottom_edges(binary: np.ndarray, w: int, h: int) -> Tuple[int, int]:
|
||||
"""Detect top and bottom content edges via binary horizontal projection."""
|
||||
top = _detect_edge_projection(binary, axis=1, from_start=True, dim=h)
|
||||
bottom = _detect_edge_projection(binary, axis=1, from_start=False, dim=h)
|
||||
return top, bottom
|
||||
|
||||
|
||||
def _detect_edge_projection(
|
||||
binary: np.ndarray,
|
||||
axis: int,
|
||||
from_start: bool,
|
||||
dim: int,
|
||||
) -> int:
|
||||
"""Find the first/last row or column with ink density above threshold.
|
||||
|
||||
axis=0 -> project vertically (column densities) -> returns x position
|
||||
axis=1 -> project horizontally (row densities) -> returns y position
|
||||
|
||||
Filters out narrow noise runs shorter than _MIN_RUN_FRAC of the dimension.
|
||||
"""
|
||||
projection = np.mean(binary, axis=axis) / 255.0
|
||||
|
||||
ink_mask = projection >= _INK_THRESHOLD
|
||||
|
||||
min_run = max(1, int(dim * _MIN_RUN_FRAC))
|
||||
ink_mask = _filter_narrow_runs(ink_mask, min_run)
|
||||
|
||||
ink_positions = np.where(ink_mask)[0]
|
||||
if len(ink_positions) == 0:
|
||||
return 0 if from_start else dim
|
||||
|
||||
if from_start:
|
||||
return int(ink_positions[0])
|
||||
else:
|
||||
return int(ink_positions[-1])
|
||||
|
||||
|
||||
def _filter_narrow_runs(mask: np.ndarray, min_run: int) -> np.ndarray:
|
||||
"""Remove True-runs shorter than min_run pixels."""
|
||||
if min_run <= 1:
|
||||
return mask
|
||||
|
||||
result = mask.copy()
|
||||
n = len(result)
|
||||
i = 0
|
||||
while i < n:
|
||||
if result[i]:
|
||||
start = i
|
||||
while i < n and result[i]:
|
||||
i += 1
|
||||
if i - start < min_run:
|
||||
result[start:i] = False
|
||||
else:
|
||||
i += 1
|
||||
return result
|
||||
160
klausur-service/backend/services/trocr_batch.py
Normal file
160
klausur-service/backend/services/trocr_batch.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""
|
||||
TrOCR Batch Processing & Streaming
|
||||
|
||||
Batch OCR and SSE streaming for multiple images.
|
||||
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
from .trocr_models import OCRResult, BatchOCRResult
|
||||
from .trocr_ocr import run_trocr_ocr_enhanced
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def run_trocr_batch(
|
||||
images: List[bytes],
|
||||
handwritten: bool = True,
|
||||
split_lines: bool = True,
|
||||
use_cache: bool = True,
|
||||
progress_callback: Optional[callable] = None
|
||||
) -> BatchOCRResult:
|
||||
"""
|
||||
Process multiple images in batch.
|
||||
|
||||
Args:
|
||||
images: List of image data bytes
|
||||
handwritten: Use handwritten model
|
||||
split_lines: Whether to split images into lines
|
||||
use_cache: Whether to use caching
|
||||
progress_callback: Optional callback(current, total) for progress updates
|
||||
|
||||
Returns:
|
||||
BatchOCRResult with all results
|
||||
"""
|
||||
start_time = time.time()
|
||||
results = []
|
||||
cached_count = 0
|
||||
error_count = 0
|
||||
|
||||
for idx, image_data in enumerate(images):
|
||||
try:
|
||||
result = await run_trocr_ocr_enhanced(
|
||||
image_data,
|
||||
handwritten=handwritten,
|
||||
split_lines=split_lines,
|
||||
use_cache=use_cache
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
if result.from_cache:
|
||||
cached_count += 1
|
||||
|
||||
# Report progress
|
||||
if progress_callback:
|
||||
progress_callback(idx + 1, len(images))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Batch OCR error for image {idx}: {e}")
|
||||
error_count += 1
|
||||
results.append(OCRResult(
|
||||
text=f"Error: {str(e)}",
|
||||
confidence=0.0,
|
||||
processing_time_ms=0,
|
||||
model="error",
|
||||
has_lora_adapter=False
|
||||
))
|
||||
|
||||
total_time_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
return BatchOCRResult(
|
||||
results=results,
|
||||
total_time_ms=total_time_ms,
|
||||
processed_count=len(images),
|
||||
cached_count=cached_count,
|
||||
error_count=error_count
|
||||
)
|
||||
|
||||
|
||||
# Generator for SSE streaming during batch processing
|
||||
async def run_trocr_batch_stream(
|
||||
images: List[bytes],
|
||||
handwritten: bool = True,
|
||||
split_lines: bool = True,
|
||||
use_cache: bool = True
|
||||
):
|
||||
"""
|
||||
Process images and yield progress updates for SSE streaming.
|
||||
|
||||
Yields:
|
||||
dict with current progress and result
|
||||
"""
|
||||
start_time = time.time()
|
||||
total = len(images)
|
||||
|
||||
for idx, image_data in enumerate(images):
|
||||
try:
|
||||
result = await run_trocr_ocr_enhanced(
|
||||
image_data,
|
||||
handwritten=handwritten,
|
||||
split_lines=split_lines,
|
||||
use_cache=use_cache
|
||||
)
|
||||
|
||||
elapsed_ms = int((time.time() - start_time) * 1000)
|
||||
avg_time_per_image = elapsed_ms / (idx + 1)
|
||||
estimated_remaining = int(avg_time_per_image * (total - idx - 1))
|
||||
|
||||
yield {
|
||||
"type": "progress",
|
||||
"current": idx + 1,
|
||||
"total": total,
|
||||
"progress_percent": ((idx + 1) / total) * 100,
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"estimated_remaining_ms": estimated_remaining,
|
||||
"result": {
|
||||
"text": result.text,
|
||||
"confidence": result.confidence,
|
||||
"processing_time_ms": result.processing_time_ms,
|
||||
"from_cache": result.from_cache
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Stream OCR error for image {idx}: {e}")
|
||||
yield {
|
||||
"type": "error",
|
||||
"current": idx + 1,
|
||||
"total": total,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
total_time_ms = int((time.time() - start_time) * 1000)
|
||||
yield {
|
||||
"type": "complete",
|
||||
"total_time_ms": total_time_ms,
|
||||
"processed_count": total
|
||||
}
|
||||
|
||||
|
||||
# Test function
|
||||
async def test_trocr_ocr(image_path: str, handwritten: bool = False):
|
||||
"""Test TrOCR on a local image file."""
|
||||
from .trocr_ocr import run_trocr_ocr
|
||||
|
||||
with open(image_path, "rb") as f:
|
||||
image_data = f.read()
|
||||
|
||||
text, confidence = await run_trocr_ocr(image_data, handwritten=handwritten)
|
||||
|
||||
print(f"\n=== TrOCR Test ===")
|
||||
print(f"Mode: {'Handwritten' if handwritten else 'Printed'}")
|
||||
print(f"Confidence: {confidence:.2f}")
|
||||
print(f"Text:\n{text}")
|
||||
|
||||
return text, confidence
|
||||
278
klausur-service/backend/services/trocr_models.py
Normal file
278
klausur-service/backend/services/trocr_models.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""
|
||||
TrOCR Models & Cache
|
||||
|
||||
Dataclasses, LRU cache, and model loading for TrOCR service.
|
||||
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import io
|
||||
import os
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
from typing import Tuple, Optional, List, Dict, Any
|
||||
from dataclasses import dataclass, field
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backend routing: auto | pytorch | onnx
|
||||
# ---------------------------------------------------------------------------
|
||||
_trocr_backend = os.environ.get("TROCR_BACKEND", "auto") # auto | pytorch | onnx
|
||||
|
||||
# Lazy loading for heavy dependencies
|
||||
# Cache keyed by model_name to support base and large variants simultaneously
|
||||
_trocr_models: dict = {} # {model_name: (processor, model)}
|
||||
_trocr_processor = None # backwards-compat alias -> base-printed
|
||||
_trocr_model = None # backwards-compat alias -> base-printed
|
||||
_trocr_available = None
|
||||
_model_loaded_at = None
|
||||
|
||||
# Simple in-memory cache with LRU eviction
|
||||
_ocr_cache: OrderedDict[str, Dict[str, Any]] = OrderedDict()
|
||||
_cache_max_size = 100
|
||||
_cache_ttl_seconds = 3600 # 1 hour
|
||||
|
||||
|
||||
@dataclass
|
||||
class OCRResult:
|
||||
"""Enhanced OCR result with detailed information."""
|
||||
text: str
|
||||
confidence: float
|
||||
processing_time_ms: int
|
||||
model: str
|
||||
has_lora_adapter: bool = False
|
||||
char_confidences: List[float] = field(default_factory=list)
|
||||
word_boxes: List[Dict[str, Any]] = field(default_factory=list)
|
||||
from_cache: bool = False
|
||||
image_hash: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchOCRResult:
|
||||
"""Result for batch processing."""
|
||||
results: List[OCRResult]
|
||||
total_time_ms: int
|
||||
processed_count: int
|
||||
cached_count: int
|
||||
error_count: int
|
||||
|
||||
|
||||
def _compute_image_hash(image_data: bytes) -> str:
|
||||
"""Compute SHA256 hash of image data for caching."""
|
||||
return hashlib.sha256(image_data).hexdigest()[:16]
|
||||
|
||||
|
||||
def _cache_get(image_hash: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get cached OCR result if available and not expired."""
|
||||
if image_hash in _ocr_cache:
|
||||
entry = _ocr_cache[image_hash]
|
||||
if datetime.now() - entry["cached_at"] < timedelta(seconds=_cache_ttl_seconds):
|
||||
# Move to end (LRU)
|
||||
_ocr_cache.move_to_end(image_hash)
|
||||
return entry["result"]
|
||||
else:
|
||||
# Expired, remove
|
||||
del _ocr_cache[image_hash]
|
||||
return None
|
||||
|
||||
|
||||
def _cache_set(image_hash: str, result: Dict[str, Any]) -> None:
|
||||
"""Store OCR result in cache."""
|
||||
# Evict oldest if at capacity
|
||||
while len(_ocr_cache) >= _cache_max_size:
|
||||
_ocr_cache.popitem(last=False)
|
||||
|
||||
_ocr_cache[image_hash] = {
|
||||
"result": result,
|
||||
"cached_at": datetime.now()
|
||||
}
|
||||
|
||||
|
||||
def get_cache_stats() -> Dict[str, Any]:
|
||||
"""Get cache statistics."""
|
||||
return {
|
||||
"size": len(_ocr_cache),
|
||||
"max_size": _cache_max_size,
|
||||
"ttl_seconds": _cache_ttl_seconds,
|
||||
"hit_rate": 0 # Could track this with additional counters
|
||||
}
|
||||
|
||||
|
||||
def _check_trocr_available() -> bool:
|
||||
"""Check if TrOCR dependencies are available."""
|
||||
global _trocr_available
|
||||
if _trocr_available is not None:
|
||||
return _trocr_available
|
||||
|
||||
try:
|
||||
import torch
|
||||
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
||||
_trocr_available = True
|
||||
except ImportError as e:
|
||||
logger.warning(f"TrOCR dependencies not available: {e}")
|
||||
_trocr_available = False
|
||||
|
||||
return _trocr_available
|
||||
|
||||
|
||||
def get_trocr_model(handwritten: bool = False, size: str = "base"):
|
||||
"""
|
||||
Lazy load TrOCR model and processor.
|
||||
|
||||
Args:
|
||||
handwritten: Use handwritten model instead of printed model
|
||||
size: Model size -- "base" (300 MB) or "large" (340 MB, higher accuracy
|
||||
for exam HTR). Only applies to handwritten variant.
|
||||
|
||||
Returns tuple of (processor, model) or (None, None) if unavailable.
|
||||
"""
|
||||
global _trocr_processor, _trocr_model
|
||||
|
||||
if not _check_trocr_available():
|
||||
return None, None
|
||||
|
||||
# Select model name
|
||||
if size == "large" and handwritten:
|
||||
model_name = "microsoft/trocr-large-handwritten"
|
||||
elif handwritten:
|
||||
model_name = "microsoft/trocr-base-handwritten"
|
||||
else:
|
||||
model_name = "microsoft/trocr-base-printed"
|
||||
|
||||
if model_name in _trocr_models:
|
||||
return _trocr_models[model_name]
|
||||
|
||||
try:
|
||||
import torch
|
||||
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
||||
|
||||
logger.info(f"Loading TrOCR model: {model_name}")
|
||||
processor = TrOCRProcessor.from_pretrained(model_name)
|
||||
model = VisionEncoderDecoderModel.from_pretrained(model_name)
|
||||
|
||||
# Use GPU if available
|
||||
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
model.to(device)
|
||||
logger.info(f"TrOCR model loaded on device: {device}")
|
||||
|
||||
_trocr_models[model_name] = (processor, model)
|
||||
|
||||
# Keep backwards-compat globals pointing at base-printed
|
||||
if model_name == "microsoft/trocr-base-printed":
|
||||
_trocr_processor = processor
|
||||
_trocr_model = model
|
||||
|
||||
return processor, model
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load TrOCR model {model_name}: {e}")
|
||||
return None, None
|
||||
|
||||
|
||||
def preload_trocr_model(handwritten: bool = True) -> bool:
|
||||
"""
|
||||
Preload TrOCR model at startup for faster first request.
|
||||
|
||||
Call this from your FastAPI startup event:
|
||||
@app.on_event("startup")
|
||||
async def startup():
|
||||
preload_trocr_model()
|
||||
"""
|
||||
global _model_loaded_at
|
||||
logger.info("Preloading TrOCR model...")
|
||||
processor, model = get_trocr_model(handwritten=handwritten)
|
||||
if processor is not None and model is not None:
|
||||
_model_loaded_at = datetime.now()
|
||||
logger.info("TrOCR model preloaded successfully")
|
||||
return True
|
||||
else:
|
||||
logger.warning("TrOCR model preloading failed")
|
||||
return False
|
||||
|
||||
|
||||
def get_model_status() -> Dict[str, Any]:
|
||||
"""Get current model status information."""
|
||||
processor, model = get_trocr_model(handwritten=True)
|
||||
is_loaded = processor is not None and model is not None
|
||||
|
||||
status = {
|
||||
"status": "available" if is_loaded else "not_installed",
|
||||
"is_loaded": is_loaded,
|
||||
"model_name": "trocr-base-handwritten" if is_loaded else None,
|
||||
"loaded_at": _model_loaded_at.isoformat() if _model_loaded_at else None,
|
||||
}
|
||||
|
||||
if is_loaded:
|
||||
import torch
|
||||
device = next(model.parameters()).device
|
||||
status["device"] = str(device)
|
||||
|
||||
return status
|
||||
|
||||
|
||||
def get_active_backend() -> str:
|
||||
"""
|
||||
Return which TrOCR backend is configured.
|
||||
|
||||
Possible values: "auto", "pytorch", "onnx".
|
||||
"""
|
||||
return _trocr_backend
|
||||
|
||||
|
||||
def _split_into_lines(image) -> list:
|
||||
"""
|
||||
Split an image into text lines using simple projection-based segmentation.
|
||||
|
||||
This is a basic implementation - for production use, consider using
|
||||
a dedicated line detection model.
|
||||
"""
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
try:
|
||||
# Convert to grayscale
|
||||
gray = image.convert('L')
|
||||
img_array = np.array(gray)
|
||||
|
||||
# Binarize (simple threshold)
|
||||
threshold = 200
|
||||
binary = img_array < threshold
|
||||
|
||||
# Horizontal projection (sum of dark pixels per row)
|
||||
h_proj = np.sum(binary, axis=1)
|
||||
|
||||
# Find line boundaries (where projection drops below threshold)
|
||||
line_threshold = img_array.shape[1] * 0.02 # 2% of width
|
||||
in_line = False
|
||||
line_start = 0
|
||||
lines = []
|
||||
|
||||
for i, val in enumerate(h_proj):
|
||||
if val > line_threshold and not in_line:
|
||||
# Start of line
|
||||
in_line = True
|
||||
line_start = i
|
||||
elif val <= line_threshold and in_line:
|
||||
# End of line
|
||||
in_line = False
|
||||
# Add padding
|
||||
start = max(0, line_start - 5)
|
||||
end = min(img_array.shape[0], i + 5)
|
||||
if end - start > 10: # Minimum line height
|
||||
lines.append(image.crop((0, start, image.width, end)))
|
||||
|
||||
# Handle last line if still in_line
|
||||
if in_line:
|
||||
start = max(0, line_start - 5)
|
||||
lines.append(image.crop((0, start, image.width, image.height)))
|
||||
|
||||
logger.info(f"Split image into {len(lines)} lines")
|
||||
return lines
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Line splitting failed: {e}")
|
||||
return []
|
||||
309
klausur-service/backend/services/trocr_ocr.py
Normal file
309
klausur-service/backend/services/trocr_ocr.py
Normal file
@@ -0,0 +1,309 @@
|
||||
"""
|
||||
TrOCR OCR Execution
|
||||
|
||||
Core OCR inference routines (PyTorch, ONNX routing, enhanced mode).
|
||||
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import io
|
||||
import logging
|
||||
import time
|
||||
from typing import Tuple, Optional, List, Dict, Any
|
||||
|
||||
from .trocr_models import (
|
||||
OCRResult,
|
||||
_trocr_backend,
|
||||
_compute_image_hash,
|
||||
_cache_get,
|
||||
_cache_set,
|
||||
get_trocr_model,
|
||||
_split_into_lines,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _try_onnx_ocr(
|
||||
image_data: bytes,
|
||||
handwritten: bool = False,
|
||||
split_lines: bool = True,
|
||||
) -> Optional[Tuple[Optional[str], float]]:
|
||||
"""
|
||||
Attempt ONNX inference. Returns the (text, confidence) tuple on
|
||||
success, or None if ONNX is not available / fails to load.
|
||||
"""
|
||||
try:
|
||||
from .trocr_onnx_service import is_onnx_available, run_trocr_onnx
|
||||
|
||||
if not is_onnx_available(handwritten=handwritten):
|
||||
return None
|
||||
# run_trocr_onnx is async -- return the coroutine's awaitable result
|
||||
# The caller (run_trocr_ocr) will await it.
|
||||
return run_trocr_onnx # sentinel: caller checks callable
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
async def _run_pytorch_ocr(
|
||||
image_data: bytes,
|
||||
handwritten: bool = False,
|
||||
split_lines: bool = True,
|
||||
size: str = "base",
|
||||
) -> Tuple[Optional[str], float]:
|
||||
"""
|
||||
Original PyTorch inference path (extracted for routing).
|
||||
"""
|
||||
processor, model = get_trocr_model(handwritten=handwritten, size=size)
|
||||
|
||||
if processor is None or model is None:
|
||||
logger.error("TrOCR PyTorch model not available")
|
||||
return None, 0.0
|
||||
|
||||
try:
|
||||
import torch
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
# Load image
|
||||
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
||||
|
||||
if split_lines:
|
||||
lines = _split_into_lines(image)
|
||||
if not lines:
|
||||
lines = [image]
|
||||
else:
|
||||
lines = [image]
|
||||
|
||||
all_text = []
|
||||
confidences = []
|
||||
|
||||
for line_image in lines:
|
||||
pixel_values = processor(images=line_image, return_tensors="pt").pixel_values
|
||||
|
||||
device = next(model.parameters()).device
|
||||
pixel_values = pixel_values.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
generated_ids = model.generate(pixel_values, max_length=128)
|
||||
|
||||
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
if generated_text.strip():
|
||||
all_text.append(generated_text.strip())
|
||||
confidences.append(0.85 if len(generated_text) > 3 else 0.5)
|
||||
|
||||
text = "\n".join(all_text)
|
||||
confidence = sum(confidences) / len(confidences) if confidences else 0.0
|
||||
|
||||
logger.info(f"TrOCR (PyTorch) extracted {len(text)} characters from {len(lines)} lines")
|
||||
return text, confidence
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"TrOCR PyTorch failed: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return None, 0.0
|
||||
|
||||
|
||||
async def run_trocr_ocr(
|
||||
image_data: bytes,
|
||||
handwritten: bool = False,
|
||||
split_lines: bool = True,
|
||||
size: str = "base",
|
||||
) -> Tuple[Optional[str], float]:
|
||||
"""
|
||||
Run TrOCR on an image.
|
||||
|
||||
Routes between ONNX and PyTorch backends based on the TROCR_BACKEND
|
||||
environment variable (default: "auto").
|
||||
|
||||
- "onnx" -- always use ONNX (raises RuntimeError if unavailable)
|
||||
- "pytorch" -- always use PyTorch (original behaviour)
|
||||
- "auto" -- try ONNX first, fall back to PyTorch
|
||||
|
||||
TrOCR is optimized for single-line text recognition, so for full-page
|
||||
images we need to either:
|
||||
1. Split into lines first (using line detection)
|
||||
2. Process the whole image and get partial results
|
||||
|
||||
Args:
|
||||
image_data: Raw image bytes
|
||||
handwritten: Use handwritten model (slower but better for handwriting)
|
||||
split_lines: Whether to split image into lines first
|
||||
size: "base" or "large" (only for handwritten variant)
|
||||
|
||||
Returns:
|
||||
Tuple of (extracted_text, confidence)
|
||||
"""
|
||||
backend = _trocr_backend
|
||||
|
||||
# --- ONNX-only mode ---
|
||||
if backend == "onnx":
|
||||
onnx_fn = _try_onnx_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
if onnx_fn is None or not callable(onnx_fn):
|
||||
raise RuntimeError(
|
||||
"ONNX backend requested (TROCR_BACKEND=onnx) but unavailable. "
|
||||
"Ensure onnxruntime + optimum are installed and ONNX model files exist."
|
||||
)
|
||||
return await onnx_fn(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
|
||||
# --- PyTorch-only mode ---
|
||||
if backend == "pytorch":
|
||||
return await _run_pytorch_ocr(
|
||||
image_data, handwritten=handwritten, split_lines=split_lines, size=size,
|
||||
)
|
||||
|
||||
# --- Auto mode: try ONNX first, then PyTorch ---
|
||||
onnx_fn = _try_onnx_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
if onnx_fn is not None and callable(onnx_fn):
|
||||
try:
|
||||
result = await onnx_fn(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
if result[0] is not None:
|
||||
return result
|
||||
logger.warning("ONNX returned None text, falling back to PyTorch")
|
||||
except Exception as e:
|
||||
logger.warning(f"ONNX inference failed ({e}), falling back to PyTorch")
|
||||
|
||||
return await _run_pytorch_ocr(
|
||||
image_data, handwritten=handwritten, split_lines=split_lines, size=size,
|
||||
)
|
||||
|
||||
|
||||
def _try_onnx_enhanced(
|
||||
handwritten: bool = True,
|
||||
):
|
||||
"""
|
||||
Return the ONNX enhanced coroutine function, or None if unavailable.
|
||||
"""
|
||||
try:
|
||||
from .trocr_onnx_service import is_onnx_available, run_trocr_onnx_enhanced
|
||||
|
||||
if not is_onnx_available(handwritten=handwritten):
|
||||
return None
|
||||
return run_trocr_onnx_enhanced
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
async def run_trocr_ocr_enhanced(
|
||||
image_data: bytes,
|
||||
handwritten: bool = True,
|
||||
split_lines: bool = True,
|
||||
use_cache: bool = True
|
||||
) -> OCRResult:
|
||||
"""
|
||||
Enhanced TrOCR OCR with caching and detailed results.
|
||||
|
||||
Routes between ONNX and PyTorch backends based on the TROCR_BACKEND
|
||||
environment variable (default: "auto").
|
||||
|
||||
Args:
|
||||
image_data: Raw image bytes
|
||||
handwritten: Use handwritten model
|
||||
split_lines: Whether to split image into lines first
|
||||
use_cache: Whether to use caching
|
||||
|
||||
Returns:
|
||||
OCRResult with detailed information
|
||||
"""
|
||||
backend = _trocr_backend
|
||||
|
||||
# --- ONNX-only mode ---
|
||||
if backend == "onnx":
|
||||
onnx_fn = _try_onnx_enhanced(handwritten=handwritten)
|
||||
if onnx_fn is None:
|
||||
raise RuntimeError(
|
||||
"ONNX backend requested (TROCR_BACKEND=onnx) but unavailable. "
|
||||
"Ensure onnxruntime + optimum are installed and ONNX model files exist."
|
||||
)
|
||||
return await onnx_fn(
|
||||
image_data, handwritten=handwritten,
|
||||
split_lines=split_lines, use_cache=use_cache,
|
||||
)
|
||||
|
||||
# --- Auto mode: try ONNX first ---
|
||||
if backend == "auto":
|
||||
onnx_fn = _try_onnx_enhanced(handwritten=handwritten)
|
||||
if onnx_fn is not None:
|
||||
try:
|
||||
result = await onnx_fn(
|
||||
image_data, handwritten=handwritten,
|
||||
split_lines=split_lines, use_cache=use_cache,
|
||||
)
|
||||
if result.text:
|
||||
return result
|
||||
logger.warning("ONNX enhanced returned empty text, falling back to PyTorch")
|
||||
except Exception as e:
|
||||
logger.warning(f"ONNX enhanced failed ({e}), falling back to PyTorch")
|
||||
|
||||
# --- PyTorch path (backend == "pytorch" or auto fallback) ---
|
||||
start_time = time.time()
|
||||
|
||||
# Check cache first
|
||||
image_hash = _compute_image_hash(image_data)
|
||||
if use_cache:
|
||||
cached = _cache_get(image_hash)
|
||||
if cached:
|
||||
return OCRResult(
|
||||
text=cached["text"],
|
||||
confidence=cached["confidence"],
|
||||
processing_time_ms=0,
|
||||
model=cached["model"],
|
||||
has_lora_adapter=cached.get("has_lora_adapter", False),
|
||||
char_confidences=cached.get("char_confidences", []),
|
||||
word_boxes=cached.get("word_boxes", []),
|
||||
from_cache=True,
|
||||
image_hash=image_hash
|
||||
)
|
||||
|
||||
# Run OCR via PyTorch
|
||||
text, confidence = await _run_pytorch_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
|
||||
processing_time_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# Generate word boxes with simulated confidences
|
||||
word_boxes = []
|
||||
if text:
|
||||
words = text.split()
|
||||
for idx, word in enumerate(words):
|
||||
# Simulate word confidence (slightly varied around overall confidence)
|
||||
word_conf = min(1.0, max(0.0, confidence + (hash(word) % 20 - 10) / 100))
|
||||
word_boxes.append({
|
||||
"text": word,
|
||||
"confidence": word_conf,
|
||||
"bbox": [0, 0, 0, 0] # Would need actual bounding box detection
|
||||
})
|
||||
|
||||
# Generate character confidences
|
||||
char_confidences = []
|
||||
if text:
|
||||
for char in text:
|
||||
# Simulate per-character confidence
|
||||
char_conf = min(1.0, max(0.0, confidence + (hash(char) % 15 - 7) / 100))
|
||||
char_confidences.append(char_conf)
|
||||
|
||||
result = OCRResult(
|
||||
text=text or "",
|
||||
confidence=confidence,
|
||||
processing_time_ms=processing_time_ms,
|
||||
model="trocr-base-handwritten" if handwritten else "trocr-base-printed",
|
||||
has_lora_adapter=False, # Would check actual adapter status
|
||||
char_confidences=char_confidences,
|
||||
word_boxes=word_boxes,
|
||||
from_cache=False,
|
||||
image_hash=image_hash
|
||||
)
|
||||
|
||||
# Cache result
|
||||
if use_cache and text:
|
||||
_cache_set(image_hash, {
|
||||
"text": result.text,
|
||||
"confidence": result.confidence,
|
||||
"model": result.model,
|
||||
"has_lora_adapter": result.has_lora_adapter,
|
||||
"char_confidences": result.char_confidences,
|
||||
"word_boxes": result.word_boxes
|
||||
})
|
||||
|
||||
return result
|
||||
@@ -1,720 +1,70 @@
|
||||
"""
|
||||
TrOCR Service
|
||||
TrOCR Service — Barrel Re-export
|
||||
|
||||
Microsoft's Transformer-based OCR for text recognition.
|
||||
Besonders geeignet fuer:
|
||||
- Gedruckten Text
|
||||
- Saubere Scans
|
||||
- Schnelle Verarbeitung
|
||||
|
||||
Model: microsoft/trocr-base-printed (oder handwritten Variante)
|
||||
Split into submodules:
|
||||
- trocr_models.py — Dataclasses, cache, model loading, line splitting
|
||||
- trocr_ocr.py — Core OCR inference (PyTorch/ONNX routing, enhanced)
|
||||
- trocr_batch.py — Batch processing and SSE streaming
|
||||
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
|
||||
Phase 2 Enhancements:
|
||||
- Batch processing for multiple images
|
||||
- SHA256-based caching for repeated requests
|
||||
- Model preloading for faster first request
|
||||
- Word-level bounding boxes with confidence
|
||||
"""
|
||||
|
||||
import io
|
||||
import os
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Tuple, Optional, List, Dict, Any
|
||||
from dataclasses import dataclass, field
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backend routing: auto | pytorch | onnx
|
||||
# ---------------------------------------------------------------------------
|
||||
_trocr_backend = os.environ.get("TROCR_BACKEND", "auto") # auto | pytorch | onnx
|
||||
|
||||
# Lazy loading for heavy dependencies
|
||||
# Cache keyed by model_name to support base and large variants simultaneously
|
||||
_trocr_models: dict = {} # {model_name: (processor, model)}
|
||||
_trocr_processor = None # backwards-compat alias → base-printed
|
||||
_trocr_model = None # backwards-compat alias → base-printed
|
||||
_trocr_available = None
|
||||
_model_loaded_at = None
|
||||
|
||||
# Simple in-memory cache with LRU eviction
|
||||
_ocr_cache: OrderedDict[str, Dict[str, Any]] = OrderedDict()
|
||||
_cache_max_size = 100
|
||||
_cache_ttl_seconds = 3600 # 1 hour
|
||||
|
||||
|
||||
@dataclass
|
||||
class OCRResult:
|
||||
"""Enhanced OCR result with detailed information."""
|
||||
text: str
|
||||
confidence: float
|
||||
processing_time_ms: int
|
||||
model: str
|
||||
has_lora_adapter: bool = False
|
||||
char_confidences: List[float] = field(default_factory=list)
|
||||
word_boxes: List[Dict[str, Any]] = field(default_factory=list)
|
||||
from_cache: bool = False
|
||||
image_hash: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchOCRResult:
|
||||
"""Result for batch processing."""
|
||||
results: List[OCRResult]
|
||||
total_time_ms: int
|
||||
processed_count: int
|
||||
cached_count: int
|
||||
error_count: int
|
||||
|
||||
|
||||
def _compute_image_hash(image_data: bytes) -> str:
|
||||
"""Compute SHA256 hash of image data for caching."""
|
||||
return hashlib.sha256(image_data).hexdigest()[:16]
|
||||
|
||||
|
||||
def _cache_get(image_hash: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get cached OCR result if available and not expired."""
|
||||
if image_hash in _ocr_cache:
|
||||
entry = _ocr_cache[image_hash]
|
||||
if datetime.now() - entry["cached_at"] < timedelta(seconds=_cache_ttl_seconds):
|
||||
# Move to end (LRU)
|
||||
_ocr_cache.move_to_end(image_hash)
|
||||
return entry["result"]
|
||||
else:
|
||||
# Expired, remove
|
||||
del _ocr_cache[image_hash]
|
||||
return None
|
||||
|
||||
|
||||
def _cache_set(image_hash: str, result: Dict[str, Any]) -> None:
|
||||
"""Store OCR result in cache."""
|
||||
# Evict oldest if at capacity
|
||||
while len(_ocr_cache) >= _cache_max_size:
|
||||
_ocr_cache.popitem(last=False)
|
||||
|
||||
_ocr_cache[image_hash] = {
|
||||
"result": result,
|
||||
"cached_at": datetime.now()
|
||||
}
|
||||
|
||||
|
||||
def get_cache_stats() -> Dict[str, Any]:
|
||||
"""Get cache statistics."""
|
||||
return {
|
||||
"size": len(_ocr_cache),
|
||||
"max_size": _cache_max_size,
|
||||
"ttl_seconds": _cache_ttl_seconds,
|
||||
"hit_rate": 0 # Could track this with additional counters
|
||||
}
|
||||
|
||||
|
||||
def _check_trocr_available() -> bool:
|
||||
"""Check if TrOCR dependencies are available."""
|
||||
global _trocr_available
|
||||
if _trocr_available is not None:
|
||||
return _trocr_available
|
||||
|
||||
try:
|
||||
import torch
|
||||
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
||||
_trocr_available = True
|
||||
except ImportError as e:
|
||||
logger.warning(f"TrOCR dependencies not available: {e}")
|
||||
_trocr_available = False
|
||||
|
||||
return _trocr_available
|
||||
|
||||
|
||||
def get_trocr_model(handwritten: bool = False, size: str = "base"):
|
||||
"""
|
||||
Lazy load TrOCR model and processor.
|
||||
|
||||
Args:
|
||||
handwritten: Use handwritten model instead of printed model
|
||||
size: Model size — "base" (300 MB) or "large" (340 MB, higher accuracy
|
||||
for exam HTR). Only applies to handwritten variant.
|
||||
|
||||
Returns tuple of (processor, model) or (None, None) if unavailable.
|
||||
"""
|
||||
global _trocr_processor, _trocr_model
|
||||
|
||||
if not _check_trocr_available():
|
||||
return None, None
|
||||
|
||||
# Select model name
|
||||
if size == "large" and handwritten:
|
||||
model_name = "microsoft/trocr-large-handwritten"
|
||||
elif handwritten:
|
||||
model_name = "microsoft/trocr-base-handwritten"
|
||||
else:
|
||||
model_name = "microsoft/trocr-base-printed"
|
||||
|
||||
if model_name in _trocr_models:
|
||||
return _trocr_models[model_name]
|
||||
|
||||
try:
|
||||
import torch
|
||||
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
||||
|
||||
logger.info(f"Loading TrOCR model: {model_name}")
|
||||
processor = TrOCRProcessor.from_pretrained(model_name)
|
||||
model = VisionEncoderDecoderModel.from_pretrained(model_name)
|
||||
|
||||
# Use GPU if available
|
||||
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
model.to(device)
|
||||
logger.info(f"TrOCR model loaded on device: {device}")
|
||||
|
||||
_trocr_models[model_name] = (processor, model)
|
||||
|
||||
# Keep backwards-compat globals pointing at base-printed
|
||||
if model_name == "microsoft/trocr-base-printed":
|
||||
_trocr_processor = processor
|
||||
_trocr_model = model
|
||||
|
||||
return processor, model
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load TrOCR model {model_name}: {e}")
|
||||
return None, None
|
||||
|
||||
|
||||
def preload_trocr_model(handwritten: bool = True) -> bool:
|
||||
"""
|
||||
Preload TrOCR model at startup for faster first request.
|
||||
|
||||
Call this from your FastAPI startup event:
|
||||
@app.on_event("startup")
|
||||
async def startup():
|
||||
preload_trocr_model()
|
||||
"""
|
||||
global _model_loaded_at
|
||||
logger.info("Preloading TrOCR model...")
|
||||
processor, model = get_trocr_model(handwritten=handwritten)
|
||||
if processor is not None and model is not None:
|
||||
_model_loaded_at = datetime.now()
|
||||
logger.info("TrOCR model preloaded successfully")
|
||||
return True
|
||||
else:
|
||||
logger.warning("TrOCR model preloading failed")
|
||||
return False
|
||||
|
||||
|
||||
def get_model_status() -> Dict[str, Any]:
|
||||
"""Get current model status information."""
|
||||
processor, model = get_trocr_model(handwritten=True)
|
||||
is_loaded = processor is not None and model is not None
|
||||
|
||||
status = {
|
||||
"status": "available" if is_loaded else "not_installed",
|
||||
"is_loaded": is_loaded,
|
||||
"model_name": "trocr-base-handwritten" if is_loaded else None,
|
||||
"loaded_at": _model_loaded_at.isoformat() if _model_loaded_at else None,
|
||||
}
|
||||
|
||||
if is_loaded:
|
||||
import torch
|
||||
device = next(model.parameters()).device
|
||||
status["device"] = str(device)
|
||||
|
||||
return status
|
||||
|
||||
|
||||
def get_active_backend() -> str:
|
||||
"""
|
||||
Return which TrOCR backend is configured.
|
||||
|
||||
Possible values: "auto", "pytorch", "onnx".
|
||||
"""
|
||||
return _trocr_backend
|
||||
|
||||
|
||||
def _try_onnx_ocr(
|
||||
image_data: bytes,
|
||||
handwritten: bool = False,
|
||||
split_lines: bool = True,
|
||||
) -> Optional[Tuple[Optional[str], float]]:
|
||||
"""
|
||||
Attempt ONNX inference. Returns the (text, confidence) tuple on
|
||||
success, or None if ONNX is not available / fails to load.
|
||||
"""
|
||||
try:
|
||||
from .trocr_onnx_service import is_onnx_available, run_trocr_onnx
|
||||
|
||||
if not is_onnx_available(handwritten=handwritten):
|
||||
return None
|
||||
# run_trocr_onnx is async — return the coroutine's awaitable result
|
||||
# The caller (run_trocr_ocr) will await it.
|
||||
return run_trocr_onnx # sentinel: caller checks callable
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
async def _run_pytorch_ocr(
|
||||
image_data: bytes,
|
||||
handwritten: bool = False,
|
||||
split_lines: bool = True,
|
||||
size: str = "base",
|
||||
) -> Tuple[Optional[str], float]:
|
||||
"""
|
||||
Original PyTorch inference path (extracted for routing).
|
||||
"""
|
||||
processor, model = get_trocr_model(handwritten=handwritten, size=size)
|
||||
|
||||
if processor is None or model is None:
|
||||
logger.error("TrOCR PyTorch model not available")
|
||||
return None, 0.0
|
||||
|
||||
try:
|
||||
import torch
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
# Load image
|
||||
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
||||
|
||||
if split_lines:
|
||||
lines = _split_into_lines(image)
|
||||
if not lines:
|
||||
lines = [image]
|
||||
else:
|
||||
lines = [image]
|
||||
|
||||
all_text = []
|
||||
confidences = []
|
||||
|
||||
for line_image in lines:
|
||||
pixel_values = processor(images=line_image, return_tensors="pt").pixel_values
|
||||
|
||||
device = next(model.parameters()).device
|
||||
pixel_values = pixel_values.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
generated_ids = model.generate(pixel_values, max_length=128)
|
||||
|
||||
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
if generated_text.strip():
|
||||
all_text.append(generated_text.strip())
|
||||
confidences.append(0.85 if len(generated_text) > 3 else 0.5)
|
||||
|
||||
text = "\n".join(all_text)
|
||||
confidence = sum(confidences) / len(confidences) if confidences else 0.0
|
||||
|
||||
logger.info(f"TrOCR (PyTorch) extracted {len(text)} characters from {len(lines)} lines")
|
||||
return text, confidence
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"TrOCR PyTorch failed: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return None, 0.0
|
||||
|
||||
|
||||
async def run_trocr_ocr(
|
||||
image_data: bytes,
|
||||
handwritten: bool = False,
|
||||
split_lines: bool = True,
|
||||
size: str = "base",
|
||||
) -> Tuple[Optional[str], float]:
|
||||
"""
|
||||
Run TrOCR on an image.
|
||||
|
||||
Routes between ONNX and PyTorch backends based on the TROCR_BACKEND
|
||||
environment variable (default: "auto").
|
||||
|
||||
- "onnx" — always use ONNX (raises RuntimeError if unavailable)
|
||||
- "pytorch" — always use PyTorch (original behaviour)
|
||||
- "auto" — try ONNX first, fall back to PyTorch
|
||||
|
||||
TrOCR is optimized for single-line text recognition, so for full-page
|
||||
images we need to either:
|
||||
1. Split into lines first (using line detection)
|
||||
2. Process the whole image and get partial results
|
||||
|
||||
Args:
|
||||
image_data: Raw image bytes
|
||||
handwritten: Use handwritten model (slower but better for handwriting)
|
||||
split_lines: Whether to split image into lines first
|
||||
size: "base" or "large" (only for handwritten variant)
|
||||
|
||||
Returns:
|
||||
Tuple of (extracted_text, confidence)
|
||||
"""
|
||||
backend = _trocr_backend
|
||||
|
||||
# --- ONNX-only mode ---
|
||||
if backend == "onnx":
|
||||
onnx_fn = _try_onnx_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
if onnx_fn is None or not callable(onnx_fn):
|
||||
raise RuntimeError(
|
||||
"ONNX backend requested (TROCR_BACKEND=onnx) but unavailable. "
|
||||
"Ensure onnxruntime + optimum are installed and ONNX model files exist."
|
||||
)
|
||||
return await onnx_fn(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
|
||||
# --- PyTorch-only mode ---
|
||||
if backend == "pytorch":
|
||||
return await _run_pytorch_ocr(
|
||||
image_data, handwritten=handwritten, split_lines=split_lines, size=size,
|
||||
)
|
||||
|
||||
# --- Auto mode: try ONNX first, then PyTorch ---
|
||||
onnx_fn = _try_onnx_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
if onnx_fn is not None and callable(onnx_fn):
|
||||
try:
|
||||
result = await onnx_fn(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
if result[0] is not None:
|
||||
return result
|
||||
logger.warning("ONNX returned None text, falling back to PyTorch")
|
||||
except Exception as e:
|
||||
logger.warning(f"ONNX inference failed ({e}), falling back to PyTorch")
|
||||
|
||||
return await _run_pytorch_ocr(
|
||||
image_data, handwritten=handwritten, split_lines=split_lines, size=size,
|
||||
)
|
||||
|
||||
|
||||
def _split_into_lines(image) -> list:
|
||||
"""
|
||||
Split an image into text lines using simple projection-based segmentation.
|
||||
|
||||
This is a basic implementation - for production use, consider using
|
||||
a dedicated line detection model.
|
||||
"""
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
try:
|
||||
# Convert to grayscale
|
||||
gray = image.convert('L')
|
||||
img_array = np.array(gray)
|
||||
|
||||
# Binarize (simple threshold)
|
||||
threshold = 200
|
||||
binary = img_array < threshold
|
||||
|
||||
# Horizontal projection (sum of dark pixels per row)
|
||||
h_proj = np.sum(binary, axis=1)
|
||||
|
||||
# Find line boundaries (where projection drops below threshold)
|
||||
line_threshold = img_array.shape[1] * 0.02 # 2% of width
|
||||
in_line = False
|
||||
line_start = 0
|
||||
lines = []
|
||||
|
||||
for i, val in enumerate(h_proj):
|
||||
if val > line_threshold and not in_line:
|
||||
# Start of line
|
||||
in_line = True
|
||||
line_start = i
|
||||
elif val <= line_threshold and in_line:
|
||||
# End of line
|
||||
in_line = False
|
||||
# Add padding
|
||||
start = max(0, line_start - 5)
|
||||
end = min(img_array.shape[0], i + 5)
|
||||
if end - start > 10: # Minimum line height
|
||||
lines.append(image.crop((0, start, image.width, end)))
|
||||
|
||||
# Handle last line if still in_line
|
||||
if in_line:
|
||||
start = max(0, line_start - 5)
|
||||
lines.append(image.crop((0, start, image.width, image.height)))
|
||||
|
||||
logger.info(f"Split image into {len(lines)} lines")
|
||||
return lines
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Line splitting failed: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def _try_onnx_enhanced(
|
||||
handwritten: bool = True,
|
||||
):
|
||||
"""
|
||||
Return the ONNX enhanced coroutine function, or None if unavailable.
|
||||
"""
|
||||
try:
|
||||
from .trocr_onnx_service import is_onnx_available, run_trocr_onnx_enhanced
|
||||
|
||||
if not is_onnx_available(handwritten=handwritten):
|
||||
return None
|
||||
return run_trocr_onnx_enhanced
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
async def run_trocr_ocr_enhanced(
|
||||
image_data: bytes,
|
||||
handwritten: bool = True,
|
||||
split_lines: bool = True,
|
||||
use_cache: bool = True
|
||||
) -> OCRResult:
|
||||
"""
|
||||
Enhanced TrOCR OCR with caching and detailed results.
|
||||
|
||||
Routes between ONNX and PyTorch backends based on the TROCR_BACKEND
|
||||
environment variable (default: "auto").
|
||||
|
||||
Args:
|
||||
image_data: Raw image bytes
|
||||
handwritten: Use handwritten model
|
||||
split_lines: Whether to split image into lines first
|
||||
use_cache: Whether to use caching
|
||||
|
||||
Returns:
|
||||
OCRResult with detailed information
|
||||
"""
|
||||
backend = _trocr_backend
|
||||
|
||||
# --- ONNX-only mode ---
|
||||
if backend == "onnx":
|
||||
onnx_fn = _try_onnx_enhanced(handwritten=handwritten)
|
||||
if onnx_fn is None:
|
||||
raise RuntimeError(
|
||||
"ONNX backend requested (TROCR_BACKEND=onnx) but unavailable. "
|
||||
"Ensure onnxruntime + optimum are installed and ONNX model files exist."
|
||||
)
|
||||
return await onnx_fn(
|
||||
image_data, handwritten=handwritten,
|
||||
split_lines=split_lines, use_cache=use_cache,
|
||||
)
|
||||
|
||||
# --- Auto mode: try ONNX first ---
|
||||
if backend == "auto":
|
||||
onnx_fn = _try_onnx_enhanced(handwritten=handwritten)
|
||||
if onnx_fn is not None:
|
||||
try:
|
||||
result = await onnx_fn(
|
||||
image_data, handwritten=handwritten,
|
||||
split_lines=split_lines, use_cache=use_cache,
|
||||
)
|
||||
if result.text:
|
||||
return result
|
||||
logger.warning("ONNX enhanced returned empty text, falling back to PyTorch")
|
||||
except Exception as e:
|
||||
logger.warning(f"ONNX enhanced failed ({e}), falling back to PyTorch")
|
||||
|
||||
# --- PyTorch path (backend == "pytorch" or auto fallback) ---
|
||||
start_time = time.time()
|
||||
|
||||
# Check cache first
|
||||
image_hash = _compute_image_hash(image_data)
|
||||
if use_cache:
|
||||
cached = _cache_get(image_hash)
|
||||
if cached:
|
||||
return OCRResult(
|
||||
text=cached["text"],
|
||||
confidence=cached["confidence"],
|
||||
processing_time_ms=0,
|
||||
model=cached["model"],
|
||||
has_lora_adapter=cached.get("has_lora_adapter", False),
|
||||
char_confidences=cached.get("char_confidences", []),
|
||||
word_boxes=cached.get("word_boxes", []),
|
||||
from_cache=True,
|
||||
image_hash=image_hash
|
||||
)
|
||||
|
||||
# Run OCR via PyTorch
|
||||
text, confidence = await _run_pytorch_ocr(image_data, handwritten=handwritten, split_lines=split_lines)
|
||||
|
||||
processing_time_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# Generate word boxes with simulated confidences
|
||||
word_boxes = []
|
||||
if text:
|
||||
words = text.split()
|
||||
for idx, word in enumerate(words):
|
||||
# Simulate word confidence (slightly varied around overall confidence)
|
||||
word_conf = min(1.0, max(0.0, confidence + (hash(word) % 20 - 10) / 100))
|
||||
word_boxes.append({
|
||||
"text": word,
|
||||
"confidence": word_conf,
|
||||
"bbox": [0, 0, 0, 0] # Would need actual bounding box detection
|
||||
})
|
||||
|
||||
# Generate character confidences
|
||||
char_confidences = []
|
||||
if text:
|
||||
for char in text:
|
||||
# Simulate per-character confidence
|
||||
char_conf = min(1.0, max(0.0, confidence + (hash(char) % 15 - 7) / 100))
|
||||
char_confidences.append(char_conf)
|
||||
|
||||
result = OCRResult(
|
||||
text=text or "",
|
||||
confidence=confidence,
|
||||
processing_time_ms=processing_time_ms,
|
||||
model="trocr-base-handwritten" if handwritten else "trocr-base-printed",
|
||||
has_lora_adapter=False, # Would check actual adapter status
|
||||
char_confidences=char_confidences,
|
||||
word_boxes=word_boxes,
|
||||
from_cache=False,
|
||||
image_hash=image_hash
|
||||
)
|
||||
|
||||
# Cache result
|
||||
if use_cache and text:
|
||||
_cache_set(image_hash, {
|
||||
"text": result.text,
|
||||
"confidence": result.confidence,
|
||||
"model": result.model,
|
||||
"has_lora_adapter": result.has_lora_adapter,
|
||||
"char_confidences": result.char_confidences,
|
||||
"word_boxes": result.word_boxes
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def run_trocr_batch(
|
||||
images: List[bytes],
|
||||
handwritten: bool = True,
|
||||
split_lines: bool = True,
|
||||
use_cache: bool = True,
|
||||
progress_callback: Optional[callable] = None
|
||||
) -> BatchOCRResult:
|
||||
"""
|
||||
Process multiple images in batch.
|
||||
|
||||
Args:
|
||||
images: List of image data bytes
|
||||
handwritten: Use handwritten model
|
||||
split_lines: Whether to split images into lines
|
||||
use_cache: Whether to use caching
|
||||
progress_callback: Optional callback(current, total) for progress updates
|
||||
|
||||
Returns:
|
||||
BatchOCRResult with all results
|
||||
"""
|
||||
start_time = time.time()
|
||||
results = []
|
||||
cached_count = 0
|
||||
error_count = 0
|
||||
|
||||
for idx, image_data in enumerate(images):
|
||||
try:
|
||||
result = await run_trocr_ocr_enhanced(
|
||||
image_data,
|
||||
handwritten=handwritten,
|
||||
split_lines=split_lines,
|
||||
use_cache=use_cache
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
if result.from_cache:
|
||||
cached_count += 1
|
||||
|
||||
# Report progress
|
||||
if progress_callback:
|
||||
progress_callback(idx + 1, len(images))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Batch OCR error for image {idx}: {e}")
|
||||
error_count += 1
|
||||
results.append(OCRResult(
|
||||
text=f"Error: {str(e)}",
|
||||
confidence=0.0,
|
||||
processing_time_ms=0,
|
||||
model="error",
|
||||
has_lora_adapter=False
|
||||
))
|
||||
|
||||
total_time_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
return BatchOCRResult(
|
||||
results=results,
|
||||
total_time_ms=total_time_ms,
|
||||
processed_count=len(images),
|
||||
cached_count=cached_count,
|
||||
error_count=error_count
|
||||
)
|
||||
|
||||
|
||||
# Generator for SSE streaming during batch processing
|
||||
async def run_trocr_batch_stream(
|
||||
images: List[bytes],
|
||||
handwritten: bool = True,
|
||||
split_lines: bool = True,
|
||||
use_cache: bool = True
|
||||
):
|
||||
"""
|
||||
Process images and yield progress updates for SSE streaming.
|
||||
|
||||
Yields:
|
||||
dict with current progress and result
|
||||
"""
|
||||
start_time = time.time()
|
||||
total = len(images)
|
||||
|
||||
for idx, image_data in enumerate(images):
|
||||
try:
|
||||
result = await run_trocr_ocr_enhanced(
|
||||
image_data,
|
||||
handwritten=handwritten,
|
||||
split_lines=split_lines,
|
||||
use_cache=use_cache
|
||||
)
|
||||
|
||||
elapsed_ms = int((time.time() - start_time) * 1000)
|
||||
avg_time_per_image = elapsed_ms / (idx + 1)
|
||||
estimated_remaining = int(avg_time_per_image * (total - idx - 1))
|
||||
|
||||
yield {
|
||||
"type": "progress",
|
||||
"current": idx + 1,
|
||||
"total": total,
|
||||
"progress_percent": ((idx + 1) / total) * 100,
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"estimated_remaining_ms": estimated_remaining,
|
||||
"result": {
|
||||
"text": result.text,
|
||||
"confidence": result.confidence,
|
||||
"processing_time_ms": result.processing_time_ms,
|
||||
"from_cache": result.from_cache
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Stream OCR error for image {idx}: {e}")
|
||||
yield {
|
||||
"type": "error",
|
||||
"current": idx + 1,
|
||||
"total": total,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
total_time_ms = int((time.time() - start_time) * 1000)
|
||||
yield {
|
||||
"type": "complete",
|
||||
"total_time_ms": total_time_ms,
|
||||
"processed_count": total
|
||||
}
|
||||
|
||||
|
||||
# Test function
|
||||
async def test_trocr_ocr(image_path: str, handwritten: bool = False):
|
||||
"""Test TrOCR on a local image file."""
|
||||
with open(image_path, "rb") as f:
|
||||
image_data = f.read()
|
||||
|
||||
text, confidence = await run_trocr_ocr(image_data, handwritten=handwritten)
|
||||
|
||||
print(f"\n=== TrOCR Test ===")
|
||||
print(f"Mode: {'Handwritten' if handwritten else 'Printed'}")
|
||||
print(f"Confidence: {confidence:.2f}")
|
||||
print(f"Text:\n{text}")
|
||||
|
||||
return text, confidence
|
||||
# Models, cache, and model loading
|
||||
from .trocr_models import (
|
||||
OCRResult,
|
||||
BatchOCRResult,
|
||||
_compute_image_hash,
|
||||
_cache_get,
|
||||
_cache_set,
|
||||
get_cache_stats,
|
||||
_check_trocr_available,
|
||||
get_trocr_model,
|
||||
preload_trocr_model,
|
||||
get_model_status,
|
||||
get_active_backend,
|
||||
_split_into_lines,
|
||||
)
|
||||
|
||||
# Core OCR execution
|
||||
from .trocr_ocr import (
|
||||
run_trocr_ocr,
|
||||
run_trocr_ocr_enhanced,
|
||||
_run_pytorch_ocr,
|
||||
)
|
||||
|
||||
# Batch processing & streaming
|
||||
from .trocr_batch import (
|
||||
run_trocr_batch,
|
||||
run_trocr_batch_stream,
|
||||
test_trocr_ocr,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Dataclasses
|
||||
"OCRResult",
|
||||
"BatchOCRResult",
|
||||
# Cache
|
||||
"_compute_image_hash",
|
||||
"_cache_get",
|
||||
"_cache_set",
|
||||
"get_cache_stats",
|
||||
# Model loading
|
||||
"_check_trocr_available",
|
||||
"get_trocr_model",
|
||||
"preload_trocr_model",
|
||||
"get_model_status",
|
||||
"get_active_backend",
|
||||
"_split_into_lines",
|
||||
# OCR execution
|
||||
"run_trocr_ocr",
|
||||
"run_trocr_ocr_enhanced",
|
||||
"_run_pytorch_ocr",
|
||||
# Batch
|
||||
"run_trocr_batch",
|
||||
"run_trocr_batch_stream",
|
||||
"test_trocr_ocr",
|
||||
]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user